Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 87 additions & 8 deletions src/google/adk/flows/llm_flows/_code_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@

logger = logging.getLogger('google_adk.' + __name__)

_AVAILABLE_FILE_PREFIX = 'Available file:'


@dataclasses.dataclass
class DataFileUtil:
Expand Down Expand Up @@ -206,7 +208,7 @@ async def _run_pre_processor(
# memory. Meanwhile, mutate the inline data file to text part in session
# history from all turns.
all_input_files = _extract_and_replace_inline_files(
code_executor_context, llm_request
code_executor_context, llm_request, invocation_context
)

# [Step 2] Run Explore_Df code on the data files from the current turn. We
Expand Down Expand Up @@ -375,20 +377,42 @@ async def _run_post_processor(
def _extract_and_replace_inline_files(
code_executor_context: CodeExecutorContext,
llm_request: LlmRequest,
invocation_context: InvocationContext,
) -> list[File]:
"""Extracts and replaces inline files with file names in the LLM request."""
"""Extracts and replaces inline files with file names in the LLM request.

This function modifies both `llm_request.contents` for the current request
and `invocation_context.session.events` to ensure the replacement of inline
data with placeholders persists across conversation turns.

Args:
code_executor_context: Context containing code executor state.
llm_request: The LLM request to process.
invocation_context: Context containing session information.

Returns:
List of extracted File objects.
"""
all_input_files = code_executor_context.get_input_files()
saved_file_names = set(f.name for f in all_input_files)

# [Step 1] Process input files from LlmRequest and cache them in CodeExecutor.
# Track which session events need to be updated
events_to_update = {}

# Process input files from LlmRequest and cache them in CodeExecutor.
for i in range(len(llm_request.contents)):
content = llm_request.contents[i]
# Only process the user message.
if content.role != 'user' and not content.parts:
if content.role != 'user' or not content.parts:
continue

for j in range(len(content.parts)):
part = content.parts[j]

# Skip if already processed (already a placeholder)
if part.text and _AVAILABLE_FILE_PREFIX in part.text:
continue

# Skip if the inline data is not supported.
if (
not part.inline_data
Expand All @@ -399,21 +423,76 @@ def _extract_and_replace_inline_files(
# Replace the inline data file with a file name placeholder.
mime_type = part.inline_data.mime_type
file_name = f'data_{i+1}_{j+1}' + _DATA_FILE_UTIL_MAP[mime_type].extension
llm_request.contents[i].parts[j] = types.Part(
text='\nAvailable file: `%s`\n' % file_name
)
placeholder_text = f'\n{_AVAILABLE_FILE_PREFIX} `{file_name}`\n'

# Store inline_data before replacing
inline_data_copy = part.inline_data

# Replace in llm_request
llm_request.contents[i].parts[j] = types.Part(text=placeholder_text)

# Find and update the corresponding session event
# to persist the replacement across turns
session = invocation_context.session
for event_idx, event in enumerate(session.events):
if (
event.content
and event.content.role == 'user'
and len(event.content.parts) > j
):
event_part = event.content.parts[j]
# Match by inline_data content (comparing mime_type, length, and data)
# Length check first for performance optimization
if (
event_part.inline_data
and event_part.inline_data.mime_type == mime_type
and len(event_part.inline_data.data) == len(inline_data_copy.data)
and event_part.inline_data.data == inline_data_copy.data
):
# Mark this event/part for update
if event_idx not in events_to_update:
events_to_update[event_idx] = {}
events_to_update[event_idx][j] = placeholder_text
break

# Add the inline data as input file to the code executor context.
file = File(
name=file_name,
content=CodeExecutionUtils.get_encoded_file_content(
part.inline_data.data
inline_data_copy.data
).decode(),
mime_type=mime_type,
)
if file_name not in saved_file_names:
code_executor_context.add_input_files([file])
all_input_files.append(file)
saved_file_names.add(file_name)

# Apply updates to session.events to persist across turns
session = invocation_context.session
for event_idx, parts_to_update in events_to_update.items():
event = session.events[event_idx]
# Create new parts list with replacements
updated_parts = list(event.content.parts)
for part_idx, placeholder_text in parts_to_update.items():
updated_parts[part_idx] = types.Part(text=placeholder_text)

# Create new content with updated parts
updated_content = types.Content(
role=event.content.role, parts=updated_parts
)

# Update the event in session (modify in place)
# Event is a Pydantic model, use model_copy() instead of dataclasses.replace()
session.events[event_idx] = event.model_copy(
update={'content': updated_content}
)

logger.debug(
'Replaced inline_data in session.events[%d] with placeholder: %s',
event_idx,
placeholder_text.strip(),
)

return all_input_files

Expand Down
Loading