diff --git a/chat_history.db b/chat_history.db index c145c6e4c322d0352570f003664828d908ec7ea7..25e601faa32abe18c62246f1fbca27d864e58994 100644 Binary files a/chat_history.db and b/chat_history.db differ diff --git a/streamlit_rag.py b/streamlit_rag.py index 95ae3d85a29f1d532e328bfa6f827e547fd704db..8a9ecc3ee4f69792f31a12fc245deb74057a35f9 100755 --- a/streamlit_rag.py +++ b/streamlit_rag.py @@ -30,7 +30,6 @@ auth_kwargs = { } CHROMA_PATH = "chroma" -DATA_PATH = "data" #RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2" #reranker = CrossEncoder(RERANKER_MODEL) @@ -79,12 +78,26 @@ def load_chat_history(): return chat_sessions def delete_session(session_name): + # Delete from SQLite database conn = sqlite3.connect(DATABASE_PATH) c = conn.cursor() c.execute("DELETE FROM chat_sessions WHERE session_name = ?", (session_name,)) conn.commit() conn.close() + # Clear Chroma collections for all providers + providers = ["OpenAI", "Ollama", "PG"] + for provider in providers: + try: + collection_name = get_collection_name(provider, session_name) + db = Chroma( + persist_directory=CHROMA_PATH, + embedding_function=get_embedding_function(provider), + collection_name=collection_name + ) + db.delete_collection() + except Exception as e: + st.warning(f"Could not delete collection for {provider}: {str(e)}") def get_embedding_function(provider): if provider == "OpenAI": @@ -97,19 +110,23 @@ def get_embedding_function(provider): def clear_database(provider="OpenAI"): try: - collection_name = f"documents_{provider.lower()}" + if not st.session_state.current_session: + st.error("Please select or create a session first!") + return + + collection_name = get_collection_name(provider, st.session_state.current_session) + # Clear Chroma collection db = Chroma( persist_directory=CHROMA_PATH, embedding_function=get_embedding_function(provider), collection_name=collection_name ) - db.delete_collection() - st.success(f"Successfully cleared {provider} collection!") + st.success(f"Successfully cleared {collection_name} collection!") except Exception as e: - st.error(f"Error clearing {provider} collection: {str(e)}") + st.error(f"Error clearing collection: {str(e)}") def split_documents(documents: list[Document]): text_splitter = RecursiveCharacterTextSplitter( @@ -135,14 +152,21 @@ def get_installed_ollama_models(): return [] +def get_collection_name(provider, session_name): + return f"documents_{provider.lower()}_{session_name}" + def add_to_chroma(chunks: list[Document], provider="OpenAI"): try: if not chunks: st.warning("No documents to add to the database.") return + if not st.session_state.current_session: + st.error("Please select or create a session first!") + return + embedding_function = get_embedding_function(provider) - collection_name = f"documents_{provider.lower()}" # Create separate collections + collection_name = get_collection_name(provider, st.session_state.current_session) db = Chroma( persist_directory=CHROMA_PATH, @@ -157,7 +181,7 @@ def add_to_chroma(chunks: list[Document], provider="OpenAI"): # Verify documents were added collection_size = len(db.get()['ids']) - st.info(f"Added {collection_size} documents to the {provider} collection.") + st.info(f"Added {collection_size} documents to the {collection_name} collection.") except Exception as e: st.error(f"Error adding documents to database: {str(e)}") @@ -184,17 +208,26 @@ def calculate_chunk_ids(chunks): return chunks +def get_data_path(provider, session_name): + return os.path.join("data", provider.lower(), session_name) + def populate_database(files, provider="OpenAI"): + if not st.session_state.current_session: + st.error("Please select or create a session first!") + return + if not os.path.exists(CHROMA_PATH): os.makedirs(CHROMA_PATH) - if not os.path.exists(DATA_PATH): - os.makedirs(DATA_PATH) + # Use the new dynamic data path + data_path = get_data_path(provider, st.session_state.current_session) + if not os.path.exists(data_path): + os.makedirs(data_path) documents = [] for uploaded_file in files: - file_path = os.path.join(DATA_PATH, uploaded_file.name) + file_path = os.path.join(data_path, uploaded_file.name) with open(file_path, "wb") as f: f.write(uploaded_file.getbuffer()) @@ -202,11 +235,10 @@ def populate_database(files, provider="OpenAI"): documents.extend(loader.load()) chunks = split_documents(documents) - add_to_chroma(chunks, provider) - st.success("Database populated successfully!") - + st.success(f"Database populated successfully! Documents saved in {data_path}") + def get_reranked_documents(query: str, provider="OpenAI"): initial_results = get_similar_documents(query, provider); @@ -224,7 +256,11 @@ def get_reranked_documents(query: str, provider="OpenAI"): """ def get_similar_documents(query: str, provider="OpenAI"): - collection_name = f"documents_{provider.lower()}" + if not st.session_state.current_session: + st.error("Please select or create a session first!") + return [] + + collection_name = get_collection_name(provider, st.session_state.current_session) db = Chroma( persist_directory=CHROMA_PATH, @@ -372,14 +408,19 @@ def manage_sessions(): def upload_files(): st.header("Upload Documents") - uploaded_files = st.file_uploader("Upload PDF files", type="pdf", accept_multiple_files=True) - provider = st.selectbox("Provider", ["OpenAI", "Ollama", "PG"]) - - if st.button("Reset Database"): - clear_database(provider) + if not st.session_state.current_session: + st.warning("No active session. Please create or select a session in 'Manage Sessions'.") + else: + st.write(f"### Active Session: {st.session_state.current_session}") + + uploaded_files = st.file_uploader("Upload PDF files", type="pdf", accept_multiple_files=True) + provider = st.selectbox("Provider", ["OpenAI", "Ollama", "PG"]) + + if st.button("Reset Database"): + clear_database(provider) - if st.button("Populate Database") and uploaded_files: - populate_database(uploaded_files, provider) + if st.button("Populate Database") and uploaded_files: + populate_database(uploaded_files, provider) def query_database(): st.header("Query") @@ -387,6 +428,8 @@ def query_database(): if not st.session_state.current_session: st.warning("No active session. Please create or select a session in 'Manage Sessions'.") else: + st.write(f"### Active Session: {st.session_state.current_session}") + session_name = st.session_state.current_session query_text = st.text_input("Enter your query:")