-
Notifications
You must be signed in to change notification settings - Fork 43
/
app_6.py
105 lines (86 loc) · 3.45 KB
/
app_6.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import streamlit as st
import os
from langchain_openai import OpenAIEmbeddings
from langchain_openai import ChatOpenAI
from langchain_community.vectorstores import AstraDB
from langchain.schema.runnable import RunnableMap
from langchain.prompts import ChatPromptTemplate
from langchain.callbacks.base import BaseCallbackHandler
# Streaming call back handler for responses
class StreamHandler(BaseCallbackHandler):
def __init__(self, container, initial_text=""):
self.container = container
self.text = initial_text
def on_llm_new_token(self, token: str, **kwargs):
self.text += token
self.container.markdown(self.text + "▌")
# Cache prompt for future runs
@st.cache_data()
def load_prompt():
template = """You're a helpful AI assistent tasked to answer the user's questions.
You're friendly and you answer extensively with multiple sentences. You prefer to use bulletpoints to summarize.
CONTEXT:
{context}
QUESTION:
{question}
YOUR ANSWER:"""
return ChatPromptTemplate.from_messages([("system", template)])
prompt = load_prompt()
# Cache OpenAI Chat Model for future runs
@st.cache_resource()
def load_chat_model():
return ChatOpenAI(
temperature=0.3,
model='gpt-3.5-turbo',
streaming=True,
verbose=True
)
chat_model = load_chat_model()
# Cache the Astra DB Vector Store for future runs
@st.cache_resource(show_spinner='Connecting to Astra')
def load_retriever():
# Connect to the Vector Store
vector_store = AstraDB(
embedding=OpenAIEmbeddings(),
collection_name="my_store",
api_endpoint=st.secrets['ASTRA_API_ENDPOINT'],
token=st.secrets['ASTRA_TOKEN']
)
# Get the retriever for the Chat Model
retriever = vector_store.as_retriever(
search_kwargs={"k": 5}
)
return retriever
retriever = load_retriever()
# Start with empty messages, stored in session state
if 'messages' not in st.session_state:
st.session_state.messages = []
# Draw a title and some markdown
st.title("Your personal Efficiency Booster")
st.markdown("""Generative AI is considered to bring the next Industrial Revolution.
Why? Studies show a **37% efficiency boost** in day to day work activities!""")
# Draw all messages, both user and bot so far (every time the app reruns)
for message in st.session_state.messages:
st.chat_message(message['role']).markdown(message['content'])
# Draw the chat input box
if question := st.chat_input("What's up?"):
# Store the user's question in a session object for redrawing next time
st.session_state.messages.append({"role": "human", "content": question})
# Draw the user's question
with st.chat_message('human'):
st.markdown(question)
# UI placeholder to start filling with agent response
with st.chat_message('assistant'):
response_placeholder = st.empty()
# Generate the answer by calling OpenAI's Chat Model
inputs = RunnableMap({
'context': lambda x: retriever.get_relevant_documents(x['question']),
'question': lambda x: x['question']
})
chain = inputs | prompt | chat_model
response = chain.invoke({'question': question}, config={'callbacks': [StreamHandler(response_placeholder)]})
answer = response.content
# Store the bot's answer in a session object for redrawing next time
st.session_state.messages.append({"role": "ai", "content": answer})
# Write the final answer without the cursor
response_placeholder.markdown(answer)