diff --git a/pyproject.toml b/pyproject.toml index 3ae9b2298..bc1793386 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,9 @@ issues = "https://github.com/MemTensor/MemOS/issues" [project.scripts] memos = "memos.cli:main" +[project.entry-points."memos.plugins"] +dream = "memos.dream:CommunityDreamPlugin" + [project.optional-dependencies] # These are optional dependencies for various features of MemoryOS. # Developers install: `poetry install --extras `. e.g., `poetry install --extras general-mem` diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py index 2e9032f11..68dcfe6fb 100644 --- a/src/memos/api/handlers/chat_handler.py +++ b/src/memos/api/handlers/chat_handler.py @@ -60,6 +60,7 @@ def __init__( self, dependencies: HandlerDependencies, chat_llms: dict[str, Any], + playground_chat_llms: dict[str, Any] | None = None, search_handler=None, add_handler=None, online_bot=None, @@ -70,6 +71,7 @@ def __init__( Args: dependencies: HandlerDependencies instance chat_llms: Dictionary mapping model names to LLM instances + playground_chat_llms: Optional model map for /chat/stream/playground search_handler: Optional SearchHandler instance (created if not provided) add_handler: Optional AddHandler instance (created if not provided) online_bot: Optional DingDing bot function for notifications @@ -89,6 +91,7 @@ def __init__( add_handler = AddHandler(dependencies) self.chat_llms = chat_llms + self.playground_chat_llms = playground_chat_llms or chat_llms self.search_handler = search_handler self.add_handler = add_handler self.online_bot = online_bot @@ -630,10 +633,11 @@ def generate_chat_response() -> Generator[str, None, None]: # Step 3: Generate streaming response from LLM try: - model = next(iter(self.chat_llms.keys())) + chat_llms = self.playground_chat_llms + model = next(iter(chat_llms.keys())) self.logger.info(f"[PLAYGROUND CHAT] Chat Playground Stream Model: {model}") start = time.time() - response_stream = self.chat_llms[model].generate_stream( + response_stream = chat_llms[model].generate_stream( current_messages, model_name_or_path=model ) diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 03dcc8412..b9c209e61 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -157,6 +157,7 @@ def init_server() -> dict[str, Any]: graph_db_config = build_graph_db_config() llm_config = build_llm_config() chat_llm_config = build_chat_llm_config() + playground_chat_llm_config = build_chat_llm_config("PLAYGROUND_CHAT_MODEL_LIST") embedder_config = build_embedder_config() nli_client_config = build_nli_client_config() mem_reader_config = build_mem_reader_config() @@ -174,6 +175,11 @@ def init_server() -> dict[str, Any]: if os.getenv("ENABLE_CHAT_API", "false") == "true" else None ) + playground_chat_llms = ( + _init_chat_llms(playground_chat_llm_config) + if os.getenv("ENABLE_CHAT_API", "false") == "true" and playground_chat_llm_config + else chat_llms + ) embedder = EmbedderFactory.from_config(embedder_config) plugin_context = build_plugin_context( @@ -317,6 +323,7 @@ def init_server() -> dict[str, Any]: "mem_reader": mem_reader, "llm": llm, "chat_llms": chat_llms, + "playground_chat_llms": playground_chat_llms, "embedder": embedder, "reranker": reranker, "internet_retriever": internet_retriever, diff --git a/src/memos/api/handlers/config_builders.py b/src/memos/api/handlers/config_builders.py index d29429fc9..0a083e284 100644 --- a/src/memos/api/handlers/config_builders.py +++ b/src/memos/api/handlers/config_builders.py @@ -85,14 +85,17 @@ def build_llm_config() -> dict[str, Any]: ) -def build_chat_llm_config() -> list[dict[str, Any]]: +def build_chat_llm_config(env_name: str = "CHAT_MODEL_LIST") -> list[dict[str, Any]]: """ Build chat LLM configuration. Returns: Validated chat LLM configuration dictionary + Args: + env_name: Environment variable that contains the JSON chat model list. + """ - configs = json.loads(os.getenv("CHAT_MODEL_LIST", "[]")) + configs = json.loads(os.getenv(env_name, "[]")) return [ { "config_class": LLMConfigFactory.model_validate( diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 15eb7c38e..7f4bad798 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -20,7 +20,8 @@ from memos.multi_mem_cube.composite_cube import CompositeCubeView from memos.multi_mem_cube.single_cube import SingleCubeView from memos.multi_mem_cube.views import MemCubeView -from memos.plugins.hooks import hookable +from memos.plugins.hook_defs import H +from memos.plugins.hooks import hookable, trigger_hook logger = get_logger(__name__) @@ -71,6 +72,14 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse # Search and deduplicate cube_view = self._build_cube_view(search_req_local) results = cube_view.search_memories(search_req_local) + hooked_results = trigger_hook( + H.SEARCH_MEMORY_RESULTS, + handler=self, + search_req=search_req_local, + results=results, + ) + if hooked_results is not None: + results = hooked_results if not search_req_local.relativity: search_req_local.relativity = 0 self.logger.info(f"[SearchHandler] Relativity filter: {search_req_local.relativity}") diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index fa8a0b396..351d3a54e 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -76,10 +76,11 @@ add_handler = AddHandler(dependencies) chat_handler = ( ChatHandler( - dependencies, - components["chat_llms"], - search_handler, - add_handler, + dependencies=dependencies, + chat_llms=components["chat_llms"], + playground_chat_llms=components.get("playground_chat_llms"), + search_handler=search_handler, + add_handler=add_handler, online_bot=components.get("online_bot"), ) if os.getenv("ENABLE_CHAT_API", "false") == "true" diff --git a/src/memos/dream/contextualization.py b/src/memos/dream/contextualization.py new file mode 100644 index 000000000..911caf2ac --- /dev/null +++ b/src/memos/dream/contextualization.py @@ -0,0 +1,765 @@ +from __future__ import annotations + +import json +import logging +import os + +from contextlib import suppress +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any +from uuid import uuid4 + +from memos.dream.enrichment import DREAM_INTERNAL_INFO_KEY +from memos.dream.prompts import CONTEXT_BINDING_PROMPT, CONTEXT_SUMMARY_PROMPT + + +logger = logging.getLogger(__name__) + +CONTEXT_MEMORY_TYPE = "Context" +_CONTEXT_ID_PREFIX = "ctx_" +_DREAM_OUTPUT_MEMORY_TYPES = {CONTEXT_MEMORY_TYPE, "InsightMemory", "DreamDiary"} +_DEFAULT_BINDING_MIN_GROUP_SIZE = 2 +_DEFAULT_BINDING_MAX_GROUP_SIZE = 30 +_DEFAULT_BINDING_CONFIDENCE_THRESHOLD = 0.65 + + +def _env_enabled(name: str, default: str = "on") -> bool: + return os.getenv(name, default).strip().lower() not in {"0", "false", "no", "off"} + + +def _env_int(name: str, default: int) -> int: + with suppress(TypeError, ValueError): + return int(os.getenv(name, str(default))) + return default + + +def _env_float(name: str, default: float) -> float: + with suppress(TypeError, ValueError): + return float(os.getenv(name, str(default))) + return default + + +@dataclass +class DreamContextSource: + id: str + memory: str + metadata: dict[str, Any] + dream: dict[str, Any] + created_at: str | None = None + + +@dataclass +class DreamContextGroup: + group_key: str + memories: list[DreamContextSource] + strategy: str = "heuristic" + confidence: float = 0.7 + proposed_key: str | None = None + should_persist: bool = True + + @property + def memory_ids(self) -> list[str]: + return [memory.id for memory in self.memories] + + +@dataclass +class DreamBindingUnit: + short_id: str + memories: list[DreamContextSource] + + @property + def real_ids(self) -> list[str]: + return [memory.id for memory in self.memories] + + +@dataclass +class DreamContextReport: + processed_memory_count: int = 0 + created_context_count: int = 0 + updated_context_count: int = 0 + bound_memory_count: int = 0 + skipped_memory_count: int = 0 + contexts: list[dict[str, Any]] = field(default_factory=list) + + def model_dump(self) -> dict[str, Any]: + return { + "processed_memory_count": self.processed_memory_count, + "created_context_count": self.created_context_count, + "updated_context_count": self.updated_context_count, + "bound_memory_count": self.bound_memory_count, + "skipped_memory_count": self.skipped_memory_count, + "contexts": self.contexts, + } + + +class DreamContextualizer: + """Create or update `Context` memory nodes from Dream pending memories. + + The v1 implementation is deliberately conservative: + - weak context IDs build the initial candidate pools; + - LLM binding can split broad pools into tighter context groups; + - existing LLM-bound contexts are matched by source memory overlap. + """ + + def __init__( + self, + *, + enabled: bool | None = None, + summary_llm_enabled: bool | None = None, + binding_llm_enabled: bool | None = None, + binding_min_group_size: int | None = None, + binding_max_group_size: int | None = None, + binding_confidence_threshold: float | None = None, + ): + self.enabled = ( + _env_enabled("MEMOS_DREAM_CONTEXT_ENABLED", "on") if enabled is None else enabled + ) + self.summary_llm_enabled = ( + _env_enabled("MEMOS_DREAM_CONTEXT_SUMMARY_LLM", "on") + if summary_llm_enabled is None + else summary_llm_enabled + ) + self.binding_llm_enabled = ( + _env_enabled("MEMOS_DREAM_CONTEXT_BINDING_LLM", "on") + if binding_llm_enabled is None + else binding_llm_enabled + ) + self.binding_min_group_size = ( + _env_int("MEMOS_DREAM_CONTEXT_BINDING_MIN_GROUP_SIZE", _DEFAULT_BINDING_MIN_GROUP_SIZE) + if binding_min_group_size is None + else binding_min_group_size + ) + self.binding_max_group_size = ( + _env_int("MEMOS_DREAM_CONTEXT_BINDING_MAX_GROUP_SIZE", _DEFAULT_BINDING_MAX_GROUP_SIZE) + if binding_max_group_size is None + else binding_max_group_size + ) + self.binding_confidence_threshold = ( + _env_float( + "MEMOS_DREAM_CONTEXT_BINDING_CONFIDENCE_THRESHOLD", + _DEFAULT_BINDING_CONFIDENCE_THRESHOLD, + ) + if binding_confidence_threshold is None + else binding_confidence_threshold + ) + self.context: dict[str, Any] = {} + + def bind_context(self, context: dict[str, Any]) -> None: + self.context = context + + def run(self, *, signal_snapshot, text_mem, cube_id: str) -> DreamContextReport: + report = DreamContextReport() + if not self.enabled: + return report + + memory_ids = list(dict.fromkeys(getattr(signal_snapshot, "pending_memory_ids", []) or [])) + if not memory_ids: + return report + + graph_db = self.context.get("shared", {}).get("graph_db") + if graph_db is None: + logger.info("[Dream Context] graph_db unavailable; skip context stage.") + return report + + memories = self._load_memories(graph_db=graph_db, memory_ids=memory_ids, cube_id=cube_id) + report.processed_memory_count = len(memories) + report.skipped_memory_count = max(0, len(memory_ids) - len(memories)) + if not memories: + return report + + existing_contexts = self._load_existing_contexts(graph_db=graph_db, cube_id=cube_id) + for group in self._build_groups(memories): + if not group.should_persist: + report.skipped_memory_count += len(group.memory_ids) + continue + context_node = self._match_existing_context(group, existing_contexts) + context_event = self._persist_group( + graph_db=graph_db, + group=group, + context_node=context_node, + cube_id=cube_id, + ) + action = context_event["action"] + if action == "created": + report.created_context_count += 1 + else: + report.updated_context_count += 1 + report.bound_memory_count += len(group.memory_ids) + report.contexts.append(context_event) + return report + + def _load_memories( + self, *, graph_db, memory_ids: list[str], cube_id: str + ) -> list[DreamContextSource]: + try: + nodes = graph_db.get_nodes(memory_ids, include_embedding=True, user_name=cube_id) + except Exception: + logger.warning("[Dream Context] failed to load pending memories.", exc_info=True) + return [] + + loaded: list[DreamContextSource] = [] + for node in nodes or []: + if not isinstance(node, dict): + continue + metadata = node.get("metadata") if isinstance(node.get("metadata"), dict) else {} + memory_type = metadata.get("memory_type") + if memory_type in _DREAM_OUTPUT_MEMORY_TYPES: + continue + if metadata.get("status") not in (None, "activated"): + continue + if metadata.get("source") == "dream": + continue + + internal_info = _coerce_json_dict(metadata.get("internal_info")) + dream = internal_info.get(DREAM_INTERNAL_INFO_KEY) + if not isinstance(dream, dict): + dream = {} + loaded.append( + DreamContextSource( + id=str(node.get("id", "")), + memory=node.get("memory", "") or node.get("content", ""), + metadata=metadata, + dream=dream, + created_at=metadata.get("created_at"), + ) + ) + return [memory for memory in loaded if memory.id and memory.memory] + + def _load_existing_contexts(self, *, graph_db, cube_id: str) -> list[dict[str, Any]]: + filters = [{"field": "memory_type", "op": "=", "value": CONTEXT_MEMORY_TYPE}] + try: + ids = graph_db.get_by_metadata(filters, user_name=cube_id, status="activated") + if not ids: + return [] + nodes = graph_db.get_nodes(ids, include_embedding=True, user_name=cube_id) + except Exception: + logger.info( + "[Dream Context] existing Context lookup unavailable; will create new contexts." + ) + return [] + return [node for node in nodes or [] if isinstance(node, dict)] + + def _build_groups(self, memories: list[DreamContextSource]) -> list[DreamContextGroup]: + candidate_pools = self._build_candidate_pools(memories) + groups: list[DreamContextGroup] = [] + for pool_key, pool_memories in candidate_pools: + groups.extend(self._bind_candidate_pool(pool_key=pool_key, memories=pool_memories)) + return groups + + def _build_candidate_pools( + self, memories: list[DreamContextSource] + ) -> list[tuple[str, list[DreamContextSource]]]: + grouped: dict[str, list[DreamContextSource]] = {} + unbound: list[DreamContextSource] = [] + for memory in memories: + weak_context_id = memory.dream.get("weak_context_id") + if weak_context_id: + grouped.setdefault(str(weak_context_id), []).append(memory) + else: + unbound.append(memory) + + pools = list(grouped.items()) + if unbound: + pools.append(("unbound", unbound)) + return pools + + def _bind_candidate_pool( + self, *, pool_key: str, memories: list[DreamContextSource] + ) -> list[DreamContextGroup]: + if len(memories) == 1: + return [ + DreamContextGroup( + group_key=f"{pool_key}:singleton", + memories=memories, + strategy="singleton_skipped", + confidence=0.0, + should_persist=False, + ) + ] + + llm = self.context.get("shared", {}).get("llm") + if ( + not self.binding_llm_enabled + or llm is None + or len(memories) < self.binding_min_group_size + or len(memories) > self.binding_max_group_size + ): + return [self._fallback_group(pool_key=pool_key, memories=memories)] + + try: + return self._llm_bind_candidate_pool(llm=llm, pool_key=pool_key, memories=memories) + except Exception: + logger.warning( + "[Dream Context] binding LLM failed; using heuristic groups.", exc_info=True + ) + return [self._fallback_group(pool_key=pool_key, memories=memories)] + + def _llm_bind_candidate_pool( + self, *, llm, pool_key: str, memories: list[DreamContextSource] + ) -> list[DreamContextGroup]: + units = _build_binding_units(memories) + prompt = CONTEXT_BINDING_PROMPT.format(memories_block=_format_binding_units_block(units)) + response = llm.generate([{"role": "user", "content": prompt}]) + raw = _parse_json_object(response) + groups = _parse_binding_groups( + raw=raw, + pool_key=pool_key, + units=units, + confidence_threshold=self.binding_confidence_threshold, + ) + if groups: + return groups + return [self._fallback_group(pool_key=pool_key, memories=memories)] + + @staticmethod + def _fallback_group(*, pool_key: str, memories: list[DreamContextSource]) -> DreamContextGroup: + if _is_batch_pool(pool_key) and len(memories) > 1: + return DreamContextGroup( + group_key=pool_key, + memories=memories, + strategy="batch", + confidence=0.85, + should_persist=True, + ) + return DreamContextGroup( + group_key=pool_key, + memories=memories, + strategy="weak_skipped", + confidence=0.0, + should_persist=False, + ) + + @staticmethod + def _match_existing_context( + group: DreamContextGroup, existing_contexts: list[dict[str, Any]] + ) -> dict[str, Any] | None: + if group.strategy.startswith("llm"): + return _match_context_by_memory_overlap(group, existing_contexts) + + group_weak_ids = _group_weak_context_ids(group) + if not group_weak_ids: + return None + + best_context = None + best_overlap = 0 + for context_node in existing_contexts: + dream = _node_dream_info(context_node) + weak_ids = set(dream.get("weak_context_ids") or []) + overlap = len(group_weak_ids & weak_ids) + if overlap > best_overlap: + best_context = context_node + best_overlap = overlap + return best_context + + def _persist_group(self, *, graph_db, group: DreamContextGroup, context_node, cube_id: str): + existing_metadata = ( + context_node.get("metadata", {}) if isinstance(context_node, dict) else {} + ) + existing_dream = _node_dream_info(context_node) if context_node else {} + existing_memory_ids = existing_dream.get("memory_ids") or [] + memory_ids = _unique([*existing_memory_ids, *group.memory_ids]) + + key, summary, summary_confidence, summary_strategy = self._summarize_group( + group=group, + existing_key=( + existing_metadata.get("key", "") if existing_metadata else group.proposed_key or "" + ), + existing_memory=context_node.get("memory", "") if context_node else "", + ) + now = datetime.now(timezone.utc).isoformat() + context_id = ( + context_node.get("id") if context_node else f"{_CONTEXT_ID_PREFIX}{uuid4().hex}" + ) + created_at = existing_metadata.get("created_at") or now + confidence = max(group.confidence, summary_confidence) + + metadata = { + "memory_type": CONTEXT_MEMORY_TYPE, + "status": "activated", + "key": key, + "embedding": self._embed_key(key), + "source": "system", + "visibility": "private", + "tags": ["dream", "context"], + "confidence": confidence, + "created_at": created_at, + "updated_at": now, + "sources": _source_refs(group.memories), + "internal_info": { + DREAM_INTERNAL_INFO_KEY: { + "kind": "context", + "memory_ids": memory_ids, + "weak_context_ids": sorted(_group_weak_context_ids(group)), + "salience": _group_salience(group.memories), + "time_range": _group_time_range(group.memories), + "binding": { + "strategy": group.strategy, + "confidence": group.confidence, + }, + "summary": { + "strategy": summary_strategy, + "updated_from_memory_ids": group.memory_ids, + }, + } + }, + } + + if context_node: + graph_db.add_node(context_id, summary, metadata, user_name=cube_id) + return _build_context_event( + context_id=context_id, + action="updated", + key=key, + summary=summary, + metadata=metadata, + group=group, + summary_strategy=summary_strategy, + ) + + graph_db.add_node(context_id, summary, metadata, user_name=cube_id) + return _build_context_event( + context_id=context_id, + action="created", + key=key, + summary=summary, + metadata=metadata, + group=group, + summary_strategy=summary_strategy, + ) + + def _summarize_group( + self, *, group: DreamContextGroup, existing_key: str = "", existing_memory: str = "" + ) -> tuple[str, str, float, str]: + llm = self.context.get("shared", {}).get("llm") + if self.summary_llm_enabled and llm is not None: + prompt = CONTEXT_SUMMARY_PROMPT.format( + existing_key=existing_key or "(none)", + existing_memory=existing_memory or "(none)", + memories_block=_format_memories_block(group.memories), + ) + try: + response = llm.generate([{"role": "user", "content": prompt}]) + raw = _parse_json_object(response) + key = str(raw.get("key") or "").strip() + memory = str(raw.get("memory") or "").strip() + confidence = float(raw.get("confidence") or 0.0) + if key and memory: + return key, memory, max(0.0, min(1.0, confidence)), "llm" + except Exception: + logger.warning("[Dream Context] summary LLM failed; using fallback.", exc_info=True) + + key = existing_key or _fallback_key(group) + summary = _fallback_summary(group, existing_memory=existing_memory) + return key, summary, 0.5, "fallback" + + def _embed_key(self, key: str) -> list[float] | None: + embedder = self.context.get("shared", {}).get("embedder") + if embedder is None: + embedder = getattr(self.context.get("shared", {}).get("text_mem"), "embedder", None) + if embedder is None: + return None + try: + return embedder.embed([key])[0] + except Exception: + logger.info("[Dream Context] key embedding unavailable; continue without embedding.") + return None + + +def _node_dream_info(node: dict[str, Any] | None) -> dict[str, Any]: + metadata = node.get("metadata", {}) if isinstance(node, dict) else {} + internal_info = _coerce_json_dict( + metadata.get("internal_info") if isinstance(metadata, dict) else {} + ) + dream = internal_info.get(DREAM_INTERNAL_INFO_KEY) + return dream if isinstance(dream, dict) else {} + + +def _build_context_event( + *, + context_id: str, + action: str, + key: str, + summary: str, + metadata: dict[str, Any], + group: DreamContextGroup, + summary_strategy: str, +) -> dict[str, Any]: + dream = (metadata.get("internal_info") or {}).get(DREAM_INTERNAL_INFO_KEY) or {} + binding = dream.get("binding") if isinstance(dream, dict) else {} + weak_context_ids = dream.get("weak_context_ids") if isinstance(dream, dict) else [] + return { + "context_id": context_id, + "action": action, + "key": key, + "label": key, + "summary": summary, + "memory_ids": list(group.memory_ids), + "source_memory_ids": list(group.memory_ids), + "weak_context_ids": list(weak_context_ids or []), + "binding": binding if isinstance(binding, dict) else {}, + "binding_strategy": group.strategy, + "binding_confidence": group.confidence, + "summary_strategy": summary_strategy, + "confidence": metadata.get("confidence"), + } + + +def _match_context_by_memory_overlap( + group: DreamContextGroup, existing_contexts: list[dict[str, Any]] +) -> dict[str, Any] | None: + group_ids = set(group.memory_ids) + best_context = None + best_overlap = 0 + for context_node in existing_contexts: + dream = _node_dream_info(context_node) + existing_ids = set(dream.get("memory_ids") or []) + overlap = len(group_ids & existing_ids) + if overlap > best_overlap: + best_context = context_node + best_overlap = overlap + return best_context + + +def _group_weak_context_ids(group: DreamContextGroup) -> set[str]: + return { + str(memory.dream["weak_context_id"]) + for memory in group.memories + if memory.dream.get("weak_context_id") + } + + +def _is_batch_pool(pool_key: str) -> bool: + return pool_key.startswith("batch:") + + +def _group_salience(memories: list[DreamContextSource]) -> float: + score = len(memories) * 0.2 + for memory in memories: + salience = memory.dream.get("salience") if isinstance(memory.dream, dict) else {} + if not isinstance(salience, dict): + continue + score += 2.0 if salience.get("has_feedback") else 0.0 + score += 1.5 if salience.get("unresolved") else 0.0 + with suppress(TypeError, ValueError): + score += float(salience.get("emotional") or 0) + return round(min(10.0, score), 3) + + +def _group_time_range(memories: list[DreamContextSource]) -> dict[str, str | None]: + times = sorted(t for t in (memory.created_at for memory in memories) if t) + return { + "start": times[0] if times else None, + "end": times[-1] if times else None, + } + + +def _source_refs(memories: list[DreamContextSource]) -> list[dict[str, Any]]: + refs: list[dict[str, Any]] = [] + for memory in memories[:20]: + refs.append( + { + "type": "memory", + "message_id": memory.id, + "content": memory.memory[:300], + } + ) + return refs + + +def _build_binding_units(memories: list[DreamContextSource]) -> list[DreamBindingUnit]: + batch_groups: dict[str, list[DreamContextSource]] = {} + free_memories: list[DreamContextSource] = [] + for memory in memories: + weak_id = memory.dream.get("weak_context_id") if isinstance(memory.dream, dict) else None + if isinstance(weak_id, str) and weak_id.startswith("batch:"): + batch_groups.setdefault(weak_id, []).append(memory) + else: + free_memories.append(memory) + + raw_units = [*batch_groups.values(), *[[memory] for memory in free_memories]] + return [ + DreamBindingUnit(short_id=f"m{idx}", memories=unit_memories) + for idx, unit_memories in enumerate(raw_units, start=1) + ] + + +def _format_binding_units_block(units: list[DreamBindingUnit]) -> str: + lines: list[str] = [] + for unit in units: + unit_label = "batch" if len(unit.memories) > 1 else "memory" + lines.append(f"ID: {unit.short_id} ({unit_label}; real_ids={unit.real_ids})") + for memory in unit.memories: + metadata = memory.metadata or {} + dream = memory.dream or {} + key = metadata.get("key") + created_at = memory.created_at or metadata.get("created_at") + weak_id = dream.get("weak_context_id") + details = [] + if key: + details.append(f"key={key}") + if created_at: + details.append(f"created_at={created_at}") + if weak_id: + details.append(f"weak_context_id={weak_id}") + for hint_field in ("context_hint", "goal_hint"): + if dream.get(hint_field): + details.append(f"{hint_field}={dream[hint_field]}") + if dream.get("entity_hints"): + details.append(f"entity_hints={dream['entity_hints']}") + prefix = f"- real_id={memory.id}" + if details: + prefix += f" ({'; '.join(details)})" + lines.append(f"{prefix}: {memory.memory[:1000]}") + lines.append("") + return "\n".join(lines).strip() + + +def _parse_binding_groups( + *, + raw: dict[str, Any], + pool_key: str, + units: list[DreamBindingUnit], + confidence_threshold: float, +) -> list[DreamContextGroup]: + unit_by_short_id = {unit.short_id: unit for unit in units} + assigned: set[str] = set() + groups: list[DreamContextGroup] = [] + + contexts = raw.get("contexts") + if not isinstance(contexts, list): + contexts = [] + + for idx, context in enumerate(contexts, start=1): + if not isinstance(context, dict): + continue + confidence = _safe_confidence(context.get("confidence")) + if confidence < confidence_threshold: + continue + short_ids = context.get("ids") + if not isinstance(short_ids, list): + short_ids = context.get("memory_ids") + if not isinstance(short_ids, list): + continue + + selected_units: list[DreamBindingUnit] = [] + valid = True + for raw_short_id in short_ids: + short_id = str(raw_short_id) + unit = unit_by_short_id.get(short_id) + if unit is None or short_id in assigned: + valid = False + break + selected_units.append(unit) + if not valid or not selected_units: + continue + + for unit in selected_units: + assigned.add(unit.short_id) + group_memories = [memory for unit in selected_units for memory in unit.memories] + should_persist = len(selected_units) > 1 or any( + len(unit.memories) > 1 for unit in selected_units + ) + groups.append( + DreamContextGroup( + group_key=f"{pool_key}:llm:{idx}", + memories=group_memories, + strategy="llm", + confidence=confidence, + proposed_key=str(context.get("key") or "").strip() or None, + should_persist=should_persist, + ) + ) + + for unit in units: + if unit.short_id in assigned: + continue + groups.append( + DreamContextGroup( + group_key=f"{pool_key}:unassigned:{unit.short_id}", + memories=unit.memories, + strategy="llm_unassigned", + confidence=0.0, + should_persist=False, + ) + ) + return groups + + +def _safe_confidence(value: Any) -> float: + with suppress(TypeError, ValueError): + return max(0.0, min(1.0, float(value))) + return 0.0 + + +def _format_memories_block(memories: list[DreamContextSource]) -> str: + lines: list[str] = [] + for memory in memories: + dream = memory.dream or {} + hints = [] + for hint_field in ("context_hint", "goal_hint"): + if dream.get(hint_field): + hints.append(f"{hint_field}={dream[hint_field]}") + if dream.get("entity_hints"): + hints.append(f"entity_hints={dream['entity_hints']}") + hint_text = f" ({'; '.join(hints)})" if hints else "" + created_at = f" created_at={memory.created_at}" if memory.created_at else "" + lines.append(f"- [{memory.id}]{created_at}{hint_text} {memory.memory[:1200]}") + return "\n".join(lines) + + +def _fallback_key(group: DreamContextGroup) -> str: + for memory in group.memories: + context_hint = memory.dream.get("context_hint") if isinstance(memory.dream, dict) else None + if context_hint: + return str(context_hint)[:80] + if group.group_key.startswith(("project:", "session:", "batch:")): + return group.group_key + first = group.memories[0].memory.strip().replace("\n", " ") if group.memories else "Context" + return first[:40] or "Context" + + +def _fallback_summary(group: DreamContextGroup, *, existing_memory: str = "") -> str: + parts = [] + if existing_memory: + parts.append(existing_memory.strip()) + parts.extend(memory.memory.strip() for memory in group.memories[:8] if memory.memory.strip()) + return "\n".join(_unique(parts))[:2000] + + +def _unique(values: list[Any]) -> list[Any]: + seen = set() + result = [] + for value in values: + marker = ( + json.dumps(value, sort_keys=True, ensure_ascii=False) + if isinstance(value, dict) + else value + ) + if marker in seen: + continue + seen.add(marker) + result.append(value) + return result + + +def _parse_json_object(text: str) -> dict[str, Any]: + cleaned = text.strip() + if cleaned.startswith("```"): + cleaned = cleaned.removeprefix("```json").removeprefix("```").removesuffix("```").strip() + raw = json.loads(cleaned) + if not isinstance(raw, dict): + raise ValueError("Expected JSON object") + return raw + + +def _coerce_json_dict(value: Any) -> dict[str, Any]: + if isinstance(value, dict): + return value + if isinstance(value, str) and value.strip().startswith("{"): + try: + parsed = json.loads(value) + except json.JSONDecodeError: + return {} + return parsed if isinstance(parsed, dict) else {} + return {} diff --git a/src/memos/dream/enrichment.py b/src/memos/dream/enrichment.py new file mode 100644 index 000000000..eb261a47c --- /dev/null +++ b/src/memos/dream/enrichment.py @@ -0,0 +1,284 @@ +from __future__ import annotations + +import os +import re + +from typing import Any + + +DREAM_INTERNAL_INFO_KEY = "dream" +DREAM_HEURISTIC_ENRICHER_VERSION = "0.1.0" + +_DEFAULT_SESSION_ID = "default_session" +_ENV_HEURISTIC_ENRICHER = "MEMOS_DREAM_HEURISTIC_ENRICHER" +_ENV_ENRICH_OVERWRITE = "MEMOS_DREAM_ENRICH_OVERWRITE" + +_QUESTION_RE = re.compile( + r"[??]|(?:\b(?:what|why|how|when|where|who|which|can|could|should)\b)", re.I +) +_AGENT_FEEDBACK_RE = re.compile( + "|".join( + [ + r"(?:你|您|助手|助理|agent|Agent|模型|系统).{0,12}(?:说错|弄错|搞错|理解错|误解|答错|回答错)", + r"(?:你|您|助手|助理|agent|Agent|模型|系统).{0,12}(?:回答|回复|理解).{0,8}(?:不对|错误|有误|不准确)", + r"(?:刚才|上面|前面|上一条).{0,12}(?:回答|回复|说法|理解).{0,8}(?:不对|错了|错误|有误|不准确)", + r"(?:这|那).{0,6}(?:不是|并不是).{0,8}(?:我说的意思|我的意思|我要的|我想要的)", + r"(?:你|您).{0,8}(?:没|没有).{0,6}(?:理解|明白|懂).{0,8}(?:我|我的意思)", + r"\bnot quite\b", + r"\bthat(?:'s| is) (?:wrong|incorrect|not right)\b", + r"\byou(?:'re| are)? wrong\b", + r"\byou (?:misunderstood|got it wrong)\b", + r"\byour (?:answer|response|reply|understanding) (?:is|was) (?:wrong|incorrect|not right|inaccurate)\b", + ] + ), + re.I, +) + + +def is_dream_heuristic_enricher_enabled() -> bool: + """Return whether the built-in rule-based Dream enricher should run.""" + return os.getenv(_ENV_HEURISTIC_ENRICHER, "on").strip().lower() not in { + "0", + "false", + "no", + "off", + } + + +def should_overwrite_dream_enrichment() -> bool: + """Return whether heuristic enrichment may overwrite existing Dream fields.""" + return os.getenv(_ENV_ENRICH_OVERWRITE, "off").strip().lower() in { + "1", + "true", + "yes", + "on", + } + + +class DreamHeuristicEnricher: + """Rule-based enrichment for Dream context binding. + + This stage deliberately avoids LLM calls. It only writes deterministic, + cheap signals that help a later semantic enricher or Dream binding stage + reason about context membership. + """ + + def __init__(self, *, enabled: bool | None = None, overwrite: bool | None = None) -> None: + self.enabled = is_dream_heuristic_enricher_enabled() if enabled is None else enabled + self.overwrite = should_overwrite_dream_enrichment() if overwrite is None else overwrite + + def enrich_items(self, *, items, user_context=None, extract_mode: str = "fine", **_: Any): + if not self.enabled or extract_mode != "fine" or not items: + return items + + batch_context_id = self._batch_context_id(items) + for item in items: + self._enrich_item( + item=item, + user_context=user_context, + batch_context_id=batch_context_id, + ) + return items + + def _enrich_item(self, *, item, user_context, batch_context_id: str | None) -> None: + metadata = getattr(item, "metadata", None) + if metadata is None: + return + + internal_info = getattr(metadata, "internal_info", None) + if not isinstance(internal_info, dict): + internal_info = {} + + dream_info = internal_info.get(DREAM_INTERNAL_INFO_KEY) + if not isinstance(dream_info, dict): + dream_info = {} + + sources = list(_iter_sources(getattr(metadata, "sources", None))) + source_roles = _source_roles(sources) + user_text = _joined_source_text(sources, roles={"user"}) + all_text = _joined_source_text(sources) or getattr(item, "memory", "") or "" + chunk_index = _first_present( + internal_info.get("chunk_index"), + _first_source_value(sources, "chunk_index"), + ) + chunk_total = _first_present( + internal_info.get("chunk_total"), + _first_source_value(sources, "chunk_total"), + ) + ingest_batch_id = _first_present( + internal_info.get("ingest_batch_id"), + _first_source_value(sources, "ingest_batch_id"), + ) + is_chunk = chunk_index is not None or (isinstance(chunk_total, int) and chunk_total > 1) + + weak_context_id = self._weak_context_id( + metadata=metadata, + user_context=user_context, + batch_context_id=batch_context_id, + ingest_batch_id=ingest_batch_id, + is_chunk=is_chunk, + ) + correction_text = user_text or all_text + has_correction = _has_agent_feedback(correction_text) + + self._set_if_missing(dream_info, "weak_context_id", weak_context_id) + if batch_context_id: + self._set_if_missing(dream_info, "batch_context_id", batch_context_id) + + signals = dream_info.get("signals") + if not isinstance(signals, dict): + signals = {} + dream_info["signals"] = signals + self._set_if_missing(signals, "source_roles", source_roles) + self._set_if_missing(signals, "is_chunk", bool(is_chunk)) + self._set_if_missing(signals, "chunk_index", chunk_index) + self._set_if_missing(signals, "chunk_total", chunk_total) + self._set_if_missing( + signals, + "has_question", + bool(_QUESTION_RE.search(user_text or all_text)), + ) + self._set_if_missing(signals, "has_correction", has_correction) + + salience = dream_info.get("salience") + if not isinstance(salience, dict): + salience = {} + dream_info["salience"] = salience + self._set_if_missing(salience, "has_feedback", has_correction) + + enriched_by = dream_info.get("enriched_by") + if not isinstance(enriched_by, dict): + enriched_by = {} + dream_info["enriched_by"] = enriched_by + self._set_if_missing(enriched_by, "heuristic", DREAM_HEURISTIC_ENRICHER_VERSION) + + internal_info[DREAM_INTERNAL_INFO_KEY] = dream_info + metadata.internal_info = internal_info + + def _set_if_missing(self, target: dict[str, Any], key: str, value: Any) -> None: + if self.overwrite or key not in target: + target[key] = value + + @staticmethod + def _batch_context_id(items) -> str | None: + batch_ids: set[str] = set() + for item in items: + metadata = getattr(item, "metadata", None) + internal_info = getattr(metadata, "internal_info", None) + if isinstance(internal_info, dict) and internal_info.get("ingest_batch_id"): + batch_ids.add(str(internal_info["ingest_batch_id"])) + if len(batch_ids) == 1: + return f"batch:{next(iter(batch_ids))}" + return None + + @staticmethod + def _weak_context_id( + *, + metadata, + user_context, + batch_context_id: str | None, + ingest_batch_id: Any, + is_chunk: bool, + ) -> str | None: + if is_chunk: + if batch_context_id: + return batch_context_id + if ingest_batch_id: + return f"batch:{ingest_batch_id}" + + project_id = _first_present( + getattr(metadata, "project_id", None), + getattr(user_context, "project_id", None), + ) + if project_id: + return f"project:{project_id}" + + session_id = _first_present( + getattr(metadata, "session_id", None), + getattr(user_context, "session_id", None), + ) + if session_id and session_id != _DEFAULT_SESSION_ID: + return f"session:{session_id}" + + return None + + +def on_memory_items_after_fine_extract( + plugin, *, items, user_context, mem_reader, extract_mode, **kw +): + enricher = getattr(plugin, "heuristic_enricher", None) + if enricher is None: + return items + return enricher.enrich_items( + items=items, + user_context=user_context, + mem_reader=mem_reader, + extract_mode=extract_mode, + **kw, + ) + + +def _iter_sources(sources: Any) -> list[Any]: + if not sources: + return [] + if isinstance(sources, list): + return sources + return [sources] + + +def _source_roles(sources: list[Any]) -> list[str]: + roles: list[str] = [] + for source in sources: + role = _source_value(source, "role") + if role and role not in roles: + roles.append(str(role)) + return roles + + +def _joined_source_text(sources: list[Any], roles: set[str] | None = None) -> str: + parts: list[str] = [] + for source in sources: + role = _source_value(source, "role") + if roles is not None and role not in roles: + continue + content = _source_value(source, "content") + if content: + parts.append(str(content)) + return "\n".join(parts) + + +def _has_agent_feedback(text: str) -> bool: + """Detect only high-confidence user feedback about the agent's response. + + This intentionally avoids broad discourse markers such as "不对", "其实是", + "actually", or "I mean". Those are common in ordinary user statements and + cause false feedback signals. Dream can afford to miss weak signals here; + strong feedback should explicitly target the assistant/answer/understanding. + """ + if not text: + return False + return bool(_AGENT_FEEDBACK_RE.search(text)) + + +def _first_source_value(sources: list[Any], key: str) -> Any: + for source in sources: + value = _source_value(source, key) + if value is not None: + return value + file_info = _source_value(source, "file_info") + if isinstance(file_info, dict) and file_info.get(key) is not None: + return file_info[key] + return None + + +def _source_value(source: Any, key: str) -> Any: + if isinstance(source, dict): + return source.get(key) + return getattr(source, key, None) + + +def _first_present(*values: Any) -> Any: + for value in values: + if value is not None and value != "": + return value + return None diff --git a/src/memos/dream/pipeline/__init__.py b/src/memos/dream/pipeline/__init__.py index 23736fddd..0a7dea98c 100644 --- a/src/memos/dream/pipeline/__init__.py +++ b/src/memos/dream/pipeline/__init__.py @@ -1,3 +1,4 @@ +from memos.dream.contextualization import DreamContextualizer from memos.dream.pipeline.base import AbstractDreamPipeline from memos.dream.pipeline.diary import StructuredDiarySummary from memos.dream.pipeline.motive import MotiveFormation @@ -10,6 +11,7 @@ "AbstractDreamPipeline", "ConsolidationReasoning", "DirectRecall", + "DreamContextualizer", "DreamPersistence", "MotiveFormation", "StructuredDiarySummary", diff --git a/src/memos/dream/pipeline/base.py b/src/memos/dream/pipeline/base.py index b00015781..9c9235853 100644 --- a/src/memos/dream/pipeline/base.py +++ b/src/memos/dream/pipeline/base.py @@ -18,7 +18,9 @@ def __init__( reasoning_strategy, diary_strategy, persistence_strategy, + context_strategy=None, ) -> None: + self.context_strategy = context_strategy self.motive_strategy = motive_strategy self.recall_strategy = recall_strategy self.reasoning_strategy = reasoning_strategy @@ -29,12 +31,15 @@ def __init__( def bind_context(self, context: dict[str, Any]) -> None: self.context = context for component in ( + self.context_strategy, self.motive_strategy, self.recall_strategy, self.reasoning_strategy, self.diary_strategy, self.persistence_strategy, ): + if component is None: + continue bind_context = getattr(component, "bind_context", None) if callable(bind_context): bind_context(context) @@ -48,6 +53,17 @@ def run( signal_snapshot, text_mem, ): + # Step 0: materialize Context nodes from pending memories. This stage is + # intentionally independent from insight reasoning; failure should not + # prevent the rest of Dream from running. + self.last_context_report = None + if self.context_strategy is not None: + self.last_context_report = self.context_strategy.run( + signal_snapshot=signal_snapshot, + text_mem=text_mem, + cube_id=cube_id, + ) + # Step 1: build Dream clusters from the scheduler payload. clusters = self.motive_strategy.form( signal_snapshot=signal_snapshot, @@ -74,6 +90,7 @@ def run( clusters=clusters, results=results, mem_cube_id=mem_cube_id, + context_report=self.last_context_report, ) # Step 4: hand persistence over to the final strategy. diff --git a/src/memos/dream/pipeline/diary.py b/src/memos/dream/pipeline/diary.py index 15f0a5d5d..64a7b9cc5 100644 --- a/src/memos/dream/pipeline/diary.py +++ b/src/memos/dream/pipeline/diary.py @@ -2,13 +2,13 @@ import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any -from memos.dream.types import DreamDiaryEntry +from memos.dream.types import DreamDiaryEntry, DreamResult if TYPE_CHECKING: - from memos.dream.types import DreamAction, DreamCluster, DreamResult + from memos.dream.types import DreamAction, DreamCluster logger = logging.getLogger(__name__) @@ -77,11 +77,22 @@ def generate( clusters: list[DreamCluster], results: list[DreamResult], mem_cube_id: str, + context_report: Any | None = None, ) -> list[DreamResult]: cluster_map = {c.cluster_id: c for c in clusters} for result in results: cluster = cluster_map.get(result.cluster_id) result.diary_entry = self._build_entry(cluster, result) + + context_events = _context_events(context_report) + if context_events: + results.append( + DreamResult( + cluster_id=f"context:{mem_cube_id}", + actions=[], + diary_entry=self._build_context_entry(context_report, context_events), + ) + ) return results def _build_entry(self, cluster: DreamCluster | None, result: DreamResult) -> DreamDiaryEntry: @@ -112,6 +123,39 @@ def _build_entry(self, cluster: DreamCluster | None, result: DreamResult) -> Dre status=status, ) + def _build_context_entry( + self, context_report: Any, context_events: list[dict[str, Any]] + ) -> DreamDiaryEntry: + created_count = getattr(context_report, "created_context_count", 0) + updated_count = getattr(context_report, "updated_context_count", 0) + bound_count = getattr(context_report, "bound_memory_count", 0) + skipped_count = getattr(context_report, "skipped_memory_count", 0) + + labels = [event.get("label") or event.get("key") for event in context_events] + labels = [label for label in labels if label] + summary = ( + f"Dream processed Context bindings: created {created_count}, updated {updated_count}, " + f"bound {bound_count} source memories, skipped {skipped_count}." + ) + if labels: + summary = f"{summary} Labels: {', '.join(labels[:5])}." + + dream_entry = "\n".join(_format_context_event_for_diary(event) for event in context_events) + return DreamDiaryEntry( + title="Dream Context Summary", + summary=summary, + dream_entry=dream_entry, + motive={ + "type": "context", + "why_now": "Context binding and summary were produced during this Dream run.", + "source_memory_count": bound_count, + "related_memory_count": len(context_events), + }, + context_events=context_events, + themes=["context"], + status="context_only", + ) + @staticmethod def _first_real_dream(actions: list[DreamAction]) -> DreamAction | None: """Return the first action that carries genuine dream content. @@ -146,3 +190,21 @@ def _make_title(motive_description: str) -> str: if len(first_sentence) <= _TITLE_MAX_LEN: return first_sentence return first_sentence[:_TITLE_MAX_LEN].rstrip() + "…" + + +def _context_events(context_report: Any | None) -> list[dict[str, Any]]: + if context_report is None: + return [] + contexts = getattr(context_report, "contexts", None) + if not isinstance(contexts, list): + return [] + return [event for event in contexts if isinstance(event, dict)] + + +def _format_context_event_for_diary(event: dict[str, Any]) -> str: + label = event.get("label") or event.get("key") or event.get("context_id", "Context") + action = event.get("action", "processed") + summary = event.get("summary", "") + source_ids = event.get("source_memory_ids") or event.get("memory_ids") or [] + source_count = len(source_ids) if isinstance(source_ids, list) else 0 + return f"- {action} Context `{label}` from {source_count} memories: {summary}" diff --git a/src/memos/dream/pipeline/persistence.py b/src/memos/dream/pipeline/persistence.py index a190fb4ed..1ab5adcdc 100644 --- a/src/memos/dream/pipeline/persistence.py +++ b/src/memos/dream/pipeline/persistence.py @@ -328,6 +328,7 @@ def _build_diary_metadata(*, entry, result, user_id, mem_cube_id, signal_snapsho "summary": entry.summary, "dream_entry": entry.dream_entry, "motive": entry.motive, + "context_events": entry.context_events, "themes": entry.themes, "tags": ["dream", "diary"], "created_at": entry.created_at.isoformat(), diff --git a/src/memos/dream/pipeline/recall.py b/src/memos/dream/pipeline/recall.py index fee9cfcd0..d1fa01742 100644 --- a/src/memos/dream/pipeline/recall.py +++ b/src/memos/dream/pipeline/recall.py @@ -27,7 +27,8 @@ class DirectRecall: Scope is restricted on purpose: Dream-produced nodes (DreamDiary, InsightMemory, …) and short-lived WorkingMemory are excluded so that each Dream run reflects on the user's real daytime experiences rather - than its own previous outputs. + than its own previous outputs. Context nodes are also excluded because + they are intermediate indexes, not original user memories. """ def __init__(self, *, recall_top_k: int = _RECALL_TOP_K) -> None: diff --git a/src/memos/dream/plugin.py b/src/memos/dream/plugin.py index e44a68df1..fbe098df7 100644 --- a/src/memos/dream/plugin.py +++ b/src/memos/dream/plugin.py @@ -6,6 +6,11 @@ from functools import partial from typing import Any +from memos.dream.contextualization import DreamContextualizer +from memos.dream.enrichment import ( + DreamHeuristicEnricher, + on_memory_items_after_fine_extract, +) from memos.dream.hooks import on_add_signal, on_dream_execute from memos.dream.pipeline import ( AbstractDreamPipeline, @@ -17,6 +22,7 @@ ) from memos.dream.routers.diary_router import create_diary_router from memos.dream.routers.trigger_router import create_trigger_router +from memos.dream.search import DreamContextSearchExtension from memos.dream.signal_store import DreamSignalStore from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import MEM_DREAM_TASK_LABEL @@ -44,7 +50,10 @@ class CommunityDreamPlugin(MemOSPlugin): def on_load(self) -> None: self.context: dict[str, Any] = {"shared": {}, "configs": {}} self.signal_store = DreamSignalStore() + self.heuristic_enricher = DreamHeuristicEnricher() + self.search_extension = DreamContextSearchExtension() self.pipeline = AbstractDreamPipeline( + context_strategy=DreamContextualizer(), motive_strategy=MotiveFormation(), recall_strategy=DirectRecall(), reasoning_strategy=ConsolidationReasoning(), @@ -55,7 +64,12 @@ def on_load(self) -> None: # Hook registration happens at load time because scheduler-triggered Dream # execution does not depend on FastAPI route binding. self.register_hook(H.DREAM_EXECUTE, partial(on_dream_execute, self)) + self.register_hook(H.SEARCH_MEMORY_RESULTS, self.search_extension.merge_context_recall) self.register_hook(H.ADD_AFTER, partial(on_add_signal, self)) + self.register_hook( + H.MEMORY_ITEMS_AFTER_FINE_EXTRACT, + partial(on_memory_items_after_fine_extract, self), + ) logger.info("[Dream] plugin loaded") def init_components(self, context: dict) -> None: @@ -71,6 +85,7 @@ def init_app(self) -> None: def on_shutdown(self) -> None: self.context = {"shared": {}, "configs": {}} + self.heuristic_enricher = None logger.info("[Dream] plugin shutdown") def submit_dream_task( diff --git a/src/memos/dream/prompts/__init__.py b/src/memos/dream/prompts/__init__.py index 07bd1415d..9b79b8167 100644 --- a/src/memos/dream/prompts/__init__.py +++ b/src/memos/dream/prompts/__init__.py @@ -1,8 +1,12 @@ +from memos.dream.prompts.context_binding_prompt import CONTEXT_BINDING_PROMPT +from memos.dream.prompts.context_summary_prompt import CONTEXT_SUMMARY_PROMPT from memos.dream.prompts.motive_prompt import MOTIVE_FORMATION_PROMPT from memos.dream.prompts.reasoning_prompt import CONSOLIDATION_REASONING_PROMPT __all__ = [ "CONSOLIDATION_REASONING_PROMPT", + "CONTEXT_BINDING_PROMPT", + "CONTEXT_SUMMARY_PROMPT", "MOTIVE_FORMATION_PROMPT", ] diff --git a/src/memos/dream/prompts/context_binding_prompt.py b/src/memos/dream/prompts/context_binding_prompt.py new file mode 100644 index 000000000..bb77c5daa --- /dev/null +++ b/src/memos/dream/prompts/context_binding_prompt.py @@ -0,0 +1,31 @@ +CONTEXT_BINDING_PROMPT = """You are grouping memories into Contexts for a long-term memory system. + +A Context is a continuing task, goal, topic, project thread, relationship, or unresolved problem. +Group memories only when they are about the same continuing context. + +Rules: +- Use the short IDs exactly as provided, such as "m1" or "m2". +- Each short ID can appear in at most one context. +- Do not group memories just because they share a session, project, entity, or broad topic. +- Group only when they are part of the same continuing user goal, task, decision, problem, or concrete theme. +- If unsure, leave the memory unassigned. +- Batch/chunk units should stay together unless the unit content clearly contains unrelated material. +- Do not invent facts or IDs. +- The key should be concise and specific. + +Candidate memories: +{memories_block} + +Return strict JSON only: +{{ + "contexts": [ + {{ + "key": "short context label", + "ids": ["m1", "m2"], + "confidence": 0.0, + "reason": "brief reason" + }} + ], + "unassigned_ids": ["m3"] +}} +""" diff --git a/src/memos/dream/prompts/context_summary_prompt.py b/src/memos/dream/prompts/context_summary_prompt.py new file mode 100644 index 000000000..1943e3717 --- /dev/null +++ b/src/memos/dream/prompts/context_summary_prompt.py @@ -0,0 +1,27 @@ +CONTEXT_SUMMARY_PROMPT = """You are maintaining a Context memory for a long-term memory system. + +A Context is a compact index node. Its `key` is a short label, and its `memory` +is a faithful summary of the memories already bound to that context. + +Rules: +- Use only the provided memories and existing context text. +- Do not infer personality traits or hidden motives. +- Preserve concrete project names, people, objects, decisions, constraints, and unresolved questions. +- Prefer specificity over broad topics like "work" or "planning". +- The key should be concise: 8-15 Chinese characters or 3-8 English words. +- The memory summary should be compact but complete: 200-500 Chinese characters or 120-250 English words. + +Existing context: +Key: {existing_key} +Memory: {existing_memory} + +Bound memories: +{memories_block} + +Return strict JSON only: +{{ + "key": "short context label", + "memory": "faithful context summary", + "confidence": 0.0 +}} +""" diff --git a/src/memos/dream/routers/diary_router.py b/src/memos/dream/routers/diary_router.py index e4b301953..f82e16c17 100644 --- a/src/memos/dream/routers/diary_router.py +++ b/src/memos/dream/routers/diary_router.py @@ -55,7 +55,7 @@ def query_diaries(req: DiaryQueryRequest) -> dict[str, object]: return {"code": 200, "message": "Dream diary retrieved successfully", "data": items} filters = [{"field": "memory_type", "op": "=", "value": "DreamDiary"}] - ids = graph_db.get_by_metadata(filters, user_name=req.cube_id, status="activated") + ids = graph_db.get_by_metadata(filters, user_name=req.cube_id) if not ids: return {"code": 200, "message": "Dream diary retrieved successfully", "data": []} @@ -88,5 +88,6 @@ def _format_item(node: dict) -> dict: "summary": meta.get("summary", ""), "dream_entry": meta.get("dream_entry", ""), "motive": meta.get("motive"), + "context_events": meta.get("context_events", []), "themes": meta.get("themes", []), } diff --git a/src/memos/dream/search.py b/src/memos/dream/search.py new file mode 100644 index 000000000..7e33b0381 --- /dev/null +++ b/src/memos/dream/search.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +import logging + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from memos.dream.contextualization import CONTEXT_MEMORY_TYPE + + +if TYPE_CHECKING: + from memos.api.product_models import APISearchRequest + + +logger = logging.getLogger(__name__) + +_DEFAULT_CONTEXT_RECALL_TOP_K = 2 +_CONTEXT_RETURN_FIELDS = [ + "memory", + "key", + "created_at", + "updated_at", + "source", + "internal_info", +] + + +@dataclass +class DreamContextSearchExtension: + """Dream-owned search extension for recalling Context nodes. + + The core SearchHandler only exposes a generic plugin hook. This extension + owns Dream-specific retrieval details such as the Context memory type, + graph scope, metadata formatting, and fallback behavior. + """ + + top_k: int = _DEFAULT_CONTEXT_RECALL_TOP_K + + def merge_context_recall( + self, + *, + handler, + search_req: APISearchRequest, + results: dict[str, Any], + ) -> dict[str, Any]: + top_k = max(0, int(self.top_k or 0)) + if top_k <= 0: + return results + + context_buckets = self._recall_context_buckets( + handler=handler, + search_req=search_req, + top_k=top_k, + ) + if context_buckets: + results.setdefault("text_mem", []).extend(context_buckets) + return results + + def _recall_context_buckets( + self, *, handler, search_req: APISearchRequest, top_k: int + ) -> list[dict[str, Any]]: + graph_db = getattr(handler, "graph_db", None) or getattr( + handler.searcher, "graph_store", None + ) + embedder = getattr(handler, "embedder", None) or getattr(handler.searcher, "embedder", None) + if graph_db is None or embedder is None: + logger.info("[Dream Search] Context recall skipped: graph_db or embedder unavailable.") + return [] + + try: + query_embedding = embedder.embed([search_req.query])[0] + except Exception: + logger.warning("[Dream Search] Context recall embedding failed.", exc_info=True) + return [] + + buckets: list[dict[str, Any]] = [] + for cube_id in _resolve_cube_ids(search_req): + try: + hits = graph_db.search_by_embedding( + query_embedding, + top_k=top_k, + scope=CONTEXT_MEMORY_TYPE, + status="activated", + user_name=cube_id, + return_fields=_CONTEXT_RETURN_FIELDS, + ) + except Exception: + logger.warning( + "[Dream Search] Context recall search failed for cube=%s.", + cube_id, + exc_info=True, + ) + continue + + memories = [_format_context_hit(hit) for hit in hits or [] if hit.get("memory")] + if not memories: + continue + buckets.append( + { + "cube_id": cube_id, + "memories": memories, + "total_nodes": len(memories), + } + ) + return buckets + + +def _resolve_cube_ids(search_req: APISearchRequest) -> list[str]: + if search_req.readable_cube_ids: + return list(dict.fromkeys(search_req.readable_cube_ids)) + return [search_req.user_id] + + +def _format_context_hit(hit: dict[str, Any]) -> dict[str, Any]: + context_id = str(hit.get("id", "")) + score = float(hit.get("score", 0.0) or 0.0) + metadata = { + "id": context_id, + "memory": hit.get("memory", ""), + "memory_type": CONTEXT_MEMORY_TYPE, + "source": hit.get("source") or "dream", + "key": hit.get("key", ""), + "relativity": score, + "score": score, + "embedding": [], + "sources": [], + "usage": [], + "ref_id": f"[{context_id.split('-')[0]}]" if context_id else "[context]", + } + for field in ("created_at", "updated_at", "internal_info"): + if hit.get(field) is not None: + metadata[field] = hit[field] + + return { + "id": context_id, + "memory": hit.get("memory", ""), + "metadata": metadata, + "ref_id": metadata["ref_id"], + } diff --git a/src/memos/dream/types.py b/src/memos/dream/types.py index 40587b70b..386323d53 100644 --- a/src/memos/dream/types.py +++ b/src/memos/dream/types.py @@ -128,6 +128,7 @@ class DreamDiaryEntry(BaseModel): summary: str dream_entry: str = "" motive: dict | None = None + context_events: list[dict[str, Any]] = Field(default_factory=list) themes: list[str] = Field(default_factory=list) created_at: datetime = Field(default_factory=datetime.utcnow) status: str = "completed" @@ -137,6 +138,13 @@ def format_content(self) -> str: parts = [self.title, self.summary] if self.dream_entry: parts.append(self.dream_entry) + if self.context_events: + context_lines = [] + for event in self.context_events: + label = event.get("label") or event.get("key") or event.get("context_id", "Context") + summary = event.get("summary", "") + context_lines.append(f"- {label}: {summary}".strip()) + parts.append("Context updates:\n" + "\n".join(context_lines)) return "\n\n".join(parts) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 856f94f2a..33a79aa75 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1956,16 +1956,21 @@ def get_by_metadata( filter: dict | None = None, knowledgebase_ids: list | None = None, user_name_flag: bool = True, + status: str | None = None, ) -> list[str]: start_time = time.perf_counter() logger.info( - f" get_by_metadata user_name:{user_name},filter: {filter}, knowledgebase_ids: {knowledgebase_ids},filters:{filters}" + f" get_by_metadata user_name:{user_name},filter: {filter}, knowledgebase_ids: {knowledgebase_ids},filters:{filters},status:{status}" ) user_name = user_name if user_name else self._get_config_value("user_name") where_conditions = [] + if status: + escaped_status = status.replace("'", "\\'") + where_conditions.append(f"n.status = '{escaped_status}'") + for f in filters: field = f["field"] op = f.get("op", "=") diff --git a/src/memos/mem_reader/read_multi_modal/image_parser.py b/src/memos/mem_reader/read_multi_modal/image_parser.py index 43674a9cc..ec19ba304 100644 --- a/src/memos/mem_reader/read_multi_modal/image_parser.py +++ b/src/memos/mem_reader/read_multi_modal/image_parser.py @@ -101,10 +101,43 @@ def parse_fast( info: dict[str, Any], **kwargs, ) -> list[TextualMemoryItem]: - """Parse image_url in fast mode - returns empty list as images need fine mode processing.""" - # In fast mode, images are not processed (they need vision models) - # They will be processed in fine mode via process_transfer - return [] + """Parse image_url in fast mode by preserving the source for fine mode.""" + if not isinstance(message, dict): + logger.warning(f"[ImageParser] Expected dict, got {type(message)}") + return [] + + source = self.create_source(message, info) + url = getattr(source, "url", None) or getattr(source, "content", "") + if not url: + logger.warning("[ImageParser] No image URL found in fast mode message") + return [] + + info_ = info.copy() + user_id = info_.pop("user_id", "") + session_id = info_.pop("session_id", "") + content = f"[image_url]: {url}" + need_emb = kwargs.get("need_emb", True) + + return [ + TextualMemoryItem( + memory=content, + metadata=TreeNodeTextualMemoryMetadata( + user_id=user_id, + session_id=session_id, + memory_type="UserMemory", + status="activated", + tags=["mode:fast", "multimodal:image"], + key=_derive_key(content), + embedding=self.embedder.embed([content])[0] if need_emb else None, + usage=[], + sources=[source], + background="", + confidence=0.99, + type="fact", + info=info_, + ), + ) + ] def parse_fine( self, diff --git a/src/memos/mem_reader/read_skill_memory/process_skill_memory.py b/src/memos/mem_reader/read_skill_memory/process_skill_memory.py index 269372f25..dde6241b4 100644 --- a/src/memos/mem_reader/read_skill_memory/process_skill_memory.py +++ b/src/memos/mem_reader/read_skill_memory/process_skill_memory.py @@ -1022,6 +1022,10 @@ def process_skill_memory_fine( chat_history = [] messages = _reconstruct_messages_from_memory_items(fast_memory_items) + tool_rounds = sum(1 for message in messages if message.get("role") == "tool") + if tool_rounds < 5: + logger.info(f"[PROCESS_SKILLS] Skip skill extraction: tool rounds {tool_rounds} < 5") + return [] chat_history, messages = _preprocess_extract_messages(chat_history, messages) if not messages: diff --git a/src/memos/mem_scheduler/base_mixins/queue_ops.py b/src/memos/mem_scheduler/base_mixins/queue_ops.py index 13de79b3d..e8d215dc6 100644 --- a/src/memos/mem_scheduler/base_mixins/queue_ops.py +++ b/src/memos/mem_scheduler/base_mixins/queue_ops.py @@ -10,6 +10,7 @@ from memos.context.context import ( ContextThread, RequestContext, + get_current_api_path, get_current_context, get_current_trace_id, set_request_context, @@ -38,6 +39,7 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt return current_trace_id = get_current_trace_id() + current_api_path = get_current_api_path() immediate_msgs: list[ScheduleMessageItem] = [] queued_msgs: list[ScheduleMessageItem] = [] @@ -45,6 +47,8 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt for msg in messages: if current_trace_id: msg.trace_id = current_trace_id + if current_api_path and not getattr(msg, "api_path", None): + msg.api_path = current_api_path with suppress(Exception): self.metrics.task_enqueued(user_id=msg.user_id, task_type=msg.label) @@ -173,6 +177,7 @@ def _message_consumer(self) -> None: try: msg_context = RequestContext( trace_id=msg.trace_id, + api_path=msg.api_path, user_name=msg.user_name, ) set_request_context(msg_context) diff --git a/src/memos/mem_scheduler/base_mixins/web_log_ops.py b/src/memos/mem_scheduler/base_mixins/web_log_ops.py index 64b5348d3..7fba9916b 100644 --- a/src/memos/mem_scheduler/base_mixins/web_log_ops.py +++ b/src/memos/mem_scheduler/base_mixins/web_log_ops.py @@ -1,6 +1,7 @@ from __future__ import annotations from memos.log import get_logger +from memos.context.context import get_current_api_path from memos.mem_scheduler.schemas.message_schemas import ScheduleLogForWebItem from memos.mem_scheduler.schemas.task_schemas import ( ADD_TASK_LABEL, @@ -28,6 +29,9 @@ def _submit_web_logs( if self.rabbitmq_config is None: return try: + current_api_path = get_current_api_path() + if current_api_path and not getattr(message, "api_path", None): + message.api_path = current_api_path logger.info( "[DIAGNOSTIC] base_scheduler._submit_web_logs: enqueue publish %s", message.model_dump_json(indent=2), diff --git a/src/memos/mem_scheduler/general_modules/scheduler_logger.py b/src/memos/mem_scheduler/general_modules/scheduler_logger.py index fd83ec86f..495b7b9ea 100644 --- a/src/memos/mem_scheduler/general_modules/scheduler_logger.py +++ b/src/memos/mem_scheduler/general_modules/scheduler_logger.py @@ -3,6 +3,7 @@ from collections.abc import Callable from memos.log import get_logger +from memos.context.context import get_current_api_path from memos.mem_cube.general import GeneralMemCube from memos.mem_scheduler.general_modules.base import BaseSchedulerModule from memos.mem_scheduler.schemas.general_schemas import ( @@ -125,6 +126,7 @@ def create_autofilled_log_item( log_content=log_content, current_memory_sizes=current_memory_sizes, memory_capacities=memory_capacities, + api_path=get_current_api_path(), ) return log_message diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py index c11d30470..29e15198f 100644 --- a/src/memos/mem_scheduler/schemas/message_schemas.py +++ b/src/memos/mem_scheduler/schemas/message_schemas.py @@ -53,6 +53,7 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): description="user name / display name (optional)", ) info: dict | None = Field(default=None, description="user custom info") + api_path: str | None = Field(default=None, description="source HTTP API path") task_id: str | None = Field( default=None, description="Optional business-level task ID. Multiple items can share the same task_id.", @@ -88,12 +89,15 @@ def to_dict(self) -> dict: "user_id": self.user_id, "cube_id": self.mem_cube_id, "trace_id": self.trace_id, + "session_id": self.session_id, "label": self.label, "cube": "Not Applicable", # Custom cube serialization "content": self.content, "timestamp": self.timestamp.isoformat(), "user_name": self.user_name, + "info": self.info if self.info is not None else {}, "task_id": self.task_id if self.task_id is not None else "", + "api_path": self.api_path if self.api_path is not None else "", "chat_history": self.chat_history if self.chat_history is not None else [], "user_context": self.user_context.model_dump(exclude_none=True) if self.user_context @@ -130,6 +134,18 @@ def _decode(val: Any) -> Any: else: chat_history = raw_chat_history + raw_info = _decode(data.get("info")) + if isinstance(raw_info, str): + if raw_info: + try: + info = json.loads(raw_info) + except Exception: + info = None + else: + info = None + else: + info = raw_info + raw_user_context = _decode(data.get("user_context")) if isinstance(raw_user_context, str): if raw_user_context: @@ -147,11 +163,14 @@ def _decode(val: Any) -> Any: user_id=_decode(data["user_id"]), mem_cube_id=_decode(data["cube_id"]), trace_id=_decode(data.get("trace_id", generate_trace_id())), + session_id=_decode(data.get("session_id", "")), label=_decode(data["label"]), content=_decode(data["content"]), timestamp=timestamp, user_name=_decode(data.get("user_name")), - task_id=_decode(data.get("task_id")), + info=info, + task_id=_decode(data.get("task_id")) or None, + api_path=_decode(data.get("api_path")), chat_history=chat_history, user_context=UserContext.model_validate(raw_user_context) if raw_user_context else None, ) @@ -209,6 +228,7 @@ class ScheduleLogForWebItem(BaseModel, DictConversionMixin): ) source_doc_id: str | None = Field(default=None, description="Source document ID") chat_history: list | None = Field(default=None, description="user chat history") + api_path: str | None = Field(default=None, description="source HTTP API path") def debug_info(self) -> dict[str, Any]: """Return structured debug information for logging purposes.""" diff --git a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py index 02cd59e8c..60dc04fcd 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py +++ b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py @@ -24,7 +24,7 @@ from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue from memos.mem_scheduler.task_schedule_modules.task_queue import ScheduleTaskQueue -from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube, is_cloud_env +from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube, is_playground_api from memos.mem_scheduler.utils.monitor_event_utils import emit_monitor_event, to_iso from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker @@ -140,6 +140,7 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): # Propagate trace_id and user info to logging context for this handler execution ctx = RequestContext( trace_id=trace_id, + api_path=getattr(first_msg, "api_path", None), user_name=getattr(first_msg, "user_name", None), user_type=None, ) @@ -317,8 +318,7 @@ def _maybe_emit_task_completion( mem_cube_id = first.mem_cube_id try: - cloud_env = is_cloud_env() - if not cloud_env: + if is_playground_api(): return for task_id in task_ids: @@ -345,6 +345,7 @@ def _maybe_emit_task_completion( log_content=f"Task {task_id} completed", status="completed", source_doc_id=source_doc_id, + api_path=getattr(messages[0], "api_path", None) if messages else None, ) self.submit_web_logs(event) @@ -369,6 +370,7 @@ def _maybe_emit_task_completion( log_content=f"Task {task_id} failed: {error_msg}", status="failed", source_doc_id=source_doc_id, + api_path=getattr(messages[0], "api_path", None) if messages else None, ) self.submit_web_logs(event) except Exception: diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/add_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/add_handler.py index e4a88a635..81e78d69f 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/handlers/add_handler.py +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/add_handler.py @@ -12,7 +12,7 @@ ) from memos.mem_scheduler.task_schedule_modules.base_handler import BaseSchedulerHandler from memos.mem_scheduler.utils.filter_utils import transform_name_to_key -from memos.mem_scheduler.utils.misc_utils import is_cloud_env +from memos.mem_scheduler.utils.misc_utils import is_playground_api if TYPE_CHECKING: @@ -38,14 +38,14 @@ def batch_handler( prepared_add_items, prepared_update_items_with_original, ) - cloud_env = is_cloud_env() + playground_api = is_playground_api() - if cloud_env: - self.send_add_log_messages_to_cloud_env( + if playground_api: + self.send_add_log_messages_to_local_env( msg, prepared_add_items, prepared_update_items_with_original ) else: - self.send_add_log_messages_to_local_env( + self.send_add_log_messages_to_memory_change( msg, prepared_add_items, prepared_update_items_with_original ) @@ -231,10 +231,10 @@ def send_add_log_messages_to_local_env( logger.info("send_add_log_messages_to_local_env: %s", len(events)) if events: self.scheduler_context.services.submit_web_logs( - events, additional_log_info="send_add_log_messages_to_cloud_env" + events, additional_log_info="send_add_log_messages_to_local_env" ) - def send_add_log_messages_to_cloud_env( + def send_add_log_messages_to_memory_change( self, msg: ScheduleMessageItem, prepared_add_items, @@ -278,7 +278,7 @@ def send_add_log_messages_to_cloud_env( if kb_log_content: logger.info( - "[DIAGNOSTIC] add_handler.send_add_log_messages_to_cloud_env: Creating event log for KB update. Label: knowledgeBaseUpdate, user_id: %s, mem_cube_id: %s, task_id: %s. KB content: %s", + "[DIAGNOSTIC] add_handler.send_add_log_messages_to_memory_change: Creating event log for KB update. Label: knowledgeBaseUpdate, user_id: %s, mem_cube_id: %s, task_id: %s. KB content: %s", msg.user_id, msg.mem_cube_id, msg.task_id, diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/feedback_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/feedback_handler.py index 173d37b50..445a50d8e 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/handlers/feedback_handler.py +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/feedback_handler.py @@ -11,7 +11,7 @@ USER_INPUT_TYPE, ) from memos.mem_scheduler.task_schedule_modules.base_handler import BaseSchedulerHandler -from memos.mem_scheduler.utils.misc_utils import is_cloud_env +from memos.mem_scheduler.utils.misc_utils import is_playground_api logger = get_logger(__name__) @@ -75,8 +75,8 @@ def process_single_feedback(self, message: ScheduleMessageItem) -> None: mem_cube_id, ) - cloud_env = is_cloud_env() - if cloud_env: + playground_api = is_playground_api() + if not playground_api: record = feedback_result.get("record") if isinstance(feedback_result, dict) else {} add_records = record.get("add") if isinstance(record, dict) else [] update_records = record.get("update") if isinstance(record, dict) else [] @@ -191,6 +191,7 @@ def _extract_fields(mem_item): ) else: logger.info( - "Skipping web log for feedback. Not in a cloud environment (is_cloud_env=%s)", - cloud_env, + "Skipping memory-change web log for feedback on playground API " + "(is_playground_api=%s)", + playground_api, ) diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py index 36cc97bdf..90ace633c 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py @@ -19,7 +19,7 @@ ) from memos.mem_scheduler.task_schedule_modules.base_handler import BaseSchedulerHandler from memos.mem_scheduler.utils.filter_utils import transform_name_to_key -from memos.mem_scheduler.utils.misc_utils import is_cloud_env +from memos.mem_scheduler.utils.misc_utils import is_playground_api from memos.memories.textual.tree import TreeTextMemory @@ -268,8 +268,8 @@ def _process_memories_with_reader( "[Scheduler] merged_from provided but graph_db is unavailable; skip archiving." ) - cloud_env = is_cloud_env() - if cloud_env: + playground_api = is_playground_api() + if not playground_api: kb_log_content = [] for item in flattened_memories: metadata = getattr(item, "metadata", None) @@ -448,8 +448,8 @@ def _process_memories_with_reader( exc_info=True, ) with contextlib.suppress(Exception): - cloud_env = is_cloud_env() - if cloud_env: + playground_api = is_playground_api() + if not playground_api: if not kb_log_content: trigger_source = ( info.get("trigger_source", "Messages") if info else "Messages" diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index 561d7931f..79f40def4 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -124,7 +124,9 @@ def __init__( self.seen_streams = set() # Task Orchestrator — cap in-memory cache to avoid unbounded growth - self._cache_max_packs = int(os.getenv("MEMSCHEDULER_REDIS_CACHE_MAX_PACKS", "50") or 50) + self._cache_max_packs = max( + 1, int(os.getenv("MEMSCHEDULER_REDIS_CACHE_MAX_PACKS", "50") or 50) + ) self.message_pack_cache: deque[list[ScheduleMessageItem]] = deque( maxlen=self._cache_max_packs ) @@ -138,6 +140,7 @@ def __init__( self._stream_keys_lock = threading.Lock() self._stream_keys_refresh_thread: ContextThread | None = None self._stream_keys_refresh_stop_event = threading.Event() + self._stream_read_offset = 0 self._initial_scan_max_keys = int( os.getenv("MEMSCHEDULER_REDIS_INITIAL_SCAN_MAX_KEYS", "1000") or 1000 ) @@ -315,6 +318,20 @@ def task_broker( stream_keys = self.get_stream_keys(stream_key_prefix=self.stream_key_prefix) if not stream_keys: return [] + stream_key_count = len(stream_keys) + if stream_key_count > self._cache_max_packs: + start = self._stream_read_offset % stream_key_count + end = start + self._cache_max_packs + if end <= stream_key_count: + stream_keys = stream_keys[start:end] + else: + stream_keys = stream_keys[start:] + stream_keys[: end % stream_key_count] + self._stream_read_offset = (start + self._cache_max_packs) % stream_key_count + logger.debug( + "[REDIS_QUEUE] Broker stream scan capped. scanned_streams=%s cache_max_packs=%s", + len(stream_keys), + self._cache_max_packs, + ) # Determine per-stream quotas for this cycle stream_quotas = self.orchestrator.get_stream_quotas( @@ -353,6 +370,28 @@ def task_broker( if claimed_messages: messages.extend(claimed_messages) + max_cached_messages = max(consume_batch_size, consume_batch_size * self._cache_max_packs) + limited_messages: list[tuple[str, list[tuple[str, dict]]]] = [] + remaining = max_cached_messages + for stream_key, stream_messages in messages: + if remaining <= 0: + break + if len(stream_messages) <= remaining: + limited_messages.append((stream_key, stream_messages)) + remaining -= len(stream_messages) + else: + limited_messages.append((stream_key, stream_messages[:remaining])) + remaining = 0 + if remaining == 0 and len(limited_messages) < len(messages): + logger.debug( + "[REDIS_QUEUE] Broker prefetch capped. streams=%s capped_messages=%s cache_max_packs=%s consume_batch=%s", + len(messages), + max_cached_messages, + self._cache_max_packs, + consume_batch_size, + ) + messages = limited_messages + cache: list[ScheduleMessageItem] = self._convert_messages(messages) # pack messages @@ -400,9 +439,9 @@ def _is_refill_thread_available(self) -> bool: return True if (time.time() - self._refill_thread_start) > self._refill_thread_timeout: logger.warning( - f"Refill thread has been running for >{self._refill_thread_timeout}s, treating as stale" + f"Refill thread has been running for >{self._refill_thread_timeout}s; " + "skip starting another refill thread to avoid duplicate memory growth" ) - return True return False def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]: diff --git a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py index 6bcf0023c..17a787895 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py @@ -5,7 +5,7 @@ the local memos_message_queue functionality in BaseScheduler. """ -from memos.context.context import get_current_trace_id +from memos.context.context import get_current_api_path, get_current_trace_id from memos.log import get_logger from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.task_schedule_modules.local_queue import SchedulerLocalQueue @@ -104,11 +104,14 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt return current_trace_id = get_current_trace_id() + current_api_path = get_current_api_path() for msg in messages: if current_trace_id: # Prefer current request trace_id so logs can be correlated msg.trace_id = current_trace_id + if current_api_path and not getattr(msg, "api_path", None): + msg.api_path = current_api_path msg.stream_key = self.memos_message_queue.get_stream_key( user_id=msg.user_id, mem_cube_id=msg.mem_cube_id, task_label=msg.label ) diff --git a/src/memos/mem_scheduler/utils/misc_utils.py b/src/memos/mem_scheduler/utils/misc_utils.py index 3ce727b5c..e0ccc238d 100644 --- a/src/memos/mem_scheduler/utils/misc_utils.py +++ b/src/memos/mem_scheduler/utils/misc_utils.py @@ -1,14 +1,13 @@ import json -import os import re import traceback - from collections import defaultdict from functools import wraps from pathlib import Path import yaml +from memos.context.context import get_current_api_path from memos.log import get_logger from memos.mem_scheduler.schemas.message_schemas import ( ScheduleMessageItem, @@ -17,39 +16,15 @@ logger = get_logger(__name__) - -def _normalize_env_value(value: str | None) -> str: - """Normalize environment variable values for comparison.""" - return value.strip().lower() if isinstance(value, str) else "" - - -def is_playground_env() -> bool: - """Return True when ENV_NAME indicates a Playground environment.""" - env_name = _normalize_env_value(os.getenv("ENV_NAME")) - return env_name.startswith("playground") +PLAYGROUND_CHAT_STREAM_PATH = "/product/chat/stream/playground" -def is_cloud_env() -> bool: +def is_playground_api() -> bool: """ - Determine whether the scheduler should treat the runtime as a cloud environment. - - Rules: - - Any Playground ENV_NAME is explicitly NOT cloud. - - MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME must be set to enable cloud behavior. - - The default memos-fanout/fanout combination is treated as non-cloud. + Determine whether the scheduler should use old playground behavior. """ - if is_playground_env(): - return False - - exchange_name = _normalize_env_value(os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME")) - exchange_type = _normalize_env_value(os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_TYPE")) - - if not exchange_name: - return False - - return not ( - exchange_name == "memos-fanout" and (not exchange_type or exchange_type == "fanout") - ) + api_path = get_current_api_path() + return api_path == PLAYGROUND_CHAT_STREAM_PATH def extract_json_obj(text: str): diff --git a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py index a07934b8e..3123bfdee 100644 --- a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py +++ b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py @@ -18,6 +18,8 @@ logger = get_logger(__name__) +PLAYGROUND_CHAT_STREAM_PATH = "/product/chat/stream/playground" + class RabbitMQSchedulerModule(BaseSchedulerModule): @require_python_package( @@ -37,6 +39,8 @@ def __init__(self): self.rabbit_queue_name = "memos-scheduler" self.rabbitmq_exchange_name = "memos-fanout" # Default, will be overridden by config self.rabbitmq_exchange_type = FANOUT_EXCHANGE_TYPE # Default, will be overridden by config + self.rabbitmq_playground_chat_exchange_name: str | None = None + self.rabbitmq_playground_chat_exchange_type = FANOUT_EXCHANGE_TYPE self.rabbitmq_connection = None self.rabbitmq_channel = None @@ -154,6 +158,27 @@ def initialize_rabbitmq( self.rabbitmq_exchange_type = env_exchange_type logger.info(f"Using env exchange type override: {self.rabbitmq_exchange_type}") + playground_exchange_name = os.getenv( + "MEMSCHEDULER_RABBITMQ_PLAYGROUND_CHAT_EXCHANGE_NAME", "" + ).strip() + playground_exchange_type = os.getenv( + "MEMSCHEDULER_RABBITMQ_PLAYGROUND_CHAT_EXCHANGE_TYPE", + FANOUT_EXCHANGE_TYPE, + ).strip() + if playground_exchange_name: + self.rabbitmq_playground_chat_exchange_name = playground_exchange_name + self.rabbitmq_playground_chat_exchange_type = ( + playground_exchange_type or FANOUT_EXCHANGE_TYPE + ) + logger.info( + "Using playground chat exchange override: name=%s, type=%s", + self.rabbitmq_playground_chat_exchange_name, + self.rabbitmq_playground_chat_exchange_type, + ) + else: + self.rabbitmq_playground_chat_exchange_name = None + self.rabbitmq_playground_chat_exchange_type = FANOUT_EXCHANGE_TYPE + # Start connection process parameters = self.get_rabbitmq_connection_param() self.rabbitmq_connection = SelectConnection( @@ -260,16 +285,33 @@ def on_rabbitmq_channel_open(self, channel): self.rabbitmq_channel = channel logger.info("[DIAGNOSTIC] RabbitMQ channel opened") - # Setup exchange and queue + # Setup primary/direct exchange and optional playground/fanout exchange. channel.exchange_declare( exchange=self.rabbitmq_exchange_name, exchange_type=self.rabbitmq_exchange_type, durable=True, - callback=self.on_rabbitmq_exchange_declared, + callback=self.on_rabbitmq_primary_exchange_declared, ) - def on_rabbitmq_exchange_declared(self, frame): - """Called when exchange is ready.""" + def on_rabbitmq_primary_exchange_declared(self, frame): + """Called when primary exchange is ready.""" + if self.rabbitmq_playground_chat_exchange_name: + self.rabbitmq_channel.exchange_declare( + exchange=self.rabbitmq_playground_chat_exchange_name, + exchange_type=self.rabbitmq_playground_chat_exchange_type, + durable=True, + callback=self.on_rabbitmq_playground_exchange_declared, + ) + return + + self._rabbitmq_continue_queue_setup() + + def on_rabbitmq_playground_exchange_declared(self, frame): + """Called when optional playground exchange is ready.""" + self._rabbitmq_continue_queue_setup() + + def _rabbitmq_continue_queue_setup(self): + """Declare scheduler queue and bind it to the primary exchange.""" self.rabbitmq_channel.queue_declare( queue=self.rabbit_queue_name, durable=True, callback=self.on_rabbitmq_queue_declared ) @@ -289,6 +331,13 @@ def on_rabbitmq_bind_ok(self, frame): # Flush any cached publish messages now that connection is ready self._flush_cached_publish_messages() + def resolve_publish_route(self, message: dict) -> tuple[str, str]: + api_path = message.get("api_path") + if api_path == PLAYGROUND_CHAT_STREAM_PATH and self.rabbitmq_playground_chat_exchange_name: + return self.rabbitmq_playground_chat_exchange_name, "" + + return self.rabbitmq_exchange_name, "" + def on_rabbitmq_message(self, channel, method, properties, body): """Handle incoming messages. Only for test.""" try: @@ -327,34 +376,17 @@ def rabbitmq_publish_message(self, message: dict): """ import pika - exchange_name = self.rabbitmq_exchange_name - routing_key = self.rabbit_queue_name + exchange_name, routing_key = self.resolve_publish_route(message) label = message.get("label") - # Special handling for knowledgeBaseUpdate in local environment: always empty routing key - if label == "knowledgeBaseUpdate": - routing_key = "" - - # Env override: apply to all message types when MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME is set - env_exchange_name = os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME") - env_routing_key = os.getenv("MEMSCHEDULER_RABBITMQ_ROUTING_KEY") - if env_exchange_name: - exchange_name = env_exchange_name - routing_key = ( - env_routing_key if env_routing_key is not None and env_routing_key != "" else "" - ) - logger.info( - f"[DIAGNOSTIC] Publishing {label} message with env exchange override. " - f"Exchange: {exchange_name}, Routing Key: '{routing_key}'." - ) - logger.info(f" - Message Content: {json.dumps(message, indent=2, ensure_ascii=False)}") - elif label == "knowledgeBaseUpdate": - # Original diagnostic logging for knowledgeBaseUpdate if NOT in cloud env - logger.info( - f"[DIAGNOSTIC] Publishing knowledgeBaseUpdate message (Local Env). " - f"Current configured Exchange: {exchange_name}, Routing Key: '{routing_key}'." - ) - logger.info(f" - Message Content: {json.dumps(message, indent=2, ensure_ascii=False)}") + logger.info( + "[DIAGNOSTIC] Publishing %s message. api_path=%s Exchange: %s, Routing Key: '%s'.", + label, + message.get("api_path"), + exchange_name, + routing_key, + ) + logger.info(f" - Message Content: {json.dumps(message, indent=2, ensure_ascii=False)}") with self._rabbitmq_lock: logger.info( diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index b7004c84c..23474cc20 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -185,6 +185,7 @@ class TreeNodeTextualMemoryMetadata(TextualMemoryMetadata): "RawFileMemory", "SkillMemory", "PreferenceMemory", + "Context", ] = Field(default="WorkingMemory", description="Memory lifecycle type.") sources: list[SourceMessage] | None = Field( default=None, description="Multiple origins of the memory (e.g., URLs, notes)." diff --git a/src/memos/plugins/hook_defs.py b/src/memos/plugins/hook_defs.py index 5ec73cc86..3650e60c0 100644 --- a/src/memos/plugins/hook_defs.py +++ b/src/memos/plugins/hook_defs.py @@ -72,6 +72,9 @@ class H: SEARCH_BEFORE = "search.before" SEARCH_AFTER = "search.after" + # Search extension point before core threshold/dedup/rerank processing. + SEARCH_MEMORY_RESULTS = "search.memory_results" + # Custom Hook (manually triggered via trigger_hook) ADD_MEMORIES_POST_PROCESS = "add.memories.post_process" @@ -106,6 +109,16 @@ class H: pipe_key="prompt", ) +define_hook( + H.SEARCH_MEMORY_RESULTS, + description=( + "Allow plugins to merge additional search result buckets before core " + "threshold, deduplication, and reranking." + ), + params=["handler", "search_req", "results"], + pipe_key="results", +) + define_hook( H.MEMORY_ITEMS_AFTER_FINE_EXTRACT, description="Post-process memory items after mem_reader fine extraction completes", diff --git a/src/memos/plugins/manager.py b/src/memos/plugins/manager.py index fb473e0d9..a4478a22d 100644 --- a/src/memos/plugins/manager.py +++ b/src/memos/plugins/manager.py @@ -4,6 +4,7 @@ import importlib.metadata import logging +import os from typing import TYPE_CHECKING @@ -29,6 +30,17 @@ def __init__(self): def plugins(self) -> dict[str, MemOSPlugin]: return dict(self._plugins) + @staticmethod + def _parse_plugin_names(value: str | None) -> set[str]: + if not value: + return set() + return {item.strip() for item in value.split(",") if item.strip()} + + @classmethod + def _is_plugin_enabled(cls, plugin: MemOSPlugin) -> bool: + disabled = cls._parse_plugin_names(os.getenv("MEMOS_DISABLED_PLUGINS")) + return plugin.name not in disabled + @staticmethod def _select_plugin_winners( candidates: list[tuple[str, MemOSPlugin]], @@ -107,6 +119,13 @@ def discover(self) -> None: winners = self._select_plugin_winners(candidates) for plugin_name, plugin in winners.items(): + if not self._is_plugin_enabled(plugin): + logger.info( + "Plugin discovered but disabled: %s v%s (MEMOS_DISABLED_PLUGINS)", + plugin.name, + plugin.version, + ) + continue plugin.on_load() self._plugins[plugin_name] = plugin logger.info( diff --git a/tests/dream/test_context_recall.py b/tests/dream/test_context_recall.py new file mode 100644 index 000000000..1aafec833 --- /dev/null +++ b/tests/dream/test_context_recall.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +from memos.api.handlers.base_handler import HandlerDependencies +from memos.api.handlers.search_handler import SearchHandler +from memos.api.product_models import APISearchRequest +from memos.dream.search import DreamContextSearchExtension +from memos.plugins.hook_defs import H +from memos.plugins.hooks import _hooks, register_hook + + +class FakeEmbedder: + def __init__(self): + self.calls: list[list[str]] = [] + + def embed(self, texts): + self.calls.append(texts) + return [[0.1, 0.2, 0.3] for _ in texts] + + +class FailingEmbedder: + def embed(self, texts): + raise RuntimeError("embed failed") + + +class FakeGraphDB: + def __init__(self, hits=None): + self.hits = hits or [] + self.calls: list[dict] = [] + + def search_by_embedding(self, vector, **kwargs): + self.calls.append({"vector": vector, **kwargs}) + return self.hits + + +class FakeCubeView: + def __init__(self, results=None): + self.results = results or _empty_results() + + def search_memories(self, search_req): + return self.results + + +def _empty_results(): + return { + "text_mem": [], + "act_mem": [], + "para_mem": [], + "pref_mem": [], + "pref_note": "", + "tool_mem": [], + "skill_mem": [], + } + + +def _handler(*, graph_db=None, embedder=None, cube_view=None) -> SearchHandler: + embedder = embedder or FakeEmbedder() + handler = SearchHandler( + HandlerDependencies( + naive_mem_cube=object(), + mem_scheduler=object(), + searcher=type("FakeSearcher", (), {"embedder": embedder})(), + deepsearch_agent=object(), + graph_db=graph_db, + embedder=embedder, + ) + ) + handler._build_cube_view = lambda _search_req: cube_view or FakeCubeView() + return handler + + +def _search_req(): + return APISearchRequest( + query="what is the user designing?", + user_id="user-a", + readable_cube_ids=["cube-a"], + top_k=5, + relativity=0, + dedup="no", + ) + + +def setup_function(): + _hooks.clear() + + +def test_context_recall_disabled_without_dream_search_hook(): + graph = FakeGraphDB( + hits=[ + { + "id": "ctx_1", + "memory": "Context summary", + "score": 0.9, + } + ] + ) + handler = _handler(graph_db=graph) + + response = handler.handle_search_memories(_search_req()) + + assert graph.calls == [] + assert response.data["text_mem"] == [] + + +def test_context_recall_searches_context_scope_and_returns_summary(): + register_hook( + H.SEARCH_MEMORY_RESULTS, + DreamContextSearchExtension(top_k=1).merge_context_recall, + ) + graph = FakeGraphDB( + hits=[ + { + "id": "ctx_1", + "memory": "The user is designing Dream Context recall.", + "score": 0.93, + "key": "Dream context recall", + "source": "system", + "internal_info": {"dream": {"memory_ids": ["m1", "m2"]}}, + } + ] + ) + handler = _handler(graph_db=graph) + + response = handler.handle_search_memories(_search_req()) + + assert graph.calls + assert graph.calls[0]["scope"] == "Context" + assert graph.calls[0]["status"] == "activated" + assert graph.calls[0]["top_k"] == 1 + assert graph.calls[0]["user_name"] == "cube-a" + assert graph.calls[0]["return_fields"] == [ + "memory", + "key", + "created_at", + "updated_at", + "source", + "internal_info", + ] + + text_mem = response.data["text_mem"] + assert len(text_mem) == 1 + memories = text_mem[0]["memories"] + assert len(memories) == 1 + assert memories[0]["id"] == "ctx_1" + assert memories[0]["memory"] == "The user is designing Dream Context recall." + assert memories[0]["metadata"]["memory_type"] == "Context" + assert memories[0]["metadata"]["key"] == "Dream context recall" + assert memories[0]["metadata"]["relativity"] == 0.93 + assert memories[0]["metadata"]["internal_info"] == {"dream": {"memory_ids": ["m1", "m2"]}} + + +def test_context_recall_gracefully_skips_without_graph_db(): + register_hook( + H.SEARCH_MEMORY_RESULTS, + DreamContextSearchExtension(top_k=1).merge_context_recall, + ) + handler = _handler(graph_db=None) + + response = handler.handle_search_memories(_search_req()) + + assert response.data["text_mem"] == [] + + +def test_context_recall_gracefully_skips_on_embedding_failure(): + register_hook( + H.SEARCH_MEMORY_RESULTS, + DreamContextSearchExtension(top_k=1).merge_context_recall, + ) + graph = FakeGraphDB() + handler = _handler(graph_db=graph, embedder=FailingEmbedder()) + + response = handler.handle_search_memories(_search_req()) + + assert graph.calls == [] + assert response.data["text_mem"] == [] diff --git a/tests/dream/test_contextualization.py b/tests/dream/test_contextualization.py new file mode 100644 index 000000000..78e0b2d0d --- /dev/null +++ b/tests/dream/test_contextualization.py @@ -0,0 +1,306 @@ +from __future__ import annotations + +import json + +from memos.dream.contextualization import CONTEXT_MEMORY_TYPE, DreamContextualizer +from memos.dream.types import DreamSignalSnapshot +from memos.memories.textual.item import TreeNodeTextualMemoryMetadata + + +class FakeGraphDB: + def __init__(self, nodes: list[dict] | None = None): + self.nodes = {node["id"]: node for node in nodes or []} + self.added: list[tuple[str, str, dict, str | None]] = [] + self.updated: list[tuple[str, dict, str | None]] = [] + + def get_nodes(self, ids, include_embedding=False, user_name=None): + return [self.nodes[node_id] for node_id in ids if node_id in self.nodes] + + def get_by_metadata(self, filters, user_name=None, status=None): + matched = [] + for node_id, node in self.nodes.items(): + metadata = node.get("metadata", {}) + if status and metadata.get("status") != status: + continue + ok = True + for item in filters: + if item.get("op") == "=" and metadata.get(item.get("field")) != item.get("value"): + ok = False + break + if ok: + matched.append(node_id) + return matched + + def add_node(self, id, memory, metadata, user_name=None): + self.added.append((id, memory, metadata, user_name)) + self.nodes[id] = {"id": id, "memory": memory, "metadata": metadata} + + def update_node(self, id, fields, user_name=None): + self.updated.append((id, fields, user_name)) + node = self.nodes[id] + memory = fields.pop("memory", node["memory"]) + node["memory"] = memory + node["metadata"].update(fields) + + +class FakeEmbedder: + def embed(self, texts): + return [[float(len(text)), 0.1] for text in texts] + + +class FakeLLM: + def generate(self, messages): + assert "Return strict JSON only" in messages[0]["content"] + return json.dumps( + { + "key": "MemOS Dream Context", + "memory": "The user is designing the MemOS Dream Context pipeline.", + "confidence": 0.91, + } + ) + + +class FakeBindingAndSummaryLLM: + def __init__(self): + self.calls: list[str] = [] + + def generate(self, messages): + prompt = messages[0]["content"] + self.calls.append(prompt) + if '"unassigned_ids"' in prompt: + assert "ID: m1" in prompt + assert "ID: m2" in prompt + assert "ID: m3" in prompt + assert "real_ids=[" in prompt + return json.dumps( + { + "contexts": [ + { + "key": "Dream Enricher Design", + "ids": ["m1", "m3"], + "confidence": 0.88, + "reason": "same implementation thread", + } + ], + "unassigned_ids": ["m2"], + } + ) + return json.dumps( + { + "key": "Summary Key", + "memory": "Summary text", + "confidence": 0.8, + } + ) + + +def _memory_node(node_id: str, weak_context_id: str | None = "project:memos") -> dict: + dream = {"weak_context_id": weak_context_id} if weak_context_id else {} + return { + "id": node_id, + "memory": f"Memory {node_id} about Dream Context.", + "metadata": { + "memory_type": "LongTermMemory", + "status": "activated", + "created_at": f"2026-05-18T00:00:0{node_id[-1]}", + "internal_info": {"dream": dream}, + }, + } + + +def _context_node() -> dict: + return { + "id": "ctx_existing", + "memory": "Old summary", + "metadata": { + "memory_type": CONTEXT_MEMORY_TYPE, + "status": "activated", + "key": "Old Context", + "created_at": "2026-05-17T00:00:00", + "internal_info": json.dumps( + { + "dream": { + "kind": "context", + "memory_ids": ["m0"], + "weak_context_ids": ["project:memos"], + } + } + ), + }, + } + + +def test_contextualizer_skips_project_pool_when_binding_llm_unavailable(): + graph = FakeGraphDB(nodes=[_memory_node("m1"), _memory_node("m2")]) + contextualizer = DreamContextualizer( + enabled=True, + binding_llm_enabled=False, + summary_llm_enabled=True, + ) + contextualizer.bind_context( + {"shared": {"graph_db": graph, "embedder": FakeEmbedder(), "llm": FakeLLM()}} + ) + + report = contextualizer.run( + signal_snapshot=DreamSignalSnapshot(mem_cube_id="cube-a", pending_memory_ids=["m1", "m2"]), + text_mem=None, + cube_id="cube-a", + ) + + assert report.created_context_count == 0 + assert report.bound_memory_count == 0 + assert report.skipped_memory_count == 2 + assert graph.added == [] + + +def test_contextualizer_uses_short_ids_for_llm_binding_and_maps_back_to_real_ids(): + graph = FakeGraphDB( + nodes=[ + _memory_node("uuid-alpha-1", "session:s1"), + _memory_node("uuid-beta-2", "session:s1"), + _memory_node("uuid-alpha-3", "session:s1"), + ] + ) + llm = FakeBindingAndSummaryLLM() + contextualizer = DreamContextualizer( + enabled=True, + binding_llm_enabled=True, + summary_llm_enabled=True, + ) + contextualizer.bind_context( + {"shared": {"graph_db": graph, "embedder": FakeEmbedder(), "llm": llm}} + ) + + report = contextualizer.run( + signal_snapshot=DreamSignalSnapshot( + mem_cube_id="cube-a", + pending_memory_ids=["uuid-alpha-1", "uuid-beta-2", "uuid-alpha-3"], + ), + text_mem=None, + cube_id="cube-a", + ) + + assert report.created_context_count == 1 + assert report.bound_memory_count == 2 + assert report.skipped_memory_count == 1 + grouped_ids = [ + metadata["internal_info"]["dream"]["memory_ids"] for _, _, metadata, _ in graph.added + ] + assert ["uuid-alpha-1", "uuid-alpha-3"] in grouped_ids + assert ["uuid-beta-2"] not in grouped_ids + assert any(ctx["binding_strategy"] == "llm" for ctx in report.contexts) + assert report.contexts[0]["label"] == "Summary Key" + assert report.contexts[0]["summary"] == "Summary text" + assert report.contexts[0]["source_memory_ids"] == ["uuid-alpha-1", "uuid-alpha-3"] + assert report.contexts[0]["summary_strategy"] == "llm" + assert "uuid-alpha-1" in llm.calls[0] + assert "ids" in llm.calls[0] + + +def test_contextualizer_persists_batch_memories_without_llm_binding(): + graph = FakeGraphDB( + nodes=[ + _memory_node("chunk-1", "batch:file-1"), + _memory_node("chunk-2", "batch:file-1"), + _memory_node("other-3", "batch:file-1"), + ] + ) + contextualizer = DreamContextualizer( + enabled=True, + binding_llm_enabled=False, + summary_llm_enabled=False, + ) + contextualizer.bind_context({"shared": {"graph_db": graph, "embedder": FakeEmbedder()}}) + + report = contextualizer.run( + signal_snapshot=DreamSignalSnapshot( + mem_cube_id="cube-a", + pending_memory_ids=["chunk-1", "chunk-2", "other-3"], + ), + text_mem=None, + cube_id="cube-a", + ) + + assert report.created_context_count == 1 + assert report.bound_memory_count == 3 + assert len(graph.added) == 1 + assert graph.added[0][2]["internal_info"]["dream"]["memory_ids"] == [ + "chunk-1", + "chunk-2", + "other-3", + ] + assert graph.added[0][2]["internal_info"]["dream"]["binding"]["strategy"] == "batch" + + +def test_contextualizer_skips_singleton_even_when_existing_context_matches_weak_id(): + graph = FakeGraphDB(nodes=[_memory_node("m1"), _context_node()]) + contextualizer = DreamContextualizer(enabled=True, summary_llm_enabled=False) + contextualizer.bind_context({"shared": {"graph_db": graph, "embedder": FakeEmbedder()}}) + + report = contextualizer.run( + signal_snapshot=DreamSignalSnapshot(mem_cube_id="cube-a", pending_memory_ids=["m1"]), + text_mem=None, + cube_id="cube-a", + ) + + assert report.updated_context_count == 0 + assert report.bound_memory_count == 0 + assert report.skipped_memory_count == 1 + assert not graph.updated + assert graph.added == [] + + +def test_contextualizer_skips_unbound_memories_instead_of_singletons(): + graph = FakeGraphDB(nodes=[_memory_node("m1", None), _memory_node("m2", None)]) + contextualizer = DreamContextualizer(enabled=True, summary_llm_enabled=False) + contextualizer.bind_context({"shared": {"graph_db": graph, "embedder": FakeEmbedder()}}) + + report = contextualizer.run( + signal_snapshot=DreamSignalSnapshot(mem_cube_id="cube-a", pending_memory_ids=["m1", "m2"]), + text_mem=None, + cube_id="cube-a", + ) + + assert report.created_context_count == 0 + assert report.bound_memory_count == 0 + assert report.skipped_memory_count == 2 + assert graph.added == [] + + +def test_contextualizer_skips_oversized_project_pool_without_fallback_context(): + nodes = [_memory_node(f"m{idx}") for idx in range(1, 5)] + graph = FakeGraphDB(nodes=nodes) + contextualizer = DreamContextualizer( + enabled=True, + binding_llm_enabled=True, + summary_llm_enabled=False, + binding_max_group_size=3, + ) + contextualizer.bind_context( + { + "shared": { + "graph_db": graph, + "embedder": FakeEmbedder(), + "llm": FakeBindingAndSummaryLLM(), + } + } + ) + + report = contextualizer.run( + signal_snapshot=DreamSignalSnapshot( + mem_cube_id="cube-a", + pending_memory_ids=[node["id"] for node in nodes], + ), + text_mem=None, + cube_id="cube-a", + ) + + assert report.created_context_count == 0 + assert report.bound_memory_count == 0 + assert report.skipped_memory_count == 4 + assert graph.added == [] + + +def test_context_memory_type_is_valid_textual_metadata(): + metadata = TreeNodeTextualMemoryMetadata(memory_type="Context") + assert metadata.memory_type == "Context" diff --git a/tests/dream/test_diary_pipeline.py b/tests/dream/test_diary_pipeline.py new file mode 100644 index 000000000..f9110033d --- /dev/null +++ b/tests/dream/test_diary_pipeline.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from memos.dream.contextualization import DreamContextReport +from memos.dream.pipeline.diary import StructuredDiarySummary + + +def test_diary_summary_creates_context_only_entry_from_context_report(): + report = DreamContextReport( + processed_memory_count=2, + created_context_count=1, + updated_context_count=0, + bound_memory_count=2, + skipped_memory_count=0, + contexts=[ + { + "context_id": "ctx_1", + "action": "created", + "label": "first-week-closed-loop", + "summary": "第一周闭环包含 Context binding 和 Search A。", + "source_memory_ids": ["m1", "m2"], + "binding_strategy": "llm", + "summary_strategy": "llm", + } + ], + ) + + results = StructuredDiarySummary().generate( + clusters=[], + results=[], + mem_cube_id="cube-a", + context_report=report, + ) + + assert len(results) == 1 + entry = results[0].diary_entry + assert entry is not None + assert entry.status == "context_only" + assert entry.title == "Dream Context Summary" + assert entry.context_events == report.contexts + assert "first-week-closed-loop" in entry.summary + assert "第一周闭环" in entry.dream_entry diff --git a/tests/dream/test_diary_router.py b/tests/dream/test_diary_router.py new file mode 100644 index 000000000..ca8910a4d --- /dev/null +++ b/tests/dream/test_diary_router.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from memos.dream.routers.diary_router import create_diary_router + + +class FakeGraphDB: + def __init__(self, nodes=None): + self.nodes = nodes or [] + self.calls = [] + + def get_by_metadata(self, filters, user_name=None, status=None): + self.calls.append({"filters": filters, "user_name": user_name, "status": status}) + ids = [] + for node in self.nodes: + metadata = node.get("metadata") or {} + if status and metadata.get("status") != status: + continue + ids.append(node["id"]) + return ids + + def get_nodes(self, ids, user_name=None): + return [node for node in self.nodes if node["id"] in ids] + + +class FakePlugin: + name = "dream" + version = "0.1.0" + + def __init__(self, graph_db): + self.context = {"shared": {"graph_db": graph_db}} + self.signal_store = type("SignalStore", (), {"trigger_threshold": 100})() + + +def _client(graph_db) -> TestClient: + app = FastAPI() + app.include_router(create_diary_router(FakePlugin(graph_db))) + return TestClient(app) + + +def test_diary_query_does_not_filter_by_activated_status_in_graph_db_call(): + node = { + "id": "dream_diary_1", + "memory": "Diary content", + "metadata": { + "memory_type": "DreamDiary", + "status": "completed", + "created_at": "2026-05-19T10:00:00", + "title": "Dream title", + "summary": "Dream summary", + "dream_entry": "Dream entry", + "motive": {"type": "newness"}, + "context_events": [{"context_id": "ctx_1", "label": "Context label"}], + "themes": ["dream"], + }, + } + graph_db = FakeGraphDB(nodes=[node]) + + response = _client(graph_db).post("/dream/diary", json={"cube_id": "cube-a"}) + + assert response.status_code == 200 + assert graph_db.calls == [ + { + "filters": [{"field": "memory_type", "op": "=", "value": "DreamDiary"}], + "user_name": "cube-a", + "status": None, + } + ] + data = response.json()["data"] + assert len(data) == 1 + assert data[0]["task_id"] == "dream_diary_1" + assert data[0]["title"] == "Dream title" + assert data[0]["context_events"] == [{"context_id": "ctx_1", "label": "Context label"}] + + +def test_diary_query_returns_completed_and_skipped_entries(): + graph_db = FakeGraphDB( + nodes=[ + { + "id": "completed", + "memory": "Completed", + "metadata": {"memory_type": "DreamDiary", "status": "completed"}, + }, + { + "id": "skipped", + "memory": "Skipped", + "metadata": {"memory_type": "DreamDiary", "status": "skipped"}, + }, + ] + ) + + response = _client(graph_db).post("/dream/diary", json={"cube_id": "cube-a"}) + + assert response.status_code == 200 + data = response.json()["data"] + assert [item["task_id"] for item in data] == ["completed", "skipped"] diff --git a/tests/dream/test_heuristic_enricher.py b/tests/dream/test_heuristic_enricher.py new file mode 100644 index 000000000..9e8384021 --- /dev/null +++ b/tests/dream/test_heuristic_enricher.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +from memos.dream.enrichment import DreamHeuristicEnricher +from memos.memories.textual.item import ( + SourceMessage, + TextualMemoryItem, + TreeNodeTextualMemoryMetadata, +) +from memos.types.general_types import UserContext + + +def _item( + memory: str = "User is planning MemOS Dream v1.", + *, + role: str = "user", + session_id: str = "session-a", + project_id: str | None = "project-a", + internal_info: dict | None = None, + sources: list[SourceMessage] | None = None, +) -> TextualMemoryItem: + return TextualMemoryItem( + memory=memory, + metadata=TreeNodeTextualMemoryMetadata( + user_id="user-a", + session_id=session_id, + memory_type="LongTermMemory", + project_id=project_id, + sources=sources or [SourceMessage(type="chat", role=role, content=memory)], + internal_info=internal_info, + ), + ) + + +def test_heuristic_enricher_writes_dream_subdict_for_project_context(): + item = _item("Can we make the Dream v1 plan use plugins?") + enricher = DreamHeuristicEnricher(enabled=True, overwrite=False) + + enriched = enricher.enrich_items( + items=[item], + user_context=UserContext(session_id="session-a", project_id="project-a"), + extract_mode="fine", + ) + + assert enriched == [item] + dream = item.metadata.internal_info["dream"] + assert dream["weak_context_id"] == "project:project-a" + assert dream["signals"]["source_roles"] == ["user"] + assert dream["signals"]["has_question"] is True + assert dream["signals"]["has_correction"] is False + assert dream["signals"]["is_chunk"] is False + assert dream["salience"]["has_feedback"] is False + assert dream["enriched_by"]["heuristic"] == "0.1.0" + + +def test_heuristic_enricher_ignores_common_clarification_markers(): + examples = [ + "其实是我想用 Context summary 做背景。", + "不对称的设计会引入噪声。", + "这不是为了生成漂亮反思,而是为了提升 search。", + "Actually, I want Context recall to be gated.", + "I mean the search path should return summaries.", + ] + enricher = DreamHeuristicEnricher(enabled=True, overwrite=False) + + for text in examples: + item = _item(text) + enricher.enrich_items(items=[item], user_context=None, extract_mode="fine") + dream = item.metadata.internal_info["dream"] + assert dream["signals"]["has_correction"] is False + assert dream["salience"]["has_feedback"] is False + + +def test_heuristic_enricher_detects_strong_agent_feedback(): + examples = [ + "你刚才说错了,我不是要展开 memory_ids。", + "你的回答不对,我说的是 Context summary。", + "你没有理解我的意思,我要默认关闭。", + "Your answer is wrong; context recall should be gated.", + "You misunderstood me: do not expand the raw memories.", + ] + enricher = DreamHeuristicEnricher(enabled=True, overwrite=False) + + for text in examples: + item = _item(text) + enricher.enrich_items(items=[item], user_context=None, extract_mode="fine") + dream = item.metadata.internal_info["dream"] + assert dream["signals"]["has_correction"] is True + assert dream["salience"]["has_feedback"] is True + + +def test_heuristic_enricher_falls_back_to_non_default_session(): + item = _item(project_id=None, session_id="session-b") + enricher = DreamHeuristicEnricher(enabled=True, overwrite=False) + + enricher.enrich_items(items=[item], user_context=None, extract_mode="fine") + + assert item.metadata.internal_info["dream"]["weak_context_id"] == "session:session-b" + + +def test_heuristic_enricher_keeps_default_session_unbound_without_project(): + item = _item(project_id=None, session_id="default_session") + enricher = DreamHeuristicEnricher(enabled=True, overwrite=False) + + enricher.enrich_items(items=[item], user_context=None, extract_mode="fine") + + assert item.metadata.internal_info["dream"]["weak_context_id"] is None + + +def test_heuristic_enricher_uses_batch_context_for_chunks(): + item_a = _item( + "chunk one", + internal_info={"ingest_batch_id": "ingest-1", "chunk_index": 0, "chunk_total": 2}, + ) + item_b = _item( + "chunk two", + internal_info={"ingest_batch_id": "ingest-1", "chunk_index": 1, "chunk_total": 2}, + ) + enricher = DreamHeuristicEnricher(enabled=True, overwrite=False) + + enricher.enrich_items(items=[item_a, item_b], user_context=None, extract_mode="fine") + + for item in (item_a, item_b): + dream = item.metadata.internal_info["dream"] + assert dream["batch_context_id"] == "batch:ingest-1" + assert dream["weak_context_id"] == "batch:ingest-1" + assert dream["signals"]["is_chunk"] is True + assert dream["signals"]["chunk_total"] == 2 + + +def test_heuristic_enricher_keeps_mixed_ingest_batches_separate(): + item_a = _item( + "chunk one", + internal_info={"ingest_batch_id": "ingest-1", "chunk_index": 0, "chunk_total": 2}, + ) + item_b = _item( + "chunk two", + internal_info={"ingest_batch_id": "ingest-2", "chunk_index": 1, "chunk_total": 2}, + ) + enricher = DreamHeuristicEnricher(enabled=True, overwrite=False) + + enricher.enrich_items(items=[item_a, item_b], user_context=None, extract_mode="fine") + + dream_a = item_a.metadata.internal_info["dream"] + dream_b = item_b.metadata.internal_info["dream"] + assert "batch_context_id" not in dream_a + assert "batch_context_id" not in dream_b + assert dream_a["weak_context_id"] == "batch:ingest-1" + assert dream_b["weak_context_id"] == "batch:ingest-2" + + +def test_heuristic_enricher_preserves_existing_semantic_fields_by_default(): + item = _item( + internal_info={ + "dream": { + "context_hint": "Existing semantic hint", + "salience": {"unresolved": True, "has_feedback": False}, + } + } + ) + enricher = DreamHeuristicEnricher(enabled=True, overwrite=False) + + enricher.enrich_items(items=[item], user_context=None, extract_mode="fine") + + dream = item.metadata.internal_info["dream"] + assert dream["context_hint"] == "Existing semantic hint" + assert dream["salience"]["unresolved"] is True + assert dream["salience"]["has_feedback"] is False + + +def test_heuristic_enricher_skips_when_disabled_or_not_fine(): + disabled_item = _item() + disabled = DreamHeuristicEnricher(enabled=False) + disabled.enrich_items(items=[disabled_item], user_context=None, extract_mode="fine") + assert disabled_item.metadata.internal_info is None + + fast_item = _item() + enabled = DreamHeuristicEnricher(enabled=True) + enabled.enrich_items(items=[fast_item], user_context=None, extract_mode="fast") + assert fast_item.metadata.internal_info is None diff --git a/tests/graph_dbs/test_search_return_fields.py b/tests/graph_dbs/test_search_return_fields.py index 82a50308b..fc95d5a81 100644 --- a/tests/graph_dbs/test_search_return_fields.py +++ b/tests/graph_dbs/test_search_return_fields.py @@ -8,6 +8,7 @@ import uuid +from contextlib import contextmanager from unittest.mock import MagicMock, patch import pytest @@ -195,6 +196,37 @@ def test_extract_invalid_json(self, polardb_instance): result = polardb_instance._extract_fields_from_properties("not-json", ["memory"]) assert result == {} + def test_get_by_metadata_accepts_status_filter(self, polardb_instance): + """PolarDB get_by_metadata honors the BaseGraphDB status contract.""" + polardb_instance.db_name = "test_db" + polardb_instance.config = {"user_name": "default_user"} + polardb_instance._build_user_name_and_kb_ids_conditions_cypher = MagicMock( + return_value=["n.user_name = 'cube-a'"] + ) + polardb_instance._build_filter_conditions_cypher = MagicMock(return_value="") + + cursor = MagicMock() + cursor.fetchall.return_value = [('"node-1"',)] + conn = MagicMock() + conn.cursor.return_value.__enter__.return_value = cursor + + @contextmanager + def fake_connection(): + yield conn + + polardb_instance._get_connection = fake_connection + + ids = polardb_instance.get_by_metadata( + [{"field": "memory_type", "op": "=", "value": "DreamDiary"}], + user_name="cube-a", + status="activated", + ) + + query = cursor.execute.call_args[0][0] + assert "n.status = 'activated'" in query + assert "n.memory_type = 'DreamDiary'" in query + assert ids == ["node-1"] + class TestFieldNameValidation: """Tests for _validate_return_fields injection prevention."""