Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions services/chatbot/src/chatbot/agent_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import logging

from langchain.agents import AgentState
from langchain.agents.middleware.types import before_model
Expand All @@ -7,7 +8,11 @@

from .config import Config

logger = logging.getLogger(__name__)

INDIVIDUAL_MIN_LENGTH = 100
# Approximate characters per token across providers
CHARS_PER_TOKEN = 4


def collect_long_strings(obj):
Expand Down Expand Up @@ -88,3 +93,53 @@ def truncate_tool_messages(state: AgentState, runtime: Runtime) -> AgentState:
else:
modified_messages.append(msg)
return {"messages": modified_messages}


def _estimate_tokens(text):
"""Estimate token count using character-based approximation."""
return len(text) // CHARS_PER_TOKEN


def _message_content(msg):
"""Extract text content from a message dict or object."""
if isinstance(msg, dict):
return msg.get("content", "")
return getattr(msg, "content", "")


def trim_messages_to_token_limit(messages):
"""
Trim conversation history from the oldest messages to fit within the token
budget derived from MAX_CONTENT_LENGTH.
The most recent message (the new user turn) is always kept.
"""
max_tokens = Config.MAX_CONTENT_LENGTH // CHARS_PER_TOKEN

if not messages:
return messages

# Estimate per-message tokens
token_counts = [_estimate_tokens(_message_content(m)) for m in messages]
total_tokens = sum(token_counts)

if total_tokens <= max_tokens:
return messages

# Always keep the last message; trim from the front
trimmed = list(messages)
trimmed_tokens = list(token_counts)

while len(trimmed) > 1 and sum(trimmed_tokens) > max_tokens:
trimmed.pop(0)
trimmed_tokens.pop(0)

logger.info(
"Trimmed conversation history from %d to %d messages "
"(estimated tokens: %d -> %d, limit: %d)",
len(messages),
len(trimmed),
total_tokens,
sum(trimmed_tokens),
max_tokens,
)
return trimmed
9 changes: 4 additions & 5 deletions services/chatbot/src/chatbot/chat_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from quart import Blueprint, jsonify, request

from .agent_utils import trim_messages_to_token_limit
from .chat_service import (delete_chat_history, get_chat_history,
process_user_message)
from .config import Config
Expand Down Expand Up @@ -229,8 +230,7 @@ async def state():
"Provider API key for session %s: %s", session_id, provider_api_key[:5]
)
chat_history = await get_chat_history(session_id)
# Limit chat history to last 20 messages
chat_history = chat_history[-20:]
chat_history = trim_messages_to_token_limit(chat_history)
return (
jsonify(
{
Expand Down Expand Up @@ -259,16 +259,15 @@ async def history():
provider_api_key = await get_api_key(session_id)
if provider in {"openai", "anthropic"} and provider_api_key:
chat_history = await get_chat_history(session_id)
# Limit chat history to last 20 messages
chat_history = chat_history[-20:]
chat_history = trim_messages_to_token_limit(chat_history)
return jsonify({"chat_history": chat_history}), 200
if provider in {"openai", "anthropic"}:
return (
jsonify({"chat_history": []}),
200,
)
chat_history = await get_chat_history(session_id)
chat_history = chat_history[-20:] if chat_history else []
chat_history = trim_messages_to_token_limit(chat_history) if chat_history else []
return jsonify({"chat_history": chat_history}), 200


Expand Down
4 changes: 2 additions & 2 deletions services/chatbot/src/chatbot/chat_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from langgraph.graph.message import Messages

from .agent_utils import trim_messages_to_token_limit
from .config import Config
from .extensions import db
from .langgraph_agent import execute_langgraph_agent
Expand Down Expand Up @@ -80,8 +81,7 @@ async def process_user_message(session_id, user_message, api_key, model_name, us
)
logger.debug("Added messages to Chroma collection - session_id: %s", session_id)

# Limit chat history to last 20 messages
history = history[-20:]
history = trim_messages_to_token_limit(history)
await update_chat_history(session_id, history)
logger.info(
"Message processing complete - session_id: %s, response_id: %s, history_count: %d",
Expand Down
2 changes: 1 addition & 1 deletion services/chatbot/src/chatbot/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ class Config:
AWS_ROLE_SESSION_NAME = os.getenv("AWS_ROLE_SESSION_NAME", "crapi-chatbot-session")
VERTEX_PROJECT = os.getenv("VERTEX_PROJECT", "")
VERTEX_LOCATION = os.getenv("VERTEX_LOCATION", "")
MAX_CONTENT_LENGTH = int(os.getenv("MAX_CONTENT_LENGTH", 50000))
MAX_CONTENT_LENGTH = int(os.getenv("MAX_CONTENT_LENGTH", 100000))
CHROMA_HOST = CHROMA_HOST
CHROMA_PORT = CHROMA_PORT
3 changes: 2 additions & 1 deletion services/chatbot/src/chatbot/langgraph_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from langchain_mistralai import ChatMistralAI
from langchain_openai import AzureChatOpenAI, ChatOpenAI

from .agent_utils import truncate_tool_messages
from .agent_utils import trim_messages_to_token_limit, truncate_tool_messages
from .aws_credentials import get_bedrock_credentials_kwargs
from .config import Config
from .extensions import postgresdb
Expand Down Expand Up @@ -263,6 +263,7 @@ async def execute_langgraph_agent(
len(messages),
)
agent = await build_langgraph_agent(api_key, model_name, user_jwt)
messages = trim_messages_to_token_limit(messages)
logger.debug("Invoking agent with %d messages", len(messages))
response = await agent.ainvoke({"messages": messages})
logger.info(
Expand Down
Empty file.
192 changes: 192 additions & 0 deletions services/chatbot/src/chatbot/tests/test_agent_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
"""Tests for trim_messages_to_token_limit and supporting helpers."""
import sys
from types import ModuleType
from unittest.mock import MagicMock, patch

import pytest

# ---------------------------------------------------------------------------
# Stub out heavy third-party deps so the module can be imported without them.
# ---------------------------------------------------------------------------
_STUBS = {}
for mod_name in [
"langchain", "langchain.agents", "langchain.agents.middleware",
"langchain.agents.middleware.types",
"langchain_core", "langchain_core.messages",
"langgraph", "langgraph.runtime",
"motor", "motor.motor_asyncio",
"langchain_community", "langchain_community.agent_toolkits",
"langchain_community.utilities",
"pymongo",
]:
if mod_name not in sys.modules:
stub = ModuleType(mod_name)
sys.modules[mod_name] = stub
_STUBS[mod_name] = stub

# Provide the decorator used by agent_utils at import time
sys.modules["langchain.agents"].AgentState = dict
sys.modules["langchain.agents.middleware.types"].before_model = (
lambda **kw: (lambda fn: fn)
)
sys.modules["langchain_core.messages"].ToolMessage = type("ToolMessage", (), {})
sys.modules["langgraph.runtime"].Runtime = type("Runtime", (), {})

# Stub dotenv
dotenv_stub = ModuleType("dotenv")
dotenv_stub.load_dotenv = lambda *a, **kw: None
sys.modules["dotenv"] = dotenv_stub

# Stub dbconnections before config is imported
db_stub = ModuleType("chatbot.dbconnections")
db_stub.CHROMA_HOST = "localhost"
db_stub.CHROMA_PORT = 8000
db_stub.MONGO_CONNECTION_URI = "mongodb://localhost"
db_stub.POSTGRES_URI = "postgresql://localhost"
sys.modules["chatbot.dbconnections"] = db_stub

# Now safe to import the module under test
from chatbot.agent_utils import (
CHARS_PER_TOKEN,
_estimate_tokens,
_message_content,
trim_messages_to_token_limit,
)


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _make_msg(role, content):
"""Return a plain dict message like those stored in chat history."""
return {"role": role, "content": content}


# ---------------------------------------------------------------------------
# _estimate_tokens
# ---------------------------------------------------------------------------

class TestEstimateTokens:
def test_empty_string(self):
assert _estimate_tokens("") == 0

def test_known_length(self):
text = "a" * 400 # 400 chars -> 100 tokens
assert _estimate_tokens(text) == 400 // CHARS_PER_TOKEN

def test_short_string(self):
assert _estimate_tokens("hi") == 0 # 2 // 4 == 0


# ---------------------------------------------------------------------------
# _message_content
# ---------------------------------------------------------------------------

class TestMessageContent:
def test_dict_message(self):
assert _message_content({"role": "user", "content": "hello"}) == "hello"

def test_dict_missing_content(self):
assert _message_content({"role": "user"}) == ""

def test_object_message(self):
class Msg:
content = "from object"
assert _message_content(Msg()) == "from object"

def test_object_no_content(self):
class Msg:
pass
assert _message_content(Msg()) == ""


# ---------------------------------------------------------------------------
# trim_messages_to_token_limit
# ---------------------------------------------------------------------------

MAX_CONTENT_LENGTH = 100000 # default


class TestTrimMessagesToTokenLimit:
"""Tests use a patched MAX_CONTENT_LENGTH to keep fixtures small."""

@patch("chatbot.agent_utils.Config.MAX_CONTENT_LENGTH", 400)
def test_under_limit_returns_all(self):
"""Messages totalling fewer tokens than the budget are untouched."""
msgs = [_make_msg("user", "a" * 100), _make_msg("assistant", "b" * 100)]
result = trim_messages_to_token_limit(msgs)
assert len(result) == 2
assert result == msgs

@patch("chatbot.agent_utils.Config.MAX_CONTENT_LENGTH", 400)
def test_over_limit_trims_oldest(self):
"""Oldest messages are dropped first to fit within budget."""
# budget = 400 // 4 = 100 tokens
# Each message = 200 chars = 50 tokens -> 3 msgs = 150 tokens > 100
msgs = [
_make_msg("user", "a" * 200),
_make_msg("assistant", "b" * 200),
_make_msg("user", "c" * 200),
]
result = trim_messages_to_token_limit(msgs)
assert len(result) < 3
# Last message is always preserved
assert result[-1]["content"] == "c" * 200

@patch("chatbot.agent_utils.Config.MAX_CONTENT_LENGTH", 400)
def test_last_message_always_kept(self):
"""Even if a single message exceeds the budget, it is kept."""
msgs = [_make_msg("user", "x" * 800)]
result = trim_messages_to_token_limit(msgs)
assert len(result) == 1
assert result[0]["content"] == "x" * 800

@patch("chatbot.agent_utils.Config.MAX_CONTENT_LENGTH", 400)
def test_trims_from_front_not_back(self):
"""Verify older messages (front) are removed, newer ones (back) stay."""
# budget = 100 tokens; each msg = 50 tokens
msgs = [
_make_msg("user", "first-" + "a" * 194),
_make_msg("assistant", "second-" + "b" * 193),
_make_msg("user", "third-" + "c" * 194),
]
result = trim_messages_to_token_limit(msgs)
assert result[-1]["content"].startswith("third-")
assert not any(m["content"].startswith("first-") for m in result)

def test_empty_messages(self):
assert trim_messages_to_token_limit([]) == []

@patch("chatbot.agent_utils.Config.MAX_CONTENT_LENGTH", MAX_CONTENT_LENGTH)
def test_default_limit_is_derived_from_max_content_length(self):
"""Token budget should be MAX_CONTENT_LENGTH // CHARS_PER_TOKEN."""
expected_token_budget = MAX_CONTENT_LENGTH // CHARS_PER_TOKEN
# Create messages just under the budget -> no trimming
msg_chars = (expected_token_budget - 1) * CHARS_PER_TOKEN
msgs = [_make_msg("user", "a" * msg_chars)]
result = trim_messages_to_token_limit(msgs)
assert len(result) == 1

@patch("chatbot.agent_utils.Config.MAX_CONTENT_LENGTH", MAX_CONTENT_LENGTH)
def test_result_fits_within_token_budget(self):
"""After trimming, estimated tokens must be <= budget."""
token_budget = MAX_CONTENT_LENGTH // CHARS_PER_TOKEN
# 20 messages each ~2500 tokens = 50000 tokens, well over 25000 budget
msgs = [_make_msg("user" if i % 2 == 0 else "assistant", "x" * 10000)
for i in range(20)]
result = trim_messages_to_token_limit(msgs)
result_tokens = sum(_estimate_tokens(m["content"]) for m in result)
assert result_tokens <= token_budget

@patch("chatbot.agent_utils.Config.MAX_CONTENT_LENGTH", 400)
def test_does_not_mutate_original(self):
"""The original message list must not be modified."""
msgs = [
_make_msg("user", "a" * 200),
_make_msg("assistant", "b" * 200),
_make_msg("user", "c" * 200),
]
original_len = len(msgs)
trim_messages_to_token_limit(msgs)
assert len(msgs) == original_len
Loading