diff --git a/src/agents/_agent.py b/src/agents/_agent.py index 2c98c988..84a435d5 100644 --- a/src/agents/_agent.py +++ b/src/agents/_agent.py @@ -42,7 +42,11 @@ # when such Node is propagated with or for interoperation with KNIME. # ------------------------------------------------------------------------ -from langchain_core.messages import AIMessage +from dataclasses import dataclass +from typing import Protocol +from langchain_core.messages import AIMessage, ToolMessage +from langchain_core.language_models.chat_models import BaseChatModel + LANGGRAPH_RECURSION_MESSAGE = "Sorry, need more steps to process this request." RECURSION_CONTINUE_PROMPT = ( @@ -85,5 +89,76 @@ def validate_ai_message(msg: AIMessage): ) -class RecursionError(RuntimeError): +class Conversation(Protocol): + def append_messages(self, messages): ... + + def append_error(self, error): ... + + def get_messages(self): ... + + +class Toolset(Protocol): + # need to be openai compatible + @property + def tools(self): ... + + def execute(self, tool_calls) -> list[ToolMessage]: ... + + +class Context(Protocol): + def is_canceled(self) -> bool: ... + + +@dataclass +class AgentConfig: + iteration_limit: int = 10 + + +class IterationLimitError(RuntimeError): pass + + +class CancelError(RuntimeError): + pass + + +class Agent: + def __init__( + self, + conversation: Conversation, + llm: BaseChatModel, + toolset: Toolset, + config: AgentConfig, + ): + tools = toolset.tools + if tools: + self._agent = llm.bind_tools(toolset.tools) + else: + self._agent = llm + self._conversation = conversation + self._config = config + self._toolset = toolset + + def run(self): + """Run the agent's turn in the conversation.""" + for _ in range(self._config.iteration_limit): + try: + response = self._agent.invoke(self._conversation.get_messages()) + except Exception as error: + self._conversation.append_error(error) + return + + self._conversation.append_messages(response) + + if response.tool_calls: + try: + results = self._toolset.execute(response.tool_calls) + except Exception as error: + self._conversation.append_error(error) + return + + self._conversation.append_messages(results) + else: + return + + raise IterationLimitError("Reached iteration limit") diff --git a/src/agents/_data_service.py b/src/agents/_data_service.py index cd6e9437..d3c185f8 100644 --- a/src/agents/_data_service.py +++ b/src/agents/_data_service.py @@ -45,71 +45,75 @@ from ._data import DataRegistry from ._tool import LangchainToolConverter from ._agent import ( - RecursionError, + CancelError, + IterationLimitError, validate_ai_message, RECURSION_CONTINUE_PROMPT, - LANGGRAPH_RECURSION_MESSAGE, + Agent, ) from ._parameters import RecursionLimitModeForView +from dataclasses import dataclass import yaml import queue import threading import knime.extension as knext - from langchain_core.messages.human import HumanMessage -from knime.types.message import from_langchain_message +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .base import AgentPrompterConversation -import pandas as pd + +@dataclass +class AgentChatWidgetConfig: + initial_message: str + conversation_column_name: str + recursion_limit_handling: str + show_tool_calls_and_results: bool + reexecution_trigger: str + has_error_column: bool + error_column_name: str class AgentChatWidgetDataService: def __init__( self, ctx, - agent_graph, + chat_model, + conversation: "AgentPrompterConversation", + toolset, + agent_config, data_registry: DataRegistry, - initial_message: str, - conversation_column_name: str, - recursion_limit: int, - recursion_limit_handling: str, - show_tool_calls_and_results: bool, - reexecution_trigger: str, + widget_config: AgentChatWidgetConfig, tool_converter: LangchainToolConverter, combined_tools_workflow_info: dict, ): self._ctx = ctx - self._agent_graph = agent_graph + + self._chat_model = chat_model + self._conversation = FrontendConversation( + conversation, tool_converter, self._check_canceled + ) + self._agent_config = agent_config + self._agent = Agent( + self._conversation, self._chat_model, toolset, self._agent_config + ) + self._data_registry = data_registry - self._tool_converter = tool_converter - self._conversation_column_name = conversation_column_name - self._initial_message = initial_message - self._recursion_limit = recursion_limit - self._recursion_limit_handling = recursion_limit_handling - self._show_tool_calls_and_results = show_tool_calls_and_results - self._reexecution_trigger = reexecution_trigger + self._widget_config = widget_config self._get_combined_tools_workflow_info = combined_tools_workflow_info - self._message_queue = queue.Queue() + self._message_queue = self._conversation.frontend self._thread = None - - @property - def _config(self): - return { - "recursion_limit": self._recursion_limit, - "configurable": {"thread_id": "1"}, - } - - @property - def _messages(self): - return self._agent_graph.get_state(self._config).values.get("messages", []) + self._is_canceled = False # TODO modified by frontend def get_initial_message(self): - if self._initial_message: + if self._widget_config.initial_message: return { "type": "ai", - "content": self._initial_message, + "content": self._widget_config.initial_message, } def post_user_message(self, user_message: str): @@ -146,8 +150,8 @@ def is_processing(self): def get_configuration(self): return { - "show_tool_calls_and_results": self._show_tool_calls_and_results, - "reexecution_trigger": self._reexecution_trigger, + "show_tool_calls_and_results": self._widget_config.show_tool_calls_and_results, + "reexecution_trigger": self._widget_config.reexecution_trigger, } def get_combined_tools_workflow_info(self): @@ -155,12 +159,15 @@ def get_combined_tools_workflow_info(self): # called by java, not the frontend def get_view_data(self): - desanitized_messages = [ - self._tool_converter.desanitize_tool_names(msg) for msg in self._messages - ] - message_values = [from_langchain_message(msg) for msg in desanitized_messages] - conversation_df = pd.DataFrame({self._conversation_column_name: message_values}) - conversation_table = knext.Table.from_pandas(conversation_df) + error_column_name = ( + self._widget_config.error_column_name + if self._widget_config.has_error_column + else None + ) + conversation_table = self._conversation.create_output_table( + self._widget_config.conversation_column_name, + error_column_name, + ) meta_data, tables = self._data_registry.dump() view_data = { @@ -171,42 +178,129 @@ def get_view_data(self): return view_data def _post_user_message(self, user_message: str): - from langgraph.errors import GraphRecursionError + from langchain_core.messages import AIMessage + + self._conversation.append_messages_to_backend( + HumanMessage(content=user_message) + ) try: - state_stream = self._agent_graph.stream( - {"messages": [HumanMessage(content=user_message)]}, - self._config, - stream_mode="updates", # streams state by state incrementally + self._agent.run() + + messages = self._conversation.get_messages() + if messages and isinstance(messages[-1], AIMessage): + validate_ai_message(messages[-1]) + + except CancelError as e: + raise e + except IterationLimitError: + self._handle_recursion_limit_error() + except Exception as e: + self._conversation.append_error(e) + + def _handle_recursion_limit_error(self): + from langchain_core.messages import AIMessage + + if ( + self._widget_config.recursion_limit_handling + == RecursionLimitModeForView.CONFIRM.name + ): + messages = [AIMessage(RECURSION_CONTINUE_PROMPT)] + self._conversation._append_messages(messages) + else: + content = ( + f"Recursion limit of {self._agent_config.iteration_limit} reached." ) + self._conversation.append_error(Exception(content)) - final_state = None - for state in state_stream: - final_state = state["agent"] if "agent" in state else state["tools"] - new_messages = final_state["messages"] + def _check_canceled(self): + return self._is_canceled - for new_message in new_messages: - # already added - if isinstance(new_message, HumanMessage): - continue - if new_message.content != LANGGRAPH_RECURSION_MESSAGE: - fe_messages = self._to_frontend_messages(new_message) - for fe_msg in fe_messages: - self._message_queue.put(fe_msg) - else: - raise RecursionError("Recursion limit was reached.") +class FrontendConversation: + def __init__( + self, + backend: "AgentPrompterConversation", + tool_converter, + check_canceled, + ): + self._frontend = queue.Queue() + self._backend_messages = backend + self._tool_converter = tool_converter + self._check_canceled = check_canceled - if final_state and final_state["messages"]: - validate_ai_message(self._messages[-1]) + @property + def frontend(self): + return self._frontend - except (GraphRecursionError, RecursionError): - self._handle_recursion_limit_error() - except Exception as e: - content = f"An error occurred: {e}" - error_message = {"type": "error", "content": content} - self._message_queue.put(error_message) - self._append_ai_message_to_memory(content) + def append_messages(self, messages): + """Appends messages to both backend and frontend. + Raises a CancelError and sanitizes the final message if the context was canceled.""" + from langchain_core.messages import AIMessage, BaseMessage + + if isinstance(messages, BaseMessage): + messages = [messages] + + if self._check_canceled and self._check_canceled(): + self._append_messages(messages[:-1]) + + # sanitize last message + final_message = messages[-1] + if not (isinstance(final_message, AIMessage) and final_message.tool_calls): + self._append_messages([final_message]) + + error = CancelError("Execution canceled.") + self.append_error_to_frontend(error) + raise error + else: + self._append_messages(messages) + + def _append_messages(self, messages): + """Appends messages to both backend and frontend without checking for cancellation.""" + from langchain_core.messages import HumanMessage + + # will not raise since backend has no context + self._backend_messages.append_messages(messages) + + for new_message in messages: + if isinstance(new_message, HumanMessage): + continue + + fe_messages = self._to_frontend_messages(new_message) + for fe_msg in fe_messages: + self._frontend.put(fe_msg) + + def append_messages_to_backend(self, messages): + """Appends messages only to the backend conversation.""" + # will not raise since backend has no context + self._backend_messages.append_messages(messages) + + def append_error(self, error: Exception): + """Appends an error to both backend and frontend.""" + self._backend_messages.append_error(error) + self.append_error_to_frontend(error) + + def append_error_to_frontend(self, error: Exception): + """Appends an error only to the frontend.""" + if not isinstance(error, CancelError): + content = f"An error occurred: {error}" + else: + content = str(error) + + error_message = {"type": "error", "content": content} + self._frontend.put(error_message) + + def get_messages(self): + return self._backend_messages.get_messages() + + def create_output_table( + self, + output_column_name: str, + error_column_name: str = None, + ) -> knext.Table: + return self._backend_messages.create_output_table( + self._tool_converter, output_column_name, error_column_name + ) def _to_frontend_messages(self, message): # split the node-view-ids out into a separate message @@ -258,28 +352,3 @@ def _render_tool_call(self, tool_call): "name": self._tool_converter.desanitize_tool_name(tool_call["name"]), "args": yaml.dump(args, indent=2) if args else None, } - - def _handle_recursion_limit_error(self): - if self._recursion_limit_handling == RecursionLimitModeForView.CONFIRM.name: - message = { - "type": "ai", - "content": RECURSION_CONTINUE_PROMPT, - } - self._message_queue.put(message) - self._append_ai_message_to_memory(RECURSION_CONTINUE_PROMPT) - else: - content = f"Recursion limit of {self._recursion_limit} reached." - error_message = { - "type": "error", - "content": content, - } - self._message_queue.put(error_message) - self._append_ai_message_to_memory(content) - - def _append_ai_message_to_memory(self, message: str): - from langchain_core import messages as lcm - - ai_message = lcm.AIMessage(message) - self._agent_graph.update_state( - self._config, {"messages": [ai_message]}, "agent" - ) diff --git a/src/agents/_error_handler.py b/src/agents/_error_handler.py new file mode 100644 index 00000000..7aef7ded --- /dev/null +++ b/src/agents/_error_handler.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- +# ------------------------------------------------------------------------ +# Copyright by KNIME AG, Zurich, Switzerland +# Website: http://www.knime.com; Email: contact@knime.com +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License, Version 3, as +# published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, see . +# +# Additional permission under GNU GPL version 3 section 7: +# +# KNIME interoperates with ECLIPSE solely via ECLIPSE's plug-in APIs. +# Hence, KNIME and ECLIPSE are both independent programs and are not +# derived from each other. Should, however, the interpretation of the +# GNU GPL Version 3 ("License") under any applicable laws result in +# KNIME and ECLIPSE being a combined program, KNIME AG herewith grants +# you the additional permission to use and propagate KNIME together with +# ECLIPSE with only the license terms in place for ECLIPSE applying to +# ECLIPSE and the GNU GPL Version 3 applying for KNIME, provided the +# license terms of ECLIPSE themselves allow for the respective use and +# propagation of ECLIPSE together with KNIME. +# +# Additional permission relating to nodes for KNIME that extend the Node +# Extension (and in particular that are based on subclasses of NodeModel, +# NodeDialog, and NodeView) and that only interoperate with KNIME through +# standard APIs ("Nodes"): +# Nodes are deemed to be separate and independent programs and to not be +# covered works. Notwithstanding anything to the contrary in the +# License, the License does not apply to Nodes, you are not required to +# license Nodes under the License, and you are granted a license to +# prepare and propagate Nodes, in each case even if such Nodes are +# propagated with or for interoperation with KNIME. The owner of a Node +# may freely choose the license terms applicable to such Node, including +# when such Node is propagated with or for interoperation with KNIME. +# ------------------------------------------------------------------------ + +""" +Error handling logic for Agent Prompter node execution. +""" + +import knime.extension as knext +from ._parameters import RecursionLimitMode, ErrorHandlingMode +from ._agent import IterationLimitError, CancelError + + +class AgentPrompterErrorHandler: + """ + Handles error conditions during agent execution based on configured modes. + + This class encapsulates the error handling logic for the Agent Prompter node, + including handling of iteration limit errors and general exceptions based on + configured error handling modes. + """ + + def __init__( + self, + conversation, + recursion_limit: int, + recursion_limit_handling: RecursionLimitMode, + error_handling: ErrorHandlingMode, + chat_model, + recursion_limit_prompt: str, + ): + """ + Initialize the error handler. + + Args: + conversation: The AgentPrompterConversation instance + recursion_limit: Maximum number of agent iterations + recursion_limit_handling: How to handle recursion limit errors + error_handling: How to handle general errors + chat_model: The chat model for generating final responses + recursion_limit_prompt: Prompt to use when generating final response + """ + self._conversation = conversation + self._recursion_limit = recursion_limit + self._recursion_limit_handling = recursion_limit_handling + self._error_handling = error_handling + self._chat_model = chat_model + self._recursion_limit_prompt = recursion_limit_prompt + + def handle_error(self, exception: Exception) -> None: + """ + Handle an exception based on its type and configured modes. + + Args: + exception: The exception that was caught + """ + if isinstance(exception, CancelError): + raise exception + elif isinstance(exception, IterationLimitError): + self._handle_iteration_limit_error() + else: + self._handle_general_error(exception) + + def _handle_iteration_limit_error(self) -> None: + """ + Handle an IterationLimitError based on the configured recursion limit mode. + + If mode is FINAL_RESPONSE, generates a final response using the chat model. + Otherwise, creates an error message and delegates to error mode handling. + """ + if self._recursion_limit_handling == RecursionLimitMode.FINAL_RESPONSE: + self._generate_final_response() + else: + error_message = f"""Recursion limit of {self._recursion_limit} reached. + You can increase the limit by setting the `recursion_limit` parameter to a higher value.""" + self._handle_error_by_mode(error_message) + + def _handle_general_error(self, exception: Exception) -> None: + """ + Handle a general exception based on the configured error handling mode. + + Args: + exception: The exception that was caught + """ + error_message = f"An error occurred while executing the agent: {exception}" + self._handle_error_by_mode(error_message) + + def _handle_error_by_mode(self, error_message: str) -> None: + """ + Handle an error based on the configured error handling mode. + + Args: + error_message: The error message to handle + + Raises: + knext.InvalidParametersError: If error_handling mode is FAIL + """ + if self._error_handling == ErrorHandlingMode.FAIL: + raise knext.InvalidParametersError(error_message) + else: + self._conversation.append_error(Exception(error_message)) + + def _generate_final_response(self) -> None: + """ + Generate a final response when the recursion limit is reached. + + This method appends a human message with the recursion limit prompt, + invokes the chat model to generate a final response, and appends + that response to the conversation. + """ + import langchain_core.messages as lcm + + messages = self._conversation.get_messages() + messages = messages + [lcm.HumanMessage(self._recursion_limit_prompt)] + final_response = self._chat_model.invoke(messages) + self._conversation.append_messages(final_response) diff --git a/src/agents/_parameters.py b/src/agents/_parameters.py index 1c239da9..f6a43e40 100644 --- a/src/agents/_parameters.py +++ b/src/agents/_parameters.py @@ -46,6 +46,31 @@ import knime_extension as knext +class RecursionLimitMode(knext.EnumParameterOptions): + FAIL = ( + "Fail", + "Execution fails if the recursion limit is reached.", + ) + FINAL_RESPONSE = ( + "Final response", + "A user message is appended to the conversation that prompts the LLM to generate a final response " + "based on all previously generated messages.", + ) + + +class ErrorHandlingMode(knext.EnumParameterOptions): + FAIL = ( + "Fail", + "Execution fails if there is an error.", + ) + COLUMN = ( + "Error column", + "Execution is stopped if there is an error. The output table will contain an additional string " + "column containing error messages. If there was an error, the corresponding row will contain a " + "missing value in the conversation column.", + ) + + class RecursionLimitModeForView(knext.EnumParameterOptions): FAIL = ( "Fail", @@ -60,7 +85,7 @@ class RecursionLimitModeForView(knext.EnumParameterOptions): def recursion_limit_mode_param_for_view(): return knext.EnumParameter( - "Recursion limit handling", + "If recursion limit is reached", "Specify how the agent should behave when the recursion limit is reached.", RecursionLimitModeForView.FAIL.name, RecursionLimitModeForView, diff --git a/src/agents/base.py b/src/agents/base.py index a3308ef1..69143884 100644 --- a/src/agents/base.py +++ b/src/agents/base.py @@ -70,11 +70,17 @@ FilestorePortObject, ) from base import AIPortObjectSpec -from ._parameters import recursion_limit_mode_param_for_view -from ._agent import RecursionError +from ._parameters import ( + recursion_limit_mode_param_for_view, + RecursionLimitMode, + ErrorHandlingMode, +) +from ._agent import CancelError +from ._error_handler import AgentPrompterErrorHandler import os import logging +from langchain_core.messages import BaseMessage _logger = logging.getLogger(__name__) @@ -107,23 +113,6 @@ class ChatMessageSettings: ) -class RecursionLimitMode(knext.EnumParameterOptions): - FAIL = ( - "Fail", - "Execution fails if the recursion limit is reached.", - ) - STOP = ( - "Stop", - "Execution is stopped if the recursion limit is reached. The output will contain the messages " - "generated before reaching the limit.", - ) - FINAL_RESPONSE = ( - "Final response", - "A user message is appended to the conversation that prompts the LLM to generate a final response " - "based on all previously generated messages.", - ) - - # TODO: Add agent type in the future? class AgentPortObjectSpec(AIPortObjectSpec): def __init__( @@ -457,6 +446,75 @@ def _data_message_prefix_parameter(): # endregion +@knext.parameter_group( + label="Error Handling Settings", is_advanced=True, since_version="5.10.0" +) +class AgentPrompterErrorSettings: + error_handling = knext.EnumParameter( + "Error handling", + "Specify the behavior of the agent when an error occurs.", + ErrorHandlingMode.FAIL.name, + ErrorHandlingMode, + style=knext.EnumParameter.Style.VALUE_SWITCH, + ) + + use_existing_error_column = knext.BoolParameter( + "Continue existing error column", + "If selected, the output table will continue the error column selected from the input conversation table.", + default_value=False, + ).rule( + knext.And( + knext.DialogContextCondition(has_conversation_table), + knext.OneOf( + error_handling, + [ErrorHandlingMode.COLUMN.name], + ), + ), + knext.Effect.SHOW, + ) + + error_column_name = knext.StringParameter( + "Error column name", + "Name of the newly generated error column.", + default_value="Errors", + ).rule( + knext.Or( + knext.OneOf( + error_handling, + [ErrorHandlingMode.FAIL.name], + ), + knext.And( + knext.OneOf( + use_existing_error_column, + [True], + ), + knext.DialogContextCondition(has_conversation_table), + ), + ), + knext.Effect.HIDE, + ) + + error_column = knext.ColumnParameter( + "Error column", + "The column containing the errors if a conversation history table is connected.", + port_index=2, + column_filter=util.create_type_filter(knext.string()), + ).rule( + knext.And( + knext.DialogContextCondition(has_conversation_table), + knext.OneOf( + error_handling, + [ErrorHandlingMode.COLUMN.name], + ), + knext.OneOf( + use_existing_error_column, + [True], + ), + ), + knext.Effect.SHOW, + ) + + # region Agent Prompter 2.0 @knext.node( "Agent Prompter", @@ -524,7 +582,7 @@ class AgentPrompter2: recursion_limit = _recursion_limit_parameter() recursion_limit_handling = knext.EnumParameter( - "Recursion limit handling", + "When recursion limit is reached", "Specify the behavior of the agent when the recursion limit is reached.", RecursionLimitMode.FAIL.name, RecursionLimitMode, @@ -556,6 +614,8 @@ class AgentPrompter2: data_message_prefix = _data_message_prefix_parameter() + errors = AgentPrompterErrorSettings() + def configure( self, ctx: knext.ConfigurationContext, @@ -566,6 +626,7 @@ def configure( ) -> knext.Schema: chat_model_spec.validate_context(ctx) self._configure_tool_tables(tools_schema) + self._check_column_names(history_schema) return self._create_conversation_schema(history_schema), [ None @@ -579,9 +640,7 @@ def _configure_tool_tables(self, tools_schema: knext.Schema): f"Column {self.tool_column} not found in the tools table." ) - def _create_conversation_schema( - self, history_schema: Optional[knext.Schema] - ) -> knext.Schema: + def _check_column_names(self, history_schema: Optional[knext.Schema]): if history_schema is not None: if self.conversation_column is None: self.conversation_column = _last_history_column(history_schema) @@ -589,15 +648,44 @@ def _create_conversation_schema( raise knext.InvalidParametersError( f"Column {self.conversation_column} not found in the conversation history table." ) - return knext.Schema.from_columns( - [ - knext.Column(_message_type(), self.conversation_column), - ] - ) - # Use user-provided column name if no history table is given - return knext.Schema.from_columns( - [knext.Column(_message_type(), self.conversation_column_name)] - ) + + if self.errors.error_handling == ErrorHandlingMode.COLUMN.name: + if history_schema is not None and self.errors.use_existing_error_column: + if self.errors.error_column is None: + self.errors.error_column = util.pick_default_column( + history_schema, knext.string() + ) + if self.errors.error_column not in history_schema.column_names: + raise knext.InvalidParametersError( + f"Column {self.errors.error_column} not found in the conversation history table." + ) + if self.conversation_column == self.errors.error_column: + raise knext.InvalidParametersError( + "The selected conversation and error columns must not be equal." + ) + else: + if self.conversation_column_name == self.errors.error_column_name: + raise knext.InvalidParametersError( + "The specified conversation and error column names must not be equal." + ) + + def _create_conversation_schema( + self, history_schema: Optional[knext.Schema] + ) -> knext.Schema: + if history_schema is not None: + convo_column_name = self.conversation_column + else: + convo_column_name = self.conversation_column_name + columns = [knext.Column(_message_type(), convo_column_name)] + + if self.errors.error_handling == ErrorHandlingMode.COLUMN.name: + if history_schema is not None and self.errors.use_existing_error_column: + error_column_name = self.errors.error_column + else: + error_column_name = self.errors.error_column_name + columns.append(knext.Column(knext.string(), error_column_name)) + + return knext.Schema.from_columns(columns) def execute( self, @@ -608,18 +696,12 @@ def execute( input_tables: list[knext.Table], ): from langchain.chat_models.base import BaseChatModel - import pandas as pd from ._data_service import ( DataRegistry, LangchainToolConverter, ) from ._tool import ExecutionMode - from ._agent import validate_ai_message - from langgraph.prebuilt import create_react_agent - from knime.types.message import to_langchain_message, from_langchain_message - from langgraph.checkpoint.memory import InMemorySaver - from util import check_canceled - import langchain_core.messages as lcm + from ._agent import Agent, AgentConfig data_registry = DataRegistry.create_with_input_tables( input_tables, data_message_prefix=self.data_message_prefix @@ -635,150 +717,88 @@ def execute( ) tool_cells = _extract_tools_from_table(tools_table, self.tool_column) - tools = [tool_converter.to_langchain_tool(tool) for tool in tool_cells] + toolset = AgentPrompterToolset(tools) - messages = [] - - if history_table is not None: - if self.conversation_column not in history_table.column_names: - raise knext.InvalidParametersError( - f"Column {self.conversation_column} not found in the conversation history table." - ) - history_df = history_table[self.conversation_column].to_pandas() - messages = [] - for msg in history_df[self.conversation_column]: - lc_msg = to_langchain_message(msg) - # Sanitize tool names so they match the current sanitized mapping - lc_msg = tool_converter.sanitize_tool_names(lc_msg) - messages.append(lc_msg) - - if data_registry.has_data or tool_converter.has_data_tools: - messages.append(data_registry.create_data_message()) - - if self.user_message: - messages.append({"role": "user", "content": self.user_message}) - - num_data_outputs = ctx.get_connected_output_port_numbers()[1] + conversation = self._create_conversation_history( + ctx, data_registry, tool_converter, history_table + ) - interrupted_nodes = ["agent"] - if tools: - interrupted_nodes.append("tools") + config = AgentConfig(self.recursion_limit) + agent = Agent(conversation, chat_model, toolset, config) - memory = InMemorySaver() - graph = create_react_agent( - chat_model, - tools=tools, - prompt=self.developer_message, - checkpointer=memory, - interrupt_before=interrupted_nodes, + error_handler = AgentPrompterErrorHandler( + conversation=conversation, + recursion_limit=self.recursion_limit, + recursion_limit_handling=RecursionLimitMode[self.recursion_limit_handling], + error_handling=ErrorHandlingMode[self.errors.error_handling], + chat_model=chat_model, + recursion_limit_prompt=self.recursion_limit_prompt, ) - inputs = {"messages": messages} - config = { - "recursion_limit": self.recursion_limit, - "configurable": {"thread_id": 1}, - } - try: - final_state = graph.invoke( - inputs, - config=config, - ) - recursion_counter = 1 - while True: - snap = graph.get_state(config) - if not snap.next: # agent finished - break - if recursion_counter >= self.recursion_limit: - if new_state := self._check_recursion_limit(chat_model, snap): - final_state = new_state - break - check_canceled(ctx) - final_state = graph.invoke( - None, - config=config, - ) - recursion_counter += 1 - except RecursionError: - raise knext.InvalidParametersError( - f"""Recursion limit of {self.recursion_limit} reached. - You can increase the limit by setting the `recursion_limit` parameter to a higher value.""" - ) + agent.run() except Exception as e: - raise knext.InvalidParametersError( - f"An error occurred while executing the agent: {e}" - ) + error_handler.handle_error(e) - messages = final_state["messages"] + return self._construct_outputs( + history_table, conversation, tool_converter, data_registry, ctx + ) - if isinstance(messages[-1], lcm.AIMessage): - try: - validate_ai_message(messages[-1]) - except Exception as e: - ctx.set_warning(str(e)) + def _create_conversation_history( + self, ctx, data_registry, tool_converter, history_table + ) -> "AgentPrompterConversation": + error_column = ( + self.errors.error_column + if self.errors.error_handling == ErrorHandlingMode.COLUMN.name + and self.errors.use_existing_error_column + else None + ) - desanitized_messages = [ - tool_converter.desanitize_tool_names(msg) for msg in messages - ] + return _create_agent_conversation( + ctx=ctx, + developer_message=self.developer_message, + data_registry=data_registry, + tool_converter=tool_converter, + history_table=history_table, + conversation_column=self.conversation_column, + error_column=error_column, + error_handling=self.errors.error_handling, + user_message=self.user_message, + ) + def _construct_outputs( + self, + history_table, + conversation, + tool_converter, + data_registry, + ctx, + ): output_column_name = ( self.conversation_column_name if history_table is None else self.conversation_column ) + error_column_name = None + if self.errors.error_handling == ErrorHandlingMode.COLUMN.name: + error_column_name = ( + self.errors.error_column + if self.errors.use_existing_error_column and history_table is not None + else self.errors.error_column_name + ) - result_df = pd.DataFrame( - { - output_column_name: [ - from_langchain_message(msg) for msg in desanitized_messages - ] - } + conversation_table = conversation.create_output_table( + tool_converter, output_column_name, error_column_name ) - conversation_table = knext.Table.from_pandas(result_df) - + num_data_outputs = ctx.get_connected_output_port_numbers()[1] if num_data_outputs == 0: return conversation_table else: # allow the model to pick which output tables to return return conversation_table, data_registry.get_last_tables(num_data_outputs) - def _check_recursion_limit(self, chat_model, snap): - """Depending on the selected handling option returns a new final state or None, or raises an error - that should be caught.""" - if self.recursion_limit_handling == RecursionLimitMode.FINAL_RESPONSE.name: - import langchain_core.messages as lcm - - messages = snap.values.get("messages") - if hasattr(messages[-1], "tool_calls"): - tool_calls = messages[-1].tool_calls - previous_content = ( - f"The content of the original message was: {messages[-1].content}." - if messages[-1].content - else "" - ) - messages = messages[:-1] + [ - lcm.HumanMessage( - "I deleted the previous AI message from the conversation because the agent reached " - "the recursion limit and tried to make tool calls. " - f"The tools it called were named: {', '.join([x.get('name', '') for x in tool_calls])}. " - + previous_content - + "\n" - + self.recursion_limit_prompt - ) - ] - else: - messages = messages + [lcm.HumanMessage(self.recursion_limit_prompt)] - final_response = chat_model.invoke([self.developer_message] + messages) - return {"messages": messages + [final_response]} - elif self.recursion_limit_handling == RecursionLimitMode.FAIL.name: - raise RecursionError( - "Recursion limit was reached." - ) # turned into user-facing message - else: - return None - def _extract_tools_from_table(tools_table: knext.Table, tool_column: str): import pyarrow.compute as pc @@ -788,6 +808,266 @@ def _extract_tools_from_table(tools_table: knext.Table, tool_column: str): return filtered_tools.to_pylist() +def _populate_conversation_from_history( + conversation: "AgentPrompterConversation", + history_table: knext.Table, + conversation_column: str, + error_column: Optional[str], + tool_converter, +) -> bool: + """ + Loads conversation history and error messages from a table. + Returns True if at least one message was successfully loaded. + """ + from knime.types.message import to_langchain_message + import pandas as pd + + def append_sanitized(msg): + lc_msg = to_langchain_message(msg) + lc_msg = tool_converter.sanitize_tool_names(lc_msg) + conversation.append_messages(lc_msg) + + has_history_messages = False + + if error_column is not None: + history_df = history_table[[conversation_column, error_column]].to_pandas() + for idx, (msg, err) in enumerate( + history_df[[conversation_column, error_column]].itertuples( + index=False, name=None + ) + ): + has_msg = pd.notna(msg) + has_err = pd.notna(err) + + if has_msg and has_err: + row_id = history_df.index[idx] + raise RuntimeError( + f"Conversation table contains row with both message and error. Row ID: {row_id}" + ) + if has_msg: + append_sanitized(msg) + has_history_messages = True + elif has_err: + conversation.append_error(Exception(err)) + else: + row_id = history_df.index[idx] + raise RuntimeError( + f"Conversation table contains empty row. Row ID: {row_id}" + ) + else: + history_df = history_table[conversation_column].to_pandas() + for msg in history_df[conversation_column]: + if pd.notna(msg): + append_sanitized(msg) + has_history_messages = True + + return has_history_messages + + +def _create_agent_conversation( + ctx: Optional[knext.ExecutionContext], + developer_message: Optional[str], + data_registry, + tool_converter, + history_table: Optional[knext.Table] = None, + conversation_column: Optional[str] = None, + error_column: Optional[str] = None, + error_handling: Optional[str] = None, + user_message: Optional[str] = None, + interactive: bool = False, +) -> "AgentPrompterConversation": + """ + Consolidated orchestrator to initialize an AgentPrompterConversation. + Used by both AgentPrompter2 and AgentChatWidget. + """ + from langchain_core.messages import SystemMessage, HumanMessage + + conversation = AgentPrompterConversation(error_handling, ctx) + has_history = False + + if developer_message: + conversation.append_messages(SystemMessage(developer_message)) + + if history_table is not None: + has_history = _populate_conversation_from_history( + conversation, + history_table, + conversation_column, + error_column, + tool_converter, + ) + + # For interactive usage (ChatWidget), we only want to append the data message + # if it's a new conversation (no history). + # For one-shot usage (AgentPrompter2), we always append the data message. + if (not interactive or not has_history) and ( + data_registry.has_data or tool_converter.has_data_tools + ): + conversation.append_messages(data_registry.create_data_message()) + + if user_message: + conversation.append_messages(HumanMessage(user_message)) + + return conversation + + +class AgentPrompterConversation: + def __init__(self, error_handling, ctx: knext.ExecutionContext = None): + self._error_handling = error_handling + self._message_and_errors = [] + self._is_message = [] + self._ctx = ctx + + def append_messages(self, messages): + """Raises a CancelError if the context was canceled.""" + from langchain_core.messages import AIMessage + from ._agent import validate_ai_message + + if isinstance(messages, BaseMessage): + messages = [messages] + + if self._ctx and self._ctx.is_canceled(): + raise CancelError("Execution canceled.") + + for msg in messages: + # Validate AI messages as they are added + if isinstance(msg, AIMessage): + try: + validate_ai_message(msg) + except Exception as e: + if self._error_handling == ErrorHandlingMode.FAIL.name: + # For FAIL mode, set warning and don't add the message + if self._ctx: + self._ctx.set_warning(str(e)) + continue + else: + # For COLUMN mode, append the error and skip the message + self._append(e) + continue + self._append(msg) + + def append_error(self, error): + if not isinstance(error, Exception): + raise error + + if self._error_handling == ErrorHandlingMode.FAIL.name: + raise error + else: + self._append(error) + + def get_messages(self): + messages = [ + moe + for is_msg, moe in zip(self._is_message, self._message_and_errors) + if is_msg + ] + return messages + + def _append(self, message_or_error): + self._message_and_errors.append(message_or_error) + self._is_message.append(isinstance(message_or_error, BaseMessage)) + + def _construct_output(self): + return [ + { + "message": moe if is_msg else None, + "error": moe if not is_msg else None, + } + for is_msg, moe in zip(self._is_message, self._message_and_errors) + ] + + def create_output_table( + self, + tool_converter, + output_column_name: str, + error_column_name: str = None, + ) -> knext.Table: + import pandas as pd + from knime.types.message import from_langchain_message + from langchain_core.messages import SystemMessage + + def to_knime_message_or_none(msg, tool_converter): + """Convert a message to KNIME format, or return None if input is None.""" + if msg is None: + return None + desanitized = tool_converter.desanitize_tool_names(msg) + return from_langchain_message(desanitized) + + if error_column_name is None: + messages = self.get_messages() + if messages and isinstance(messages[0], SystemMessage): + messages = messages[1:] + result_df = pd.DataFrame( + { + output_column_name: [ + to_knime_message_or_none(msg, tool_converter) + for msg in messages + ] + } + ) + else: + messages_and_errors = [ + moe + for moe in self._construct_output() + if not isinstance(moe["message"], SystemMessage) + ] + messages = [ + to_knime_message_or_none(moe["message"], tool_converter) + for moe in messages_and_errors + ] + errors = [ + str(moe["error"]) if moe["error"] is not None else None + for moe in messages_and_errors + ] + result_df = pd.DataFrame( + {output_column_name: messages, error_column_name: errors} + ) + + if not any(messages): + result_df[output_column_name] = result_df[output_column_name].astype( + _message_type().to_pandas() + ) + if not any(errors): + result_df[error_column_name] = result_df[error_column_name].astype( + "string" + ) + + return knext.Table.from_pandas(result_df) + + +class AgentPrompterToolset: + def __init__(self, tools): + self._by_name: dict = {t.name: t for t in tools} + + @property + def tools(self): + return list(self._by_name.values()) + + def execute(self, tool_calls): + from langchain_core.messages import ToolMessage + + results = [] + for tool_call in tool_calls: + if tool_call["name"] not in self._by_name: + msg = f"Error: Tool '{tool_call['name']}' not found among available tools: {list(self._by_name.keys())}" + results.append(ToolMessage(msg, tool_call_id=tool_call["id"])) + continue + + tool = self._by_name[tool_call["name"]] + + try: + result = tool.invoke(tool_call["args"]) + except Exception as e: + result = "Error: " + str(e) + + results.append( + ToolMessage( + result, tool_call_id=tool_call["id"], name=tool_call["name"] + ) + ) + return results + + # endregion @@ -897,6 +1177,23 @@ class ReexecutionTrigger(knext.EnumParameterOptions): since_version="5.6.0", ) + has_error_column = knext.BoolParameter( + "Output errors", + "If checked, the output table will contain an additional column that contains error messages. " + "Each row that contains an error message will have a missing value in the conversation column.", + default_value=False, + since_version="5.10.0", + is_advanced=True, + ) + + error_column_name = knext.StringParameter( + "Error column name", + "Name of the error column in the output table.", + default_value="Errors", + since_version="5.10.0", + is_advanced=True, + ).rule(knext.OneOf(has_error_column, [True]), knext.Effect.SHOW) + recursion_limit = _recursion_limit_parameter() recursion_limit_handling = recursion_limit_mode_param_for_view() @@ -929,11 +1226,17 @@ def configure( f"Column {self.tool_column} not found in the tools table." ) + columns = [knext.Column(_message_type(), self.conversation_column_name)] + if self.has_error_column: + if self.conversation_column_name == self.error_column_name: + raise knext.InvalidParametersError( + "The conversation and error column names must not be equal." + ) + columns.append(knext.Column(knext.string(), self.error_column_name)) + return ( None, # combined tools workflow - knext.Schema.from_columns( - [knext.Column(_message_type(), self.conversation_column_name)] - ), + knext.Schema.from_columns(columns), [None] * ctx.get_connected_output_port_numbers()[2], ) @@ -946,6 +1249,7 @@ def execute( ): import pandas as pd from ._data_service import DataRegistry + import pyarrow as pa view_data = ctx._get_view_data() num_data_outputs = ctx.get_connected_output_port_numbers()[2] @@ -963,16 +1267,22 @@ def execute( ) else: message_type = _message_type() - conversation_table = util.create_empty_table( - None, - [ + columns = [ + util.OutputColumn( + self.conversation_column_name, + message_type, + message_type.to_pyarrow(), + ) + ] + if self.has_error_column: + columns.append( util.OutputColumn( - self.conversation_column_name, - message_type, - message_type.to_pyarrow(), + self.error_column_name, + knext.string(), + pa.string(), ) - ], - ) + ) + conversation_table = util.create_empty_table(None, columns) return ( combined_tools_workflow, conversation_table, @@ -989,12 +1299,12 @@ def get_data_service( tools_table: Optional[knext.Table], input_tables: list[knext.Table], ): - from langgraph.prebuilt import create_react_agent - from langgraph.checkpoint.memory import MemorySaver + from ._agent import AgentConfig from ._data_service import ( DataRegistry, LangchainToolConverter, AgentChatWidgetDataService, + AgentChatWidgetConfig, ) from ._tool import ExecutionMode @@ -1032,23 +1342,31 @@ def get_data_service( else: tools = [] - memory = MemorySaver() - agent = create_react_agent( - chat_model, tools=tools, prompt=self.developer_message, checkpointer=memory + conversation = self._create_conversation_history( + view_data, data_registry, tool_converter ) - self._fill_memory_with_messages(agent, view_data, data_registry, tool_converter) + toolset = AgentPrompterToolset(tools) + agent_config = AgentConfig(self.recursion_limit) - return AgentChatWidgetDataService( - ctx, - agent, - data_registry, + widget_config = AgentChatWidgetConfig( self.initial_message, self.conversation_column_name, - self.recursion_limit, self.recursion_limit_handling, self.show_tool_calls_and_results, self.reexecution_trigger, + self.has_error_column, + self.error_column_name, + ) + + return AgentChatWidgetDataService( + ctx, + chat_model, + conversation, + toolset, + agent_config, + data_registry, + widget_config, tool_converter, { "project_id": project_id, @@ -1056,41 +1374,20 @@ def get_data_service( }, ) - def _fill_memory_with_messages( - self, agent, view_data, data_registry, tool_converter - ): - config = { - "recursion_limit": self.recursion_limit, - "configurable": {"thread_id": "1"}, - } - previous_messages = [] - - if view_data is not None: - conversation_table = view_data["ports"][0] - if conversation_table is not None: - self._fill_memory_with_previous_messages( - agent, config, conversation_table, tool_converter, previous_messages - ) - - if not previous_messages and ( - data_registry.has_data or tool_converter.has_data_tools - ): - self._fill_memory_with_data_message(agent, config, data_registry) - - def _fill_memory_with_previous_messages( - self, agent, config, conversation_table, tool_converter, previous_messages - ): - from knime.types.message import to_langchain_message - - conversation_df = conversation_table[self.conversation_column_name].to_pandas() - for msg in conversation_df[self.conversation_column_name]: - lc_msg = to_langchain_message(msg) - previous_messages.append(tool_converter.sanitize_tool_names(lc_msg)) - agent.update_state(config, {"messages": previous_messages}, "agent") - - def _fill_memory_with_data_message(self, agent, config, data_registry): - agent.update_state( - config, {"messages": [data_registry.create_data_message()]}, "agent" + def _create_conversation_history(self, view_data, data_registry, tool_converter): + history_table = view_data["ports"][0] if view_data is not None else None + error_column = self.error_column_name if self.has_error_column else None + + return _create_agent_conversation( + ctx=None, + developer_message=self.developer_message, + data_registry=data_registry, + tool_converter=tool_converter, + history_table=history_table, + conversation_column=self.conversation_column_name, + error_column=error_column, + error_handling=None, + interactive=True, ) diff --git a/src/agents/base_deprecated.py b/src/agents/base_deprecated.py index 7a3b2c88..b2e15e72 100644 --- a/src/agents/base_deprecated.py +++ b/src/agents/base_deprecated.py @@ -490,7 +490,6 @@ def _post_user_message(self, user_message: str): content = f"An error occurred: {e}" error_message = {"type": "error", "content": content} self._message_queue.put(error_message) - self._append_ai_message_to_memory(content) def _to_frontend_messages(self, message): # split the node-view-ids out into a separate message @@ -558,7 +557,6 @@ def _handle_recursion_limit_error(self): "content": content, } self._message_queue.put(error_message) - self._append_ai_message_to_memory(content) def _append_ai_message_to_memory(self, message: str): from langchain_core import messages as lcm