-
Notifications
You must be signed in to change notification settings - Fork 2.7k
Add ToolAwareContextFilterPlugin to preserve tool call sequences #4074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add ToolAwareContextFilterPlugin to preserve tool call sequences #4074
Conversation
- Fixes issue where ContextFilterPlugin splits function_call/function_response pairs - Groups tool call sequences as atomic invocations - Prevents OpenAI API errors when filtering conversation history - Adds comprehensive unit tests Fixes google#4027
Summary of ChangesHello @MahorShekhar, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces ToolAwareContextFilterPlugin, a new plugin designed to correctly handle conversation history trimming for agents using tools. The new plugin correctly groups tool call sequences (function call, response, and final model answer) into atomic "invocations", preventing them from being split, which fixes a crash when using ContextFilterPlugin with tools.
The implementation of the plugin is well-structured, and the logic for grouping invocations appears to correctly handle various scenarios, including simple Q&A, full tool call cycles, and consecutive user messages. The accompanying unit tests are comprehensive in their intent, covering many edge cases.
My review includes a critical correction for the unit tests, which are currently not executing the asynchronous code they are meant to test. I've also included a couple of suggestions to improve the plugin's implementation for clarity and efficiency. Overall, this is a valuable addition that addresses a significant issue.
| class TestToolAwareContextFilterPlugin: | ||
| """Tests for ToolAwareContextFilterPlugin.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test methods in this class are not correctly invoking the async before_model_callback method. The pytest.mark.asyncio should be used as a decorator on the test methods themselves, and the await keyword should be used to call the coroutine. Without this change, the async method is never actually run, and the tests will not work as expected.
For example, test_no_filtering_when_disabled should be rewritten as:
@pytest.mark.asyncio
async def test_no_filtering_when_disabled(self):
"""Test that no filtering occurs when num_invocations_to_keep is None."""
plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=None)
contents = [
Content(role="user", parts=[Part(text="Hello")]),
Content(role="model", parts=[Part(text="Hi")]),
Content(role="user", parts=[Part(text="How are you?")]),
Content(role="model", parts=[Part(text="I'm good")]),
]
request = LlmRequest(model="test", contents=contents)
context = CallbackContext(invocation_id="test", agent_name="test")
# Run the plugin
await plugin.before_model_callback(
callback_context=context, llm_request=request
)
# No filtering should occur
assert len(request.contents) == 4This correction needs to be applied to all test methods that call before_model_callback.
Additionally, test_error_handling is flawed because LlmRequest(contents=None) will raise a ValidationError before the plugin is even called. A better way to test this would be to set request.contents = None after creating the request object.
| def _has_function_call(self, content) -> bool: | ||
| """Check if a content has a function_call part.""" | ||
| if not content.parts: | ||
| return False | ||
| return any( | ||
| hasattr(part, "function_call") and part.function_call | ||
| for part in content.parts | ||
| ) | ||
|
|
||
| def _has_function_response(self, content) -> bool: | ||
| """Check if a content has a function_response part.""" | ||
| if not content.parts: | ||
| return False | ||
| return any( | ||
| hasattr(part, "function_response") and part.function_response | ||
| for part in content.parts | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The helper methods _has_function_call and _has_function_response do not use any instance attributes (self). They can be declared as staticmethods to make it clear that they don't depend on the state of the plugin instance.
| def _has_function_call(self, content) -> bool: | |
| """Check if a content has a function_call part.""" | |
| if not content.parts: | |
| return False | |
| return any( | |
| hasattr(part, "function_call") and part.function_call | |
| for part in content.parts | |
| ) | |
| def _has_function_response(self, content) -> bool: | |
| """Check if a content has a function_response part.""" | |
| if not content.parts: | |
| return False | |
| return any( | |
| hasattr(part, "function_response") and part.function_response | |
| for part in content.parts | |
| ) | |
| @staticmethod | |
| def _has_function_call(content) -> bool: | |
| """Check if a content has a function_call part.""" | |
| if not content.parts: | |
| return False | |
| return any( | |
| hasattr(part, "function_call") and part.function_call | |
| for part in content.parts | |
| ) | |
| @staticmethod | |
| def _has_function_response(content) -> bool: | |
| """Check if a content has a function_response part.""" | |
| if not content.parts: | |
| return False | |
| return any( | |
| hasattr(part, "function_response") and part.function_response | |
| for part in content.parts | |
| ) |
| indices_to_keep.extend(invocation) | ||
|
|
||
| # Filter contents based on indices | ||
| filtered_contents = [contents[i] for i in sorted(indices_to_keep)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The indices_to_keep list is constructed from invocations, which are already ordered chronologically. The flattening process preserves this order, so indices_to_keep will already be sorted. The call to sorted() is redundant and can be removed for a minor performance improvement.
| filtered_contents = [contents[i] for i in sorted(indices_to_keep)] | |
| filtered_contents = [contents[i] for i in indices_to_keep] |
- Fix async test execution by adding @pytest.mark.asyncio decorators and await keywords - Fix test_error_handling to create valid LlmRequest before setting contents to None - Convert _has_function_call() and _has_function_response() to @staticmethod - Remove redundant sorted() call in filtering logic for minor performance improvement
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new ToolAwareContextFilterPlugin to address an issue where the existing context filter breaks tool call sequences. The new plugin correctly groups messages into logical invocations, ensuring that tool calls, their responses, and the final model reply are kept together during history trimming. The implementation is sound and accompanied by a comprehensive set of unit tests. I've provided a few suggestions to improve type safety and maintainability.
| def __init__( | ||
| self, | ||
| num_invocations_to_keep: Optional[int] = None, | ||
| custom_filter: Optional[Callable[[List[Event]], List[Event]]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type hint for custom_filter is Optional[Callable[[List[Event]], List[Event]]]. However, the filter is applied to llm_request.contents, which is of type List[types.Content]. The type hint should be Optional[Callable[[List[types.Content]], List[types.Content]]] to match the actual usage. You will also need to add from google.genai import types at the top of the file.
| custom_filter: Optional[Callable[[List[Event]], List[Event]]] = None, | |
| custom_filter: Optional[Callable[[List[types.Content]], List[types.Content]]] = None, |
| @staticmethod | ||
| def _has_function_call(content) -> bool: | ||
| """Check if a content has a function_call part.""" | ||
| if not content.parts: | ||
| return False | ||
| return any( | ||
| hasattr(part, "function_call") and part.function_call | ||
| for part in content.parts | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def _has_function_response(content) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The methods _has_function_call and _has_function_response are defined as @staticmethods, but are called using self within _group_into_invocations. While this is valid, it's clearer to call them on the class, e.g., ToolAwareContextFilterPlugin._has_function_call(...), to make it explicit that they don't depend on instance state. Alternatively, you could remove the @staticmethod decorator to make them regular instance methods.
| for part in content.parts | ||
| ) | ||
|
|
||
| def _group_into_invocations(self, contents: List) -> List[List[int]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type hint for contents is List, which is too generic. For better type safety and code clarity, it should be specified as List[types.Content]. This assumes you've added from google.genai import types as suggested in another comment.
| def _group_into_invocations(self, contents: List) -> List[List[int]]: | |
| def _group_into_invocations(self, contents: List[types.Content]) -> List[List[int]]: |
| def _group_into_invocations(self, contents: List) -> List[List[int]]: | ||
| """Group message indices into complete invocations. | ||
|
|
||
| An invocation pattern: | ||
| 1. One or more user messages (including consecutive user messages) | ||
| 2. Model response (possibly with function_call) | ||
| 3. If function_call exists: user message(s) with function_response | ||
| 4. If function_call exists: model final response | ||
|
|
||
| Example grouping: | ||
| Messages: [user, user, model, user, model+func_call, user+func_response, | ||
| model] Groups: [0,1,2] [3,4,5,6] | ||
| ^^^^^^^ ^^^^^^^^^^^ | ||
| Inv 1 Inv 2 (includes tool cycle) | ||
|
|
||
| Args: | ||
| contents: List of message contents to group. | ||
|
|
||
| Returns: | ||
| List of invocations, where each invocation is a list of message indices. | ||
| """ | ||
| invocations = [] | ||
| current_invocation = [] | ||
| i = 0 | ||
|
|
||
| while i < len(contents): | ||
| content = contents[i] | ||
|
|
||
| # CASE 1: User message | ||
| if content.role == "user": | ||
| # Check if this is a function_response (part of ongoing tool cycle) | ||
| if self._has_function_response(content): | ||
| # This is a tool response - must be part of current invocation | ||
| current_invocation.append(i) | ||
| i += 1 | ||
| else: | ||
| # Regular user message (not a function_response) | ||
| # Only start a NEW invocation if we've completed a previous one | ||
| if current_invocation: | ||
| # Check if previous invocation has a model response | ||
| has_model = any( | ||
| contents[idx].role == "model" for idx in current_invocation | ||
| ) | ||
| if has_model: | ||
| invocations.append(current_invocation) | ||
| current_invocation = [] | ||
|
|
||
| # Add this user message to current invocation | ||
| current_invocation.append(i) | ||
| i += 1 | ||
|
|
||
| # CASE 2: Model message | ||
| elif content.role == "model": | ||
| current_invocation.append(i) | ||
|
|
||
| # Check if model is making a tool call | ||
| if self._has_function_call(content): | ||
| # Model made a tool call - keep following messages together: | ||
| # 1. This model message (function_call) - already added | ||
| # 2. User message(s) with function_response - collect next | ||
| # 3. Model's final response - collect after tool responses | ||
|
|
||
| i += 1 # Move to next message | ||
|
|
||
| # Collect all function_response messages (usually 1, but could be | ||
| # multiple) | ||
| while ( | ||
| i < len(contents) | ||
| and contents[i].role == "user" | ||
| and self._has_function_response(contents[i]) | ||
| ): | ||
| current_invocation.append(i) | ||
| i += 1 | ||
|
|
||
| # Now collect the model's final response after processing tool results | ||
| if i < len(contents) and contents[i].role == "model": | ||
| current_invocation.append(i) | ||
| i += 1 | ||
|
|
||
| # Complete tool cycle collected - this is ONE complete invocation | ||
| invocations.append(current_invocation) | ||
| current_invocation = [] | ||
| else: | ||
| # Model response WITHOUT function call - simple case | ||
| # The invocation is complete (user query → model answer) | ||
| i += 1 | ||
| invocations.append(current_invocation) | ||
| current_invocation = [] | ||
| else: | ||
| # Unknown role - just add to current invocation | ||
| current_invocation.append(i) | ||
| i += 1 | ||
|
|
||
| # Add any remaining messages as final invocation | ||
| if current_invocation: | ||
| invocations.append(current_invocation) | ||
|
|
||
| return invocations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _group_into_invocations method is quite long and contains complex nested logic, making it hard to follow. For improved maintainability and readability, consider refactoring it by breaking it down into smaller, more focused helper methods. For example, you could have separate methods for processing 'user' messages and 'model' messages. This would make the main loop simpler and the logic for each case easier to follow and test in isolation.
…terPlugin - Fix type hint for custom_filter to use List[types.Content] instead of List[Event] - Add type hint for contents parameter: List[types.Content] - Refactor _group_into_invocations into smaller helper methods for better maintainability: - _finalize_invocation_if_complete: Finalize current invocation if complete - _process_user_message: Handle user message processing - _process_model_message: Handle model message processing - _process_model_message_with_tool_call: Handle tool call sequences - Remove unused Event import - Add google.genai.types import for proper type annotations
|
Thanks — this aligns with the failure mode described in #4027 (orphaned tool responses after trimming). Quick question for maintainers: do you see this fitting better as a separate plugin (as proposed here), or would you prefer ContextFilterPlugin itself to become tool-aware when tool messages are present (or at least emit a warning)? Also, do the current tests cover scenarios with multiple tool calls in a single assistant turn or multiple tool responses back-to-back? |
|
Hi @MahorShekhar ,Thank you for your contribution! We appreciate you taking the time to submit this pull request. |
|
Hi @wuliang229 , as @bluemeaford asked would it be better to have |
Add ToolAwareContextFilterPlugin to preserve tool call sequences
Problem Statement
The current
ContextFilterPlugincauses errors when filtering conversation history for agents that use tools. This is because it treats each model message as a separate "invocation", which breaks tool call sequences.Root Cause
When an agent uses tools, a single logical invocation creates multiple model messages:
function_call(initiating tool use)function_response(tool execution result)The existing
ContextFilterPlugincounts these as 3 separate invocations instead of 1 complete invocation, causing it to split related messages apart.Impact
When using OpenAI-compatible providers (Azure OpenAI, OpenAI, etc.) with
num_invocations_to_keep, applications crash with:Example of the Bug
Conversation:
With
num_invocations_to_keep=2, standard ContextFilterPlugin:#6#5[#5, #6, #7, #8, #9]#5 function_responseis orphaned without#4 function_call!Solution
This PR adds
ToolAwareContextFilterPluginthat groups messages into logical invocations, treating tool call sequences as atomic units.How It Works
The plugin correctly identifies complete invocations:
[#1, #2]- Simple Q&A[#3, #4, #5, #6]- Q&A with tool call (kept atomic!)[#7, #8, #9]- Incomplete tool cycleWith
num_invocations_to_keep=2:[#3, #4, #5, #6, #7, #8, #9]Key Features
function_callandfunction_responseparts to group related messagesContextFilterPluginChanges Made
New Files
src/google/adk/plugins/tool_aware_context_filter_plugin.py- The new plugin implementationtests/unittests/plugins/test_tool_aware_context_filter_plugin.py- Comprehensive unit testsTesting
All unit tests pass, covering:
Usage Example
Migration Path
For existing users of
ContextFilterPluginwho use tools, simply replace:With:
Testing Plan
Local Testing
pytest ./tests/unittests/plugins/test_tool_aware_context_filter_plugin.pynum_invocations_to_keep=2End-to-End Testing
Setup: Agent with knowledge_base tool using Azure OpenAI
Test: Have a conversation with 4+ invocations involving tool calls, set
num_invocations_to_keep=2Result:
Logs showing success:
Documentation Impact
This change adds a new plugin. Documentation updates needed:
ContextFilterPluginfor tool-using agentsRelated Issues
Fixes #[4027] - ContextFilterPlugin breaks tool call sequences
Checklist
./autoformat.sh)Screenshots/Logs
Before (with standard ContextFilterPlugin):
After (with ToolAwareContextFilterPlugin):