From 64b20d45ba99d926d9c75bb1d507814cd205d8cd Mon Sep 17 00:00:00 2001
From: Konrad Kolka <s180440@student.pg.edu.pl>
Date: Sun, 12 Jan 2025 23:04:12 +0100
Subject: [PATCH] allow user to change number of docuemtents retrieved from db
 and reranking

---
 streamlit_rag.py | 65 ++++++++++++++++++++++++++++++++++++------------
 1 file changed, 49 insertions(+), 16 deletions(-)

diff --git a/streamlit_rag.py b/streamlit_rag.py
index 8a9ecc3..2021384 100755
--- a/streamlit_rag.py
+++ b/streamlit_rag.py
@@ -242,8 +242,15 @@ def populate_database(files, provider="OpenAI"):
 def get_reranked_documents(query: str, provider="OpenAI"):
     initial_results = get_similar_documents(query, provider);
     
-    return initial_results
+    # Sort by score (distance) - lowest distance first
+    sorted_results = sorted(initial_results, key=lambda x: x[1])
+    
+    # Print similarity scores for debugging
+    #for doc, score in sorted_results:
+        #st.write(f"Similarity score: {score:.4f} - {doc.page_content[:100]}...")
 
+    return sorted_results
+    
     """
     query_doc_pairs = [(query, doc.page_content) for doc, _ in initial_results]
     scores = reranker.predict(query_doc_pairs)
@@ -268,7 +275,8 @@ def get_similar_documents(query: str, provider="OpenAI"):
          collection_name=collection_name
     )
     
-    return db.similarity_search_with_score(query, k=10)
+    k = st.session_state.get('chroma_k', 10)
+    return db.similarity_search_with_score(query, k=k)
 
 def PG(prompt):
     base_url = "https://153.19.239.239/api/llm/prompt/chat"
@@ -303,7 +311,8 @@ def query_rag(query, provider="OpenAI", model="GPT-4o"):
     
     reranked_results = get_reranked_documents(query, provider)
 
-    top_results = reranked_results[:5]
+    k = st.session_state.get('rerank_k', 5)
+    top_results = reranked_results[:k]
 
     prompt_template = ChatPromptTemplate.from_template(
         """
@@ -429,25 +438,49 @@ def query_database():
         st.warning("No active session. Please create or select a session in 'Manage Sessions'.")
     else:
         st.write(f"### Active Session: {st.session_state.current_session}")
-        
+         
         session_name = st.session_state.current_session
 
         query_text = st.text_input("Enter your query:")
-        provider = st.selectbox("Provider", ["OpenAI", "Ollama", "PG"])
-
-        if provider == "OpenAI":
-            model = st.selectbox("Model", ["OpenAI GPT-4o"])
-        elif provider == "Ollama":
-            installed_models = get_installed_ollama_models()
-            if not installed_models:
-                st.error("No Ollama models found. Please ensure Ollama is running and has models installed.")
-                return
-            model = st.selectbox("Model", installed_models)
-        elif provider == "PG":
-            model = st.selectbox("Model", ["Bielik-11B-v2.2-Instruct model"])
+        
+        # Create columns for provider and model selection
+        col1, col2 = st.columns(2)
+        with col1:
+            provider = st.selectbox("Provider", ["OpenAI", "Ollama", "PG"])
+        with col2:
+            if provider == "OpenAI":
+                model = st.selectbox("Model", ["OpenAI GPT-4o"])
+            elif provider == "Ollama":
+                installed_models = get_installed_ollama_models()
+                if not installed_models:
+                    st.error("No Ollama models found. Please ensure Ollama is running and has models installed.")
+                    return
+                model = st.selectbox("Model", installed_models)
+            elif provider == "PG":
+                model = st.selectbox("Model", ["Bielik-11B-v2.2-Instruct model"])
+
+        # Create columns for document retrieval settings
+        col3, col4 = st.columns(2)
+        with col3:
+            chroma_k = st.number_input(
+                "Documents retrieved from ChromaDB",
+                min_value=1,
+                max_value=20,
+                value=10
+            )
+        with col4:
+            rerank_k = st.number_input(
+                "Documents retrieved from reranking",
+                min_value=1,
+                max_value=chroma_k,
+                value=min(5, chroma_k)
+            )
 
         if st.button("Submit Query") and query_text:
             with st.spinner("Retrieving information..."):
+                # Update the function calls to use the new k values
+                st.session_state.chroma_k = chroma_k
+                st.session_state.rerank_k = rerank_k
                 response, sources = query_rag(query_text, provider, model)
             modelInfo = f"Provider: {provider}, Model: {model}"
             
-- 
GitLab