Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions api/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

57 changes: 57 additions & 0 deletions sdks/python/agenta/sdk/engines/running/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,6 +1041,7 @@ async def auto_ai_critique_v0(
) from e

try:
provider_settings = _resolve_aws_credentials(provider_settings)
with mockllm.user_aws_credentials_from(_coerce_credentials(provider_settings)):
response = await mockllm.acompletion(
messages=formatted_prompt_template,
Expand Down Expand Up @@ -1825,6 +1826,61 @@ def _coerce_credentials(provider_settings: Dict) -> Dict:
}


def _resolve_aws_credentials(provider_settings: Dict) -> Dict:
"""If aws_role_arn is set, assume that role and return a copy of provider_settings
with all three AWS credentials (key, secret, token) replaced by the temporary
session credentials. This ensures litellm receives a consistent credential set
rather than a mix of base credentials (via kwargs) and role session token (via env)."""

role_arn = provider_settings.get("aws_role_arn") or provider_settings.get(
"AWS_ROLE_ARN"
)
if not role_arn:
return provider_settings

import boto3

access_key = provider_settings.get("aws_access_key_id") or provider_settings.get(
"AWS_ACCESS_KEY_ID"
)
secret_key = provider_settings.get(
"aws_secret_access_key"
) or provider_settings.get("AWS_SECRET_ACCESS_KEY")
session_token = provider_settings.get("aws_session_token") or provider_settings.get(
"AWS_SESSION_TOKEN"
)
region = (
provider_settings.get("aws_region_name")
or provider_settings.get("aws_region")
or provider_settings.get("AWS_REGION")
or "us-east-1"
)

sts = boto3.client(
"sts",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
aws_session_token=session_token,
region_name=region,
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
resp = sts.assume_role(RoleArn=role_arn, RoleSessionName="agenta-bedrock")
creds = resp["Credentials"]

updated = dict(provider_settings)
updated["aws_access_key_id"] = creds["AccessKeyId"]
updated["aws_secret_access_key"] = creds["SecretAccessKey"]
updated["aws_session_token"] = creds["SessionToken"]
# Remove the upper-case variants so litellm only sees one consistent set
updated.pop("AWS_ACCESS_KEY_ID", None)
updated.pop("AWS_SECRET_ACCESS_KEY", None)
updated.pop("AWS_SESSION_TOKEN", None)
# Remove role ARN so it isn't forwarded as an unknown kwarg to litellm
updated.pop("aws_role_arn", None)
updated.pop("AWS_ROLE_ARN", None)

return updated


def _apply_responses_bridge_if_needed(
provider_settings: Dict,
llm_config: ModelConfig,
Expand Down Expand Up @@ -2051,6 +2107,7 @@ async def _run_prompt_llm_config_with_retry(
if messages is not None:
openai_kwargs["messages"] = [*openai_kwargs["messages"], *messages]

provider_settings = _resolve_aws_credentials(provider_settings)
with mockllm.user_aws_credentials_from(
_coerce_credentials(provider_settings)
):
Expand Down
Loading