-
Notifications
You must be signed in to change notification settings - Fork 53
/
vector_store.py
90 lines (67 loc) · 2.88 KB
/
vector_store.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import logging
import os
from typing import List
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from local_loader import get_document_text
from remote_loader import download_file
from splitter import split_documents
from dotenv import load_dotenv
from time import sleep
EMBED_DELAY = 0.02 # 20 milliseconds
# This is to get the Streamlit app to use less CPU while embedding documents into Chromadb.
class EmbeddingProxy:
def __init__(self, embedding):
self.embedding = embedding
def embed_documents(self, texts: List[str]) -> List[List[float]]:
sleep(EMBED_DELAY)
return self.embedding.embed_documents(texts)
def embed_query(self, text: str) -> List[float]:
sleep(EMBED_DELAY)
return self.embedding.embed_query(text)
# This happens all at once, not ideal for large datasets.
def create_vector_db(texts, embeddings=None, collection_name="chroma"):
if not texts:
logging.warning("Empty texts passed in to create vector database")
# Select embeddings
if not embeddings:
# To use HuggingFace embeddings instead:
# from langchain_community.embeddings import HuggingFaceEmbeddings
# embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
openai_api_key = os.environ["OPENAI_API_KEY"]
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key, model="text-embedding-3-small")
proxy_embeddings = EmbeddingProxy(embeddings)
# Create a vectorstore from documents
# this will be a chroma collection with a default name.
db = Chroma(collection_name=collection_name,
embedding_function=proxy_embeddings,
persist_directory=os.path.join("store/", collection_name))
db.add_documents(texts)
return db
def find_similar(vs, query):
docs = vs.similarity_search(query)
return docs
def main():
load_dotenv()
pdf_filename = "examples/mal_boole.pdf"
if not os.path.exists(pdf_filename):
math_analysis_of_logic_by_boole = "https://www.gutenberg.org/files/36884/36884-pdf.pdf"
local_pdf_path = download_file(math_analysis_of_logic_by_boole, pdf_filename)
else:
local_pdf_path = pdf_filename
print(f"PDF path is {local_pdf_path}")
with open(local_pdf_path, "rb") as pdf_file:
docs = get_document_text(pdf_file, title="Analysis of Logic")
texts = split_documents(docs)
vs = create_vector_db(texts)
results = find_similar(vs, query="What is meant by the simple conversion of a proposition?")
MAX_CHARS = 300
print("=== Results ===")
for i, text in enumerate(results):
# cap to max length but split by words.
content = text.page_content
n = max(content.find(' ', MAX_CHARS), MAX_CHARS)
content = text.page_content[:n]
print(f"Result {i + 1}:\n {content}\n")
if __name__ == "__main__":
main()