-
Notifications
You must be signed in to change notification settings - Fork 2
/
gemini.py
166 lines (143 loc) · 6.57 KB
/
gemini.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import os
import json
import google.generativeai as genai
from dotenv import load_dotenv
from datetime import datetime
def gemini_request(input="", system="Answer in detail", model_name="gemini-1.5-flash-exp-0827",
max_tokens=8192, response_type="text", history=[], stream=False, API_KEY=None):
response_type = "application/json" if response_type == "json" else "text/plain"
if API_KEY:
genai.configure(api_key=API_KEY)
else:
# Load environment variables from .env file
load_dotenv()
genai.configure(api_key=os.environ["GEMINI_API_KEY"])
generation_config = {
"temperature": 1,
"top_p": 0.95,
"top_k": 64,
"max_output_tokens": max_tokens,
"response_mime_type": response_type,
}
model = genai.GenerativeModel(
model_name=model_name,
generation_config=generation_config,
system_instruction=system,
)
chat_session = model.start_chat(history=history)
if stream is True:
response = chat_session.send_message(input ,stream=True)
for chunk in response:
yield chunk.text
else:
response = chat_session.send_message(input, stream=False)
yield response.text
class GeminiPlus:
def __init__(self, model_name="gemini-1.5-flash", system="Answer in detail",
max_tokens=8192, response_type="text", API_KEY = None):
self.model_name = model_name
self.system = system
self.chat_histories = self.load_chat_histories('chat_histories.json')
self.current_convo_name = None
self.max_tokens = max_tokens
self.response_type = "application/json" if response_type == "json" else "text/plain"
self.convo_timestamps = {}
self.API_KEY = API_KEY
@staticmethod
def load_chat_histories(filename):
try:
with open(filename, 'r') as file:
return json.load(file)
except (FileNotFoundError, json.JSONDecodeError):
return {}
@staticmethod
def save_chat_histories(filename, histories):
with open(filename, 'w') as file:
json.dump(histories, file)
def start_conversation(self, convo_name):
self.current_convo_name = convo_name
self.chat_history = self.chat_histories.get(convo_name, [])
self.update_conversation_timestamp(convo_name)
def start_temp_conversation(self):
self.current_convo_name = "temporary"
self.chat_history = []
def update_conversation_timestamp(self, convo_name):
self.convo_timestamps[convo_name] = datetime.now().isoformat()
def display_conversation_list(self):
if not self.chat_histories:
print("No previous conversations found.")
return []
sorted_convos = sorted(self.chat_histories.keys(),
key=lambda x: self.convo_timestamps.get(x, "1970-01-01T00:00:00"),
reverse=True)
print("Available conversations:")
for index, convo in enumerate(sorted_convos):
print(f"{index + 1}: {convo}")
return sorted_convos
def send_message(self, user_input):
if self.current_convo_name is None:
return "Please start a conversation first."
try:
response_text = gemini_request(user_input, system=self.system,
model_name=self.model_name, max_tokens=self.max_tokens,
response_type=self.response_type, history=self.chat_history, stream=True, API_KEY=self.API_KEY)
response_history = ""
for chunk in response_text:
response_history += chunk
yield chunk
if self.current_convo_name != "temporary":
self.chat_history.append({"role": "user", "parts": [user_input]})
self.chat_history.append({"role": "model", "parts": [response_history]})
self.chat_histories[self.current_convo_name] = self.chat_history
self.save_chat_histories('chat_histories.json', self.chat_histories)
self.update_conversation_timestamp(self.current_convo_name)
except Exception as e:
print(f"Error sending message: {e}")
return "Sorry, there was an error processing your request."
def delete_conversation(self, convo_name):
if convo_name in self.chat_histories:
del self.chat_histories[convo_name]
if convo_name in self.convo_timestamps:
del self.convo_timestamps[convo_name]
self.save_chat_histories('chat_histories.json', self.chat_histories)
print(f"Conversation '{convo_name}' deleted successfully.")
else:
print(f"Conversation '{convo_name}' not found.")
def run_chat():
gemini_instance = GeminiPlus(model_name="gemini-1.5-flash")
while True:
convo_name = input("Enter conversation name (or type 'old' to choose from history, 'temp' for temporary chat): ")
if convo_name.lower() in ['exit', 'quit']:
break
elif convo_name.lower() == 'old':
sorted_convos = gemini_instance.display_conversation_list()
if sorted_convos:
try:
choice = int(input("Choose a conversation by number: ")) - 1
if 0 <= choice < len(sorted_convos):
gemini_instance.start_conversation(sorted_convos[choice])
else:
print("Invalid choice. Please try again.")
except ValueError:
print("Invalid input. Please enter a number.")
elif convo_name.lower() == 'temp':
gemini_instance.start_temp_conversation()
else:
if convo_name:
gemini_instance.start_conversation(convo_name)
else:
gemini_instance.start_temp_conversation()
while True:
user_input = input("You: ")
if user_input.lower() in ['exit', 'quit']:
break
response_text = gemini_instance.send_message(user_input)
print(f"Bot: ", end="")
for chunk in response_text:
print(chunk, end="")
if __name__ == "__main__":
# run_chat()
gems = gemini_request(input="Explain friction in detail", stream=False)
# print(gems)
for chunk in gems:
print(chunk)