Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion surfsense_backend/app/agents/new_chat/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
AUTO_MODE_ID,
ChatLiteLLMRouter,
LLMRouterService,
get_auto_mode_llm,
is_auto_mode,
)

Expand Down Expand Up @@ -389,7 +390,7 @@ def create_chat_litellm_from_agent_config(
print("Error: Auto mode requested but LLM Router not initialized")
return None
try:
return ChatLiteLLMRouter()
return get_auto_mode_llm()
except Exception as e:
print(f"Error creating ChatLiteLLMRouter: {e}")
return None
Expand Down
7 changes: 7 additions & 0 deletions surfsense_backend/app/agents/new_chat/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ async def aexecute(

_daytona_client: Daytona | None = None
_sandbox_cache: dict[str, _TimeoutAwareSandbox] = {}
_SANDBOX_CACHE_MAX_SIZE = 20
THREAD_LABEL_KEY = "surfsense_thread"


Expand Down Expand Up @@ -144,6 +145,12 @@ async def get_or_create_sandbox(thread_id: int | str) -> _TimeoutAwareSandbox:
return cached
sandbox = await asyncio.to_thread(_find_or_create, key)
_sandbox_cache[key] = sandbox

if len(_sandbox_cache) > _SANDBOX_CACHE_MAX_SIZE:
oldest_key = next(iter(_sandbox_cache))
_sandbox_cache.pop(oldest_key, None)
logger.debug("Evicted oldest sandbox cache entry: %s", oldest_key)

return sandbox


Expand Down
22 changes: 22 additions & 0 deletions surfsense_backend/app/agents/new_chat/tools/mcp_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,24 @@
logger = logging.getLogger(__name__)

_MCP_CACHE_TTL_SECONDS = 300 # 5 minutes
_MCP_CACHE_MAX_SIZE = 50
_mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {}


def _evict_expired_mcp_cache() -> None:
"""Remove expired entries from the MCP tools cache to prevent unbounded growth."""
now = time.monotonic()
expired = [
k
for k, (ts, _) in _mcp_tools_cache.items()
if now - ts >= _MCP_CACHE_TTL_SECONDS
]
for k in expired:
del _mcp_tools_cache[k]
if expired:
logger.debug("Evicted %d expired MCP cache entries", len(expired))


def _create_dynamic_input_model_from_schema(
tool_name: str,
input_schema: dict[str, Any],
Expand Down Expand Up @@ -392,6 +407,8 @@ async def load_mcp_tools(
List of LangChain StructuredTool instances

"""
_evict_expired_mcp_cache()

now = time.monotonic()
cached = _mcp_tools_cache.get(search_space_id)
if cached is not None:
Expand Down Expand Up @@ -445,6 +462,11 @@ async def load_mcp_tools(
)

_mcp_tools_cache[search_space_id] = (now, tools)

if len(_mcp_tools_cache) > _MCP_CACHE_MAX_SIZE:
oldest_key = min(_mcp_tools_cache, key=lambda k: _mcp_tools_cache[k][0])
del _mcp_tools_cache[oldest_key]

logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}")
return tools

Expand Down
16 changes: 9 additions & 7 deletions surfsense_backend/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,22 +103,24 @@ def _check_rate_limit_memory(
now = time.monotonic()

with _memory_lock:
# Evict timestamps outside the current window
_memory_rate_limits[key] = [
t for t in _memory_rate_limits[key] if now - t < window_seconds
]
timestamps = [t for t in _memory_rate_limits[key] if now - t < window_seconds]

if not timestamps:
_memory_rate_limits.pop(key, None)
else:
_memory_rate_limits[key] = timestamps

if len(_memory_rate_limits[key]) >= max_requests:
if len(timestamps) >= max_requests:
rate_limit_logger.warning(
f"Rate limit exceeded (in-memory fallback) on {scope} for IP {client_ip} "
f"({len(_memory_rate_limits[key])}/{max_requests} in {window_seconds}s)"
f"({len(timestamps)}/{max_requests} in {window_seconds}s)"
)
raise HTTPException(
status_code=429,
detail="RATE_LIMIT_EXCEEDED",
)

_memory_rate_limits[key].append(now)
_memory_rate_limits[key] = [*timestamps, now]


def _check_rate_limit(
Expand Down
118 changes: 67 additions & 51 deletions surfsense_backend/app/services/llm_router_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,48 @@ def get_model_count(cls) -> int:
return len(instance._model_list)


_cached_context_profile: dict | None = None
_cached_context_profile_computed: bool = False

# Cached singleton instances keyed by (streaming,) to avoid re-creating on every call
_router_instance_cache: dict[bool, "ChatLiteLLMRouter"] = {}


def _get_cached_context_profile(router: Router) -> dict | None:
"""Compute and cache the min context profile across all router deployments.

Called once on first ChatLiteLLMRouter creation; subsequent calls return
the cached value. This avoids calling litellm.get_model_info() for every
deployment on every request.
"""
global _cached_context_profile, _cached_context_profile_computed
if _cached_context_profile_computed:
return _cached_context_profile

from litellm import get_model_info

min_ctx: int | None = None
for deployment in router.model_list:
params = deployment.get("litellm_params", {})
base_model = params.get("base_model") or params.get("model", "")
try:
info = get_model_info(base_model)
ctx = info.get("max_input_tokens")
if isinstance(ctx, int) and ctx > 0 and (min_ctx is None or ctx < min_ctx):
min_ctx = ctx
except Exception:
continue

if min_ctx is not None:
logger.info("ChatLiteLLMRouter profile: max_input_tokens=%d", min_ctx)
_cached_context_profile = {"max_input_tokens": min_ctx}
else:
_cached_context_profile = None

_cached_context_profile_computed = True
return _cached_context_profile


class ChatLiteLLMRouter(BaseChatModel):
"""
A LangChain-compatible chat model that uses LiteLLM Router for load balancing.
Expand All @@ -260,6 +302,10 @@ class ChatLiteLLMRouter(BaseChatModel):
Exposes a ``profile`` with ``max_input_tokens`` set to the smallest context
window across all router deployments so that deepagents
SummarizationMiddleware can use fraction-based triggers.

**Singleton-ish**: Use ``get_auto_mode_llm()`` or call ``ChatLiteLLMRouter()``
directly — instances without bound tools are cached per streaming flag to
avoid per-request re-initialization overhead and memory growth.
"""

# Use model_config for Pydantic v2 compatibility
Expand All @@ -281,14 +327,6 @@ def __init__(
tool_choice: str | dict | None = None,
**kwargs,
):
"""
Initialize the ChatLiteLLMRouter.

Args:
router: LiteLLM Router instance. If None, uses the global singleton.
bound_tools: Pre-bound tools for tool calling
tool_choice: Tool choice configuration
"""
try:
super().__init__(**kwargs)
resolved_router = router or LLMRouterService.get_router()
Expand All @@ -300,51 +338,20 @@ def __init__(
"LLM Router not initialized. Call LLMRouterService.initialize() first."
)

# Set profile so deepagents SummarizationMiddleware gets fraction-based triggers
computed_profile = self._compute_min_context_profile()
computed_profile = _get_cached_context_profile(self._router)
if computed_profile is not None:
object.__setattr__(self, "profile", computed_profile)

logger.info(
f"ChatLiteLLMRouter initialized with {LLMRouterService.get_model_count()} models"
logger.debug(
"ChatLiteLLMRouter ready (models=%d, streaming=%s, has_tools=%s)",
LLMRouterService.get_model_count(),
self.streaming,
bound_tools is not None,
)
except Exception as e:
logger.error(f"Failed to initialize ChatLiteLLMRouter: {e}")
raise

def _compute_min_context_profile(self) -> dict | None:
"""Derive a profile dict with max_input_tokens from router deployments.

Uses litellm.get_model_info to look up each deployment's context window
and picks the *minimum* so that summarization triggers before ANY model
in the pool overflows.
"""
from litellm import get_model_info

if not self._router:
return None

min_ctx: int | None = None
for deployment in self._router.model_list:
params = deployment.get("litellm_params", {})
base_model = params.get("base_model") or params.get("model", "")
try:
info = get_model_info(base_model)
ctx = info.get("max_input_tokens")
if (
isinstance(ctx, int)
and ctx > 0
and (min_ctx is None or ctx < min_ctx)
):
min_ctx = ctx
except Exception:
continue

if min_ctx is not None:
logger.info(f"ChatLiteLLMRouter profile: max_input_tokens={min_ctx}")
return {"max_input_tokens": min_ctx}
return None

@property
def _llm_type(self) -> str:
return "litellm-router"
Expand Down Expand Up @@ -770,19 +777,28 @@ def _convert_delta_to_chunk(self, delta: Any) -> AIMessageChunk | None:
return None


def get_auto_mode_llm() -> ChatLiteLLMRouter | None:
"""
Get a ChatLiteLLMRouter instance for auto mode.
def get_auto_mode_llm(
*,
streaming: bool = True,
) -> ChatLiteLLMRouter | None:
"""Return a cached ChatLiteLLMRouter for auto mode.

Returns:
ChatLiteLLMRouter instance or None if router not initialized
Base (no tools) instances are cached per ``streaming`` flag so we
avoid re-constructing them on every request. ``bind_tools()`` still
returns a fresh instance because bound tools differ per agent.
"""
if not LLMRouterService.is_initialized():
logger.warning("LLM Router not initialized for auto mode")
return None

cached = _router_instance_cache.get(streaming)
if cached is not None:
return cached

try:
return ChatLiteLLMRouter()
instance = ChatLiteLLMRouter(streaming=streaming)
_router_instance_cache[streaming] = instance
return instance
except Exception as e:
logger.error(f"Failed to create ChatLiteLLMRouter: {e}")
return None
Expand Down
3 changes: 2 additions & 1 deletion surfsense_backend/app/services/llm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
AUTO_MODE_ID,
ChatLiteLLMRouter,
LLMRouterService,
get_auto_mode_llm,
is_auto_mode,
)

Expand Down Expand Up @@ -221,7 +222,7 @@ async def get_search_space_llm_instance(
logger.debug(
f"Using Auto mode (LLM Router) for search space {search_space_id}, role {role}"
)
return ChatLiteLLMRouter(disable_streaming=disable_streaming)
return get_auto_mode_llm(streaming=not disable_streaming)
except Exception as e:
logger.error(f"Failed to create ChatLiteLLMRouter: {e}")
return None
Expand Down
18 changes: 18 additions & 0 deletions surfsense_backend/app/tasks/chat/stream_new_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"""

import asyncio
import gc
import json
import logging
import re
Expand Down Expand Up @@ -1476,6 +1477,16 @@ async def stream_new_chat(

_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)

# Trigger a GC pass so LangGraph agent graphs, tool closures, and
# LLM wrappers with potential circular refs are reclaimed promptly.
collected = gc.collect()
if collected:
_perf_log.info(
"[stream_new_chat] gc.collect() reclaimed %d objects (chat_id=%s)",
collected,
chat_id,
)


async def stream_resume_chat(
chat_id: int,
Expand Down Expand Up @@ -1662,3 +1673,10 @@ async def stream_resume_chat(
)

_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
collected = gc.collect()
if collected:
_perf_log.info(
"[stream_resume] gc.collect() reclaimed %d objects (chat_id=%s)",
collected,
chat_id,
)
Loading