From 7b09efebab9d6aca1727a06a5de6eceb07c0fe9b Mon Sep 17 00:00:00 2001 From: Saranya Somepalli Date: Fri, 27 Mar 2026 10:14:41 -0700 Subject: [PATCH] Move endpoint resolution from interceptors to pipeline stage --- .../tasks/AuthSchemeGeneratorTasks.java | 6 - .../emitters/tasks/EndpointProviderTasks.java | 6 +- .../scheme/AuthSchemeInterceptorSpec.java | 524 ------------------ .../auth/scheme/AuthSchemeParamsSpec.java | 32 +- .../poet/builder/BaseClientBuilderClass.java | 4 - .../codegen/poet/client/AsyncClientClass.java | 10 +- .../codegen/poet/client/ClientClassUtils.java | 231 ++++++++ .../codegen/poet/client/SyncClientClass.java | 10 +- .../poet/client/specs/JsonProtocolSpec.java | 10 +- .../poet/client/specs/QueryProtocolSpec.java | 8 + .../poet/client/specs/XmlProtocolSpec.java | 12 +- ...ec.java => EndpointResolverUtilsSpec.java} | 408 ++++---------- .../poet/rules/EndpointRulesSpecUtils.java | 10 + .../rules/RequestEndpointInterceptorSpec.java | 84 --- .../poet/auth/scheme/AuthSchemeSpecTest.java | 35 -- .../EndpointResolverInterceptorSpecTest.java | 55 -- ...stEndpointProviderInterceptorSpecTest.java | 31 -- .../internal/AwsExecutionContextBuilder.java | 46 +- .../endpoints/AwsEndpointProviderUtils.java | 92 +-- .../identity/AwsIdentityProviderUpdater.java | 54 ++ .../AwsExecutionContextBuilderTest.java | 45 -- .../awssdk/core/SelectedAuthScheme.java | 24 +- .../client/handler/ClientExecutionParams.java | 23 + .../SdkInternalExecutionAttribute.java | 24 + .../internal/http/AmazonAsyncHttpClient.java | 4 + .../internal/http/AmazonSyncHttpClient.java | 4 + .../http/auth/AuthSchemeResolver.java | 179 ++++++ .../stages/ApiCallMetricCollectionStage.java | 10 +- .../AsyncApiCallMetricCollectionStage.java | 8 +- .../stages/AuthSchemeResolutionStage.java | 158 ++++++ .../stages/EndpointResolutionStage.java | 119 ++++ .../core/spi/endpoint/EndpointResolver.java | 37 ++ .../identity/AuthSchemeOptionsResolver.java | 39 ++ .../spi/identity/IdentityProviderUpdater.java | 39 ++ .../http/auth/AuthSchemeResolverTest.java | 189 +++++++ .../stages/AuthSchemeResolutionStageTest.java | 231 ++++++++ .../stages/EndpointResolutionStageTest.java | 217 ++++++++ .../dynamodb/PaginatorInUserAgentTest.java | 2 +- .../awssdk/services/s3/S3Utilities.java | 60 +- .../internal/signing/DefaultS3Presigner.java | 147 ++++- .../AsyncResponseTransformerTest.java | 3 +- .../S3ExpressAuthSchemeProviderTest.java | 139 ++--- .../services/AuthSchemeInterceptorTest.java | 108 ---- .../IdentityResolutionOverrideTest.java | 7 +- .../AwsEndpointProviderUtilsTest.java | 2 +- .../EndpointInterceptorTests.java | 4 +- .../RequestOverrideEndpointProviderTests.java | 73 +-- 47 files changed, 2072 insertions(+), 1491 deletions(-) delete mode 100644 codegen/src/main/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeInterceptorSpec.java rename codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/{EndpointResolverInterceptorSpec.java => EndpointResolverUtilsSpec.java} (76%) delete mode 100644 codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/RequestEndpointInterceptorSpec.java delete mode 100644 codegen/src/test/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpecTest.java delete mode 100644 codegen/src/test/java/software/amazon/awssdk/codegen/poet/rules/RequestEndpointProviderInterceptorSpecTest.java rename codegen/src/main/resources/software/amazon/awssdk/codegen/rules/AwsEndpointProviderUtils.java.resource => core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/endpoints/AwsEndpointProviderUtils.java (54%) create mode 100644 core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/identity/AwsIdentityProviderUpdater.java create mode 100644 core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/auth/AuthSchemeResolver.java create mode 100644 core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStage.java create mode 100644 core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/EndpointResolutionStage.java create mode 100644 core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/endpoint/EndpointResolver.java create mode 100644 core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/identity/AuthSchemeOptionsResolver.java create mode 100644 core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/identity/IdentityProviderUpdater.java create mode 100644 core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/auth/AuthSchemeResolverTest.java create mode 100644 core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStageTest.java create mode 100644 core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/EndpointResolutionStageTest.java delete mode 100644 test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/AuthSchemeInterceptorTest.java diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/AuthSchemeGeneratorTasks.java b/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/AuthSchemeGeneratorTasks.java index 4b6719cc709c..0794a165181c 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/AuthSchemeGeneratorTasks.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/AuthSchemeGeneratorTasks.java @@ -20,7 +20,6 @@ import software.amazon.awssdk.codegen.emitters.GeneratorTask; import software.amazon.awssdk.codegen.emitters.GeneratorTaskParams; import software.amazon.awssdk.codegen.emitters.PoetGeneratorTask; -import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeInterceptorSpec; import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeParamsSpec; import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeProviderSpec; import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeSpecUtils; @@ -47,7 +46,6 @@ protected List createTasks() { tasks.add(generateDefaultParamsImpl()); tasks.add(generateModelBasedProvider()); tasks.add(generatePreferenceProvider()); - tasks.add(generateAuthSchemeInterceptor()); if (authSchemeSpecUtils.useEndpointBasedAuthProvider()) { tasks.add(generateEndpointBasedProvider()); tasks.add(generateEndpointAwareAuthSchemeParams()); @@ -86,10 +84,6 @@ private GeneratorTask generateEndpointAwareAuthSchemeParams() { } - private GeneratorTask generateAuthSchemeInterceptor() { - return new PoetGeneratorTask(authSchemeInternalDir(), model.getFileHeader(), new AuthSchemeInterceptorSpec(model)); - } - private String authSchemeDir() { return generatorTaskParams.getPathProvider().getAuthSchemeDirectory(); } diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/EndpointProviderTasks.java b/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/EndpointProviderTasks.java index 7f5d64da8e79..b72a33263c5b 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/EndpointProviderTasks.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/EndpointProviderTasks.java @@ -32,9 +32,8 @@ import software.amazon.awssdk.codegen.poet.rules.EndpointProviderInterfaceSpec; import software.amazon.awssdk.codegen.poet.rules.EndpointProviderSpec; import software.amazon.awssdk.codegen.poet.rules.EndpointProviderTestSpec; -import software.amazon.awssdk.codegen.poet.rules.EndpointResolverInterceptorSpec; +import software.amazon.awssdk.codegen.poet.rules.EndpointResolverUtilsSpec; import software.amazon.awssdk.codegen.poet.rules.EndpointRulesClientTestSpec; -import software.amazon.awssdk.codegen.poet.rules.RequestEndpointInterceptorSpec; import software.amazon.awssdk.codegen.poet.rules2.EndpointProviderSpec2; public final class EndpointProviderTasks extends BaseGeneratorTasks { @@ -103,8 +102,7 @@ private boolean shouldGenerateCompiledEndpointRules() { private Collection generateInterceptors() { return Arrays.asList( - new PoetGeneratorTask(endpointRulesInternalDir(), model.getFileHeader(), new EndpointResolverInterceptorSpec(model)), - new PoetGeneratorTask(endpointRulesInternalDir(), model.getFileHeader(), new RequestEndpointInterceptorSpec(model))); + new PoetGeneratorTask(endpointRulesInternalDir(), model.getFileHeader(), new EndpointResolverUtilsSpec(model))); } private GeneratorTask generateClientTests() { diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeInterceptorSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeInterceptorSpec.java deleted file mode 100644 index c4f9e768eaa3..000000000000 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeInterceptorSpec.java +++ /dev/null @@ -1,524 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package software.amazon.awssdk.codegen.poet.auth.scheme; - -import com.squareup.javapoet.ClassName; -import com.squareup.javapoet.FieldSpec; -import com.squareup.javapoet.MethodSpec; -import com.squareup.javapoet.ParameterSpec; -import com.squareup.javapoet.ParameterizedTypeName; -import com.squareup.javapoet.TypeName; -import com.squareup.javapoet.TypeSpec; -import com.squareup.javapoet.TypeVariableName; -import com.squareup.javapoet.WildcardTypeName; -import java.time.Duration; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.concurrent.CompletableFuture; -import java.util.function.Supplier; -import java.util.stream.Collectors; -import javax.lang.model.element.Modifier; -import software.amazon.awssdk.annotations.SdkInternalApi; -import software.amazon.awssdk.awscore.AwsExecutionAttribute; -import software.amazon.awssdk.codegen.model.intermediate.IntermediateModel; -import software.amazon.awssdk.codegen.poet.ClassSpec; -import software.amazon.awssdk.codegen.poet.PoetUtils; -import software.amazon.awssdk.codegen.poet.rules.EndpointRulesSpecUtils; -import software.amazon.awssdk.core.SdkRequest; -import software.amazon.awssdk.core.SelectedAuthScheme; -import software.amazon.awssdk.core.exception.SdkException; -import software.amazon.awssdk.core.identity.SdkIdentityProperty; -import software.amazon.awssdk.core.interceptor.Context; -import software.amazon.awssdk.core.interceptor.ExecutionAttributes; -import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; -import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; -import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; -import software.amazon.awssdk.core.internal.util.MetricUtils; -import software.amazon.awssdk.core.metrics.CoreMetric; -import software.amazon.awssdk.core.useragent.BusinessMetricFeatureId; -import software.amazon.awssdk.endpoints.EndpointProvider; -import software.amazon.awssdk.http.auth.aws.scheme.AwsV4aAuthScheme; -import software.amazon.awssdk.http.auth.aws.signer.RegionSet; -import software.amazon.awssdk.http.auth.scheme.BearerAuthScheme; -import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; -import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; -import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; -import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; -import software.amazon.awssdk.identity.spi.Identity; -import software.amazon.awssdk.identity.spi.IdentityProvider; -import software.amazon.awssdk.identity.spi.IdentityProviders; -import software.amazon.awssdk.identity.spi.ResolveIdentityRequest; -import software.amazon.awssdk.identity.spi.TokenIdentity; -import software.amazon.awssdk.metrics.MetricCollector; -import software.amazon.awssdk.metrics.SdkMetric; -import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.utils.CollectionUtils; -import software.amazon.awssdk.utils.Logger; -import software.amazon.awssdk.utils.Validate; - -public final class AuthSchemeInterceptorSpec implements ClassSpec { - private final AuthSchemeSpecUtils authSchemeSpecUtils; - private final EndpointRulesSpecUtils endpointRulesSpecUtils; - private final IntermediateModel intermediateModel; - - public AuthSchemeInterceptorSpec(IntermediateModel intermediateModel) { - this.intermediateModel = intermediateModel; - this.authSchemeSpecUtils = new AuthSchemeSpecUtils(intermediateModel); - this.endpointRulesSpecUtils = new EndpointRulesSpecUtils(intermediateModel); - } - - @Override - public ClassName className() { - return authSchemeSpecUtils.authSchemeInterceptor(); - } - - @Override - public TypeSpec poetSpec() { - TypeSpec.Builder builder = PoetUtils.createClassBuilder(className()) - .addSuperinterface(ExecutionInterceptor.class) - .addModifiers(Modifier.PUBLIC, Modifier.FINAL) - .addAnnotation(SdkInternalApi.class); - - builder.addField(FieldSpec.builder(Logger.class, "LOG", Modifier.PRIVATE, Modifier.STATIC) - .initializer("$T.loggerFor($T.class)", Logger.class, className()) - .build()); - - builder.addMethod(generateBeforeExecution()) - .addMethod(generateResolveAuthOptions()) - .addMethod(generateSelectAuthScheme()) - .addMethod(generateAuthSchemeParams()) - .addMethod(generateTrySelectAuthScheme()) - .addMethod(generateGetIdentityMetric()) - .addMethod(putSelectedAuthSchemeMethodSpec()); - if (intermediateModel.getCustomizationConfig().isEnableEnvironmentBearerToken()) { - builder.addMethod(generateEnvironmentTokenMetric()); - } - return builder.build(); - } - - private MethodSpec generateEnvironmentTokenMetric() { - return MethodSpec - .methodBuilder("recordEnvironmentTokenBusinessMetric") - .addModifiers(Modifier.PRIVATE) - .addTypeVariable(TypeVariableName.get("T", Identity.class)) - .addParameter(ParameterSpec.builder( - ParameterizedTypeName.get(ClassName.get(SelectedAuthScheme.class), - TypeVariableName.get("T")), - "selectedAuthScheme").build()) - .addParameter(ExecutionAttributes.class, "executionAttributes") - .addStatement("$T tokenFromEnv = executionAttributes.getAttribute($T.TOKEN_CONFIGURED_FROM_ENV)", - String.class, SdkInternalExecutionAttribute.class) - .beginControlFlow("if (selectedAuthScheme != null && selectedAuthScheme.authSchemeOption().schemeId().equals($T" - + ".SCHEME_ID) && selectedAuthScheme.identity().isDone())", BearerAuthScheme.class) - .beginControlFlow("if (selectedAuthScheme.identity().getNow(null) instanceof $T)", TokenIdentity.class) - - .addStatement("$T configuredToken = ($T) selectedAuthScheme.identity().getNow(null)", - TokenIdentity.class, TokenIdentity.class) - .beginControlFlow("if (configuredToken.token().equals(tokenFromEnv))") - .addStatement("executionAttributes.getAttribute($T.BUSINESS_METRICS)" - + ".addMetric($T.BEARER_SERVICE_ENV_VARS.value())", - SdkInternalExecutionAttribute.class, BusinessMetricFeatureId.class) - .endControlFlow() - .endControlFlow() - .endControlFlow() - .build(); - - - } - - private MethodSpec generateBeforeExecution() { - MethodSpec.Builder builder = MethodSpec.methodBuilder("beforeExecution") - .addAnnotation(Override.class) - .addModifiers(Modifier.PUBLIC) - .addParameter(Context.BeforeExecution.class, - "context") - .addParameter(ExecutionAttributes.class, - "executionAttributes"); - - builder.addStatement("$T authOptions = resolveAuthOptions(context, executionAttributes)", - listOf(AuthSchemeOption.class)) - .addStatement("$T selectedAuthScheme = selectAuthScheme(authOptions, executionAttributes)", - wildcardSelectedAuthScheme()) - .addStatement("putSelectedAuthScheme(executionAttributes, selectedAuthScheme)"); - - if (intermediateModel.getCustomizationConfig().isEnableEnvironmentBearerToken()) { - builder.addStatement("recordEnvironmentTokenBusinessMetric(selectedAuthScheme, " - + "executionAttributes)"); - } - - if (authSchemeSpecUtils.hasSigV4aSupport()) { - builder.beginControlFlow("if (selectedAuthScheme != null && " - + "selectedAuthScheme.authSchemeOption().schemeId().equals($T.SCHEME_ID) && " - + "!$T.isSignerOverridden(context.request(), executionAttributes))", - AwsV4aAuthScheme.class, - ClassName.get("software.amazon.awssdk.awscore.util", "SignerOverrideUtils")) - .addStatement("$T businessMetrics = executionAttributes.getAttribute($T.BUSINESS_METRICS)", - ClassName.get("software.amazon.awssdk.core.useragent", "BusinessMetricCollection"), - SdkInternalExecutionAttribute.class) - .beginControlFlow("if (businessMetrics != null)") - .addStatement("businessMetrics.addMetric($T.SIGV4A_SIGNING.value())", - BusinessMetricFeatureId.class) - .endControlFlow() - .endControlFlow(); - } - return builder.build(); - } - - private MethodSpec generateResolveAuthOptions() { - MethodSpec.Builder builder = MethodSpec.methodBuilder("resolveAuthOptions") - .addModifiers(Modifier.PRIVATE) - .returns(listOf(AuthSchemeOption.class)) - .addParameter(Context.BeforeExecution.class, - "context") - .addParameter(ExecutionAttributes.class, - "executionAttributes"); - - builder.addStatement("$1T authSchemeProvider = $2T.isInstanceOf($1T.class, executionAttributes" - + ".getAttribute($3T.AUTH_SCHEME_RESOLVER), $4S)", - authSchemeSpecUtils.providerInterfaceName(), - Validate.class, - SdkInternalExecutionAttribute.class, - "Expected an instance of " + authSchemeSpecUtils.providerInterfaceName().simpleName()); - builder.addStatement("$T params = authSchemeParams(context.request(), executionAttributes)", - authSchemeSpecUtils.parametersInterfaceName()); - builder.addStatement("return authSchemeProvider.resolveAuthScheme(params)"); - return builder.build(); - } - - private MethodSpec generateAuthSchemeParams() { - MethodSpec.Builder builder = MethodSpec.methodBuilder("authSchemeParams") - .addModifiers(Modifier.PRIVATE) - .returns(authSchemeSpecUtils.parametersInterfaceName()) - .addParameter(SdkRequest.class, "request") - .addParameter(ExecutionAttributes.class, "executionAttributes"); - - if (!authSchemeSpecUtils.useEndpointBasedAuthProvider()) { - builder.addStatement("$T operation = executionAttributes.getAttribute($T.OPERATION_NAME)", String.class, - SdkExecutionAttribute.class); - builder.addStatement("$T.Builder builder = $T.builder().operation(operation)", - authSchemeSpecUtils.parametersInterfaceName(), - authSchemeSpecUtils.parametersInterfaceName()); - - if (authSchemeSpecUtils.usesSigV4()) { - builder.addStatement("$T region = executionAttributes.getAttribute($T.AWS_REGION)", Region.class, - AwsExecutionAttribute.class); - builder.addStatement("builder.region(region)"); - } - generateSigv4aSigningRegionSet(builder); - builder.addStatement("return builder.build()"); - return builder.build(); - } - - builder.addStatement("$T endpointParams = $T.ruleParams(request, executionAttributes)", - endpointRulesSpecUtils.parametersClassName(), - endpointRulesSpecUtils.resolverInterceptorName()); - builder.addStatement("$1T.Builder builder = $1T.builder()", authSchemeSpecUtils.parametersInterfaceName()); - boolean regionIncluded = false; - for (String paramName : endpointRulesSpecUtils.parameters().keySet()) { - if (!authSchemeSpecUtils.includeParamForProvider(paramName)) { - continue; - } - regionIncluded = regionIncluded || paramName.equalsIgnoreCase("region"); - String methodName = endpointRulesSpecUtils.paramMethodName(paramName); - builder.addStatement("builder.$1N(endpointParams.$1N())", methodName); - } - - builder.addStatement("$T operation = executionAttributes.getAttribute($T.OPERATION_NAME)", String.class, - SdkExecutionAttribute.class); - builder.addStatement("builder.operation(operation)"); - if (authSchemeSpecUtils.usesSigV4() && !regionIncluded) { - builder.addStatement("$T region = executionAttributes.getAttribute($T.AWS_REGION)", Region.class, - AwsExecutionAttribute.class); - builder.addStatement("builder.region(region)"); - } - generateSigv4aSigningRegionSet(builder); - ClassName paramsBuilderClass = authSchemeSpecUtils.parametersEndpointAwareDefaultImplName().nestedClass("Builder"); - builder.beginControlFlow("if (builder instanceof $T)", - paramsBuilderClass); - ClassName endpointProviderClass = endpointRulesSpecUtils.providerInterfaceName(); - builder.addStatement("$T endpointProvider = executionAttributes.getAttribute($T.ENDPOINT_PROVIDER)", - EndpointProvider.class, - SdkInternalExecutionAttribute.class); - builder.beginControlFlow("if (endpointProvider instanceof $T)", endpointProviderClass); - builder.addStatement("(($T)builder).endpointProvider(($T)endpointProvider)", paramsBuilderClass, endpointProviderClass); - builder.endControlFlow(); - builder.endControlFlow(); - builder.addStatement("return builder.build()"); - return builder.build(); - } - - private MethodSpec generateSelectAuthScheme() { - MethodSpec.Builder builder = MethodSpec.methodBuilder("selectAuthScheme") - .addModifiers(Modifier.PRIVATE) - .returns(wildcardSelectedAuthScheme()) - .addParameter(listOf(AuthSchemeOption.class), "authOptions") - .addParameter(ExecutionAttributes.class, "executionAttributes"); - - builder.addStatement("$T metricCollector = executionAttributes.getAttribute($T.API_CALL_METRIC_COLLECTOR)", - MetricCollector.class, SdkExecutionAttribute.class) - .addStatement("$T authSchemes = executionAttributes.getAttribute($T.AUTH_SCHEMES)", - mapOf(String.class, wildcardAuthScheme()), - SdkInternalExecutionAttribute.class) - .addStatement("$T identityProviders = executionAttributes.getAttribute($T.IDENTITY_PROVIDERS)", - IdentityProviders.class, SdkInternalExecutionAttribute.class) - .addStatement("$T discardedReasons = new $T<>()", - listOfStringSuppliers(), ArrayList.class); - - builder.beginControlFlow("for ($T authOption : authOptions)", AuthSchemeOption.class); - { - builder.addStatement("$T authScheme = authSchemes.get(authOption.schemeId())", wildcardAuthScheme()) - .addStatement("$T selectedAuthScheme = trySelectAuthScheme(authOption, authScheme, identityProviders, " - + "discardedReasons, metricCollector, executionAttributes)", - wildcardSelectedAuthScheme()); - builder.beginControlFlow("if (selectedAuthScheme != null)"); - { - addLogDebugDiscardedOptions(builder); - builder.addStatement("return selectedAuthScheme") - .endControlFlow(); - } - // end foreach - builder.endControlFlow(); - } - builder.addStatement("throw $T.builder()" - + ".message($S + discardedReasons.stream().map($T::get).collect($T.joining(\", \")))" - + ".build()", - SdkException.class, - "Failed to determine how to authenticate the user: ", - Supplier.class, - Collectors.class); - return builder.build(); - } - - //TODO (s3express) Review "general" identity properties and their propagation - private MethodSpec generateTrySelectAuthScheme() { - MethodSpec.Builder builder = MethodSpec.methodBuilder("trySelectAuthScheme") - .addModifiers(Modifier.PRIVATE) - .returns(namedSelectedAuthScheme()) - .addParameter(AuthSchemeOption.class, "authOption") - .addParameter(namedAuthScheme(), "authScheme") - .addParameter(IdentityProviders.class, "identityProviders") - .addParameter(listOfStringSuppliers(), "discardedReasons") - .addParameter(MetricCollector.class, "metricCollector") - .addParameter(ExecutionAttributes.class, "executionAttributes") - .addTypeVariable(TypeVariableName.get("T", Identity.class)); - - builder.beginControlFlow("if (authScheme == null)"); - { - builder.addStatement("discardedReasons.add(() -> String.format($S, authOption.schemeId()))", - "'%s' is not enabled for this request.") - .addStatement("return null") - .endControlFlow(); - } - builder.addStatement("$T identityProvider = authScheme.identityProvider(identityProviders)", - namedIdentityProvider()); - - builder.beginControlFlow("if (identityProvider == null)"); - { - builder.addStatement("discardedReasons.add(() -> String.format($S, authOption.schemeId()))", - "'%s' does not have an identity provider configured.") - .addStatement("return null") - .endControlFlow(); - } - - builder.addStatement("$T signer", - ParameterizedTypeName.get(ClassName.get(HttpSigner.class), TypeVariableName.get("T"))); - builder.beginControlFlow("try"); - { - builder.addStatement("signer = authScheme.signer()"); - builder.nextControlFlow("catch (RuntimeException e)"); - builder.addStatement("discardedReasons.add(() -> String.format($S, authOption.schemeId(), e.getMessage()))", - "'%s' signer could not be retrieved: %s") - .addStatement("return null") - .endControlFlow(); - } - - - builder.addStatement("$T.Builder identityRequestBuilder = $T.builder()", - ResolveIdentityRequest.class, - ResolveIdentityRequest.class); - builder.addStatement("authOption.forEachIdentityProperty(identityRequestBuilder::putProperty)"); - if (endpointRulesSpecUtils.isS3()) { - builder.addStatement("identityRequestBuilder.putProperty($T.SDK_CLIENT, " - + "executionAttributes.getAttribute($T.SDK_CLIENT))", - SdkIdentityProperty.class, - SdkInternalExecutionAttribute.class); - } - builder.addStatement("$T identity", namedIdentityFuture()); - builder.addStatement("$T metric = getIdentityMetric(identityProvider)", durationSdkMetric()); - builder.beginControlFlow("if (metric == null)") - .addStatement("identity = identityProvider.resolveIdentity(identityRequestBuilder.build())") - .nextControlFlow("else") - .addStatement("identity = $T.reportDuration(" - + "() -> identityProvider.resolveIdentity(identityRequestBuilder.build()), metricCollector, metric)", - MetricUtils.class) - .endControlFlow(); - - builder.addStatement("return new $T<>(identity, signer, authOption)", SelectedAuthScheme.class); - return builder.build(); - } - - private MethodSpec generateGetIdentityMetric() { - MethodSpec.Builder builder = MethodSpec.methodBuilder("getIdentityMetric") - .addModifiers(Modifier.PRIVATE) - .returns(durationSdkMetric()) - .addParameter(wildcardIdentityProvider(), "identityProvider"); - - builder.addStatement("Class identityType = identityProvider.identityType()") - .beginControlFlow("if (identityType == $T.class)", AwsCredentialsIdentity.class) - .addStatement("return $T.CREDENTIALS_FETCH_DURATION", CoreMetric.class) - .endControlFlow() - .beginControlFlow("if (identityType == $T.class)", TokenIdentity.class) - .addStatement("return $T.TOKEN_FETCH_DURATION", CoreMetric.class) - .endControlFlow() - .addStatement("return null"); - - return builder.build(); - } - - private MethodSpec putSelectedAuthSchemeMethodSpec() { - String attributeParamName = "attributes"; - String selectedAuthSchemeParamName = "selectedAuthScheme"; - MethodSpec.Builder builder = MethodSpec.methodBuilder("putSelectedAuthScheme") - .addModifiers(Modifier.PRIVATE) - .addTypeVariable(TypeVariableName.get("T", Identity.class)) - .addParameter(ExecutionAttributes.class, attributeParamName) - .addParameter(ParameterSpec.builder( - ParameterizedTypeName.get(ClassName.get(SelectedAuthScheme.class), - TypeVariableName.get("T")), - selectedAuthSchemeParamName).build()); - builder.addStatement("$T existingAuthScheme = $N.getAttribute($T.SELECTED_AUTH_SCHEME)", - ParameterizedTypeName.get(ClassName.get(SelectedAuthScheme.class), - WildcardTypeName.subtypeOf(Object.class)), - attributeParamName, - SdkInternalExecutionAttribute.class); - - builder.beginControlFlow("if (existingAuthScheme != null)") - .addStatement("$T selectedOption = $N.authSchemeOption().toBuilder()", - AuthSchemeOption.Builder.class, selectedAuthSchemeParamName) - .addStatement("existingAuthScheme.authSchemeOption().forEachIdentityProperty" - + "(selectedOption::putIdentityPropertyIfAbsent)") - .addStatement("existingAuthScheme.authSchemeOption().forEachSignerProperty" - + "(selectedOption::putSignerPropertyIfAbsent)") - .addStatement("$N = new $T<>($N.identity(), $N.signer(), selectedOption.build())", - selectedAuthSchemeParamName, - SelectedAuthScheme.class, - selectedAuthSchemeParamName, - selectedAuthSchemeParamName); - builder.endControlFlow(); - - builder.addStatement("$N.putAttribute($T.SELECTED_AUTH_SCHEME, $N)", - attributeParamName, SdkInternalExecutionAttribute.class, selectedAuthSchemeParamName); - - return builder.build(); - } - - private void addLogDebugDiscardedOptions(MethodSpec.Builder builder) { - builder.beginControlFlow("if (!discardedReasons.isEmpty())"); - { - builder.addStatement("LOG.debug(() -> String.format(\"%s auth will be used, discarded: '%s'\", " - + "authOption.schemeId(), " - + "discardedReasons.stream().map($T::get).collect($T.joining(\", \"))))", - Supplier.class, Collectors.class) - .endControlFlow(); - } - } - - // IdentityProvider - private TypeName namedIdentityProvider() { - return ParameterizedTypeName.get(ClassName.get(IdentityProvider.class), TypeVariableName.get("T")); - } - - // IdentityProvider - private TypeName wildcardIdentityProvider() { - return ParameterizedTypeName.get(ClassName.get(IdentityProvider.class), WildcardTypeName.subtypeOf(Object.class)); - } - - // CompletableFuture - private TypeName namedIdentityFuture() { - return ParameterizedTypeName.get(ClassName.get(CompletableFuture.class), - WildcardTypeName.subtypeOf(TypeVariableName.get("T"))); - } - - // AuthScheme - private TypeName namedAuthScheme() { - return ParameterizedTypeName.get(ClassName.get(AuthScheme.class), - TypeVariableName.get("T", Identity.class)); - } - - // AuthScheme - private TypeName wildcardAuthScheme() { - return ParameterizedTypeName.get(ClassName.get(AuthScheme.class), - WildcardTypeName.subtypeOf(Object.class)); - } - - // SelectedAuthScheme - private TypeName namedSelectedAuthScheme() { - return ParameterizedTypeName.get(ClassName.get(SelectedAuthScheme.class), - TypeVariableName.get("T", Identity.class)); - } - - // SelectedAuthScheme - private TypeName wildcardSelectedAuthScheme() { - return ParameterizedTypeName.get(ClassName.get(SelectedAuthScheme.class), - WildcardTypeName.subtypeOf(Identity.class)); - } - - // List> - private TypeName listOfStringSuppliers() { - return listOf(ParameterizedTypeName.get(Supplier.class, String.class)); - } - - // Map - private TypeName mapOf(Object keyType, Object valueType) { - return ParameterizedTypeName.get(ClassName.get(Map.class), toTypeName(keyType), toTypeName(valueType)); - } - - // List - private TypeName listOf(Object valueType) { - return ParameterizedTypeName.get(ClassName.get(List.class), toTypeName(valueType)); - } - - // SdkMetric - private ParameterizedTypeName durationSdkMetric() { - return ParameterizedTypeName.get(ClassName.get(SdkMetric.class), toTypeName(Duration.class)); - } - - private TypeName toTypeName(Object valueType) { - TypeName result; - if (valueType instanceof Class) { - result = ClassName.get((Class) valueType); - } else if (valueType instanceof TypeName) { - result = (TypeName) valueType; - } else { - throw new IllegalArgumentException("Don't know how to convert " + valueType + " to TypeName"); - } - return result; - } - - private void generateSigv4aSigningRegionSet(MethodSpec.Builder builder) { - if (authSchemeSpecUtils.hasSigV4aSupport()) { - builder.addStatement( - "executionAttributes.getOptionalAttribute($T.AWS_SIGV4A_SIGNING_REGION_SET)\n" + - " .filter(regionSet -> !$T.isNullOrEmpty(regionSet))\n" + - " .ifPresent(nonEmptyRegionSet -> builder.regionSet($T.create(nonEmptyRegionSet)))", - AwsExecutionAttribute.class, - CollectionUtils.class, - RegionSet.class - ); - } - } -} diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeParamsSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeParamsSpec.java index 848182945524..86af5542e886 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeParamsSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeParamsSpec.java @@ -60,14 +60,42 @@ public TypeSpec poetSpec() { .addModifiers(Modifier.PUBLIC) .addAnnotation(SdkPublicApi.class) .addJavadoc(interfaceJavadoc()) - .addMethod(builderMethod()) - .addType(builderInterfaceSpec()); + .addMethod(builderMethod()); + + if (authSchemeSpecUtils.generateEndpointBasedParams()) { + b.addMethod(fromEndpointParamsMethod()); + } + + b.addType(builderInterfaceSpec()); addAccessorMethods(b); addToBuilder(b); return b.build(); } + private MethodSpec fromEndpointParamsMethod() { + ClassName endpointParamsClass = endpointRulesSpecUtils.parametersClassName(); + MethodSpec.Builder builder = MethodSpec.methodBuilder("fromEndpointParams") + .addModifiers(Modifier.PUBLIC, Modifier.STATIC) + .addParameter(endpointParamsClass, "endpointParams") + .returns(authSchemeSpecUtils.parametersInterfaceBuilderInterfaceName()) + .addJavadoc("Create a builder pre-populated with endpoint parameters.\n" + + "@param endpointParams the endpoint parameters to copy\n" + + "@return a builder with values from the endpoint parameters"); + + builder.addStatement("$T builder = builder()", authSchemeSpecUtils.parametersInterfaceBuilderInterfaceName()); + + parameters().forEach((name, model) -> { + if (authSchemeSpecUtils.includeParamForProvider(name)) { + String methodName = endpointRulesSpecUtils.paramMethodName(name); + builder.addStatement("builder.$1N(endpointParams.$1N())", methodName); + } + }); + + builder.addStatement("return builder"); + return builder.build(); + } + private CodeBlock interfaceJavadoc() { CodeBlock.Builder b = CodeBlock.builder(); diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/BaseClientBuilderClass.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/BaseClientBuilderClass.java index 44d8ce154e72..7533cf25c0c4 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/BaseClientBuilderClass.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/BaseClientBuilderClass.java @@ -370,10 +370,6 @@ private MethodSpec finalizeServiceConfigurationMethod() { List builtInInterceptors = new ArrayList<>(); - builtInInterceptors.add(authSchemeSpecUtils.authSchemeInterceptor()); - builtInInterceptors.add(endpointRulesSpecUtils.resolverInterceptorName()); - builtInInterceptors.add(endpointRulesSpecUtils.requestModifierInterceptorName()); - for (String interceptor : model.getCustomizationConfig().getInterceptors()) { builtInInterceptors.add(ClassName.bestGuess(interceptor)); } diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java index 539d5df35a43..09bbaa7caf80 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java @@ -69,10 +69,12 @@ import software.amazon.awssdk.codegen.poet.PoetExtension; import software.amazon.awssdk.codegen.poet.PoetUtils; import software.amazon.awssdk.codegen.poet.StaticImport; +import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeSpecUtils; import software.amazon.awssdk.codegen.poet.client.specs.ProtocolSpec; import software.amazon.awssdk.codegen.poet.eventstream.EventStreamUtils; import software.amazon.awssdk.codegen.poet.model.EventStreamSpecHelper; import software.amazon.awssdk.codegen.poet.model.ServiceClientConfigurationUtils; +import software.amazon.awssdk.codegen.poet.rules.EndpointRulesSpecUtils; import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.core.async.AsyncResponseTransformerUtils; @@ -100,6 +102,8 @@ public final class AsyncClientClass extends AsyncClientInterface { private final ProtocolSpec protocolSpec; private final ClassName serviceClientConfigurationClassName; private final ServiceClientConfigurationUtils configurationUtils; + private final AuthSchemeSpecUtils authSchemeSpecUtils; + private final EndpointRulesSpecUtils endpointRulesSpecUtils; private boolean hasScheduledExecutor; public AsyncClientClass(GeneratorTaskParams dependencies) { @@ -110,6 +114,8 @@ public AsyncClientClass(GeneratorTaskParams dependencies) { this.protocolSpec = getProtocolSpecs(poetExtensions, model); this.serviceClientConfigurationClassName = new PoetExtension(model).getServiceConfigClass(); this.configurationUtils = new ServiceClientConfigurationUtils(model); + this.authSchemeSpecUtils = new AuthSchemeSpecUtils(model); + this.endpointRulesSpecUtils = new EndpointRulesSpecUtils(model); } @Override @@ -165,7 +171,9 @@ protected void addAdditionalMethods(TypeSpec.Builder type) { .addMethod(nameMethod()) .addMethods(protocolSpec.additionalMethods()) .addMethod(protocolSpec.initProtocolFactory(model)) - .addMethod(resolveMetricPublishersMethod()); + .addMethod(resolveMetricPublishersMethod()) + .addMethod(ClientClassUtils.resolveAuthSchemeOptionsMethod(authSchemeSpecUtils, endpointRulesSpecUtils)) + .addMethod(ClientClassUtils.resolveEndpointMethod(authSchemeSpecUtils, endpointRulesSpecUtils)); type.addMethod(ClientClassUtils.updateRetryStrategyClientConfigurationMethod()); type.addMethod(updateSdkClientConfigurationMethod(configurationUtils.serviceClientConfigurationBuilderClassName(), diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/ClientClassUtils.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/ClientClassUtils.java index 9d4e15b3dd86..0b7452d3111d 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/ClientClassUtils.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/ClientClassUtils.java @@ -25,10 +25,13 @@ import com.squareup.javapoet.ParameterizedTypeName; import com.squareup.javapoet.TypeName; import com.squareup.javapoet.TypeVariableName; +import com.squareup.javapoet.WildcardTypeName; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.stream.Collectors; import javax.lang.model.element.Modifier; @@ -45,14 +48,22 @@ import software.amazon.awssdk.codegen.model.service.HostPrefixProcessor; import software.amazon.awssdk.codegen.poet.PoetExtension; import software.amazon.awssdk.codegen.poet.PoetUtils; +import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeSpecUtils; import software.amazon.awssdk.codegen.poet.rules.EndpointRulesSpecUtils; +import software.amazon.awssdk.core.SdkClient; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.retry.RetryMode; import software.amazon.awssdk.core.signer.Signer; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.utils.AttributeMap; import software.amazon.awssdk.utils.CollectionUtils; @@ -346,4 +357,224 @@ public static MethodSpec updateRetryStrategyClientConfigurationMethod() { static String transformServiceId(String serviceId) { return serviceId.replace(" ", "_"); } + + static MethodSpec resolveAuthSchemeOptionsMethod(AuthSchemeSpecUtils authSchemeSpecUtils, + EndpointRulesSpecUtils endpointRulesSpecUtils) { + MethodSpec.Builder builder = MethodSpec.methodBuilder("resolveAuthSchemeOptions") + .addModifiers(PRIVATE) + .returns(ParameterizedTypeName.get(ClassName.get(List.class), ClassName.get(AuthSchemeOption.class))) + .addParameter(SdkRequest.class, "request") + .addParameter(String.class, "operationName") + .addParameter(SdkClientConfiguration.class, "clientConfiguration"); + + ClassName providerInterface = authSchemeSpecUtils.providerInterfaceName(); + + builder.addStatement("$T authSchemeProvider = $T.isInstanceOf($T.class, " + + "clientConfiguration.option($T.AUTH_SCHEME_PROVIDER), $S)", + providerInterface, Validate.class, providerInterface, SdkClientOption.class, + "Expected an instance of " + authSchemeSpecUtils.providerInterfaceName().simpleName()); + + if (authSchemeSpecUtils.useEndpointBasedAuthProvider()) { + addEndpointBasedAuthSchemeResolution(builder, authSchemeSpecUtils, endpointRulesSpecUtils); + } else { + addSimpleAuthSchemeResolution(builder, authSchemeSpecUtils); + } + + if (endpointRulesSpecUtils.isS3()) { + ClassName sdkIdentityProperty = ClassName.get("software.amazon.awssdk.core.identity", "SdkIdentityProperty"); + builder.addStatement("$T sdkClient = clientConfiguration.option($T.SDK_CLIENT)", + SdkClient.class, SdkClientOption.class); + builder.addStatement("return options.stream().map(o -> o.toBuilder()" + + ".putIdentityProperty($T.SDK_CLIENT, sdkClient).build())" + + ".collect($T.toList())", + sdkIdentityProperty, Collectors.class); + } else { + builder.addStatement("return options"); + } + + return builder.build(); + } + + private static void addSimpleAuthSchemeResolution(MethodSpec.Builder builder, + AuthSchemeSpecUtils authSchemeSpecUtils) { + ClassName paramsInterface = authSchemeSpecUtils.parametersInterfaceName(); + ClassName awsClientOption = ClassName.get("software.amazon.awssdk.awscore.client.config", "AwsClientOption"); + + builder.addStatement("$T.Builder paramsBuilder = $T.builder().operation(operationName)", + paramsInterface, paramsInterface); + + if (authSchemeSpecUtils.usesSigV4()) { + builder.addStatement("paramsBuilder.region(clientConfiguration.option($T.AWS_REGION))", awsClientOption); + } + + if (authSchemeSpecUtils.hasSigV4aSupport()) { + ClassName regionSet = ClassName.get("software.amazon.awssdk.http.auth.aws.signer", "RegionSet"); + builder.addStatement("$T sigv4aRegionSet = clientConfiguration.option($T.AWS_SIGV4A_SIGNING_REGION_SET)", + ClassName.get(Set.class), awsClientOption); + builder.beginControlFlow("if (!$T.isNullOrEmpty(sigv4aRegionSet))", CollectionUtils.class); + builder.addStatement("paramsBuilder.regionSet($T.create(sigv4aRegionSet))", regionSet); + builder.endControlFlow(); + } + + builder.addStatement("$T<$T> options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build())", + List.class, AuthSchemeOption.class); + } + + private static void addEndpointBasedAuthSchemeResolution(MethodSpec.Builder builder, + AuthSchemeSpecUtils authSchemeSpecUtils, + EndpointRulesSpecUtils endpointRulesSpecUtils) { + ClassName paramsInterface = authSchemeSpecUtils.parametersInterfaceName(); + ClassName awsClientOption = ClassName.get("software.amazon.awssdk.awscore.client.config", "AwsClientOption"); + ClassName endpointParamsClass = endpointRulesSpecUtils.parametersClassName(); + ClassName endpointResolverUtils = endpointRulesSpecUtils.endpointResolverUtilsName(); + ClassName executionAttributesClass = ClassName.get("software.amazon.awssdk.core.interceptor", "ExecutionAttributes"); + ClassName awsExecutionAttribute = ClassName.get("software.amazon.awssdk.awscore", "AwsExecutionAttribute"); + ClassName sdkExecutionAttribute = ClassName.get("software.amazon.awssdk.core.interceptor", "SdkExecutionAttribute"); + ClassName sdkInternalExecutionAttribute = ClassName.get("software.amazon.awssdk.core.interceptor", + "SdkInternalExecutionAttribute"); + + builder.addStatement("$T executionAttributes = new $T()", executionAttributesClass, executionAttributesClass); + builder.addStatement("executionAttributes.putAttribute($T.AWS_REGION, clientConfiguration.option($T.AWS_REGION))", + awsExecutionAttribute, awsClientOption); + builder.addStatement("executionAttributes.putAttribute($T.DUALSTACK_ENDPOINT_ENABLED, " + + "clientConfiguration.option($T.DUALSTACK_ENDPOINT_ENABLED))", + awsExecutionAttribute, awsClientOption); + builder.addStatement("executionAttributes.putAttribute($T.FIPS_ENDPOINT_ENABLED, " + + "clientConfiguration.option($T.FIPS_ENDPOINT_ENABLED))", + awsExecutionAttribute, awsClientOption); + builder.addStatement("executionAttributes.putAttribute($T.OPERATION_NAME, operationName)", sdkExecutionAttribute); + builder.addStatement("executionAttributes.putAttribute($T.CLIENT_ENDPOINT_PROVIDER, " + + "clientConfiguration.option($T.CLIENT_ENDPOINT_PROVIDER))", + sdkInternalExecutionAttribute, SdkClientOption.class); + builder.addStatement("executionAttributes.putAttribute($T.CLIENT_CONTEXT_PARAMS, " + + "clientConfiguration.option($T.CLIENT_CONTEXT_PARAMS))", + sdkInternalExecutionAttribute, SdkClientOption.class); + + builder.addStatement("$T endpointParams = $T.ruleParams(request, executionAttributes)", + endpointParamsClass, endpointResolverUtils); + + builder.addStatement("$T.Builder paramsBuilder = $T.builder()", paramsInterface, paramsInterface); + + boolean regionIncluded = false; + for (String paramName : endpointRulesSpecUtils.parameters().keySet()) { + if (!authSchemeSpecUtils.includeParamForProvider(paramName)) { + continue; + } + regionIncluded = regionIncluded || paramName.equalsIgnoreCase("region"); + String methodName = endpointRulesSpecUtils.paramMethodName(paramName); + builder.addStatement("paramsBuilder.$1N(endpointParams.$1N())", methodName); + } + + builder.addStatement("paramsBuilder.operation(operationName)"); + + if (authSchemeSpecUtils.usesSigV4() && !regionIncluded) { + builder.addStatement("paramsBuilder.region(clientConfiguration.option($T.AWS_REGION))", awsClientOption); + } + + if (authSchemeSpecUtils.hasSigV4aSupport()) { + ClassName regionSet = ClassName.get("software.amazon.awssdk.http.auth.aws.signer", "RegionSet"); + builder.addStatement("$T sigv4aRegionSet = clientConfiguration.option($T.AWS_SIGV4A_SIGNING_REGION_SET)", + ClassName.get(Set.class), awsClientOption); + builder.beginControlFlow("if (!$T.isNullOrEmpty(sigv4aRegionSet))", CollectionUtils.class); + builder.addStatement("paramsBuilder.regionSet($T.create(sigv4aRegionSet))", regionSet); + builder.endControlFlow(); + } + + ClassName paramsBuilderClass = authSchemeSpecUtils.parametersEndpointAwareDefaultImplName().nestedClass("Builder"); + ClassName endpointProviderInterface = endpointRulesSpecUtils.providerInterfaceName(); + + builder.beginControlFlow("if (paramsBuilder instanceof $T)", paramsBuilderClass); + builder.addStatement("$T endpointProvider = clientConfiguration.option($T.ENDPOINT_PROVIDER)", + ClassName.get("software.amazon.awssdk.endpoints", "EndpointProvider"), SdkClientOption.class); + builder.beginControlFlow("if (endpointProvider instanceof $T)", endpointProviderInterface); + builder.addStatement("(($T) paramsBuilder).endpointProvider(($T) endpointProvider)", + paramsBuilderClass, endpointProviderInterface); + builder.endControlFlow(); + builder.endControlFlow(); + + builder.addStatement("$T<$T> options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build())", + List.class, AuthSchemeOption.class); + } + + static MethodSpec resolveEndpointMethod(AuthSchemeSpecUtils authSchemeSpecUtils, + EndpointRulesSpecUtils endpointRulesSpecUtils) { + ClassName utilsClass = endpointRulesSpecUtils.endpointResolverUtilsName(); + ClassName endpointParamsClass = endpointRulesSpecUtils.parametersClassName(); + ClassName providerInterface = endpointRulesSpecUtils.providerInterfaceName(); + ClassName awsEndpointProviderUtils = endpointRulesSpecUtils.sharedAwsEndpointProviderUtilsName(); + ClassName awsEndpointAttribute = ClassName.get("software.amazon.awssdk.awscore.endpoints", "AwsEndpointAttribute"); + ClassName endpointAuthScheme = ClassName.get("software.amazon.awssdk.awscore.endpoints.authscheme", "EndpointAuthScheme"); + + MethodSpec.Builder b = MethodSpec.methodBuilder("resolveEndpoint") + .addModifiers(PRIVATE) + .returns(Endpoint.class) + .addParameter(SdkRequest.class, "request") + .addParameter(ExecutionAttributes.class, "executionAttributes") + .addParameter(String.class, "operationName"); + + b.addStatement("$1T provider = ($1T) executionAttributes.getAttribute($2T.ENDPOINT_PROVIDER)", + providerInterface, SdkInternalExecutionAttribute.class); + + b.beginControlFlow("try"); + b.addStatement("$T endpointParams = $T.ruleParams(request, executionAttributes)", + endpointParamsClass, utilsClass); + b.addStatement("$T endpoint = provider.resolveEndpoint(endpointParams).join()", Endpoint.class); + + b.beginControlFlow("if (!$T.disableHostPrefixInjection(executionAttributes))", awsEndpointProviderUtils); + b.addStatement("$T hostPrefix = $T.hostPrefix(operationName, request)", + ParameterizedTypeName.get(Optional.class, String.class), utilsClass); + b.beginControlFlow("if (hostPrefix.isPresent())"); + b.addStatement("endpoint = $T.addHostPrefix(endpoint, hostPrefix.get())", awsEndpointProviderUtils); + b.endControlFlow(); + b.endControlFlow(); + + b.addStatement("$T endpointAuthSchemes = endpoint.attribute($T.AUTH_SCHEMES)", + ParameterizedTypeName.get(ClassName.get(List.class), endpointAuthScheme), + awsEndpointAttribute); + b.addStatement("$T selectedAuthScheme = executionAttributes.getAttribute($T.SELECTED_AUTH_SCHEME)", + ParameterizedTypeName.get(ClassName.get(SelectedAuthScheme.class), + WildcardTypeName.subtypeOf(Object.class)), + SdkInternalExecutionAttribute.class); + b.beginControlFlow("if (endpointAuthSchemes != null && selectedAuthScheme != null)"); + b.addStatement("selectedAuthScheme = $T.authSchemeWithEndpointSignerProperties(endpointAuthSchemes, selectedAuthScheme)", + utilsClass); + + if (authSchemeSpecUtils.usesSigV4a() || authSchemeSpecUtils.generateEndpointBasedParams()) { + ClassName awsV4aAuthScheme = ClassName.get("software.amazon.awssdk.http.auth.aws.scheme", "AwsV4aAuthScheme"); + ClassName awsV4aHttpSigner = ClassName.get("software.amazon.awssdk.http.auth.aws.signer", "AwsV4aHttpSigner"); + ClassName regionSet = ClassName.get("software.amazon.awssdk.http.auth.aws.signer", "RegionSet"); + ClassName authSchemeOption = ClassName.get("software.amazon.awssdk.http.auth.spi.scheme", "AuthSchemeOption"); + + b.addComment("Precedence of SigV4a RegionSet is set according to multi-auth SigV4a specifications"); + b.beginControlFlow("if (selectedAuthScheme.authSchemeOption().schemeId().equals($T.SCHEME_ID) " + + "&& selectedAuthScheme.authSchemeOption().signerProperty($T.REGION_SET) == null)", + awsV4aAuthScheme, awsV4aHttpSigner); + b.addStatement("$T optionBuilder = selectedAuthScheme.authSchemeOption().toBuilder()", + authSchemeOption.nestedClass("Builder")); + b.addStatement("$1T rs = $1T.create(endpointParams.region().id())", regionSet); + b.addStatement("optionBuilder.putSignerProperty($T.REGION_SET, rs)", awsV4aHttpSigner); + b.addStatement("selectedAuthScheme = new $T(selectedAuthScheme.identity(), selectedAuthScheme.signer(), " + + "optionBuilder.build())", SelectedAuthScheme.class); + b.endControlFlow(); + } + + b.addStatement("executionAttributes.putAttribute($T.SELECTED_AUTH_SCHEME, selectedAuthScheme)", + SdkInternalExecutionAttribute.class); + b.endControlFlow(); + + b.addStatement("$T.setMetricValues(endpoint, executionAttributes)", utilsClass); + + b.addStatement("return endpoint"); + + b.nextControlFlow("catch ($T e)", CompletionException.class); + b.addStatement("$T cause = e.getCause()", Throwable.class); + b.beginControlFlow("if (cause instanceof $T)", SdkClientException.class); + b.addStatement("throw ($T) cause", SdkClientException.class); + b.endControlFlow(); + b.addStatement("throw $T.create($S + cause.getMessage(), cause)", + SdkClientException.class, "Endpoint resolution failed: "); + b.endControlFlow(); + + return b.build(); + } } diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/SyncClientClass.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/SyncClientClass.java index 5736ddbecaf5..0d33e441bc40 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/SyncClientClass.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/SyncClientClass.java @@ -54,12 +54,14 @@ import software.amazon.awssdk.codegen.model.service.PreClientExecutionRequestCustomizer; import software.amazon.awssdk.codegen.poet.PoetExtension; import software.amazon.awssdk.codegen.poet.PoetUtils; +import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeSpecUtils; import software.amazon.awssdk.codegen.poet.client.specs.Ec2ProtocolSpec; import software.amazon.awssdk.codegen.poet.client.specs.JsonProtocolSpec; import software.amazon.awssdk.codegen.poet.client.specs.ProtocolSpec; import software.amazon.awssdk.codegen.poet.client.specs.QueryProtocolSpec; import software.amazon.awssdk.codegen.poet.client.specs.XmlProtocolSpec; import software.amazon.awssdk.codegen.poet.model.ServiceClientConfigurationUtils; +import software.amazon.awssdk.codegen.poet.rules.EndpointRulesSpecUtils; import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; @@ -83,6 +85,8 @@ public class SyncClientClass extends SyncClientInterface { private final ProtocolSpec protocolSpec; private final ClassName serviceClientConfigurationClassName; private final ServiceClientConfigurationUtils configurationUtils; + private final AuthSchemeSpecUtils authSchemeSpecUtils; + private final EndpointRulesSpecUtils endpointRulesSpecUtils; public SyncClientClass(GeneratorTaskParams taskParams) { super(taskParams.getModel()); @@ -92,6 +96,8 @@ public SyncClientClass(GeneratorTaskParams taskParams) { this.protocolSpec = getProtocolSpecs(poetExtensions, model); this.serviceClientConfigurationClassName = new PoetExtension(model).getServiceConfigClass(); this.configurationUtils = new ServiceClientConfigurationUtils(model); + this.authSchemeSpecUtils = new AuthSchemeSpecUtils(model); + this.endpointRulesSpecUtils = new EndpointRulesSpecUtils(model); } @Override @@ -133,7 +139,9 @@ protected void addAdditionalMethods(TypeSpec.Builder type) { type.addMethod(constructor()) .addMethod(nameMethod()) .addMethods(protocolSpec.additionalMethods()) - .addMethod(resolveMetricPublishersMethod()); + .addMethod(resolveMetricPublishersMethod()) + .addMethod(ClientClassUtils.resolveAuthSchemeOptionsMethod(authSchemeSpecUtils, endpointRulesSpecUtils)) + .addMethod(ClientClassUtils.resolveEndpointMethod(authSchemeSpecUtils, endpointRulesSpecUtils)); protocolSpec.createErrorResponseHandler().ifPresent(type::addMethod); type.addMethod(ClientClassUtils.updateRetryStrategyClientConfigurationMethod()); diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java index b9b69d9bd8a0..57790e49d34d 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java @@ -221,7 +221,11 @@ public CodeBlock executionHandler(OperationModel opModel) { .add(credentialType(opModel, model)) .add(".withRequestConfiguration(clientConfiguration)") .add(".withInput($L)\n", opModel.getInput().getVariableName()) - .add(".withMetricCollector(apiCallMetricCollector)") + .add(".withMetricCollector(apiCallMetricCollector)\n") + .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", + opModel.getOperationName()) + .add(".withEndpointResolver((r, a) -> resolveEndpoint(r, a, $S))\n", + opModel.getOperationName()) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); @@ -295,6 +299,10 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper .add(".withErrorResponseHandler(errorResponseHandler)\n") .add(".withRequestConfiguration(clientConfiguration)") .add(".withMetricCollector(apiCallMetricCollector)\n") + .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", + opModel.getOperationName()) + .add(".withEndpointResolver((r, a) -> resolveEndpoint(r, a, $S))\n", + opModel.getOperationName()) .add(hostPrefixExpression(opModel)) .add(discoveredEndpoint(opModel)) .add(credentialType(opModel, model)) diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/QueryProtocolSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/QueryProtocolSpec.java index bab1f29e7281..1132bb911435 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/QueryProtocolSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/QueryProtocolSpec.java @@ -116,6 +116,10 @@ public CodeBlock executionHandler(OperationModel opModel) { .add(".withRequestConfiguration(clientConfiguration)") .add(".withInput($L)", opModel.getInput().getVariableName()) .add(".withMetricCollector(apiCallMetricCollector)") + .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", + opModel.getOperationName()) + .add(".withEndpointResolver((r, a) -> resolveEndpoint(r, a, $S))\n", + opModel.getOperationName()) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); @@ -155,6 +159,10 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper .add(credentialType(opModel, intermediateModel)) .add(".withRequestConfiguration(clientConfiguration)") .add(".withMetricCollector(apiCallMetricCollector)\n") + .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", + opModel.getOperationName()) + .add(".withEndpointResolver((r, a) -> resolveEndpoint(r, a, $S))\n", + opModel.getOperationName()) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/XmlProtocolSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/XmlProtocolSpec.java index dbbd33f4276f..207e38a49371 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/XmlProtocolSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/XmlProtocolSpec.java @@ -134,7 +134,11 @@ public CodeBlock executionHandler(OperationModel opModel) { discoveredEndpoint(opModel)) .add(credentialType(opModel, model)) .add(".withRequestConfiguration(clientConfiguration)") - .add(".withInput($L)", opModel.getInput().getVariableName()) + .add(".withInput($L)", opModel.getInput().getVariableName()) + .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", + opModel.getOperationName()) + .add(".withEndpointResolver((r, a) -> resolveEndpoint(r, a, $S))\n", + opModel.getOperationName()) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); @@ -212,7 +216,11 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper builder.add(hostPrefixExpression(opModel)) .add(credentialType(opModel, model)) - .add(".withMetricCollector(apiCallMetricCollector)\n") + .add(".withMetricCollector(apiCallMetricCollector)\n") + .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", + opModel.getOperationName()) + .add(".withEndpointResolver((r, a) -> resolveEndpoint(r, a, $S))\n", + opModel.getOperationName()) .add(asyncRequestBody(opModel)) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverUtilsSpec.java similarity index 76% rename from codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpec.java rename to codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverUtilsSpec.java index 726d61cdb99e..485b9c553788 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverUtilsSpec.java @@ -22,13 +22,11 @@ import com.fasterxml.jackson.jr.stree.JrsString; import com.squareup.javapoet.ClassName; import com.squareup.javapoet.CodeBlock; -import com.squareup.javapoet.FieldSpec; import com.squareup.javapoet.MethodSpec; import com.squareup.javapoet.ParameterizedTypeName; import com.squareup.javapoet.TypeName; import com.squareup.javapoet.TypeSpec; import com.squareup.javapoet.TypeVariableName; -import java.time.Duration; import java.util.Iterator; import java.util.List; import java.util.Locale; @@ -36,7 +34,6 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; -import java.util.concurrent.CompletionException; import javax.lang.model.element.Modifier; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.awscore.AwsExecutionAttribute; @@ -61,15 +58,9 @@ import software.amazon.awssdk.codegen.poet.waiters.JmesPathAcceptorGenerator; import software.amazon.awssdk.core.SdkRequest; import software.amazon.awssdk.core.SelectedAuthScheme; -import software.amazon.awssdk.core.exception.SdkClientException; -import software.amazon.awssdk.core.interceptor.Context; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; -import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; -import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; -import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.endpoints.Endpoint; -import software.amazon.awssdk.http.SdkHttpRequest; import software.amazon.awssdk.http.auth.aws.scheme.AwsV4AuthScheme; import software.amazon.awssdk.http.auth.aws.scheme.AwsV4aAuthScheme; import software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner; @@ -77,14 +68,18 @@ import software.amazon.awssdk.http.auth.aws.signer.RegionSet; import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.identity.spi.Identity; -import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.utils.AttributeMap; import software.amazon.awssdk.utils.CollectionUtils; import software.amazon.awssdk.utils.HostnameValidator; import software.amazon.awssdk.utils.StringUtils; import software.amazon.awssdk.utils.internal.CodegenNamingUtils; -public class EndpointResolverInterceptorSpec implements ClassSpec { +/** + * Generates a per-service endpoint resolver utility class (e.g. {@code DynamoDbEndpointResolverUtils}) + * containing all static helper methods for endpoint resolution: building endpoint params, host prefix + * resolution, auth scheme property copying, and business metrics. + */ +public class EndpointResolverUtilsSpec implements ClassSpec { private final IntermediateModel model; private final EndpointRulesSpecUtils endpointRulesSpecUtils; @@ -95,35 +90,28 @@ public class EndpointResolverInterceptorSpec implements ClassSpec { private final boolean multiAuthSigv4a; private final boolean legacyAuthFromEndpointRulesService; - - public EndpointResolverInterceptorSpec(IntermediateModel model) { + public EndpointResolverUtilsSpec(IntermediateModel model) { this.model = model; this.endpointRulesSpecUtils = new EndpointRulesSpecUtils(model); this.endpointParamsKnowledgeIndex = EndpointParamsKnowledgeIndex.of(model); this.poetExtension = new PoetExtension(model); this.jmesPathGenerator = new JmesPathAcceptorGenerator(poetExtension.jmesPathRuntimeClass()); - // We need to know whether the service has a dependency on the http-auth-aws module. Because we can't check that - // directly, assume that if they're using AwsV4AuthScheme or AwsV4aAuthScheme that it's available. Set> supportedAuthSchemes = ModelAuthSchemeClassesKnowledgeIndex.of(model).serviceConcreteAuthSchemeClasses(); this.dependsOnHttpAuthAws = supportedAuthSchemes.contains(AwsV4AuthScheme.class) || supportedAuthSchemes.contains(AwsV4aAuthScheme.class); - this.multiAuthSigv4a = new AuthSchemeSpecUtils(model).usesSigV4a(); this.legacyAuthFromEndpointRulesService = new AuthSchemeSpecUtils(model).generateEndpointBasedParams(); } @Override public TypeSpec poetSpec() { - FieldSpec endpointAuthSchemeStrategyFieldSpec = endpointAuthSchemeStrategyFieldSpec(); TypeSpec.Builder b = PoetUtils.createClassBuilder(className()) .addModifiers(Modifier.PUBLIC, Modifier.FINAL) .addAnnotation(SdkInternalApi.class) - .addSuperinterface(ExecutionInterceptor.class); + .addMethod(privateConstructor()); - b.addMethod(modifyRequestMethod(endpointAuthSchemeStrategyFieldSpec.name)); - b.addMethod(modifyHttpRequestMethod()); b.addMethod(ruleParams()); b.addMethod(setContextParams()); @@ -146,126 +134,17 @@ public TypeSpec poetSpec() { endpointParamsKnowledgeIndex.addAccountIdMethodsIfPresent(b); b.addMethod(setMetricValuesMethod()); + return b.build(); } @Override public ClassName className() { - return endpointRulesSpecUtils.resolverInterceptorName(); - } - - private FieldSpec endpointAuthSchemeStrategyFieldSpec() { - return FieldSpec.builder(endpointRulesSpecUtils.rulesRuntimeClassName("EndpointAuthSchemeStrategy"), - "endpointAuthSchemeStrategy", Modifier.PRIVATE, Modifier.FINAL) - .build(); - } - - private MethodSpec modifyRequestMethod(String endpointAuthSchemeStrategyFieldName) { - - MethodSpec.Builder b = MethodSpec.methodBuilder("modifyRequest") - .addModifiers(Modifier.PUBLIC) - .addAnnotation(Override.class) - .returns(SdkRequest.class) - .addParameter(Context.ModifyRequest.class, "context") - .addParameter(ExecutionAttributes.class, "executionAttributes"); - - String providerVar = "provider"; - - b.addStatement("$T result = context.request()", SdkRequest.class); - // We skip resolution if the source of the endpoint is the endpoint discovery call - b.beginControlFlow("if ($1T.endpointIsDiscovered(executionAttributes))", - endpointRulesSpecUtils.rulesRuntimeClassName("AwsEndpointProviderUtils")); - b.addStatement("return result"); - b.endControlFlow(); - - b.addStatement("$1T $2N = ($1T) executionAttributes.getAttribute($3T.ENDPOINT_PROVIDER)", - endpointRulesSpecUtils.providerInterfaceName(), providerVar, SdkInternalExecutionAttribute.class); - b.beginControlFlow("try"); - b.addStatement("long resolveEndpointStart = $T.nanoTime()", System.class); - b.addStatement("$T endpointParams = ruleParams(result, executionAttributes)", - endpointRulesSpecUtils.parametersClassName()); - b.addStatement("$T endpoint = $N.resolveEndpoint(endpointParams).join()", - Endpoint.class, providerVar); - b.addStatement("$1T resolveEndpointDuration = $1T.ofNanos($2T.nanoTime() - resolveEndpointStart)", Duration.class, - System.class); - b.addStatement("$T metricCollector = executionAttributes.getOptionalAttribute($T.API_CALL_METRIC_COLLECTOR)", - ParameterizedTypeName.get(Optional.class, MetricCollector.class), SdkExecutionAttribute.class); - b.addStatement("metricCollector.ifPresent(mc -> mc.reportMetric($T.ENDPOINT_RESOLVE_DURATION, resolveEndpointDuration))", - CoreMetric.class); - b.beginControlFlow("if (!$T.disableHostPrefixInjection(executionAttributes))", - endpointRulesSpecUtils.rulesRuntimeClassName("AwsEndpointProviderUtils")); - b.addStatement("$T hostPrefix = hostPrefix(executionAttributes.getAttribute($T.OPERATION_NAME), result)", - ParameterizedTypeName.get(Optional.class, String.class), SdkExecutionAttribute.class); - b.beginControlFlow("if (hostPrefix.isPresent())"); - b.addStatement("endpoint = $T.addHostPrefix(endpoint, hostPrefix.get())", - endpointRulesSpecUtils.rulesRuntimeClassName("AwsEndpointProviderUtils")); - b.endControlFlow(); - b.endControlFlow(); - - - // If the endpoint resolver returns auth settings, use them as signer properties. - // This effectively works to set the preSRA Signer ExecutionAttributes, so it is not conditional on useSraAuth. - b.addStatement("$T<$T> endpointAuthSchemes = endpoint.attribute($T.AUTH_SCHEMES)", - List.class, EndpointAuthScheme.class, AwsEndpointAttribute.class); - b.addStatement("$T selectedAuthScheme = executionAttributes.getAttribute($T.SELECTED_AUTH_SCHEME)", - SelectedAuthScheme.class, SdkInternalExecutionAttribute.class); - b.beginControlFlow("if (endpointAuthSchemes != null && selectedAuthScheme != null)"); - b.addStatement("selectedAuthScheme = authSchemeWithEndpointSignerProperties(endpointAuthSchemes, selectedAuthScheme)"); - if (multiAuthSigv4a || legacyAuthFromEndpointRulesService) { - b.addComment("Precedence of SigV4a RegionSet is set according to multi-auth SigV4a specifications"); - b.beginControlFlow("if(selectedAuthScheme.authSchemeOption().schemeId().equals($T.SCHEME_ID) " - + "&& selectedAuthScheme.authSchemeOption().signerProperty($T.REGION_SET) == null)", - AwsV4aAuthScheme.class, AwsV4aHttpSigner.class); - b.addStatement("$T optionBuilder = selectedAuthScheme.authSchemeOption().toBuilder()", - AuthSchemeOption.Builder.class); - b.addStatement("$T regionSet = $T.create(endpointParams.region().id())", - RegionSet.class, RegionSet.class); - b.addStatement("optionBuilder.putSignerProperty($T.REGION_SET, regionSet)", AwsV4aHttpSigner.class); - b.addStatement("selectedAuthScheme = new $T(selectedAuthScheme.identity(), selectedAuthScheme.signer(), " - + "optionBuilder.build())", SelectedAuthScheme.class); - b.endControlFlow(); - } - b.addStatement("executionAttributes.putAttribute($T.SELECTED_AUTH_SCHEME, selectedAuthScheme)", - SdkInternalExecutionAttribute.class); - b.endControlFlow(); - - b.addStatement("executionAttributes.putAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT, endpoint)"); - b.addStatement("setMetricValues(endpoint, executionAttributes)"); - b.addStatement("return result"); - b.endControlFlow(); - b.beginControlFlow("catch ($T e)", CompletionException.class); - b.addStatement("$T cause = e.getCause()", Throwable.class); - b.beginControlFlow("if (cause instanceof $T)", SdkClientException.class); - b.addStatement("throw ($T) cause", SdkClientException.class); - b.endControlFlow(); - b.beginControlFlow("else"); - b.addStatement("throw $T.create($S, cause)", SdkClientException.class, "Endpoint resolution failed"); - b.endControlFlow(); - b.endControlFlow(); - return b.build(); + return endpointRulesSpecUtils.endpointResolverUtilsName(); } - private MethodSpec modifyHttpRequestMethod() { - MethodSpec.Builder b = MethodSpec.methodBuilder("modifyHttpRequest") - .addModifiers(Modifier.PUBLIC) - .addAnnotation(Override.class) - .returns(SdkHttpRequest.class) - .addParameter(Context.ModifyHttpRequest.class, "context") - .addParameter(ExecutionAttributes.class, "executionAttributes"); - - b.addStatement("$T resolvedEndpoint = executionAttributes.getAttribute($T.RESOLVED_ENDPOINT)", - Endpoint.class, SdkInternalExecutionAttribute.class); - b.beginControlFlow("if (resolvedEndpoint.headers().isEmpty())"); - b.addStatement("return context.httpRequest()"); - b.endControlFlow(); - - b.addStatement("$T httpRequestBuilder = context.httpRequest().toBuilder()", SdkHttpRequest.Builder.class); - b.addCode("resolvedEndpoint.headers().forEach((name, values) -> {"); - b.addStatement("values.forEach(v -> httpRequestBuilder.appendHeader(name, v))"); - b.addCode("});"); - b.addStatement("return httpRequestBuilder.build()"); - - return b.build(); + private static MethodSpec privateConstructor() { + return MethodSpec.constructorBuilder().addModifiers(Modifier.PRIVATE).build(); } private MethodSpec ruleParams() { @@ -278,7 +157,6 @@ private MethodSpec ruleParams() { b.addStatement("$T builder = $T.builder()", paramsBuilderClass(), endpointRulesSpecUtils.parametersClassName()); Map parameters = model.getEndpointRuleSetModel().getParameters(); - parameters.forEach((n, m) -> { if (m.getBuiltInEnum() == null) { return; @@ -307,15 +185,11 @@ private MethodSpec ruleParams() { b.addStatement("builder.$N(executionAttributes.getAttribute($T.$N))", setter, AwsExecutionAttribute.class, model.getNamingStrategy().getEnumValueName(n)); break; - // The S3 specific built-ins are set through the existing S3Configuration and set through client context params case AWS_S3_ACCELERATE: case AWS_S3_DISABLE_MULTI_REGION_ACCESS_POINTS: case AWS_S3_FORCE_PATH_STYLE: case AWS_S3_USE_ARN_REGION: case AWS_S3_CONTROL_USE_ARN_REGION: - // end of S3 specific builtins - - // V2 doesn't support this, only regional endpoints case AWS_STS_USE_GLOBAL_ENDPOINT: return; default: @@ -339,85 +213,77 @@ private MethodSpec ruleParams() { private CodeBlock endpointProviderUtilsSetter(String builtInFn, String setterName) { return CodeBlock.of("builder.$N($T.$N(executionAttributes))", setterName, - endpointRulesSpecUtils.rulesRuntimeClassName("AwsEndpointProviderUtils"), builtInFn); + endpointRulesSpecUtils.sharedAwsEndpointProviderUtilsName(), builtInFn); } private ClassName paramsBuilderClass() { return endpointRulesSpecUtils.parametersClassName().nestedClass("Builder"); } - private MethodSpec addStaticContextParamsMethod(OperationModel opModel) { - String methodName = staticContextParamsMethodName(opModel); + private MethodSpec setContextParams() { + Map operations = model.getOperations(); - MethodSpec.Builder b = MethodSpec.methodBuilder(methodName) + MethodSpec.Builder b = MethodSpec.methodBuilder("setContextParams") .addModifiers(Modifier.PRIVATE, Modifier.STATIC) - .returns(void.class) - .addParameter(paramsBuilderClass(), "params"); + .addParameter(paramsBuilderClass(), "params") + .addParameter(String.class, "operationName") + .addParameter(SdkRequest.class, "request") + .returns(void.class); - opModel.getStaticContextParams().forEach((n, m) -> { - String setterName = endpointRulesSpecUtils.paramMethodName(n); - TreeNode value = m.getValue(); - switch (value.asToken()) { - case VALUE_STRING: - b.addStatement("params.$N($S)", setterName, ((JrsString) value).getValue()); - break; - case VALUE_TRUE: - case VALUE_FALSE: - b.addStatement("params.$N($L)", setterName, ((JrsBoolean) value).booleanValue()); - break; - case START_ARRAY: - JrsArray arrayValue = (JrsArray) value; - CodeBlock arrayCode = endpointRulesSpecUtils.treeNodeToLiteral(arrayValue); - b.addStatement("params.$N($L)", setterName, arrayCode); - break; - default: - throw new RuntimeException("Don't know how to set parameter of type " + value.asToken()); - } - }); + boolean generateSwitch = operations.values().stream().anyMatch(this::hasContextParams); + if (generateSwitch) { + b.beginControlFlow("switch (operationName)"); + operations.forEach((n, m) -> { + if (!hasContextParams(m)) { + return; + } + String requestClassName = model.getNamingStrategy().getRequestClassName(m.getOperationName()); + ClassName requestClass = poetExtension.getModelClass(requestClassName); + b.addCode("case $S:", n); + b.addStatement("setContextParams(params, ($T) request)", requestClass); + b.addStatement("break"); + }); + b.addCode("default:"); + b.addStatement("break"); + b.endControlFlow(); + } return b.build(); } - private String staticContextParamsMethodName(OperationModel opModel) { - return opModel.getMethodName() + "StaticContextParams"; - } - - private boolean hasStaticContextParams(OperationModel opModel) { - Map staticContextParams = opModel.getStaticContextParams(); - return staticContextParams != null && !staticContextParams.isEmpty(); - } - - private boolean hasOperationContextParams(OperationModel opModel) { - return CollectionUtils.isNotEmpty(opModel.getOperationContextParams()); - } - - private void addStaticContextParamMethods(TypeSpec.Builder classBuilder) { - Map operations = model.getOperations(); - - operations.forEach((n, m) -> { - if (hasStaticContextParams(m)) { - classBuilder.addMethod(addStaticContextParamsMethod(m)); - } - }); - } - private void addContextParamMethods(TypeSpec.Builder classBuilder) { - Map operations = model.getOperations(); - - operations.forEach((n, m) -> { + model.getOperations().forEach((n, m) -> { if (hasContextParams(m)) { classBuilder.addMethod(setContextParamsMethod(m)); } }); } - private void addOperationContextParamMethods(TypeSpec.Builder classBuilder) { - Map operations = model.getOperations(); - operations.forEach((n, m) -> { - if (hasOperationContextParams(m)) { - classBuilder.addMethod(setOperationContextParamsMethod(m)); + private MethodSpec setContextParamsMethod(OperationModel opModel) { + String requestClassName = model.getNamingStrategy().getRequestClassName(opModel.getOperationName()); + ClassName requestClass = poetExtension.getModelClass(requestClassName); + + MethodSpec.Builder b = MethodSpec.methodBuilder("setContextParams") + .addModifiers(Modifier.PRIVATE, Modifier.STATIC) + .addParameter(paramsBuilderClass(), "params") + .addParameter(requestClass, "request") + .returns(void.class); + + opModel.getInputShape().getMembers().forEach(m -> { + ContextParam param = m.getContextParam(); + if (param == null) { + return; } + String setterName = endpointRulesSpecUtils.paramMethodName(param.getName()); + b.addStatement("params.$N(request.$N())", setterName, m.getFluentGetterMethodName()); }); + + return b.build(); + } + + private boolean hasContextParams(OperationModel opModel) { + return opModel.getInputShape().getMembers().stream() + .anyMatch(m -> m.getContextParam() != null); } private MethodSpec setStaticContextParamsMethod() { @@ -432,12 +298,10 @@ private MethodSpec setStaticContextParamsMethod() { boolean generateSwitch = operations.values().stream().anyMatch(this::hasStaticContextParams); if (generateSwitch) { b.beginControlFlow("switch (operationName)"); - operations.forEach((n, m) -> { if (!hasStaticContextParams(m)) { return; } - b.addCode("case $S:", n); b.addStatement("$N(params)", staticContextParamsMethodName(m)); b.addStatement("break"); @@ -450,38 +314,53 @@ private MethodSpec setStaticContextParamsMethod() { return b.build(); } - private MethodSpec setContextParams() { - Map operations = model.getOperations(); + private void addStaticContextParamMethods(TypeSpec.Builder classBuilder) { + model.getOperations().forEach((n, m) -> { + if (hasStaticContextParams(m)) { + classBuilder.addMethod(addStaticContextParamsMethod(m)); + } + }); + } - MethodSpec.Builder b = MethodSpec.methodBuilder("setContextParams") - .addModifiers(Modifier.PRIVATE, Modifier.STATIC) - .addParameter(paramsBuilderClass(), "params") - .addParameter(String.class, "operationName") - .addParameter(SdkRequest.class, "request") - .returns(void.class); + private MethodSpec addStaticContextParamsMethod(OperationModel opModel) { + String methodName = staticContextParamsMethodName(opModel); - boolean generateSwitch = operations.values().stream().anyMatch(this::hasContextParams); - if (generateSwitch) { - b.beginControlFlow("switch (operationName)"); + MethodSpec.Builder b = MethodSpec.methodBuilder(methodName) + .addModifiers(Modifier.PRIVATE, Modifier.STATIC) + .returns(void.class) + .addParameter(paramsBuilderClass(), "params"); - operations.forEach((n, m) -> { - if (!hasContextParams(m)) { - return; - } + opModel.getStaticContextParams().forEach((n, m) -> { + String setterName = endpointRulesSpecUtils.paramMethodName(n); + TreeNode value = m.getValue(); + switch (value.asToken()) { + case VALUE_STRING: + b.addStatement("params.$N($S)", setterName, ((JrsString) value).getValue()); + break; + case VALUE_TRUE: + case VALUE_FALSE: + b.addStatement("params.$N($L)", setterName, ((JrsBoolean) value).booleanValue()); + break; + case START_ARRAY: + JrsArray arrayValue = (JrsArray) value; + CodeBlock arrayCode = endpointRulesSpecUtils.treeNodeToLiteral(arrayValue); + b.addStatement("params.$N($L)", setterName, arrayCode); + break; + default: + throw new RuntimeException("Don't know how to set parameter of type " + value.asToken()); + } + }); - String requestClassName = model.getNamingStrategy().getRequestClassName(m.getOperationName()); - ClassName requestClass = poetExtension.getModelClass(requestClassName); + return b.build(); + } - b.addCode("case $S:", n); - b.addStatement("setContextParams(params, ($T) request)", requestClass); - b.addStatement("break"); - }); - b.addCode("default:"); - b.addStatement("break"); - b.endControlFlow(); - } + private String staticContextParamsMethodName(OperationModel opModel) { + return opModel.getMethodName() + "StaticContextParams"; + } - return b.build(); + private boolean hasStaticContextParams(OperationModel opModel) { + Map staticContextParams = opModel.getStaticContextParams(); + return staticContextParams != null && !staticContextParams.isEmpty(); } private MethodSpec setOperationContextParams() { @@ -497,15 +376,12 @@ private MethodSpec setOperationContextParams() { boolean generateSwitch = operations.values().stream().anyMatch(this::hasOperationContextParams); if (generateSwitch) { b.beginControlFlow("switch (operationName)"); - operations.forEach((n, m) -> { if (!hasOperationContextParams(m)) { return; } - String requestClassName = model.getNamingStrategy().getRequestClassName(m.getOperationName()); ClassName requestClass = poetExtension.getModelClass(requestClassName); - b.addCode("case $S:", n); b.addStatement("setOperationContextParams(params, ($T) request)", requestClass); b.addStatement("break"); @@ -518,28 +394,12 @@ private MethodSpec setOperationContextParams() { return b.build(); } - private MethodSpec setContextParamsMethod(OperationModel opModel) { - String requestClassName = model.getNamingStrategy().getRequestClassName(opModel.getOperationName()); - ClassName requestClass = poetExtension.getModelClass(requestClassName); - - MethodSpec.Builder b = MethodSpec.methodBuilder("setContextParams") - .addModifiers(Modifier.PRIVATE, Modifier.STATIC) - .addParameter(paramsBuilderClass(), "params") - .addParameter(requestClass, "request") - .returns(void.class); - - opModel.getInputShape().getMembers().forEach(m -> { - ContextParam param = m.getContextParam(); - if (param == null) { - return; + private void addOperationContextParamMethods(TypeSpec.Builder classBuilder) { + model.getOperations().forEach((n, m) -> { + if (hasOperationContextParams(m)) { + classBuilder.addMethod(setOperationContextParamsMethod(m)); } - - String setterName = endpointRulesSpecUtils.paramMethodName(param.getName()); - - b.addStatement("params.$N(request.$N())", setterName, m.getFluentGetterMethodName()); }); - - return b.build(); } private MethodSpec setOperationContextParamsMethod(OperationModel opModel) { @@ -557,7 +417,6 @@ private MethodSpec setOperationContextParamsMethod(OperationModel opModel) { opModel.getOperationContextParams().forEach((key, value) -> { if (Objects.requireNonNull(value.getPath().asToken()) == JsonToken.VALUE_STRING) { String setterName = endpointRulesSpecUtils.paramMethodName(key); - String jmesPathString = ((JrsString) value.getPath()).getValue(); CodeBlock addParam = CodeBlock.builder() .add("params.$N(", setterName) @@ -565,18 +424,20 @@ private MethodSpec setOperationContextParamsMethod(OperationModel opModel) { .add(matchToParameterType(key)) .add(")") .build(); - b.addStatement(addParam); } else { throw new RuntimeException("Invalid operation context parameter path for " + opModel.getOperationName() + ". Expected VALUE_STRING, but got " + value.getPath().asToken()); } - }); return b.build(); } + private boolean hasOperationContextParams(OperationModel opModel) { + return CollectionUtils.isNotEmpty(opModel.getOperationContextParams()); + } + private CodeBlock matchToParameterType(String paramName) { Map parameters = model.getEndpointRuleSetModel().getParameters(); Optional endpointParameter = parameters.entrySet().stream() @@ -601,11 +462,6 @@ private CodeBlock convertValueToParameterType(ParameterModel parameterModel) { } } - private boolean hasContextParams(OperationModel opModel) { - return opModel.getInputShape().getMembers().stream() - .anyMatch(m -> m.getContextParam() != null); - } - private boolean hasClientContextParams() { Map clientContextParams = model.getClientContextParams(); return clientContextParams != null && !clientContextParams.isEmpty(); @@ -627,20 +483,18 @@ private MethodSpec setClientContextParamsMethod() { params.forEach((n, m) -> { String attrName = endpointRulesSpecUtils.clientContextParamName(n); b.addStatement("$T.ofNullable(clientContextParams.get($T.$N)).ifPresent(params::$N)", Optional.class, paramsClass, - attrName, - endpointRulesSpecUtils.paramMethodName(n)); + attrName, endpointRulesSpecUtils.paramMethodName(n)); }); return b.build(); } - private MethodSpec hostPrefixMethod() { MethodSpec.Builder builder = MethodSpec.methodBuilder("hostPrefix") .returns(ParameterizedTypeName.get(Optional.class, String.class)) .addParameter(String.class, "operationName") .addParameter(SdkRequest.class, "request") - .addModifiers(Modifier.PRIVATE, Modifier.STATIC); + .addModifiers(Modifier.PUBLIC, Modifier.STATIC); boolean generateSwitch = model.getOperations().values().stream().anyMatch(opModel -> StringUtils.isNotBlank(getHostPrefix(opModel))); @@ -649,16 +503,13 @@ private MethodSpec hostPrefixMethod() { builder.addStatement("return $T.empty()", Optional.class); } else { builder.beginControlFlow("switch (operationName)"); - model.getOperations().forEach((name, opModel) -> { String hostPrefix = getHostPrefix(opModel); if (StringUtils.isBlank(hostPrefix)) { return; } - builder.beginControlFlow("case $S:", name); HostPrefixProcessor processor = new HostPrefixProcessor(hostPrefix); - if (processor.c2jNames().isEmpty()) { builder.addStatement("return $T.of($S)", Optional.class, processor.hostWithStringSpecifier()); } else { @@ -666,12 +517,8 @@ private MethodSpec hostPrefixMethod() { processor.c2jNames().forEach(c2jName -> { builder.addStatement("$1T.validateHostnameCompliant(request.getValueForField($2S, $3T.class)" + ".orElse(null), $2S, $4S)", - HostnameValidator.class, - c2jName, - String.class, - requestVar); + HostnameValidator.class, c2jName, String.class, requestVar); }); - builder.addCode("return $T.of($T.format($S, ", Optional.class, String.class, processor.hostWithStringSpecifier()); Iterator c2jNamesIter = processor.c2jNames().listIterator(); @@ -685,7 +532,6 @@ private MethodSpec hostPrefixMethod() { } builder.endControlFlow(); }); - builder.addCode("default:"); builder.addStatement("return $T.empty()", Optional.class); builder.endControlFlow(); @@ -699,7 +545,6 @@ private String getHostPrefix(OperationModel opModel) { if (endpointTrait == null) { return null; } - return endpointTrait.getHostPrefix(); } @@ -711,15 +556,13 @@ private MethodSpec authSchemeWithEndpointSignerPropertiesMethod() { MethodSpec.Builder method = MethodSpec.methodBuilder("authSchemeWithEndpointSignerProperties") - .addModifiers(Modifier.PRIVATE) + .addModifiers(Modifier.PUBLIC, Modifier.STATIC) .addTypeVariable(tExtendsIdentity) .returns(selectedAuthSchemeOfT) .addParameter(listOfEndpointAuthScheme, "endpointAuthSchemes") .addParameter(selectedAuthSchemeOfT, "selectedAuthScheme"); method.beginControlFlow("for ($T endpointAuthScheme : endpointAuthSchemes)", EndpointAuthScheme.class); - - // Don't include signer properties for auth options that don't match our selected auth scheme method.beginControlFlow("if (!endpointAuthScheme.schemeId()" + ".equals(selectedAuthScheme.authSchemeOption().schemeId()))"); method.addStatement("continue"); @@ -738,9 +581,7 @@ private MethodSpec authSchemeWithEndpointSignerPropertiesMethod() { method.addStatement("throw new $T(\"Endpoint auth scheme '\" + endpointAuthScheme.name() + \"' cannot be mapped to the " + "SDK auth scheme. Was it declared in the service's model?\")", IllegalArgumentException.class); - method.endControlFlow(); - method.addStatement("return selectedAuthScheme"); return method.build(); @@ -750,34 +591,27 @@ private static CodeBlock copyV4EndpointSignerPropertiesToAuth() { CodeBlock.Builder code = CodeBlock.builder(); code.beginControlFlow("if (endpointAuthScheme instanceof $T)", SigV4AuthScheme.class); code.addStatement("$1T v4AuthScheme = ($1T) endpointAuthScheme", SigV4AuthScheme.class); - code.beginControlFlow("if (v4AuthScheme.isDisableDoubleEncodingSet())"); code.addStatement("option.putSignerProperty($T.DOUBLE_URL_ENCODE, !v4AuthScheme.disableDoubleEncoding())", AwsV4HttpSigner.class); code.endControlFlow(); - code.beginControlFlow("if (v4AuthScheme.signingRegion() != null)"); - code.addStatement("option.putSignerProperty($T.REGION_NAME, v4AuthScheme.signingRegion())", - AwsV4HttpSigner.class); + code.addStatement("option.putSignerProperty($T.REGION_NAME, v4AuthScheme.signingRegion())", AwsV4HttpSigner.class); code.endControlFlow(); - code.beginControlFlow("if (v4AuthScheme.signingName() != null)"); code.addStatement("option.putSignerProperty($T.SERVICE_SIGNING_NAME, v4AuthScheme.signingName())", AwsV4HttpSigner.class); code.endControlFlow(); - code.addStatement("return new $T<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build())", SelectedAuthScheme.class); code.endControlFlow(); return code.build(); } - private CodeBlock copyV4aEndpointSignerPropertiesToAuth() { + private CodeBlock copyV4aEndpointSignerPropertiesToAuth() { CodeBlock.Builder code = CodeBlock.builder(); - code.beginControlFlow("if (endpointAuthScheme instanceof $T)", SigV4aAuthScheme.class); code.addStatement("$1T v4aAuthScheme = ($1T) endpointAuthScheme", SigV4aAuthScheme.class); - code.beginControlFlow("if (v4aAuthScheme.isDisableDoubleEncodingSet())"); code.addStatement("option.putSignerProperty($T.DOUBLE_URL_ENCODE, !v4aAuthScheme.disableDoubleEncoding())", AwsV4aHttpSigner.class); @@ -793,12 +627,10 @@ private CodeBlock copyV4aEndpointSignerPropertiesToAuth() { code.addStatement("$1T regionSet = $1T.create(v4aAuthScheme.signingRegionSet())", RegionSet.class); code.addStatement("option.putSignerProperty($T.REGION_SET, regionSet)", AwsV4aHttpSigner.class); code.endControlFlow(); - code.beginControlFlow("if (v4aAuthScheme.signingName() != null)"); code.addStatement("option.putSignerProperty($T.SERVICE_SIGNING_NAME, v4aAuthScheme.signingName())", AwsV4aHttpSigner.class); code.endControlFlow(); - code.addStatement("return new $T<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build())", SelectedAuthScheme.class); code.endControlFlow(); @@ -812,22 +644,18 @@ private CodeBlock copyS3ExpressEndpointSignerPropertiesToAuth() { "S3ExpressEndpointAuthScheme"); code.beginControlFlow("if (endpointAuthScheme instanceof $T)", s3ExpressEndpointAuthSchemeClassName); code.addStatement("$1T s3ExpressAuthScheme = ($1T) endpointAuthScheme", s3ExpressEndpointAuthSchemeClassName); - code.beginControlFlow("if (s3ExpressAuthScheme.isDisableDoubleEncodingSet())"); code.addStatement("option.putSignerProperty($T.DOUBLE_URL_ENCODE, !s3ExpressAuthScheme.disableDoubleEncoding())", AwsV4HttpSigner.class); code.endControlFlow(); - code.beginControlFlow("if (s3ExpressAuthScheme.signingRegion() != null)"); code.addStatement("option.putSignerProperty($T.REGION_NAME, s3ExpressAuthScheme.signingRegion())", AwsV4HttpSigner.class); code.endControlFlow(); - code.beginControlFlow("if (s3ExpressAuthScheme.signingName() != null)"); code.addStatement("option.putSignerProperty($T.SERVICE_SIGNING_NAME, s3ExpressAuthScheme.signingName())", AwsV4HttpSigner.class); code.endControlFlow(); - code.addStatement("return new $T<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build())", SelectedAuthScheme.class); code.endControlFlow(); @@ -836,7 +664,7 @@ private CodeBlock copyS3ExpressEndpointSignerPropertiesToAuth() { private MethodSpec setMetricValuesMethod() { MethodSpec.Builder b = MethodSpec.methodBuilder("setMetricValues") - .addModifiers(Modifier.PRIVATE) + .addModifiers(Modifier.PUBLIC, Modifier.STATIC) .addParameter(Endpoint.class, "endpoint") .addParameter(ExecutionAttributes.class, "executionAttributes") .returns(void.class); @@ -851,7 +679,7 @@ private MethodSpec setMetricValuesMethod() { b.addStatement("$T.addS3ExpressBusinessMetricIfApplicable(executionAttributes)", ClassName.get("software.amazon.awssdk.services.s3.internal.s3express", "S3ExpressUtils")); } - + return b.build(); } } diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointRulesSpecUtils.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointRulesSpecUtils.java index dfcf68b056fd..25f483fa4829 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointRulesSpecUtils.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointRulesSpecUtils.java @@ -93,6 +93,16 @@ public ClassName resolverInterceptorName() { md.getServiceName() + "ResolveEndpointInterceptor"); } + public ClassName endpointResolverUtilsName() { + Metadata md = intermediateModel.getMetadata(); + return ClassName.get(md.getFullInternalEndpointRulesPackageName(), + md.getServiceName() + "EndpointResolverUtils"); + } + + public ClassName sharedAwsEndpointProviderUtilsName() { + return ClassName.get("software.amazon.awssdk.awscore.internal.endpoints", "AwsEndpointProviderUtils"); + } + public ClassName requestModifierInterceptorName() { Metadata md = intermediateModel.getMetadata(); return ClassName.get(md.getFullInternalEndpointRulesPackageName(), diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/RequestEndpointInterceptorSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/RequestEndpointInterceptorSpec.java deleted file mode 100644 index 63b5133f180e..000000000000 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/RequestEndpointInterceptorSpec.java +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package software.amazon.awssdk.codegen.poet.rules; - -import com.squareup.javapoet.ClassName; -import com.squareup.javapoet.MethodSpec; -import com.squareup.javapoet.TypeSpec; -import javax.lang.model.element.Modifier; -import software.amazon.awssdk.annotations.SdkInternalApi; -import software.amazon.awssdk.codegen.model.intermediate.IntermediateModel; -import software.amazon.awssdk.codegen.poet.ClassSpec; -import software.amazon.awssdk.codegen.poet.PoetUtils; -import software.amazon.awssdk.core.interceptor.Context; -import software.amazon.awssdk.core.interceptor.ExecutionAttributes; -import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; -import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; -import software.amazon.awssdk.endpoints.Endpoint; -import software.amazon.awssdk.http.SdkHttpRequest; - -public class RequestEndpointInterceptorSpec implements ClassSpec { - private final EndpointRulesSpecUtils endpointRulesSpecUtils; - - public RequestEndpointInterceptorSpec(IntermediateModel model) { - this.endpointRulesSpecUtils = new EndpointRulesSpecUtils(model); - } - - @Override - public TypeSpec poetSpec() { - TypeSpec.Builder b = PoetUtils.createClassBuilder(className()) - .addModifiers(Modifier.PUBLIC, Modifier.FINAL) - .addAnnotation(SdkInternalApi.class) - .addSuperinterface(ExecutionInterceptor.class); - - b.addMethod(modifyHttpRequestMethod()); - - return b.build(); - } - - @Override - public ClassName className() { - return endpointRulesSpecUtils.requestModifierInterceptorName(); - } - - private MethodSpec modifyHttpRequestMethod() { - - MethodSpec.Builder b = MethodSpec.methodBuilder("modifyHttpRequest") - .addModifiers(Modifier.PUBLIC) - .addAnnotation(Override.class) - .returns(SdkHttpRequest.class) - .addParameter(Context.ModifyHttpRequest.class, "context") - .addParameter(ExecutionAttributes.class, "executionAttributes"); - - - // We skip setting the endpoint here if the source of the endpoint is the endpoint discovery call - b.beginControlFlow("if ($1T.endpointIsDiscovered(executionAttributes))", - endpointRulesSpecUtils.rulesRuntimeClassName("AwsEndpointProviderUtils")) - .addStatement("return context.httpRequest()") - .endControlFlow().build(); - - b.addStatement("$T endpoint = executionAttributes.getAttribute($T.RESOLVED_ENDPOINT)", - Endpoint.class, - SdkInternalExecutionAttribute.class); - b.addStatement("return $T.setUri(context.httpRequest()," - + "executionAttributes.getAttribute($T.CLIENT_ENDPOINT_PROVIDER).clientEndpoint()," - + "endpoint.url())", - endpointRulesSpecUtils.rulesRuntimeClassName("AwsEndpointProviderUtils"), - SdkInternalExecutionAttribute.class); - return b.build(); - } - -} diff --git a/codegen/src/test/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeSpecTest.java b/codegen/src/test/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeSpecTest.java index 6b095f01ddc6..067d5d1673ed 100644 --- a/codegen/src/test/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeSpecTest.java +++ b/codegen/src/test/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeSpecTest.java @@ -187,28 +187,6 @@ static List parameters() { .caseName("auth-with-legacy-trait") .outputFileSuffix("default-provider") .build(), - // Interceptors - // - Normal case - TestCase.builder() - .modelProvider(ClientTestModels::queryServiceModels) - .classSpecProvider(AuthSchemeInterceptorSpec::new) - .caseName("query") - .outputFileSuffix("interceptor") - .build(), - // - Endpoints based params with allow list - TestCase.builder() - .modelProvider(ClientTestModels::queryServiceModelsEndpointAuthParamsWithAllowList) - .classSpecProvider(AuthSchemeInterceptorSpec::new) - .caseName("query-endpoint-auth-params-with-allowlist") - .outputFileSuffix("interceptor") - .build(), - // - Endpoints based params without allow list - TestCase.builder() - .modelProvider(ClientTestModels::queryServiceModelsEndpointAuthParamsWithoutAllowList) - .classSpecProvider(AuthSchemeInterceptorSpec::new) - .caseName("query-endpoint-auth-params-without-allowlist") - .outputFileSuffix("interceptor") - .build(), // Service with auth trait with Sigv4a TestCase.builder() .modelProvider(ClientTestModels::opsWithSigv4a) @@ -228,19 +206,6 @@ static List parameters() { .caseName("ops-auth-sigv4a-value") .outputFileSuffix("default-params") .build(), - TestCase.builder() - .modelProvider(ClientTestModels::opsWithSigv4a) - .classSpecProvider(AuthSchemeInterceptorSpec::new) - .caseName("ops-auth-sigv4a-value") - .outputFileSuffix("interceptor") - .build(), - // service with environment bearer token enabled - TestCase.builder() - .modelProvider(ClientTestModels::envBearerTokenServiceModels) - .classSpecProvider(AuthSchemeInterceptorSpec::new) - .caseName("env-bearer-token") - .outputFileSuffix("interceptor") - .build(), // Rest Json service with checksum TestCase.builder() .modelProvider(ClientTestModels::restJsonServiceModels) diff --git a/codegen/src/test/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpecTest.java b/codegen/src/test/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpecTest.java deleted file mode 100644 index 89f87ff41952..000000000000 --- a/codegen/src/test/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpecTest.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package software.amazon.awssdk.codegen.poet.rules; - -import static org.hamcrest.MatcherAssert.assertThat; -import static software.amazon.awssdk.codegen.poet.PoetMatchers.generatesTo; - -import org.junit.jupiter.api.Test; -import software.amazon.awssdk.codegen.model.intermediate.IntermediateModel; -import software.amazon.awssdk.codegen.poet.ClassSpec; -import software.amazon.awssdk.codegen.poet.ClientTestModels; - -public class EndpointResolverInterceptorSpecTest { - @Test - public void endpointResolverInterceptorClass() { - ClassSpec endpointProviderInterceptor = new EndpointResolverInterceptorSpec(getModel()); - assertThat(endpointProviderInterceptor, generatesTo("endpoint-resolve-interceptor.java")); - } - - private static IntermediateModel getModel() { - IntermediateModel model = ClientTestModels.queryServiceModels(); - return model; - } - - @Test - void endpointResolverInterceptorClassWithSigv4aMultiAuth() { - ClassSpec endpointProviderInterceptor = new EndpointResolverInterceptorSpec(ClientTestModels.opsWithSigv4a()); - assertThat(endpointProviderInterceptor, generatesTo("endpoint-resolve-interceptor-with-multiauthsigv4a.java")); - } - - @Test - void endpointResolverInterceptorClassWithEndpointBasedAuth() { - ClassSpec endpointProviderInterceptor = new EndpointResolverInterceptorSpec(ClientTestModels.queryServiceModelsEndpointAuthParamsWithoutAllowList()); - assertThat(endpointProviderInterceptor, generatesTo("endpoint-resolve-interceptor-with-endpointsbasedauth.java")); - } - - @Test - public void endpointProviderTestClassWithStringArray() { - ClassSpec endpointProviderInterceptor = new EndpointResolverInterceptorSpec(ClientTestModels.stringArrayServiceModels()); - assertThat(endpointProviderInterceptor, generatesTo("endpoint-resolve-interceptor-with-stringarray.java")); - } -} diff --git a/codegen/src/test/java/software/amazon/awssdk/codegen/poet/rules/RequestEndpointProviderInterceptorSpecTest.java b/codegen/src/test/java/software/amazon/awssdk/codegen/poet/rules/RequestEndpointProviderInterceptorSpecTest.java deleted file mode 100644 index f4f2434340c9..000000000000 --- a/codegen/src/test/java/software/amazon/awssdk/codegen/poet/rules/RequestEndpointProviderInterceptorSpecTest.java +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package software.amazon.awssdk.codegen.poet.rules; - -import static org.hamcrest.MatcherAssert.assertThat; -import static software.amazon.awssdk.codegen.poet.PoetMatchers.generatesTo; - -import org.junit.jupiter.api.Test; -import software.amazon.awssdk.codegen.poet.ClassSpec; -import software.amazon.awssdk.codegen.poet.ClientTestModels; - -public class RequestEndpointProviderInterceptorSpecTest { - @Test - public void requestSetEndpointInterceptorClass() { - ClassSpec endpointProviderInterceptor = new RequestEndpointInterceptorSpec(ClientTestModels.queryServiceModels()); - assertThat(endpointProviderInterceptor, generatesTo("request-set-endpoint-interceptor.java")); - } -} diff --git a/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilder.java b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilder.java index 32273f16019c..b3ce212efbee 100644 --- a/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilder.java +++ b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilder.java @@ -29,10 +29,10 @@ import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.auth.signer.AwsSignerExecutionAttribute; import software.amazon.awssdk.awscore.AwsExecutionAttribute; -import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.internal.authcontext.AuthorizationStrategy; import software.amazon.awssdk.awscore.internal.authcontext.AuthorizationStrategyFactory; +import software.amazon.awssdk.awscore.internal.identity.AwsIdentityProviderUpdater; import software.amazon.awssdk.awscore.util.SignerOverrideUtils; import software.amazon.awssdk.core.HttpChecksumConstant; import software.amazon.awssdk.core.RequestOverrideConfiguration; @@ -144,6 +144,20 @@ private AwsExecutionContextBuilder() { // Auth Scheme resolution related attributes putAuthSchemeResolutionAttributes(executionAttributes, clientConfig, originalRequest); + if (executionParams.authSchemeOptionsResolver() != null) { + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER, + executionParams.authSchemeOptionsResolver()); + } + + if (executionParams.endpointResolver() != null) { + executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER, + executionParams.endpointResolver()); + } + + // Set the identity provider updater for the pipeline stage to use + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDER_UPDATER, + AwsIdentityProviderUpdater.create()); + ExecutionInterceptorChain executionInterceptorChain = new ExecutionInterceptorChain(clientConfig.option(SdkClientOption.EXECUTION_INTERCEPTORS)); @@ -273,7 +287,7 @@ private static void putAuthSchemeResolutionAttributes(ExecutionAttributes execut // request preferred over client. Map> authSchemes = clientConfig.option(SdkClientOption.AUTH_SCHEMES); - IdentityProviders identityProviders = resolveIdentityProviders(originalRequest, clientConfig); + IdentityProviders identityProviders = clientConfig.option(SdkClientOption.IDENTITY_PROVIDERS); executionAttributes .putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_RESOLVER, authSchemeProvider) @@ -281,31 +295,6 @@ private static void putAuthSchemeResolutionAttributes(ExecutionAttributes execut .putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, identityProviders); } - private static IdentityProviders resolveIdentityProviders(SdkRequest originalRequest, - SdkClientConfiguration clientConfig) { - IdentityProviders identityProviders = - clientConfig.option(SdkClientOption.IDENTITY_PROVIDERS); - - // identityProviders can be null, for new core with old client. In this case, even if AwsRequestOverrideConfiguration - // has credentialsIdentityProvider set (because it is in new core), it is ok to not setup IDENTITY_PROVIDERS, as old - // client won't have AUTH_SCHEME_PROVIDER/AUTH_SCHEMES set either, which are also needed for SRA logic. - if (identityProviders == null) { - return null; - } - - return originalRequest - .overrideConfiguration() - .filter(c -> c instanceof AwsRequestOverrideConfiguration) - .map(c -> (AwsRequestOverrideConfiguration) c) - .map(c -> { - return identityProviders.copy(b -> { - c.credentialsIdentityProvider().ifPresent(b::putIdentityProvider); - c.tokenIdentityProvider().ifPresent(b::putIdentityProvider); - }); - }) - .orElse(identityProviders); - } - /** * Finalize {@link SdkRequest} by running beforeExecution and modifyRequest interceptors. * @@ -355,7 +344,7 @@ private static EndpointProvider resolveEndpointProvider(SdkRequest request, } private static BusinessMetricCollection - resolveUserAgentBusinessMetrics(SdkClientConfiguration clientConfig, + resolveUserAgentBusinessMetrics(SdkClientConfiguration clientConfig, ClientExecutionParams executionParams) { BusinessMetricCollection businessMetrics = new BusinessMetricCollection(); Optional retryModeMetric = resolveRetryMode(clientConfig.option(RETRY_POLICY), @@ -373,4 +362,5 @@ private static boolean isRpcV2CborProtocol(SdkProtocolMetadata protocolMetadata) return protocolMetadata != null && SMITHY_RPC_V2_CBOR.toString().equals(protocolMetadata.serviceProtocol()); } + } diff --git a/codegen/src/main/resources/software/amazon/awssdk/codegen/rules/AwsEndpointProviderUtils.java.resource b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/endpoints/AwsEndpointProviderUtils.java similarity index 54% rename from codegen/src/main/resources/software/amazon/awssdk/codegen/rules/AwsEndpointProviderUtils.java.resource rename to core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/endpoints/AwsEndpointProviderUtils.java index 1d0fedb9f7ff..1cd27489455f 100644 --- a/codegen/src/main/resources/software/amazon/awssdk/codegen/rules/AwsEndpointProviderUtils.java.resource +++ b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/endpoints/AwsEndpointProviderUtils.java @@ -1,27 +1,41 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.awscore.internal.endpoints; + import static software.amazon.awssdk.utils.FunctionalUtils.invokeSafely; import java.net.URI; -import java.util.List; -import java.util.Map; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.awscore.AwsExecutionAttribute; -import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; -import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; -import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.useragent.BusinessMetricFeatureId; import software.amazon.awssdk.endpoints.Endpoint; import software.amazon.awssdk.http.SdkHttpRequest; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.utils.HostnameValidator; -import software.amazon.awssdk.utils.Logger; import software.amazon.awssdk.utils.StringUtils; import software.amazon.awssdk.utils.http.SdkHttpUtils; +/** + * Shared utility methods for endpoint resolution across all AWS services. + * Previously this was generated per-service as an identical copy; now de-duplicated here. + */ @SdkInternalApi public final class AwsEndpointProviderUtils { - private static final Logger LOG = Logger.loggerFor(AwsEndpointProviderUtils.class); private AwsEndpointProviderUtils() { } @@ -38,10 +52,6 @@ public static Boolean fipsEnabledBuiltIn(ExecutionAttributes executionAttributes return executionAttributes.getAttribute(AwsExecutionAttribute.FIPS_ENDPOINT_ENABLED); } - /** - * Returns the endpoint set on the client. Note that this strips off the query part of the URI because the endpoint - * rules library, e.g. {@code ParseURL} will return an exception if the URI it parses has query parameters. - */ public static String endpointBuiltIn(ExecutionAttributes executionAttributes) { if (endpointIsOverridden(executionAttributes)) { executionAttributes.getOptionalAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS).ifPresent( @@ -56,80 +66,36 @@ public static String endpointBuiltIn(ExecutionAttributes executionAttributes) { return null; } - /** - * Read {@link SdkExecutionAttribute#CLIENT_ENDPOINT_PROVIDER}'s isEndpointOverridden attribute. - */ public static boolean endpointIsOverridden(ExecutionAttributes attrs) { return attrs.getAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER).isEndpointOverridden(); } - /** - * True if the the {@link SdkInternalExecutionAttribute#IS_DISCOVERED_ENDPOINT} attribute is present and its value - * is {@code true}, {@code false} otherwise. - */ public static boolean endpointIsDiscovered(ExecutionAttributes attrs) { return attrs.getOptionalAttribute(SdkInternalExecutionAttribute.IS_DISCOVERED_ENDPOINT).orElse(false); } - /** - * True if the the {@link SdkInternalExecutionAttribute#DISABLE_HOST_PREFIX_INJECTION} attribute is present and its - * value is {@code true}, {@code false} otherwise. - */ public static boolean disableHostPrefixInjection(ExecutionAttributes attrs) { return attrs.getOptionalAttribute(SdkInternalExecutionAttribute.DISABLE_HOST_PREFIX_INJECTION).orElse(false); } - /** - * Apply the given endpoint prefix to the endpoint. - */ public static Endpoint addHostPrefix(Endpoint endpoint, String prefix) { if (StringUtils.isBlank(prefix)) { return endpoint; } - validatePrefixIsHostNameCompliant(prefix); - URI originalUrl = endpoint.url(); String newHost = prefix + endpoint.url().getHost(); URI newUrl = invokeSafely(() -> new URI(originalUrl.getScheme(), null, newHost, originalUrl.getPort(), originalUrl.getPath(), originalUrl.getQuery(), originalUrl.getFragment())); - return endpoint.toBuilder().url(newUrl).build(); } - /** - * This sets the request URI to the resolved URI returned by the endpoint provider. There are some things to be - * careful about to make this work properly: - *

- * If the client endpoint is an endpoint override, it may contain a path. In addition, the request marshaller itself - * may add components to the path if it's modeled for the operation. Unfortunately, - * {@link SdkHttpRequest#encodedPath()} returns the combined path from both the endpoint and the request. There is - * no way to know, just from the HTTP request object, where the override path ends (if it's even there) and where - * the request path starts. Additionally, the rule itself may also append other parts to the endpoint override path. - *

- * To solve this issue, we pass in the endpoint set on the path, which allows us to the strip the path from the - * endpoint override from the request path, and then correctly combine the paths. - *

- * For example, let's suppose the endpoint override on the client is {@code https://example.com/a}. Then we call an - * operation {@code Foo()}, that marshalls {@code /c} to the path. The resulting request path is {@code /a/c}. - * However, we also pass the endpoint to provider as a parameter, and the resolver returns - * {@code https://example.com/a/b}. This method takes care of combining the paths correctly so that the resulting - * path is {@code https://example.com/a/b/c}. - */ public static SdkHttpRequest setUri(SdkHttpRequest request, URI clientEndpoint, URI resolvedUri) { - // [client endpoint path] String clientEndpointPath = clientEndpoint.getRawPath(); - - // [client endpoint path]/[request path] String requestPath = request.encodedPath(); - - // [client endpoint path]/[additional path added by resolver] String resolvedUriPath = resolvedUri.getRawPath(); String finalPath = requestPath; - - // If there is an additional path added by resolver, i.e., [additional path added by resolver] not null, - // we need to combine the path if (!resolvedUriPath.equals(clientEndpointPath)) { finalPath = combinePath(clientEndpointPath, requestPath, resolvedUriPath); } @@ -138,27 +104,15 @@ public static SdkHttpRequest setUri(SdkHttpRequest request, URI clientEndpoint, .encodedPath(finalPath).build(); } - /** - * Our goal is to construct [client endpoint path]/[additional path added by resolver]/[request path], so we just - * need to strip the client endpoint path from the marshalled request path to isolate just the part added by the - * marshaller. Trailing slash is removed from client endpoint path before stripping because it could cause the - * leading slash to be removed from the request path: e.g., StringUtils.replaceOnce("/", "//test", "") generates - * "/test" and the expected result is "//test" - */ private static String combinePath(String clientEndpointPath, String requestPath, String resolvedUriPath) { String requestPathWithClientPathRemoved = StringUtils.replaceOnce(requestPath, clientEndpointPath, ""); - String finalPath = SdkHttpUtils.appendUri(resolvedUriPath, requestPathWithClientPathRemoved); - return finalPath; + return SdkHttpUtils.appendUri(resolvedUriPath, requestPathWithClientPathRemoved); } private static void validatePrefixIsHostNameCompliant(String prefix) { - String[] components = splitHostLabelOnDots(prefix); + String[] components = prefix.split("\\."); for (String component : components) { HostnameValidator.validateHostnameCompliant(component, component, "request"); } } - - private static String[] splitHostLabelOnDots(String label) { - return label.split("\\."); - } } diff --git a/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/identity/AwsIdentityProviderUpdater.java b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/identity/AwsIdentityProviderUpdater.java new file mode 100644 index 000000000000..f83b8de84c85 --- /dev/null +++ b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/identity/AwsIdentityProviderUpdater.java @@ -0,0 +1,54 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.awscore.internal.identity; + +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.spi.identity.IdentityProviderUpdater; +import software.amazon.awssdk.identity.spi.IdentityProviders; + +/** + * AWS implementation of {@link IdentityProviderUpdater} that reads credential overrides + * from {@link AwsRequestOverrideConfiguration}. + */ +@SdkInternalApi +public final class AwsIdentityProviderUpdater implements IdentityProviderUpdater { + + private static final AwsIdentityProviderUpdater INSTANCE = new AwsIdentityProviderUpdater(); + + private AwsIdentityProviderUpdater() { + } + + public static AwsIdentityProviderUpdater create() { + return INSTANCE; + } + + @Override + public IdentityProviders update(SdkRequest request, IdentityProviders base) { + if (base == null) { + return null; + } + return request.overrideConfiguration() + .filter(c -> c instanceof AwsRequestOverrideConfiguration) + .map(c -> (AwsRequestOverrideConfiguration) c) + .map(c -> base.copy(b -> { + c.credentialsIdentityProvider().ifPresent(b::putIdentityProvider); + c.tokenIdentityProvider().ifPresent(b::putIdentityProvider); + })) + .orElse(base); + } +} diff --git a/core/aws-core/src/test/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilderTest.java b/core/aws-core/src/test/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilderTest.java index e6ab211de5da..cf2cb59e3d23 100644 --- a/core/aws-core/src/test/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilderTest.java +++ b/core/aws-core/src/test/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilderTest.java @@ -38,7 +38,6 @@ import org.mockito.junit.MockitoJUnitRunner; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; -import software.amazon.awssdk.auth.token.credentials.StaticTokenProvider; import software.amazon.awssdk.awscore.AwsRequest; import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; import software.amazon.awssdk.awscore.client.config.AwsClientOption; @@ -74,7 +73,6 @@ import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; import software.amazon.awssdk.identity.spi.IdentityProvider; import software.amazon.awssdk.identity.spi.IdentityProviders; -import software.amazon.awssdk.identity.spi.TokenIdentity; import software.amazon.awssdk.profiles.ProfileFile; @RunWith(MockitoJUnitRunner.class) @@ -399,49 +397,6 @@ public void invokeInterceptorsAndCreateExecutionContext_withoutIdentityProviders assertThat(executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS)).isNull(); } - @Test - public void invokeInterceptorsAndCreateExecutionContext_requestOverrideForIdentityProvider_updatesIdentityProviders() { - IdentityProvider clientCredentialsProvider = - StaticCredentialsProvider.create(AwsBasicCredentials.create("foo", "bar")); - IdentityProvider clientTokenProvider = StaticTokenProvider.create(() -> "client-token"); - IdentityProviders identityProviders = - IdentityProviders.builder() - .putIdentityProvider(clientCredentialsProvider) - .putIdentityProvider(clientTokenProvider) - .build(); - SdkClientConfiguration clientConfig = testClientConfiguration() - .option(SdkClientOption.IDENTITY_PROVIDERS, identityProviders) - .build(); - - IdentityProvider requestCredentialsProvider = - StaticCredentialsProvider.create(AwsBasicCredentials.create("akid", "skid")); - IdentityProvider requestTokenProvider = StaticTokenProvider.create(() -> "request-token"); - Optional overrideConfiguration = - Optional.of(AwsRequestOverrideConfiguration.builder() - .credentialsProvider(requestCredentialsProvider) - .tokenIdentityProvider(requestTokenProvider) - .build()); - when(sdkRequest.overrideConfiguration()).thenReturn(overrideConfiguration); - - ClientExecutionParams executionParams = clientExecutionParams(); - - ExecutionContext executionContext = - AwsExecutionContextBuilder.invokeInterceptorsAndCreateExecutionContext(executionParams, clientConfig); - - IdentityProviders actualIdentityProviders = - executionContext.executionAttributes().getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS); - - IdentityProvider actualIdentityProvider = - actualIdentityProviders.identityProvider(AwsCredentialsIdentity.class); - - assertThat(actualIdentityProvider).isSameAs(requestCredentialsProvider); - - IdentityProvider actualTokenProvider = - actualIdentityProviders.identityProvider(TokenIdentity.class); - - assertThat(actualTokenProvider).isSameAs(requestTokenProvider); - } - @Test public void invokeInterceptorsAndCreateExecutionContext_withRequestBody_addsUserAgentMetadata() throws IOException { ClientExecutionParams executionParams = clientExecutionParams(); diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/SelectedAuthScheme.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/SelectedAuthScheme.java index 10cb488792e2..8772260e3042 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/SelectedAuthScheme.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/SelectedAuthScheme.java @@ -22,9 +22,27 @@ import software.amazon.awssdk.identity.spi.Identity; import software.amazon.awssdk.utils.Validate; -/** - * A container for the identity resolver, signer and auth option that we selected for use with this service call attempt. - */ + +/// +/// A container for the identity resolver, signer and auth option that we selected for use with this service call attempt. +/// ## The Hierarchy +/// ``` +/// IDENTITY_PROVIDERS (IdentityProviders) +/// └── contains multiple IdentityProvider instances +/// e.g., IdentityProvider for AWS credentials +/// e.g., IdentityProvider for bearer tokens +/// +/// AUTH_SCHEMES (Map>) +/// └── each AuthScheme knows: +/// - which IdentityProvider type it needs +/// - which HttpSigner to use +/// +/// SELECTED_AUTH_SCHEME (SelectedAuthScheme) +/// └── the chosen auth scheme, containing: +/// - identity: CompletableFuture ← the resolved identity! +/// - signer: HttpSigner +/// - authSchemeOption: AuthSchemeOption +/// ``` @SdkProtectedApi public final class SelectedAuthScheme { private final CompletableFuture identity; diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/client/handler/ClientExecutionParams.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/client/handler/ClientExecutionParams.java index e307f5857ce7..6950d6105fa2 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/client/handler/ClientExecutionParams.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/client/handler/ClientExecutionParams.java @@ -30,6 +30,8 @@ import software.amazon.awssdk.core.interceptor.ExecutionAttribute; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.runtime.transform.Marshaller; +import software.amazon.awssdk.core.spi.endpoint.EndpointResolver; +import software.amazon.awssdk.core.spi.identity.AuthSchemeOptionsResolver; import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.core.sync.ResponseTransformer; import software.amazon.awssdk.metrics.MetricCollector; @@ -63,6 +65,8 @@ public final class ClientExecutionParams { private MetricCollector metricCollector; private final ExecutionAttributes attributes = new ExecutionAttributes(); private SdkClientConfiguration requestConfiguration; + private AuthSchemeOptionsResolver authSchemeOptionsResolver; + private EndpointResolver endpointResolver; public Marshaller getMarshaller() { return marshaller; @@ -261,4 +265,23 @@ public ClientExecutionParams withRequestConfiguration(SdkCl this.requestConfiguration = requestConfiguration; return this; } + + public AuthSchemeOptionsResolver authSchemeOptionsResolver() { + return authSchemeOptionsResolver; + } + + public ClientExecutionParams withAuthSchemeOptionsResolver( + AuthSchemeOptionsResolver authSchemeOptionsResolver) { + this.authSchemeOptionsResolver = authSchemeOptionsResolver; + return this; + } + + public EndpointResolver endpointResolver() { + return endpointResolver; + } + + public ClientExecutionParams withEndpointResolver(EndpointResolver endpointResolver) { + this.endpointResolver = endpointResolver; + return this; + } } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/interceptor/SdkInternalExecutionAttribute.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/interceptor/SdkInternalExecutionAttribute.java index 3fe0f69d3ab8..30ea3e948e7e 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/interceptor/SdkInternalExecutionAttribute.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/interceptor/SdkInternalExecutionAttribute.java @@ -29,6 +29,9 @@ import software.amazon.awssdk.core.interceptor.trait.HttpChecksum; import software.amazon.awssdk.core.interceptor.trait.HttpChecksumRequired; import software.amazon.awssdk.core.internal.interceptor.trait.RequestCompression; +import software.amazon.awssdk.core.spi.endpoint.EndpointResolver; +import software.amazon.awssdk.core.spi.identity.AuthSchemeOptionsResolver; +import software.amazon.awssdk.core.spi.identity.IdentityProviderUpdater; import software.amazon.awssdk.core.useragent.AdditionalMetadata; import software.amazon.awssdk.core.useragent.BusinessMetricCollection; import software.amazon.awssdk.endpoints.Endpoint; @@ -166,6 +169,27 @@ public final class SdkInternalExecutionAttribute extends SdkExecutionAttribute { */ public static final ExecutionAttribute IDENTITY_PROVIDERS = new ExecutionAttribute<>("IdentityProviders"); + /** + * Callback for updating identity providers based on request-level overrides. + * This allows aws-core to provide AWS-specific logic without sdk-core depending on aws-core. + */ + public static final ExecutionAttribute IDENTITY_PROVIDER_UPDATER = + new ExecutionAttribute<>("IdentityProviderUpdater"); + + /** + * Callback to resolve auth scheme options from the (possibly modified) request. + * Called by AuthSchemeResolutionStage after interceptors have run. + */ + public static final ExecutionAttribute AUTH_SCHEME_OPTIONS_RESOLVER = + new ExecutionAttribute<>("AuthSchemeOptionsResolver"); + + /** + * Callback for resolving the endpoint. Generated per-service as a lambda in the client class. + * Called by EndpointResolutionStage after interceptors have run. + */ + public static final ExecutionAttribute ENDPOINT_RESOLVER = + new ExecutionAttribute<>("EndpointResolver"); + /** * The selected auth scheme for a request. */ diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonAsyncHttpClient.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonAsyncHttpClient.java index 59bd964d6fad..6594db840740 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonAsyncHttpClient.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonAsyncHttpClient.java @@ -39,7 +39,9 @@ import software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncExecutionFailureExceptionReportingStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncRetryableStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncSigningStage; +import software.amazon.awssdk.core.internal.http.pipeline.stages.AuthSchemeResolutionStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.CompressRequestStage; +import software.amazon.awssdk.core.internal.http.pipeline.stages.EndpointResolutionStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.HttpChecksumStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.MakeAsyncHttpRequestStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.MakeRequestImmutableStage; @@ -198,6 +200,8 @@ public CompletableFuture execute( .then(QueryParametersToBodyStage::new) .then(() -> new CompressRequestStage(httpClientDependencies)) .then(() -> new HttpChecksumStage(ClientType.ASYNC)) + .then(AuthSchemeResolutionStage::new) + .then(EndpointResolutionStage::new) .then(ApplyUserAgentStage::new) .then(MakeRequestImmutableStage::new) .then(RequestPipelineBuilder diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonSyncHttpClient.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonSyncHttpClient.java index dccaad1e2109..2298855e98fe 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonSyncHttpClient.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonSyncHttpClient.java @@ -34,9 +34,11 @@ import software.amazon.awssdk.core.internal.http.pipeline.stages.ApiCallTimeoutTrackingStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.ApplyTransactionIdStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.ApplyUserAgentStage; +import software.amazon.awssdk.core.internal.http.pipeline.stages.AuthSchemeResolutionStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.BeforeTransmissionExecutionInterceptorsStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.BeforeUnmarshallingExecutionInterceptorsStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.CompressRequestStage; +import software.amazon.awssdk.core.internal.http.pipeline.stages.EndpointResolutionStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.ExecutionFailureExceptionReportingStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.HandleResponseStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.HttpChecksumStage; @@ -186,6 +188,8 @@ public OutputT execute(HttpResponseHandler> response .then(QueryParametersToBodyStage::new) .then(() -> new CompressRequestStage(httpClientDependencies)) .then(() -> new HttpChecksumStage(ClientType.SYNC)) + .then(AuthSchemeResolutionStage::new) + .then(EndpointResolutionStage::new) .then(ApplyUserAgentStage::new) .then(MakeRequestImmutableStage::new) // End of mutating request diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/auth/AuthSchemeResolver.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/auth/AuthSchemeResolver.java new file mode 100644 index 000000000000..6348a85f33b9 --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/auth/AuthSchemeResolver.java @@ -0,0 +1,179 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.http.auth; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.SelectedAuthScheme; +import software.amazon.awssdk.core.exception.SdkException; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.core.internal.util.MetricUtils; +import software.amazon.awssdk.core.metrics.CoreMetric; +import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; +import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; +import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; +import software.amazon.awssdk.identity.spi.Identity; +import software.amazon.awssdk.identity.spi.IdentityProvider; +import software.amazon.awssdk.identity.spi.IdentityProviders; +import software.amazon.awssdk.identity.spi.ResolveIdentityRequest; +import software.amazon.awssdk.identity.spi.TokenIdentity; +import software.amazon.awssdk.metrics.MetricCollector; +import software.amazon.awssdk.metrics.SdkMetric; +import software.amazon.awssdk.utils.Logger; + +/** + * Shared utility for selecting auth schemes from a list of options. + */ +@SdkInternalApi +public final class AuthSchemeResolver { + + private static final Logger LOG = Logger.loggerFor(AuthSchemeResolver.class); + + private AuthSchemeResolver() { + } + + /** + * Select an auth scheme from the given options. + * + * @param authOptions List of auth scheme options to try in order + * @param authSchemes Map of available auth schemes + * @param identityProviders Identity providers to use for resolving identity + * @param metricCollector Optional metric collector for recording identity fetch duration + * @return The selected auth scheme + * @throws SdkException if no auth scheme could be selected + */ + public static SelectedAuthScheme selectAuthScheme( + List authOptions, + Map> authSchemes, + IdentityProviders identityProviders, + MetricCollector metricCollector) { + + List> discardedReasons = new ArrayList<>(); + + for (AuthSchemeOption authOption : authOptions) { + AuthScheme authScheme = authSchemes.get(authOption.schemeId()); + SelectedAuthScheme selectedAuthScheme = trySelectAuthScheme( + authOption, authScheme, identityProviders, discardedReasons, metricCollector); + + if (selectedAuthScheme != null) { + if (!discardedReasons.isEmpty()) { + LOG.debug(() -> String.format("%s auth will be used, discarded: '%s'", + authOption.schemeId(), + discardedReasons.stream().map(Supplier::get).collect(Collectors.joining(", ")))); + } + return selectedAuthScheme; + } + } + + throw SdkException.builder() + .message("Failed to determine how to authenticate the user: " + + discardedReasons.stream().map(Supplier::get).collect(Collectors.joining(", "))) + .build(); + } + + /** + * Merge properties from any pre-existing auth scheme into the selected one. + */ + public static SelectedAuthScheme mergePreExistingAuthSchemeProperties( + SelectedAuthScheme selectedAuthScheme, + ExecutionAttributes executionAttributes) { + + SelectedAuthScheme existingAuthScheme = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + + if (existingAuthScheme == null) { + return selectedAuthScheme; + } + + AuthSchemeOption.Builder mergedOption = selectedAuthScheme.authSchemeOption().toBuilder(); + existingAuthScheme.authSchemeOption().forEachIdentityProperty(mergedOption::putIdentityPropertyIfAbsent); + existingAuthScheme.authSchemeOption().forEachSignerProperty(mergedOption::putSignerPropertyIfAbsent); + + return new SelectedAuthScheme<>( + selectedAuthScheme.identity(), + selectedAuthScheme.signer(), + mergedOption.build() + ); + } + + private static SelectedAuthScheme trySelectAuthScheme( + AuthSchemeOption authOption, + AuthScheme authScheme, + IdentityProviders identityProviders, + List> discardedReasons, + MetricCollector metricCollector) { + + if (authScheme == null) { + discardedReasons.add(() -> String.format("'%s' is not enabled for this request.", authOption.schemeId())); + return null; + } + + IdentityProvider identityProvider = authScheme.identityProvider(identityProviders); + if (identityProvider == null) { + discardedReasons.add(() -> String.format("'%s' does not have an identity provider configured.", + authOption.schemeId())); + return null; + } + + HttpSigner signer; + try { + signer = authScheme.signer(); + } catch (RuntimeException e) { + discardedReasons.add(() -> String.format("'%s' signer could not be retrieved: %s", + authOption.schemeId(), e.getMessage())); + return null; + } + + ResolveIdentityRequest.Builder identityRequestBuilder = ResolveIdentityRequest.builder(); + authOption.forEachIdentityProperty(identityRequestBuilder::putProperty); + + CompletableFuture identity = resolveIdentity( + identityProvider, identityRequestBuilder.build(), metricCollector); + + return new SelectedAuthScheme<>(identity, signer, authOption); + } + + private static CompletableFuture resolveIdentity( + IdentityProvider identityProvider, + ResolveIdentityRequest request, + MetricCollector metricCollector) { + + SdkMetric metric = getIdentityMetric(identityProvider); + if (metric == null || metricCollector == null) { + return identityProvider.resolveIdentity(request); + } + return MetricUtils.reportDuration(() -> identityProvider.resolveIdentity(request), metricCollector, metric); + } + + private static SdkMetric getIdentityMetric(IdentityProvider identityProvider) { + Class identityType = identityProvider.identityType(); + if (identityType == AwsCredentialsIdentity.class) { + return CoreMetric.CREDENTIALS_FETCH_DURATION; + } + if (identityType == TokenIdentity.class) { + return CoreMetric.TOKEN_FETCH_DURATION; + } + return null; + } +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/ApiCallMetricCollectionStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/ApiCallMetricCollectionStage.java index 3b78dedaf2ad..c7a6783fa7da 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/ApiCallMetricCollectionStage.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/ApiCallMetricCollectionStage.java @@ -40,17 +40,23 @@ public ApiCallMetricCollectionStage(RequestPipeline execute(SdkHttpFullRequest input, RequestExecutionContext context) throws Exception { MetricCollector metricCollector = context.executionContext().metricCollector(); - MetricUtils.collectServiceEndpointMetrics(metricCollector, input); // Note: at this point, any exception, even a service exception, will // be thrown from the wrapped pipeline so we can't use // MetricUtil.measureDuration() long callStart = System.nanoTime(); try { - return wrapped.execute(input, context); + Response response = wrapped.execute(input, context); + return response; } finally { long d = System.nanoTime() - callStart; metricCollector.reportMetric(CoreMetric.API_CALL_DURATION, Duration.ofNanos(d)); + // Collect SERVICE_ENDPOINT after pipeline execution so that EndpointResolutionStage + // has applied the resolved URL to the request. + SdkHttpFullRequest finalRequest = context.executionContext().interceptorContext().httpRequest() != null + ? (SdkHttpFullRequest) context.executionContext().interceptorContext().httpRequest() + : input; + MetricUtils.collectServiceEndpointMetrics(metricCollector, finalRequest); } } } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncApiCallMetricCollectionStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncApiCallMetricCollectionStage.java index 1d7d040971cf..e85a6f752aee 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncApiCallMetricCollectionStage.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncApiCallMetricCollectionStage.java @@ -41,7 +41,6 @@ public AsyncApiCallMetricCollectionStage(RequestPipeline execute(SdkHttpFullRequest input, RequestExecutionContext context) throws Exception { MetricCollector metricCollector = context.executionContext().metricCollector(); - MetricUtils.collectServiceEndpointMetrics(metricCollector, input); CompletableFuture future = new CompletableFuture<>(); @@ -52,6 +51,13 @@ public CompletableFuture execute(SdkHttpFullRequest input, RequestExecu long duration = System.nanoTime() - callStart; metricCollector.reportMetric(CoreMetric.API_CALL_DURATION, Duration.ofNanos(duration)); + // Collect SERVICE_ENDPOINT after pipeline execution so that EndpointResolutionStage + // has applied the resolved URL to the request. + SdkHttpFullRequest finalRequest = context.executionContext().interceptorContext().httpRequest() != null + ? (SdkHttpFullRequest) context.executionContext().interceptorContext().httpRequest() + : input; + MetricUtils.collectServiceEndpointMetrics(metricCollector, finalRequest); + if (t != null) { future.completeExceptionally(t); } else { diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStage.java new file mode 100644 index 000000000000..b5f616d54166 --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStage.java @@ -0,0 +1,158 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.http.pipeline.stages; + +import java.util.List; +import java.util.Map; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.RequestOverrideConfiguration; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.core.internal.http.HttpClientDependencies; +import software.amazon.awssdk.core.internal.http.RequestExecutionContext; +import software.amazon.awssdk.core.internal.http.auth.AuthSchemeResolver; +import software.amazon.awssdk.core.internal.http.pipeline.MutableRequestToRequestPipeline; +import software.amazon.awssdk.core.spi.identity.AuthSchemeOptionsResolver; +import software.amazon.awssdk.core.spi.identity.IdentityProviderUpdater; +import software.amazon.awssdk.core.useragent.BusinessMetricCollection; +import software.amazon.awssdk.core.useragent.BusinessMetricFeatureId; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; +import software.amazon.awssdk.identity.spi.Identity; +import software.amazon.awssdk.identity.spi.IdentityProviders; +import software.amazon.awssdk.identity.spi.TokenIdentity; +import software.amazon.awssdk.metrics.MetricCollector; + +/** + * Pipeline stage that resolves the auth scheme and identity for signing. + */ +@SdkInternalApi +public final class AuthSchemeResolutionStage implements MutableRequestToRequestPipeline { + + private static final String SIGV4A_SCHEME_ID = "aws.auth#sigv4a"; + private static final String BEARER_SCHEME_ID = "smithy.api#httpBearerAuth"; + + public AuthSchemeResolutionStage(HttpClientDependencies dependencies) { + } + + @Override + public SdkHttpFullRequest.Builder execute(SdkHttpFullRequest.Builder request, RequestExecutionContext context) + throws Exception { + ExecutionAttributes executionAttributes = context.executionAttributes(); + + Map> authSchemes = executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES); + if (authSchemes == null) { + return request; + } + + SdkRequest sdkRequest = context.executionContext().interceptorContext().request(); + List authOptions = resolveAuthSchemeOptions(executionAttributes, sdkRequest); + if (authOptions == null || authOptions.isEmpty()) { + return request; + } + + IdentityProviders identityProviders = updateIdentityProvidersIfNeeded(executionAttributes, sdkRequest); + + MetricCollector metricCollector = + executionAttributes.getAttribute(SdkExecutionAttribute.API_CALL_METRIC_COLLECTOR); + + SelectedAuthScheme selectedAuthScheme = + AuthSchemeResolver.selectAuthScheme(authOptions, authSchemes, identityProviders, metricCollector); + + selectedAuthScheme = AuthSchemeResolver.mergePreExistingAuthSchemeProperties(selectedAuthScheme, executionAttributes); + + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + + recordBusinessMetrics(selectedAuthScheme, sdkRequest, executionAttributes); + + return request; + } + + private List resolveAuthSchemeOptions(ExecutionAttributes executionAttributes, SdkRequest request) { + AuthSchemeOptionsResolver resolver = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER); + + if (resolver == null) { + return null; + } + return resolver.resolve(request); + } + + /** + * Returns identity providers after applying any request-level overrides. This allows aws-core to inject + * credential overrides from {@code AwsRequestOverrideConfiguration} (e.g., per-request credentials provider) + * without sdk-core depending on aws-core. The updater is set by {@code AwsExecutionContextBuilder} and runs + * after interceptors have modified the request, ensuring user-injected credentials are respected. + */ + private IdentityProviders updateIdentityProvidersIfNeeded(ExecutionAttributes executionAttributes, SdkRequest request) { + IdentityProviders identityProviders = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS); + + IdentityProviderUpdater updater = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDER_UPDATER); + if (updater != null) { + identityProviders = updater.update(request, identityProviders); + } + return identityProviders; + } + + private void recordBusinessMetrics(SelectedAuthScheme selectedAuthScheme, + SdkRequest request, + ExecutionAttributes executionAttributes) { + if (selectedAuthScheme == null) { + return; + } + + BusinessMetricCollection businessMetrics = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS); + if (businessMetrics == null) { + return; + } + + String schemeId = selectedAuthScheme.authSchemeOption().schemeId(); + + if (isSignerOverridden(request, executionAttributes)) { + return; + } + + if (SIGV4A_SCHEME_ID.equals(schemeId)) { + businessMetrics.addMetric(BusinessMetricFeatureId.SIGV4A_SIGNING.value()); + } + + if (BEARER_SCHEME_ID.equals(schemeId) && selectedAuthScheme.identity().isDone()) { + Identity identity = selectedAuthScheme.identity().getNow(null); + if (identity instanceof TokenIdentity) { + String tokenFromEnv = executionAttributes.getAttribute(SdkInternalExecutionAttribute.TOKEN_CONFIGURED_FROM_ENV); + if (tokenFromEnv != null && tokenFromEnv.equals(((TokenIdentity) identity).token())) { + businessMetrics.addMetric(BusinessMetricFeatureId.BEARER_SERVICE_ENV_VARS.value()); + } + } + } + } + + private boolean isSignerOverridden(SdkRequest request, ExecutionAttributes executionAttributes) { + boolean isClientSignerOverridden = + Boolean.TRUE.equals(executionAttributes.getAttribute(SdkExecutionAttribute.SIGNER_OVERRIDDEN)); + boolean isRequestSignerOverridden = request.overrideConfiguration() + .flatMap(RequestOverrideConfiguration::signer) + .isPresent(); + return isClientSignerOverridden || isRequestSignerOverridden; + } +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/EndpointResolutionStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/EndpointResolutionStage.java new file mode 100644 index 000000000000..70d2886c2fd1 --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/EndpointResolutionStage.java @@ -0,0 +1,119 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.http.pipeline.stages; + +import java.net.URI; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.ClientEndpointProvider; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.core.internal.http.HttpClientDependencies; +import software.amazon.awssdk.core.internal.http.RequestExecutionContext; +import software.amazon.awssdk.core.internal.http.pipeline.MutableRequestToRequestPipeline; +import software.amazon.awssdk.core.metrics.CoreMetric; +import software.amazon.awssdk.core.spi.endpoint.EndpointResolver; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.metrics.MetricCollector; +import software.amazon.awssdk.utils.StringUtils; +import software.amazon.awssdk.utils.http.SdkHttpUtils; + +/** + * Pipeline stage that resolves the endpoint using the service's endpoint rules engine. + *

+ * This stage runs after all interceptors (including {@code modifyHttpRequest}) and after + * {@link AuthSchemeResolutionStage}. It calls a service-specific callback to resolve the endpoint, + * then applies the resolved URL, headers, and metrics. + *

+ */ +@SdkInternalApi +public final class EndpointResolutionStage implements MutableRequestToRequestPipeline { + + public EndpointResolutionStage(HttpClientDependencies dependencies) { + } + + @Override + public SdkHttpFullRequest.Builder execute(SdkHttpFullRequest.Builder request, RequestExecutionContext context) + throws Exception { + ExecutionAttributes attrs = context.executionAttributes(); + + if (Boolean.TRUE.equals(attrs.getAttribute(SdkInternalExecutionAttribute.IS_DISCOVERED_ENDPOINT))) { + return request; + } + + EndpointResolver resolver = attrs.getAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER); + if (resolver == null) { + return request; + } + + SdkRequest sdkRequest = context.executionContext().interceptorContext().request(); + + long resolveEndpointStart = System.nanoTime(); + Endpoint endpoint = resolver.resolve(sdkRequest, attrs); + Duration resolveEndpointDuration = Duration.ofNanos(System.nanoTime() - resolveEndpointStart); + + MetricCollector metricCollector = attrs.getAttribute(SdkExecutionAttribute.API_CALL_METRIC_COLLECTOR); + if (metricCollector != null) { + metricCollector.reportMetric(CoreMetric.ENDPOINT_RESOLVE_DURATION, resolveEndpointDuration); + } + + attrs.putAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT, endpoint); + + // Copy endpoint headers onto HTTP request + Map> headers = endpoint.headers(); + if (headers != null && !headers.isEmpty()) { + headers.forEach((name, values) -> + values.forEach(v -> request.appendHeader(name, v))); + } + + // Apply resolved endpoint URL + ClientEndpointProvider clientEndpointProvider = + attrs.getAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER); + return setUri(request, clientEndpointProvider.clientEndpoint(), endpoint.url()); + } + + /** + * Applies the resolved endpoint URL to the HTTP request, merging three path components: + * client endpoint path + rules engine additions + marshaller operation path. + */ + private static SdkHttpFullRequest.Builder setUri(SdkHttpFullRequest.Builder request, + URI clientEndpoint, + URI resolvedUri) { + String clientEndpointPath = clientEndpoint.getRawPath(); + String requestPath = request.encodedPath(); + String resolvedUriPath = resolvedUri.getRawPath(); + + String finalPath = requestPath; + if (!resolvedUriPath.equals(clientEndpointPath)) { + finalPath = combinePath(clientEndpointPath, requestPath, resolvedUriPath); + } + + return request.protocol(resolvedUri.getScheme()) + .host(resolvedUri.getHost()) + .port(resolvedUri.getPort()) + .encodedPath(finalPath); + } + + private static String combinePath(String clientEndpointPath, String requestPath, String resolvedUriPath) { + String requestPathWithClientPathRemoved = StringUtils.replaceOnce(requestPath, clientEndpointPath, ""); + return SdkHttpUtils.appendUri(resolvedUriPath, requestPathWithClientPathRemoved); + } +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/endpoint/EndpointResolver.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/endpoint/EndpointResolver.java new file mode 100644 index 000000000000..62ba60164d24 --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/endpoint/EndpointResolver.java @@ -0,0 +1,37 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.spi.endpoint; + +import software.amazon.awssdk.annotations.SdkProtectedApi; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.endpoints.Endpoint; + +/** + * Callback interface for resolving endpoints from the request and execution context. + */ +@FunctionalInterface +@SdkProtectedApi +public interface EndpointResolver { + /** + * Resolves the endpoint for the given request. + * + * @param request The SDK request (after interceptors have modified it) + * @param executionAttributes The execution attributes containing client config, region, etc. + * @return The resolved endpoint + */ + Endpoint resolve(SdkRequest request, ExecutionAttributes executionAttributes); +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/identity/AuthSchemeOptionsResolver.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/identity/AuthSchemeOptionsResolver.java new file mode 100644 index 000000000000..7a37d7c3dd6b --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/identity/AuthSchemeOptionsResolver.java @@ -0,0 +1,39 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.spi.identity; + +import java.util.List; +import software.amazon.awssdk.annotations.SdkProtectedApi; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; + +/** + * Callback interface for resolving auth scheme options from the request. + *

+ * This allows auth scheme resolution to happen after interceptors have modified the request, + * ensuring that any request modifications affecting auth scheme selection are respected. + */ +@FunctionalInterface +@SdkProtectedApi +public interface AuthSchemeOptionsResolver { + /** + * Resolves auth scheme options for the given request. + * + * @param request The request (after interceptors have modified it) + * @return List of auth scheme options in priority order + */ + List resolve(SdkRequest request); +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/identity/IdentityProviderUpdater.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/identity/IdentityProviderUpdater.java new file mode 100644 index 000000000000..af54e5ef3661 --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/identity/IdentityProviderUpdater.java @@ -0,0 +1,39 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.spi.identity; + +import software.amazon.awssdk.annotations.SdkProtectedApi; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.identity.spi.IdentityProviders; + +/** + * Callback interface for updating identity providers based on request-level overrides. + *

+ * This allows aws-core to provide AWS-specific logic for reading credential overrides + * from {@code AwsRequestOverrideConfiguration} without sdk-core depending on aws-core. + */ +@FunctionalInterface +@SdkProtectedApi +public interface IdentityProviderUpdater { + /** + * Updates identity providers based on request-level overrides. + * + * @param request The request (after interceptors have modified it) + * @param base The base identity providers from client configuration + * @return Updated identity providers, or base if no overrides + */ + IdentityProviders update(SdkRequest request, IdentityProviders base); +} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/auth/AuthSchemeResolverTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/auth/AuthSchemeResolverTest.java new file mode 100644 index 000000000000..90c2cdf12008 --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/auth/AuthSchemeResolverTest.java @@ -0,0 +1,189 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.http.auth; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.SelectedAuthScheme; +import software.amazon.awssdk.core.exception.SdkException; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; +import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; +import software.amazon.awssdk.identity.spi.Identity; +import software.amazon.awssdk.identity.spi.IdentityProvider; +import software.amazon.awssdk.identity.spi.IdentityProviders; +import software.amazon.awssdk.identity.spi.ResolveIdentityRequest; + +class AuthSchemeResolverTest { + + private static final String SCHEME_A = "schemeA"; + private static final String SCHEME_B = "schemeB"; + + @Test + void selectAuthScheme_firstOptionSucceeds_returnsFirstScheme() { + AuthScheme schemeA = createMockAuthScheme(); + Map> authSchemes = new HashMap<>(); + authSchemes.put(SCHEME_A, schemeA); + + List options = Collections.singletonList( + AuthSchemeOption.builder().schemeId(SCHEME_A).build() + ); + + SelectedAuthScheme result = AuthSchemeResolver.selectAuthScheme( + options, authSchemes, mock(IdentityProviders.class), null); + + assertThat(result.authSchemeOption().schemeId()).isEqualTo(SCHEME_A); + } + + @Test + void selectAuthScheme_firstOptionNoScheme_fallsBackToSecond() { + AuthScheme schemeB = createMockAuthScheme(); + Map> authSchemes = new HashMap<>(); + authSchemes.put(SCHEME_B, schemeB); + + List options = Arrays.asList( + AuthSchemeOption.builder().schemeId(SCHEME_A).build(), + AuthSchemeOption.builder().schemeId(SCHEME_B).build() + ); + + SelectedAuthScheme result = AuthSchemeResolver.selectAuthScheme( + options, authSchemes, mock(IdentityProviders.class), null); + + assertThat(result.authSchemeOption().schemeId()).isEqualTo(SCHEME_B); + } + + @Test + void selectAuthScheme_firstOptionNoIdentityProvider_fallsBackToSecond() { + AuthScheme schemeA = createMockAuthScheme(); + when(schemeA.identityProvider(any())).thenReturn(null); + + AuthScheme schemeB = createMockAuthScheme(); + + Map> authSchemes = new HashMap<>(); + authSchemes.put(SCHEME_A, schemeA); + authSchemes.put(SCHEME_B, schemeB); + + List options = Arrays.asList( + AuthSchemeOption.builder().schemeId(SCHEME_A).build(), + AuthSchemeOption.builder().schemeId(SCHEME_B).build() + ); + + SelectedAuthScheme result = AuthSchemeResolver.selectAuthScheme( + options, authSchemes, mock(IdentityProviders.class), null); + + assertThat(result.authSchemeOption().schemeId()).isEqualTo(SCHEME_B); + } + + @Test + void selectAuthScheme_signerThrows_fallsBackToSecond() { + AuthScheme schemeA = createMockAuthScheme(); + when(schemeA.signer()).thenThrow(new RuntimeException("Signer not available")); + + AuthScheme schemeB = createMockAuthScheme(); + + Map> authSchemes = new HashMap<>(); + authSchemes.put(SCHEME_A, schemeA); + authSchemes.put(SCHEME_B, schemeB); + + List options = Arrays.asList( + AuthSchemeOption.builder().schemeId(SCHEME_A).build(), + AuthSchemeOption.builder().schemeId(SCHEME_B).build() + ); + + SelectedAuthScheme result = AuthSchemeResolver.selectAuthScheme( + options, authSchemes, mock(IdentityProviders.class), null); + + assertThat(result.authSchemeOption().schemeId()).isEqualTo(SCHEME_B); + } + + @Test + void selectAuthScheme_allOptionsFail_throwsException() { + Map> authSchemes = new HashMap<>(); + + List options = Collections.singletonList( + AuthSchemeOption.builder().schemeId(SCHEME_A).build() + ); + + assertThatThrownBy(() -> AuthSchemeResolver.selectAuthScheme( + options, authSchemes, mock(IdentityProviders.class), null)) + .isInstanceOf(SdkException.class) + .hasMessageContaining("Failed to determine how to authenticate"); + } + + @Test + void mergeProperties_noExistingScheme_returnsOriginal() { + SelectedAuthScheme selected = createSelectedAuthScheme(SCHEME_A); + ExecutionAttributes attributes = new ExecutionAttributes(); + + SelectedAuthScheme result = AuthSchemeResolver.mergePreExistingAuthSchemeProperties( + selected, attributes); + + assertThat(result).isSameAs(selected); + } + + @Test + @SuppressWarnings("unchecked") + void mergeProperties_withExistingScheme_returnsNewInstance() { + SelectedAuthScheme selected = createSelectedAuthScheme(SCHEME_A); + SelectedAuthScheme existing = createSelectedAuthScheme(SCHEME_B); + + ExecutionAttributes attributes = new ExecutionAttributes(); + attributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, existing); + + SelectedAuthScheme result = AuthSchemeResolver.mergePreExistingAuthSchemeProperties( + selected, attributes); + + assertThat(result).isNotSameAs(selected); + assertThat(result.authSchemeOption().schemeId()).isEqualTo(SCHEME_A); + } + + @SuppressWarnings("unchecked") + private AuthScheme createMockAuthScheme() { + AuthScheme scheme = mock(AuthScheme.class); + IdentityProvider identityProvider = mock(IdentityProvider.class); + Identity mockIdentity = mock(Identity.class); + doReturn(CompletableFuture.completedFuture(mockIdentity)) + .when(identityProvider).resolveIdentity(any(ResolveIdentityRequest.class)); + when(scheme.identityProvider(any())).thenReturn(identityProvider); + when(scheme.signer()).thenReturn(mock(HttpSigner.class)); + return scheme; + } + + @SuppressWarnings("unchecked") + private SelectedAuthScheme createSelectedAuthScheme(String schemeId) { + Identity mockIdentity = mock(Identity.class); + HttpSigner mockSigner = mock(HttpSigner.class); + return new SelectedAuthScheme<>( + CompletableFuture.completedFuture(mockIdentity), + mockSigner, + AuthSchemeOption.builder().schemeId(schemeId).build() + ); + } +} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStageTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStageTest.java new file mode 100644 index 000000000000..da86f97f053f --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStageTest.java @@ -0,0 +1,231 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.http.pipeline.stages; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; +import software.amazon.awssdk.core.http.ExecutionContext; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.InterceptorContext; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.core.internal.http.RequestExecutionContext; +import software.amazon.awssdk.core.spi.identity.AuthSchemeOptionsResolver; +import software.amazon.awssdk.core.spi.identity.IdentityProviderUpdater; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; +import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; +import software.amazon.awssdk.identity.spi.Identity; +import software.amazon.awssdk.identity.spi.IdentityProvider; +import software.amazon.awssdk.identity.spi.IdentityProviders; +import software.amazon.awssdk.identity.spi.ResolveIdentityRequest; + +class AuthSchemeResolutionStageTest { + + private static final String SCHEME_ID = "test.scheme"; + + private AuthSchemeResolutionStage stage; + private SdkHttpFullRequest.Builder httpRequestBuilder; + private RequestExecutionContext context; + private ExecutionAttributes executionAttributes; + private SdkRequest sdkRequest; + + @BeforeEach + void setup() { + stage = new AuthSchemeResolutionStage(null); + httpRequestBuilder = mock(SdkHttpFullRequest.Builder.class); + sdkRequest = mock(SdkRequest.class); + executionAttributes = new ExecutionAttributes(); + + InterceptorContext interceptorContext = InterceptorContext.builder() + .request(sdkRequest) + .build(); + ExecutionContext executionContext = ExecutionContext.builder() + .interceptorContext(interceptorContext) + .executionAttributes(executionAttributes) + .build(); + context = RequestExecutionContext.builder() + .executionContext(executionContext) + .originalRequest(sdkRequest) + .build(); + } + + @Test + void execute_noAuthSchemes_returnsRequestUnchanged() throws Exception { + // AUTH_SCHEMES is null + SdkHttpFullRequest.Builder result = stage.execute(httpRequestBuilder, context); + + assertThat(result).isSameAs(httpRequestBuilder); + assertThat(executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME)).isNull(); + } + + @Test + void execute_noResolver_returnsRequestUnchanged() throws Exception { + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, createAuthSchemes()); + // AUTH_SCHEME_OPTIONS_RESOLVER is null + + SdkHttpFullRequest.Builder result = stage.execute(httpRequestBuilder, context); + + assertThat(result).isSameAs(httpRequestBuilder); + assertThat(executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME)).isNull(); + } + + @Test + void execute_resolverReturnsEmpty_returnsRequestUnchanged() throws Exception { + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, createAuthSchemes()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER, + (AuthSchemeOptionsResolver) req -> Collections.emptyList()); + + SdkHttpFullRequest.Builder result = stage.execute(httpRequestBuilder, context); + + assertThat(result).isSameAs(httpRequestBuilder); + assertThat(executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME)).isNull(); + } + + @Test + void execute_resolverReceivesRequestFromInterceptorContext() throws Exception { + SdkRequest modifiedRequest = mock(SdkRequest.class); + + // Setup interceptor context with a DIFFERENT request than originalRequest + InterceptorContext interceptorContext = InterceptorContext.builder() + .request(modifiedRequest) + .build(); + ExecutionContext executionContext = ExecutionContext.builder() + .interceptorContext(interceptorContext) + .executionAttributes(executionAttributes) + .build(); + context = RequestExecutionContext.builder() + .executionContext(executionContext) + .originalRequest(sdkRequest) // Different from modifiedRequest + .build(); + + AuthSchemeOptionsResolver resolver = mock(AuthSchemeOptionsResolver.class); + doReturn(createAuthOptions()).when(resolver).resolve(modifiedRequest); + + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, createAuthSchemes()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER, resolver); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, createIdentityProviders()); + + stage.execute(httpRequestBuilder, context); + + // Verify resolver was called with the MODIFIED request, not originalRequest + verify(resolver).resolve(modifiedRequest); + } + + @Test + void execute_withIdentityProviderUpdater_callsUpdaterWithRequest() throws Exception { + // Create mocks first before any stubbing + IdentityProvider identityProvider = createMockIdentityProvider(); + Map> authSchemes = createAuthSchemes(); + IdentityProviders baseProviders = mock(IdentityProviders.class); + IdentityProviders updatedProviders = mock(IdentityProviders.class); + + IdentityProviderUpdater updater = mock(IdentityProviderUpdater.class); + doReturn(updatedProviders).when(updater).update(sdkRequest, baseProviders); + + // Setup so that auth scheme uses the updated providers + @SuppressWarnings("unchecked") + AuthScheme scheme = (AuthScheme) authSchemes.get(SCHEME_ID); + doReturn(identityProvider).when(scheme).identityProvider(updatedProviders); + + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, authSchemes); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER, + (AuthSchemeOptionsResolver) req -> createAuthOptions()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, baseProviders); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDER_UPDATER, updater); + + stage.execute(httpRequestBuilder, context); + + verify(updater).update(sdkRequest, baseProviders); + } + + @Test + void execute_withoutIdentityProviderUpdater_doesNotFail() throws Exception { + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, createAuthSchemes()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER, + (AuthSchemeOptionsResolver) req -> createAuthOptions()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, createIdentityProviders()); + // No IDENTITY_PROVIDER_UPDATER set + + SdkHttpFullRequest.Builder result = stage.execute(httpRequestBuilder, context); + + assertThat(result).isSameAs(httpRequestBuilder); + assertThat(executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME)).isNotNull(); + } + + @Test + void execute_happyPath_setsSelectedAuthScheme() throws Exception { + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, createAuthSchemes()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER, + (AuthSchemeOptionsResolver) req -> createAuthOptions()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, createIdentityProviders()); + + SdkHttpFullRequest.Builder result = stage.execute(httpRequestBuilder, context); + + assertThat(result).isSameAs(httpRequestBuilder); + SelectedAuthScheme selectedAuthScheme = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + assertThat(selectedAuthScheme).isNotNull(); + assertThat(selectedAuthScheme.authSchemeOption().schemeId()).isEqualTo(SCHEME_ID); + } + + @SuppressWarnings("unchecked") + private Map> createAuthSchemes() { + IdentityProvider identityProvider = createMockIdentityProvider(); + HttpSigner signer = mock(HttpSigner.class); + + AuthScheme scheme = mock(AuthScheme.class); + doReturn(identityProvider).when(scheme).identityProvider(any()); + doReturn(signer).when(scheme).signer(); + + Map> schemes = new HashMap<>(); + schemes.put(SCHEME_ID, scheme); + return schemes; + } + + @SuppressWarnings("unchecked") + private IdentityProvider createMockIdentityProvider() { + IdentityProvider provider = mock(IdentityProvider.class); + Identity mockIdentity = mock(Identity.class); + doReturn(CompletableFuture.completedFuture(mockIdentity)) + .when(provider).resolveIdentity(any(ResolveIdentityRequest.class)); + return provider; + } + + private IdentityProviders createIdentityProviders() { + IdentityProviders providers = mock(IdentityProviders.class); + return providers; + } + + private List createAuthOptions() { + return Collections.singletonList( + AuthSchemeOption.builder().schemeId(SCHEME_ID).build() + ); + } +} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/EndpointResolutionStageTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/EndpointResolutionStageTest.java new file mode 100644 index 000000000000..94ee03b5df50 --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/EndpointResolutionStageTest.java @@ -0,0 +1,217 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.http.pipeline.stages; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import java.net.URI; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import software.amazon.awssdk.core.ClientEndpointProvider; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.http.ExecutionContext; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.InterceptorContext; +import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.core.internal.http.RequestExecutionContext; +import software.amazon.awssdk.core.metrics.CoreMetric; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.SdkHttpMethod; +import software.amazon.awssdk.metrics.MetricCollector; + +class EndpointResolutionStageTest { + + private static final URI CLIENT_ENDPOINT = URI.create("https://myservice.us-east-1.amazonaws.com"); + private static final URI RESOLVED_ENDPOINT = URI.create("https://resolved.us-west-2.amazonaws.com"); + + private EndpointResolutionStage stage; + private ExecutionAttributes executionAttributes; + private SdkRequest sdkRequest; + + @BeforeEach + void setup() { + stage = new EndpointResolutionStage(null); + sdkRequest = mock(SdkRequest.class); + executionAttributes = new ExecutionAttributes(); + } + + @Test + void execute_noResolver_returnsRequestUnchanged() throws Exception { + SdkHttpFullRequest.Builder request = defaultRequest(); + RequestExecutionContext context = createContext(); + + SdkHttpFullRequest.Builder result = stage.execute(request, context); + + assertThat(result).isSameAs(request); + } + + @Test + void execute_discoveredEndpoint_returnsRequestUnchanged() throws Exception { + SdkHttpFullRequest.Builder request = defaultRequest(); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IS_DISCOVERED_ENDPOINT, true); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER, + (req, attrs) -> Endpoint.builder().url(RESOLVED_ENDPOINT).build()); + RequestExecutionContext context = createContext(); + + SdkHttpFullRequest.Builder result = stage.execute(request, context); + + assertThat(result).isSameAs(request); + } + + @Test + void execute_happyPath_appliesResolvedUrl() throws Exception { + SdkHttpFullRequest.Builder request = defaultRequest(); + Endpoint endpoint = Endpoint.builder().url(RESOLVED_ENDPOINT).build(); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER, (req, attrs) -> endpoint); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, + ClientEndpointProvider.forEndpointOverride(CLIENT_ENDPOINT)); + RequestExecutionContext context = createContext(); + + SdkHttpFullRequest.Builder result = stage.execute(request, context); + + assertThat(result.host()).isEqualTo("resolved.us-west-2.amazonaws.com"); + assertThat(result.protocol()).isEqualTo("https"); + } + + @Test + void execute_happyPath_storesResolvedEndpoint() throws Exception { + SdkHttpFullRequest.Builder request = defaultRequest(); + Endpoint endpoint = Endpoint.builder().url(RESOLVED_ENDPOINT).build(); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER, (req, attrs) -> endpoint); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, + ClientEndpointProvider.forEndpointOverride(CLIENT_ENDPOINT)); + RequestExecutionContext context = createContext(); + + stage.execute(request, context); + + assertThat(executionAttributes.getAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT)).isSameAs(endpoint); + } + + @Test + void execute_withHeaders_appendsToRequest() throws Exception { + Endpoint endpoint = Endpoint.builder().url(RESOLVED_ENDPOINT) + .putHeader("x-amz-custom", "val1") + .putHeader("x-amz-custom", "val2") + .build(); + + SdkHttpFullRequest.Builder request = defaultRequest(); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER, (req, attrs) -> endpoint); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, + ClientEndpointProvider.forEndpointOverride(CLIENT_ENDPOINT)); + RequestExecutionContext context = createContext(); + + SdkHttpFullRequest.Builder result = stage.execute(request, context); + + assertThat(result.build().matchingHeaders("x-amz-custom")).containsExactly("val1", "val2"); + } + + @Test + void execute_withMetricCollector_reportsEndpointResolveDuration() throws Exception { + Endpoint endpoint = Endpoint.builder().url(RESOLVED_ENDPOINT).build(); + MetricCollector metricCollector = mock(MetricCollector.class); + + SdkHttpFullRequest.Builder request = defaultRequest(); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER, (req, attrs) -> endpoint); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, + ClientEndpointProvider.forEndpointOverride(CLIENT_ENDPOINT)); + executionAttributes.putAttribute(SdkExecutionAttribute.API_CALL_METRIC_COLLECTOR, metricCollector); + RequestExecutionContext context = createContext(); + + stage.execute(request, context); + + ArgumentCaptor durationCaptor = ArgumentCaptor.forClass(Duration.class); + verify(metricCollector).reportMetric(org.mockito.ArgumentMatchers.eq(CoreMetric.ENDPOINT_RESOLVE_DURATION), + durationCaptor.capture()); + assertThat(durationCaptor.getValue()).isGreaterThanOrEqualTo(Duration.ZERO); + } + + @Test + void execute_resolvedPathDiffersFromClientPath_combinesPaths() throws Exception { + URI clientEndpoint = URI.create("https://myservice.amazonaws.com"); + URI resolvedUri = URI.create("https://resolved.amazonaws.com/v2"); + + Endpoint endpoint = Endpoint.builder().url(resolvedUri).build(); + SdkHttpFullRequest.Builder request = SdkHttpFullRequest.builder() + .method(SdkHttpMethod.GET) + .protocol("https") + .host("myservice.amazonaws.com") + .encodedPath("/my/operation"); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER, (req, attrs) -> endpoint); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, + ClientEndpointProvider.forEndpointOverride(clientEndpoint)); + RequestExecutionContext context = createContext(); + + SdkHttpFullRequest.Builder result = stage.execute(request, context); + + assertThat(result.encodedPath()).isEqualTo("/v2/my/operation"); + } + + @Test + void execute_resolverReceivesRequestFromInterceptorContext() throws Exception { + SdkRequest modifiedRequest = mock(SdkRequest.class); + SdkRequest[] capturedRequest = new SdkRequest[1]; + + Endpoint endpoint = Endpoint.builder().url(RESOLVED_ENDPOINT).build(); + SdkHttpFullRequest.Builder request = defaultRequest(); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER, (req, attrs) -> { + capturedRequest[0] = req; + return endpoint; + }); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, + ClientEndpointProvider.forEndpointOverride(CLIENT_ENDPOINT)); + + // Use modifiedRequest in interceptor context (different from sdkRequest) + RequestExecutionContext context = createContext(modifiedRequest); + + stage.execute(request, context); + + assertThat(capturedRequest[0]).isSameAs(modifiedRequest); + } + + private SdkHttpFullRequest.Builder defaultRequest() { + return SdkHttpFullRequest.builder() + .method(SdkHttpMethod.GET) + .protocol("https") + .host(CLIENT_ENDPOINT.getHost()) + .encodedPath("/"); + } + + private RequestExecutionContext createContext() { + return createContext(sdkRequest); + } + + private RequestExecutionContext createContext(SdkRequest request) { + InterceptorContext interceptorContext = InterceptorContext.builder() + .request(request) + .build(); + ExecutionContext executionContext = ExecutionContext.builder() + .interceptorContext(interceptorContext) + .executionAttributes(executionAttributes) + .build(); + return RequestExecutionContext.builder() + .executionContext(executionContext) + .originalRequest(request) + .build(); + } +} diff --git a/services/dynamodb/src/test/java/software/amazon/awssdk/services/dynamodb/PaginatorInUserAgentTest.java b/services/dynamodb/src/test/java/software/amazon/awssdk/services/dynamodb/PaginatorInUserAgentTest.java index b4bc3a504391..ef4f2b3b9e70 100644 --- a/services/dynamodb/src/test/java/software/amazon/awssdk/services/dynamodb/PaginatorInUserAgentTest.java +++ b/services/dynamodb/src/test/java/software/amazon/awssdk/services/dynamodb/PaginatorInUserAgentTest.java @@ -39,7 +39,7 @@ import software.amazon.awssdk.core.useragent.BusinessMetricCollection; import software.amazon.awssdk.core.useragent.BusinessMetricFeatureId; import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.dynamodb.endpoints.internal.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.internal.endpoints.AwsEndpointProviderUtils; import software.amazon.awssdk.services.dynamodb.paginators.QueryPublisher; public class PaginatorInUserAgentTest { diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/S3Utilities.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/S3Utilities.java index cb3064798a99..a745071d9e62 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/S3Utilities.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/S3Utilities.java @@ -19,11 +19,11 @@ import java.net.MalformedURLException; import java.net.URI; import java.net.URL; -import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Supplier; import java.util.regex.Matcher; @@ -39,18 +39,18 @@ import software.amazon.awssdk.awscore.endpoint.DualstackEnabledProvider; import software.amazon.awssdk.awscore.endpoint.FipsEnabledProvider; import software.amazon.awssdk.awscore.internal.defaultsmode.DefaultsModeConfiguration; +import software.amazon.awssdk.awscore.internal.endpoints.AwsEndpointProviderUtils; import software.amazon.awssdk.core.ClientEndpointProvider; import software.amazon.awssdk.core.ClientType; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; -import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; -import software.amazon.awssdk.core.interceptor.ExecutionInterceptorChain; -import software.amazon.awssdk.core.interceptor.InterceptorContext; import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.endpoints.Endpoint; import software.amazon.awssdk.http.SdkHttpFullRequest; import software.amazon.awssdk.http.SdkHttpMethod; import software.amazon.awssdk.http.SdkHttpRequest; @@ -62,9 +62,9 @@ import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.services.s3.endpoints.S3ClientContextParams; +import software.amazon.awssdk.services.s3.endpoints.S3EndpointParams; import software.amazon.awssdk.services.s3.endpoints.S3EndpointProvider; -import software.amazon.awssdk.services.s3.endpoints.internal.S3RequestSetEndpointInterceptor; -import software.amazon.awssdk.services.s3.endpoints.internal.S3ResolveEndpointInterceptor; +import software.amazon.awssdk.services.s3.endpoints.internal.S3EndpointResolverUtils; import software.amazon.awssdk.services.s3.internal.endpoints.UseGlobalEndpointResolver; import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.services.s3.model.GetUrlRequest; @@ -110,7 +110,6 @@ public final class S3Utilities { private final Supplier profileFile; private final String profileName; private final boolean fipsEnabled; - private final ExecutionInterceptorChain interceptorChain; private final UseGlobalEndpointResolver useGlobalEndpointResolver; /** @@ -140,8 +139,6 @@ private S3Utilities(Builder builder) { .isFipsEnabled() .orElse(false); - this.interceptorChain = createEndpointInterceptorChain(); - this.useGlobalEndpointResolver = createUseGlobalEndpointResolver(); } @@ -243,14 +240,9 @@ public URL getUrl(GetUrlRequest getUrlRequest) { .versionId(getUrlRequest.versionId()) .build(); - InterceptorContext interceptorContext = InterceptorContext.builder() - .httpRequest(marshalledRequest) - .request(getObjectRequest) - .build(); - ExecutionAttributes executionAttributes = createExecutionAttributes(clientEndpoint, resolvedRegion); - SdkHttpRequest modifiedRequest = runInterceptors(interceptorContext, executionAttributes).httpRequest(); + SdkHttpRequest modifiedRequest = resolveEndpointAndApply(getObjectRequest, marshalledRequest, executionAttributes); try { return modifiedRequest.getUri().toURL(); } catch (MalformedURLException exception) { @@ -506,16 +498,36 @@ private AttributeMap createClientContextParams() { return params.build(); } - private InterceptorContext runInterceptors(InterceptorContext context, ExecutionAttributes executionAttributes) { - context = interceptorChain.modifyRequest(context, executionAttributes); - return interceptorChain.modifyHttpRequestAndHttpContent(context, executionAttributes); - } + private SdkHttpRequest resolveEndpointAndApply(GetObjectRequest request, SdkHttpRequest httpRequest, + ExecutionAttributes executionAttributes) { + S3EndpointParams endpointParams = S3EndpointResolverUtils.ruleParams(request, executionAttributes); + S3EndpointProvider provider = (S3EndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + + Endpoint endpoint; + try { + endpoint = provider.resolveEndpoint(endpointParams).join(); + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + + ClientEndpointProvider clientEndpointProvider = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER); + SdkHttpRequest result = AwsEndpointProviderUtils.setUri(httpRequest, + clientEndpointProvider.clientEndpoint(), endpoint.url()); + + if (!endpoint.headers().isEmpty()) { + SdkHttpRequest.Builder builder = result.toBuilder(); + endpoint.headers().forEach((name, values) -> + values.forEach(v -> builder.appendHeader(name, v))); + result = builder.build(); + } - private ExecutionInterceptorChain createEndpointInterceptorChain() { - List interceptors = new ArrayList<>(); - interceptors.add(new S3ResolveEndpointInterceptor()); - interceptors.add(new S3RequestSetEndpointInterceptor()); - return new ExecutionInterceptorChain(interceptors); + return result; } private UseGlobalEndpointResolver createUseGlobalEndpointResolver() { diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/signing/DefaultS3Presigner.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/signing/DefaultS3Presigner.java index 48e09c49a6ba..0aa198a34f48 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/signing/DefaultS3Presigner.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/signing/DefaultS3Presigner.java @@ -32,7 +32,9 @@ import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.function.Function; +import java.util.stream.Collectors; import java.util.stream.Stream; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.auth.signer.AwsSignerExecutionAttribute; @@ -41,19 +43,26 @@ import software.amazon.awssdk.awscore.client.builder.AwsDefaultClientBuilder; import software.amazon.awssdk.awscore.defaultsmode.DefaultsMode; import software.amazon.awssdk.awscore.endpoint.AwsClientEndpointProvider; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.internal.AwsExecutionContextBuilder; import software.amazon.awssdk.awscore.internal.defaultsmode.DefaultsModeConfiguration; +import software.amazon.awssdk.awscore.internal.endpoints.AwsEndpointProviderUtils; import software.amazon.awssdk.awscore.presigner.PresignRequest; import software.amazon.awssdk.awscore.presigner.PresignedRequest; +import software.amazon.awssdk.core.ClientEndpointProvider; import software.amazon.awssdk.core.ClientType; import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.core.SdkClient; import software.amazon.awssdk.core.SdkRequest; import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.builder.SdkDefaultClientBuilder; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.ExecutionContext; +import software.amazon.awssdk.core.identity.SdkIdentityProperty; import software.amazon.awssdk.core.interceptor.ClasspathInterceptorChainFactory; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; @@ -61,9 +70,12 @@ import software.amazon.awssdk.core.interceptor.InterceptorContext; import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.core.internal.http.auth.AuthSchemeResolver; import software.amazon.awssdk.core.signer.Presigner; import software.amazon.awssdk.core.signer.Signer; import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.endpoints.EndpointProvider; import software.amazon.awssdk.http.ContentStreamProvider; import software.amazon.awssdk.http.SdkHttpFullRequest; import software.amazon.awssdk.http.SdkHttpMethod; @@ -71,6 +83,7 @@ import software.amazon.awssdk.http.auth.aws.scheme.AwsV4AuthScheme; import software.amazon.awssdk.http.auth.aws.scheme.AwsV4aAuthScheme; import software.amazon.awssdk.http.auth.aws.signer.AwsV4FamilyHttpSigner; +import software.amazon.awssdk.http.auth.aws.signer.RegionSet; import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; @@ -84,12 +97,13 @@ import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.s3.S3Configuration; +import software.amazon.awssdk.services.s3.auth.scheme.S3AuthSchemeParams; import software.amazon.awssdk.services.s3.auth.scheme.S3AuthSchemeProvider; -import software.amazon.awssdk.services.s3.auth.scheme.internal.S3AuthSchemeInterceptor; +import software.amazon.awssdk.services.s3.auth.scheme.internal.S3EndpointResolverAware; import software.amazon.awssdk.services.s3.endpoints.S3ClientContextParams; +import software.amazon.awssdk.services.s3.endpoints.S3EndpointParams; import software.amazon.awssdk.services.s3.endpoints.S3EndpointProvider; -import software.amazon.awssdk.services.s3.endpoints.internal.S3RequestSetEndpointInterceptor; -import software.amazon.awssdk.services.s3.endpoints.internal.S3ResolveEndpointInterceptor; +import software.amazon.awssdk.services.s3.endpoints.internal.S3EndpointResolverUtils; import software.amazon.awssdk.services.s3.internal.endpoints.UseGlobalEndpointResolver; import software.amazon.awssdk.services.s3.internal.s3express.S3ExpressAuthSchemeProvider; import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; @@ -131,6 +145,7 @@ import software.amazon.awssdk.services.s3.transform.PutObjectRequestMarshaller; import software.amazon.awssdk.services.s3.transform.UploadPartRequestMarshaller; import software.amazon.awssdk.utils.AttributeMap; +import software.amazon.awssdk.utils.CollectionUtils; import software.amazon.awssdk.utils.CompletableFutureUtils; import software.amazon.awssdk.utils.IoUtils; import software.amazon.awssdk.utils.Logger; @@ -239,11 +254,6 @@ private List initializeInterceptors() { ClasspathInterceptorChainFactory interceptorFactory = new ClasspathInterceptorChainFactory(); List s3Interceptors = interceptorFactory.getInterceptors("software/amazon/awssdk/services/s3/execution.interceptors"); - List additionalInterceptors = new ArrayList<>(); - additionalInterceptors.add(new S3AuthSchemeInterceptor()); - additionalInterceptors.add(new S3ResolveEndpointInterceptor()); - additionalInterceptors.add(new S3RequestSetEndpointInterceptor()); - s3Interceptors = mergeLists(s3Interceptors, additionalInterceptors); return mergeLists(interceptorFactory.getGlobalInterceptors(), s3Interceptors); } @@ -405,6 +415,12 @@ private T presign(T presignedRequest, addRequestLevelHeadersAndQueryParameters(execCtx); callModifyHttpRequestHooksAndUpdateContext(execCtx); + // Resolve auth scheme after interceptors complete + resolveAndSelectAuthScheme(execCtx, requestToPresign, operationName); + + // Resolve endpoint + resolveEndpointAndUpdateContext(execCtx, operationName); + SdkHttpFullRequest httpRequest = getHttpFullRequest(execCtx); SdkHttpFullRequest signedHttpRequest = execCtx.signer() != null @@ -594,6 +610,121 @@ private SdkHttpFullRequest getHttpFullRequest(ExecutionContext execCtx) { .build(); } + /** + * Resolve and select auth scheme for presigning. + */ + private void resolveAndSelectAuthScheme(ExecutionContext execCtx, SdkRequest request, String operationName) { + ExecutionAttributes executionAttributes = execCtx.executionAttributes(); + + Map> authSchemes = executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES); + if (authSchemes == null) { + return; + } + + List authOptions = resolveAuthSchemeOptions(request, operationName, executionAttributes); + if (authOptions == null || authOptions.isEmpty()) { + return; + } + + IdentityProviders identityProviders = executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS); + + SelectedAuthScheme selectedAuthScheme = + AuthSchemeResolver.selectAuthScheme(authOptions, authSchemes, identityProviders, null); + + selectedAuthScheme = AuthSchemeResolver.mergePreExistingAuthSchemeProperties(selectedAuthScheme, executionAttributes); + + executionAttributes.putAttribute(SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + + /** + * Resolve the endpoint using the rules engine and apply it to the HTTP request in the execution context. + */ + private void resolveEndpointAndUpdateContext(ExecutionContext execCtx, String operationName) { + ExecutionAttributes executionAttributes = execCtx.executionAttributes(); + SdkRequest sdkRequest = execCtx.interceptorContext().request(); + + S3EndpointParams endpointParams = S3EndpointResolverUtils.ruleParams(sdkRequest, executionAttributes); + S3EndpointProvider provider = (S3EndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + + Endpoint endpoint; + try { + endpoint = provider.resolveEndpoint(endpointParams).join(); + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = S3EndpointResolverUtils.hostPrefix(operationName, sdkRequest); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes.getAttribute(SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = S3EndpointResolverUtils.authSchemeWithEndpointSignerProperties( + endpointAuthSchemes, selectedAuthScheme); + executionAttributes.putAttribute(SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + + SdkHttpRequest httpRequest = execCtx.interceptorContext().httpRequest(); + ClientEndpointProvider clientEndpointProvider = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER); + SdkHttpRequest updatedRequest = AwsEndpointProviderUtils.setUri(httpRequest, + clientEndpointProvider.clientEndpoint(), endpoint.url()); + + if (!endpoint.headers().isEmpty()) { + SdkHttpRequest.Builder requestBuilder = updatedRequest.toBuilder(); + endpoint.headers().forEach((name, values) -> + values.forEach(v -> requestBuilder.appendHeader(name, v))); + updatedRequest = requestBuilder.build(); + } + + SdkHttpRequest finalRequest = updatedRequest; + execCtx.interceptorContext(execCtx.interceptorContext().copy(c -> c.httpRequest(finalRequest))); + } + + /** + * Resolve auth scheme options using full endpoint params. + */ + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + ExecutionAttributes executionAttributes) { + S3AuthSchemeProvider authSchemeProvider = Validate.isInstanceOf(S3AuthSchemeProvider.class, + executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_RESOLVER), + "Expected an instance of S3AuthSchemeProvider"); + + S3EndpointParams endpointParams = S3EndpointResolverUtils.ruleParams(request, executionAttributes); + S3AuthSchemeParams.Builder paramsBuilder = S3AuthSchemeParams.fromEndpointParams(endpointParams) + .operation(operationName); + + executionAttributes.getOptionalAttribute(AwsExecutionAttribute.AWS_SIGV4A_SIGNING_REGION_SET) + .filter(regionSet -> !CollectionUtils.isNullOrEmpty(regionSet)) + .ifPresent(nonEmptyRegionSet -> paramsBuilder.regionSet(RegionSet.create(nonEmptyRegionSet))); + + if (paramsBuilder instanceof S3EndpointResolverAware.Builder) { + EndpointProvider endpointProvider = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + if (endpointProvider instanceof S3EndpointProvider) { + ((S3EndpointResolverAware.Builder) paramsBuilder).endpointProvider((S3EndpointProvider) endpointProvider); + } + } + + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + SdkClient sdkClient = executionAttributes.getAttribute(SdkInternalExecutionAttribute.SDK_CLIENT); + if (sdkClient != null) { + options = options.stream() + .map(o -> o.toBuilder().putIdentityProperty(SdkIdentityProperty.SDK_CLIENT, sdkClient).build()) + .collect(Collectors.toList()); + } + return options; + } + /** * Presign the provided HTTP request using old Signer */ diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/functionaltests/AsyncResponseTransformerTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/functionaltests/AsyncResponseTransformerTest.java index 4c2c5e748c0b..155306ab9faf 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/functionaltests/AsyncResponseTransformerTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/functionaltests/AsyncResponseTransformerTest.java @@ -17,6 +17,7 @@ import static org.assertj.core.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -50,6 +51,6 @@ public void AsyncResponseTransformerPrepareCalled_BeforeCredentialsResolution() } verify(asyncResponseTransformer, times(1)).prepare(); - verify(asyncResponseTransformer, times(1)).exceptionOccurred(any()); + verify(asyncResponseTransformer, atLeast(1)).exceptionOccurred(any()); } } diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressAuthSchemeProviderTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressAuthSchemeProviderTest.java index 13c7b37ab958..1a176f1e119b 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressAuthSchemeProviderTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressAuthSchemeProviderTest.java @@ -17,125 +17,48 @@ import static org.assertj.core.api.AssertionsForClassTypes.assertThat; -import java.net.URI; -import java.util.HashMap; -import java.util.Map; -import java.util.concurrent.CompletableFuture; +import java.util.List; import org.junit.jupiter.api.Test; -import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; -import software.amazon.awssdk.awscore.AwsExecutionAttribute; -import software.amazon.awssdk.core.ClientEndpointProvider; -import software.amazon.awssdk.core.SelectedAuthScheme; -import software.amazon.awssdk.core.interceptor.ExecutionAttributes; -import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; -import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; -import software.amazon.awssdk.http.auth.aws.scheme.AwsV4AuthScheme; -import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; -import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; -import software.amazon.awssdk.identity.spi.IdentityProvider; -import software.amazon.awssdk.identity.spi.IdentityProviders; -import software.amazon.awssdk.identity.spi.ResolveIdentityRequest; -import software.amazon.awssdk.identity.spi.internal.DefaultIdentityProviders; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.auth.scheme.S3AuthSchemeParams; +import software.amazon.awssdk.services.s3.auth.scheme.S3AuthSchemeProvider; import software.amazon.awssdk.services.s3.auth.scheme.internal.DefaultS3AuthSchemeProvider; -import software.amazon.awssdk.services.s3.auth.scheme.internal.S3AuthSchemeInterceptor; -import software.amazon.awssdk.services.s3.endpoints.S3ClientContextParams; -import software.amazon.awssdk.services.s3.endpoints.S3EndpointProvider; -import software.amazon.awssdk.services.s3.model.PutObjectRequest; -import software.amazon.awssdk.services.s3.s3express.S3ExpressAuthScheme; -import software.amazon.awssdk.services.s3.s3express.S3ExpressSessionCredentials; -import software.amazon.awssdk.utils.AttributeMap; class S3ExpressAuthSchemeProviderTest { private static final String S3EXPRESS_BUCKET = "s3expressformat--use1-az1--x-s3"; @Test - public void s3express_defaultAuthEnabled_returnspressAuthScheme() { - PutObjectRequest request = PutObjectRequest.builder().bucket(S3EXPRESS_BUCKET).key("k").build(); - AttributeMap clientContextParams = AttributeMap.builder().build(); - ExecutionAttributes executionAttributes = requiredExecutionAttributes(clientContextParams); - - new S3AuthSchemeInterceptor().beforeExecution(() -> request, executionAttributes); - - SelectedAuthScheme attribute = executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); - assertThat(attribute).isNotNull(); - verifyAuthScheme("aws.auth#sigv4-s3express", attribute); - } - - private ExecutionAttributes requiredExecutionAttributes(AttributeMap clientContextParams) { - ExecutionAttributes executionAttributes = new ExecutionAttributes(); - executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_RESOLVER, DefaultS3AuthSchemeProvider.create()); - executionAttributes.putAttribute(SdkExecutionAttribute.OPERATION_NAME, "PutObject"); - executionAttributes.putAttribute(AwsExecutionAttribute.AWS_REGION, Region.US_EAST_1); - executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_CONTEXT_PARAMS, clientContextParams); - executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, authSchemesWithS3Express()); - executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, - DefaultIdentityProviders.builder() - .putIdentityProvider(DefaultCredentialsProvider.create()) - .build()); - executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, - ClientEndpointProvider.forEndpointOverride(URI.create("https://localhost"))); - executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER, - S3EndpointProvider.defaultProvider()); - return executionAttributes; + public void s3express_defaultAuthEnabled_returnsS3ExpressAuthScheme() { + S3AuthSchemeProvider provider = DefaultS3AuthSchemeProvider.create(); + S3AuthSchemeParams params = S3AuthSchemeParams.builder() + .operation("PutObject") + .region(Region.US_EAST_1) + .bucket(S3EXPRESS_BUCKET) + .build(); + + List authOptions = provider.resolveAuthScheme(params); + + assertThat(authOptions).isNotNull(); + assertThat(authOptions.isEmpty()).isFalse(); + assertThat(authOptions.get(0).schemeId()).isEqualTo("aws.auth#sigv4-s3express"); } @Test public void s3express_authDisabled_returnsV4AuthScheme() { - PutObjectRequest request = PutObjectRequest.builder().bucket(S3EXPRESS_BUCKET).key("k").build(); - AttributeMap clientContextParams = AttributeMap.builder() - .put(S3ClientContextParams.DISABLE_S3_EXPRESS_SESSION_AUTH, true) - .build(); - ExecutionAttributes executionAttributes = requiredExecutionAttributes(clientContextParams); - - new S3AuthSchemeInterceptor().beforeExecution(() -> request, executionAttributes); - - SelectedAuthScheme attribute = executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); - assertThat(attribute).isNotNull(); - verifyAuthScheme("aws.auth#sigv4", attribute); - } - - private void verifyAuthScheme(String expectedAuthSchemeId, SelectedAuthScheme authScheme) { - assertThat(authScheme).isNotNull(); - assertThat(authScheme.authSchemeOption()).isNotNull(); - assertThat(authScheme.authSchemeOption().schemeId()).isEqualTo(expectedAuthSchemeId); - - assertThat(authScheme.identity()).isNotNull(); - assertThat(authScheme.signer()).isNotNull(); - } - - private Map> authSchemesWithS3Express() { - Map> schemes = new HashMap<>(); - AwsV4AuthScheme awsV4AuthScheme = AwsV4AuthScheme.create(); - schemes.put(awsV4AuthScheme.schemeId(), awsV4AuthScheme); - S3ExpressAuthScheme s3ExpressAuthScheme = new S3ExpressAuthScheme() { - @Override - public IdentityProvider identityProvider(IdentityProviders providers) { - return new IdentityProvider() { - @Override - public Class identityType() { - return null; - } - - @Override - public CompletableFuture resolveIdentity(ResolveIdentityRequest request) { - return CompletableFuture.completedFuture(S3ExpressSessionCredentials.create("a","b","c")); - } - }; - } - - @Override - public HttpSigner signer() { - return DefaultS3ExpressHttpSigner.create(); - } - - @Override - public String schemeId() { - return SCHEME_ID; - } - }; - schemes.put(s3ExpressAuthScheme.schemeId(), s3ExpressAuthScheme); - return schemes; + S3AuthSchemeProvider provider = DefaultS3AuthSchemeProvider.create(); + S3AuthSchemeParams params = S3AuthSchemeParams.builder() + .operation("PutObject") + .region(Region.US_EAST_1) + .bucket(S3EXPRESS_BUCKET) + .disableS3ExpressSessionAuth(true) + .build(); + + List authOptions = provider.resolveAuthScheme(params); + + assertThat(authOptions).isNotNull(); + assertThat(authOptions.isEmpty()).isFalse(); + assertThat(authOptions.get(0).schemeId()).isEqualTo("aws.auth#sigv4"); } -} \ No newline at end of file +} diff --git a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/AuthSchemeInterceptorTest.java b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/AuthSchemeInterceptorTest.java deleted file mode 100644 index e187ab4435d5..000000000000 --- a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/AuthSchemeInterceptorTest.java +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package software.amazon.awssdk.services; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider; -import software.amazon.awssdk.core.SelectedAuthScheme; -import software.amazon.awssdk.core.interceptor.Context; -import software.amazon.awssdk.core.interceptor.ExecutionAttributes; -import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; -import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; -import software.amazon.awssdk.http.auth.aws.scheme.AwsV4AuthScheme; -import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; -import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; -import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; -import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; -import software.amazon.awssdk.identity.spi.IdentityProvider; -import software.amazon.awssdk.identity.spi.IdentityProviders; -import software.amazon.awssdk.services.protocolrestjson.auth.scheme.ProtocolRestJsonAuthSchemeParams; -import software.amazon.awssdk.services.protocolrestjson.auth.scheme.ProtocolRestJsonAuthSchemeProvider; -import software.amazon.awssdk.services.protocolrestjson.auth.scheme.internal.ProtocolRestJsonAuthSchemeInterceptor; - -public class AuthSchemeInterceptorTest { - private static final ProtocolRestJsonAuthSchemeInterceptor INTERCEPTOR = new ProtocolRestJsonAuthSchemeInterceptor(); - - private Context.BeforeExecution mockContext; - - @BeforeEach - public void setup() { - mockContext = mock(Context.BeforeExecution.class); - } - - @Test - public void resolveAuthScheme_authSchemeSignerThrows_continuesToNextAuthScheme() { - ProtocolRestJsonAuthSchemeProvider mockAuthSchemeProvider = mock(ProtocolRestJsonAuthSchemeProvider.class); - List authSchemeOptions = Arrays.asList( - AuthSchemeOption.builder().schemeId(TestAuthScheme.SCHEME_ID).build(), - AuthSchemeOption.builder().schemeId(AwsV4AuthScheme.SCHEME_ID).build() - ); - when(mockAuthSchemeProvider.resolveAuthScheme(any(ProtocolRestJsonAuthSchemeParams.class))).thenReturn(authSchemeOptions); - - IdentityProviders mockIdentityProviders = mock(IdentityProviders.class); - when(mockIdentityProviders.identityProvider(any(Class.class))).thenReturn(AnonymousCredentialsProvider.create()); - - Map> authSchemes = new HashMap<>(); - authSchemes.put(AwsV4AuthScheme.SCHEME_ID, AwsV4AuthScheme.create()); - - TestAuthScheme notProvidedAuthScheme = spy(new TestAuthScheme()); - authSchemes.put(TestAuthScheme.SCHEME_ID, notProvidedAuthScheme); - - ExecutionAttributes attributes = new ExecutionAttributes(); - attributes.putAttribute(SdkExecutionAttribute.OPERATION_NAME, "GetFoo"); - attributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_RESOLVER, mockAuthSchemeProvider); - attributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, mockIdentityProviders); - attributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, authSchemes); - - INTERCEPTOR.beforeExecution(mockContext, attributes); - - SelectedAuthScheme selectedAuthScheme = attributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); - - verify(notProvidedAuthScheme).signer(); - assertThat(selectedAuthScheme.authSchemeOption().schemeId()).isEqualTo(AwsV4AuthScheme.SCHEME_ID); - } - - private static class TestAuthScheme implements AuthScheme { - public static final String SCHEME_ID = "codegen-test-scheme"; - - @Override - public String schemeId() { - return SCHEME_ID; - } - - @Override - public IdentityProvider identityProvider(IdentityProviders providers) { - return providers.identityProvider(AwsCredentialsIdentity.class); - } - - @Override - public HttpSigner signer() { - throw new RuntimeException("Not on classpath"); - } - } -} diff --git a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/IdentityResolutionOverrideTest.java b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/IdentityResolutionOverrideTest.java index f20c84a25756..daf73cd838e3 100644 --- a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/IdentityResolutionOverrideTest.java +++ b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/IdentityResolutionOverrideTest.java @@ -71,9 +71,8 @@ void when_apiCall_setsCredentialsProviderInRequestOverride_overrideCredentialsAr assertSelectedAuthSchemeBeforeTransmissionContains(OVERRIDE_CREDENTIALS); } - // Changing the credentials provider in modifyRequest does not work in SRA identity resolution - // Identity is resolved in beforeExecution (and happens before user applied interceptors) and cannot - // be affected by execution interceptors. + // After moving identity resolution to pipeline stage (after interceptors), credentials provider + // set in modifyRequest is now respected (identity resolved after interceptors complete) @Test void when_executionInterceptorModifyRequest_setsCredentialProviderInRequestOverride_clientCredentialsAreUsed() { ExecutionInterceptor overridingInterceptor = @@ -83,7 +82,7 @@ void when_executionInterceptorModifyRequest_setsCredentialProviderInRequestOverr assertThatThrownBy(() -> syncClient.allTypes(r -> {})).hasMessageContaining("stop"); - assertSelectedAuthSchemeBeforeTransmissionContains(CLIENT_CREDENTIALS); + assertSelectedAuthSchemeBeforeTransmissionContains(OVERRIDE_CREDENTIALS); } @Test diff --git a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/AwsEndpointProviderUtilsTest.java b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/AwsEndpointProviderUtilsTest.java index 83d9d1d0fb59..25878e3cd9e2 100644 --- a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/AwsEndpointProviderUtilsTest.java +++ b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/AwsEndpointProviderUtilsTest.java @@ -29,7 +29,7 @@ import software.amazon.awssdk.http.SdkHttpMethod; import software.amazon.awssdk.http.SdkHttpRequest; import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.restjsonendpointproviders.endpoints.internal.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.internal.endpoints.AwsEndpointProviderUtils; public class AwsEndpointProviderUtilsTest { @Test diff --git a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/EndpointInterceptorTests.java b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/EndpointInterceptorTests.java index d24456368fbd..df32b88193b2 100644 --- a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/EndpointInterceptorTests.java +++ b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/EndpointInterceptorTests.java @@ -177,7 +177,7 @@ public void sync_endpointProviderReturnsHeaders_appendedToExistingRequest() { "TestValue0")))) .hasMessageContaining("stop"); - assertThat(interceptor.context.httpRequest().matchingHeaders("TestHeader")).containsExactly("TestValue", "TestValue0"); + assertThat(interceptor.context.httpRequest().matchingHeaders("TestHeader")).containsExactly("TestValue0", "TestValue"); } @Test @@ -198,7 +198,7 @@ public void async_endpointProviderReturnsHeaders_appendedToExistingRequest() { .join()) .hasMessageContaining("stop"); - assertThat(interceptor.context.httpRequest().matchingHeaders("TestHeader")).containsExactly("TestValue", "TestValue0"); + assertThat(interceptor.context.httpRequest().matchingHeaders("TestHeader")).containsExactly("TestValue0", "TestValue"); } diff --git a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/RequestOverrideEndpointProviderTests.java b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/RequestOverrideEndpointProviderTests.java index 075976d136c8..07dd11000688 100644 --- a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/RequestOverrideEndpointProviderTests.java +++ b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/RequestOverrideEndpointProviderTests.java @@ -18,16 +18,12 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import java.net.URI; -import java.util.concurrent.CompletableFuture; -import org.junit.Test; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; import software.amazon.awssdk.core.client.config.SdkAdvancedClientOption; import software.amazon.awssdk.core.interceptor.Context; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; -import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.endpoints.Endpoint; import software.amazon.awssdk.http.SdkHttpRequest; import software.amazon.awssdk.regions.Region; @@ -37,6 +33,7 @@ import software.amazon.awssdk.services.restjsonendpointproviders.RestJsonEndpointProvidersClientBuilder; import software.amazon.awssdk.services.restjsonendpointproviders.endpoints.RestJsonEndpointProvidersEndpointParams; import software.amazon.awssdk.services.restjsonendpointproviders.endpoints.RestJsonEndpointProvidersEndpointProvider; +import org.junit.Test; public class RequestOverrideEndpointProviderTests { @@ -49,34 +46,28 @@ public void sync_endpointOverridden_equals_requestOverride() { .build(); assertThatThrownBy(() -> client.operationWithHostPrefix( - r -> r.overrideConfiguration(o -> o.endpointProvider(new CustomEndpointProvider(Region.AWS_GLOBAL)))) - - ) + r -> r.overrideConfiguration(o -> o.endpointProvider(new CustomEndpointProvider(Region.AWS_GLOBAL))))) .hasMessageContaining("stop"); - Endpoint endpoint = interceptor.executionAttributes().getAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT); - assertThat(endpoint.url().getHost()).isEqualTo("restjson.aws-global.amazonaws.com"); + + assertThat(interceptor.httpRequest.host()).isEqualTo("restjson.aws-global.amazonaws.com"); } @Test public void sync_endpointOverridden_equals_ClientsWhenNoRequestOverride() { CapturingInterceptor interceptor = new CapturingInterceptor(); - RestJsonEndpointProvidersClient client = syncClientBuilder().endpointProvider(new CustomEndpointProvider(Region.EU_WEST_2)) + RestJsonEndpointProvidersClient client = syncClientBuilder() + .endpointProvider(new CustomEndpointProvider(Region.EU_WEST_2)) .overrideConfiguration(o -> o.addExecutionInterceptor(interceptor) .putAdvancedOption(SdkAdvancedClientOption.DISABLE_HOST_PREFIX_INJECTION, true)) .build(); assertThatThrownBy(() -> client.operationWithHostPrefix(r -> {})).hasMessageContaining("stop"); - Endpoint endpoint = interceptor.executionAttributes().getAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT); - - assertThat(endpoint.url().getHost()).isEqualTo("restjson.eu-west-2.amazonaws.com"); + assertThat(interceptor.httpRequest.host()).isEqualTo("restjson.eu-west-2.amazonaws.com"); } - @Test public void async_endpointOverridden_equals_requestOverride() { - - CapturingInterceptor interceptor = new CapturingInterceptor(); RestJsonEndpointProvidersAsyncClient client = asyncClientBuilder() .overrideConfiguration(o -> o.addExecutionInterceptor(interceptor) @@ -88,18 +79,14 @@ public void async_endpointOverridden_equals_requestOverride() { ).join()) .hasMessageContaining("stop"); - Endpoint endpoint = interceptor.executionAttributes().getAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT); - - assertThat(endpoint.url().getHost()).isEqualTo("restjson.aws-global.amazonaws.com"); + assertThat(interceptor.httpRequest.host()).isEqualTo("restjson.aws-global.amazonaws.com"); } - @Test public void async_endpointOverridden_equals_ClientsWhenNoRequestOverride() { - - CapturingInterceptor interceptor = new CapturingInterceptor(); - RestJsonEndpointProvidersAsyncClient client = asyncClientBuilder().endpointProvider(new CustomEndpointProvider(Region.EU_WEST_2)) + RestJsonEndpointProvidersAsyncClient client = asyncClientBuilder() + .endpointProvider(new CustomEndpointProvider(Region.EU_WEST_2)) .overrideConfiguration(o -> o.addExecutionInterceptor(interceptor) .putAdvancedOption(SdkAdvancedClientOption.DISABLE_HOST_PREFIX_INJECTION, true)) .build(); @@ -107,31 +94,25 @@ public void async_endpointOverridden_equals_ClientsWhenNoRequestOverride() { assertThatThrownBy(() -> client.operationWithHostPrefix(r -> {}).join()) .hasMessageContaining("stop"); - Endpoint endpoint = interceptor.executionAttributes().getAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT); - - assertThat(endpoint.url().getHost()).isEqualTo("restjson.eu-west-2.amazonaws.com"); + assertThat(interceptor.httpRequest.host()).isEqualTo("restjson.eu-west-2.amazonaws.com"); } public static class CapturingInterceptor implements ExecutionInterceptor { - - private ExecutionAttributes executionAttributes; + private SdkHttpRequest httpRequest; @Override - public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) { - this.executionAttributes = executionAttributes; + public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) { + this.httpRequest = context.httpRequest(); throw new CaptureCompletedException("stop"); } - public ExecutionAttributes executionAttributes() { - return executionAttributes; - } - public class CaptureCompletedException extends RuntimeException { CaptureCompletedException(String message) { super(message); } } } + private RestJsonEndpointProvidersClientBuilder syncClientBuilder() { return RestJsonEndpointProvidersClient.builder() .region(Region.US_WEST_2) @@ -142,23 +123,23 @@ private RestJsonEndpointProvidersClientBuilder syncClientBuilder() { private RestJsonEndpointProvidersAsyncClientBuilder asyncClientBuilder() { return RestJsonEndpointProvidersAsyncClient.builder() - .region(Region.US_WEST_2) - .credentialsProvider( - StaticCredentialsProvider.create( - AwsBasicCredentials.create("akid", "skid"))); + .region(Region.US_WEST_2) + .credentialsProvider( + StaticCredentialsProvider.create( + AwsBasicCredentials.create("akid", "skid"))); } + public static class CustomEndpointProvider implements RestJsonEndpointProvidersEndpointProvider { + private final Region region; - class CustomEndpointProvider implements RestJsonEndpointProvidersEndpointProvider{ - final RestJsonEndpointProvidersEndpointProvider endpointProvider; - final Region overridingRegion; - CustomEndpointProvider (Region region){ - this.endpointProvider = RestJsonEndpointProvidersEndpointProvider.defaultProvider(); - this.overridingRegion = region; + CustomEndpointProvider(Region region) { + this.region = region; } + @Override - public CompletableFuture resolveEndpoint(RestJsonEndpointProvidersEndpointParams endpointParams) { - return endpointProvider.resolveEndpoint(endpointParams.toBuilder().region(overridingRegion).build()); + public java.util.concurrent.CompletableFuture resolveEndpoint(RestJsonEndpointProvidersEndpointParams params) { + return RestJsonEndpointProvidersEndpointProvider.defaultProvider() + .resolveEndpoint(params.toBuilder().region(region).build()); } } }