From bc528b4541f6e835b339361bc03cb02138c269f7 Mon Sep 17 00:00:00 2001 From: Anirudha Acharya <127017012+anirudhaacharyap@users.noreply.github.com> Date: Mon, 11 May 2026 12:36:13 +0530 Subject: [PATCH 1/3] feat: add per-user ingestion coordination to prevent race conditions Introduce UserIngestionCoordinator that serialises ingestion pipeline execution per user_id using async FIFO locks, while allowing different users to proceed in parallel. - New src/api/ingestion_coordinator.py with lazy lock creation, automatic cleanup, and a clean async context-manager interface - Wrap /v1/memory/ingest and /v1/memory/batch-ingest routes with per-user lock (existing global Semaphore(5) retained as backpressure) - Wrap both server.py test-frontend ingest routes with coordinator - Prevents profile overwrites, temporal duplicates, and summary drift caused by concurrent cross-source requests for the same user Closes #per-user-coordination --- server.py | 30 +++++----- src/api/ingestion_coordinator.py | 94 ++++++++++++++++++++++++++++++++ src/api/routes/memory.py | 70 +++++++++++++----------- 3 files changed, 148 insertions(+), 46 deletions(-) create mode 100644 src/api/ingestion_coordinator.py diff --git a/server.py b/server.py index 4bd1859..abceff3 100644 --- a/server.py +++ b/server.py @@ -50,6 +50,7 @@ from src.pipelines.ingest import IngestPipeline from src.pipelines.retrieval import RetrievalPipeline +from src.api.ingestion_coordinator import UserIngestionCoordinator # ═══════════════════════════════════════════════════════════════════ @@ -82,6 +83,7 @@ def emit(self, record: logging.LogRecord) -> None: _pipelines_ready = asyncio.Event() _init_error: str | None = None SKIP_PIPELINES = os.getenv("XMEM_SKIP_PIPELINES", "").lower() in {"1", "true", "yes"} +_user_coordinator = UserIngestionCoordinator() def _init_pipelines_sync() -> None: @@ -315,14 +317,15 @@ async def v1_ingest_memory(req: IngestRequest): ) try: - result = await ingest_pipeline.run( - user_query=req.user_query, - agent_response=req.agent_response or "Acknowledged.", - user_id=req.user_id, - session_datetime=req.session_datetime, - image_url=req.image_url, - effort_level=req.effort_level, - ) + async with _user_coordinator.acquire(req.user_id): + result = await ingest_pipeline.run( + user_query=req.user_query, + agent_response=req.agent_response or "Acknowledged.", + user_id=req.user_id, + session_datetime=req.session_datetime, + image_url=req.image_url, + effort_level=req.effort_level, + ) data = { "model": _get_model_name(ingest_pipeline.model), @@ -368,11 +371,12 @@ async def api_ingest(req: IngestRequest): lg.addHandler(capture) try: - result = await ingest_pipeline.run( - user_query=req.user_query, - agent_response=req.agent_response or "Acknowledged.", - user_id=req.user_id, - ) + async with _user_coordinator.acquire(req.user_id): + result = await ingest_pipeline.run( + user_query=req.user_query, + agent_response=req.agent_response or "Acknowledged.", + user_id=req.user_id, + ) # Build structured response response: Dict[str, Any] = { diff --git a/src/api/ingestion_coordinator.py b/src/api/ingestion_coordinator.py new file mode 100644 index 0000000..1503a3b --- /dev/null +++ b/src/api/ingestion_coordinator.py @@ -0,0 +1,94 @@ +""" +Per-user ingestion coordinator — serialises ingestion for each user. + +Guarantees that only one ingestion pipeline runs at a time for any given +``user_id``, while allowing different users to proceed in parallel. +Requests for the same user are processed in strict FIFO order. + +This is the **in-memory** implementation (Option 1). A future distributed +lock (Redis, etc.) can be swapped in by implementing the same ``acquire()`` +context-manager interface. + +Usage:: + + from src.api.ingestion_coordinator import UserIngestionCoordinator + + coordinator = UserIngestionCoordinator() + + async with coordinator.acquire(user_id): + result = await pipeline.run(...) +""" + +from __future__ import annotations + +import asyncio +import logging +from contextlib import asynccontextmanager +from typing import AsyncIterator, Dict + +logger = logging.getLogger("xmem.api.ingestion_coordinator") + + +class UserIngestionCoordinator: + """Per-user FIFO ingestion lock. + + Internally maintains a ``dict[str, asyncio.Lock]`` keyed by ``user_id``. + Locks are created lazily on first access and removed once no tasks are + waiting or holding them, preventing unbounded memory growth. + + Thread-safety note + ------------------ + All mutations to the internal registry are protected by a single + ``asyncio.Lock`` (the *registry lock*). Since this code runs on the + asyncio event loop, ``asyncio.Lock`` is sufficient — no OS-level + threading primitives are needed. + """ + + def __init__(self) -> None: + # Maps user_id -> (asyncio.Lock, active_count) + # active_count tracks how many tasks are either holding or waiting + # for the lock so we know when it's safe to clean up. + self._locks: Dict[str, asyncio.Lock] = {} + self._waiters: Dict[str, int] = {} + self._registry_lock = asyncio.Lock() + + @asynccontextmanager + async def acquire(self, user_id: str) -> AsyncIterator[None]: + """Acquire the per-user ingestion lock. + + Usage:: + + async with coordinator.acquire("user_123"): + # Only one coroutine per user_id reaches here at a time. + await do_work() + + The lock is automatically released (and cleaned up if idle) when + the ``async with`` block exits, even if an exception is raised. + """ + # ── Get-or-create the user lock ────────────────────────────── + async with self._registry_lock: + if user_id not in self._locks: + self._locks[user_id] = asyncio.Lock() + self._waiters[user_id] = 0 + self._waiters[user_id] += 1 + user_lock = self._locks[user_id] + + logger.debug("User %s: waiting for ingestion lock (waiters=%d)", user_id, self._waiters.get(user_id, 0)) + + try: + async with user_lock: + logger.debug("User %s: ingestion lock acquired", user_id) + yield + finally: + # ── Cleanup: remove the lock if nobody else is waiting ──── + async with self._registry_lock: + self._waiters[user_id] -= 1 + if self._waiters[user_id] <= 0: + self._locks.pop(user_id, None) + self._waiters.pop(user_id, None) + logger.debug("User %s: ingestion lock cleaned up", user_id) + + @property + def active_users(self) -> int: + """Return the number of users with active or pending ingestion locks.""" + return len(self._locks) diff --git a/src/api/routes/memory.py b/src/api/routes/memory.py index 3397a69..d8a2f0e 100644 --- a/src/api/routes/memory.py +++ b/src/api/routes/memory.py @@ -41,6 +41,7 @@ WeaverSummary, ) from src.pipelines.retrieval import RetrievalPipeline +from src.api.ingestion_coordinator import UserIngestionCoordinator from bs4 import BeautifulSoup import json @@ -50,6 +51,7 @@ logger = logging.getLogger("xmem.api.routes.memory") _ingest_semaphore = asyncio.Semaphore(5) +_user_coordinator = UserIngestionCoordinator() router = APIRouter( prefix="/v1/memory", @@ -540,17 +542,18 @@ async def ingest_memory(req: IngestRequest, request: Request, user: dict = Depen try: async with _ingest_semaphore: - result = await asyncio.wait_for( - pipeline.run( - user_query=req.user_query, - agent_response=req.agent_response or "Acknowledged.", - user_id=user_id, - session_datetime=req.session_datetime, - image_url=req.image_url, - effort_level=req.effort_level, - ), - timeout=120.0 - ) + async with _user_coordinator.acquire(user_id): + result = await asyncio.wait_for( + pipeline.run( + user_query=req.user_query, + agent_response=req.agent_response or "Acknowledged.", + user_id=user_id, + session_datetime=req.session_datetime, + image_url=req.image_url, + effort_level=req.effort_level, + ), + timeout=120.0 + ) data = IngestResponse( model=_model_name(pipeline.model), classification=_safe_classifications(result), @@ -588,28 +591,29 @@ async def batch_ingest_memory(req: BatchIngestRequest, request: Request, user: d results = [] - for item in req.items: - result = await asyncio.wait_for( - pipeline.run( - user_query=item.user_query, - agent_response=item.agent_response or "Acknowledged.", - user_id=user_id, - session_datetime=item.session_datetime, - image_url=item.image_url, - effort_level=item.effort_level, - ), - timeout=120.0 - ) - - data = IngestResponse( - model=_model_name(pipeline.model), - classification=_safe_classifications(result), - profile=_build_domain_result(result.get("profile_judge"), result.get("profile_weaver")), - temporal=_build_domain_result(result.get("temporal_judge"), result.get("temporal_weaver")), - summary=_build_domain_result(result.get("summary_judge"), result.get("summary_weaver")), - image=_build_domain_result(result.get("image_judge"), result.get("image_weaver")), - ) - results.append(data) + async with _user_coordinator.acquire(user_id): + for item in req.items: + result = await asyncio.wait_for( + pipeline.run( + user_query=item.user_query, + agent_response=item.agent_response or "Acknowledged.", + user_id=user_id, + session_datetime=item.session_datetime, + image_url=item.image_url, + effort_level=item.effort_level, + ), + timeout=120.0 + ) + + data = IngestResponse( + model=_model_name(pipeline.model), + classification=_safe_classifications(result), + profile=_build_domain_result(result.get("profile_judge"), result.get("profile_weaver")), + temporal=_build_domain_result(result.get("temporal_judge"), result.get("temporal_weaver")), + summary=_build_domain_result(result.get("summary_judge"), result.get("summary_weaver")), + image=_build_domain_result(result.get("image_judge"), result.get("image_weaver")), + ) + results.append(data) response_data = BatchIngestResponse(results=results) From cae0a31e984fdefa94365e5aff77734c1c64f504 Mon Sep 17 00:00:00 2001 From: Anirudha Acharya <127017012+anirudhaacharyap@users.noreply.github.com> Date: Fri, 22 May 2026 09:21:12 +0530 Subject: [PATCH 2/3] Optimize locking order in ingest_memory and add backpressure & structured error handling to batch_ingest_memory --- debug_test.py | 36 ++++++++++++++ src/api/routes/memory.py | 66 ++++++++++++++------------ test_output.txt | Bin 0 -> 30148 bytes xlsx.py | 98 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 170 insertions(+), 30 deletions(-) create mode 100644 debug_test.py create mode 100644 test_output.txt create mode 100644 xlsx.py diff --git a/debug_test.py b/debug_test.py new file mode 100644 index 0000000..2ffd5b2 --- /dev/null +++ b/debug_test.py @@ -0,0 +1,36 @@ +import asyncio +from fastapi.testclient import TestClient +from unittest.mock import AsyncMock, patch + +from src.api.app import create_app + +app = create_app() +client = TestClient(app) + +with patch("src.api.routes.memory.require_api_key", return_value={"username": "test_user"}): + from src.api.dependencies import require_api_key, enforce_rate_limit, require_ready + app.dependency_overrides[require_api_key] = lambda: {"username": "test_user"} + app.dependency_overrides[enforce_rate_limit] = lambda: True + app.dependency_overrides[require_ready] = lambda: True + + payload = { + "items": [ + { + "user_query": "Hello world", + "agent_response": "Hi there", + "user_id": "test_user_1", + } + ] + } + + try: + response = client.post( + "/v1/memory/batch-ingest", + json=payload, + headers={"Authorization": "Bearer test-key"} + ) + print("Status code:", response.status_code) + import json + print(json.dumps(response.json(), indent=2)) + except Exception as e: + print("Exception:", e) diff --git a/src/api/routes/memory.py b/src/api/routes/memory.py index d8a2f0e..35231c8 100644 --- a/src/api/routes/memory.py +++ b/src/api/routes/memory.py @@ -541,8 +541,8 @@ async def ingest_memory(req: IngestRequest, request: Request, user: dict = Depen user_id = user.get("username") or user.get("name") or user["id"] try: - async with _ingest_semaphore: - async with _user_coordinator.acquire(user_id): + async with _user_coordinator.acquire(user_id): + async with _ingest_semaphore: result = await asyncio.wait_for( pipeline.run( user_query=req.user_query, @@ -591,34 +591,40 @@ async def batch_ingest_memory(req: BatchIngestRequest, request: Request, user: d results = [] - async with _user_coordinator.acquire(user_id): - for item in req.items: - result = await asyncio.wait_for( - pipeline.run( - user_query=item.user_query, - agent_response=item.agent_response or "Acknowledged.", - user_id=user_id, - session_datetime=item.session_datetime, - image_url=item.image_url, - effort_level=item.effort_level, - ), - timeout=120.0 - ) - - data = IngestResponse( - model=_model_name(pipeline.model), - classification=_safe_classifications(result), - profile=_build_domain_result(result.get("profile_judge"), result.get("profile_weaver")), - temporal=_build_domain_result(result.get("temporal_judge"), result.get("temporal_weaver")), - summary=_build_domain_result(result.get("summary_judge"), result.get("summary_weaver")), - image=_build_domain_result(result.get("image_judge"), result.get("image_weaver")), - ) - results.append(data) - - response_data = BatchIngestResponse(results=results) - - elapsed = round((time.perf_counter() - start) * 1000, 2) - return _wrap(request, response_data, elapsed) + try: + async with _user_coordinator.acquire(user_id): + for item in req.items: + async with _ingest_semaphore: + result = await asyncio.wait_for( + pipeline.run( + user_query=item.user_query, + agent_response=item.agent_response or "Acknowledged.", + user_id=user_id, + session_datetime=item.session_datetime, + image_url=item.image_url, + effort_level=item.effort_level, + ), + timeout=120.0 + ) + + data = IngestResponse( + model=_model_name(pipeline.model), + classification=_safe_classifications(result), + profile=_build_domain_result(result.get("profile_judge"), result.get("profile_weaver")), + temporal=_build_domain_result(result.get("temporal_judge"), result.get("temporal_weaver")), + summary=_build_domain_result(result.get("summary_judge"), result.get("summary_weaver")), + image=_build_domain_result(result.get("image_judge"), result.get("image_weaver")), + ) + results.append(data) + + response_data = BatchIngestResponse(results=results) + elapsed = round((time.perf_counter() - start) * 1000, 2) + return _wrap(request, response_data, elapsed) + + except Exception as exc: + elapsed = round((time.perf_counter() - start) * 1000, 2) + logger.exception("Batch ingest failed for user=%s", user_id) + return _error(request, str(exc), 500, elapsed) diff --git a/test_output.txt b/test_output.txt new file mode 100644 index 0000000000000000000000000000000000000000..77f325d02e72079851c7f77e7ae91dc1873cf335 GIT binary patch literal 30148 zcmeI5Yi}IKm4^HC0{b5ZGy==nXvB*sNabeqeWrJg7S&O8QXQx9 zd38{o>D)r2sSxf}|62V^by~gAT<6tG&3#tA)qLLz&T+LOxn1a|wN? znpayoZ|}d-`O|dXu)SX0)|s2>tetyPy{lg7b)mVo^!+XUZ6?@fn(;(4zYxT;>aJ$` zMDNe4|ExY2#c#7CFnyljztpv_gzb^=K1}xj{Z#j#RNv~G=bCS>zZ2d6=88fvfT~5h z<3L9j8lMOw!*o001is(s%nQwSpmRVx)0~Dk6mM5t1ILav<3ccw5@)W9qItD@HOqm{ zUFe$Umr$Uazs+<%^M0pU_pZK$zVaTp@mANp5d3@MhaJtU`_)&~lj@IYp7M_UbnQ#c z^G+C^SNn-e?{)MpQL(RY;rFS|ED~*}Y3@JhyJHKGVlgs9MDOgZ_p6EVK{D;~% zdYNnvh2<>i7Mw%Mi*ywn@YxrVydUpZUkj^q&H0V+Mux9+7wcl4^lnytrq}HM==D@m z+tlk9)t|1goW9=TPO!dTJ**x{nqO8As>i}V&aQFjQ`?-Q-NJCf_C)yF#Ck{mXx-z1D!Z_lgOH9gyxo*n7$FhOFaIJWnNudQ;{CwI<7 z&4F~9HTz-GRMugxwawYCRUheEtUNHDr#l~Jdj3TXI@ca)6#-#KxbMoMZt3r~Ub~6L zI~sj>NfVs*)qyQ}sgZMZxTLQY(!x><<<&>BxH%4v%SOe^r3g;nT*~Y1rTv2E=mFG% zF_vN`Jg~hp$(b{yK4N{aG#BZfuoPx3zF35$(Z3{Q2{g=Ra^CpHTW{afpm zK~rG5kvz@Q&g=cY?pggR|J)FjL@~1K*E@5|p6PPDAKtoZlKe`ws)LSSQyXi+YPs z&=l4`@tW~)JXePB{wv9X^?W8cL_5R>;~@oq*UA63_3oNwe$+?)`9q@rRbkz5YPZ8+|pi?b}mHYb9r>*>s%w` zY1q#bPrwRm#<{~}-P&b=rW@uK-5yIy-)UZQ4Set41kH8ZJ@RtvvQhmYAKdPzoTd*1 z?>{s*9{x}-c+a}Sn~;AHIm6dwd_ouDqstZ^oN63AI4jnG!}U+U9xqfzLUso2{Eg1F zz3hs43vL(UTV4Ma)zc)Q(*UFZ8{6PI9KV;?)a%;;Ej9a%9!E zs@iB6_T8VW@h>oV(8{-&Y>8Ql_qu`=igtoMk*irCRtJ1}uWyZ4d=h1mLs|?11pa)T zj>&KCDwf+xFxzlp)-f@7*ov;$4{fHpaHP*Sn!k*)#zp;3d#-Hu;_U}H>+0TTb#a=l^u~P65i{zz<0U|cUDquUn zd!~1{Ds6Dcb%^En)BRLdJX5y%nd9T1QN1xggcWi>PZZA#CN&XX^IqjzdUg zR{eXT1K$)jpXlXDk{J>RskvA8ybp*Q;#n8HL^d9^^ktNe-$wHJ7Gps5QR22&**rtD zx}U{4Ue)Wuq(ABo?ZMA^<*ix6AYMbe=F%8438H6e4KMXhHP9@6%+{`pIj;G!a{Ma(7%#0&qS?gnXPL=1t%uKF%(Q4Di zXBORd9eH@Z@n&E3qR6l2fh_dqKkVeW(mvRZ9S+)qOo}C#bqtvzPuAL zA?xL+`fJjC>!x~D;Nx@`yr5#_ncDTs+7A_iHp)v+-*$!hY-aU!n)4ceNe)Cju+Um0 zP6TsgVDSq*i?(^mqq);2^_98zz)TdY{?tWtT{ z7Jp{ylX(>{&vo*Z966g+#=N5HHMo0M653Ci4yVBy8+jlu93~rK9kgr7a-fCg*EnmL zPjHKB5?AEe2^pF7i#C?-BY38}@UYL4M2_?s^*3}Uw_C;qt2E*xt@C$F9J!M+CBDgf zN%cNoJ{PpQTzo~?%6+I=Jg)vyl;+l<-P&STu@P@%A@&omY_zYLtoKN!TXrm|uKHac z)uOMtqV~CZuQ!FE)49+Jx9@YzodV(o>rsHg*ODllf1zVCa?ZhA=RdnvFQy3PosUu$rC$tgiQ?Dbdf2ZU7;^C>P zMP_B6XyZi``a(SbOYWn2y0?ytSJHh3d?qmFmlbyZp4icL z{@Bi-YnKoYlH44I;b}gFehbxkKWB1j!p94uE#TL6)-rX+$NYKNZR)gJ`YZY2d#I4P z+4t!GJWR4hOV1_maxKE)SVPWl?x5R)ty-k`0{zCqTYnn=9Pqj3Tu_l5*8izS^bU+C zac2MOdzZMuaLHJYyvU+&6V?ycl6bH8Es>mb!9O;3meVyE5w>I}Qzup;d~)IHg3WWUJQ zJ@>ocyrbi8zZw3A`~G}ot(y`U>WvyB?;>#yApvA-ynCUs)-LB{QzKgccWoVm=f2j^E@2wC`1vxn#N{{M4~Dkl_s z#)0Ryyla@S*Vgs2|BK#XBF%l1=1$wINZ(?ksS z--sjVJ;JY3h_2c$b*fXg z?}R!LSfKsvAp&Qha4oyl*^>^{>{n^>%x0WvZ#mRiJ=mk;--;IEYVL!d>xjq5ok}=2 zuJnU-NyX8mFw?GR^a}4y?SbRZMWbn$=lrahaT>WRuDB&wmqs$@&l>sD*GSu0_s$6JQ6oCOtoa-WCcS_gsV4WmW`^fCba&o~x!#DH z=c(@GPY$jYF-I4Xn2|`*_IXfLiwWqe~eZcnUG@j=m zOYz;@4W89UX$^(Zx^~&(!JG5H$Kt#g<++~OPw9%!JS>LG9Q6Qo#*tK*JIHPN@y2j zmv_5eGuXb9YP=0~y7;7H6qbzD9;022+Uceq(e_N#IIeGL{@^mS#;8rtjc%$bJOReq zvf3a0((zo_%Pbju2upKYqjtK>Jrk?=v#?}CG}?VO(cuvkA)4ie0|dw zbS^BnEv39|ClD4D?_$qU48dckSY!CH-~!fm2;H`?kYCmq%Xgm6cS`vTy3vFou=Bk{ zL+EvA=y|}lUkmVa+XL5k4txKyf_plE2gF4|DLe@26T1zeo8Kv6mkfKa#zYYGDZ`T8 z)F>>?P+h~9VqKx1y%{`}I?TT2_M9jtFrok2PB)aIVf1Up*b145Yz@SdiG=8}zmqFOnI}N&uM>knX@SM0MV8ewz*w&F||K=EOn|_fCgk}@H<(SmzZi~!T^C>}TpRTpTwzz@c zmy11ywiEU&?KEU(K)dGg`%uOM-g*AsN6pxw9X7rksQkJD`_1v}NWi-6eAf=ye#gn7 z*9lv zyP>eL?O2sOD&`p;b^{uJ+NB(Dx5eX#(1@7Bx;O1wPIti0dEN&vAaj(%W*jT7Y1eYt zMB^vYk8ef!7-|DLwk4+_(X^$mC~WSPB=d6ymKc8$0-uIq-) z^BPuPI+{c^l!kWIN8PabB@3(wyA`K!92#><%h52{QGHqFv4WHJqp7fo2+2y!Uk=Z@ z1AaX|4@{c*H#I!?96PvIUFEsFyVqD7?6B+b!JY{@RJ_t(&8K!47Q7sRM`_pEU|UvV zk;*uYBi9VuLPaesEIzIu_E>5|+Ef*TZYoO=w={c_ZFJ|)kBy}ppY6C%ZOg4(y{>4P z0R0&{*Uysgv>lJrI-*oP95$aZ+7tilJZCL;!}edgv)_&6_iMVi5j+QW#4Yfk5B9VQ za?lzpn(spM_M{2ju-SXh&J_D?(KadA4r1F0+eUwI8Q3mwJX#;!P1s}mv+IOi_GjBD z4SsVp#;%pR)}(vvYF{Vad4G0@cM0uUUEX8+v(xEbb=N`O354#ja?QG8-L8$*>Qc|@ z(P3%V`qo&f)#Iaqs?ptsM2kCNCiS=BdT8R1=-;X0by#>W6K=6gRg2>$(P| z;X}(&A8eks-&(Q47*R18<)gZ7DfQf8ud>pDZeZu5VX&!uuydPTM&snp{~>S*P<6;h zLt#%_X{@S%R*r_sdzF=zb9oBSr>-%3!ROd`L|~PPB)o0ks>()pW`&09=B79&X;}v z=K3z_CKsfx$bSuCzhZOWL6?^yE6H_@|KQ|YAc`sF&E8+0s#w?4L}jGXE$Jyl7$NziZ13j;rKj(Qv5DjkM@7dhdbI|?QvR^6xy;OcDwEMr3 z#IH8~csQ2-QJ@`X>vNGuvTk_6$8nOjsK@_11eqyMtfEiNU%x@hoU{Kpx+y-6mp2&x E9V&+}AOHXW literal 0 HcmV?d00001 diff --git a/xlsx.py b/xlsx.py new file mode 100644 index 0000000..bdaa05d --- /dev/null +++ b/xlsx.py @@ -0,0 +1,98 @@ +from openpyxl import Workbook +from openpyxl.styles import Font, PatternFill, Alignment, Border, Side +from openpyxl.utils import get_column_letter + +wb = Workbook() +wb.remove(wb.active) + +NAVY = "1B3A6B" +GOLD = "C9A84C" +WHITE = "FFFFFF" +LGRAY = "F4F6FA" +MGRAY = "8A99B0" +DGRAY = "2D3748" +TEAL = "0D9488" +RED = "C0392B" +GREEN_TX = "007A3D" +BLUE_TX = "0000FF" + +INR = '₹#,##0;(₹#,##0);"-"' +PCT = '0.0%;(0.0%);"-"' + +def side(style="thin", color="D1D5DB"): + return Side(style=style, color=color) + +def border(): + s = side() + return Border(top=s, bottom=s, left=s, right=s) + +def cell(ws, row, col, val=None, bold=False, bg=WHITE, fg=BLACK_TX, + size=10, align="left", fmt=None, italic=False): + c = ws.cell(row=row, column=col, value=val) + c.font = Font(name="Arial", bold=bold, color=fg, size=size, italic=italic) + c.fill = PatternFill("solid", start_color=bg) + c.alignment = Alignment(horizontal=align, vertical="center") + c.border = border() + if fmt: + c.number_format = fmt + return c + +def merge(ws, r1, c1, r2, c2, val=None, bold=False, bg=WHITE, fg=BLACK_TX): + ws.merge_cells(start_row=r1, start_column=c1, end_row=r2, end_column=c2) + c = ws.cell(row=r1, column=c1, value=val) + c.font = Font(name="Arial", bold=bold, color=fg, size=11) + c.fill = PatternFill("solid", start_color=bg) + c.alignment = Alignment(horizontal="center", vertical="center") + return c + +# ================= MASTER SHEET ================= +ws = wb.create_sheet("Master Budget") + +headers = ["Item", "Amount", "%"] +for i, h in enumerate(headers, 1): + cell(ws, 1, i, h, bold=True, bg=NAVY, fg=WHITE) + +data = [ + ("1st Prize", 25000), + ("2nd Prize", 12000), + ("3rd Prize", 6000), + ("Marketing", 78000), + ("Team", 25000), + ("Misc", 25000), +] + +row = 2 +for label, amount in data: + cell(ws, row, 1, label) + cell(ws, row, 2, amount, fg=BLUE_TX, align="center", fmt=INR) + row += 1 + +# Total +cell(ws, row, 1, "TOTAL", bold=True) +cell(ws, row, 2, f"=SUM(B2:B{row-1})", bold=True, fmt=INR) + +# Percent column +for r in range(2, row): + cell(ws, r, 3, f"=B{r}/B{row}", fmt=PCT, align="center") + +# ================= PRIZE SHEET ================= +ws2 = wb.create_sheet("Prize Structure") + +headers = ["Rank", "Cash Prize"] +for i, h in enumerate(headers, 1): + cell(ws2, 1, i, h, bold=True, bg=NAVY, fg=WHITE) + +prizes = [ + ("1st", 25000), + ("2nd", 12000), + ("3rd", 6000), +] + +for i, (rank, val) in enumerate(prizes, start=2): + cell(ws2, i, 1, rank) + cell(ws2, i, 2, val, fmt=INR, align="center") + +# ================= SAVE FILE ================= +wb.save("DSA_Budget.xlsx") + +print("✅ Excel file saved as DSA_Budget.xlsx in your folder") \ No newline at end of file From 02d79e902b3cc7ea0a632aafce3a9f8631da4cff Mon Sep 17 00:00:00 2001 From: Anirudha Acharya <127017012+anirudhaacharyap@users.noreply.github.com> Date: Fri, 22 May 2026 10:06:56 +0530 Subject: [PATCH 3/3] feat: staged parallel hybrid ingestion pipeline and starvation prevention locking order --- src/agents/judge.py | 147 +++++++++- src/api/routes/memory.py | 60 +++- src/pipelines/ingest.py | 542 +++++++++++++++++++++++++++++-------- tests/test_batch_ingest.py | 132 ++++++++- 4 files changed, 742 insertions(+), 139 deletions(-) diff --git a/src/agents/judge.py b/src/agents/judge.py index 8ca8c57..7a74a25 100644 --- a/src/agents/judge.py +++ b/src/agents/judge.py @@ -19,6 +19,7 @@ import asyncio import json +from difflib import SequenceMatcher from typing import Any, Callable, Dict, List, Optional from langchain_core.language_models import BaseChatModel @@ -162,7 +163,9 @@ def __init__( # Public entry point # ------------------------------------------------------------------ - async def arun(self, state: Dict[str, Any]) -> JudgeResult: + async def arun( + self, state: Dict[str, Any], pending_ops: Optional[List[Operation]] = None + ) -> JudgeResult: domain_str = state.get("domain", "") try: domain = JudgeDomain(domain_str) @@ -186,6 +189,7 @@ async def arun(self, state: Dict[str, Any]) -> JudgeResult: new_items=new_items, user_id=user_id, domain=domain, + pending_ops=pending_ops, ) if domain == JudgeDomain.SUMMARY and not _has_summary_judge_candidates(matches_per_item): @@ -218,7 +222,9 @@ async def arun(self, state: Dict[str, Any]) -> JudgeResult: return result - async def arun_deterministic(self, state: Dict[str, Any]) -> JudgeResult: + async def arun_deterministic( + self, state: Dict[str, Any], pending_ops: Optional[List[Operation]] = None + ) -> JudgeResult: """Build operations without an LLM for structured domains. Profile and temporal extraction already returns normalized structured @@ -238,15 +244,15 @@ async def arun_deterministic(self, state: Dict[str, Any]) -> JudgeResult: return JudgeResult() if domain == JudgeDomain.PROFILE: - result = await self._deterministic_profile(new_items, user_id) + result = await self._deterministic_profile(new_items, user_id, pending_ops=pending_ops) elif domain == JudgeDomain.TEMPORAL: - result = await self._deterministic_temporal(new_items, user_id) + result = await self._deterministic_temporal(new_items, user_id, pending_ops=pending_ops) else: self.logger.warning( "Deterministic judge unsupported for %s; falling back to LLM judge.", domain.value, ) - return await self.arun(state) + return await self.arun(state, pending_ops=pending_ops) self._log_result(domain, result) return result @@ -275,11 +281,16 @@ async def _fetch_similar( new_items: list, user_id: str, domain: JudgeDomain, + pending_ops: Optional[List[Operation]] = None, ) -> Dict[str, List[SearchResult]]: if domain == JudgeDomain.TEMPORAL: - return await self._fetch_similar_temporal(items_strings, new_items, user_id) + return await self._fetch_similar_temporal( + items_strings, new_items, user_id, pending_ops=pending_ops + ) else: - return await self._fetch_similar_vector(items_strings, new_items, user_id, domain) + return await self._fetch_similar_vector( + items_strings, new_items, user_id, domain, pending_ops=pending_ops + ) # -- Profile / Summary: Pinecone (vector store) ----------------------- @@ -289,6 +300,7 @@ async def _fetch_similar_vector( new_items: list, user_id: str, domain: JudgeDomain, + pending_ops: Optional[List[Operation]] = None, ) -> Dict[str, List[SearchResult]]: if not self.vector_store: self.logger.debug("No vector store attached — skipping similarity search.") @@ -297,7 +309,7 @@ async def _fetch_similar_vector( # Profile domain: use deterministic metadata lookup if domain == JudgeDomain.PROFILE: return await self._fetch_similar_profile_metadata( - items_strings, new_items, user_id, + items_strings, new_items, user_id, pending_ops=pending_ops ) # Summary / other: parallel semantic search across all items @@ -320,7 +332,43 @@ async def _search_one(item_str: str) -> tuple[str, List[SearchResult]]: return item_str, [] pairs = await asyncio.gather(*(_search_one(s) for s in items_strings)) - return dict(pairs) + matches_per_item = dict(pairs) + + if pending_ops: + for item_str in items_strings: + matches = matches_per_item.get(item_str, []) + + # Apply deletes first + deletes = {op.embedding_id for op in pending_ops if op.type == OperationType.DELETE and op.embedding_id} + if deletes: + matches = [m for m in matches if m.id not in deletes] + + # Apply adds/updates + for op in pending_ops: + if op.type in (OperationType.ADD, OperationType.UPDATE): + ratio = SequenceMatcher(None, _norm_text(item_str), _norm_text(op.content)).ratio() + if ratio > 0.5: + simulated = SearchResult( + id=op.embedding_id or f"pending_{domain.value}_{hash(op.content)}", + content=op.content, + score=ratio, + metadata={ + "domain": domain.value, + "user_id": user_id, + } + ) + # Replace if same ID already exists in matches, otherwise prepend + existing_idx = next((i for i, m in enumerate(matches) if m.id == simulated.id), None) + if existing_idx is not None: + matches[existing_idx] = simulated + else: + matches.insert(0, simulated) + + # Sort matches by score descending + matches = sorted(matches, key=lambda x: x.score or 0.0, reverse=True)[:self.top_k] + matches_per_item[item_str] = matches + + return matches_per_item # -- Profile: deterministic metadata lookup ---------------------------- @@ -329,6 +377,7 @@ async def _fetch_similar_profile_metadata( items_strings: List[str], new_items: list, user_id: str, + pending_ops: Optional[List[Operation]] = None, ) -> Dict[str, List[SearchResult]]: """Fetch existing profile records by exact topic_subtopic match (parallel). @@ -377,7 +426,36 @@ async def _lookup_one(idx: int, item_str: str) -> tuple[str, List[SearchResult]] pairs = await asyncio.gather( *(_lookup_one(i, s) for i, s in enumerate(items_strings)) ) - return dict(pairs) + matches_per_item = dict(pairs) + + if pending_ops: + for idx, item_str in enumerate(items_strings): + item = new_items[idx] if idx < len(new_items) else {} + meta_key = _build_profile_metadata_key(item) + if not meta_key: + continue + + for op in pending_ops: + op_meta_key = _profile_meta_key_from_content(op.content) + if op_meta_key == meta_key: + if op.type in (OperationType.ADD, OperationType.UPDATE): + _, _, memo = _parse_profile_content(op.content) + simulated = SearchResult( + id=op.embedding_id or f"pending_profile_{meta_key}", + content=op.content, + score=1.0, + metadata={ + "main_content": meta_key, + "subcontent": memo, + "domain": "profile", + "user_id": user_id, + } + ) + matches_per_item[item_str] = [simulated] + elif op.type == OperationType.DELETE: + matches_per_item[item_str] = [] + + return matches_per_item # -- Semantic search fallback (summary domain) ------------------------- @@ -406,6 +484,7 @@ async def _fetch_similar_temporal( items_strings: List[str], new_items: list, user_id: str, + pending_ops: Optional[List[Operation]] = None, ) -> Dict[str, List[SearchResult]]: if not self.graph_event_search: self.logger.debug("No graph_event_search provided — skipping Neo4j lookup.") @@ -434,7 +513,36 @@ async def _lookup_one(idx: int, item_str: str) -> tuple[str, List[SearchResult]] pairs = await asyncio.gather( *(_lookup_one(i, s) for i, s in enumerate(items_strings)) ) - return dict(pairs) + matches_per_item = dict(pairs) + + if pending_ops: + for idx, item_str in enumerate(items_strings): + event = new_items[idx] if idx < len(new_items) else {} + event_name = event.get("event_name", "") if isinstance(event, dict) else "" + if not event_name: + continue + norm_event_name = _norm_text(event_name) + + for op in pending_ops: + fields = _temporal_fields_from_content(op.content) + op_event_name = fields.get("event_name", "") + if _norm_text(op_event_name) == norm_event_name: + if op.type in (OperationType.ADD, OperationType.UPDATE): + simulated = SearchResult( + id=op.embedding_id or f"pending_temporal_{norm_event_name}", + content=op.content, + score=1.0, + metadata={ + **fields, + "domain": "temporal", + "user_id": user_id, + } + ) + matches_per_item[item_str] = [simulated] + elif op.type == OperationType.DELETE: + matches_per_item[item_str] = [] + + return matches_per_item # -- Deterministic operation builders --------------------------------- @@ -687,3 +795,20 @@ def _temporal_fields_from_match(match: SearchResult) -> Dict[str, str]: def _same_temporal_event(incoming: Dict[str, str], existing: Dict[str, str]) -> bool: keys = ["date", "event_name", "desc", "year", "time", "date_expression"] return all(_norm_text(incoming.get(key)) == _norm_text(existing.get(key)) for key in keys) + + +def _parse_profile_content(content: str) -> tuple[str, str, str]: + if " = " not in content: + return "", "", content + left, memo = content.split(" = ", 1) + if " / " not in left: + return left.strip(), "", memo.strip() + topic, sub_topic = left.split(" / ", 1) + return topic.strip(), sub_topic.strip(), memo.strip() + + +def _profile_meta_key_from_content(content: str) -> str: + topic, sub_topic, _ = _parse_profile_content(content) + if not topic or not sub_topic: + return "" + return f"{topic}_{sub_topic}".replace(" ", "_").lower() diff --git a/src/api/routes/memory.py b/src/api/routes/memory.py index 1fe4f34..0c7d91b 100644 --- a/src/api/routes/memory.py +++ b/src/api/routes/memory.py @@ -212,14 +212,55 @@ async def _run_ingest_payload( return data.model_dump() -async def _run_batch_ingest_payload( +async def _run_staged_batch_payload( payload: Dict[str, Any], user_id: str, ) -> Dict[str, Any]: - results = [] + pipeline = get_ingest_pipeline() + items = [] for item in payload["items"]: - results.append(await _run_ingest_payload(item, user_id)) - return {"results": results} + if hasattr(item, "model_dump"): + items.append(item.model_dump()) + elif isinstance(item, dict): + items.append(item) + else: + items.append(dict(item)) + + batch_results = await pipeline.run_staged_batch(items, user_id=user_id) + + results = [] + for result in batch_results: + data = IngestResponse( + model=_model_name(pipeline.model), + classification=_safe_classifications(result), + profile=_build_domain_result( + result.get("profile_judge"), + result.get("profile_weaver"), + ), + temporal=_build_domain_result( + result.get("temporal_judge"), + result.get("temporal_weaver"), + ), + summary=_build_domain_result( + result.get("summary_judge"), + result.get("summary_weaver"), + ), + image=_build_domain_result( + result.get("image_judge"), + result.get("image_weaver"), + ), + ) + results.append(data) + + return {"results": [r.model_dump() for r in results]} + + +async def _run_batch_ingest_payload( + payload: Dict[str, Any], + user_id: str, +) -> Dict[str, Any]: + async with _ingest_semaphore: + return await _run_staged_batch_payload(payload, user_id) async def _run_scrape_payload(payload: Dict[str, Any]) -> Dict[str, Any]: @@ -803,16 +844,15 @@ async def batch_ingest_memory(req: BatchIngestRequest, request: Request, user: d user_id = _current_user_id(user) try: - results = [] + payload = req.model_dump() async with _user_coordinator.acquire(user_id): - for item in req.items: + async with _ingest_semaphore: data = await asyncio.wait_for( - _run_ingest_payload(item.model_dump(), user_id), - timeout=120.0, + _run_staged_batch_payload(payload, user_id), + timeout=max(120.0, len(req.items) * 120.0), ) - results.append(IngestResponse(**data)) - response_data = BatchIngestResponse(results=results) + response_data = BatchIngestResponse(**data) elapsed = round((time.perf_counter() - start) * 1000, 2) return _wrap(request, response_data, elapsed) diff --git a/src/pipelines/ingest.py b/src/pipelines/ingest.py index 446440a..f2b8c87 100644 --- a/src/pipelines/ingest.py +++ b/src/pipelines/ingest.py @@ -82,7 +82,7 @@ ) from src.schemas.events import EventResult from src.schemas.image import ImageResult -from src.schemas.judge import JudgeDomain, JudgeResult, OperationType +from src.schemas.judge import JudgeDomain, JudgeResult, OperationType, Operation from src.schemas.profile import ProfileResult from src.schemas.summary import SummaryResult from src.schemas.weaver import WeaverResult @@ -762,33 +762,130 @@ def _route_after_classify(self, state: IngestState) -> List[Send]: # ── Extraction nodes ────────────────────────────────────────────── + # ── Decoupled helpers ───────────────────────────────────────────── + + async def _extract_profile(self, combined_query: str) -> ProfileResult: + return await self.profiler.arun({"classifier_output": combined_query}) + + async def _judge_profile( + self, items: list, user_id: str, pending_ops: Optional[List[Operation]] = None + ) -> JudgeResult: + return await self.judge.arun_deterministic({ + "domain": "profile", + "new_items": items, + "user_id": user_id, + }, pending_ops=pending_ops) + + async def _weave_profile(self, judge_result: JudgeResult, user_id: str) -> WeaverResult: + return await self.weaver.execute( + judge_result=judge_result, + domain=JudgeDomain.PROFILE, + user_id=user_id, + ) + + async def _extract_temporal(self, combined_query: str, session_dt: str) -> EventResult: + return await self.temporal.arun({ + "classifier_output": combined_query, + "session_datetime": session_dt, + }) + + async def _judge_temporal( + self, items: list, user_id: str, pending_ops: Optional[List[Operation]] = None + ) -> JudgeResult: + return await self.judge.arun_deterministic({ + "domain": "temporal", + "new_items": items, + "user_id": user_id, + }, pending_ops=pending_ops) + + async def _weave_temporal(self, judge_result: JudgeResult, user_id: str) -> WeaverResult: + return await self.weaver.execute( + judge_result=judge_result, + domain=JudgeDomain.TEMPORAL, + user_id=user_id, + ) + + async def _extract_image(self, state: IngestState) -> ImageResult: + return await self.image_agent.arun(state) + + async def _extract_code(self, combined_query: str) -> CodeAnnotationResult: + return await self.code_agent.arun({"classifier_output": combined_query}) + + async def _judge_code( + self, items: list, user_id: str, pending_ops: Optional[List[Operation]] = None + ) -> JudgeResult: + return await self.judge.arun({ + "domain": JudgeDomain.CODE, + "new_items": items, + "user_id": user_id, + }, pending_ops=pending_ops) + + async def _weave_code(self, judge_result: JudgeResult, user_id: str) -> WeaverResult: + return await self.weaver.execute( + judge_result=judge_result, + domain=JudgeDomain.CODE, + user_id=user_id, + ) + + async def _extract_snippet(self, combined_query: str) -> SnippetExtractionResult: + return await self.snippet_agent.arun({"classifier_output": combined_query}) + + async def _judge_snippet( + self, items: list, user_id: str, pending_ops: Optional[List[Operation]] = None + ) -> JudgeResult: + return await self.judge.arun({ + "domain": JudgeDomain.SNIPPET, + "new_items": items, + "user_id": user_id, + }, pending_ops=pending_ops) + + async def _weave_snippet(self, judge_result: JudgeResult, user_id: str) -> WeaverResult: + self.weaver.snippet_vector_store = self._get_snippet_store(user_id) + return await self.weaver.execute( + judge_result=judge_result, + domain=JudgeDomain.SNIPPET, + user_id=user_id, + ) + + async def _extract_summary(self, user_query: str, agent_response: str) -> SummaryResult: + return await self.summarizer.arun({ + "user_query": user_query, + "agent_response": agent_response, + }) + + async def _judge_summary( + self, items: list, user_id: str, pending_ops: Optional[List[Operation]] = None + ) -> JudgeResult: + return await self.judge.arun({ + "domain": JudgeDomain.SUMMARY, + "new_items": items, + "user_id": user_id, + }, pending_ops=pending_ops) + + async def _weave_summary(self, judge_result: JudgeResult, user_id: str) -> WeaverResult: + return await self.weaver.execute( + judge_result=judge_result, + domain=JudgeDomain.SUMMARY, + user_id=user_id, + ) + + # ── Extraction nodes ────────────────────────────────────────────── + async def _node_extract_profile(self, state: IngestState) -> Dict[str, Any]: """Extract profile facts from the classifier query.""" queries = state.get("profile_queries", []) user_id = state.get("user_id", "default") - # Merge into a single query (safety net if classifier outputs duplicate lines) combined_query = " ".join(queries) - result = await self.profiler.arun({"classifier_output": combined_query}) + result = await self._extract_profile(combined_query) if result.is_empty: return {"status": "no_profile_facts"} - # Profile facts are already structured; exact metadata lookup avoids - # an extra judge LLM call on the hot path. items = [f.model_dump() for f in result.facts] - judge_result = await self.judge.arun_deterministic({ - "domain": "profile", - "new_items": items, - "user_id": user_id, - }) + judge_result = await self._judge_profile(items, user_id) - # Weave - weaver_result = await self.weaver.execute( - judge_result=judge_result, - domain=JudgeDomain.PROFILE, - user_id=user_id, - ) + weaver_result = await self._weave_profile(judge_result, user_id) return { "profile_result": result, "profile_judge": judge_result, @@ -801,12 +898,8 @@ async def _node_extract_temporal(self, state: IngestState) -> Dict[str, Any]: user_id = state.get("user_id", "default") session_dt = state.get("session_datetime", "") - # Merge into a single query combined_query = " ".join(queries) - result = await self.temporal.arun({ - "classifier_output": combined_query, - "session_datetime": session_dt, - }) + result = await self._extract_temporal(combined_query, session_dt) if result.is_empty: return {"status": "no_temporal_event"} @@ -822,17 +915,9 @@ async def _node_extract_temporal(self, state: IngestState) -> Dict[str, Any]: "date_expression": event.date_expression or "", }) - judge_result = await self.judge.arun_deterministic({ - "domain": "temporal", - "new_items": all_items, - "user_id": user_id, - }) + judge_result = await self._judge_temporal(all_items, user_id) - weaver_result = await self.weaver.execute( - judge_result=judge_result, - domain=JudgeDomain.TEMPORAL, - user_id=user_id, - ) + weaver_result = await self._weave_temporal(judge_result, user_id) return { "temporal_result": result, "temporal_judge": judge_result, @@ -843,16 +928,11 @@ async def _node_extract_image(self, state: IngestState) -> Dict[str, Any]: """Extract visual observations from the image and store them as summary.""" user_id = state.get("user_id", "default") - # ImageAgent reads classifier_output and image_url from state - result = await self.image_agent.arun(state) + result = await self._extract_image(state) if result.is_empty: return {"status": "no_image_observations"} - # Convert observations to list of dicts for Judge - # items = [obs.model_dump() for obs in result.observations] - - #converted observation of images to summary and stored as summary items = [] if result.description: items.append(f"[Image] {result.description}") @@ -863,17 +943,9 @@ async def _node_extract_image(self, state: IngestState) -> Dict[str, Any]: if not items: return {"status": "no_image_observations"} - judge_result = await self.judge.arun({ - "domain": JudgeDomain.SUMMARY, - "new_items": items, - "user_id": user_id, - }) + judge_result = await self._judge_summary(items, user_id) - weaver_result = await self.weaver.execute( - judge_result=judge_result, - domain=JudgeDomain.SUMMARY, - user_id=user_id, - ) + weaver_result = await self._weave_summary(judge_result, user_id) return { "image_result": result, @@ -886,9 +958,8 @@ async def _node_extract_code(self, state: IngestState) -> Dict[str, Any]: queries = state.get("code_queries", []) user_id = state.get("user_id", "default") - # Merge into a single query combined_query = " ".join(queries) - result = await self.code_agent.arun({"classifier_output": combined_query}) + result = await self._extract_code(combined_query) if result.is_empty: return {"status": "no_code_annotations"} @@ -905,17 +976,9 @@ async def _node_extract_code(self, state: IngestState) -> Dict[str, Any]: ] all_items.append(" | ".join(parts)) - judge_result = await self.judge.arun({ - "domain": JudgeDomain.CODE, - "new_items": all_items, - "user_id": user_id, - }) + judge_result = await self._judge_code(all_items, user_id) - weaver_result = await self.weaver.execute( - judge_result=judge_result, - domain=JudgeDomain.CODE, - user_id=user_id, - ) + weaver_result = await self._weave_code(judge_result, user_id) return { "code_result": result, "code_judge": judge_result, @@ -927,9 +990,8 @@ async def _node_extract_snippet(self, state: IngestState) -> Dict[str, Any]: queries = state.get("code_queries", []) user_id = state.get("user_id", "default") - # Merge into a single query combined_query = " ".join(queries) - result = await self.snippet_agent.arun({"classifier_output": combined_query}) + result = await self._extract_snippet(combined_query) if result.is_empty: return {"status": "no_snippets"} @@ -945,20 +1007,9 @@ async def _node_extract_snippet(self, state: IngestState) -> Dict[str, Any]: ] all_items.append(" | ".join(parts)) - judge_result = await self.judge.arun({ - "domain": JudgeDomain.SNIPPET, - "new_items": all_items, - "user_id": user_id, - }) - - # Bind the user-scoped snippet store before executing - self.weaver.snippet_vector_store = self._get_snippet_store(user_id) + judge_result = await self._judge_snippet(all_items, user_id) - weaver_result = await self.weaver.execute( - judge_result=judge_result, - domain=JudgeDomain.SNIPPET, - user_id=user_id, - ) + weaver_result = await self._weave_snippet(judge_result, user_id) return { "snippet_result": result, "snippet_judge": judge_result, @@ -966,14 +1017,14 @@ async def _node_extract_snippet(self, state: IngestState) -> Dict[str, Any]: } async def _node_extract_summary(self, state: IngestState) -> Dict[str, Any]: - result = await self.summarizer.arun({ - "user_query": state.get("user_query", ""), - "agent_response": state.get("agent_response", ""), - }) + user_id = state.get("user_id", "default") + result = await self._extract_summary( + user_query=state.get("user_query", ""), + agent_response=state.get("agent_response", ""), + ) if result.is_empty: return {"status": "no_summary"} - # Split bullet summary into individual items items = [ line.lstrip("- •").strip() for line in result.summary.strip().splitlines() @@ -982,17 +1033,9 @@ async def _node_extract_summary(self, state: IngestState) -> Dict[str, Any]: if not items: return {"status": "no_summary_items"} - judge_result = await self.judge.arun({ - "domain": "summary", - "new_items": items, - "user_id": state.get("user_id", "default"), - }) + judge_result = await self._judge_summary(items, user_id) - weaver_result = await self.weaver.execute( - judge_result=judge_result, - domain=JudgeDomain.SUMMARY, - user_id=state.get("user_id", "default"), - ) + weaver_result = await self._weave_summary(judge_result, user_id) return { "summary_result": result, "summary_judge": judge_result, @@ -1143,6 +1186,285 @@ async def _invoke_graph( } return await self.graph.ainvoke(initial_state) + async def _process_item_phase_a(self, idx: int, item: Dict[str, Any], user_id: str) -> Dict[str, Any]: + """Phase A - Classification and domain extraction concurrently for a single item.""" + user_query = item.get("user_query", "") + agent_response = item.get("agent_response", "") or "Acknowledged." + session_dt = item.get("session_datetime", "") + image_url = item.get("image_url", "") + disabled_domains = set(item.get("disabled_domains") or []) + + # Run Classifier + classifier_query = user_query + if image_url: + classifier_query += " [User has attached an image]" + + classification_result = await self.classifier.arun({ + "user_query": classifier_query, + }) + + # Collect sub-queries per domain + profile_queries = [] + temporal_queries = [] + image_queries = [] + code_queries = [] + + if classification_result and classification_result.classifications: + for c in classification_result.classifications: + if c["source"] == "profile": + profile_queries.append(c["query"]) + elif c["source"] == "event": + temporal_queries.append(c["query"]) + elif c["source"] == "image": + image_queries.append(c["query"]) + elif c["source"] == "code": + code_queries.append(c["query"]) + + words = user_query.split() + is_trivial = len(words) < 4 and not any([profile_queries, temporal_queries, code_queries, image_queries]) + + tasks = [] + task_names = [] + + if not is_trivial: + tasks.append(self._extract_summary(user_query, agent_response)) + task_names.append("summary") + + if profile_queries: + combined_profile = " ".join(profile_queries) + tasks.append(self._extract_profile(combined_profile)) + task_names.append("profile") + + if temporal_queries: + combined_temporal = " ".join(temporal_queries) + tasks.append(self._extract_temporal(combined_temporal, session_dt)) + task_names.append("temporal") + + if code_queries and not {"code", "snippet"}.issubset(disabled_domains): + is_enterprise = self.org_id != "default" + if is_enterprise and "code" not in disabled_domains: + combined_code = " ".join(code_queries) + tasks.append(self._extract_code(combined_code)) + task_names.append("code") + elif not is_enterprise and "snippet" not in disabled_domains: + combined_snippet = " ".join(code_queries) + tasks.append(self._extract_snippet(combined_snippet)) + task_names.append("snippet") + + if image_url: + if not image_queries: + image_queries.append("Analyze this image for memory-relevant details.") + combined_image = " ".join(image_queries) + image_state = { + "classifier_output": combined_image, + "image_url": image_url, + "user_id": user_id, + } + tasks.append(self._extract_image(image_state)) + task_names.append("image") + + extraction_results = await asyncio.gather(*tasks, return_exceptions=True) + + item_state = { + "user_query": user_query, + "agent_response": agent_response, + "user_id": user_id, + "session_datetime": session_dt, + "image_url": image_url, + "disabled_domains": list(disabled_domains), + "classification_result": classification_result, + "errors": [], + "status": "extracted", + } + + for name, result in zip(task_names, extraction_results): + if isinstance(result, Exception): + logger.error(f"Error during {name} extraction for batch item {idx}: {result}") + item_state["errors"].append(f"{name}_extraction_error: {str(result)}") + else: + item_state[f"{name}_result"] = result + + return {"idx": idx, "item_state": item_state} + + async def run_staged_batch( + self, + items: List[Dict[str, Any]], + user_id: str, + ) -> List[Dict[str, Any]]: + """Run batch memory ingestion using a staged parallel/sequential hybrid pipeline.""" + logger.info("=" * 60) + logger.info("RUN STAGED BATCH: %d items", len(items)) + logger.info("=" * 60) + + # Phase A: Concurrently run classification + domain extraction across all items + phase_a_tasks = [self._process_item_phase_a(idx, item, user_id) for idx, item in enumerate(items)] + phase_a_outputs = await asyncio.gather(*phase_a_tasks) + + # Phase B: Sequentially run Judge across all items with pending_ops tracking + pending_ops: List[Operation] = [] + + for phase_a_out in phase_a_outputs: + item_state = phase_a_out["item_state"] + idx = phase_a_out["idx"] + + judge_tasks = [] + judge_domains = [] + + # 1. Profile facts + profile_res = item_state.get("profile_result") + if profile_res and not profile_res.is_empty: + items_data = [f.model_dump() for f in profile_res.facts] + judge_tasks.append(self._judge_profile(items_data, user_id, pending_ops=pending_ops)) + judge_domains.append("profile") + + # 2. Temporal events + temporal_res = item_state.get("temporal_result") + if temporal_res and not temporal_res.is_empty: + items_data = [] + for event in temporal_res.events: + items_data.append({ + "date": event.date, + "event_name": event.event_name or "", + "desc": event.desc or "", + "year": event.year or "", + "time": event.time or "", + "date_expression": event.date_expression or "", + }) + judge_tasks.append(self._judge_temporal(items_data, user_id, pending_ops=pending_ops)) + judge_domains.append("temporal") + + # 3. Summary (and Image) + summary_res = item_state.get("summary_result") + image_res = item_state.get("image_result") + + summary_items = [] + if summary_res and not summary_res.is_empty: + summary_items.extend([ + line.lstrip("- •").strip() + for line in summary_res.summary.strip().splitlines() + if line.strip() and line.strip() not in ("-", "•") + ]) + + if image_res and not image_res.is_empty: + if image_res.description: + summary_items.append(f"[Image] {image_res.description}") + for obs in image_res.observations: + conf = f" ({obs.confidence})" if obs.confidence else "" + summary_items.append(f"[Image/{obs.category}] {obs.description}{conf}") + + if summary_items: + judge_tasks.append(self._judge_summary(summary_items, user_id, pending_ops=pending_ops)) + judge_domains.append("summary") + + # 4. Code annotations + code_res = item_state.get("code_judge") or item_state.get("code_result") + # Wait, let's look at the result schema. It's code_result + code_res = item_state.get("code_result") + if code_res and not code_res.is_empty: + items_data = [] + for ann in code_res.annotations: + parts = [ + ann.annotation_type.value, + ann.target_symbol or "", + ann.target_file or "", + ann.repo or "", + ann.severity.value if ann.severity else "", + ann.content, + ] + items_data.append(" | ".join(parts)) + judge_tasks.append(self._judge_code(items_data, user_id, pending_ops=pending_ops)) + judge_domains.append("code") + + # 5. Personal code snippets + snippet_res = item_state.get("snippet_result") + if snippet_res and not snippet_res.is_empty: + items_data = [] + for snip in snippet_res.snippets: + parts = [ + snip.content, + snip.code_snippet.replace("\n", "\\n") if snip.code_snippet else "", + snip.language, + snip.snippet_type.value, + ",".join(snip.tags), + ] + items_data.append(" | ".join(parts)) + judge_tasks.append(self._judge_snippet(items_data, user_id, pending_ops=pending_ops)) + judge_domains.append("snippet") + + if judge_tasks: + judge_results = await asyncio.gather(*judge_tasks, return_exceptions=True) + for domain_name, jr in zip(judge_domains, judge_results): + if isinstance(jr, Exception): + logger.error(f"Error during {domain_name} judge for batch item {idx}: {jr}") + item_state["errors"].append(f"{domain_name}_judge_error: {str(jr)}") + else: + item_state[f"{domain_name}_judge"] = jr + if domain_name == "summary" and image_res and not image_res.is_empty: + item_state["image_judge"] = jr + + if jr and jr.operations: + pending_ops.extend(jr.operations) + + # Phase C: Concurrently run Weaver to write changes in parallel across all items + weave_tasks = [] + weave_mappings = [] + + for phase_a_out in phase_a_outputs: + item_state = phase_a_out["item_state"] + idx = phase_a_out["idx"] + + # Profile + profile_judge = item_state.get("profile_judge") + if profile_judge: + weave_tasks.append(self._weave_profile(profile_judge, user_id)) + weave_mappings.append((item_state, "profile_weaver")) + + # Temporal + temporal_judge = item_state.get("temporal_judge") + if temporal_judge: + weave_tasks.append(self._weave_temporal(temporal_judge, user_id)) + weave_mappings.append((item_state, "temporal_weaver")) + + # Summary + summary_judge = item_state.get("summary_judge") + if summary_judge: + weave_tasks.append(self._weave_summary(summary_judge, user_id)) + weave_mappings.append((item_state, "summary_weaver")) + + # Image + image_judge = item_state.get("image_judge") + if image_judge and "image_result" in item_state: + weave_tasks.append(self._weave_summary(image_judge, user_id)) + weave_mappings.append((item_state, "image_weaver")) + + # Code + code_judge = item_state.get("code_judge") + if code_judge: + weave_tasks.append(self._weave_code(code_judge, user_id)) + weave_mappings.append((item_state, "code_weaver")) + + # Snippet + snippet_judge = item_state.get("snippet_judge") + if snippet_judge: + weave_tasks.append(self._weave_snippet(snippet_judge, user_id)) + weave_mappings.append((item_state, "snippet_weaver")) + + if weave_tasks: + weave_results = await asyncio.gather(*weave_tasks, return_exceptions=True) + for (item_state, key), wr in zip(weave_mappings, weave_results): + if isinstance(wr, Exception): + logger.error(f"Error during weaving for key {key}: {wr}") + item_state["errors"].append(f"{key}_error: {str(wr)}") + else: + item_state[key] = wr + + # Complete all items + for phase_a_out in phase_a_outputs: + item_state = phase_a_out["item_state"] + item_state["status"] = "completed" + + return [out["item_state"] for out in phase_a_outputs] + async def _run_high_effort( self, user_query: str, @@ -1153,13 +1475,7 @@ async def _run_high_effort( cfg: EffortConfig, disabled_domains: Optional[List[str]] = None, ) -> Dict[str, Any]: - """HIGH-effort path: chunk user_query → sequential pipeline calls → merge. - - Each chunk gets the full pipeline run independently and sequentially. - The ``agent_response`` is passed to every chunk so summary extraction - always has the full assistant context. Image is only forwarded to the - first chunk to avoid duplicate image processing. - """ + """HIGH-effort path: chunk user_query -> parallel staged staged batch run -> merge.""" chunks = chunk_text( user_query, chunk_size_tokens=cfg.chunk_size_tokens, @@ -1175,25 +1491,21 @@ async def _run_high_effort( cfg.chunk_threshold_tokens, ) - # Process every chunk through the pipeline sequentially to avoid duplicates. - # Image is only sent with chunk[0] to avoid duplicate processing. - chunk_results: List[Dict[str, Any]] = [] - for idx, chunk in enumerate(chunks): - logger.info("Processing chunk %d/%d...", idx + 1, len(chunks)) - res = await self._invoke_graph( - user_query=chunk, - agent_response=agent_response, - user_id=user_id, - session_datetime=session_datetime, - image_url=image_url if idx == 0 else "", - disabled_domains=disabled_domains, - ) - chunk_results.append(res) + batch_items = [ + { + "user_query": chunk, + "agent_response": agent_response, + "user_id": user_id, + "session_datetime": session_datetime, + "image_url": image_url if idx == 0 else "", + "disabled_domains": disabled_domains or [], + } + for idx, chunk in enumerate(chunks) + ] + + chunk_results = await self.run_staged_batch(batch_items, user_id=user_id) # ── Merge states ───────────────────────────────────────────── - # All writes (Pinecone / Neo4j) already happened inside each chunk's - # pipeline run. We merge the returned state dicts so callers get a - # sensible aggregate view. merged: Dict[str, Any] = {} all_errors: List[str] = [] @@ -1201,9 +1513,7 @@ async def _run_high_effort( # Accumulate errors from every chunk. all_errors.extend(state.get("errors") or []) - # For every key, prefer the last non-None value; this gives - # callers the final-chunk's extraction results while retaining - # earlier chunks' results when a later chunk produced nothing. + # For every key, prefer the last non-None value for key, value in state.items(): if key == "errors": continue diff --git a/tests/test_batch_ingest.py b/tests/test_batch_ingest.py index 2e885ff..7d474d1 100644 --- a/tests/test_batch_ingest.py +++ b/tests/test_batch_ingest.py @@ -33,7 +33,24 @@ async def mock_run(*args, **kwargs): "image_weaver": None, } + async def mock_run_staged_batch(items, user_id): + return [ + { + "classification_result": SimpleNamespace(classifications=["test"]), + "profile_judge": None, + "profile_weaver": None, + "temporal_judge": None, + "temporal_weaver": None, + "summary_judge": None, + "summary_weaver": None, + "image_judge": None, + "image_weaver": None, + } + for _ in items + ] + mock_pipeline.run.side_effect = mock_run + mock_pipeline.run_staged_batch.side_effect = mock_run_staged_batch mock_get_pipeline.return_value = mock_pipeline yield mock_pipeline @@ -130,6 +147,117 @@ def _send_batch(idx): assert r is not None assert r.status_code == 200, r.json() - # All 4 pipeline.run calls (2 items × 2 batches) should have been made - assert mock_ingest_pipeline.run.call_count == 4 + # All 2 run_staged_batch calls (2 batches) should have been made + assert mock_ingest_pipeline.run_staged_batch.call_count == 2 + + +@pytest.mark.asyncio +async def test_run_staged_batch_overlays(): + """Verify that JudgeAgent applies simulated overlays for profile, temporal, and semantic domains when pending_ops are passed.""" + from unittest.mock import MagicMock + from src.agents.judge import JudgeAgent, JudgeDomain + from src.schemas.judge import Operation, OperationType + from src.storage.base import SearchResult + + # Mock the vector store and graph event search + mock_vector_store = MagicMock() + mock_graph_search = MagicMock() + + agent = JudgeAgent(model=MagicMock(), name="judge", system_prompt="system") + agent.vector_store = mock_vector_store + agent.graph_event_search = mock_graph_search + agent.top_k = 5 + + # 1. Test profile topic/sub-topic exact match + pending_profile_ops = [ + Operation( + type=OperationType.ADD, + content="work / company = XMem", + reason="User works at XMem", + embedding_id="pending_prof_1" + ) + ] + + mock_vector_store.search_by_metadata.return_value = [] + + res = await agent._fetch_similar( + items_strings=["work / company = XMem"], + new_items=[{"topic": "work", "sub_topic": "company", "memo": "XMem"}], + user_id="user_1", + domain=JudgeDomain.PROFILE, + pending_ops=pending_profile_ops + ) + + assert "work / company = XMem" in res + assert len(res["work / company = XMem"]) == 1 + match = res["work / company = XMem"][0] + assert match.id == "pending_prof_1" + assert match.score == 1.0 + assert match.metadata["domain"] == "profile" + + # 2. Test profile delete overlay (should clear matches) + pending_delete_profile_ops = [ + Operation( + type=OperationType.DELETE, + content="work / company = XMem", + reason="User left XMem", + embedding_id="pending_prof_1" + ) + ] + + res_del = await agent._fetch_similar( + items_strings=["work / company = XMem"], + new_items=[{"topic": "work", "sub_topic": "company", "memo": "XMem"}], + user_id="user_1", + domain=JudgeDomain.PROFILE, + pending_ops=pending_delete_profile_ops + ) + assert len(res_del["work / company = XMem"]) == 0 + + # 3. Test temporal event name match + pending_temporal_ops = [ + Operation( + type=OperationType.ADD, + content="Date: 05-22, 2026 | Event: Launch | Description: Final Release | Time: | Date expression: today", + reason="Launch event", + embedding_id="pending_temp_1" + ) + ] + + mock_graph_search.search_events_by_embedding.return_value = [] + res_temp = await agent._fetch_similar( + items_strings=["Launch event"], + new_items=[{"date": "05-22", "event_name": "Launch", "desc": "Final Release", "year": "2026", "time": "", "date_expression": "today"}], + user_id="user_1", + domain=JudgeDomain.TEMPORAL, + pending_ops=pending_temporal_ops + ) + assert len(res_temp["Launch event"]) == 1 + match_temp = res_temp["Launch event"][0] + assert match_temp.id == "pending_temp_1" + assert match_temp.score == 1.0 + + # 4. Test semantic similarity match (SequenceMatcher) + pending_summary_ops = [ + Operation( + type=OperationType.ADD, + content="Likes clean coding and unit testing", + reason="Clean code preference", + embedding_id="pending_sum_1" + ) + ] + + mock_vector_store.search_by_text = MagicMock(return_value=[]) + res_sum = await agent._fetch_similar( + items_strings=["Likes clean coding"], + new_items=[], + user_id="user_1", + domain=JudgeDomain.SUMMARY, + pending_ops=pending_summary_ops + ) + assert len(res_sum["Likes clean coding"]) == 1 + match_sum = res_sum["Likes clean coding"][0] + assert match_sum.id == "pending_sum_1" + assert match_sum.score > 0.5 +