From 64b20d45ba99d926d9c75bb1d507814cd205d8cd Mon Sep 17 00:00:00 2001 From: Konrad Kolka <s180440@student.pg.edu.pl> Date: Sun, 12 Jan 2025 23:04:12 +0100 Subject: [PATCH] allow user to change number of docuemtents retrieved from db and reranking --- streamlit_rag.py | 65 ++++++++++++++++++++++++++++++++++++------------ 1 file changed, 49 insertions(+), 16 deletions(-) diff --git a/streamlit_rag.py b/streamlit_rag.py index 8a9ecc3..2021384 100755 --- a/streamlit_rag.py +++ b/streamlit_rag.py @@ -242,8 +242,15 @@ def populate_database(files, provider="OpenAI"): def get_reranked_documents(query: str, provider="OpenAI"): initial_results = get_similar_documents(query, provider); - return initial_results + # Sort by score (distance) - lowest distance first + sorted_results = sorted(initial_results, key=lambda x: x[1]) + + # Print similarity scores for debugging + #for doc, score in sorted_results: + #st.write(f"Similarity score: {score:.4f} - {doc.page_content[:100]}...") + return sorted_results + """ query_doc_pairs = [(query, doc.page_content) for doc, _ in initial_results] scores = reranker.predict(query_doc_pairs) @@ -268,7 +275,8 @@ def get_similar_documents(query: str, provider="OpenAI"): collection_name=collection_name ) - return db.similarity_search_with_score(query, k=10) + k = st.session_state.get('chroma_k', 10) + return db.similarity_search_with_score(query, k=k) def PG(prompt): base_url = "https://153.19.239.239/api/llm/prompt/chat" @@ -303,7 +311,8 @@ def query_rag(query, provider="OpenAI", model="GPT-4o"): reranked_results = get_reranked_documents(query, provider) - top_results = reranked_results[:5] + k = st.session_state.get('rerank_k', 5) + top_results = reranked_results[:k] prompt_template = ChatPromptTemplate.from_template( """ @@ -429,25 +438,49 @@ def query_database(): st.warning("No active session. Please create or select a session in 'Manage Sessions'.") else: st.write(f"### Active Session: {st.session_state.current_session}") - + session_name = st.session_state.current_session query_text = st.text_input("Enter your query:") - provider = st.selectbox("Provider", ["OpenAI", "Ollama", "PG"]) - - if provider == "OpenAI": - model = st.selectbox("Model", ["OpenAI GPT-4o"]) - elif provider == "Ollama": - installed_models = get_installed_ollama_models() - if not installed_models: - st.error("No Ollama models found. Please ensure Ollama is running and has models installed.") - return - model = st.selectbox("Model", installed_models) - elif provider == "PG": - model = st.selectbox("Model", ["Bielik-11B-v2.2-Instruct model"]) + + # Create columns for provider and model selection + col1, col2 = st.columns(2) + with col1: + provider = st.selectbox("Provider", ["OpenAI", "Ollama", "PG"]) + with col2: + if provider == "OpenAI": + model = st.selectbox("Model", ["OpenAI GPT-4o"]) + elif provider == "Ollama": + installed_models = get_installed_ollama_models() + if not installed_models: + st.error("No Ollama models found. Please ensure Ollama is running and has models installed.") + return + model = st.selectbox("Model", installed_models) + elif provider == "PG": + model = st.selectbox("Model", ["Bielik-11B-v2.2-Instruct model"]) + + # Create columns for document retrieval settings + col3, col4 = st.columns(2) + with col3: + chroma_k = st.number_input( + "Documents retrieved from ChromaDB", + min_value=1, + max_value=20, + value=10 + ) + with col4: + rerank_k = st.number_input( + "Documents retrieved from reranking", + min_value=1, + max_value=chroma_k, + value=min(5, chroma_k) + ) if st.button("Submit Query") and query_text: with st.spinner("Retrieving information..."): + # Update the function calls to use the new k values + st.session_state.chroma_k = chroma_k + st.session_state.rerank_k = rerank_k response, sources = query_rag(query_text, provider, model) modelInfo = f"Provider: {provider}, Model: {model}" -- GitLab