114 lines
3.9 KiB
Python
114 lines
3.9 KiB
Python
import os
|
|
import streamlit as st
|
|
from dotenv import load_dotenv
|
|
from langchain_community.vectorstores import FAISS
|
|
from sentence_transformers import SentenceTransformer
|
|
from langchain.embeddings.base import Embeddings
|
|
import requests # To handle HTTP requests for Groq API
|
|
|
|
# Load environment variables from .env file
|
|
load_dotenv()
|
|
|
|
# Set page configuration
|
|
st.set_page_config(page_title="Ayurveda Chatbot", layout="wide")
|
|
|
|
# Check for the GROQ_API_KEY environment variable
|
|
groq_key = os.getenv("GROQ_API_KEY")
|
|
if not groq_key:
|
|
st.error("The 'GROQ_API_KEY' environment variable is not set. Please set it in the .env file or the environment.")
|
|
else:
|
|
st.write(f"GROQ_KEY loaded successfully")
|
|
|
|
# Define a custom embedding wrapper for LangChain
|
|
class SentenceTransformerEmbeddings(Embeddings):
|
|
def __init__(self, model_name="all-MiniLM-L6-v2"):
|
|
self.model = SentenceTransformer(model_name)
|
|
|
|
def embed_documents(self, texts):
|
|
return self.model.encode(texts, show_progress_bar=True)
|
|
|
|
def embed_query(self, text):
|
|
return self.model.encode([text], show_progress_bar=False)[0]
|
|
|
|
# Path to FAISS index
|
|
faiss_index_path = "faiss_index"
|
|
|
|
# Load FAISS Index with dangerous deserialization enabled
|
|
embedding_model = SentenceTransformerEmbeddings()
|
|
try:
|
|
db = FAISS.load_local(faiss_index_path, embedding_model, allow_dangerous_deserialization=True)
|
|
except Exception as e:
|
|
st.error(f"Failed to load FAISS index: {str(e)}")
|
|
db = None
|
|
|
|
# Define the class to handle API calls to Groq
|
|
class GroqAPI:
|
|
def __init__(self, api_key):
|
|
self.api_key = api_key
|
|
self.endpoint = "https://api.groq.com/openai/v1/chat/completions"
|
|
|
|
def generate_answer(self, query, context, model="llama-3.3-70b-versatile"):
|
|
# Prepare the system message
|
|
system_message = (
|
|
"You are an Ayurvedic expert with deep knowledge of Ayurvedic practices, remedies, and diagnostics. "
|
|
"Use the provided Ayurvedic context to answer the question thoughtfully and accurately.\n\n"
|
|
f"Context:\n{context}\n\n"
|
|
f"Question:\n{query}\n\n"
|
|
"Answer as an Ayurvedic expert:"
|
|
)
|
|
|
|
payload = {
|
|
"model": model,
|
|
"messages": [
|
|
{"role": "system", "content": system_message},
|
|
{"role": "user", "content": query}
|
|
]
|
|
}
|
|
headers = {
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json"
|
|
}
|
|
|
|
try:
|
|
response = requests.post(self.endpoint, json=payload, headers=headers)
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
return result["choices"][0]["message"]["content"]
|
|
else:
|
|
return f"Error: {response.status_code} - {response.text}"
|
|
except Exception as e:
|
|
return f"Error: {str(e)}"
|
|
|
|
# Initialize the GroqAPI
|
|
groq_api = GroqAPI(api_key=groq_key)
|
|
|
|
# Custom QA chain function that integrates FAISS and Groq API
|
|
def custom_qa_chain(query):
|
|
if not db:
|
|
return "FAISS index is not loaded."
|
|
try:
|
|
# Retrieve relevant context from FAISS index
|
|
context = db.similarity_search(query, k=3)
|
|
context_text = "\n".join([doc.page_content for doc in context])
|
|
|
|
# Get the response from Groq API
|
|
response = groq_api.generate_answer(query, context_text)
|
|
except Exception as e:
|
|
response = f"Error during QA chain: {str(e)}"
|
|
|
|
return response
|
|
|
|
# Streamlit UI
|
|
st.title("Ayurveda Chatbot")
|
|
|
|
st.subheader("Ask your Ayurvedic Question")
|
|
query = st.text_input("Enter your query:")
|
|
if query:
|
|
with st.spinner("Retrieving answer..."):
|
|
st.write(f"Processing query: {query}")
|
|
|
|
# Get the response from custom QA chain
|
|
response = custom_qa_chain(query)
|
|
|
|
st.markdown(f"### Answer:\n{response}")
|