-
Notifications
You must be signed in to change notification settings - Fork 53
/
full_chain.py
68 lines (50 loc) · 2.01 KB
/
full_chain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
from dotenv import load_dotenv
from langchain.memory import ChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate
from basic_chain import get_model
from filter import ensemble_retriever_from_docs
from local_loader import load_txt_files
from memory import create_memory_chain
from rag_chain import make_rag_chain
def create_full_chain(retriever, openai_api_key=None, chat_memory=ChatMessageHistory()):
model = get_model("ChatGPT", openai_api_key=openai_api_key)
system_prompt = """You are a helpful AI assistant for busy professionals trying to improve their health.
Use the following context and the users' chat history to help the user:
If you don't know the answer, just say that you don't know.
Context: {context}
Question: """
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
("human", "{question}"),
]
)
rag_chain = make_rag_chain(model, retriever, rag_prompt=prompt)
chain = create_memory_chain(model, rag_chain, chat_memory)
return chain
def ask_question(chain, query):
response = chain.invoke(
{"question": query},
config={"configurable": {"session_id": "foo"}}
)
return response
def main():
load_dotenv()
from rich.console import Console
from rich.markdown import Markdown
console = Console()
docs = load_txt_files()
ensemble_retriever = ensemble_retriever_from_docs(docs)
chain = create_full_chain(ensemble_retriever)
queries = [
"Generate a grocery list for my family meal plan for the next week(following 7 days). Prefer local, in-season ingredients."
"Create a list of estimated calorie counts and grams of carbohydrates for each meal."
]
for query in queries:
response = ask_question(chain, query)
console.print(Markdown(response.content))
if __name__ == '__main__':
# this is to quiet parallel tokenizers warning.
os.environ["TOKENIZERS_PARALLELISM"] = "false"
main()