From 649ad5569a7bc0b31ce30c10dbb3b09b1859582f Mon Sep 17 00:00:00 2001 From: Konrad Kolka <s180440@student.pg.edu.pl> Date: Sun, 12 Jan 2025 22:01:56 +0100 Subject: [PATCH] add collections per provider and session name --- chat_history.db | Bin 16384 -> 36864 bytes streamlit_rag.py | 85 +++++++++++++++++++++++++++++++++++------------ 2 files changed, 64 insertions(+), 21 deletions(-) diff --git a/chat_history.db b/chat_history.db index c145c6e4c322d0352570f003664828d908ec7ea7..25e601faa32abe18c62246f1fbca27d864e58994 100644 GIT binary patch literal 36864 zcmeI5OOM=EcE`IV+j;;V#|Vs!kj;e#16x3=rD~1Fc5fV6BRC#QmSQPTCWFAZO;ROR z6^k`|3_}J%U|DS$2@oV%1ZZfVAd9TLGqh9Jm&iBBZf28BkRZQvFUcbNElZlo0JqT3 zD%O4c&+FDb_xySLPp*eqpzg&<q%(DC>GzhFmzVxXsimbQkMC1_pXB>E-^+Zz!guqK z{||e3ORL`U>s+<`z5iL_%RL+d4grUNL%<>65O4@M1RMem0f&G?z#;IPN8rKI@`*oq z`|af~KgjgPB=GL*?DI58(=eW<y?>5f{p{M6JJ-~mD<59Jrg|@_cTTH49(6OHPjwWi zJJ<f~&dQ#b{UQ&N@_^TqAe}LogPKU=Jn@3`fJUO&4<_dIU1K=+DyScS<JW@yD{uVw z8~<hh{LPEP_2>|A2si{B0uBL(fJ49`;1F;KI0PI54grV2i$&m+Wc6QtjUUI4NpAkw z(!VdgKH$gy_s7=qu~TQhIR39!zy3UpClej%S?uePO5$oG_v1oO1C?e;6{M9{1=BE4 zQ?0z1pQDKXMrv^J;zd<W_4w&G#SJ}wP^wSTaC>w9WP58x4d)MddRnTDn2n$yQDghs zI8;}rSsYG-x3%&Ebz6ti?EJOA%x6#kEh*m)hVhEZVinCF#D1y9aT<l0R*Wop`px`F zKHA>&O8!;m-)ZUHSF;Ht&9rxaRsEE=%lHaCe<DLntL;sd1QYG?J~ax`OwE$29O?PD z^M|TZlQ0dU6*bPDev@d$dvmWUW$4>m8}V4?7=;XE{@4pusA7K>7h&L6p`J|5ldGps zUk_7hGZ^bC9|cNq2%_8|vE_yNikd$O=$W?7h_xL>Q8J)x49K)Q3j7ji`I<kbeJ>1F zm?7hrX{ahrGaA(@PK$UHPxXojAqID^;%{lbz5sx?lvizU$#|J>8YCVPyUz?}o?ca- z34($p*O<=13aKLRJmPm|u-B8oz#CVQOcs!gXrpFEL8L0QgihlPUQTC;E=1mTngDFD zmuW5L2uz>SugHnXsWC`Qv~72xGOE(Kyq7tR^jK}^#IIroHB}LlPS_??9$Kh$p2bO- z1sPq;;&lEXPpe@5FwKg1RsDo^>*mL_g??4QI0YE#*eC#hjCBF1<L%8Z&gMllFR8ST z=n+REEek?HR`N_%GLJtFQ@RheLMcUA?8F5G)S?JfL06S13oqs#*A-=?hg#GRai!?P z49$_NWHJKNLfg=hnaEY1_mg}m`WFLpxw;E0QJRHHA%E$>$Y1n(jJo)u*dml#gSQ_E zUJK=ynHs}D*^K@pc2VO*)Jj_81XI97CwPI4c&I(_HHc^1Tf^-wptzl}pLzt>uzvdV z*&XiTte7~OHbjyhGgh&&todU`8LL!Bab^un;7Y4}9O8|_e1WdfegMpy>3QL*`e=vC zS=r%5R<K}Wzq<AwOxgI3l8Lc1jFh#itYZm1>j3E<*AfQnYF4Hdb~r{q<5@Toj69Kn zMV%+)m0=JM6Mc_nz;C9ba79gYrjuzXuZlV2x@rZD3lWp(rdXbm|I#VIGZ2Azb5&t9 z_=YKJP}L)Chl%H+K_j@b_F6bjk$1*#@zUshXk8T&Hp<;|j6v<Yuu?sLEGoCR>#|;} z?n*`%1`~!i#o5)(F6_K(PP^QeI=`QvGOBEx)vCH`ycJGxmunLL*tjmtNc<Gj+ukCe z;K3SaWjs2r1k>btAL*ggL%sdTdX?Ha88|u?w6AaUV{4{8--xazvG@}|<_l-fRbfB? z;eTL9TH>3H=fFJ?1trdC@cv|iSBUWlDwqUOFwJnLjph?!%&5c&c^B#)U<yr3ImT&S zj6}TF$@F~5ZZBb?ndqn0jB$(;#lo;Xe?$I272_f9Xa#Hw9O@{TKc0jJ`NkU-ID;UX z2CI#u#ze;P{Gr52qxoQug)rKvjWlSh(Hilub!e3UlZXYF=wfuEK;<w78%B6P-Z9(W zoRKZUi(-wk#;_i8s`D|CjYUBljC-!i1M&h_JrTFP+ieN*dpe-T^NJ_DHwngRIqJEh zcJxsgF;Jo}9nBJS3RfX0rS#XASD};4tPEI=$QA8q5F6_nJq~$Z!dvT=v6+Qa-o3`y zene2K)tvT5L=j@&1b<?!-pHWVQX^Vx{Mx7spc#b+P+UW4mLse)7<E=o#^AYd@i7f; z4sD<$&>~&qomfn){`{}bhW^=gb=Jp|T}aJ(EnvMyuI+rJC*^rocDf2yXa2qQ!P@$z zvn%TCOb-LzGFanpn^HSk184twMLk1zmoBYe>~weO(#!2`@ZS1br@O&>FSon3i|d0< zcWW14Zg=nYdVBYoJNt`UN&L&u50Z6t>!THQL-M5S>O<m1I6glZe0ctsYpZMLKc3=} zb1!36$LpKdmrtJg;EO-}+N(c;#G-qVb%iXknhy>7gFBsn#Dnd&XtOY68K@8XfblHK zy!KvCnb%&jxCig|1Tc92nI-T{w!`eY%KXx2cX|Cs9c3O|e#x`D*z4{)*v@xR=hykn zzx=&_S>nq*90Cpjhk!%CA>a^j2si{B0uBL(fJ49`@LNXU!QIm*-Wog{y!o??jA(bq zrC<woYxXG{gtGt3mhOxsLNF!cA)A&VJHOL$5@cj`?Uii##%#iO_x)veQMTR5*j4dV zHptmFB+)4;4&I)z&6?$k4P#|;M?-RwW{ccsPgU}go@{W_wdN|@NxrJCY>?!jBes&* z<n-gQ$(*nox=6{&7-X-1L-JK6$!Ol<>j~S!Dx8{ve2=6h`@5!XMefREbtn0TeTEEQ z)1OSY&aBMJOsCcA&pVp|H)p~0%E#^gmi<?KbnDLfAI0+d|COchFP-||WB<7PpZrlj zzWnsH6K4irEx)<PE|;C;%{sR{|CV&wT_XnB<!TZhk|?&BUD@O{7`D69RZ))FR<AUo zA_)-#sU+<dv*yYgcx9$bK3MrAO4<LG6g@E3Xw(UGPqGts3$9?hyU@N(X0%(+WHl{l zX7f7mk+!G_lCPT-8A`G&CLu+l-GWRGi{xUd;zU=Hp}B8!pf-(SB!!AdA1DXe>KcSd z+K_w#S=>O<;dYw9u$Un7GMj%DDJfr$Eo_7vlh9=1Leh(Y8GvNnGRY8dmcVSNOM;J% zy0Sg;zU3D`d=-Iy@$s8KvnMC^HX=YF32SX9s<n#&kgH;Tx2(EyqJdp{lMf!Xa+3TO z`Wrp{#_Q<)LfLzRK1d<|e)P)8GlPHr;a5M{SMzF|$z5$fDuBx*U92^7cteAyBx-a* z=}jhLAY&z#YK~qMDOvV!Nk+)?l2}LN<D^6%CIY;u!8_Nl{^Z;W3e%W-eRP7vxH;{R zaX5o)133*`4M#KNO{;4At9mGs&=IReCv$SM_Nanrnk4_DP$zXNSGtmng&1{VGB7(j zU>Lmwj3id$B<6?7I#0+@k$Lw<lDjAUByNduO(oGH){~H(LMlbKI}Qj;$CBuw2}GSu zfR5}Q<3D|RVb3OSX_Z>#WMRY^N;()jA7F6eR*M<_sxTMYIrRr18Mk@_f<~ve-~*{a zy1D4yz-5w#=A6vY8}7RBe(#;m-?{qPIqMhv;9e+`)r0~f1KF`GPV;Fn#?)FM;|(l5 zM+4&NImwN7Udj-aym0T{IV_zs7Xe@woUJf&+!t8#m>{irH>pw1b&xtw#J;VQu(OrK zCJJLtr@+P1X5>~>MG_Kamc%^?ut?4s@UtjI##WgKHGhsO;)lrb#y$eCm24pZIjDgC zO$jn|B-y{(MG6#!1uLEyWJDzon2K0<UX_>-4Kb^v9HNelSOyfLLu<8qm|QHVhGg65 zq!|mYknmhjNS=-ea~TRqRNM$T(X{Z8s90F#a+YIcH0H(&YK$RFaLUr749lzy1U<>+ za&pBuSg};{t3eVN>*|_-h+bUZIIU1>GT@R5opJmIcY?TR+Y?mxm^v7du~adsXjI|y z`oL(}g<7G0ifxl)E$ZRfqa-rMCh9MIk3>JTMb5n-)s>Eg5~-q?V;J@5nT&|3oIgo% zSB(rLb8bd$RwSH=V<ZyoYKKlZBF;{N8OxA1b!(?ex*xkSvu<UxBh_kCWn*uoO>Xs4 zFFmwJs<VInu2gFO<@I-asnb2uef{eXYwpLr=6<|8yYp<#t@V>r&)eMlz0vKFS=%|f z&VujAi-N!Vo8u?X-2R92ZLpQNCE3hb%uy%K1ZiKbaX6wAX9suXSyV6I^$MD0ZL8yB zWc6Enq;C8Y%jpE@LJ?0`w$ir*XBJptI-3L|mTiZ|Ygqzvh>xXbT4p)tCONc$SCa#h zXt12(Xb*nKX+m*V&5?)R7`lsuPRyi%QR+#|T$rFa#wgCvoF$aC8HaTkyGk)WISpc_ zIuT+mdp4zDfx*WVISXnJ$jC{ox|5K<B3z2k`{n2u>lw2iE+f3f1}C^!$B1wJr14hm zs5#1Kr&*i5@$Y8BjKQ2q?U{NT%&>1U_<fcPtn(Hj!Yq;H5QvR1vV7rSPh1rmFG1^a z*of6&>loxn&|;N>ly`u}F)tGv<V<1)P~vRrpkP*UZAg&*40e^qQL(CIN<1ZTB4!Bf zO1cv#!01-&%b7)b#od+GnVaBfRTj-?R5UNDA*aIRFx#VA5JoXrA2sVVvv`uI%f~c{ z1SXW#>l55yvwku-*r>P&g*#R_<`Y3L*tfy}J@YBmtk?E{tHN_ng0NmNR2T=bM~>He z*1JpOtu2fgWX_Y?!+9LtGn$xYc4=;{kOi`Q)LbCMtZi&qECIs)k-~!)I0r|77wltS za6s8{7Uy|g_T`?-zWlt+z1M4QkE0wfA_6p@_uX%YfJ49`;1F;KI0PI54grUNL%<>6 z5O4@M1RMem0f&G?z#-rea0oaA90Cpjhk!%CA>a^j2si|O=Mgx@41^HyMEzUt^Cv8U z@|=4(1RMem0f&G?z#-rea0oaA90Cpjhk!%CA@Dnn!0U|v_|lh4ul>jIU-9d=@%U<V z>f~D&9{%aqzrPo|Ltw}qn!0pToMtMBED1rk2-wt(36F#D*;C+R868aVtdy{HN`@$0 z5|(z$t3d4;gb!2de56E@f^f<@ArHP`YIZDJ1SLY09@Q)bJ%Ryhgbyn??-h;jS#fSj zhMUSSsm?J2git~Xex>ZmPBu40xHU~RToEm2MX;j?Cx-ArhZlq^<3hhkWc|<zn*HCq ztDbWIOV<yr``Q1!&(h7I6*n)Yn?bJ<XRy1l=f!k$Xu;1x-E_)7Zyvz+u{k^mlgi(| z`zB;5doiih+zN(cWt-bFp!z$=J7%$zcluS!#z9S^sZ!=rAv~dwlbUo5$A{4JQ{9K{ zLh`ZGY%F2Qg6M++N^VjL#|m7#<1A08i!<aKHN%H3;A==er0RMYQ#mdr(1wA-5KJ3B z^9ALC+^e9yJ;5&OkcQQRn|V;_ZA-tQ1Cbjz2+MTT@_FoFD{1`*Z0EI0Xw8+$wzZ`V zY;xm;u`q(zl79&OsQ{-o5x8n9Y=|a;gm39vSO%mz^S-@n1EqKrW^Yzdgb1a6J+vAb z>K(y>@<=WBKA|?!+>*c-LY7*L)>C#EQ6gbA*bPN3x3-?tNZ4?s(5$USm74K@lF3bK z(hb+m;iAK@HFIwTCU+dciCaSOgXGnLvd0~A<l5$c`oW2lZw(#~AA!XIs=gqvX&6Kn z#m+4r;^P-(w+){NrIUwNCtEfVYAmno4QOQKP8@lpZWEjWxKavr*RULfkjud1@o_`@ z13ys}CquLw@_iGV1pP{Vf@1n&8aH4zp<r_u3i-#)Sr$V^GUh2$4T?J6EfJ|u*I!#a zzoPEQRnY%)mFchTQSKH3y^B;Vd-qKnANiU)lu-9O+K<{qQ%W1cw%+0*)9LWVcME~h z^|s*kAi9!X{cEXUKmj?!%x}qC_T{GN@Rx8e695nzB#0&~WyxL3R<g5bx2{X?7eaTa zOL+6mMV;@<AS#hnh_TTlvXe+Pcuuu-YKo!z1SPgsJQq?M?tUPPhR(_Zm)x7mp;RQY zLaNKpDKqc#kjVW2xK9VS<f(~h7Uj!^9ix@eq|tawkz=cn_ook8)Mgv13R8}Z-gJaU zP5X4Vic7b~AiNw*)-qv;@|l@)y=OUXj2}(QC^y_DT)0z_<UO6{8zLM?OkNb*2pY7z zsoJZQP2sy}+>wY-EWt=kUIGyi;T6WQu9R^GmcM1ssli=RJ99*6xJsnvUT*WyHvwA@ zF7!cQ-Q|v<%pcnt(WOv|Sb^5n`x5+N&tWZ4dv}24IU58SmSNUmRQ85P6ubMSXZ4~t z0~+nP5N@Ls%B_>c5$wsXvhTQSww0at{#qCQqK0Y(kIGn&Lg8a0H)|rpuwbXpipGX| z;eHUc9IRd3!KKnS2R@L<4Z1MQj?Te;upkAJ8;PO&+)8G_5M+96G`%Zp7;MCK2wj9A zC;}0}t}YE!z7XV&4+yj2h2T4VI@srpyVmUcTz5Ey%RX;BatDWg@Oa)14yAc{-VP2W Pe%ZH!&PRYFuaf=`C7<va delta 1851 zcmcJP&rcIU6vtcIR3y-1JQ(Hj=mktlrL-1U4iLR~P-zG;QA5bEGus{6?kuyjEmh1y z{6*ryADq1TCy4z={8PLW65lLXTLJ<RH`!#r`)1x}-pu>X=GUIh&G^Tmfl@3MOX3?n zpGGG$rC7fHegI`Wy&p^Ow}uBlq+9#xOYy#VW>1ewX7*a~%FoPTBB6&RQZ&~8<8m=Y zch?@J`rh16^j@Fc8rn8@hUc!f9*%wQ3E$o5>Ag1py1G5H#1!*Y!xqpWKFGjzxmyL| z0(a{e;8miyaAy~71|Ach!#V>vQemJLNXjHMm}Ia({8iz&d{xv(e6S>EF7<Ob$#ep$ zQ!kL`L0IE9cacKjiST`16ZISrM_{<Yl^rGZgu_h+qmp@&`OH;}!b&p=3CrRpCk{l* zwz=;inSv6XbRa7tP=GTdtR*B6pu$u_-VF1A1vL*=tYlNK3uk*$)EN?p(ikj8Z;2~_ z)j4G@E{0SIcl{^|9)2pUnwVzbBVAb0+Dyr7EdTUHmeZ^OSxOYSbyxpe#%h-<fy;fn z7GgQei+PV)#zbk-7|-S)>yawM8K_bCB9Nw*CnRXk&%VsT-&yUUh?+V?g}+%-|BF>D z-0mt1bq6ZE(7Gne<t`TL4%Ts~P#;2jW6raO#ffef>Qh)pP)ikTvJm@88ZhrTq(*Y^ zKv3ow5UFCmU^qm3^6gk0J^Jd(_RQn6d9)&=#bpyJOd*3S(>}@5VPxt&+(S3}K$(wS zyMjH8?XJkGc7sQxLem0UCpCU}B~Vxk^pzmAM!%wa(fpWzIKBYMEG7?ah<UR5Vz3!$ z8-rpIrs8NY%0iu9A3#<!^hi^nnby^5#%Z6KQ64WCg%01KKKnt0)H=_hPZy`VSg6m= zIt~>cy%7G?HBlb#YG%~`#ySoa9^5_UVyUYx)ECb>K$WIC%PJaxKhDS6d3I;y^IS3& VPi({z8@p>Vk@yl%w!Xnu|1b2CYtaAz diff --git a/streamlit_rag.py b/streamlit_rag.py index 95ae3d8..8a9ecc3 100755 --- a/streamlit_rag.py +++ b/streamlit_rag.py @@ -30,7 +30,6 @@ auth_kwargs = { } CHROMA_PATH = "chroma" -DATA_PATH = "data" #RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2" #reranker = CrossEncoder(RERANKER_MODEL) @@ -79,12 +78,26 @@ def load_chat_history(): return chat_sessions def delete_session(session_name): + # Delete from SQLite database conn = sqlite3.connect(DATABASE_PATH) c = conn.cursor() c.execute("DELETE FROM chat_sessions WHERE session_name = ?", (session_name,)) conn.commit() conn.close() + # Clear Chroma collections for all providers + providers = ["OpenAI", "Ollama", "PG"] + for provider in providers: + try: + collection_name = get_collection_name(provider, session_name) + db = Chroma( + persist_directory=CHROMA_PATH, + embedding_function=get_embedding_function(provider), + collection_name=collection_name + ) + db.delete_collection() + except Exception as e: + st.warning(f"Could not delete collection for {provider}: {str(e)}") def get_embedding_function(provider): if provider == "OpenAI": @@ -97,19 +110,23 @@ def get_embedding_function(provider): def clear_database(provider="OpenAI"): try: - collection_name = f"documents_{provider.lower()}" + if not st.session_state.current_session: + st.error("Please select or create a session first!") + return + + collection_name = get_collection_name(provider, st.session_state.current_session) + # Clear Chroma collection db = Chroma( persist_directory=CHROMA_PATH, embedding_function=get_embedding_function(provider), collection_name=collection_name ) - db.delete_collection() - st.success(f"Successfully cleared {provider} collection!") + st.success(f"Successfully cleared {collection_name} collection!") except Exception as e: - st.error(f"Error clearing {provider} collection: {str(e)}") + st.error(f"Error clearing collection: {str(e)}") def split_documents(documents: list[Document]): text_splitter = RecursiveCharacterTextSplitter( @@ -135,14 +152,21 @@ def get_installed_ollama_models(): return [] +def get_collection_name(provider, session_name): + return f"documents_{provider.lower()}_{session_name}" + def add_to_chroma(chunks: list[Document], provider="OpenAI"): try: if not chunks: st.warning("No documents to add to the database.") return + if not st.session_state.current_session: + st.error("Please select or create a session first!") + return + embedding_function = get_embedding_function(provider) - collection_name = f"documents_{provider.lower()}" # Create separate collections + collection_name = get_collection_name(provider, st.session_state.current_session) db = Chroma( persist_directory=CHROMA_PATH, @@ -157,7 +181,7 @@ def add_to_chroma(chunks: list[Document], provider="OpenAI"): # Verify documents were added collection_size = len(db.get()['ids']) - st.info(f"Added {collection_size} documents to the {provider} collection.") + st.info(f"Added {collection_size} documents to the {collection_name} collection.") except Exception as e: st.error(f"Error adding documents to database: {str(e)}") @@ -184,17 +208,26 @@ def calculate_chunk_ids(chunks): return chunks +def get_data_path(provider, session_name): + return os.path.join("data", provider.lower(), session_name) + def populate_database(files, provider="OpenAI"): + if not st.session_state.current_session: + st.error("Please select or create a session first!") + return + if not os.path.exists(CHROMA_PATH): os.makedirs(CHROMA_PATH) - if not os.path.exists(DATA_PATH): - os.makedirs(DATA_PATH) + # Use the new dynamic data path + data_path = get_data_path(provider, st.session_state.current_session) + if not os.path.exists(data_path): + os.makedirs(data_path) documents = [] for uploaded_file in files: - file_path = os.path.join(DATA_PATH, uploaded_file.name) + file_path = os.path.join(data_path, uploaded_file.name) with open(file_path, "wb") as f: f.write(uploaded_file.getbuffer()) @@ -202,11 +235,10 @@ def populate_database(files, provider="OpenAI"): documents.extend(loader.load()) chunks = split_documents(documents) - add_to_chroma(chunks, provider) - st.success("Database populated successfully!") - + st.success(f"Database populated successfully! Documents saved in {data_path}") + def get_reranked_documents(query: str, provider="OpenAI"): initial_results = get_similar_documents(query, provider); @@ -224,7 +256,11 @@ def get_reranked_documents(query: str, provider="OpenAI"): """ def get_similar_documents(query: str, provider="OpenAI"): - collection_name = f"documents_{provider.lower()}" + if not st.session_state.current_session: + st.error("Please select or create a session first!") + return [] + + collection_name = get_collection_name(provider, st.session_state.current_session) db = Chroma( persist_directory=CHROMA_PATH, @@ -372,14 +408,19 @@ def manage_sessions(): def upload_files(): st.header("Upload Documents") - uploaded_files = st.file_uploader("Upload PDF files", type="pdf", accept_multiple_files=True) - provider = st.selectbox("Provider", ["OpenAI", "Ollama", "PG"]) - - if st.button("Reset Database"): - clear_database(provider) + if not st.session_state.current_session: + st.warning("No active session. Please create or select a session in 'Manage Sessions'.") + else: + st.write(f"### Active Session: {st.session_state.current_session}") + + uploaded_files = st.file_uploader("Upload PDF files", type="pdf", accept_multiple_files=True) + provider = st.selectbox("Provider", ["OpenAI", "Ollama", "PG"]) + + if st.button("Reset Database"): + clear_database(provider) - if st.button("Populate Database") and uploaded_files: - populate_database(uploaded_files, provider) + if st.button("Populate Database") and uploaded_files: + populate_database(uploaded_files, provider) def query_database(): st.header("Query") @@ -387,6 +428,8 @@ def query_database(): if not st.session_state.current_session: st.warning("No active session. Please create or select a session in 'Manage Sessions'.") else: + st.write(f"### Active Session: {st.session_state.current_session}") + session_name = st.session_state.current_session query_text = st.text_input("Enter your query:") -- GitLab