From 659384e023605d9d261e068cb6ed347beba93088 Mon Sep 17 00:00:00 2001 From: Celyn Raquel Date: Fri, 22 May 2026 10:00:37 +0800 Subject: [PATCH 1/4] fix: Add optional ARN to AWS Bedrock connection to assume role when using models --- api/uv.lock | 51 ++++ .../agenta/sdk/engines/running/handlers.py | 53 ++++ .../unit/test_resolve_aws_credentials.py | 247 ++++++++++++++++++ sdks/python/pyproject.toml | 1 + sdks/python/uv.lock | 2 + services/uv.lock | 2 + .../assets/ConfigureProviderDrawerContent.tsx | 3 +- .../assets/constants.ts | 12 +- .../src/secret/core/transforms.ts | 2 + .../agenta-shared/src/types/llmProvider.ts | 1 + 10 files changed, 371 insertions(+), 3 deletions(-) create mode 100644 sdks/python/oss/tests/pytest/unit/test_resolve_aws_credentials.py diff --git a/api/uv.lock b/api/uv.lock index f6ae4bf04b..2a6ff82c56 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -12,6 +12,7 @@ version = "0.100.0" source = { editable = "../sdks/python" } dependencies = [ { name = "agenta-client" }, + { name = "boto3" }, { name = "daytona" }, { name = "fastapi" }, { name = "httpx" }, @@ -32,6 +33,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "agenta-client", editable = "../clients/python" }, + { name = "boto3", specifier = ">=1,<2" }, { name = "daytona", specifier = ">=0.176,<0.177" }, { name = "fastapi", specifier = ">=0.136,<0.137" }, { name = "httpx", specifier = ">=0.28,<0.29" }, @@ -413,6 +415,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/df/73/b6e24bd22e6720ca8ee9a85a0c4a2971af8497d8f3193fa05390cbd46e09/backoff-2.2.1-py3-none-any.whl", hash = "sha256:63579f9a0628e06278f7e47b7d7d5b6ce20dc65c5e96a6f3ca99a6adca0396e8", size = 15148, upload-time = "2022-10-05T19:19:30.546Z" }, ] +[[package]] +name = "boto3" +version = "1.43.12" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "botocore" }, + { name = "jmespath" }, + { name = "s3transfer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/67/2f/c4159fa45079b41f11ad17d8c5df8e1d10169b94d1e4240df5be116d3f0a/boto3-1.43.12.tar.gz", hash = "sha256:4a60cdf02c52cb0a60f8dbc986142ce2c31e87e3df1438ffe6755b83008f3e4e", size = 113142, upload-time = "2026-05-20T19:38:13.163Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/35/b7ab4b6977811f9887405e24460640033c22f4515cf1e904480710bd6296/boto3-1.43.12-py3-none-any.whl", hash = "sha256:685c3e6093455623bfc22dac55b4946ea243095252f7f9c11a99d84b38033bcf", size = 140537, upload-time = "2026-05-20T19:38:09.995Z" }, +] + +[[package]] +name = "botocore" +version = "1.43.12" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jmespath" }, + { name = "python-dateutil" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6d/a1/95ec376c2300e605225998619c46f7093c515710b9d6d65f891f126f32b6/botocore-1.43.12.tar.gz", hash = "sha256:7608ecd51687132e22aa8b82acb89a5917b1b68ec0563c25d82c3e16adab9bc0", size = 15366431, upload-time = "2026-05-20T19:37:58.734Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a9/20/b2ef618de8dc634361e32344bdc5139f1ad92968ab6c18cddd6c8c431f67/botocore-1.43.12-py3-none-any.whl", hash = "sha256:75dfb84c6edbb5aaa0314d93776d840d74e26e8d97e0431270a3274d70abeba3", size = 15046449, upload-time = "2026-05-20T19:37:53.723Z" }, +] + [[package]] name = "cachetools" version = "7.1.3" @@ -1203,6 +1233,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c8/8d/302cb2057b7513327b4d575cff6b1d066ee6431a5357fc3f8867cd684406/jiter-0.15.0-graalpy312-graalpy250_312_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54d5d6090cdc1b7c9e780dfb04949a990adb1e301a2fc0bbcee7de4638d33f9a", size = 344469, upload-time = "2026-05-19T10:09:46.864Z" }, ] +[[package]] +name = "jmespath" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d3/59/322338183ecda247fb5d1763a6cbe46eff7222eaeebafd9fa65d4bf5cb11/jmespath-1.1.0.tar.gz", hash = "sha256:472c87d80f36026ae83c6ddd0f1d05d4e510134ed462851fd5f754c8c3cbb88d", size = 27377, upload-time = "2026-01-22T16:35:26.279Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/2f/967ba146e6d58cf6a652da73885f52fc68001525b4197effc174321d70b4/jmespath-1.1.0-py3-none-any.whl", hash = "sha256:a5663118de4908c91729bea0acadca56526eb2698e83de10cd116ae0f4e97c64", size = 20419, upload-time = "2026-01-22T16:35:24.919Z" }, +] + [[package]] name = "jsonschema" version = "4.26.0" @@ -2376,6 +2415,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/b7/b95708304cd49b7b6f82fdd039f1748b66ec2b21d6a45180910802f1abf1/rpds_py-0.30.0-pp311-pypy311_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:ac37f9f516c51e5753f27dfdef11a88330f04de2d564be3991384b2f3535d02e", size = 562191, upload-time = "2025-11-30T20:24:36.853Z" }, ] +[[package]] +name = "s3transfer" +version = "0.17.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "botocore" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9b/ec/7c692cde9125b77e84b307354d4fb705f98b8ccad59a036d5957ca75bfc3/s3transfer-0.17.0.tar.gz", hash = "sha256:9edeb6d1c3c2f89d6050348548834ad8289610d886e5bf7b7207728bd43ce33a", size = 155337, upload-time = "2026-04-29T22:07:36.33Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/72/c6c32d2b657fa3dad1de340254e14390b1e334ce38268b7ad51abda3c8c2/s3transfer-0.17.0-py3-none-any.whl", hash = "sha256:ce3801712acf4ad3e89fb9990df97b4972e93f4b3b0004d214be5bce12814c20", size = 86811, upload-time = "2026-04-29T22:07:34.966Z" }, +] + [[package]] name = "sendgrid" version = "6.12.5" diff --git a/sdks/python/agenta/sdk/engines/running/handlers.py b/sdks/python/agenta/sdk/engines/running/handlers.py index eca3a745f0..aa9565d825 100644 --- a/sdks/python/agenta/sdk/engines/running/handlers.py +++ b/sdks/python/agenta/sdk/engines/running/handlers.py @@ -1041,6 +1041,7 @@ async def auto_ai_critique_v0( ) from e try: + provider_settings = _resolve_aws_credentials(provider_settings) with mockllm.user_aws_credentials_from(_coerce_credentials(provider_settings)): response = await mockllm.acompletion( messages=formatted_prompt_template, @@ -1825,6 +1826,57 @@ def _coerce_credentials(provider_settings: Dict) -> Dict: } +def _resolve_aws_credentials(provider_settings: Dict) -> Dict: + """If aws_role_arn is set, assume that role and return a copy of provider_settings + with all three AWS credentials (key, secret, token) replaced by the temporary + session credentials. This ensures litellm receives a consistent credential set + rather than a mix of base credentials (via kwargs) and role session token (via env).""" + + role_arn = provider_settings.get("aws_role_arn") or provider_settings.get( + "AWS_ROLE_ARN" + ) + if not role_arn: + return provider_settings + + import boto3 + + access_key = provider_settings.get("aws_access_key_id") or provider_settings.get( + "AWS_ACCESS_KEY_ID" + ) + secret_key = provider_settings.get( + "aws_secret_access_key" + ) or provider_settings.get("AWS_SECRET_ACCESS_KEY") + region = ( + provider_settings.get("aws_region_name") + or provider_settings.get("aws_region") + or provider_settings.get("AWS_REGION") + or "us-east-1" + ) + + sts = boto3.client( + "sts", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=region, + ) + resp = sts.assume_role(RoleArn=role_arn, RoleSessionName="agenta-bedrock") + creds = resp["Credentials"] + + updated = dict(provider_settings) + updated["aws_access_key_id"] = creds["AccessKeyId"] + updated["aws_secret_access_key"] = creds["SecretAccessKey"] + updated["aws_session_token"] = creds["SessionToken"] + # Remove the upper-case variants so litellm only sees one consistent set + updated.pop("AWS_ACCESS_KEY_ID", None) + updated.pop("AWS_SECRET_ACCESS_KEY", None) + updated.pop("AWS_SESSION_TOKEN", None) + # Remove role ARN so it isn't forwarded as an unknown kwarg to litellm + updated.pop("aws_role_arn", None) + updated.pop("AWS_ROLE_ARN", None) + + return updated + + def _apply_responses_bridge_if_needed( provider_settings: Dict, llm_config: ModelConfig, @@ -2051,6 +2103,7 @@ async def _run_prompt_llm_config_with_retry( if messages is not None: openai_kwargs["messages"] = [*openai_kwargs["messages"], *messages] + provider_settings = _resolve_aws_credentials(provider_settings) with mockllm.user_aws_credentials_from( _coerce_credentials(provider_settings) ): diff --git a/sdks/python/oss/tests/pytest/unit/test_resolve_aws_credentials.py b/sdks/python/oss/tests/pytest/unit/test_resolve_aws_credentials.py new file mode 100644 index 0000000000..c5d1cc0289 --- /dev/null +++ b/sdks/python/oss/tests/pytest/unit/test_resolve_aws_credentials.py @@ -0,0 +1,247 @@ +"""Unit tests for ``_resolve_aws_credentials`` in handlers.py. + +Covers: +- no role ARN → settings returned unchanged; +- lowercase ``aws_role_arn`` triggers STS assume_role; +- uppercase ``AWS_ROLE_ARN`` triggers STS assume_role; +- role ARN keys removed from result; +- uppercase AWS_* credential keys removed from result; +- session token injected from STS response; +- region defaults to us-east-1 when not supplied; +- region resolved from aws_region_name / aws_region / AWS_REGION; +- base credentials forwarded to STS client constructor. +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from agenta.sdk.engines.running.handlers import _resolve_aws_credentials + +_ROLE_ARN = "arn:aws:iam::123456789012:role/my-role" +_TEMP_ACCESS = "ASIATEMP" +_TEMP_SECRET = "SECRETTEMP" +_TEMP_TOKEN = "TOKENTEMP" + + +def _mock_sts( + access_key=_TEMP_ACCESS, + secret=_TEMP_SECRET, + token=_TEMP_TOKEN, +) -> MagicMock: + sts = MagicMock() + sts.assume_role.return_value = { + "Credentials": { + "AccessKeyId": access_key, + "SecretAccessKey": secret, + "SessionToken": token, + } + } + return sts + + +# --------------------------------------------------------------------------- +# No-op path +# --------------------------------------------------------------------------- + + +def test_no_role_arn_returns_unchanged(): + settings = {"aws_access_key_id": "AKID", "aws_secret_access_key": "SECRET"} + result = _resolve_aws_credentials(settings) + assert result is settings + + +def test_empty_dict_returns_unchanged(): + settings = {} + result = _resolve_aws_credentials(settings) + assert result is settings + + +# --------------------------------------------------------------------------- +# STS is called +# --------------------------------------------------------------------------- + + +def test_lowercase_role_arn_triggers_sts(): + sts = _mock_sts() + settings = {"aws_role_arn": _ROLE_ARN} + + with patch("boto3.client", return_value=sts) as mock_client: + result = _resolve_aws_credentials(settings) + + mock_client.assert_called_once_with( + "sts", + aws_access_key_id=None, + aws_secret_access_key=None, + region_name="us-east-1", + ) + sts.assume_role.assert_called_once_with( + RoleArn=_ROLE_ARN, RoleSessionName="agenta-bedrock" + ) + assert result["aws_access_key_id"] == _TEMP_ACCESS + + +def test_uppercase_role_arn_triggers_sts(): + sts = _mock_sts() + settings = {"AWS_ROLE_ARN": _ROLE_ARN} + + with patch("boto3.client", return_value=sts): + result = _resolve_aws_credentials(settings) + + sts.assume_role.assert_called_once_with( + RoleArn=_ROLE_ARN, RoleSessionName="agenta-bedrock" + ) + assert result["aws_access_key_id"] == _TEMP_ACCESS + + +# --------------------------------------------------------------------------- +# Result shape +# --------------------------------------------------------------------------- + + +def test_role_arn_keys_removed_from_result(): + sts = _mock_sts() + settings = { + "aws_role_arn": _ROLE_ARN, + "AWS_ROLE_ARN": _ROLE_ARN, + "aws_access_key_id": "AKID", + } + + with patch("boto3.client", return_value=sts): + result = _resolve_aws_credentials(settings) + + assert "aws_role_arn" not in result + assert "AWS_ROLE_ARN" not in result + + +def test_uppercase_credential_keys_removed_from_result(): + sts = _mock_sts() + settings = { + "aws_role_arn": _ROLE_ARN, + "AWS_ACCESS_KEY_ID": "AKID", + "AWS_SECRET_ACCESS_KEY": "SECRET", + "AWS_SESSION_TOKEN": "OLD_TOKEN", + } + + with patch("boto3.client", return_value=sts): + result = _resolve_aws_credentials(settings) + + assert "AWS_ACCESS_KEY_ID" not in result + assert "AWS_SECRET_ACCESS_KEY" not in result + assert "AWS_SESSION_TOKEN" not in result + + +def test_session_token_injected(): + sts = _mock_sts(token="FRESH_TOKEN") + settings = {"aws_role_arn": _ROLE_ARN} + + with patch("boto3.client", return_value=sts): + result = _resolve_aws_credentials(settings) + + assert result["aws_session_token"] == "FRESH_TOKEN" + + +def test_lowercase_creds_replaced_with_temp(): + sts = _mock_sts() + settings = { + "aws_role_arn": _ROLE_ARN, + "aws_access_key_id": "ORIGINAL_KEY", + "aws_secret_access_key": "ORIGINAL_SECRET", + } + + with patch("boto3.client", return_value=sts): + result = _resolve_aws_credentials(settings) + + assert result["aws_access_key_id"] == _TEMP_ACCESS + assert result["aws_secret_access_key"] == _TEMP_SECRET + + +# --------------------------------------------------------------------------- +# Region resolution +# --------------------------------------------------------------------------- + + +def test_region_defaults_to_us_east_1(): + sts = _mock_sts() + settings = {"aws_role_arn": _ROLE_ARN} + + with patch("boto3.client", return_value=sts) as mock_client: + _resolve_aws_credentials(settings) + + _, kwargs = mock_client.call_args + assert kwargs["region_name"] == "us-east-1" + + +@pytest.mark.parametrize( + "key", + ["aws_region_name", "aws_region", "AWS_REGION"], +) +def test_region_resolved_from_setting(key): + sts = _mock_sts() + settings = {"aws_role_arn": _ROLE_ARN, key: "eu-west-1"} + + with patch("boto3.client", return_value=sts) as mock_client: + _resolve_aws_credentials(settings) + + _, kwargs = mock_client.call_args + assert kwargs["region_name"] == "eu-west-1" + + +# --------------------------------------------------------------------------- +# Base credentials forwarded to STS +# --------------------------------------------------------------------------- + + +def test_base_credentials_forwarded_to_sts(): + sts = _mock_sts() + settings = { + "aws_role_arn": _ROLE_ARN, + "aws_access_key_id": "BASE_KEY", + "aws_secret_access_key": "BASE_SECRET", + "aws_region_name": "ap-southeast-1", + } + + with patch("boto3.client", return_value=sts) as mock_client: + _resolve_aws_credentials(settings) + + mock_client.assert_called_once_with( + "sts", + aws_access_key_id="BASE_KEY", + aws_secret_access_key="BASE_SECRET", + region_name="ap-southeast-1", + ) + + +def test_uppercase_base_credentials_forwarded_to_sts(): + sts = _mock_sts() + settings = { + "aws_role_arn": _ROLE_ARN, + "AWS_ACCESS_KEY_ID": "UC_KEY", + "AWS_SECRET_ACCESS_KEY": "UC_SECRET", + } + + with patch("boto3.client", return_value=sts) as mock_client: + _resolve_aws_credentials(settings) + + mock_client.assert_called_once_with( + "sts", + aws_access_key_id="UC_KEY", + aws_secret_access_key="UC_SECRET", + region_name="us-east-1", + ) + + +# --------------------------------------------------------------------------- +# Original dict is not mutated +# --------------------------------------------------------------------------- + + +def test_original_dict_not_mutated(): + sts = _mock_sts() + settings = {"aws_role_arn": _ROLE_ARN, "aws_access_key_id": "ORIG"} + original = dict(settings) + + with patch("boto3.client", return_value=sts): + _resolve_aws_credentials(settings) + + assert settings == original diff --git a/sdks/python/pyproject.toml b/sdks/python/pyproject.toml index 14a512a14b..a73ac2c647 100644 --- a/sdks/python/pyproject.toml +++ b/sdks/python/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "fastapi>=0.136,<0.137", "pyyaml>=6,<7", "litellm>=1,<2", + "boto3>=1,<2", "openai>=2,<3", "jinja2>=3,<4", "orjson>=3,<4", diff --git a/sdks/python/uv.lock b/sdks/python/uv.lock index 7dbfbe4ce3..4fa95104c5 100644 --- a/sdks/python/uv.lock +++ b/sdks/python/uv.lock @@ -8,6 +8,7 @@ version = "0.100.0" source = { editable = "." } dependencies = [ { name = "agenta-client" }, + { name = "boto3" }, { name = "daytona" }, { name = "fastapi" }, { name = "httpx" }, @@ -45,6 +46,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "agenta-client", editable = "../../clients/python" }, + { name = "boto3", specifier = ">=1,<2" }, { name = "daytona", specifier = ">=0.176,<0.177" }, { name = "fastapi", specifier = ">=0.136,<0.137" }, { name = "httpx", specifier = ">=0.28,<0.29" }, diff --git a/services/uv.lock b/services/uv.lock index b8eeabee5c..cb07407d12 100644 --- a/services/uv.lock +++ b/services/uv.lock @@ -12,6 +12,7 @@ version = "0.100.0" source = { editable = "../sdks/python" } dependencies = [ { name = "agenta-client" }, + { name = "boto3" }, { name = "daytona" }, { name = "fastapi" }, { name = "httpx" }, @@ -32,6 +33,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "agenta-client", editable = "../clients/python" }, + { name = "boto3", specifier = ">=1,<2" }, { name = "daytona", specifier = ">=0.176,<0.177" }, { name = "fastapi", specifier = ">=0.136,<0.137" }, { name = "httpx", specifier = ">=0.28,<0.29" }, diff --git a/web/oss/src/components/ModelRegistry/Drawers/ConfigureProviderDrawer/assets/ConfigureProviderDrawerContent.tsx b/web/oss/src/components/ModelRegistry/Drawers/ConfigureProviderDrawer/assets/ConfigureProviderDrawerContent.tsx index 2ef7b96fa5..a13ad33e6c 100644 --- a/web/oss/src/components/ModelRegistry/Drawers/ConfigureProviderDrawer/assets/ConfigureProviderDrawerContent.tsx +++ b/web/oss/src/components/ModelRegistry/Drawers/ConfigureProviderDrawer/assets/ConfigureProviderDrawerContent.tsx @@ -193,6 +193,7 @@ const ConfigureProviderDrawerContent = ({ accessKeyId: "", accessKey: "", sessionToken: "", + roleArn: "", models: [""], }} > @@ -223,7 +224,7 @@ const ConfigureProviderDrawerContent = ({ const field = rawField as FieldWithAttributes const isJson = field.attributes?.kind === "json" const isRequired = - field.key === "apiBaseUrl" + field.key === "apiBaseUrl" || field.required === false ? false : !shouldFilter ? !!field.required diff --git a/web/oss/src/components/ModelRegistry/Drawers/ConfigureProviderDrawer/assets/constants.ts b/web/oss/src/components/ModelRegistry/Drawers/ConfigureProviderDrawer/assets/constants.ts index def625f5c9..70ce79325e 100644 --- a/web/oss/src/components/ModelRegistry/Drawers/ConfigureProviderDrawer/assets/constants.ts +++ b/web/oss/src/components/ModelRegistry/Drawers/ConfigureProviderDrawer/assets/constants.ts @@ -74,7 +74,7 @@ export const PROVIDER_FIELDS: { label: "Access key ID", placeholder: "Enter access key ID", note: "This secret will be encrypted in transit and at rest.", - model: ["bedrock", "sagemaker"], + model: ["bedrock", "bedrock_converse", "sagemaker"], required: false, }, { @@ -82,7 +82,7 @@ export const PROVIDER_FIELDS: { label: "Secret Access Key", placeholder: "Enter secret access key", note: "This secret will be encrypted in transit and at rest.", - model: ["bedrock", "sagemaker"], + model: ["bedrock", "bedrock_converse", "sagemaker"], required: false, }, { @@ -93,4 +93,12 @@ export const PROVIDER_FIELDS: { model: [], required: false, }, + { + key: "roleArn", + label: "Role ARN", + placeholder: "arn:aws:iam::123456789012:role/my-role", + note: "Optional. If set, this role will be assumed before each request.", + model: ["bedrock", "bedrock_converse", "sagemaker"], + required: false, + }, ] diff --git a/web/packages/agenta-entities/src/secret/core/transforms.ts b/web/packages/agenta-entities/src/secret/core/transforms.ts index 3971bfbd5f..eb8a7941e1 100644 --- a/web/packages/agenta-entities/src/secret/core/transforms.ts +++ b/web/packages/agenta-entities/src/secret/core/transforms.ts @@ -103,6 +103,7 @@ export const transformSecret = (secrets: SecretResponseDto[]): LlmProvider[] => accessKeyId: extras.aws_access_key_id || "", accessKey: extras.aws_secret_access_key || "", sessionToken: extras.aws_session_token || "", + roleArn: extras.aws_role_arn || "", models: data.models.map((model) => model.slug), modelKeys: data.model_keys ?? undefined, version: data.provider.version ?? "", @@ -146,6 +147,7 @@ export const transformCustomProviderPayloadData = (values: LlmProvider): CreateS aws_access_key_id: values.accessKeyId, aws_secret_access_key: values.accessKey, aws_session_token: values.sessionToken, + aws_role_arn: values.roleArn, }, }, models: values.models?.map((slug) => ({slug})) ?? [], diff --git a/web/packages/agenta-shared/src/types/llmProvider.ts b/web/packages/agenta-shared/src/types/llmProvider.ts index 9f32cbd975..cd99e825aa 100644 --- a/web/packages/agenta-shared/src/types/llmProvider.ts +++ b/web/packages/agenta-shared/src/types/llmProvider.ts @@ -25,6 +25,7 @@ export interface LlmProvider { accessKeyId?: string accessKey?: string sessionToken?: string + roleArn?: string models?: string[] modelKeys?: string[] id?: string From f1edf37dcaddcfa22b064e3e669dc45eace69c1c Mon Sep 17 00:00:00 2001 From: Celyn Raquel Date: Fri, 22 May 2026 17:40:48 +0800 Subject: [PATCH 2/4] fix: Remove bedrock_converse provider in roleArn model values --- .../Drawers/ConfigureProviderDrawer/assets/constants.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/oss/src/components/ModelRegistry/Drawers/ConfigureProviderDrawer/assets/constants.ts b/web/oss/src/components/ModelRegistry/Drawers/ConfigureProviderDrawer/assets/constants.ts index 70ce79325e..d337f465ed 100644 --- a/web/oss/src/components/ModelRegistry/Drawers/ConfigureProviderDrawer/assets/constants.ts +++ b/web/oss/src/components/ModelRegistry/Drawers/ConfigureProviderDrawer/assets/constants.ts @@ -98,7 +98,7 @@ export const PROVIDER_FIELDS: { label: "Role ARN", placeholder: "arn:aws:iam::123456789012:role/my-role", note: "Optional. If set, this role will be assumed before each request.", - model: ["bedrock", "bedrock_converse", "sagemaker"], + model: ["bedrock", "sagemaker"], required: false, }, ] From cb9caa04aba3ca1b56c719b0286c1d5465186483 Mon Sep 17 00:00:00 2001 From: Celyn Raquel Date: Fri, 22 May 2026 17:42:40 +0800 Subject: [PATCH 3/4] fix: Remove bedrock_converse provider in accessKeyId and accessKey constants model values --- .../Drawers/ConfigureProviderDrawer/assets/constants.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web/oss/src/components/ModelRegistry/Drawers/ConfigureProviderDrawer/assets/constants.ts b/web/oss/src/components/ModelRegistry/Drawers/ConfigureProviderDrawer/assets/constants.ts index d337f465ed..8e06cf3407 100644 --- a/web/oss/src/components/ModelRegistry/Drawers/ConfigureProviderDrawer/assets/constants.ts +++ b/web/oss/src/components/ModelRegistry/Drawers/ConfigureProviderDrawer/assets/constants.ts @@ -74,7 +74,7 @@ export const PROVIDER_FIELDS: { label: "Access key ID", placeholder: "Enter access key ID", note: "This secret will be encrypted in transit and at rest.", - model: ["bedrock", "bedrock_converse", "sagemaker"], + model: ["bedrock", "sagemaker"], required: false, }, { @@ -82,7 +82,7 @@ export const PROVIDER_FIELDS: { label: "Secret Access Key", placeholder: "Enter secret access key", note: "This secret will be encrypted in transit and at rest.", - model: ["bedrock", "bedrock_converse", "sagemaker"], + model: ["bedrock", "sagemaker"], required: false, }, { From 527672ba96b8ec0ac8340cd77d5d118b37ab265e Mon Sep 17 00:00:00 2001 From: Celyn Raquel Date: Fri, 22 May 2026 17:46:58 +0800 Subject: [PATCH 4/4] fix: Forward session token if available to boto3 client for temporary-credential support --- .../agenta/sdk/engines/running/handlers.py | 4 ++ .../unit/test_resolve_aws_credentials.py | 45 +++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/sdks/python/agenta/sdk/engines/running/handlers.py b/sdks/python/agenta/sdk/engines/running/handlers.py index aa9565d825..8fa7fe04b3 100644 --- a/sdks/python/agenta/sdk/engines/running/handlers.py +++ b/sdks/python/agenta/sdk/engines/running/handlers.py @@ -1846,6 +1846,9 @@ def _resolve_aws_credentials(provider_settings: Dict) -> Dict: secret_key = provider_settings.get( "aws_secret_access_key" ) or provider_settings.get("AWS_SECRET_ACCESS_KEY") + session_token = provider_settings.get("aws_session_token") or provider_settings.get( + "AWS_SESSION_TOKEN" + ) region = ( provider_settings.get("aws_region_name") or provider_settings.get("aws_region") @@ -1857,6 +1860,7 @@ def _resolve_aws_credentials(provider_settings: Dict) -> Dict: "sts", aws_access_key_id=access_key, aws_secret_access_key=secret_key, + aws_session_token=session_token, region_name=region, ) resp = sts.assume_role(RoleArn=role_arn, RoleSessionName="agenta-bedrock") diff --git a/sdks/python/oss/tests/pytest/unit/test_resolve_aws_credentials.py b/sdks/python/oss/tests/pytest/unit/test_resolve_aws_credentials.py index c5d1cc0289..e5e542de94 100644 --- a/sdks/python/oss/tests/pytest/unit/test_resolve_aws_credentials.py +++ b/sdks/python/oss/tests/pytest/unit/test_resolve_aws_credentials.py @@ -73,6 +73,7 @@ def test_lowercase_role_arn_triggers_sts(): "sts", aws_access_key_id=None, aws_secret_access_key=None, + aws_session_token=None, region_name="us-east-1", ) sts.assume_role.assert_called_once_with( @@ -208,10 +209,53 @@ def test_base_credentials_forwarded_to_sts(): "sts", aws_access_key_id="BASE_KEY", aws_secret_access_key="BASE_SECRET", + aws_session_token=None, region_name="ap-southeast-1", ) +@pytest.mark.parametrize( + ("settings", "expected_kwargs"), + [ + ( + { + "aws_role_arn": _ROLE_ARN, + "aws_access_key_id": "BASE_KEY", + "aws_secret_access_key": "BASE_SECRET", + "aws_session_token": "BASE_TOKEN", + }, + { + "aws_access_key_id": "BASE_KEY", + "aws_secret_access_key": "BASE_SECRET", + "aws_session_token": "BASE_TOKEN", + "region_name": "us-east-1", + }, + ), + ( + { + "aws_role_arn": _ROLE_ARN, + "AWS_ACCESS_KEY_ID": "UC_KEY", + "AWS_SECRET_ACCESS_KEY": "UC_SECRET", + "AWS_SESSION_TOKEN": "UC_TOKEN", + }, + { + "aws_access_key_id": "UC_KEY", + "aws_secret_access_key": "UC_SECRET", + "aws_session_token": "UC_TOKEN", + "region_name": "us-east-1", + }, + ), + ], +) +def test_session_credentials_forwarded_to_sts(settings, expected_kwargs): + sts = _mock_sts() + + with patch("boto3.client", return_value=sts) as mock_client: + _resolve_aws_credentials(settings) + + mock_client.assert_called_once_with("sts", **expected_kwargs) + + def test_uppercase_base_credentials_forwarded_to_sts(): sts = _mock_sts() settings = { @@ -227,6 +271,7 @@ def test_uppercase_base_credentials_forwarded_to_sts(): "sts", aws_access_key_id="UC_KEY", aws_secret_access_key="UC_SECRET", + aws_session_token=None, region_name="us-east-1", )