diff --git a/python-agent-tools/custom-agent-tool/tool.json b/python-agent-tools/custom-agent-tool/tool.json new file mode 100644 index 0000000..a7b2235 --- /dev/null +++ b/python-agent-tools/custom-agent-tool/tool.json @@ -0,0 +1,308 @@ +{ + "meta" : { + "label": "API Connect agent tool", + "description": "Give LLM access to any Rest API", + "icon": "icon-rocket" + }, + "params": [ + { + "type": "SEPARATOR", + "label": "Tool description" + }, + { + "name": "tool_description", + "label": "Descriptors", + "type": "STRING" + }, + { + "name": "tool_parameters", + "label": "Parameters", + "type": "OBJECT_LIST", + "subParams": [ + { + "name": "tool_parameter_name", + "label": "Name", + "type": "STRING" + }, + { + "name": "tool_parameter_type", + "label": "Type", + "type": "SELECT", + "selectChoices":[ + {"value": "string", "label": "String"}, + {"value": "integer", "label": "Int"} + ] + }, + { + "name": "tool_parameter_description", + "label": "Description", + "type": "STRING" + }, + { + "name": "tool_parameter_is_required", + "label": "Required", + "type": "BOOLEAN" + } + ] + }, + { + "type": "SEPARATOR", + "label": "Authentication" + }, + { + "name": "auth_type", + "label": "Authentication type", + "description": "", + "type": "SELECT", + "defaultValue": null, + "selectChoices":[ + {"value": "secure_oauth", "label": "SSO"}, + {"value": "secure_basic", "label": "Secure username / password"}, + {"value": null, "label": "Other"} + ] + }, + { + "name": "credential", + "label": "Credential preset", + "type": "PRESET", + "parameterSetId": "credential", + "visibilityCondition": "model.auth_type == null" + }, + { + "name": "secure_oauth", + "label": "SSO preset", + "type": "PRESET", + "parameterSetId": "secure-oauth", + "visibilityCondition": "model.auth_type == 'secure_oauth'" + }, + { + "name": "secure_basic", + "label": "Credential preset", + "type": "PRESET", + "parameterSetId": "secure-basic", + "visibilityCondition": "model.auth_type == 'secure_basic'" + }, + { + "name": "should_use_user_secrets", + "label": " ", + "description": "Use profile's 'other credentials'", + "type": "BOOLEAN", + "defaultValue": false + }, + { + "type": "SEPARATOR", + "label": "API call parameters" + }, + { + "name": "custom_key_values", + "label": "Custom keys / values", + "description": "Replace {{key}} by value in presets (optional)", + "type": "KEY_VALUE_LIST", + "visibilityCondition": false + }, + + { + "name": "endpoint_url", + "label": "URL template", + "description": "https://{{variable}}.exmpl.com/usr/{{username}}/details", + "type": "TEXTAREA" + }, + { + "name": "http_method", + "label": "HTTP method", + "description": "", + "type": "SELECT", + "defaultValue": "GET", + "selectChoices":[ + {"value": "GET", "label": "GET"}, + {"value": "POST", "label": "POST"}, + {"value": "PUT", "label": "PUT"}, + {"value": "PATCH", "label": "PATCH"}, + {"value": "DELETE", "label": "DELETE"} + ] + }, + { + "name": "endpoint_query_string", + "label": "Query Params", + "description": "Will add ?key1=val1&key2=val2 to the URL", + "type": "KEY_VALUE_LIST" + }, + { + "name": "endpoint_body", + "label": "Body", + "description": "", + "type": "KEY_VALUE_LIST", + "visibilityCondition": false + }, + { + "name": "endpoint_headers", + "label": "Headers", + "description": "", + "type": "KEY_VALUE_LIST", + "defaultValue": [ + { + "from": "Content-Type", + "to": "application/json" + }, + { + "from": "Accept", + "to": "application/json" + } + ] + }, + { + "name": "body_format", + "label": "Body", + "description": "", + "type": "SELECT", + "defaultValue": null, + "selectChoices":[ + {"value": null, "label": "None"}, + {"value": "FORM_DATA", "label": "Form-data"}, + {"value": "RAW", "label": "Raw"} + ] + }, + { + "name": "text_body", + "label": "Request's body", + "description": "", + "type": "TEXTAREA", + "visibilityCondition": "model.body_format=='RAW'" + }, + { + "name": "key_value_body", + "label": "Request's body", + "description": "", + "type": "KEY_VALUE_LIST", + "visibilityCondition": "(['FORM_DATA'].indexOf(model.body_format)>-1)" + }, + { + "type": "SEPARATOR", + "label": "Data extraction" + }, + { + "name": "extraction_key", + "label": "JSON path to data", + "description": "$.path.to.my[*].data (optional)", + "defaultValue": null, + "type": "STRING" + }, + { + "type": "SEPARATOR", + "label": "Pagination" + }, + { + "name": "pagination_type", + "label": "Pagination mechanism", + "description": "Refer to the API's documentation", + "type": "SELECT", + "defaultValue": "na", + "selectChoices":[ + {"value": "na", "label": "No pagination"}, + {"value": "next_page", "label": "Next page URL provided"}, + {"value": "offset", "label": "Offset pagination"}, + {"value": "page", "label": "Per page"} + ] + }, + { + "type": "SEPARATOR", + "description": "⚠ Requires a key to data array", + "visibilityCondition": "model.pagination_type=='page' && !model.extraction_key" + }, + { + "name": "next_page_url_key", + "label": "Key to next request URL", + "description": "Dot separated key path to next request URL", + "type": "STRING", + "defaultValue": null, + "visibilityCondition": "model.pagination_type=='next_page'" + }, + { + "name": "is_next_page_url_relative", + "label": " ", + "description": "Next page URL is relative", + "type": "BOOLEAN", + "defaultValue": false, + "visibilityCondition": "model.pagination_type=='next_page'" + }, + { + "name": "next_page_url_base", + "label": "Base URL to next page", + "description": "https://mysite.com/path/", + "type": "STRING", + "defaultValue": null, + "visibilityCondition": "model.pagination_type=='next_page' && (model.is_next_page_url_relative==true)" + }, + { + "name": "top_key", + "label": "Key limiting elements per page", + "description": "", + "type": "STRING", + "defaultValue": null, + "visibilityCondition": "model.pagination_type == 'offset'" + }, + { + "name": "skip_key", + "label": "Key for element offset", + "description": "", + "type": "STRING", + "defaultValue": null, + "visibilityCondition": "model.pagination_type=='offset'" + }, + { + "name": "skip_key", + "label": "Key for page offset", + "description": "", + "type": "STRING", + "defaultValue": null, + "visibilityCondition": "model.pagination_type=='page'" + }, + { + "type": "SEPARATOR", + "label": "Advanced" + }, + { + "name": "ignore_ssl_check", + "label": " ", + "description": "Ignore SSL check", + "type": "BOOLEAN", + "visibilityCondition": "model.auth_type!='secure_oauth' && model.auth_type!='secure_basic'", + "defaultValue": false + }, + { + "name": "redirect_auth_header", + "label": " ", + "description": "Redirect authorization header", + "type": "BOOLEAN", + "defaultValue": false + }, + { + "name": "display_metadata", + "label": " ", + "description": "Display metadata", + "type": "BOOLEAN", + "defaultValue": false + }, + { + "name": "timeout", + "label": "Timeout (s)", + "description": "-1 for no limit", + "type": "INT", + "defaultValue": 3600 + }, + { + "name": "requests_per_minute", + "label": "Rate limit (requests/m)", + "description": "-1 for no limit", + "type": "INT", + "defaultValue": -1 + }, + { + "name": "maximum_number_rows", + "label": "Maximum number of rows", + "description": "-1 for no limit", + "type": "INT", + "defaultValue": -1 + } + ] +} diff --git a/python-agent-tools/custom-agent-tool/tool.py b/python-agent-tools/custom-agent-tool/tool.py new file mode 100644 index 0000000..4945623 --- /dev/null +++ b/python-agent-tools/custom-agent-tool/tool.py @@ -0,0 +1,106 @@ +from dataiku.llm.agent_tools import BaseAgentTool +from safe_logger import SafeLogger +from dku_constants import DKUConstants +from rest_api_client import RestAPIClient +from dku_utils import ( + get_dku_key_values, get_endpoint_parameters, + get_secure_credentials, + decode_csv_data, decode_bytes, get_user_secrets +) +from jsonpath_ng.ext import parse + + +logger = SafeLogger("api-connect plugin", forbidden_keys=DKUConstants.FORBIDDEN_KEYS) + + +class CustomAgentTool(BaseAgentTool): + def set_config(self, config, plugin_config): + logger.info('API-Connect plugin agent tool v{}'.format(DKUConstants.PLUGIN_VERSION)) + logger.info("config={}, plugin_config={}".format( + logger.filter_secrets(config), + logger.filter_secrets(plugin_config) + ) + ) + self.config = config + tool_parameters = config.get("tool_parameters", []) + self.properties = {} + self.required = [] + self.parameters = [] + self.sample_query = {} + for tool_parameter in tool_parameters: + parameter_name = tool_parameter.get("tool_parameter_name") + parameter_type = tool_parameter.get("tool_parameter_type") + parameter_description = tool_parameter.get("tool_parameter_description") + if parameter_name and parameter_type: + self.sample_query[parameter_name] = parameter_description + self.parameters.append(parameter_name) + self.properties[parameter_name] = { + "type": parameter_type, + "description": parameter_description + } + if tool_parameter.get("tool_parameter_is_required", False): + self.required.append(parameter_name) + + self.endpoint_parameters = get_endpoint_parameters(config) + self.secure_credentials = get_secure_credentials(config) + self.credential = config.get("credential", {}) + self.custom_key_values = get_dku_key_values(config.get("custom_key_values", {})) + user_secrets = get_user_secrets(config) + self.custom_key_values.update(user_secrets) + self.extraction_key = self.endpoint_parameters.get("extraction_key", "") + self.extraction_path = self.extraction_key.split('.') + + def get_descriptor(self, tool): + return { + "description": "{}".format(self.config.get("tool_description", "")), + "inputSchema": { + "$id": "https://example.com/agents/tools/hash/input", + "title": "Input for the hashing tool", + "type": "object", + "properties": self.properties, + "required": self.required + } + } + + def invoke(self, input, trace): + args = input.get("input", {}) + self.custom_key_values.update(args) + client = RestAPIClient(self.credential, self.secure_credentials, self.endpoint_parameters, self.custom_key_values) + rows = [] + while client.has_more_data(): + json_response = client.paginated_api_call() + if self.extraction_key: + matches = parse(self.extraction_key).find(json_response) + data = [] + for counter in range(0, len(matches)): + data.append(matches[counter].value) + else: + data = json_response + for formated_row in format_data(data): + rows.append(formated_row) + return { + "output": "{}".format(rows), + "sources": [{ + "toolCallDescription": "Payload was hashed" + }] + } + + def load_sample_query(self, tool): + return self.sample_query + + +def format_data(data): + if isinstance(data, list): + for row in data: + yield row + elif isinstance(data, dict): + yield data + else: + csv_data = decode_csv_data(data) + if csv_data: + for row in csv_data: + yield row + else: + yield { + DKUConstants.API_RESPONSE_KEY: "{}".format(decode_bytes(data)) + }