67 lines
2.7 KiB
Python
67 lines
2.7 KiB
Python
import os
|
|
from langchain_community.document_loaders import PyPDFLoader
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
from langchain_community.vectorstores import FAISS
|
|
from sentence_transformers import SentenceTransformer
|
|
from langchain.embeddings.base import Embeddings
|
|
from langchain.docstore.in_memory import InMemoryDocstore
|
|
import faiss
|
|
|
|
# Define a custom embedding wrapper for LangChain
|
|
class SentenceTransformerEmbeddings(Embeddings):
|
|
def __init__(self, model_name="all-MiniLM-L6-v2"):
|
|
self.model = SentenceTransformer(model_name)
|
|
|
|
def embed_documents(self, texts):
|
|
return self.model.encode(texts, show_progress_bar=True)
|
|
|
|
def embed_query(self, text):
|
|
return self.model.encode([text], show_progress_bar=False)[0]
|
|
|
|
# Function to create an empty FAISS index
|
|
def create_empty_faiss_index(embedding_model):
|
|
embedding_dimension = embedding_model.model.get_sentence_embedding_dimension()
|
|
index = faiss.IndexFlatL2(embedding_dimension) # Initialize FAISS index
|
|
docstore = InMemoryDocstore({})
|
|
index_to_docstore_id = {}
|
|
return FAISS(index=index, docstore=docstore, index_to_docstore_id=index_to_docstore_id, embedding_function=embedding_model)
|
|
|
|
# Function to update the FAISS index with new books
|
|
def update_faiss_index(book_paths, faiss_index_path="faiss_index"):
|
|
# Load or initialize FAISS index
|
|
embedding_model = SentenceTransformerEmbeddings()
|
|
if os.path.exists(faiss_index_path):
|
|
print("Loading existing FAISS index...")
|
|
db = FAISS.load_local(faiss_index_path, embedding_model)
|
|
else:
|
|
print("Creating a new FAISS index...")
|
|
db = create_empty_faiss_index(embedding_model)
|
|
|
|
# Process each book
|
|
for book_path in book_paths:
|
|
print(f"Processing book: {book_path}")
|
|
loader = PyPDFLoader(book_path)
|
|
documents = loader.load()
|
|
|
|
# Split text into chunks
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
|
|
chunks = text_splitter.split_documents(documents)
|
|
texts = [chunk.page_content for chunk in chunks]
|
|
|
|
# Add embeddings to FAISS index
|
|
db.add_texts(texts)
|
|
|
|
# Save the updated FAISS index
|
|
db.save_local(faiss_index_path)
|
|
print(f"FAISS index updated and saved at: {faiss_index_path}")
|
|
|
|
# Command-line interface
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
parser = argparse.ArgumentParser(description="Update FAISS index with new books")
|
|
parser.add_argument("books", nargs="+", help="Path(s) to the PDF book(s)")
|
|
parser.add_argument("--index-path", default="faiss_index", help="Path to FAISS index directory")
|
|
args = parser.parse_args()
|
|
|
|
update_faiss_index(args.books, args.index_path)
|