diff --git a/services/chatbot/src/chatbot/agent_utils.py b/services/chatbot/src/chatbot/agent_utils.py index 5aa8de3f..16cd5719 100644 --- a/services/chatbot/src/chatbot/agent_utils.py +++ b/services/chatbot/src/chatbot/agent_utils.py @@ -1,4 +1,5 @@ import json +import logging from langchain.agents import AgentState from langchain.agents.middleware.types import before_model @@ -7,7 +8,11 @@ from .config import Config +logger = logging.getLogger(__name__) + INDIVIDUAL_MIN_LENGTH = 100 +# Approximate characters per token across providers +CHARS_PER_TOKEN = 4 def collect_long_strings(obj): @@ -88,3 +93,53 @@ def truncate_tool_messages(state: AgentState, runtime: Runtime) -> AgentState: else: modified_messages.append(msg) return {"messages": modified_messages} + + +def _estimate_tokens(text): + """Estimate token count using character-based approximation.""" + return len(text) // CHARS_PER_TOKEN + + +def _message_content(msg): + """Extract text content from a message dict or object.""" + if isinstance(msg, dict): + return msg.get("content", "") + return getattr(msg, "content", "") + + +def trim_messages_to_token_limit(messages): + """ + Trim conversation history from the oldest messages to fit within the token + budget derived from MAX_CONTENT_LENGTH. + The most recent message (the new user turn) is always kept. + """ + max_tokens = Config.MAX_CONTENT_LENGTH // CHARS_PER_TOKEN + + if not messages: + return messages + + # Estimate per-message tokens + token_counts = [_estimate_tokens(_message_content(m)) for m in messages] + total_tokens = sum(token_counts) + + if total_tokens <= max_tokens: + return messages + + # Always keep the last message; trim from the front + trimmed = list(messages) + trimmed_tokens = list(token_counts) + + while len(trimmed) > 1 and sum(trimmed_tokens) > max_tokens: + trimmed.pop(0) + trimmed_tokens.pop(0) + + logger.info( + "Trimmed conversation history from %d to %d messages " + "(estimated tokens: %d -> %d, limit: %d)", + len(messages), + len(trimmed), + total_tokens, + sum(trimmed_tokens), + max_tokens, + ) + return trimmed diff --git a/services/chatbot/src/chatbot/chat_api.py b/services/chatbot/src/chatbot/chat_api.py index 28a8bd4d..7f439835 100644 --- a/services/chatbot/src/chatbot/chat_api.py +++ b/services/chatbot/src/chatbot/chat_api.py @@ -4,6 +4,7 @@ from quart import Blueprint, jsonify, request +from .agent_utils import trim_messages_to_token_limit from .chat_service import (delete_chat_history, get_chat_history, process_user_message) from .config import Config @@ -229,8 +230,7 @@ async def state(): "Provider API key for session %s: %s", session_id, provider_api_key[:5] ) chat_history = await get_chat_history(session_id) - # Limit chat history to last 20 messages - chat_history = chat_history[-20:] + chat_history = trim_messages_to_token_limit(chat_history) return ( jsonify( { @@ -259,8 +259,7 @@ async def history(): provider_api_key = await get_api_key(session_id) if provider in {"openai", "anthropic"} and provider_api_key: chat_history = await get_chat_history(session_id) - # Limit chat history to last 20 messages - chat_history = chat_history[-20:] + chat_history = trim_messages_to_token_limit(chat_history) return jsonify({"chat_history": chat_history}), 200 if provider in {"openai", "anthropic"}: return ( @@ -268,7 +267,7 @@ async def history(): 200, ) chat_history = await get_chat_history(session_id) - chat_history = chat_history[-20:] if chat_history else [] + chat_history = trim_messages_to_token_limit(chat_history) if chat_history else [] return jsonify({"chat_history": chat_history}), 200 diff --git a/services/chatbot/src/chatbot/chat_service.py b/services/chatbot/src/chatbot/chat_service.py index 024878f4..8f5de1f7 100644 --- a/services/chatbot/src/chatbot/chat_service.py +++ b/services/chatbot/src/chatbot/chat_service.py @@ -3,6 +3,7 @@ from langgraph.graph.message import Messages +from .agent_utils import trim_messages_to_token_limit from .config import Config from .extensions import db from .langgraph_agent import execute_langgraph_agent @@ -80,8 +81,7 @@ async def process_user_message(session_id, user_message, api_key, model_name, us ) logger.debug("Added messages to Chroma collection - session_id: %s", session_id) - # Limit chat history to last 20 messages - history = history[-20:] + history = trim_messages_to_token_limit(history) await update_chat_history(session_id, history) logger.info( "Message processing complete - session_id: %s, response_id: %s, history_count: %d", diff --git a/services/chatbot/src/chatbot/config.py b/services/chatbot/src/chatbot/config.py index 7645e501..6e79d3ba 100644 --- a/services/chatbot/src/chatbot/config.py +++ b/services/chatbot/src/chatbot/config.py @@ -34,6 +34,6 @@ class Config: AWS_ROLE_SESSION_NAME = os.getenv("AWS_ROLE_SESSION_NAME", "crapi-chatbot-session") VERTEX_PROJECT = os.getenv("VERTEX_PROJECT", "") VERTEX_LOCATION = os.getenv("VERTEX_LOCATION", "") - MAX_CONTENT_LENGTH = int(os.getenv("MAX_CONTENT_LENGTH", 50000)) + MAX_CONTENT_LENGTH = int(os.getenv("MAX_CONTENT_LENGTH", 100000)) CHROMA_HOST = CHROMA_HOST CHROMA_PORT = CHROMA_PORT diff --git a/services/chatbot/src/chatbot/langgraph_agent.py b/services/chatbot/src/chatbot/langgraph_agent.py index 9e728466..3f7a96a9 100644 --- a/services/chatbot/src/chatbot/langgraph_agent.py +++ b/services/chatbot/src/chatbot/langgraph_agent.py @@ -11,7 +11,7 @@ from langchain_mistralai import ChatMistralAI from langchain_openai import AzureChatOpenAI, ChatOpenAI -from .agent_utils import truncate_tool_messages +from .agent_utils import trim_messages_to_token_limit, truncate_tool_messages from .aws_credentials import get_bedrock_credentials_kwargs from .config import Config from .extensions import postgresdb @@ -263,6 +263,7 @@ async def execute_langgraph_agent( len(messages), ) agent = await build_langgraph_agent(api_key, model_name, user_jwt) + messages = trim_messages_to_token_limit(messages) logger.debug("Invoking agent with %d messages", len(messages)) response = await agent.ainvoke({"messages": messages}) logger.info( diff --git a/services/chatbot/src/chatbot/tests/__init__.py b/services/chatbot/src/chatbot/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/services/chatbot/src/chatbot/tests/test_agent_utils.py b/services/chatbot/src/chatbot/tests/test_agent_utils.py new file mode 100644 index 00000000..79023411 --- /dev/null +++ b/services/chatbot/src/chatbot/tests/test_agent_utils.py @@ -0,0 +1,192 @@ +"""Tests for trim_messages_to_token_limit and supporting helpers.""" +import sys +from types import ModuleType +from unittest.mock import MagicMock, patch + +import pytest + +# --------------------------------------------------------------------------- +# Stub out heavy third-party deps so the module can be imported without them. +# --------------------------------------------------------------------------- +_STUBS = {} +for mod_name in [ + "langchain", "langchain.agents", "langchain.agents.middleware", + "langchain.agents.middleware.types", + "langchain_core", "langchain_core.messages", + "langgraph", "langgraph.runtime", + "motor", "motor.motor_asyncio", + "langchain_community", "langchain_community.agent_toolkits", + "langchain_community.utilities", + "pymongo", +]: + if mod_name not in sys.modules: + stub = ModuleType(mod_name) + sys.modules[mod_name] = stub + _STUBS[mod_name] = stub + +# Provide the decorator used by agent_utils at import time +sys.modules["langchain.agents"].AgentState = dict +sys.modules["langchain.agents.middleware.types"].before_model = ( + lambda **kw: (lambda fn: fn) +) +sys.modules["langchain_core.messages"].ToolMessage = type("ToolMessage", (), {}) +sys.modules["langgraph.runtime"].Runtime = type("Runtime", (), {}) + +# Stub dotenv +dotenv_stub = ModuleType("dotenv") +dotenv_stub.load_dotenv = lambda *a, **kw: None +sys.modules["dotenv"] = dotenv_stub + +# Stub dbconnections before config is imported +db_stub = ModuleType("chatbot.dbconnections") +db_stub.CHROMA_HOST = "localhost" +db_stub.CHROMA_PORT = 8000 +db_stub.MONGO_CONNECTION_URI = "mongodb://localhost" +db_stub.POSTGRES_URI = "postgresql://localhost" +sys.modules["chatbot.dbconnections"] = db_stub + +# Now safe to import the module under test +from chatbot.agent_utils import ( + CHARS_PER_TOKEN, + _estimate_tokens, + _message_content, + trim_messages_to_token_limit, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_msg(role, content): + """Return a plain dict message like those stored in chat history.""" + return {"role": role, "content": content} + + +# --------------------------------------------------------------------------- +# _estimate_tokens +# --------------------------------------------------------------------------- + +class TestEstimateTokens: + def test_empty_string(self): + assert _estimate_tokens("") == 0 + + def test_known_length(self): + text = "a" * 400 # 400 chars -> 100 tokens + assert _estimate_tokens(text) == 400 // CHARS_PER_TOKEN + + def test_short_string(self): + assert _estimate_tokens("hi") == 0 # 2 // 4 == 0 + + +# --------------------------------------------------------------------------- +# _message_content +# --------------------------------------------------------------------------- + +class TestMessageContent: + def test_dict_message(self): + assert _message_content({"role": "user", "content": "hello"}) == "hello" + + def test_dict_missing_content(self): + assert _message_content({"role": "user"}) == "" + + def test_object_message(self): + class Msg: + content = "from object" + assert _message_content(Msg()) == "from object" + + def test_object_no_content(self): + class Msg: + pass + assert _message_content(Msg()) == "" + + +# --------------------------------------------------------------------------- +# trim_messages_to_token_limit +# --------------------------------------------------------------------------- + +MAX_CONTENT_LENGTH = 100000 # default + + +class TestTrimMessagesToTokenLimit: + """Tests use a patched MAX_CONTENT_LENGTH to keep fixtures small.""" + + @patch("chatbot.agent_utils.Config.MAX_CONTENT_LENGTH", 400) + def test_under_limit_returns_all(self): + """Messages totalling fewer tokens than the budget are untouched.""" + msgs = [_make_msg("user", "a" * 100), _make_msg("assistant", "b" * 100)] + result = trim_messages_to_token_limit(msgs) + assert len(result) == 2 + assert result == msgs + + @patch("chatbot.agent_utils.Config.MAX_CONTENT_LENGTH", 400) + def test_over_limit_trims_oldest(self): + """Oldest messages are dropped first to fit within budget.""" + # budget = 400 // 4 = 100 tokens + # Each message = 200 chars = 50 tokens -> 3 msgs = 150 tokens > 100 + msgs = [ + _make_msg("user", "a" * 200), + _make_msg("assistant", "b" * 200), + _make_msg("user", "c" * 200), + ] + result = trim_messages_to_token_limit(msgs) + assert len(result) < 3 + # Last message is always preserved + assert result[-1]["content"] == "c" * 200 + + @patch("chatbot.agent_utils.Config.MAX_CONTENT_LENGTH", 400) + def test_last_message_always_kept(self): + """Even if a single message exceeds the budget, it is kept.""" + msgs = [_make_msg("user", "x" * 800)] + result = trim_messages_to_token_limit(msgs) + assert len(result) == 1 + assert result[0]["content"] == "x" * 800 + + @patch("chatbot.agent_utils.Config.MAX_CONTENT_LENGTH", 400) + def test_trims_from_front_not_back(self): + """Verify older messages (front) are removed, newer ones (back) stay.""" + # budget = 100 tokens; each msg = 50 tokens + msgs = [ + _make_msg("user", "first-" + "a" * 194), + _make_msg("assistant", "second-" + "b" * 193), + _make_msg("user", "third-" + "c" * 194), + ] + result = trim_messages_to_token_limit(msgs) + assert result[-1]["content"].startswith("third-") + assert not any(m["content"].startswith("first-") for m in result) + + def test_empty_messages(self): + assert trim_messages_to_token_limit([]) == [] + + @patch("chatbot.agent_utils.Config.MAX_CONTENT_LENGTH", MAX_CONTENT_LENGTH) + def test_default_limit_is_derived_from_max_content_length(self): + """Token budget should be MAX_CONTENT_LENGTH // CHARS_PER_TOKEN.""" + expected_token_budget = MAX_CONTENT_LENGTH // CHARS_PER_TOKEN + # Create messages just under the budget -> no trimming + msg_chars = (expected_token_budget - 1) * CHARS_PER_TOKEN + msgs = [_make_msg("user", "a" * msg_chars)] + result = trim_messages_to_token_limit(msgs) + assert len(result) == 1 + + @patch("chatbot.agent_utils.Config.MAX_CONTENT_LENGTH", MAX_CONTENT_LENGTH) + def test_result_fits_within_token_budget(self): + """After trimming, estimated tokens must be <= budget.""" + token_budget = MAX_CONTENT_LENGTH // CHARS_PER_TOKEN + # 20 messages each ~2500 tokens = 50000 tokens, well over 25000 budget + msgs = [_make_msg("user" if i % 2 == 0 else "assistant", "x" * 10000) + for i in range(20)] + result = trim_messages_to_token_limit(msgs) + result_tokens = sum(_estimate_tokens(m["content"]) for m in result) + assert result_tokens <= token_budget + + @patch("chatbot.agent_utils.Config.MAX_CONTENT_LENGTH", 400) + def test_does_not_mutate_original(self): + """The original message list must not be modified.""" + msgs = [ + _make_msg("user", "a" * 200), + _make_msg("assistant", "b" * 200), + _make_msg("user", "c" * 200), + ] + original_len = len(msgs) + trim_messages_to_token_limit(msgs) + assert len(msgs) == original_len