Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 138 additions & 15 deletions src/google/adk/agents/config_agent_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,34 @@

from __future__ import annotations

import contextvars
import importlib
import inspect
import os
from typing import Any
from typing import Dict
from typing import List

from typing_extensions import deprecated
import yaml

from ..features import experimental
from ..features import FeatureName
from ..workflow._base_node import BaseNode
from .agent_config import AgentConfig
from .base_agent import BaseAgent
from .base_agent_config import BaseAgentConfig
from .common_configs import AgentRefConfig
from .common_configs import CodeConfig

_loaded_nodes_cache: contextvars.ContextVar[Dict[str, BaseNode]] = (
contextvars.ContextVar("loaded_nodes_cache")
)


@deprecated("from_config is deprecated and will be removed in future versions.")
@experimental(FeatureName.AGENT_CONFIG)
def from_config(config_path: str) -> BaseAgent:
def from_config(config_path: str) -> BaseNode:
"""Build agent from a configfile path.

Args:
Expand All @@ -49,36 +56,149 @@ def from_config(config_path: str) -> BaseAgent:
ValueError: If agent type is unsupported.
"""
abs_path = os.path.abspath(config_path)

try:
cache = _loaded_nodes_cache.get()
except LookupError:
cache = None

if cache is None:
cache_dict: dict[str, BaseNode] = {}
token = _loaded_nodes_cache.set(cache_dict)
try:
return _from_config(abs_path)
finally:
_loaded_nodes_cache.reset(token)
else:
if abs_path in cache:
return cache[abs_path]
node = _from_config(abs_path)
cache[abs_path] = node
return node


def _from_config(abs_path: str) -> BaseNode:
config = _load_config_from_path(abs_path)
agent_config = config.root

# pylint: disable=unidiomatic-typecheck Needs exact class matching.
if type(agent_config) is BaseAgentConfig:
# Resolve the concrete agent config for user-defined agent classes.
agent_class = _resolve_agent_class(agent_config.agent_class)
agent_class = _resolve_node_class(agent_config.agent_class)

from ..workflow._function_node import FunctionNode
from ..workflow._join_node import JoinNode
from ..workflow._tool_node import _ToolNode

if issubclass(agent_class, (JoinNode, FunctionNode, _ToolNode)):
if issubclass(agent_class, JoinNode):
return agent_class(name=agent_config.name)

elif issubclass(agent_class, FunctionNode):
func_code = agent_config.model_extra.get("func_code")
if not func_code:
raise ValueError(
f"FunctionNode {agent_config.name} configuration is missing"
" 'func_code'"
)

func = resolve_fully_qualified_name(func_code)

kwargs = {
"func": func,
"name": agent_config.name,
}
if "rerun_on_resume" in agent_config.model_extra:
kwargs["rerun_on_resume"] = agent_config.model_extra[
"rerun_on_resume"
]
if "parameter_binding" in agent_config.model_extra:
kwargs["parameter_binding"] = agent_config.model_extra[
"parameter_binding"
]

return agent_class(**kwargs)

elif issubclass(agent_class, _ToolNode):
tool_code = agent_config.model_extra.get("tool_code")
if not tool_code:
raise ValueError(
f"ToolNode {agent_config.name} configuration is missing"
" 'tool_code'"
)

from ..tools.base_tool import BaseTool
from ..tools.function_tool import FunctionTool
from ..tools.tool_configs import ToolArgsConfig
from ..tools.tool_configs import ToolConfig
from .llm_agent import LlmAgent

args = agent_config.model_extra.get("args")
tool_args = ToolArgsConfig.model_validate(args) if args else None
tool_config = ToolConfig(name=tool_code, args=tool_args)

resolved = LlmAgent._resolve_tools([tool_config], abs_path)
if not resolved:
raise ValueError(
f"Failed to resolve tool {tool_code} for ToolNode"
f" {agent_config.name}"
)

tool = resolved[0]
if not isinstance(tool, BaseTool):
if callable(tool):
tool = FunctionTool(tool)
else:
raise ValueError(
f"Resolved tool {tool_code} is neither a BaseTool nor callable"
)

return agent_class(tool=tool, name=agent_config.name)

agent_config = agent_class.config_type.model_validate(
agent_config.model_dump()
)
return agent_class.from_config(agent_config, abs_path)
from_config_fn = getattr(agent_class, "from_config", None)
if not from_config_fn:
raise ValueError(
f"Agent class {agent_class.__name__} does not implement 'from_config'"
)
node: BaseNode = from_config_fn(agent_config, abs_path)
return node
else:
# For built-in agent classes, no need to re-validate.
agent_class = _resolve_agent_class(agent_config.agent_class)
return agent_class.from_config(agent_config, abs_path)
agent_class = _resolve_node_class(agent_config.agent_class)
from_config_fn = getattr(agent_class, "from_config", None)
if not from_config_fn:
raise ValueError(
f"Agent class {agent_class.__name__} does not implement 'from_config'"
)
built_in_node: BaseNode = from_config_fn(agent_config, abs_path)
return built_in_node


def _resolve_agent_class(agent_class: str) -> type[BaseAgent]:
def _resolve_node_class(agent_class: str) -> type[BaseNode]:
"""Resolve the agent class from its fully qualified name."""
agent_class_name = agent_class or "LlmAgent"
if "." not in agent_class_name:
agent_class_name = f"google.adk.agents.{agent_class_name}"
if agent_class_name == "Workflow":
agent_class_name = "google.adk.workflow.Workflow"
elif agent_class_name == "JoinNode":
agent_class_name = "google.adk.workflow.JoinNode"
elif agent_class_name == "FunctionNode":
agent_class_name = "google.adk.workflow.FunctionNode"
elif agent_class_name == "ToolNode":
agent_class_name = "google.adk.workflow._tool_node._ToolNode"
else:
agent_class_name = f"google.adk.agents.{agent_class_name}"

agent_class = resolve_fully_qualified_name(agent_class_name)
if inspect.isclass(agent_class) and issubclass(agent_class, BaseAgent):
if inspect.isclass(agent_class) and issubclass(agent_class, BaseNode):
return agent_class

raise ValueError(
f"Invalid agent class `{agent_class_name}`. It must be a subclass of"
" BaseAgent."
" BaseNode."
)


Expand All @@ -102,7 +222,8 @@ def _load_config_from_path(config_path: str) -> AgentConfig:
with open(config_path, "r", encoding="utf-8") as f:
config_data = yaml.safe_load(f)

return AgentConfig.model_validate(config_data)
config: AgentConfig = AgentConfig.model_validate(config_data)
return config


@experimental(FeatureName.AGENT_CONFIG)
Expand All @@ -118,7 +239,7 @@ def resolve_fully_qualified_name(name: str) -> Any:
@experimental(FeatureName.AGENT_CONFIG)
def resolve_agent_reference(
ref_config: AgentRefConfig, referencing_agent_config_abs_path: str
) -> BaseAgent:
) -> BaseNode:
"""Build an agent from a reference.

Args:
Expand All @@ -131,21 +252,23 @@ def resolve_agent_reference(
"""
if ref_config.config_path:
if os.path.isabs(ref_config.config_path):
return from_config(ref_config.config_path)
node: BaseNode = from_config(ref_config.config_path)
return node
else:
return from_config(
rel_node: BaseNode = from_config(
os.path.join(
os.path.dirname(referencing_agent_config_abs_path),
ref_config.config_path,
)
)
return rel_node
elif ref_config.code:
return _resolve_agent_code_reference(ref_config.code)
else:
raise ValueError("AgentRefConfig must have either 'code' or 'config_path'")


def _resolve_agent_code_reference(code: str) -> Any:
def _resolve_agent_code_reference(code: str) -> BaseNode:
"""Resolve a code reference to an actual agent instance.

Args:
Expand All @@ -167,7 +290,7 @@ def _resolve_agent_code_reference(code: str) -> Any:
if callable(obj):
raise ValueError(f"Invalid agent reference to a callable: {code}")

if not isinstance(obj, BaseAgent):
if not isinstance(obj, BaseNode):
raise ValueError(f"Invalid agent reference to a non-agent instance: {code}")

return obj
Expand Down
6 changes: 4 additions & 2 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,8 +591,10 @@ async def _drive_root_node():
async for event in agen:
yield event
finally:
await self._cleanup_root_task(task, self.agent.name)
await ic.plugin_manager.run_after_run_callback(invocation_context=ic)
try:
await self._cleanup_root_task(task, self.agent.name)
finally:
await ic.plugin_manager.run_after_run_callback(invocation_context=ic)
if self.app and self.app.events_compaction_config:
logger.debug('Running event compactor.')
from google.adk.apps.compaction import _run_compaction_for_sliding_window
Expand Down
47 changes: 47 additions & 0 deletions src/google/adk/workflow/_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@
from dataclasses import field
import logging
from typing import Any
from typing import ClassVar
from typing import TYPE_CHECKING

from pydantic import Field

from ..agents.base_agent_config import BaseAgentConfig
from ._base_node import BaseNode
from ._base_node import START
from ._dynamic_node_scheduler import DynamicNodeScheduler
Expand Down Expand Up @@ -154,6 +156,51 @@ class Workflow(BaseNode):
- FINALIZE: collect terminal outputs
"""

config_type: ClassVar[type[BaseAgentConfig]] = BaseAgentConfig

@classmethod
def from_config(
cls,
config: Any,
config_abs_path: str,
) -> Workflow:
"""Creates a Workflow from a config."""
import os

from ..agents.config_agent_utils import from_config as load_agent_config

kwargs = {
'name': config.name,
'description': config.description,
}

if config.model_extra:
config_dir = os.path.dirname(config_abs_path)

def resolve_node_refs(val: Any) -> Any:
if isinstance(val, str):
if val == 'START':
return 'START'
if val.endswith('.yaml') or val.endswith('.yml'):
return load_agent_config(os.path.join(config_dir, val))
return val
if isinstance(val, list):
# Convert inner edge chains/tuples to tuples as required by EdgeItem
return tuple(resolve_node_refs(x) for x in val)
if isinstance(val, dict):
return {k: resolve_node_refs(v) for k, v in val.items()}
if isinstance(val, tuple):
return tuple(resolve_node_refs(x) for x in val)
return val

for key, value in config.model_extra.items():
if key == 'edges' and isinstance(value, list):
kwargs['edges'] = [resolve_node_refs(item) for item in value]
elif key in cls.model_fields and key not in kwargs:
kwargs[key] = value

return cls(**kwargs)

rerun_on_resume: bool = Field(default=True)

edges: list[EdgeItem] = Field(
Expand Down
Loading
Loading