diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/AuthSchemeGeneratorTasks.java b/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/AuthSchemeGeneratorTasks.java index 4b6719cc709c..0794a165181c 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/AuthSchemeGeneratorTasks.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/AuthSchemeGeneratorTasks.java @@ -20,7 +20,6 @@ import software.amazon.awssdk.codegen.emitters.GeneratorTask; import software.amazon.awssdk.codegen.emitters.GeneratorTaskParams; import software.amazon.awssdk.codegen.emitters.PoetGeneratorTask; -import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeInterceptorSpec; import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeParamsSpec; import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeProviderSpec; import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeSpecUtils; @@ -47,7 +46,6 @@ protected List createTasks() { tasks.add(generateDefaultParamsImpl()); tasks.add(generateModelBasedProvider()); tasks.add(generatePreferenceProvider()); - tasks.add(generateAuthSchemeInterceptor()); if (authSchemeSpecUtils.useEndpointBasedAuthProvider()) { tasks.add(generateEndpointBasedProvider()); tasks.add(generateEndpointAwareAuthSchemeParams()); @@ -86,10 +84,6 @@ private GeneratorTask generateEndpointAwareAuthSchemeParams() { } - private GeneratorTask generateAuthSchemeInterceptor() { - return new PoetGeneratorTask(authSchemeInternalDir(), model.getFileHeader(), new AuthSchemeInterceptorSpec(model)); - } - private String authSchemeDir() { return generatorTaskParams.getPathProvider().getAuthSchemeDirectory(); } diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/EndpointProviderTasks.java b/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/EndpointProviderTasks.java index 7f5d64da8e79..b72a33263c5b 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/EndpointProviderTasks.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/EndpointProviderTasks.java @@ -32,9 +32,8 @@ import software.amazon.awssdk.codegen.poet.rules.EndpointProviderInterfaceSpec; import software.amazon.awssdk.codegen.poet.rules.EndpointProviderSpec; import software.amazon.awssdk.codegen.poet.rules.EndpointProviderTestSpec; -import software.amazon.awssdk.codegen.poet.rules.EndpointResolverInterceptorSpec; +import software.amazon.awssdk.codegen.poet.rules.EndpointResolverUtilsSpec; import software.amazon.awssdk.codegen.poet.rules.EndpointRulesClientTestSpec; -import software.amazon.awssdk.codegen.poet.rules.RequestEndpointInterceptorSpec; import software.amazon.awssdk.codegen.poet.rules2.EndpointProviderSpec2; public final class EndpointProviderTasks extends BaseGeneratorTasks { @@ -103,8 +102,7 @@ private boolean shouldGenerateCompiledEndpointRules() { private Collection generateInterceptors() { return Arrays.asList( - new PoetGeneratorTask(endpointRulesInternalDir(), model.getFileHeader(), new EndpointResolverInterceptorSpec(model)), - new PoetGeneratorTask(endpointRulesInternalDir(), model.getFileHeader(), new RequestEndpointInterceptorSpec(model))); + new PoetGeneratorTask(endpointRulesInternalDir(), model.getFileHeader(), new EndpointResolverUtilsSpec(model))); } private GeneratorTask generateClientTests() { diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeInterceptorSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeInterceptorSpec.java deleted file mode 100644 index c4f9e768eaa3..000000000000 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeInterceptorSpec.java +++ /dev/null @@ -1,524 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package software.amazon.awssdk.codegen.poet.auth.scheme; - -import com.squareup.javapoet.ClassName; -import com.squareup.javapoet.FieldSpec; -import com.squareup.javapoet.MethodSpec; -import com.squareup.javapoet.ParameterSpec; -import com.squareup.javapoet.ParameterizedTypeName; -import com.squareup.javapoet.TypeName; -import com.squareup.javapoet.TypeSpec; -import com.squareup.javapoet.TypeVariableName; -import com.squareup.javapoet.WildcardTypeName; -import java.time.Duration; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.concurrent.CompletableFuture; -import java.util.function.Supplier; -import java.util.stream.Collectors; -import javax.lang.model.element.Modifier; -import software.amazon.awssdk.annotations.SdkInternalApi; -import software.amazon.awssdk.awscore.AwsExecutionAttribute; -import software.amazon.awssdk.codegen.model.intermediate.IntermediateModel; -import software.amazon.awssdk.codegen.poet.ClassSpec; -import software.amazon.awssdk.codegen.poet.PoetUtils; -import software.amazon.awssdk.codegen.poet.rules.EndpointRulesSpecUtils; -import software.amazon.awssdk.core.SdkRequest; -import software.amazon.awssdk.core.SelectedAuthScheme; -import software.amazon.awssdk.core.exception.SdkException; -import software.amazon.awssdk.core.identity.SdkIdentityProperty; -import software.amazon.awssdk.core.interceptor.Context; -import software.amazon.awssdk.core.interceptor.ExecutionAttributes; -import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; -import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; -import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; -import software.amazon.awssdk.core.internal.util.MetricUtils; -import software.amazon.awssdk.core.metrics.CoreMetric; -import software.amazon.awssdk.core.useragent.BusinessMetricFeatureId; -import software.amazon.awssdk.endpoints.EndpointProvider; -import software.amazon.awssdk.http.auth.aws.scheme.AwsV4aAuthScheme; -import software.amazon.awssdk.http.auth.aws.signer.RegionSet; -import software.amazon.awssdk.http.auth.scheme.BearerAuthScheme; -import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; -import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; -import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; -import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; -import software.amazon.awssdk.identity.spi.Identity; -import software.amazon.awssdk.identity.spi.IdentityProvider; -import software.amazon.awssdk.identity.spi.IdentityProviders; -import software.amazon.awssdk.identity.spi.ResolveIdentityRequest; -import software.amazon.awssdk.identity.spi.TokenIdentity; -import software.amazon.awssdk.metrics.MetricCollector; -import software.amazon.awssdk.metrics.SdkMetric; -import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.utils.CollectionUtils; -import software.amazon.awssdk.utils.Logger; -import software.amazon.awssdk.utils.Validate; - -public final class AuthSchemeInterceptorSpec implements ClassSpec { - private final AuthSchemeSpecUtils authSchemeSpecUtils; - private final EndpointRulesSpecUtils endpointRulesSpecUtils; - private final IntermediateModel intermediateModel; - - public AuthSchemeInterceptorSpec(IntermediateModel intermediateModel) { - this.intermediateModel = intermediateModel; - this.authSchemeSpecUtils = new AuthSchemeSpecUtils(intermediateModel); - this.endpointRulesSpecUtils = new EndpointRulesSpecUtils(intermediateModel); - } - - @Override - public ClassName className() { - return authSchemeSpecUtils.authSchemeInterceptor(); - } - - @Override - public TypeSpec poetSpec() { - TypeSpec.Builder builder = PoetUtils.createClassBuilder(className()) - .addSuperinterface(ExecutionInterceptor.class) - .addModifiers(Modifier.PUBLIC, Modifier.FINAL) - .addAnnotation(SdkInternalApi.class); - - builder.addField(FieldSpec.builder(Logger.class, "LOG", Modifier.PRIVATE, Modifier.STATIC) - .initializer("$T.loggerFor($T.class)", Logger.class, className()) - .build()); - - builder.addMethod(generateBeforeExecution()) - .addMethod(generateResolveAuthOptions()) - .addMethod(generateSelectAuthScheme()) - .addMethod(generateAuthSchemeParams()) - .addMethod(generateTrySelectAuthScheme()) - .addMethod(generateGetIdentityMetric()) - .addMethod(putSelectedAuthSchemeMethodSpec()); - if (intermediateModel.getCustomizationConfig().isEnableEnvironmentBearerToken()) { - builder.addMethod(generateEnvironmentTokenMetric()); - } - return builder.build(); - } - - private MethodSpec generateEnvironmentTokenMetric() { - return MethodSpec - .methodBuilder("recordEnvironmentTokenBusinessMetric") - .addModifiers(Modifier.PRIVATE) - .addTypeVariable(TypeVariableName.get("T", Identity.class)) - .addParameter(ParameterSpec.builder( - ParameterizedTypeName.get(ClassName.get(SelectedAuthScheme.class), - TypeVariableName.get("T")), - "selectedAuthScheme").build()) - .addParameter(ExecutionAttributes.class, "executionAttributes") - .addStatement("$T tokenFromEnv = executionAttributes.getAttribute($T.TOKEN_CONFIGURED_FROM_ENV)", - String.class, SdkInternalExecutionAttribute.class) - .beginControlFlow("if (selectedAuthScheme != null && selectedAuthScheme.authSchemeOption().schemeId().equals($T" - + ".SCHEME_ID) && selectedAuthScheme.identity().isDone())", BearerAuthScheme.class) - .beginControlFlow("if (selectedAuthScheme.identity().getNow(null) instanceof $T)", TokenIdentity.class) - - .addStatement("$T configuredToken = ($T) selectedAuthScheme.identity().getNow(null)", - TokenIdentity.class, TokenIdentity.class) - .beginControlFlow("if (configuredToken.token().equals(tokenFromEnv))") - .addStatement("executionAttributes.getAttribute($T.BUSINESS_METRICS)" - + ".addMetric($T.BEARER_SERVICE_ENV_VARS.value())", - SdkInternalExecutionAttribute.class, BusinessMetricFeatureId.class) - .endControlFlow() - .endControlFlow() - .endControlFlow() - .build(); - - - } - - private MethodSpec generateBeforeExecution() { - MethodSpec.Builder builder = MethodSpec.methodBuilder("beforeExecution") - .addAnnotation(Override.class) - .addModifiers(Modifier.PUBLIC) - .addParameter(Context.BeforeExecution.class, - "context") - .addParameter(ExecutionAttributes.class, - "executionAttributes"); - - builder.addStatement("$T authOptions = resolveAuthOptions(context, executionAttributes)", - listOf(AuthSchemeOption.class)) - .addStatement("$T selectedAuthScheme = selectAuthScheme(authOptions, executionAttributes)", - wildcardSelectedAuthScheme()) - .addStatement("putSelectedAuthScheme(executionAttributes, selectedAuthScheme)"); - - if (intermediateModel.getCustomizationConfig().isEnableEnvironmentBearerToken()) { - builder.addStatement("recordEnvironmentTokenBusinessMetric(selectedAuthScheme, " - + "executionAttributes)"); - } - - if (authSchemeSpecUtils.hasSigV4aSupport()) { - builder.beginControlFlow("if (selectedAuthScheme != null && " - + "selectedAuthScheme.authSchemeOption().schemeId().equals($T.SCHEME_ID) && " - + "!$T.isSignerOverridden(context.request(), executionAttributes))", - AwsV4aAuthScheme.class, - ClassName.get("software.amazon.awssdk.awscore.util", "SignerOverrideUtils")) - .addStatement("$T businessMetrics = executionAttributes.getAttribute($T.BUSINESS_METRICS)", - ClassName.get("software.amazon.awssdk.core.useragent", "BusinessMetricCollection"), - SdkInternalExecutionAttribute.class) - .beginControlFlow("if (businessMetrics != null)") - .addStatement("businessMetrics.addMetric($T.SIGV4A_SIGNING.value())", - BusinessMetricFeatureId.class) - .endControlFlow() - .endControlFlow(); - } - return builder.build(); - } - - private MethodSpec generateResolveAuthOptions() { - MethodSpec.Builder builder = MethodSpec.methodBuilder("resolveAuthOptions") - .addModifiers(Modifier.PRIVATE) - .returns(listOf(AuthSchemeOption.class)) - .addParameter(Context.BeforeExecution.class, - "context") - .addParameter(ExecutionAttributes.class, - "executionAttributes"); - - builder.addStatement("$1T authSchemeProvider = $2T.isInstanceOf($1T.class, executionAttributes" - + ".getAttribute($3T.AUTH_SCHEME_RESOLVER), $4S)", - authSchemeSpecUtils.providerInterfaceName(), - Validate.class, - SdkInternalExecutionAttribute.class, - "Expected an instance of " + authSchemeSpecUtils.providerInterfaceName().simpleName()); - builder.addStatement("$T params = authSchemeParams(context.request(), executionAttributes)", - authSchemeSpecUtils.parametersInterfaceName()); - builder.addStatement("return authSchemeProvider.resolveAuthScheme(params)"); - return builder.build(); - } - - private MethodSpec generateAuthSchemeParams() { - MethodSpec.Builder builder = MethodSpec.methodBuilder("authSchemeParams") - .addModifiers(Modifier.PRIVATE) - .returns(authSchemeSpecUtils.parametersInterfaceName()) - .addParameter(SdkRequest.class, "request") - .addParameter(ExecutionAttributes.class, "executionAttributes"); - - if (!authSchemeSpecUtils.useEndpointBasedAuthProvider()) { - builder.addStatement("$T operation = executionAttributes.getAttribute($T.OPERATION_NAME)", String.class, - SdkExecutionAttribute.class); - builder.addStatement("$T.Builder builder = $T.builder().operation(operation)", - authSchemeSpecUtils.parametersInterfaceName(), - authSchemeSpecUtils.parametersInterfaceName()); - - if (authSchemeSpecUtils.usesSigV4()) { - builder.addStatement("$T region = executionAttributes.getAttribute($T.AWS_REGION)", Region.class, - AwsExecutionAttribute.class); - builder.addStatement("builder.region(region)"); - } - generateSigv4aSigningRegionSet(builder); - builder.addStatement("return builder.build()"); - return builder.build(); - } - - builder.addStatement("$T endpointParams = $T.ruleParams(request, executionAttributes)", - endpointRulesSpecUtils.parametersClassName(), - endpointRulesSpecUtils.resolverInterceptorName()); - builder.addStatement("$1T.Builder builder = $1T.builder()", authSchemeSpecUtils.parametersInterfaceName()); - boolean regionIncluded = false; - for (String paramName : endpointRulesSpecUtils.parameters().keySet()) { - if (!authSchemeSpecUtils.includeParamForProvider(paramName)) { - continue; - } - regionIncluded = regionIncluded || paramName.equalsIgnoreCase("region"); - String methodName = endpointRulesSpecUtils.paramMethodName(paramName); - builder.addStatement("builder.$1N(endpointParams.$1N())", methodName); - } - - builder.addStatement("$T operation = executionAttributes.getAttribute($T.OPERATION_NAME)", String.class, - SdkExecutionAttribute.class); - builder.addStatement("builder.operation(operation)"); - if (authSchemeSpecUtils.usesSigV4() && !regionIncluded) { - builder.addStatement("$T region = executionAttributes.getAttribute($T.AWS_REGION)", Region.class, - AwsExecutionAttribute.class); - builder.addStatement("builder.region(region)"); - } - generateSigv4aSigningRegionSet(builder); - ClassName paramsBuilderClass = authSchemeSpecUtils.parametersEndpointAwareDefaultImplName().nestedClass("Builder"); - builder.beginControlFlow("if (builder instanceof $T)", - paramsBuilderClass); - ClassName endpointProviderClass = endpointRulesSpecUtils.providerInterfaceName(); - builder.addStatement("$T endpointProvider = executionAttributes.getAttribute($T.ENDPOINT_PROVIDER)", - EndpointProvider.class, - SdkInternalExecutionAttribute.class); - builder.beginControlFlow("if (endpointProvider instanceof $T)", endpointProviderClass); - builder.addStatement("(($T)builder).endpointProvider(($T)endpointProvider)", paramsBuilderClass, endpointProviderClass); - builder.endControlFlow(); - builder.endControlFlow(); - builder.addStatement("return builder.build()"); - return builder.build(); - } - - private MethodSpec generateSelectAuthScheme() { - MethodSpec.Builder builder = MethodSpec.methodBuilder("selectAuthScheme") - .addModifiers(Modifier.PRIVATE) - .returns(wildcardSelectedAuthScheme()) - .addParameter(listOf(AuthSchemeOption.class), "authOptions") - .addParameter(ExecutionAttributes.class, "executionAttributes"); - - builder.addStatement("$T metricCollector = executionAttributes.getAttribute($T.API_CALL_METRIC_COLLECTOR)", - MetricCollector.class, SdkExecutionAttribute.class) - .addStatement("$T authSchemes = executionAttributes.getAttribute($T.AUTH_SCHEMES)", - mapOf(String.class, wildcardAuthScheme()), - SdkInternalExecutionAttribute.class) - .addStatement("$T identityProviders = executionAttributes.getAttribute($T.IDENTITY_PROVIDERS)", - IdentityProviders.class, SdkInternalExecutionAttribute.class) - .addStatement("$T discardedReasons = new $T<>()", - listOfStringSuppliers(), ArrayList.class); - - builder.beginControlFlow("for ($T authOption : authOptions)", AuthSchemeOption.class); - { - builder.addStatement("$T authScheme = authSchemes.get(authOption.schemeId())", wildcardAuthScheme()) - .addStatement("$T selectedAuthScheme = trySelectAuthScheme(authOption, authScheme, identityProviders, " - + "discardedReasons, metricCollector, executionAttributes)", - wildcardSelectedAuthScheme()); - builder.beginControlFlow("if (selectedAuthScheme != null)"); - { - addLogDebugDiscardedOptions(builder); - builder.addStatement("return selectedAuthScheme") - .endControlFlow(); - } - // end foreach - builder.endControlFlow(); - } - builder.addStatement("throw $T.builder()" - + ".message($S + discardedReasons.stream().map($T::get).collect($T.joining(\", \")))" - + ".build()", - SdkException.class, - "Failed to determine how to authenticate the user: ", - Supplier.class, - Collectors.class); - return builder.build(); - } - - //TODO (s3express) Review "general" identity properties and their propagation - private MethodSpec generateTrySelectAuthScheme() { - MethodSpec.Builder builder = MethodSpec.methodBuilder("trySelectAuthScheme") - .addModifiers(Modifier.PRIVATE) - .returns(namedSelectedAuthScheme()) - .addParameter(AuthSchemeOption.class, "authOption") - .addParameter(namedAuthScheme(), "authScheme") - .addParameter(IdentityProviders.class, "identityProviders") - .addParameter(listOfStringSuppliers(), "discardedReasons") - .addParameter(MetricCollector.class, "metricCollector") - .addParameter(ExecutionAttributes.class, "executionAttributes") - .addTypeVariable(TypeVariableName.get("T", Identity.class)); - - builder.beginControlFlow("if (authScheme == null)"); - { - builder.addStatement("discardedReasons.add(() -> String.format($S, authOption.schemeId()))", - "'%s' is not enabled for this request.") - .addStatement("return null") - .endControlFlow(); - } - builder.addStatement("$T identityProvider = authScheme.identityProvider(identityProviders)", - namedIdentityProvider()); - - builder.beginControlFlow("if (identityProvider == null)"); - { - builder.addStatement("discardedReasons.add(() -> String.format($S, authOption.schemeId()))", - "'%s' does not have an identity provider configured.") - .addStatement("return null") - .endControlFlow(); - } - - builder.addStatement("$T signer", - ParameterizedTypeName.get(ClassName.get(HttpSigner.class), TypeVariableName.get("T"))); - builder.beginControlFlow("try"); - { - builder.addStatement("signer = authScheme.signer()"); - builder.nextControlFlow("catch (RuntimeException e)"); - builder.addStatement("discardedReasons.add(() -> String.format($S, authOption.schemeId(), e.getMessage()))", - "'%s' signer could not be retrieved: %s") - .addStatement("return null") - .endControlFlow(); - } - - - builder.addStatement("$T.Builder identityRequestBuilder = $T.builder()", - ResolveIdentityRequest.class, - ResolveIdentityRequest.class); - builder.addStatement("authOption.forEachIdentityProperty(identityRequestBuilder::putProperty)"); - if (endpointRulesSpecUtils.isS3()) { - builder.addStatement("identityRequestBuilder.putProperty($T.SDK_CLIENT, " - + "executionAttributes.getAttribute($T.SDK_CLIENT))", - SdkIdentityProperty.class, - SdkInternalExecutionAttribute.class); - } - builder.addStatement("$T identity", namedIdentityFuture()); - builder.addStatement("$T metric = getIdentityMetric(identityProvider)", durationSdkMetric()); - builder.beginControlFlow("if (metric == null)") - .addStatement("identity = identityProvider.resolveIdentity(identityRequestBuilder.build())") - .nextControlFlow("else") - .addStatement("identity = $T.reportDuration(" - + "() -> identityProvider.resolveIdentity(identityRequestBuilder.build()), metricCollector, metric)", - MetricUtils.class) - .endControlFlow(); - - builder.addStatement("return new $T<>(identity, signer, authOption)", SelectedAuthScheme.class); - return builder.build(); - } - - private MethodSpec generateGetIdentityMetric() { - MethodSpec.Builder builder = MethodSpec.methodBuilder("getIdentityMetric") - .addModifiers(Modifier.PRIVATE) - .returns(durationSdkMetric()) - .addParameter(wildcardIdentityProvider(), "identityProvider"); - - builder.addStatement("Class identityType = identityProvider.identityType()") - .beginControlFlow("if (identityType == $T.class)", AwsCredentialsIdentity.class) - .addStatement("return $T.CREDENTIALS_FETCH_DURATION", CoreMetric.class) - .endControlFlow() - .beginControlFlow("if (identityType == $T.class)", TokenIdentity.class) - .addStatement("return $T.TOKEN_FETCH_DURATION", CoreMetric.class) - .endControlFlow() - .addStatement("return null"); - - return builder.build(); - } - - private MethodSpec putSelectedAuthSchemeMethodSpec() { - String attributeParamName = "attributes"; - String selectedAuthSchemeParamName = "selectedAuthScheme"; - MethodSpec.Builder builder = MethodSpec.methodBuilder("putSelectedAuthScheme") - .addModifiers(Modifier.PRIVATE) - .addTypeVariable(TypeVariableName.get("T", Identity.class)) - .addParameter(ExecutionAttributes.class, attributeParamName) - .addParameter(ParameterSpec.builder( - ParameterizedTypeName.get(ClassName.get(SelectedAuthScheme.class), - TypeVariableName.get("T")), - selectedAuthSchemeParamName).build()); - builder.addStatement("$T existingAuthScheme = $N.getAttribute($T.SELECTED_AUTH_SCHEME)", - ParameterizedTypeName.get(ClassName.get(SelectedAuthScheme.class), - WildcardTypeName.subtypeOf(Object.class)), - attributeParamName, - SdkInternalExecutionAttribute.class); - - builder.beginControlFlow("if (existingAuthScheme != null)") - .addStatement("$T selectedOption = $N.authSchemeOption().toBuilder()", - AuthSchemeOption.Builder.class, selectedAuthSchemeParamName) - .addStatement("existingAuthScheme.authSchemeOption().forEachIdentityProperty" - + "(selectedOption::putIdentityPropertyIfAbsent)") - .addStatement("existingAuthScheme.authSchemeOption().forEachSignerProperty" - + "(selectedOption::putSignerPropertyIfAbsent)") - .addStatement("$N = new $T<>($N.identity(), $N.signer(), selectedOption.build())", - selectedAuthSchemeParamName, - SelectedAuthScheme.class, - selectedAuthSchemeParamName, - selectedAuthSchemeParamName); - builder.endControlFlow(); - - builder.addStatement("$N.putAttribute($T.SELECTED_AUTH_SCHEME, $N)", - attributeParamName, SdkInternalExecutionAttribute.class, selectedAuthSchemeParamName); - - return builder.build(); - } - - private void addLogDebugDiscardedOptions(MethodSpec.Builder builder) { - builder.beginControlFlow("if (!discardedReasons.isEmpty())"); - { - builder.addStatement("LOG.debug(() -> String.format(\"%s auth will be used, discarded: '%s'\", " - + "authOption.schemeId(), " - + "discardedReasons.stream().map($T::get).collect($T.joining(\", \"))))", - Supplier.class, Collectors.class) - .endControlFlow(); - } - } - - // IdentityProvider - private TypeName namedIdentityProvider() { - return ParameterizedTypeName.get(ClassName.get(IdentityProvider.class), TypeVariableName.get("T")); - } - - // IdentityProvider - private TypeName wildcardIdentityProvider() { - return ParameterizedTypeName.get(ClassName.get(IdentityProvider.class), WildcardTypeName.subtypeOf(Object.class)); - } - - // CompletableFuture - private TypeName namedIdentityFuture() { - return ParameterizedTypeName.get(ClassName.get(CompletableFuture.class), - WildcardTypeName.subtypeOf(TypeVariableName.get("T"))); - } - - // AuthScheme - private TypeName namedAuthScheme() { - return ParameterizedTypeName.get(ClassName.get(AuthScheme.class), - TypeVariableName.get("T", Identity.class)); - } - - // AuthScheme - private TypeName wildcardAuthScheme() { - return ParameterizedTypeName.get(ClassName.get(AuthScheme.class), - WildcardTypeName.subtypeOf(Object.class)); - } - - // SelectedAuthScheme - private TypeName namedSelectedAuthScheme() { - return ParameterizedTypeName.get(ClassName.get(SelectedAuthScheme.class), - TypeVariableName.get("T", Identity.class)); - } - - // SelectedAuthScheme - private TypeName wildcardSelectedAuthScheme() { - return ParameterizedTypeName.get(ClassName.get(SelectedAuthScheme.class), - WildcardTypeName.subtypeOf(Identity.class)); - } - - // List> - private TypeName listOfStringSuppliers() { - return listOf(ParameterizedTypeName.get(Supplier.class, String.class)); - } - - // Map - private TypeName mapOf(Object keyType, Object valueType) { - return ParameterizedTypeName.get(ClassName.get(Map.class), toTypeName(keyType), toTypeName(valueType)); - } - - // List - private TypeName listOf(Object valueType) { - return ParameterizedTypeName.get(ClassName.get(List.class), toTypeName(valueType)); - } - - // SdkMetric - private ParameterizedTypeName durationSdkMetric() { - return ParameterizedTypeName.get(ClassName.get(SdkMetric.class), toTypeName(Duration.class)); - } - - private TypeName toTypeName(Object valueType) { - TypeName result; - if (valueType instanceof Class) { - result = ClassName.get((Class) valueType); - } else if (valueType instanceof TypeName) { - result = (TypeName) valueType; - } else { - throw new IllegalArgumentException("Don't know how to convert " + valueType + " to TypeName"); - } - return result; - } - - private void generateSigv4aSigningRegionSet(MethodSpec.Builder builder) { - if (authSchemeSpecUtils.hasSigV4aSupport()) { - builder.addStatement( - "executionAttributes.getOptionalAttribute($T.AWS_SIGV4A_SIGNING_REGION_SET)\n" + - " .filter(regionSet -> !$T.isNullOrEmpty(regionSet))\n" + - " .ifPresent(nonEmptyRegionSet -> builder.regionSet($T.create(nonEmptyRegionSet)))", - AwsExecutionAttribute.class, - CollectionUtils.class, - RegionSet.class - ); - } - } -} diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeParamsSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeParamsSpec.java index 848182945524..86af5542e886 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeParamsSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeParamsSpec.java @@ -60,14 +60,42 @@ public TypeSpec poetSpec() { .addModifiers(Modifier.PUBLIC) .addAnnotation(SdkPublicApi.class) .addJavadoc(interfaceJavadoc()) - .addMethod(builderMethod()) - .addType(builderInterfaceSpec()); + .addMethod(builderMethod()); + + if (authSchemeSpecUtils.generateEndpointBasedParams()) { + b.addMethod(fromEndpointParamsMethod()); + } + + b.addType(builderInterfaceSpec()); addAccessorMethods(b); addToBuilder(b); return b.build(); } + private MethodSpec fromEndpointParamsMethod() { + ClassName endpointParamsClass = endpointRulesSpecUtils.parametersClassName(); + MethodSpec.Builder builder = MethodSpec.methodBuilder("fromEndpointParams") + .addModifiers(Modifier.PUBLIC, Modifier.STATIC) + .addParameter(endpointParamsClass, "endpointParams") + .returns(authSchemeSpecUtils.parametersInterfaceBuilderInterfaceName()) + .addJavadoc("Create a builder pre-populated with endpoint parameters.\n" + + "@param endpointParams the endpoint parameters to copy\n" + + "@return a builder with values from the endpoint parameters"); + + builder.addStatement("$T builder = builder()", authSchemeSpecUtils.parametersInterfaceBuilderInterfaceName()); + + parameters().forEach((name, model) -> { + if (authSchemeSpecUtils.includeParamForProvider(name)) { + String methodName = endpointRulesSpecUtils.paramMethodName(name); + builder.addStatement("builder.$1N(endpointParams.$1N())", methodName); + } + }); + + builder.addStatement("return builder"); + return builder.build(); + } + private CodeBlock interfaceJavadoc() { CodeBlock.Builder b = CodeBlock.builder(); diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/BaseClientBuilderClass.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/BaseClientBuilderClass.java index d3c37f92a367..cb0d9b103f1d 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/BaseClientBuilderClass.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/BaseClientBuilderClass.java @@ -375,10 +375,6 @@ private MethodSpec finalizeServiceConfigurationMethod() { List builtInInterceptors = new ArrayList<>(); - builtInInterceptors.add(authSchemeSpecUtils.authSchemeInterceptor()); - builtInInterceptors.add(endpointRulesSpecUtils.resolverInterceptorName()); - builtInInterceptors.add(endpointRulesSpecUtils.requestModifierInterceptorName()); - for (String interceptor : model.getCustomizationConfig().getInterceptors()) { builtInInterceptors.add(ClassName.bestGuess(interceptor)); } diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java index 539d5df35a43..09bbaa7caf80 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java @@ -69,10 +69,12 @@ import software.amazon.awssdk.codegen.poet.PoetExtension; import software.amazon.awssdk.codegen.poet.PoetUtils; import software.amazon.awssdk.codegen.poet.StaticImport; +import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeSpecUtils; import software.amazon.awssdk.codegen.poet.client.specs.ProtocolSpec; import software.amazon.awssdk.codegen.poet.eventstream.EventStreamUtils; import software.amazon.awssdk.codegen.poet.model.EventStreamSpecHelper; import software.amazon.awssdk.codegen.poet.model.ServiceClientConfigurationUtils; +import software.amazon.awssdk.codegen.poet.rules.EndpointRulesSpecUtils; import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.core.async.AsyncResponseTransformerUtils; @@ -100,6 +102,8 @@ public final class AsyncClientClass extends AsyncClientInterface { private final ProtocolSpec protocolSpec; private final ClassName serviceClientConfigurationClassName; private final ServiceClientConfigurationUtils configurationUtils; + private final AuthSchemeSpecUtils authSchemeSpecUtils; + private final EndpointRulesSpecUtils endpointRulesSpecUtils; private boolean hasScheduledExecutor; public AsyncClientClass(GeneratorTaskParams dependencies) { @@ -110,6 +114,8 @@ public AsyncClientClass(GeneratorTaskParams dependencies) { this.protocolSpec = getProtocolSpecs(poetExtensions, model); this.serviceClientConfigurationClassName = new PoetExtension(model).getServiceConfigClass(); this.configurationUtils = new ServiceClientConfigurationUtils(model); + this.authSchemeSpecUtils = new AuthSchemeSpecUtils(model); + this.endpointRulesSpecUtils = new EndpointRulesSpecUtils(model); } @Override @@ -165,7 +171,9 @@ protected void addAdditionalMethods(TypeSpec.Builder type) { .addMethod(nameMethod()) .addMethods(protocolSpec.additionalMethods()) .addMethod(protocolSpec.initProtocolFactory(model)) - .addMethod(resolveMetricPublishersMethod()); + .addMethod(resolveMetricPublishersMethod()) + .addMethod(ClientClassUtils.resolveAuthSchemeOptionsMethod(authSchemeSpecUtils, endpointRulesSpecUtils)) + .addMethod(ClientClassUtils.resolveEndpointMethod(authSchemeSpecUtils, endpointRulesSpecUtils)); type.addMethod(ClientClassUtils.updateRetryStrategyClientConfigurationMethod()); type.addMethod(updateSdkClientConfigurationMethod(configurationUtils.serviceClientConfigurationBuilderClassName(), diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/ClientClassUtils.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/ClientClassUtils.java index 9d4e15b3dd86..9684be34b02e 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/ClientClassUtils.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/ClientClassUtils.java @@ -25,16 +25,21 @@ import com.squareup.javapoet.ParameterizedTypeName; import com.squareup.javapoet.TypeName; import com.squareup.javapoet.TypeVariableName; +import com.squareup.javapoet.WildcardTypeName; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.stream.Collectors; import javax.lang.model.element.Modifier; import software.amazon.awssdk.arns.Arn; import software.amazon.awssdk.auth.signer.EventStreamAws4Signer; import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.retry.AwsRetryStrategy; import software.amazon.awssdk.codegen.model.config.customization.S3ArnableFieldConfig; import software.amazon.awssdk.codegen.model.intermediate.IntermediateModel; @@ -45,14 +50,22 @@ import software.amazon.awssdk.codegen.model.service.HostPrefixProcessor; import software.amazon.awssdk.codegen.poet.PoetExtension; import software.amazon.awssdk.codegen.poet.PoetUtils; +import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeSpecUtils; import software.amazon.awssdk.codegen.poet.rules.EndpointRulesSpecUtils; +import software.amazon.awssdk.core.SdkClient; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.retry.RetryMode; import software.amazon.awssdk.core.signer.Signer; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.utils.AttributeMap; import software.amazon.awssdk.utils.CollectionUtils; @@ -346,4 +359,233 @@ public static MethodSpec updateRetryStrategyClientConfigurationMethod() { static String transformServiceId(String serviceId) { return serviceId.replace(" ", "_"); } + + static MethodSpec resolveAuthSchemeOptionsMethod(AuthSchemeSpecUtils authSchemeSpecUtils, + EndpointRulesSpecUtils endpointRulesSpecUtils) { + MethodSpec.Builder builder = MethodSpec.methodBuilder("resolveAuthSchemeOptions") + .addModifiers(PRIVATE) + .returns(ParameterizedTypeName.get(ClassName.get(List.class), ClassName.get(AuthSchemeOption.class))) + .addParameter(SdkRequest.class, "request") + .addParameter(String.class, "operationName") + .addParameter(SdkClientConfiguration.class, "clientConfiguration"); + + ClassName providerInterface = authSchemeSpecUtils.providerInterfaceName(); + + // Check for request-level authSchemeProvider override + builder.addStatement("$T requestAuthSchemeProvider = request.overrideConfiguration()" + + ".flatMap(c -> c.authSchemeProvider())" + + ".map(p -> $T.isInstanceOf($T.class, p, $S))" + + ".orElse(null)", + providerInterface, Validate.class, providerInterface, + "Expected an instance of " + authSchemeSpecUtils.providerInterfaceName().simpleName()); + builder.addStatement("$T authSchemeProvider = requestAuthSchemeProvider != null " + + "? requestAuthSchemeProvider " + + ": $T.isInstanceOf($T.class, " + + "clientConfiguration.option($T.AUTH_SCHEME_PROVIDER), $S)", + providerInterface, Validate.class, providerInterface, SdkClientOption.class, + "Expected an instance of " + authSchemeSpecUtils.providerInterfaceName().simpleName()); + + if (authSchemeSpecUtils.useEndpointBasedAuthProvider()) { + addEndpointBasedAuthSchemeResolution(builder, authSchemeSpecUtils, endpointRulesSpecUtils); + } else { + addSimpleAuthSchemeResolution(builder, authSchemeSpecUtils); + } + + if (endpointRulesSpecUtils.isS3()) { + ClassName sdkIdentityProperty = ClassName.get("software.amazon.awssdk.core.identity", "SdkIdentityProperty"); + builder.addStatement("$T sdkClient = clientConfiguration.option($T.SDK_CLIENT)", + SdkClient.class, SdkClientOption.class); + builder.addStatement("return options.stream().map(o -> o.toBuilder()" + + ".putIdentityProperty($T.SDK_CLIENT, sdkClient).build())" + + ".collect($T.toList())", + sdkIdentityProperty, Collectors.class); + } else { + builder.addStatement("return options"); + } + + return builder.build(); + } + + private static void addSimpleAuthSchemeResolution(MethodSpec.Builder builder, + AuthSchemeSpecUtils authSchemeSpecUtils) { + ClassName paramsInterface = authSchemeSpecUtils.parametersInterfaceName(); + ClassName awsClientOption = ClassName.get("software.amazon.awssdk.awscore.client.config", "AwsClientOption"); + + builder.addStatement("$T.Builder paramsBuilder = $T.builder().operation(operationName)", + paramsInterface, paramsInterface); + + if (authSchemeSpecUtils.usesSigV4()) { + builder.addStatement("paramsBuilder.region(clientConfiguration.option($T.AWS_REGION))", awsClientOption); + } + + if (authSchemeSpecUtils.hasSigV4aSupport()) { + ClassName regionSet = ClassName.get("software.amazon.awssdk.http.auth.aws.signer", "RegionSet"); + builder.addStatement("$T sigv4aRegionSet = clientConfiguration.option($T.AWS_SIGV4A_SIGNING_REGION_SET)", + ClassName.get(Set.class), awsClientOption); + builder.beginControlFlow("if (!$T.isNullOrEmpty(sigv4aRegionSet))", CollectionUtils.class); + builder.addStatement("paramsBuilder.regionSet($T.create(sigv4aRegionSet))", regionSet); + builder.endControlFlow(); + } + + builder.addStatement("$T<$T> options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build())", + List.class, AuthSchemeOption.class); + } + + private static void addEndpointBasedAuthSchemeResolution(MethodSpec.Builder builder, + AuthSchemeSpecUtils authSchemeSpecUtils, + EndpointRulesSpecUtils endpointRulesSpecUtils) { + ClassName paramsInterface = authSchemeSpecUtils.parametersInterfaceName(); + ClassName awsClientOption = ClassName.get("software.amazon.awssdk.awscore.client.config", "AwsClientOption"); + ClassName endpointParamsClass = endpointRulesSpecUtils.parametersClassName(); + ClassName endpointResolverUtils = endpointRulesSpecUtils.endpointResolverUtilsName(); + ClassName executionAttributesClass = ClassName.get("software.amazon.awssdk.core.interceptor", "ExecutionAttributes"); + ClassName awsExecutionAttribute = ClassName.get("software.amazon.awssdk.awscore", "AwsExecutionAttribute"); + ClassName sdkExecutionAttribute = ClassName.get("software.amazon.awssdk.core.interceptor", "SdkExecutionAttribute"); + ClassName sdkInternalExecutionAttribute = ClassName.get("software.amazon.awssdk.core.interceptor", + "SdkInternalExecutionAttribute"); + + builder.addStatement("$T executionAttributes = new $T()", executionAttributesClass, executionAttributesClass); + builder.addStatement("executionAttributes.putAttribute($T.AWS_REGION, clientConfiguration.option($T.AWS_REGION))", + awsExecutionAttribute, awsClientOption); + builder.addStatement("executionAttributes.putAttribute($T.DUALSTACK_ENDPOINT_ENABLED, " + + "clientConfiguration.option($T.DUALSTACK_ENDPOINT_ENABLED))", + awsExecutionAttribute, awsClientOption); + builder.addStatement("executionAttributes.putAttribute($T.FIPS_ENDPOINT_ENABLED, " + + "clientConfiguration.option($T.FIPS_ENDPOINT_ENABLED))", + awsExecutionAttribute, awsClientOption); + builder.addStatement("executionAttributes.putAttribute($T.OPERATION_NAME, operationName)", sdkExecutionAttribute); + builder.addStatement("executionAttributes.putAttribute($T.CLIENT_ENDPOINT_PROVIDER, " + + "clientConfiguration.option($T.CLIENT_ENDPOINT_PROVIDER))", + sdkInternalExecutionAttribute, SdkClientOption.class); + builder.addStatement("executionAttributes.putAttribute($T.CLIENT_CONTEXT_PARAMS, " + + "clientConfiguration.option($T.CLIENT_CONTEXT_PARAMS))", + sdkInternalExecutionAttribute, SdkClientOption.class); + + builder.addStatement("$T endpointParams = $T.ruleParams(request, executionAttributes)", + endpointParamsClass, endpointResolverUtils); + + builder.addStatement("$T.Builder paramsBuilder = $T.builder()", paramsInterface, paramsInterface); + + boolean regionIncluded = false; + for (String paramName : endpointRulesSpecUtils.parameters().keySet()) { + if (!authSchemeSpecUtils.includeParamForProvider(paramName)) { + continue; + } + regionIncluded = regionIncluded || paramName.equalsIgnoreCase("region"); + String methodName = endpointRulesSpecUtils.paramMethodName(paramName); + builder.addStatement("paramsBuilder.$1N(endpointParams.$1N())", methodName); + } + + builder.addStatement("paramsBuilder.operation(operationName)"); + + if (authSchemeSpecUtils.usesSigV4() && !regionIncluded) { + builder.addStatement("paramsBuilder.region(clientConfiguration.option($T.AWS_REGION))", awsClientOption); + } + + if (authSchemeSpecUtils.hasSigV4aSupport()) { + ClassName regionSet = ClassName.get("software.amazon.awssdk.http.auth.aws.signer", "RegionSet"); + builder.addStatement("$T sigv4aRegionSet = clientConfiguration.option($T.AWS_SIGV4A_SIGNING_REGION_SET)", + ClassName.get(Set.class), awsClientOption); + builder.beginControlFlow("if (!$T.isNullOrEmpty(sigv4aRegionSet))", CollectionUtils.class); + builder.addStatement("paramsBuilder.regionSet($T.create(sigv4aRegionSet))", regionSet); + builder.endControlFlow(); + } + + ClassName paramsBuilderClass = authSchemeSpecUtils.parametersEndpointAwareDefaultImplName().nestedClass("Builder"); + ClassName endpointProviderInterface = endpointRulesSpecUtils.providerInterfaceName(); + + builder.beginControlFlow("if (paramsBuilder instanceof $T)", paramsBuilderClass); + builder.addStatement("$T endpointProvider = clientConfiguration.option($T.ENDPOINT_PROVIDER)", + ClassName.get("software.amazon.awssdk.endpoints", "EndpointProvider"), SdkClientOption.class); + builder.beginControlFlow("if (endpointProvider instanceof $T)", endpointProviderInterface); + builder.addStatement("(($T) paramsBuilder).endpointProvider(($T) endpointProvider)", + paramsBuilderClass, endpointProviderInterface); + builder.endControlFlow(); + builder.endControlFlow(); + + builder.addStatement("$T<$T> options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build())", + List.class, AuthSchemeOption.class); + } + + static MethodSpec resolveEndpointMethod(AuthSchemeSpecUtils authSchemeSpecUtils, + EndpointRulesSpecUtils endpointRulesSpecUtils) { + ClassName utilsClass = endpointRulesSpecUtils.endpointResolverUtilsName(); + ClassName endpointParamsClass = endpointRulesSpecUtils.parametersClassName(); + ClassName providerInterface = endpointRulesSpecUtils.providerInterfaceName(); + ClassName awsEndpointProviderUtils = endpointRulesSpecUtils.sharedAwsEndpointProviderUtilsName(); + ClassName awsEndpointAttribute = ClassName.get(AwsEndpointAttribute.class); + ClassName endpointAuthScheme = ClassName.get(EndpointAuthScheme.class); + + MethodSpec.Builder b = MethodSpec.methodBuilder("resolveEndpoint") + .addModifiers(PRIVATE) + .returns(Endpoint.class) + .addParameter(SdkRequest.class, "request") + .addParameter(ExecutionAttributes.class, "executionAttributes") + .addParameter(String.class, "operationName"); + + b.addStatement("$1T provider = ($1T) executionAttributes.getAttribute($2T.ENDPOINT_PROVIDER)", + providerInterface, SdkInternalExecutionAttribute.class); + + b.beginControlFlow("try"); + b.addStatement("$T endpointParams = $T.ruleParams(request, executionAttributes)", + endpointParamsClass, utilsClass); + b.addStatement("$T endpoint = provider.resolveEndpoint(endpointParams).join()", Endpoint.class); + + b.beginControlFlow("if (!$T.disableHostPrefixInjection(executionAttributes))", awsEndpointProviderUtils); + b.addStatement("$T hostPrefix = $T.hostPrefix(operationName, request)", + ParameterizedTypeName.get(Optional.class, String.class), utilsClass); + b.beginControlFlow("if (hostPrefix.isPresent())"); + b.addStatement("endpoint = $T.addHostPrefix(endpoint, hostPrefix.get())", awsEndpointProviderUtils); + b.endControlFlow(); + b.endControlFlow(); + + b.addStatement("$T endpointAuthSchemes = endpoint.attribute($T.AUTH_SCHEMES)", + ParameterizedTypeName.get(ClassName.get(List.class), endpointAuthScheme), + awsEndpointAttribute); + b.addStatement("$T selectedAuthScheme = executionAttributes.getAttribute($T.SELECTED_AUTH_SCHEME)", + ParameterizedTypeName.get(ClassName.get(SelectedAuthScheme.class), + WildcardTypeName.subtypeOf(Object.class)), + SdkInternalExecutionAttribute.class); + b.beginControlFlow("if (endpointAuthSchemes != null && selectedAuthScheme != null)"); + b.addStatement("selectedAuthScheme = $T.authSchemeWithEndpointSignerProperties(endpointAuthSchemes, selectedAuthScheme)", + utilsClass); + + if (authSchemeSpecUtils.usesSigV4a() || authSchemeSpecUtils.generateEndpointBasedParams()) { + ClassName awsV4aAuthScheme = ClassName.get("software.amazon.awssdk.http.auth.aws.scheme", "AwsV4aAuthScheme"); + ClassName awsV4aHttpSigner = ClassName.get("software.amazon.awssdk.http.auth.aws.signer", "AwsV4aHttpSigner"); + ClassName regionSet = ClassName.get("software.amazon.awssdk.http.auth.aws.signer", "RegionSet"); + ClassName authSchemeOption = ClassName.get("software.amazon.awssdk.http.auth.spi.scheme", "AuthSchemeOption"); + + b.addComment("Precedence of SigV4a RegionSet is set according to multi-auth SigV4a specifications"); + b.beginControlFlow("if (selectedAuthScheme.authSchemeOption().schemeId().equals($T.SCHEME_ID) " + + "&& selectedAuthScheme.authSchemeOption().signerProperty($T.REGION_SET) == null)", + awsV4aAuthScheme, awsV4aHttpSigner); + b.addStatement("$T optionBuilder = selectedAuthScheme.authSchemeOption().toBuilder()", + authSchemeOption.nestedClass("Builder")); + b.addStatement("$1T rs = $1T.create(endpointParams.region().id())", regionSet); + b.addStatement("optionBuilder.putSignerProperty($T.REGION_SET, rs)", awsV4aHttpSigner); + b.addStatement("selectedAuthScheme = new $T(selectedAuthScheme.identity(), selectedAuthScheme.signer(), " + + "optionBuilder.build())", SelectedAuthScheme.class); + b.endControlFlow(); + } + + b.addStatement("executionAttributes.putAttribute($T.SELECTED_AUTH_SCHEME, selectedAuthScheme)", + SdkInternalExecutionAttribute.class); + b.endControlFlow(); + + b.addStatement("$T.setMetricValues(endpoint, executionAttributes)", utilsClass); + + b.addStatement("return endpoint"); + + b.nextControlFlow("catch ($T e)", CompletionException.class); + b.addStatement("$T cause = e.getCause()", Throwable.class); + b.beginControlFlow("if (cause instanceof $T)", SdkClientException.class); + b.addStatement("throw ($T) cause", SdkClientException.class); + b.endControlFlow(); + b.addStatement("throw $T.create($S + cause.getMessage(), cause)", + SdkClientException.class, "Endpoint resolution failed: "); + b.endControlFlow(); + + return b.build(); + } } diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/SyncClientClass.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/SyncClientClass.java index 5736ddbecaf5..0d33e441bc40 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/SyncClientClass.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/SyncClientClass.java @@ -54,12 +54,14 @@ import software.amazon.awssdk.codegen.model.service.PreClientExecutionRequestCustomizer; import software.amazon.awssdk.codegen.poet.PoetExtension; import software.amazon.awssdk.codegen.poet.PoetUtils; +import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeSpecUtils; import software.amazon.awssdk.codegen.poet.client.specs.Ec2ProtocolSpec; import software.amazon.awssdk.codegen.poet.client.specs.JsonProtocolSpec; import software.amazon.awssdk.codegen.poet.client.specs.ProtocolSpec; import software.amazon.awssdk.codegen.poet.client.specs.QueryProtocolSpec; import software.amazon.awssdk.codegen.poet.client.specs.XmlProtocolSpec; import software.amazon.awssdk.codegen.poet.model.ServiceClientConfigurationUtils; +import software.amazon.awssdk.codegen.poet.rules.EndpointRulesSpecUtils; import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; @@ -83,6 +85,8 @@ public class SyncClientClass extends SyncClientInterface { private final ProtocolSpec protocolSpec; private final ClassName serviceClientConfigurationClassName; private final ServiceClientConfigurationUtils configurationUtils; + private final AuthSchemeSpecUtils authSchemeSpecUtils; + private final EndpointRulesSpecUtils endpointRulesSpecUtils; public SyncClientClass(GeneratorTaskParams taskParams) { super(taskParams.getModel()); @@ -92,6 +96,8 @@ public SyncClientClass(GeneratorTaskParams taskParams) { this.protocolSpec = getProtocolSpecs(poetExtensions, model); this.serviceClientConfigurationClassName = new PoetExtension(model).getServiceConfigClass(); this.configurationUtils = new ServiceClientConfigurationUtils(model); + this.authSchemeSpecUtils = new AuthSchemeSpecUtils(model); + this.endpointRulesSpecUtils = new EndpointRulesSpecUtils(model); } @Override @@ -133,7 +139,9 @@ protected void addAdditionalMethods(TypeSpec.Builder type) { type.addMethod(constructor()) .addMethod(nameMethod()) .addMethods(protocolSpec.additionalMethods()) - .addMethod(resolveMetricPublishersMethod()); + .addMethod(resolveMetricPublishersMethod()) + .addMethod(ClientClassUtils.resolveAuthSchemeOptionsMethod(authSchemeSpecUtils, endpointRulesSpecUtils)) + .addMethod(ClientClassUtils.resolveEndpointMethod(authSchemeSpecUtils, endpointRulesSpecUtils)); protocolSpec.createErrorResponseHandler().ifPresent(type::addMethod); type.addMethod(ClientClassUtils.updateRetryStrategyClientConfigurationMethod()); diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java index 894055c2db7c..09984c4554af 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java @@ -223,7 +223,11 @@ public CodeBlock executionHandler(OperationModel opModel) { .add(LongPollTrait.executionParamSetter(opModel)) .add(".withRequestConfiguration(clientConfiguration)") .add(".withInput($L)\n", opModel.getInput().getVariableName()) - .add(".withMetricCollector(apiCallMetricCollector)") + .add(".withMetricCollector(apiCallMetricCollector)\n") + .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", + opModel.getOperationName()) + .add(".withEndpointResolver((r, a) -> resolveEndpoint(r, a, $S))\n", + opModel.getOperationName()) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); @@ -298,6 +302,10 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper .add(".withErrorResponseHandler(errorResponseHandler)\n") .add(".withRequestConfiguration(clientConfiguration)") .add(".withMetricCollector(apiCallMetricCollector)\n") + .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", + opModel.getOperationName()) + .add(".withEndpointResolver((r, a) -> resolveEndpoint(r, a, $S))\n", + opModel.getOperationName()) .add(hostPrefixExpression(opModel)) .add(discoveredEndpoint(opModel)) .add(credentialType(opModel, model)) diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/QueryProtocolSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/QueryProtocolSpec.java index bab1f29e7281..1132bb911435 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/QueryProtocolSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/QueryProtocolSpec.java @@ -116,6 +116,10 @@ public CodeBlock executionHandler(OperationModel opModel) { .add(".withRequestConfiguration(clientConfiguration)") .add(".withInput($L)", opModel.getInput().getVariableName()) .add(".withMetricCollector(apiCallMetricCollector)") + .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", + opModel.getOperationName()) + .add(".withEndpointResolver((r, a) -> resolveEndpoint(r, a, $S))\n", + opModel.getOperationName()) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); @@ -155,6 +159,10 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper .add(credentialType(opModel, intermediateModel)) .add(".withRequestConfiguration(clientConfiguration)") .add(".withMetricCollector(apiCallMetricCollector)\n") + .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", + opModel.getOperationName()) + .add(".withEndpointResolver((r, a) -> resolveEndpoint(r, a, $S))\n", + opModel.getOperationName()) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/XmlProtocolSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/XmlProtocolSpec.java index dbbd33f4276f..207e38a49371 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/XmlProtocolSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/XmlProtocolSpec.java @@ -134,7 +134,11 @@ public CodeBlock executionHandler(OperationModel opModel) { discoveredEndpoint(opModel)) .add(credentialType(opModel, model)) .add(".withRequestConfiguration(clientConfiguration)") - .add(".withInput($L)", opModel.getInput().getVariableName()) + .add(".withInput($L)", opModel.getInput().getVariableName()) + .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", + opModel.getOperationName()) + .add(".withEndpointResolver((r, a) -> resolveEndpoint(r, a, $S))\n", + opModel.getOperationName()) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); @@ -212,7 +216,11 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper builder.add(hostPrefixExpression(opModel)) .add(credentialType(opModel, model)) - .add(".withMetricCollector(apiCallMetricCollector)\n") + .add(".withMetricCollector(apiCallMetricCollector)\n") + .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", + opModel.getOperationName()) + .add(".withEndpointResolver((r, a) -> resolveEndpoint(r, a, $S))\n", + opModel.getOperationName()) .add(asyncRequestBody(opModel)) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverUtilsSpec.java similarity index 76% rename from codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpec.java rename to codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverUtilsSpec.java index 726d61cdb99e..84677fcae0cd 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverUtilsSpec.java @@ -22,13 +22,11 @@ import com.fasterxml.jackson.jr.stree.JrsString; import com.squareup.javapoet.ClassName; import com.squareup.javapoet.CodeBlock; -import com.squareup.javapoet.FieldSpec; import com.squareup.javapoet.MethodSpec; import com.squareup.javapoet.ParameterizedTypeName; import com.squareup.javapoet.TypeName; import com.squareup.javapoet.TypeSpec; import com.squareup.javapoet.TypeVariableName; -import java.time.Duration; import java.util.Iterator; import java.util.List; import java.util.Locale; @@ -36,7 +34,6 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; -import java.util.concurrent.CompletionException; import javax.lang.model.element.Modifier; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.awscore.AwsExecutionAttribute; @@ -61,15 +58,9 @@ import software.amazon.awssdk.codegen.poet.waiters.JmesPathAcceptorGenerator; import software.amazon.awssdk.core.SdkRequest; import software.amazon.awssdk.core.SelectedAuthScheme; -import software.amazon.awssdk.core.exception.SdkClientException; -import software.amazon.awssdk.core.interceptor.Context; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; -import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; -import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; -import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.endpoints.Endpoint; -import software.amazon.awssdk.http.SdkHttpRequest; import software.amazon.awssdk.http.auth.aws.scheme.AwsV4AuthScheme; import software.amazon.awssdk.http.auth.aws.scheme.AwsV4aAuthScheme; import software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner; @@ -77,14 +68,18 @@ import software.amazon.awssdk.http.auth.aws.signer.RegionSet; import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.identity.spi.Identity; -import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.utils.AttributeMap; import software.amazon.awssdk.utils.CollectionUtils; import software.amazon.awssdk.utils.HostnameValidator; import software.amazon.awssdk.utils.StringUtils; import software.amazon.awssdk.utils.internal.CodegenNamingUtils; -public class EndpointResolverInterceptorSpec implements ClassSpec { +/** + * Generates a per-service endpoint resolver utility class (e.g. {@code DynamoDbEndpointResolverUtils}) + * containing all static helper methods for endpoint resolution: building endpoint params, host prefix + * resolution, auth scheme property copying, and business metrics. + */ +public class EndpointResolverUtilsSpec implements ClassSpec { private final IntermediateModel model; private final EndpointRulesSpecUtils endpointRulesSpecUtils; @@ -95,35 +90,28 @@ public class EndpointResolverInterceptorSpec implements ClassSpec { private final boolean multiAuthSigv4a; private final boolean legacyAuthFromEndpointRulesService; - - public EndpointResolverInterceptorSpec(IntermediateModel model) { + public EndpointResolverUtilsSpec(IntermediateModel model) { this.model = model; this.endpointRulesSpecUtils = new EndpointRulesSpecUtils(model); this.endpointParamsKnowledgeIndex = EndpointParamsKnowledgeIndex.of(model); this.poetExtension = new PoetExtension(model); this.jmesPathGenerator = new JmesPathAcceptorGenerator(poetExtension.jmesPathRuntimeClass()); - // We need to know whether the service has a dependency on the http-auth-aws module. Because we can't check that - // directly, assume that if they're using AwsV4AuthScheme or AwsV4aAuthScheme that it's available. Set> supportedAuthSchemes = ModelAuthSchemeClassesKnowledgeIndex.of(model).serviceConcreteAuthSchemeClasses(); this.dependsOnHttpAuthAws = supportedAuthSchemes.contains(AwsV4AuthScheme.class) || supportedAuthSchemes.contains(AwsV4aAuthScheme.class); - this.multiAuthSigv4a = new AuthSchemeSpecUtils(model).usesSigV4a(); this.legacyAuthFromEndpointRulesService = new AuthSchemeSpecUtils(model).generateEndpointBasedParams(); } @Override public TypeSpec poetSpec() { - FieldSpec endpointAuthSchemeStrategyFieldSpec = endpointAuthSchemeStrategyFieldSpec(); TypeSpec.Builder b = PoetUtils.createClassBuilder(className()) .addModifiers(Modifier.PUBLIC, Modifier.FINAL) .addAnnotation(SdkInternalApi.class) - .addSuperinterface(ExecutionInterceptor.class); + .addMethod(privateConstructor()); - b.addMethod(modifyRequestMethod(endpointAuthSchemeStrategyFieldSpec.name)); - b.addMethod(modifyHttpRequestMethod()); b.addMethod(ruleParams()); b.addMethod(setContextParams()); @@ -146,126 +134,17 @@ public TypeSpec poetSpec() { endpointParamsKnowledgeIndex.addAccountIdMethodsIfPresent(b); b.addMethod(setMetricValuesMethod()); + return b.build(); } @Override public ClassName className() { - return endpointRulesSpecUtils.resolverInterceptorName(); - } - - private FieldSpec endpointAuthSchemeStrategyFieldSpec() { - return FieldSpec.builder(endpointRulesSpecUtils.rulesRuntimeClassName("EndpointAuthSchemeStrategy"), - "endpointAuthSchemeStrategy", Modifier.PRIVATE, Modifier.FINAL) - .build(); - } - - private MethodSpec modifyRequestMethod(String endpointAuthSchemeStrategyFieldName) { - - MethodSpec.Builder b = MethodSpec.methodBuilder("modifyRequest") - .addModifiers(Modifier.PUBLIC) - .addAnnotation(Override.class) - .returns(SdkRequest.class) - .addParameter(Context.ModifyRequest.class, "context") - .addParameter(ExecutionAttributes.class, "executionAttributes"); - - String providerVar = "provider"; - - b.addStatement("$T result = context.request()", SdkRequest.class); - // We skip resolution if the source of the endpoint is the endpoint discovery call - b.beginControlFlow("if ($1T.endpointIsDiscovered(executionAttributes))", - endpointRulesSpecUtils.rulesRuntimeClassName("AwsEndpointProviderUtils")); - b.addStatement("return result"); - b.endControlFlow(); - - b.addStatement("$1T $2N = ($1T) executionAttributes.getAttribute($3T.ENDPOINT_PROVIDER)", - endpointRulesSpecUtils.providerInterfaceName(), providerVar, SdkInternalExecutionAttribute.class); - b.beginControlFlow("try"); - b.addStatement("long resolveEndpointStart = $T.nanoTime()", System.class); - b.addStatement("$T endpointParams = ruleParams(result, executionAttributes)", - endpointRulesSpecUtils.parametersClassName()); - b.addStatement("$T endpoint = $N.resolveEndpoint(endpointParams).join()", - Endpoint.class, providerVar); - b.addStatement("$1T resolveEndpointDuration = $1T.ofNanos($2T.nanoTime() - resolveEndpointStart)", Duration.class, - System.class); - b.addStatement("$T metricCollector = executionAttributes.getOptionalAttribute($T.API_CALL_METRIC_COLLECTOR)", - ParameterizedTypeName.get(Optional.class, MetricCollector.class), SdkExecutionAttribute.class); - b.addStatement("metricCollector.ifPresent(mc -> mc.reportMetric($T.ENDPOINT_RESOLVE_DURATION, resolveEndpointDuration))", - CoreMetric.class); - b.beginControlFlow("if (!$T.disableHostPrefixInjection(executionAttributes))", - endpointRulesSpecUtils.rulesRuntimeClassName("AwsEndpointProviderUtils")); - b.addStatement("$T hostPrefix = hostPrefix(executionAttributes.getAttribute($T.OPERATION_NAME), result)", - ParameterizedTypeName.get(Optional.class, String.class), SdkExecutionAttribute.class); - b.beginControlFlow("if (hostPrefix.isPresent())"); - b.addStatement("endpoint = $T.addHostPrefix(endpoint, hostPrefix.get())", - endpointRulesSpecUtils.rulesRuntimeClassName("AwsEndpointProviderUtils")); - b.endControlFlow(); - b.endControlFlow(); - - - // If the endpoint resolver returns auth settings, use them as signer properties. - // This effectively works to set the preSRA Signer ExecutionAttributes, so it is not conditional on useSraAuth. - b.addStatement("$T<$T> endpointAuthSchemes = endpoint.attribute($T.AUTH_SCHEMES)", - List.class, EndpointAuthScheme.class, AwsEndpointAttribute.class); - b.addStatement("$T selectedAuthScheme = executionAttributes.getAttribute($T.SELECTED_AUTH_SCHEME)", - SelectedAuthScheme.class, SdkInternalExecutionAttribute.class); - b.beginControlFlow("if (endpointAuthSchemes != null && selectedAuthScheme != null)"); - b.addStatement("selectedAuthScheme = authSchemeWithEndpointSignerProperties(endpointAuthSchemes, selectedAuthScheme)"); - if (multiAuthSigv4a || legacyAuthFromEndpointRulesService) { - b.addComment("Precedence of SigV4a RegionSet is set according to multi-auth SigV4a specifications"); - b.beginControlFlow("if(selectedAuthScheme.authSchemeOption().schemeId().equals($T.SCHEME_ID) " - + "&& selectedAuthScheme.authSchemeOption().signerProperty($T.REGION_SET) == null)", - AwsV4aAuthScheme.class, AwsV4aHttpSigner.class); - b.addStatement("$T optionBuilder = selectedAuthScheme.authSchemeOption().toBuilder()", - AuthSchemeOption.Builder.class); - b.addStatement("$T regionSet = $T.create(endpointParams.region().id())", - RegionSet.class, RegionSet.class); - b.addStatement("optionBuilder.putSignerProperty($T.REGION_SET, regionSet)", AwsV4aHttpSigner.class); - b.addStatement("selectedAuthScheme = new $T(selectedAuthScheme.identity(), selectedAuthScheme.signer(), " - + "optionBuilder.build())", SelectedAuthScheme.class); - b.endControlFlow(); - } - b.addStatement("executionAttributes.putAttribute($T.SELECTED_AUTH_SCHEME, selectedAuthScheme)", - SdkInternalExecutionAttribute.class); - b.endControlFlow(); - - b.addStatement("executionAttributes.putAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT, endpoint)"); - b.addStatement("setMetricValues(endpoint, executionAttributes)"); - b.addStatement("return result"); - b.endControlFlow(); - b.beginControlFlow("catch ($T e)", CompletionException.class); - b.addStatement("$T cause = e.getCause()", Throwable.class); - b.beginControlFlow("if (cause instanceof $T)", SdkClientException.class); - b.addStatement("throw ($T) cause", SdkClientException.class); - b.endControlFlow(); - b.beginControlFlow("else"); - b.addStatement("throw $T.create($S, cause)", SdkClientException.class, "Endpoint resolution failed"); - b.endControlFlow(); - b.endControlFlow(); - return b.build(); + return endpointRulesSpecUtils.endpointResolverUtilsName(); } - private MethodSpec modifyHttpRequestMethod() { - MethodSpec.Builder b = MethodSpec.methodBuilder("modifyHttpRequest") - .addModifiers(Modifier.PUBLIC) - .addAnnotation(Override.class) - .returns(SdkHttpRequest.class) - .addParameter(Context.ModifyHttpRequest.class, "context") - .addParameter(ExecutionAttributes.class, "executionAttributes"); - - b.addStatement("$T resolvedEndpoint = executionAttributes.getAttribute($T.RESOLVED_ENDPOINT)", - Endpoint.class, SdkInternalExecutionAttribute.class); - b.beginControlFlow("if (resolvedEndpoint.headers().isEmpty())"); - b.addStatement("return context.httpRequest()"); - b.endControlFlow(); - - b.addStatement("$T httpRequestBuilder = context.httpRequest().toBuilder()", SdkHttpRequest.Builder.class); - b.addCode("resolvedEndpoint.headers().forEach((name, values) -> {"); - b.addStatement("values.forEach(v -> httpRequestBuilder.appendHeader(name, v))"); - b.addCode("});"); - b.addStatement("return httpRequestBuilder.build()"); - - return b.build(); + private static MethodSpec privateConstructor() { + return MethodSpec.constructorBuilder().addModifiers(Modifier.PRIVATE).build(); } private MethodSpec ruleParams() { @@ -278,7 +157,6 @@ private MethodSpec ruleParams() { b.addStatement("$T builder = $T.builder()", paramsBuilderClass(), endpointRulesSpecUtils.parametersClassName()); Map parameters = model.getEndpointRuleSetModel().getParameters(); - parameters.forEach((n, m) -> { if (m.getBuiltInEnum() == null) { return; @@ -307,15 +185,11 @@ private MethodSpec ruleParams() { b.addStatement("builder.$N(executionAttributes.getAttribute($T.$N))", setter, AwsExecutionAttribute.class, model.getNamingStrategy().getEnumValueName(n)); break; - // The S3 specific built-ins are set through the existing S3Configuration and set through client context params case AWS_S3_ACCELERATE: case AWS_S3_DISABLE_MULTI_REGION_ACCESS_POINTS: case AWS_S3_FORCE_PATH_STYLE: case AWS_S3_USE_ARN_REGION: case AWS_S3_CONTROL_USE_ARN_REGION: - // end of S3 specific builtins - - // V2 doesn't support this, only regional endpoints case AWS_STS_USE_GLOBAL_ENDPOINT: return; default: @@ -339,85 +213,77 @@ private MethodSpec ruleParams() { private CodeBlock endpointProviderUtilsSetter(String builtInFn, String setterName) { return CodeBlock.of("builder.$N($T.$N(executionAttributes))", setterName, - endpointRulesSpecUtils.rulesRuntimeClassName("AwsEndpointProviderUtils"), builtInFn); + endpointRulesSpecUtils.sharedAwsEndpointProviderUtilsName(), builtInFn); } private ClassName paramsBuilderClass() { return endpointRulesSpecUtils.parametersClassName().nestedClass("Builder"); } - private MethodSpec addStaticContextParamsMethod(OperationModel opModel) { - String methodName = staticContextParamsMethodName(opModel); + private MethodSpec setContextParams() { + Map operations = model.getOperations(); - MethodSpec.Builder b = MethodSpec.methodBuilder(methodName) + MethodSpec.Builder b = MethodSpec.methodBuilder("setContextParams") .addModifiers(Modifier.PRIVATE, Modifier.STATIC) - .returns(void.class) - .addParameter(paramsBuilderClass(), "params"); + .addParameter(paramsBuilderClass(), "params") + .addParameter(String.class, "operationName") + .addParameter(SdkRequest.class, "request") + .returns(void.class); - opModel.getStaticContextParams().forEach((n, m) -> { - String setterName = endpointRulesSpecUtils.paramMethodName(n); - TreeNode value = m.getValue(); - switch (value.asToken()) { - case VALUE_STRING: - b.addStatement("params.$N($S)", setterName, ((JrsString) value).getValue()); - break; - case VALUE_TRUE: - case VALUE_FALSE: - b.addStatement("params.$N($L)", setterName, ((JrsBoolean) value).booleanValue()); - break; - case START_ARRAY: - JrsArray arrayValue = (JrsArray) value; - CodeBlock arrayCode = endpointRulesSpecUtils.treeNodeToLiteral(arrayValue); - b.addStatement("params.$N($L)", setterName, arrayCode); - break; - default: - throw new RuntimeException("Don't know how to set parameter of type " + value.asToken()); - } - }); + boolean generateSwitch = operations.values().stream().anyMatch(this::hasContextParams); + if (generateSwitch) { + b.beginControlFlow("switch (operationName)"); + operations.forEach((n, m) -> { + if (!hasContextParams(m)) { + return; + } + String requestClassName = model.getNamingStrategy().getRequestClassName(m.getOperationName()); + ClassName requestClass = poetExtension.getModelClass(requestClassName); + b.addCode("case $S:", n); + b.addStatement("setContextParams(params, ($T) request)", requestClass); + b.addStatement("break"); + }); + b.addCode("default:"); + b.addStatement("break"); + b.endControlFlow(); + } return b.build(); } - private String staticContextParamsMethodName(OperationModel opModel) { - return opModel.getMethodName() + "StaticContextParams"; - } - - private boolean hasStaticContextParams(OperationModel opModel) { - Map staticContextParams = opModel.getStaticContextParams(); - return staticContextParams != null && !staticContextParams.isEmpty(); - } - - private boolean hasOperationContextParams(OperationModel opModel) { - return CollectionUtils.isNotEmpty(opModel.getOperationContextParams()); - } - - private void addStaticContextParamMethods(TypeSpec.Builder classBuilder) { - Map operations = model.getOperations(); - - operations.forEach((n, m) -> { - if (hasStaticContextParams(m)) { - classBuilder.addMethod(addStaticContextParamsMethod(m)); - } - }); - } - private void addContextParamMethods(TypeSpec.Builder classBuilder) { - Map operations = model.getOperations(); - - operations.forEach((n, m) -> { + model.getOperations().forEach((n, m) -> { if (hasContextParams(m)) { classBuilder.addMethod(setContextParamsMethod(m)); } }); } - private void addOperationContextParamMethods(TypeSpec.Builder classBuilder) { - Map operations = model.getOperations(); - operations.forEach((n, m) -> { - if (hasOperationContextParams(m)) { - classBuilder.addMethod(setOperationContextParamsMethod(m)); + private MethodSpec setContextParamsMethod(OperationModel opModel) { + String requestClassName = model.getNamingStrategy().getRequestClassName(opModel.getOperationName()); + ClassName requestClass = poetExtension.getModelClass(requestClassName); + + MethodSpec.Builder b = MethodSpec.methodBuilder("setContextParams") + .addModifiers(Modifier.PRIVATE, Modifier.STATIC) + .addParameter(paramsBuilderClass(), "params") + .addParameter(requestClass, "request") + .returns(void.class); + + opModel.getInputShape().getMembers().forEach(m -> { + ContextParam param = m.getContextParam(); + if (param == null) { + return; } + String setterName = endpointRulesSpecUtils.paramMethodName(param.getName()); + b.addStatement("params.$N(request.$N())", setterName, m.getFluentGetterMethodName()); }); + + return b.build(); + } + + private boolean hasContextParams(OperationModel opModel) { + return opModel.getInputShape().getMembers().stream() + .anyMatch(m -> m.getContextParam() != null); } private MethodSpec setStaticContextParamsMethod() { @@ -432,12 +298,10 @@ private MethodSpec setStaticContextParamsMethod() { boolean generateSwitch = operations.values().stream().anyMatch(this::hasStaticContextParams); if (generateSwitch) { b.beginControlFlow("switch (operationName)"); - operations.forEach((n, m) -> { if (!hasStaticContextParams(m)) { return; } - b.addCode("case $S:", n); b.addStatement("$N(params)", staticContextParamsMethodName(m)); b.addStatement("break"); @@ -450,38 +314,53 @@ private MethodSpec setStaticContextParamsMethod() { return b.build(); } - private MethodSpec setContextParams() { - Map operations = model.getOperations(); + private void addStaticContextParamMethods(TypeSpec.Builder classBuilder) { + model.getOperations().forEach((n, m) -> { + if (hasStaticContextParams(m)) { + classBuilder.addMethod(addStaticContextParamsMethod(m)); + } + }); + } - MethodSpec.Builder b = MethodSpec.methodBuilder("setContextParams") - .addModifiers(Modifier.PRIVATE, Modifier.STATIC) - .addParameter(paramsBuilderClass(), "params") - .addParameter(String.class, "operationName") - .addParameter(SdkRequest.class, "request") - .returns(void.class); + private MethodSpec addStaticContextParamsMethod(OperationModel opModel) { + String methodName = staticContextParamsMethodName(opModel); - boolean generateSwitch = operations.values().stream().anyMatch(this::hasContextParams); - if (generateSwitch) { - b.beginControlFlow("switch (operationName)"); + MethodSpec.Builder b = MethodSpec.methodBuilder(methodName) + .addModifiers(Modifier.PRIVATE, Modifier.STATIC) + .returns(void.class) + .addParameter(paramsBuilderClass(), "params"); - operations.forEach((n, m) -> { - if (!hasContextParams(m)) { - return; - } + opModel.getStaticContextParams().forEach((n, m) -> { + String setterName = endpointRulesSpecUtils.paramMethodName(n); + TreeNode value = m.getValue(); + switch (value.asToken()) { + case VALUE_STRING: + b.addStatement("params.$N($S)", setterName, ((JrsString) value).getValue()); + break; + case VALUE_TRUE: + case VALUE_FALSE: + b.addStatement("params.$N($L)", setterName, ((JrsBoolean) value).booleanValue()); + break; + case START_ARRAY: + JrsArray arrayValue = (JrsArray) value; + CodeBlock arrayCode = endpointRulesSpecUtils.treeNodeToLiteral(arrayValue); + b.addStatement("params.$N($L)", setterName, arrayCode); + break; + default: + throw new RuntimeException("Don't know how to set parameter of type " + value.asToken()); + } + }); - String requestClassName = model.getNamingStrategy().getRequestClassName(m.getOperationName()); - ClassName requestClass = poetExtension.getModelClass(requestClassName); + return b.build(); + } - b.addCode("case $S:", n); - b.addStatement("setContextParams(params, ($T) request)", requestClass); - b.addStatement("break"); - }); - b.addCode("default:"); - b.addStatement("break"); - b.endControlFlow(); - } + private String staticContextParamsMethodName(OperationModel opModel) { + return opModel.getMethodName() + "StaticContextParams"; + } - return b.build(); + private boolean hasStaticContextParams(OperationModel opModel) { + Map staticContextParams = opModel.getStaticContextParams(); + return staticContextParams != null && !staticContextParams.isEmpty(); } private MethodSpec setOperationContextParams() { @@ -497,15 +376,12 @@ private MethodSpec setOperationContextParams() { boolean generateSwitch = operations.values().stream().anyMatch(this::hasOperationContextParams); if (generateSwitch) { b.beginControlFlow("switch (operationName)"); - operations.forEach((n, m) -> { if (!hasOperationContextParams(m)) { return; } - String requestClassName = model.getNamingStrategy().getRequestClassName(m.getOperationName()); ClassName requestClass = poetExtension.getModelClass(requestClassName); - b.addCode("case $S:", n); b.addStatement("setOperationContextParams(params, ($T) request)", requestClass); b.addStatement("break"); @@ -518,28 +394,12 @@ private MethodSpec setOperationContextParams() { return b.build(); } - private MethodSpec setContextParamsMethod(OperationModel opModel) { - String requestClassName = model.getNamingStrategy().getRequestClassName(opModel.getOperationName()); - ClassName requestClass = poetExtension.getModelClass(requestClassName); - - MethodSpec.Builder b = MethodSpec.methodBuilder("setContextParams") - .addModifiers(Modifier.PRIVATE, Modifier.STATIC) - .addParameter(paramsBuilderClass(), "params") - .addParameter(requestClass, "request") - .returns(void.class); - - opModel.getInputShape().getMembers().forEach(m -> { - ContextParam param = m.getContextParam(); - if (param == null) { - return; + private void addOperationContextParamMethods(TypeSpec.Builder classBuilder) { + model.getOperations().forEach((n, m) -> { + if (hasOperationContextParams(m)) { + classBuilder.addMethod(setOperationContextParamsMethod(m)); } - - String setterName = endpointRulesSpecUtils.paramMethodName(param.getName()); - - b.addStatement("params.$N(request.$N())", setterName, m.getFluentGetterMethodName()); }); - - return b.build(); } private MethodSpec setOperationContextParamsMethod(OperationModel opModel) { @@ -557,7 +417,6 @@ private MethodSpec setOperationContextParamsMethod(OperationModel opModel) { opModel.getOperationContextParams().forEach((key, value) -> { if (Objects.requireNonNull(value.getPath().asToken()) == JsonToken.VALUE_STRING) { String setterName = endpointRulesSpecUtils.paramMethodName(key); - String jmesPathString = ((JrsString) value.getPath()).getValue(); CodeBlock addParam = CodeBlock.builder() .add("params.$N(", setterName) @@ -565,18 +424,20 @@ private MethodSpec setOperationContextParamsMethod(OperationModel opModel) { .add(matchToParameterType(key)) .add(")") .build(); - b.addStatement(addParam); } else { throw new RuntimeException("Invalid operation context parameter path for " + opModel.getOperationName() + ". Expected VALUE_STRING, but got " + value.getPath().asToken()); } - }); return b.build(); } + private boolean hasOperationContextParams(OperationModel opModel) { + return CollectionUtils.isNotEmpty(opModel.getOperationContextParams()); + } + private CodeBlock matchToParameterType(String paramName) { Map parameters = model.getEndpointRuleSetModel().getParameters(); Optional endpointParameter = parameters.entrySet().stream() @@ -601,11 +462,6 @@ private CodeBlock convertValueToParameterType(ParameterModel parameterModel) { } } - private boolean hasContextParams(OperationModel opModel) { - return opModel.getInputShape().getMembers().stream() - .anyMatch(m -> m.getContextParam() != null); - } - private boolean hasClientContextParams() { Map clientContextParams = model.getClientContextParams(); return clientContextParams != null && !clientContextParams.isEmpty(); @@ -627,20 +483,18 @@ private MethodSpec setClientContextParamsMethod() { params.forEach((n, m) -> { String attrName = endpointRulesSpecUtils.clientContextParamName(n); b.addStatement("$T.ofNullable(clientContextParams.get($T.$N)).ifPresent(params::$N)", Optional.class, paramsClass, - attrName, - endpointRulesSpecUtils.paramMethodName(n)); + attrName, endpointRulesSpecUtils.paramMethodName(n)); }); return b.build(); } - private MethodSpec hostPrefixMethod() { MethodSpec.Builder builder = MethodSpec.methodBuilder("hostPrefix") .returns(ParameterizedTypeName.get(Optional.class, String.class)) .addParameter(String.class, "operationName") .addParameter(SdkRequest.class, "request") - .addModifiers(Modifier.PRIVATE, Modifier.STATIC); + .addModifiers(Modifier.PUBLIC, Modifier.STATIC); boolean generateSwitch = model.getOperations().values().stream().anyMatch(opModel -> StringUtils.isNotBlank(getHostPrefix(opModel))); @@ -649,16 +503,13 @@ private MethodSpec hostPrefixMethod() { builder.addStatement("return $T.empty()", Optional.class); } else { builder.beginControlFlow("switch (operationName)"); - model.getOperations().forEach((name, opModel) -> { String hostPrefix = getHostPrefix(opModel); if (StringUtils.isBlank(hostPrefix)) { return; } - builder.beginControlFlow("case $S:", name); HostPrefixProcessor processor = new HostPrefixProcessor(hostPrefix); - if (processor.c2jNames().isEmpty()) { builder.addStatement("return $T.of($S)", Optional.class, processor.hostWithStringSpecifier()); } else { @@ -666,12 +517,8 @@ private MethodSpec hostPrefixMethod() { processor.c2jNames().forEach(c2jName -> { builder.addStatement("$1T.validateHostnameCompliant(request.getValueForField($2S, $3T.class)" + ".orElse(null), $2S, $4S)", - HostnameValidator.class, - c2jName, - String.class, - requestVar); + HostnameValidator.class, c2jName, String.class, requestVar); }); - builder.addCode("return $T.of($T.format($S, ", Optional.class, String.class, processor.hostWithStringSpecifier()); Iterator c2jNamesIter = processor.c2jNames().listIterator(); @@ -685,7 +532,6 @@ private MethodSpec hostPrefixMethod() { } builder.endControlFlow(); }); - builder.addCode("default:"); builder.addStatement("return $T.empty()", Optional.class); builder.endControlFlow(); @@ -699,7 +545,6 @@ private String getHostPrefix(OperationModel opModel) { if (endpointTrait == null) { return null; } - return endpointTrait.getHostPrefix(); } @@ -711,15 +556,13 @@ private MethodSpec authSchemeWithEndpointSignerPropertiesMethod() { MethodSpec.Builder method = MethodSpec.methodBuilder("authSchemeWithEndpointSignerProperties") - .addModifiers(Modifier.PRIVATE) + .addModifiers(Modifier.PUBLIC, Modifier.STATIC) .addTypeVariable(tExtendsIdentity) .returns(selectedAuthSchemeOfT) .addParameter(listOfEndpointAuthScheme, "endpointAuthSchemes") .addParameter(selectedAuthSchemeOfT, "selectedAuthScheme"); method.beginControlFlow("for ($T endpointAuthScheme : endpointAuthSchemes)", EndpointAuthScheme.class); - - // Don't include signer properties for auth options that don't match our selected auth scheme method.beginControlFlow("if (!endpointAuthScheme.schemeId()" + ".equals(selectedAuthScheme.authSchemeOption().schemeId()))"); method.addStatement("continue"); @@ -738,9 +581,7 @@ private MethodSpec authSchemeWithEndpointSignerPropertiesMethod() { method.addStatement("throw new $T(\"Endpoint auth scheme '\" + endpointAuthScheme.name() + \"' cannot be mapped to the " + "SDK auth scheme. Was it declared in the service's model?\")", IllegalArgumentException.class); - method.endControlFlow(); - method.addStatement("return selectedAuthScheme"); return method.build(); @@ -750,34 +591,27 @@ private static CodeBlock copyV4EndpointSignerPropertiesToAuth() { CodeBlock.Builder code = CodeBlock.builder(); code.beginControlFlow("if (endpointAuthScheme instanceof $T)", SigV4AuthScheme.class); code.addStatement("$1T v4AuthScheme = ($1T) endpointAuthScheme", SigV4AuthScheme.class); - code.beginControlFlow("if (v4AuthScheme.isDisableDoubleEncodingSet())"); code.addStatement("option.putSignerProperty($T.DOUBLE_URL_ENCODE, !v4AuthScheme.disableDoubleEncoding())", AwsV4HttpSigner.class); code.endControlFlow(); - code.beginControlFlow("if (v4AuthScheme.signingRegion() != null)"); - code.addStatement("option.putSignerProperty($T.REGION_NAME, v4AuthScheme.signingRegion())", - AwsV4HttpSigner.class); + code.addStatement("option.putSignerProperty($T.REGION_NAME, v4AuthScheme.signingRegion())", AwsV4HttpSigner.class); code.endControlFlow(); - code.beginControlFlow("if (v4AuthScheme.signingName() != null)"); code.addStatement("option.putSignerProperty($T.SERVICE_SIGNING_NAME, v4AuthScheme.signingName())", AwsV4HttpSigner.class); code.endControlFlow(); - code.addStatement("return new $T<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build())", SelectedAuthScheme.class); code.endControlFlow(); return code.build(); } - private CodeBlock copyV4aEndpointSignerPropertiesToAuth() { + private CodeBlock copyV4aEndpointSignerPropertiesToAuth() { CodeBlock.Builder code = CodeBlock.builder(); - code.beginControlFlow("if (endpointAuthScheme instanceof $T)", SigV4aAuthScheme.class); code.addStatement("$1T v4aAuthScheme = ($1T) endpointAuthScheme", SigV4aAuthScheme.class); - code.beginControlFlow("if (v4aAuthScheme.isDisableDoubleEncodingSet())"); code.addStatement("option.putSignerProperty($T.DOUBLE_URL_ENCODE, !v4aAuthScheme.disableDoubleEncoding())", AwsV4aHttpSigner.class); @@ -793,12 +627,10 @@ private CodeBlock copyV4aEndpointSignerPropertiesToAuth() { code.addStatement("$1T regionSet = $1T.create(v4aAuthScheme.signingRegionSet())", RegionSet.class); code.addStatement("option.putSignerProperty($T.REGION_SET, regionSet)", AwsV4aHttpSigner.class); code.endControlFlow(); - code.beginControlFlow("if (v4aAuthScheme.signingName() != null)"); code.addStatement("option.putSignerProperty($T.SERVICE_SIGNING_NAME, v4aAuthScheme.signingName())", AwsV4aHttpSigner.class); code.endControlFlow(); - code.addStatement("return new $T<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build())", SelectedAuthScheme.class); code.endControlFlow(); @@ -812,22 +644,18 @@ private CodeBlock copyS3ExpressEndpointSignerPropertiesToAuth() { "S3ExpressEndpointAuthScheme"); code.beginControlFlow("if (endpointAuthScheme instanceof $T)", s3ExpressEndpointAuthSchemeClassName); code.addStatement("$1T s3ExpressAuthScheme = ($1T) endpointAuthScheme", s3ExpressEndpointAuthSchemeClassName); - code.beginControlFlow("if (s3ExpressAuthScheme.isDisableDoubleEncodingSet())"); code.addStatement("option.putSignerProperty($T.DOUBLE_URL_ENCODE, !s3ExpressAuthScheme.disableDoubleEncoding())", AwsV4HttpSigner.class); code.endControlFlow(); - code.beginControlFlow("if (s3ExpressAuthScheme.signingRegion() != null)"); code.addStatement("option.putSignerProperty($T.REGION_NAME, s3ExpressAuthScheme.signingRegion())", AwsV4HttpSigner.class); code.endControlFlow(); - code.beginControlFlow("if (s3ExpressAuthScheme.signingName() != null)"); code.addStatement("option.putSignerProperty($T.SERVICE_SIGNING_NAME, s3ExpressAuthScheme.signingName())", AwsV4HttpSigner.class); code.endControlFlow(); - code.addStatement("return new $T<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build())", SelectedAuthScheme.class); code.endControlFlow(); @@ -836,7 +664,7 @@ private CodeBlock copyS3ExpressEndpointSignerPropertiesToAuth() { private MethodSpec setMetricValuesMethod() { MethodSpec.Builder b = MethodSpec.methodBuilder("setMetricValues") - .addModifiers(Modifier.PRIVATE) + .addModifiers(Modifier.PUBLIC, Modifier.STATIC) .addParameter(Endpoint.class, "endpoint") .addParameter(ExecutionAttributes.class, "executionAttributes") .returns(void.class); @@ -848,10 +676,10 @@ private MethodSpec setMetricValuesMethod() { b.endControlFlow(); if (endpointRulesSpecUtils.isS3()) { - b.addStatement("$T.addS3ExpressBusinessMetricIfApplicable(executionAttributes)", + b.addStatement("$T.addS3ExpressBusinessMetricIfApplicable(endpoint, executionAttributes)", ClassName.get("software.amazon.awssdk.services.s3.internal.s3express", "S3ExpressUtils")); } - + return b.build(); } } diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointRulesSpecUtils.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointRulesSpecUtils.java index dfcf68b056fd..852b1693c5af 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointRulesSpecUtils.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointRulesSpecUtils.java @@ -93,6 +93,16 @@ public ClassName resolverInterceptorName() { md.getServiceName() + "ResolveEndpointInterceptor"); } + public ClassName endpointResolverUtilsName() { + Metadata md = intermediateModel.getMetadata(); + return ClassName.get(md.getFullInternalEndpointRulesPackageName(), + md.getServiceName() + "EndpointResolverUtils"); + } + + public ClassName sharedAwsEndpointProviderUtilsName() { + return ClassName.get("software.amazon.awssdk.awscore.endpoints", "AwsEndpointProviderUtils"); + } + public ClassName requestModifierInterceptorName() { Metadata md = intermediateModel.getMetadata(); return ClassName.get(md.getFullInternalEndpointRulesPackageName(), diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/RequestEndpointInterceptorSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/RequestEndpointInterceptorSpec.java deleted file mode 100644 index 63b5133f180e..000000000000 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/RequestEndpointInterceptorSpec.java +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package software.amazon.awssdk.codegen.poet.rules; - -import com.squareup.javapoet.ClassName; -import com.squareup.javapoet.MethodSpec; -import com.squareup.javapoet.TypeSpec; -import javax.lang.model.element.Modifier; -import software.amazon.awssdk.annotations.SdkInternalApi; -import software.amazon.awssdk.codegen.model.intermediate.IntermediateModel; -import software.amazon.awssdk.codegen.poet.ClassSpec; -import software.amazon.awssdk.codegen.poet.PoetUtils; -import software.amazon.awssdk.core.interceptor.Context; -import software.amazon.awssdk.core.interceptor.ExecutionAttributes; -import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; -import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; -import software.amazon.awssdk.endpoints.Endpoint; -import software.amazon.awssdk.http.SdkHttpRequest; - -public class RequestEndpointInterceptorSpec implements ClassSpec { - private final EndpointRulesSpecUtils endpointRulesSpecUtils; - - public RequestEndpointInterceptorSpec(IntermediateModel model) { - this.endpointRulesSpecUtils = new EndpointRulesSpecUtils(model); - } - - @Override - public TypeSpec poetSpec() { - TypeSpec.Builder b = PoetUtils.createClassBuilder(className()) - .addModifiers(Modifier.PUBLIC, Modifier.FINAL) - .addAnnotation(SdkInternalApi.class) - .addSuperinterface(ExecutionInterceptor.class); - - b.addMethod(modifyHttpRequestMethod()); - - return b.build(); - } - - @Override - public ClassName className() { - return endpointRulesSpecUtils.requestModifierInterceptorName(); - } - - private MethodSpec modifyHttpRequestMethod() { - - MethodSpec.Builder b = MethodSpec.methodBuilder("modifyHttpRequest") - .addModifiers(Modifier.PUBLIC) - .addAnnotation(Override.class) - .returns(SdkHttpRequest.class) - .addParameter(Context.ModifyHttpRequest.class, "context") - .addParameter(ExecutionAttributes.class, "executionAttributes"); - - - // We skip setting the endpoint here if the source of the endpoint is the endpoint discovery call - b.beginControlFlow("if ($1T.endpointIsDiscovered(executionAttributes))", - endpointRulesSpecUtils.rulesRuntimeClassName("AwsEndpointProviderUtils")) - .addStatement("return context.httpRequest()") - .endControlFlow().build(); - - b.addStatement("$T endpoint = executionAttributes.getAttribute($T.RESOLVED_ENDPOINT)", - Endpoint.class, - SdkInternalExecutionAttribute.class); - b.addStatement("return $T.setUri(context.httpRequest()," - + "executionAttributes.getAttribute($T.CLIENT_ENDPOINT_PROVIDER).clientEndpoint()," - + "endpoint.url())", - endpointRulesSpecUtils.rulesRuntimeClassName("AwsEndpointProviderUtils"), - SdkInternalExecutionAttribute.class); - return b.build(); - } - -} diff --git a/codegen/src/test/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeSpecTest.java b/codegen/src/test/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeSpecTest.java index 6b095f01ddc6..067d5d1673ed 100644 --- a/codegen/src/test/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeSpecTest.java +++ b/codegen/src/test/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeSpecTest.java @@ -187,28 +187,6 @@ static List parameters() { .caseName("auth-with-legacy-trait") .outputFileSuffix("default-provider") .build(), - // Interceptors - // - Normal case - TestCase.builder() - .modelProvider(ClientTestModels::queryServiceModels) - .classSpecProvider(AuthSchemeInterceptorSpec::new) - .caseName("query") - .outputFileSuffix("interceptor") - .build(), - // - Endpoints based params with allow list - TestCase.builder() - .modelProvider(ClientTestModels::queryServiceModelsEndpointAuthParamsWithAllowList) - .classSpecProvider(AuthSchemeInterceptorSpec::new) - .caseName("query-endpoint-auth-params-with-allowlist") - .outputFileSuffix("interceptor") - .build(), - // - Endpoints based params without allow list - TestCase.builder() - .modelProvider(ClientTestModels::queryServiceModelsEndpointAuthParamsWithoutAllowList) - .classSpecProvider(AuthSchemeInterceptorSpec::new) - .caseName("query-endpoint-auth-params-without-allowlist") - .outputFileSuffix("interceptor") - .build(), // Service with auth trait with Sigv4a TestCase.builder() .modelProvider(ClientTestModels::opsWithSigv4a) @@ -228,19 +206,6 @@ static List parameters() { .caseName("ops-auth-sigv4a-value") .outputFileSuffix("default-params") .build(), - TestCase.builder() - .modelProvider(ClientTestModels::opsWithSigv4a) - .classSpecProvider(AuthSchemeInterceptorSpec::new) - .caseName("ops-auth-sigv4a-value") - .outputFileSuffix("interceptor") - .build(), - // service with environment bearer token enabled - TestCase.builder() - .modelProvider(ClientTestModels::envBearerTokenServiceModels) - .classSpecProvider(AuthSchemeInterceptorSpec::new) - .caseName("env-bearer-token") - .outputFileSuffix("interceptor") - .build(), // Rest Json service with checksum TestCase.builder() .modelProvider(ClientTestModels::restJsonServiceModels) diff --git a/codegen/src/test/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpecTest.java b/codegen/src/test/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpecTest.java deleted file mode 100644 index 89f87ff41952..000000000000 --- a/codegen/src/test/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpecTest.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package software.amazon.awssdk.codegen.poet.rules; - -import static org.hamcrest.MatcherAssert.assertThat; -import static software.amazon.awssdk.codegen.poet.PoetMatchers.generatesTo; - -import org.junit.jupiter.api.Test; -import software.amazon.awssdk.codegen.model.intermediate.IntermediateModel; -import software.amazon.awssdk.codegen.poet.ClassSpec; -import software.amazon.awssdk.codegen.poet.ClientTestModels; - -public class EndpointResolverInterceptorSpecTest { - @Test - public void endpointResolverInterceptorClass() { - ClassSpec endpointProviderInterceptor = new EndpointResolverInterceptorSpec(getModel()); - assertThat(endpointProviderInterceptor, generatesTo("endpoint-resolve-interceptor.java")); - } - - private static IntermediateModel getModel() { - IntermediateModel model = ClientTestModels.queryServiceModels(); - return model; - } - - @Test - void endpointResolverInterceptorClassWithSigv4aMultiAuth() { - ClassSpec endpointProviderInterceptor = new EndpointResolverInterceptorSpec(ClientTestModels.opsWithSigv4a()); - assertThat(endpointProviderInterceptor, generatesTo("endpoint-resolve-interceptor-with-multiauthsigv4a.java")); - } - - @Test - void endpointResolverInterceptorClassWithEndpointBasedAuth() { - ClassSpec endpointProviderInterceptor = new EndpointResolverInterceptorSpec(ClientTestModels.queryServiceModelsEndpointAuthParamsWithoutAllowList()); - assertThat(endpointProviderInterceptor, generatesTo("endpoint-resolve-interceptor-with-endpointsbasedauth.java")); - } - - @Test - public void endpointProviderTestClassWithStringArray() { - ClassSpec endpointProviderInterceptor = new EndpointResolverInterceptorSpec(ClientTestModels.stringArrayServiceModels()); - assertThat(endpointProviderInterceptor, generatesTo("endpoint-resolve-interceptor-with-stringarray.java")); - } -} diff --git a/codegen/src/test/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverUtilsSpecTest.java b/codegen/src/test/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverUtilsSpecTest.java new file mode 100644 index 000000000000..ba70047a40db --- /dev/null +++ b/codegen/src/test/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverUtilsSpecTest.java @@ -0,0 +1,51 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.codegen.poet.rules; + +import static org.hamcrest.MatcherAssert.assertThat; +import static software.amazon.awssdk.codegen.poet.PoetMatchers.generatesTo; + +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.codegen.poet.ClassSpec; +import software.amazon.awssdk.codegen.poet.ClientTestModels; + +public class EndpointResolverUtilsSpecTest { + + @Test + void endpointResolverUtilsClass() { + ClassSpec spec = new EndpointResolverUtilsSpec(ClientTestModels.queryServiceModels()); + assertThat(spec, generatesTo("endpoint-resolver-utils.java")); + } + + @Test + void endpointResolverUtilsClassWithSigv4aMultiAuth() { + ClassSpec spec = new EndpointResolverUtilsSpec(ClientTestModels.opsWithSigv4a()); + assertThat(spec, generatesTo("endpoint-resolver-utils-with-multiauthsigv4a.java")); + } + + @Test + void endpointResolverUtilsClassWithEndpointBasedAuth() { + ClassSpec spec = new EndpointResolverUtilsSpec( + ClientTestModels.queryServiceModelsEndpointAuthParamsWithoutAllowList()); + assertThat(spec, generatesTo("endpoint-resolver-utils-with-endpointsbasedauth.java")); + } + + @Test + void endpointResolverUtilsClassWithStringArray() { + ClassSpec spec = new EndpointResolverUtilsSpec(ClientTestModels.stringArrayServiceModels()); + assertThat(spec, generatesTo("endpoint-resolver-utils-with-stringarray.java")); + } +} diff --git a/codegen/src/test/java/software/amazon/awssdk/codegen/poet/rules/RequestEndpointProviderInterceptorSpecTest.java b/codegen/src/test/java/software/amazon/awssdk/codegen/poet/rules/RequestEndpointProviderInterceptorSpecTest.java deleted file mode 100644 index f4f2434340c9..000000000000 --- a/codegen/src/test/java/software/amazon/awssdk/codegen/poet/rules/RequestEndpointProviderInterceptorSpecTest.java +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package software.amazon.awssdk.codegen.poet.rules; - -import static org.hamcrest.MatcherAssert.assertThat; -import static software.amazon.awssdk.codegen.poet.PoetMatchers.generatesTo; - -import org.junit.jupiter.api.Test; -import software.amazon.awssdk.codegen.poet.ClassSpec; -import software.amazon.awssdk.codegen.poet.ClientTestModels; - -public class RequestEndpointProviderInterceptorSpecTest { - @Test - public void requestSetEndpointInterceptorClass() { - ClassSpec endpointProviderInterceptor = new RequestEndpointInterceptorSpec(ClientTestModels.queryServiceModels()); - assertThat(endpointProviderInterceptor, generatesTo("request-set-endpoint-interceptor.java")); - } -} diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/auth/scheme/query-endpoint-auth-params-auth-scheme-params-with-allowlist.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/auth/scheme/query-endpoint-auth-params-auth-scheme-params-with-allowlist.java index 7fb0ebe41072..4ac1f03ed9e3 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/auth/scheme/query-endpoint-auth-params-auth-scheme-params-with-allowlist.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/auth/scheme/query-endpoint-auth-params-auth-scheme-params-with-allowlist.java @@ -20,6 +20,7 @@ import software.amazon.awssdk.http.auth.aws.signer.RegionSet; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.query.auth.scheme.internal.DefaultQueryAuthSchemeParams; +import software.amazon.awssdk.services.query.endpoints.QueryEndpointParams; import software.amazon.awssdk.utils.builder.CopyableBuilder; import software.amazon.awssdk.utils.builder.ToCopyableBuilder; @@ -36,6 +37,25 @@ static Builder builder() { return DefaultQueryAuthSchemeParams.builder(); } + /** + * Create a builder pre-populated with endpoint parameters. + * + * @param endpointParams + * the endpoint parameters to copy + * @return a builder with values from the endpoint parameters + */ + static Builder fromEndpointParams(QueryEndpointParams endpointParams) { + Builder builder = builder(); + builder.region(endpointParams.region()); + builder.defaultTrueParam(endpointParams.defaultTrueParam()); + builder.defaultStringParam(endpointParams.defaultStringParam()); + builder.deprecatedParam(endpointParams.deprecatedParam()); + builder.booleanContextParam(endpointParams.booleanContextParam()); + builder.stringContextParam(endpointParams.stringContextParam()); + builder.operationContextParam(endpointParams.operationContextParam()); + return builder; + } + /** * Returns the operation for which to resolve the auth scheme. */ diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/auth/scheme/query-endpoint-auth-params-auth-scheme-params-without-allowlist.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/auth/scheme/query-endpoint-auth-params-auth-scheme-params-without-allowlist.java index e5dd1d988daf..b583bd4c120d 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/auth/scheme/query-endpoint-auth-params-auth-scheme-params-without-allowlist.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/auth/scheme/query-endpoint-auth-params-auth-scheme-params-without-allowlist.java @@ -6,6 +6,7 @@ import software.amazon.awssdk.http.auth.aws.signer.RegionSet; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.query.auth.scheme.internal.DefaultQueryAuthSchemeParams; +import software.amazon.awssdk.services.query.endpoints.QueryEndpointParams; import software.amazon.awssdk.utils.builder.CopyableBuilder; import software.amazon.awssdk.utils.builder.ToCopyableBuilder; @@ -22,6 +23,32 @@ static Builder builder() { return DefaultQueryAuthSchemeParams.builder(); } + /** + * Create a builder pre-populated with endpoint parameters. + * + * @param endpointParams + * the endpoint parameters to copy + * @return a builder with values from the endpoint parameters + */ + static Builder fromEndpointParams(QueryEndpointParams endpointParams) { + Builder builder = builder(); + builder.region(endpointParams.region()); + builder.useDualStackEndpoint(endpointParams.useDualStackEndpoint()); + builder.useFipsEndpoint(endpointParams.useFipsEndpoint()); + builder.accountId(endpointParams.accountId()); + builder.accountIdEndpointMode(endpointParams.accountIdEndpointMode()); + builder.listOfStrings(endpointParams.listOfStrings()); + builder.defaultListOfStrings(endpointParams.defaultListOfStrings()); + builder.endpointId(endpointParams.endpointId()); + builder.defaultTrueParam(endpointParams.defaultTrueParam()); + builder.defaultStringParam(endpointParams.defaultStringParam()); + builder.deprecatedParam(endpointParams.deprecatedParam()); + builder.booleanContextParam(endpointParams.booleanContextParam()); + builder.stringContextParam(endpointParams.stringContextParam()); + builder.operationContextParam(endpointParams.operationContextParam()); + return builder; + } + /** * Returns the operation for which to resolve the auth scheme. */ diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-bearer-auth-client-builder-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-bearer-auth-client-builder-class.java index ee8f9a73d3e5..1b7d802857d9 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-bearer-auth-client-builder-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-bearer-auth-client-builder-class.java @@ -32,10 +32,7 @@ import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.services.json.auth.scheme.JsonAuthSchemeProvider; -import software.amazon.awssdk.services.json.auth.scheme.internal.JsonAuthSchemeInterceptor; import software.amazon.awssdk.services.json.endpoints.JsonEndpointProvider; -import software.amazon.awssdk.services.json.endpoints.internal.JsonRequestSetEndpointInterceptor; -import software.amazon.awssdk.services.json.endpoints.internal.JsonResolveEndpointInterceptor; import software.amazon.awssdk.services.json.internal.JsonServiceClientConfigurationBuilder; import software.amazon.awssdk.utils.CollectionUtils; import software.amazon.awssdk.utils.Validate; @@ -74,9 +71,6 @@ protected final SdkClientConfiguration mergeServiceDefaults(SdkClientConfigurati @Override protected final SdkClientConfiguration finalizeServiceConfiguration(SdkClientConfiguration config) { List endpointInterceptors = new ArrayList<>(); - endpointInterceptors.add(new JsonAuthSchemeInterceptor()); - endpointInterceptors.add(new JsonResolveEndpointInterceptor()); - endpointInterceptors.add(new JsonRequestSetEndpointInterceptor()); ClasspathInterceptorChainFactory interceptorFactory = new ClasspathInterceptorChainFactory(); List interceptors = interceptorFactory .getInterceptors("software/amazon/awssdk/services/json/execution.interceptors"); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-client-builder-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-client-builder-class.java index a0bdac67d04d..cdba2ca0888c 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-client-builder-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-client-builder-class.java @@ -41,11 +41,8 @@ import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.services.json.auth.scheme.JsonAuthSchemeProvider; -import software.amazon.awssdk.services.json.auth.scheme.internal.JsonAuthSchemeInterceptor; import software.amazon.awssdk.services.json.endpoints.JsonClientContextParams; import software.amazon.awssdk.services.json.endpoints.JsonEndpointProvider; -import software.amazon.awssdk.services.json.endpoints.internal.JsonRequestSetEndpointInterceptor; -import software.amazon.awssdk.services.json.endpoints.internal.JsonResolveEndpointInterceptor; import software.amazon.awssdk.services.json.internal.JsonServiceClientConfigurationBuilder; import software.amazon.awssdk.utils.AttributeMap; import software.amazon.awssdk.utils.CollectionUtils; @@ -86,9 +83,6 @@ protected final SdkClientConfiguration mergeServiceDefaults(SdkClientConfigurati @Override protected final SdkClientConfiguration finalizeServiceConfiguration(SdkClientConfiguration config) { List endpointInterceptors = new ArrayList<>(); - endpointInterceptors.add(new JsonAuthSchemeInterceptor()); - endpointInterceptors.add(new JsonResolveEndpointInterceptor()); - endpointInterceptors.add(new JsonRequestSetEndpointInterceptor()); ClasspathInterceptorChainFactory interceptorFactory = new ClasspathInterceptorChainFactory(); List interceptors = interceptorFactory .getInterceptors("software/amazon/awssdk/services/json/execution.interceptors"); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-client-builder-endpoints-auth-params.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-client-builder-endpoints-auth-params.java index 360d3664eaad..0fe341a73bca 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-client-builder-endpoints-auth-params.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-client-builder-endpoints-auth-params.java @@ -40,11 +40,8 @@ import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.services.query.auth.scheme.QueryAuthSchemeProvider; -import software.amazon.awssdk.services.query.auth.scheme.internal.QueryAuthSchemeInterceptor; import software.amazon.awssdk.services.query.endpoints.QueryClientContextParams; import software.amazon.awssdk.services.query.endpoints.QueryEndpointProvider; -import software.amazon.awssdk.services.query.endpoints.internal.QueryRequestSetEndpointInterceptor; -import software.amazon.awssdk.services.query.endpoints.internal.QueryResolveEndpointInterceptor; import software.amazon.awssdk.services.query.internal.QueryServiceClientConfigurationBuilder; import software.amazon.awssdk.utils.CollectionUtils; import software.amazon.awssdk.utils.Validate; @@ -83,9 +80,6 @@ protected final SdkClientConfiguration mergeServiceDefaults(SdkClientConfigurati @Override protected final SdkClientConfiguration finalizeServiceConfiguration(SdkClientConfiguration config) { List endpointInterceptors = new ArrayList<>(); - endpointInterceptors.add(new QueryAuthSchemeInterceptor()); - endpointInterceptors.add(new QueryResolveEndpointInterceptor()); - endpointInterceptors.add(new QueryRequestSetEndpointInterceptor()); ClasspathInterceptorChainFactory interceptorFactory = new ClasspathInterceptorChainFactory(); List interceptors = interceptorFactory .getInterceptors("software/amazon/awssdk/services/query/execution.interceptors"); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-client-builder-internal-defaults-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-client-builder-internal-defaults-class.java index 5ecb09b82593..e5d1cf583c58 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-client-builder-internal-defaults-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-client-builder-internal-defaults-class.java @@ -29,10 +29,7 @@ import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.services.json.auth.scheme.JsonAuthSchemeProvider; -import software.amazon.awssdk.services.json.auth.scheme.internal.JsonAuthSchemeInterceptor; import software.amazon.awssdk.services.json.endpoints.JsonEndpointProvider; -import software.amazon.awssdk.services.json.endpoints.internal.JsonRequestSetEndpointInterceptor; -import software.amazon.awssdk.services.json.endpoints.internal.JsonResolveEndpointInterceptor; import software.amazon.awssdk.services.json.internal.JsonServiceClientConfigurationBuilder; import software.amazon.awssdk.utils.CollectionUtils; @@ -76,9 +73,6 @@ protected final SdkClientConfiguration mergeInternalDefaults(SdkClientConfigurat @Override protected final SdkClientConfiguration finalizeServiceConfiguration(SdkClientConfiguration config) { List endpointInterceptors = new ArrayList<>(); - endpointInterceptors.add(new JsonAuthSchemeInterceptor()); - endpointInterceptors.add(new JsonResolveEndpointInterceptor()); - endpointInterceptors.add(new JsonRequestSetEndpointInterceptor()); ClasspathInterceptorChainFactory interceptorFactory = new ClasspathInterceptorChainFactory(); List interceptors = interceptorFactory .getInterceptors("software/amazon/awssdk/services/json/execution.interceptors"); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-composed-sync-default-client-builder.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-composed-sync-default-client-builder.java index 117e19038881..168c15a64e31 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-composed-sync-default-client-builder.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-composed-sync-default-client-builder.java @@ -37,11 +37,8 @@ import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.services.json.auth.scheme.JsonAuthSchemeProvider; -import software.amazon.awssdk.services.json.auth.scheme.internal.JsonAuthSchemeInterceptor; import software.amazon.awssdk.services.json.endpoints.JsonClientContextParams; import software.amazon.awssdk.services.json.endpoints.JsonEndpointProvider; -import software.amazon.awssdk.services.json.endpoints.internal.JsonRequestSetEndpointInterceptor; -import software.amazon.awssdk.services.json.endpoints.internal.JsonResolveEndpointInterceptor; import software.amazon.awssdk.services.json.internal.JsonServiceClientConfigurationBuilder; import software.amazon.awssdk.utils.CollectionUtils; import software.amazon.awssdk.utils.Validate; @@ -81,9 +78,6 @@ protected final SdkClientConfiguration mergeServiceDefaults(SdkClientConfigurati @Override protected final SdkClientConfiguration finalizeServiceConfiguration(SdkClientConfiguration config) { List endpointInterceptors = new ArrayList<>(); - endpointInterceptors.add(new JsonAuthSchemeInterceptor()); - endpointInterceptors.add(new JsonResolveEndpointInterceptor()); - endpointInterceptors.add(new JsonRequestSetEndpointInterceptor()); ClasspathInterceptorChainFactory interceptorFactory = new ClasspathInterceptorChainFactory(); List interceptors = interceptorFactory .getInterceptors("software/amazon/awssdk/services/json/execution.interceptors"); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-env-bearer-token-client-builder-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-env-bearer-token-client-builder-class.java index 48ecf08535fa..615f6a76d1a5 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-env-bearer-token-client-builder-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-env-bearer-token-client-builder-class.java @@ -36,10 +36,7 @@ import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.services.json.auth.scheme.JsonAuthSchemeProvider; -import software.amazon.awssdk.services.json.auth.scheme.internal.JsonAuthSchemeInterceptor; import software.amazon.awssdk.services.json.endpoints.JsonEndpointProvider; -import software.amazon.awssdk.services.json.endpoints.internal.JsonRequestSetEndpointInterceptor; -import software.amazon.awssdk.services.json.endpoints.internal.JsonResolveEndpointInterceptor; import software.amazon.awssdk.services.json.internal.EnvironmentTokenSystemSettings; import software.amazon.awssdk.services.json.internal.JsonServiceClientConfigurationBuilder; import software.amazon.awssdk.utils.CollectionUtils; @@ -91,9 +88,6 @@ protected final SdkClientConfiguration mergeServiceDefaults(SdkClientConfigurati @Override protected final SdkClientConfiguration finalizeServiceConfiguration(SdkClientConfiguration config) { List endpointInterceptors = new ArrayList<>(); - endpointInterceptors.add(new JsonAuthSchemeInterceptor()); - endpointInterceptors.add(new JsonResolveEndpointInterceptor()); - endpointInterceptors.add(new JsonRequestSetEndpointInterceptor()); ClasspathInterceptorChainFactory interceptorFactory = new ClasspathInterceptorChainFactory(); List interceptors = interceptorFactory .getInterceptors("software/amazon/awssdk/services/json/execution.interceptors"); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-h2-service-client-builder-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-h2-service-client-builder-class.java index 76c1cd2fc7eb..2e5198dac4b6 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-h2-service-client-builder-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-h2-service-client-builder-class.java @@ -32,10 +32,7 @@ import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.services.h2.auth.scheme.H2AuthSchemeProvider; -import software.amazon.awssdk.services.h2.auth.scheme.internal.H2AuthSchemeInterceptor; import software.amazon.awssdk.services.h2.endpoints.H2EndpointProvider; -import software.amazon.awssdk.services.h2.endpoints.internal.H2RequestSetEndpointInterceptor; -import software.amazon.awssdk.services.h2.endpoints.internal.H2ResolveEndpointInterceptor; import software.amazon.awssdk.services.h2.internal.H2ServiceClientConfigurationBuilder; import software.amazon.awssdk.utils.AttributeMap; import software.amazon.awssdk.utils.CollectionUtils; @@ -71,9 +68,6 @@ protected final SdkClientConfiguration mergeServiceDefaults(SdkClientConfigurati @Override protected final SdkClientConfiguration finalizeServiceConfiguration(SdkClientConfiguration config) { List endpointInterceptors = new ArrayList<>(); - endpointInterceptors.add(new H2AuthSchemeInterceptor()); - endpointInterceptors.add(new H2ResolveEndpointInterceptor()); - endpointInterceptors.add(new H2RequestSetEndpointInterceptor()); ClasspathInterceptorChainFactory interceptorFactory = new ClasspathInterceptorChainFactory(); List interceptors = interceptorFactory .getInterceptors("software/amazon/awssdk/services/h2/execution.interceptors"); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-h2-usePriorKnowledgeForH2-service-client-builder-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-h2-usePriorKnowledgeForH2-service-client-builder-class.java index d3fc4996de98..80d554e50d1f 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-h2-usePriorKnowledgeForH2-service-client-builder-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-h2-usePriorKnowledgeForH2-service-client-builder-class.java @@ -31,10 +31,7 @@ import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.services.h2.auth.scheme.H2AuthSchemeProvider; -import software.amazon.awssdk.services.h2.auth.scheme.internal.H2AuthSchemeInterceptor; import software.amazon.awssdk.services.h2.endpoints.H2EndpointProvider; -import software.amazon.awssdk.services.h2.endpoints.internal.H2RequestSetEndpointInterceptor; -import software.amazon.awssdk.services.h2.endpoints.internal.H2ResolveEndpointInterceptor; import software.amazon.awssdk.services.h2.internal.H2ServiceClientConfigurationBuilder; import software.amazon.awssdk.utils.AttributeMap; import software.amazon.awssdk.utils.CollectionUtils; @@ -70,9 +67,6 @@ protected final SdkClientConfiguration mergeServiceDefaults(SdkClientConfigurati @Override protected final SdkClientConfiguration finalizeServiceConfiguration(SdkClientConfiguration config) { List endpointInterceptors = new ArrayList<>(); - endpointInterceptors.add(new H2AuthSchemeInterceptor()); - endpointInterceptors.add(new H2ResolveEndpointInterceptor()); - endpointInterceptors.add(new H2RequestSetEndpointInterceptor()); ClasspathInterceptorChainFactory interceptorFactory = new ClasspathInterceptorChainFactory(); List interceptors = interceptorFactory .getInterceptors("software/amazon/awssdk/services/h2/execution.interceptors"); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-multi-auth-sigv4a-client-builder-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-multi-auth-sigv4a-client-builder-class.java index 75faf2cad7a8..fcc5f72b4e25 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-multi-auth-sigv4a-client-builder-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-multi-auth-sigv4a-client-builder-class.java @@ -31,10 +31,7 @@ import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.services.database.auth.scheme.DatabaseAuthSchemeProvider; -import software.amazon.awssdk.services.database.auth.scheme.internal.DatabaseAuthSchemeInterceptor; import software.amazon.awssdk.services.database.endpoints.DatabaseEndpointProvider; -import software.amazon.awssdk.services.database.endpoints.internal.DatabaseRequestSetEndpointInterceptor; -import software.amazon.awssdk.services.database.endpoints.internal.DatabaseResolveEndpointInterceptor; import software.amazon.awssdk.services.database.internal.DatabaseServiceClientConfigurationBuilder; import software.amazon.awssdk.utils.CollectionUtils; @@ -70,9 +67,6 @@ protected final SdkClientConfiguration mergeServiceDefaults(SdkClientConfigurati @Override protected final SdkClientConfiguration finalizeServiceConfiguration(SdkClientConfiguration config) { List endpointInterceptors = new ArrayList<>(); - endpointInterceptors.add(new DatabaseAuthSchemeInterceptor()); - endpointInterceptors.add(new DatabaseResolveEndpointInterceptor()); - endpointInterceptors.add(new DatabaseRequestSetEndpointInterceptor()); ClasspathInterceptorChainFactory interceptorFactory = new ClasspathInterceptorChainFactory(); List interceptors = interceptorFactory .getInterceptors("software/amazon/awssdk/services/database/execution.interceptors"); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-no-auth-ops-client-builder-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-no-auth-ops-client-builder-class.java index 72d4f526bfb3..5da93144892e 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-no-auth-ops-client-builder-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-no-auth-ops-client-builder-class.java @@ -33,10 +33,7 @@ import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.services.database.auth.scheme.DatabaseAuthSchemeProvider; -import software.amazon.awssdk.services.database.auth.scheme.internal.DatabaseAuthSchemeInterceptor; import software.amazon.awssdk.services.database.endpoints.DatabaseEndpointProvider; -import software.amazon.awssdk.services.database.endpoints.internal.DatabaseRequestSetEndpointInterceptor; -import software.amazon.awssdk.services.database.endpoints.internal.DatabaseResolveEndpointInterceptor; import software.amazon.awssdk.services.database.internal.DatabaseServiceClientConfigurationBuilder; import software.amazon.awssdk.utils.CollectionUtils; import software.amazon.awssdk.utils.Validate; @@ -76,9 +73,6 @@ protected final SdkClientConfiguration mergeServiceDefaults(SdkClientConfigurati @Override protected final SdkClientConfiguration finalizeServiceConfiguration(SdkClientConfiguration config) { List endpointInterceptors = new ArrayList<>(); - endpointInterceptors.add(new DatabaseAuthSchemeInterceptor()); - endpointInterceptors.add(new DatabaseResolveEndpointInterceptor()); - endpointInterceptors.add(new DatabaseRequestSetEndpointInterceptor()); ClasspathInterceptorChainFactory interceptorFactory = new ClasspathInterceptorChainFactory(); List interceptors = interceptorFactory .getInterceptors("software/amazon/awssdk/services/database/execution.interceptors"); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-no-auth-service-client-builder-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-no-auth-service-client-builder-class.java index 0be9c031d828..05f0b72afa2b 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-no-auth-service-client-builder-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-no-auth-service-client-builder-class.java @@ -27,10 +27,7 @@ import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.services.database.auth.scheme.DatabaseAuthSchemeProvider; -import software.amazon.awssdk.services.database.auth.scheme.internal.DatabaseAuthSchemeInterceptor; import software.amazon.awssdk.services.database.endpoints.DatabaseEndpointProvider; -import software.amazon.awssdk.services.database.endpoints.internal.DatabaseRequestSetEndpointInterceptor; -import software.amazon.awssdk.services.database.endpoints.internal.DatabaseResolveEndpointInterceptor; import software.amazon.awssdk.services.database.internal.DatabaseServiceClientConfigurationBuilder; import software.amazon.awssdk.utils.CollectionUtils; @@ -66,9 +63,6 @@ protected final SdkClientConfiguration mergeServiceDefaults(SdkClientConfigurati @Override protected final SdkClientConfiguration finalizeServiceConfiguration(SdkClientConfiguration config) { List endpointInterceptors = new ArrayList<>(); - endpointInterceptors.add(new DatabaseAuthSchemeInterceptor()); - endpointInterceptors.add(new DatabaseResolveEndpointInterceptor()); - endpointInterceptors.add(new DatabaseRequestSetEndpointInterceptor()); ClasspathInterceptorChainFactory interceptorFactory = new ClasspathInterceptorChainFactory(); List interceptors = interceptorFactory .getInterceptors("software/amazon/awssdk/services/database/execution.interceptors"); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-query-client-builder-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-query-client-builder-class.java index 19b8d5abbae1..db0dae50266e 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-query-client-builder-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-query-client-builder-class.java @@ -38,11 +38,8 @@ import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.services.query.auth.scheme.QueryAuthSchemeProvider; -import software.amazon.awssdk.services.query.auth.scheme.internal.QueryAuthSchemeInterceptor; import software.amazon.awssdk.services.query.endpoints.QueryClientContextParams; import software.amazon.awssdk.services.query.endpoints.QueryEndpointProvider; -import software.amazon.awssdk.services.query.endpoints.internal.QueryRequestSetEndpointInterceptor; -import software.amazon.awssdk.services.query.endpoints.internal.QueryResolveEndpointInterceptor; import software.amazon.awssdk.services.query.internal.QueryServiceClientConfigurationBuilder; import software.amazon.awssdk.utils.CollectionUtils; import software.amazon.awssdk.utils.Validate; @@ -81,9 +78,6 @@ protected final SdkClientConfiguration mergeServiceDefaults(SdkClientConfigurati @Override protected final SdkClientConfiguration finalizeServiceConfiguration(SdkClientConfiguration config) { List endpointInterceptors = new ArrayList<>(); - endpointInterceptors.add(new QueryAuthSchemeInterceptor()); - endpointInterceptors.add(new QueryResolveEndpointInterceptor()); - endpointInterceptors.add(new QueryRequestSetEndpointInterceptor()); ClasspathInterceptorChainFactory interceptorFactory = new ClasspathInterceptorChainFactory(); List interceptors = interceptorFactory .getInterceptors("software/amazon/awssdk/services/query/execution.interceptors"); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-aws-json-async-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-aws-json-async-client-class.java index a7d94b064cd8..361fa3c4eaf2 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-aws-json-async-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-aws-json-async-client-class.java @@ -7,6 +7,7 @@ import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.Executor; import java.util.function.Consumer; import java.util.function.Function; @@ -15,8 +16,12 @@ import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler; import software.amazon.awssdk.awscore.client.handler.AwsClientHandlerUtils; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.eventstream.EventStreamAsyncResponseTransformer; import software.amazon.awssdk.awscore.eventstream.EventStreamTaggedUnionJsonMarshaller; import software.amazon.awssdk.awscore.eventstream.EventStreamTaggedUnionPojoSupplier; @@ -29,6 +34,7 @@ import software.amazon.awssdk.core.SdkPojoBuilder; import software.amazon.awssdk.core.SdkRequest; import software.amazon.awssdk.core.SdkResponse; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.core.async.AsyncResponseTransformerUtils; @@ -40,7 +46,9 @@ import software.amazon.awssdk.core.client.handler.AsyncClientHandler; import software.amazon.awssdk.core.client.handler.AttachHttpMetadataResponseHandler; import software.amazon.awssdk.core.client.handler.ClientExecutionParams; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.interceptor.trait.HttpChecksumRequired; import software.amazon.awssdk.core.internal.interceptor.trait.RequestCompression; @@ -48,6 +56,8 @@ import software.amazon.awssdk.core.protocol.VoidSdkResponse; import software.amazon.awssdk.core.retry.RetryMode; import software.amazon.awssdk.core.runtime.transform.AsyncStreamingRequestMarshaller; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -57,6 +67,11 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.json.auth.scheme.JsonAuthSchemeParams; +import software.amazon.awssdk.services.json.auth.scheme.JsonAuthSchemeProvider; +import software.amazon.awssdk.services.json.endpoints.JsonEndpointParams; +import software.amazon.awssdk.services.json.endpoints.JsonEndpointProvider; +import software.amazon.awssdk.services.json.endpoints.internal.JsonEndpointResolverUtils; import software.amazon.awssdk.services.json.internal.JsonServiceClientConfigurationBuilder; import software.amazon.awssdk.services.json.internal.ServiceVersionInfo; import software.amazon.awssdk.services.json.model.APostOperationRequest; @@ -118,6 +133,7 @@ import software.amazon.awssdk.utils.CompletableFutureUtils; import software.amazon.awssdk.utils.HostnameValidator; import software.amazon.awssdk.utils.Pair; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link JsonAsyncClient}. @@ -216,10 +232,16 @@ public CompletableFuture aPostOperation(APostOperationRe CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("APostOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new APostOperationRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperation")) .hostPrefixExpression(resolvedHostExpression).withInput(aPostOperationRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -294,10 +316,16 @@ public CompletableFuture aPostOperationWithOut CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperationWithOutput").withProtocolMetadata(protocolMetadata) + .withOperationName("APostOperationWithOutput") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new APostOperationWithOutputRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperationWithOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperationWithOutput")) .withInput(aPostOperationWithOutputRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -400,13 +428,20 @@ public CompletableFuture eventStreamOperation(EventStreamOperationRequest CompletableFuture executeFuture = clientHandler.execute( new ClientExecutionParams() - .withOperationName("EventStreamOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("EventStreamOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new EventStreamOperationRequestMarshaller(protocolFactory)) - .withAsyncRequestBody(AsyncRequestBody.fromPublisher(adapted)).withFullDuplex(true) - .withInitialRequestEvent(true).withResponseHandler(voidResponseHandler) - .withErrorResponseHandler(errorResponseHandler).withRequestConfiguration(clientConfiguration) - .withMetricCollector(apiCallMetricCollector).withInput(eventStreamOperationRequest), - asyncResponseTransformer); + .withAsyncRequestBody(AsyncRequestBody.fromPublisher(adapted)) + .withFullDuplex(true) + .withInitialRequestEvent(true) + .withResponseHandler(voidResponseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "EventStreamOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "EventStreamOperation")) + .withInput(eventStreamOperationRequest), asyncResponseTransformer); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { if (e != null) { try { @@ -492,11 +527,18 @@ public CompletableFuture eventStreamO CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("EventStreamOperationWithOnlyInput").withProtocolMetadata(protocolMetadata) + .withOperationName("EventStreamOperationWithOnlyInput") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new EventStreamOperationWithOnlyInputRequestMarshaller(protocolFactory)) - .withAsyncRequestBody(AsyncRequestBody.fromPublisher(adapted)).withInitialRequestEvent(true) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withAsyncRequestBody(AsyncRequestBody.fromPublisher(adapted)) + .withInitialRequestEvent(true) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "EventStreamOperationWithOnlyInput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "EventStreamOperationWithOnlyInput")) .withInput(eventStreamOperationWithOnlyInputRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -596,10 +638,16 @@ public CompletableFuture eventStreamOperationWithOnlyOutput( CompletableFuture executeFuture = clientHandler.execute( new ClientExecutionParams() - .withOperationName("EventStreamOperationWithOnlyOutput").withProtocolMetadata(protocolMetadata) + .withOperationName("EventStreamOperationWithOnlyOutput") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new EventStreamOperationWithOnlyOutputRequestMarshaller(protocolFactory)) - .withResponseHandler(voidResponseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(voidResponseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "EventStreamOperationWithOnlyOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "EventStreamOperationWithOnlyOutput")) .withInput(eventStreamOperationWithOnlyOutputRequest), asyncResponseTransformer); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { if (e != null) { @@ -682,10 +730,16 @@ public CompletableFuture getWithoutRequiredMe CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("GetWithoutRequiredMembers").withProtocolMetadata(protocolMetadata) + .withOperationName("GetWithoutRequiredMembers") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new GetWithoutRequiredMembersRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "GetWithoutRequiredMembers", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GetWithoutRequiredMembers")) .withInput(getWithoutRequiredMembersRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -763,6 +817,9 @@ public CompletableFuture operationWithChe .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithChecksumRequired", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithChecksumRequired")) .putExecutionAttribute(SdkInternalExecutionAttribute.HTTP_CHECKSUM_REQUIRED, HttpChecksumRequired.create()).withInput(operationWithChecksumRequiredRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { @@ -833,10 +890,16 @@ public CompletableFuture operationWithNoneAut CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithNoneAuthType").withProtocolMetadata(protocolMetadata) + .withOperationName("OperationWithNoneAuthType") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OperationWithNoneAuthTypeRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithNoneAuthType", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithNoneAuthType")) .withInput(operationWithNoneAuthTypeRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -914,6 +977,9 @@ public CompletableFuture operationWithR .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithRequestCompression", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithRequestCompression")) .putExecutionAttribute(SdkInternalExecutionAttribute.REQUEST_COMPRESSION, RequestCompression.builder().encodings("gzip").isStreaming(false).build()) .withInput(operationWithRequestCompressionRequest)); @@ -986,10 +1052,16 @@ public CompletableFuture paginatedOpera CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("PaginatedOperationWithResultKey").withProtocolMetadata(protocolMetadata) + .withOperationName("PaginatedOperationWithResultKey") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new PaginatedOperationWithResultKeyRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PaginatedOperationWithResultKey", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PaginatedOperationWithResultKey")) .withInput(paginatedOperationWithResultKeyRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -1060,10 +1132,16 @@ public CompletableFuture paginatedOp CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("PaginatedOperationWithoutResultKey").withProtocolMetadata(protocolMetadata) + .withOperationName("PaginatedOperationWithoutResultKey") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new PaginatedOperationWithoutResultKeyRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PaginatedOperationWithoutResultKey", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PaginatedOperationWithoutResultKey")) .withInput(paginatedOperationWithoutResultKeyRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -1143,10 +1221,15 @@ public CompletableFuture streamingInputOperatio .withMarshaller( AsyncStreamingRequestMarshaller.builder() .delegateMarshaller(new StreamingInputOperationRequestMarshaller(protocolFactory)) - .asyncRequestBody(requestBody).build()).withResponseHandler(responseHandler) - .withErrorResponseHandler(errorResponseHandler).withRequestConfiguration(clientConfiguration) - .withMetricCollector(apiCallMetricCollector).withAsyncRequestBody(requestBody) - .withInput(streamingInputOperationRequest)); + .asyncRequestBody(requestBody).build()) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingInputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingInputOperation")) + .withAsyncRequestBody(requestBody).withInput(streamingInputOperationRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); @@ -1238,8 +1321,13 @@ public CompletableFuture streamingInputOutputOperation( .delegateMarshaller( new StreamingInputOutputOperationRequestMarshaller(protocolFactory)) .asyncRequestBody(requestBody).transferEncoding(true).build()) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingInputOutputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingInputOutputOperation")) .withAsyncRequestBody(requestBody).withAsyncResponseTransformer(asyncResponseTransformer) .withInput(streamingInputOutputOperationRequest), asyncResponseTransformer); AsyncResponseTransformer finalAsyncResponseTransformer = asyncResponseTransformer; @@ -1330,10 +1418,16 @@ public CompletableFuture streamingOutputOperation( CompletableFuture executeFuture = clientHandler.execute( new ClientExecutionParams() - .withOperationName("StreamingOutputOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("StreamingOutputOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new StreamingOutputOperationRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingOutputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingOutputOperation")) .withAsyncResponseTransformer(asyncResponseTransformer).withInput(streamingOutputOperationRequest), asyncResponseTransformer); AsyncResponseTransformer finalAsyncResponseTransformer = asyncResponseTransformer; @@ -1387,6 +1481,54 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + JsonAuthSchemeProvider requestAuthSchemeProvider = request + .overrideConfiguration() + .flatMap(c -> c.authSchemeProvider()) + .map(p -> Validate + .isInstanceOf(JsonAuthSchemeProvider.class, p, "Expected an instance of JsonAuthSchemeProvider")) + .orElse(null); + JsonAuthSchemeProvider authSchemeProvider = requestAuthSchemeProvider != null ? requestAuthSchemeProvider : Validate + .isInstanceOf(JsonAuthSchemeProvider.class, clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of JsonAuthSchemeProvider"); + JsonAuthSchemeParams.Builder paramsBuilder = JsonAuthSchemeParams.builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + JsonEndpointProvider provider = (JsonEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + JsonEndpointParams endpointParams = JsonEndpointResolverUtils.ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = JsonEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = JsonEndpointResolverUtils.authSchemeWithEndpointSignerProperties(endpointAuthSchemes, + selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + JsonEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-aws-query-compatible-json-async-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-aws-query-compatible-json-async-client-class.java index 96f890a789a1..3c18fc5229d3 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-aws-query-compatible-json-async-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-aws-query-compatible-json-async-client-class.java @@ -6,13 +6,18 @@ import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -20,14 +25,20 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; import software.amazon.awssdk.core.client.handler.AsyncClientHandler; import software.amazon.awssdk.core.client.handler.ClientExecutionParams; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -37,6 +48,11 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.querytojsoncompatible.auth.scheme.QueryToJsonCompatibleAuthSchemeParams; +import software.amazon.awssdk.services.querytojsoncompatible.auth.scheme.QueryToJsonCompatibleAuthSchemeProvider; +import software.amazon.awssdk.services.querytojsoncompatible.endpoints.QueryToJsonCompatibleEndpointParams; +import software.amazon.awssdk.services.querytojsoncompatible.endpoints.QueryToJsonCompatibleEndpointProvider; +import software.amazon.awssdk.services.querytojsoncompatible.endpoints.internal.QueryToJsonCompatibleEndpointResolverUtils; import software.amazon.awssdk.services.querytojsoncompatible.internal.QueryToJsonCompatibleServiceClientConfigurationBuilder; import software.amazon.awssdk.services.querytojsoncompatible.internal.ServiceVersionInfo; import software.amazon.awssdk.services.querytojsoncompatible.model.APostOperationRequest; @@ -46,6 +62,7 @@ import software.amazon.awssdk.services.querytojsoncompatible.transform.APostOperationRequestMarshaller; import software.amazon.awssdk.utils.CompletableFutureUtils; import software.amazon.awssdk.utils.HostnameValidator; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link QueryToJsonCompatibleAsyncClient}. @@ -133,10 +150,16 @@ public CompletableFuture aPostOperation(APostOperationRe CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("APostOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new APostOperationRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperation")) .hostPrefixExpression(resolvedHostExpression).withInput(aPostOperationRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -180,6 +203,56 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + QueryToJsonCompatibleAuthSchemeProvider requestAuthSchemeProvider = request + .overrideConfiguration() + .flatMap(c -> c.authSchemeProvider()) + .map(p -> Validate.isInstanceOf(QueryToJsonCompatibleAuthSchemeProvider.class, p, + "Expected an instance of QueryToJsonCompatibleAuthSchemeProvider")).orElse(null); + QueryToJsonCompatibleAuthSchemeProvider authSchemeProvider = requestAuthSchemeProvider != null ? requestAuthSchemeProvider + : Validate.isInstanceOf(QueryToJsonCompatibleAuthSchemeProvider.class, + clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of QueryToJsonCompatibleAuthSchemeProvider"); + QueryToJsonCompatibleAuthSchemeParams.Builder paramsBuilder = QueryToJsonCompatibleAuthSchemeParams.builder().operation( + operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + QueryToJsonCompatibleEndpointProvider provider = (QueryToJsonCompatibleEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + QueryToJsonCompatibleEndpointParams endpointParams = QueryToJsonCompatibleEndpointResolverUtils.ruleParams(request, + executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = QueryToJsonCompatibleEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = QueryToJsonCompatibleEndpointResolverUtils.authSchemeWithEndpointSignerProperties( + endpointAuthSchemes, selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + QueryToJsonCompatibleEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-aws-query-compatible-json-sync-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-aws-query-compatible-json-sync-client-class.java index d4fb640000aa..851eb27f0f35 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-aws-query-compatible-json-sync-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-aws-query-compatible-json-sync-client-class.java @@ -3,11 +3,16 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsSyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -15,6 +20,7 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; @@ -22,8 +28,12 @@ import software.amazon.awssdk.core.client.handler.SyncClientHandler; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -33,6 +43,11 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.querytojsoncompatible.auth.scheme.QueryToJsonCompatibleAuthSchemeParams; +import software.amazon.awssdk.services.querytojsoncompatible.auth.scheme.QueryToJsonCompatibleAuthSchemeProvider; +import software.amazon.awssdk.services.querytojsoncompatible.endpoints.QueryToJsonCompatibleEndpointParams; +import software.amazon.awssdk.services.querytojsoncompatible.endpoints.QueryToJsonCompatibleEndpointProvider; +import software.amazon.awssdk.services.querytojsoncompatible.endpoints.internal.QueryToJsonCompatibleEndpointResolverUtils; import software.amazon.awssdk.services.querytojsoncompatible.internal.QueryToJsonCompatibleServiceClientConfigurationBuilder; import software.amazon.awssdk.services.querytojsoncompatible.internal.ServiceVersionInfo; import software.amazon.awssdk.services.querytojsoncompatible.model.APostOperationRequest; @@ -42,6 +57,7 @@ import software.amazon.awssdk.services.querytojsoncompatible.transform.APostOperationRequestMarshaller; import software.amazon.awssdk.utils.HostnameValidator; import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link QueryToJsonCompatibleClient}. @@ -129,6 +145,8 @@ public APostOperationResponse aPostOperation(APostOperationRequest aPostOperatio .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .hostPrefixExpression(resolvedHostExpression).withRequestConfiguration(clientConfiguration) .withInput(aPostOperationRequest).withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "APostOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperation")) .withMarshaller(new APostOperationRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -155,6 +173,56 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + QueryToJsonCompatibleAuthSchemeProvider requestAuthSchemeProvider = request + .overrideConfiguration() + .flatMap(c -> c.authSchemeProvider()) + .map(p -> Validate.isInstanceOf(QueryToJsonCompatibleAuthSchemeProvider.class, p, + "Expected an instance of QueryToJsonCompatibleAuthSchemeProvider")).orElse(null); + QueryToJsonCompatibleAuthSchemeProvider authSchemeProvider = requestAuthSchemeProvider != null ? requestAuthSchemeProvider + : Validate.isInstanceOf(QueryToJsonCompatibleAuthSchemeProvider.class, + clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of QueryToJsonCompatibleAuthSchemeProvider"); + QueryToJsonCompatibleAuthSchemeParams.Builder paramsBuilder = QueryToJsonCompatibleAuthSchemeParams.builder().operation( + operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + QueryToJsonCompatibleEndpointProvider provider = (QueryToJsonCompatibleEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + QueryToJsonCompatibleEndpointParams endpointParams = QueryToJsonCompatibleEndpointResolverUtils.ruleParams(request, + executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = QueryToJsonCompatibleEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = QueryToJsonCompatibleEndpointResolverUtils.authSchemeWithEndpointSignerProperties( + endpointAuthSchemes, selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + QueryToJsonCompatibleEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private HttpResponseHandler createErrorResponseHandler(BaseAwsJsonProtocolFactory protocolFactory, JsonOperationMetadata operationMetadata, Function> exceptionMetadataMapper) { return protocolFactory.createErrorResponseHandler(operationMetadata, exceptionMetadataMapper); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-batchmanager-async.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-batchmanager-async.java index 1db4ed5a574f..fd0a2f5e8ed3 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-batchmanager-async.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-batchmanager-async.java @@ -6,6 +6,7 @@ import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.ScheduledExecutorService; import java.util.function.Consumer; import java.util.function.Function; @@ -13,7 +14,11 @@ import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -21,14 +26,20 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; import software.amazon.awssdk.core.client.handler.AsyncClientHandler; import software.amazon.awssdk.core.client.handler.ClientExecutionParams; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -38,7 +49,12 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.batchmanagertest.auth.scheme.BatchManagerTestAuthSchemeParams; +import software.amazon.awssdk.services.batchmanagertest.auth.scheme.BatchManagerTestAuthSchemeProvider; import software.amazon.awssdk.services.batchmanagertest.batchmanager.BatchManagerTestAsyncBatchManager; +import software.amazon.awssdk.services.batchmanagertest.endpoints.BatchManagerTestEndpointParams; +import software.amazon.awssdk.services.batchmanagertest.endpoints.BatchManagerTestEndpointProvider; +import software.amazon.awssdk.services.batchmanagertest.endpoints.internal.BatchManagerTestEndpointResolverUtils; import software.amazon.awssdk.services.batchmanagertest.internal.BatchManagerTestServiceClientConfigurationBuilder; import software.amazon.awssdk.services.batchmanagertest.internal.ServiceVersionInfo; import software.amazon.awssdk.services.batchmanagertest.model.BatchManagerTestException; @@ -46,6 +62,7 @@ import software.amazon.awssdk.services.batchmanagertest.model.SendRequestResponse; import software.amazon.awssdk.services.batchmanagertest.transform.SendRequestRequestMarshaller; import software.amazon.awssdk.utils.CompletableFutureUtils; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link BatchManagerTestAsyncClient}. @@ -129,7 +146,8 @@ public CompletableFuture sendRequest(SendRequestRequest sen .withMarshaller(new SendRequestRequestMarshaller(protocolFactory)) .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) - .withInput(sendRequestRequest)); + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "SendRequest", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "SendRequest")).withInput(sendRequestRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); @@ -177,6 +195,56 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + BatchManagerTestAuthSchemeProvider requestAuthSchemeProvider = request + .overrideConfiguration() + .flatMap(c -> c.authSchemeProvider()) + .map(p -> Validate.isInstanceOf(BatchManagerTestAuthSchemeProvider.class, p, + "Expected an instance of BatchManagerTestAuthSchemeProvider")).orElse(null); + BatchManagerTestAuthSchemeProvider authSchemeProvider = requestAuthSchemeProvider != null ? requestAuthSchemeProvider + : Validate.isInstanceOf(BatchManagerTestAuthSchemeProvider.class, + clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of BatchManagerTestAuthSchemeProvider"); + BatchManagerTestAuthSchemeParams.Builder paramsBuilder = BatchManagerTestAuthSchemeParams.builder().operation( + operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + BatchManagerTestEndpointProvider provider = (BatchManagerTestEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + BatchManagerTestEndpointParams endpointParams = BatchManagerTestEndpointResolverUtils.ruleParams(request, + executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = BatchManagerTestEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = BatchManagerTestEndpointResolverUtils.authSchemeWithEndpointSignerProperties( + endpointAuthSchemes, selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + BatchManagerTestEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-cbor-async-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-cbor-async-client-class.java index 073ba89daf3e..881fa55f2c88 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-cbor-async-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-cbor-async-client-class.java @@ -7,6 +7,7 @@ import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.Executor; import java.util.function.Consumer; import java.util.function.Function; @@ -15,8 +16,12 @@ import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler; import software.amazon.awssdk.awscore.client.handler.AwsClientHandlerUtils; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.eventstream.EventStreamAsyncResponseTransformer; import software.amazon.awssdk.awscore.eventstream.EventStreamTaggedUnionJsonMarshaller; import software.amazon.awssdk.awscore.eventstream.EventStreamTaggedUnionPojoSupplier; @@ -29,6 +34,7 @@ import software.amazon.awssdk.core.SdkPojoBuilder; import software.amazon.awssdk.core.SdkRequest; import software.amazon.awssdk.core.SdkResponse; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.core.async.AsyncResponseTransformerUtils; @@ -40,7 +46,9 @@ import software.amazon.awssdk.core.client.handler.AsyncClientHandler; import software.amazon.awssdk.core.client.handler.AttachHttpMetadataResponseHandler; import software.amazon.awssdk.core.client.handler.ClientExecutionParams; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.interceptor.trait.HttpChecksumRequired; import software.amazon.awssdk.core.internal.interceptor.trait.RequestCompression; @@ -48,6 +56,8 @@ import software.amazon.awssdk.core.protocol.VoidSdkResponse; import software.amazon.awssdk.core.retry.RetryMode; import software.amazon.awssdk.core.runtime.transform.AsyncStreamingRequestMarshaller; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -58,6 +68,11 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.json.auth.scheme.JsonAuthSchemeParams; +import software.amazon.awssdk.services.json.auth.scheme.JsonAuthSchemeProvider; +import software.amazon.awssdk.services.json.endpoints.JsonEndpointParams; +import software.amazon.awssdk.services.json.endpoints.JsonEndpointProvider; +import software.amazon.awssdk.services.json.endpoints.internal.JsonEndpointResolverUtils; import software.amazon.awssdk.services.json.internal.JsonServiceClientConfigurationBuilder; import software.amazon.awssdk.services.json.internal.ServiceVersionInfo; import software.amazon.awssdk.services.json.model.APostOperationRequest; @@ -119,6 +134,7 @@ import software.amazon.awssdk.utils.CompletableFutureUtils; import software.amazon.awssdk.utils.HostnameValidator; import software.amazon.awssdk.utils.Pair; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link JsonAsyncClient}. @@ -146,8 +162,7 @@ final class DefaultJsonAsyncClient implements JsonAsyncClient { protected DefaultJsonAsyncClient(SdkClientConfiguration clientConfiguration) { this.clientHandler = new AwsAsyncClientHandler(clientConfiguration); this.clientConfiguration = clientConfiguration.toBuilder().option(SdkClientOption.SDK_CLIENT, this) - .option(SdkClientOption.API_METADATA, - "Json_Service" + "#" + ServiceVersionInfo.VERSION).build(); + .option(SdkClientOption.API_METADATA, "Json_Service" + "#" + ServiceVersionInfo.VERSION).build(); this.protocolFactory = init(AwsCborProtocolFactory.builder()).build(); this.jsonProtocolFactory = init(AwsJsonProtocolFactory.builder()).build(); this.executor = clientConfiguration.option(SdkAdvancedAsyncClientOption.FUTURE_COMPLETION_EXECUTOR); @@ -221,10 +236,16 @@ public CompletableFuture aPostOperation(APostOperationRe CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("APostOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new APostOperationRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperation")) .hostPrefixExpression(resolvedHostExpression).withInput(aPostOperationRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -299,10 +320,16 @@ public CompletableFuture aPostOperationWithOut CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperationWithOutput").withProtocolMetadata(protocolMetadata) + .withOperationName("APostOperationWithOutput") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new APostOperationWithOutputRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperationWithOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperationWithOutput")) .withInput(aPostOperationWithOutputRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -405,13 +432,20 @@ public CompletableFuture eventStreamOperation(EventStreamOperationRequest CompletableFuture executeFuture = clientHandler.execute( new ClientExecutionParams() - .withOperationName("EventStreamOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("EventStreamOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new EventStreamOperationRequestMarshaller(protocolFactory)) - .withAsyncRequestBody(AsyncRequestBody.fromPublisher(adapted)).withFullDuplex(true) - .withInitialRequestEvent(true).withResponseHandler(voidResponseHandler) - .withErrorResponseHandler(errorResponseHandler).withRequestConfiguration(clientConfiguration) - .withMetricCollector(apiCallMetricCollector).withInput(eventStreamOperationRequest), - asyncResponseTransformer); + .withAsyncRequestBody(AsyncRequestBody.fromPublisher(adapted)) + .withFullDuplex(true) + .withInitialRequestEvent(true) + .withResponseHandler(voidResponseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "EventStreamOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "EventStreamOperation")) + .withInput(eventStreamOperationRequest), asyncResponseTransformer); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { if (e != null) { try { @@ -497,11 +531,18 @@ public CompletableFuture eventStreamO CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("EventStreamOperationWithOnlyInput").withProtocolMetadata(protocolMetadata) + .withOperationName("EventStreamOperationWithOnlyInput") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new EventStreamOperationWithOnlyInputRequestMarshaller(protocolFactory)) - .withAsyncRequestBody(AsyncRequestBody.fromPublisher(adapted)).withInitialRequestEvent(true) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withAsyncRequestBody(AsyncRequestBody.fromPublisher(adapted)) + .withInitialRequestEvent(true) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "EventStreamOperationWithOnlyInput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "EventStreamOperationWithOnlyInput")) .withInput(eventStreamOperationWithOnlyInputRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -601,10 +642,16 @@ public CompletableFuture eventStreamOperationWithOnlyOutput( CompletableFuture executeFuture = clientHandler.execute( new ClientExecutionParams() - .withOperationName("EventStreamOperationWithOnlyOutput").withProtocolMetadata(protocolMetadata) + .withOperationName("EventStreamOperationWithOnlyOutput") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new EventStreamOperationWithOnlyOutputRequestMarshaller(protocolFactory)) - .withResponseHandler(voidResponseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(voidResponseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "EventStreamOperationWithOnlyOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "EventStreamOperationWithOnlyOutput")) .withInput(eventStreamOperationWithOnlyOutputRequest), asyncResponseTransformer); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { if (e != null) { @@ -687,10 +734,16 @@ public CompletableFuture getWithoutRequiredMe CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("GetWithoutRequiredMembers").withProtocolMetadata(protocolMetadata) + .withOperationName("GetWithoutRequiredMembers") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new GetWithoutRequiredMembersRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "GetWithoutRequiredMembers", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GetWithoutRequiredMembers")) .withInput(getWithoutRequiredMembersRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -768,6 +821,9 @@ public CompletableFuture operationWithChe .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithChecksumRequired", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithChecksumRequired")) .putExecutionAttribute(SdkInternalExecutionAttribute.HTTP_CHECKSUM_REQUIRED, HttpChecksumRequired.create()).withInput(operationWithChecksumRequiredRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { @@ -838,10 +894,16 @@ public CompletableFuture operationWithNoneAut CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithNoneAuthType").withProtocolMetadata(protocolMetadata) + .withOperationName("OperationWithNoneAuthType") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OperationWithNoneAuthTypeRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithNoneAuthType", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithNoneAuthType")) .withInput(operationWithNoneAuthTypeRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -919,6 +981,9 @@ public CompletableFuture operationWithR .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithRequestCompression", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithRequestCompression")) .putExecutionAttribute(SdkInternalExecutionAttribute.REQUEST_COMPRESSION, RequestCompression.builder().encodings("gzip").isStreaming(false).build()) .withInput(operationWithRequestCompressionRequest)); @@ -991,10 +1056,16 @@ public CompletableFuture paginatedOpera CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("PaginatedOperationWithResultKey").withProtocolMetadata(protocolMetadata) + .withOperationName("PaginatedOperationWithResultKey") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new PaginatedOperationWithResultKeyRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PaginatedOperationWithResultKey", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PaginatedOperationWithResultKey")) .withInput(paginatedOperationWithResultKeyRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -1065,10 +1136,16 @@ public CompletableFuture paginatedOp CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("PaginatedOperationWithoutResultKey").withProtocolMetadata(protocolMetadata) + .withOperationName("PaginatedOperationWithoutResultKey") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new PaginatedOperationWithoutResultKeyRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PaginatedOperationWithoutResultKey", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PaginatedOperationWithoutResultKey")) .withInput(paginatedOperationWithoutResultKeyRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -1148,10 +1225,15 @@ public CompletableFuture streamingInputOperatio .withMarshaller( AsyncStreamingRequestMarshaller.builder() .delegateMarshaller(new StreamingInputOperationRequestMarshaller(protocolFactory)) - .asyncRequestBody(requestBody).build()).withResponseHandler(responseHandler) - .withErrorResponseHandler(errorResponseHandler).withRequestConfiguration(clientConfiguration) - .withMetricCollector(apiCallMetricCollector).withAsyncRequestBody(requestBody) - .withInput(streamingInputOperationRequest)); + .asyncRequestBody(requestBody).build()) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingInputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingInputOperation")) + .withAsyncRequestBody(requestBody).withInput(streamingInputOperationRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); @@ -1243,8 +1325,13 @@ public CompletableFuture streamingInputOutputOperation( .delegateMarshaller( new StreamingInputOutputOperationRequestMarshaller(protocolFactory)) .asyncRequestBody(requestBody).transferEncoding(true).build()) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingInputOutputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingInputOutputOperation")) .withAsyncRequestBody(requestBody).withAsyncResponseTransformer(asyncResponseTransformer) .withInput(streamingInputOutputOperationRequest), asyncResponseTransformer); AsyncResponseTransformer finalAsyncResponseTransformer = asyncResponseTransformer; @@ -1335,10 +1422,16 @@ public CompletableFuture streamingOutputOperation( CompletableFuture executeFuture = clientHandler.execute( new ClientExecutionParams() - .withOperationName("StreamingOutputOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("StreamingOutputOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new StreamingOutputOperationRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingOutputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingOutputOperation")) .withAsyncResponseTransformer(asyncResponseTransformer).withInput(streamingOutputOperationRequest), asyncResponseTransformer); AsyncResponseTransformer finalAsyncResponseTransformer = asyncResponseTransformer; @@ -1392,6 +1485,54 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + JsonAuthSchemeProvider requestAuthSchemeProvider = request + .overrideConfiguration() + .flatMap(c -> c.authSchemeProvider()) + .map(p -> Validate + .isInstanceOf(JsonAuthSchemeProvider.class, p, "Expected an instance of JsonAuthSchemeProvider")) + .orElse(null); + JsonAuthSchemeProvider authSchemeProvider = requestAuthSchemeProvider != null ? requestAuthSchemeProvider : Validate + .isInstanceOf(JsonAuthSchemeProvider.class, clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of JsonAuthSchemeProvider"); + JsonAuthSchemeParams.Builder paramsBuilder = JsonAuthSchemeParams.builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + JsonEndpointProvider provider = (JsonEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + JsonEndpointParams endpointParams = JsonEndpointResolverUtils.ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = JsonEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = JsonEndpointResolverUtils.authSchemeWithEndpointSignerProperties(endpointAuthSchemes, + selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + JsonEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-cbor-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-cbor-client-class.java index a69befcfad26..55f7aeaf5c39 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-cbor-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-cbor-client-class.java @@ -3,11 +3,16 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsSyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -15,6 +20,7 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; @@ -22,6 +28,7 @@ import software.amazon.awssdk.core.client.handler.SyncClientHandler; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.interceptor.trait.HttpChecksumRequired; import software.amazon.awssdk.core.internal.interceptor.trait.RequestCompression; @@ -30,6 +37,8 @@ import software.amazon.awssdk.core.runtime.transform.StreamingRequestMarshaller; import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.core.sync.ResponseTransformer; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -39,6 +48,11 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.json.auth.scheme.JsonAuthSchemeParams; +import software.amazon.awssdk.services.json.auth.scheme.JsonAuthSchemeProvider; +import software.amazon.awssdk.services.json.endpoints.JsonEndpointParams; +import software.amazon.awssdk.services.json.endpoints.JsonEndpointProvider; +import software.amazon.awssdk.services.json.endpoints.internal.JsonEndpointResolverUtils; import software.amazon.awssdk.services.json.internal.JsonServiceClientConfigurationBuilder; import software.amazon.awssdk.services.json.internal.ServiceVersionInfo; import software.amazon.awssdk.services.json.model.APostOperationRequest; @@ -79,6 +93,7 @@ import software.amazon.awssdk.services.json.transform.StreamingOutputOperationRequestMarshaller; import software.amazon.awssdk.utils.HostnameValidator; import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link JsonClient}. @@ -169,6 +184,8 @@ public APostOperationResponse aPostOperation(APostOperationRequest aPostOperatio .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .hostPrefixExpression(resolvedHostExpression).withRequestConfiguration(clientConfiguration) .withInput(aPostOperationRequest).withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "APostOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperation")) .withMarshaller(new APostOperationRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -234,10 +251,16 @@ public APostOperationWithOutputResponse aPostOperationWithOutput( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperationWithOutput").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(aPostOperationWithOutputRequest) + .withOperationName("APostOperationWithOutput") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(aPostOperationWithOutputRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperationWithOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperationWithOutput")) .withMarshaller(new APostOperationWithOutputRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -302,10 +325,16 @@ public GetWithoutRequiredMembersResponse getWithoutRequiredMembers( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("GetWithoutRequiredMembers").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(getWithoutRequiredMembersRequest) + .withOperationName("GetWithoutRequiredMembers") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(getWithoutRequiredMembersRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "GetWithoutRequiredMembers", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GetWithoutRequiredMembers")) .withMarshaller(new GetWithoutRequiredMembersRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -373,6 +402,9 @@ public OperationWithChecksumRequiredResponse operationWithChecksumRequired( .withRequestConfiguration(clientConfiguration) .withInput(operationWithChecksumRequiredRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithChecksumRequired", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithChecksumRequired")) .putExecutionAttribute(SdkInternalExecutionAttribute.HTTP_CHECKSUM_REQUIRED, HttpChecksumRequired.create()) .withMarshaller(new OperationWithChecksumRequiredRequestMarshaller(protocolFactory))); @@ -435,10 +467,16 @@ public OperationWithNoneAuthTypeResponse operationWithNoneAuthType( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithNoneAuthType").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(operationWithNoneAuthTypeRequest) + .withOperationName("OperationWithNoneAuthType") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(operationWithNoneAuthTypeRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithNoneAuthType", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithNoneAuthType")) .withMarshaller(new OperationWithNoneAuthTypeRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -506,6 +544,9 @@ public OperationWithRequestCompressionResponse operationWithRequestCompression( .withRequestConfiguration(clientConfiguration) .withInput(operationWithRequestCompressionRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithRequestCompression", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithRequestCompression")) .putExecutionAttribute(SdkInternalExecutionAttribute.REQUEST_COMPRESSION, RequestCompression.builder().encodings("gzip").isStreaming(false).build()) .withMarshaller(new OperationWithRequestCompressionRequestMarshaller(protocolFactory))); @@ -568,10 +609,16 @@ public PaginatedOperationWithResultKeyResponse paginatedOperationWithResultKey( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("PaginatedOperationWithResultKey").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(paginatedOperationWithResultKeyRequest) + .withOperationName("PaginatedOperationWithResultKey") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(paginatedOperationWithResultKeyRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PaginatedOperationWithResultKey", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PaginatedOperationWithResultKey")) .withMarshaller(new PaginatedOperationWithResultKeyRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -632,10 +679,16 @@ public PaginatedOperationWithoutResultKeyResponse paginatedOperationWithoutResul return clientHandler .execute(new ClientExecutionParams() - .withOperationName("PaginatedOperationWithoutResultKey").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(paginatedOperationWithoutResultKeyRequest) + .withOperationName("PaginatedOperationWithoutResultKey") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(paginatedOperationWithoutResultKeyRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PaginatedOperationWithoutResultKey", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PaginatedOperationWithoutResultKey")) .withMarshaller(new PaginatedOperationWithoutResultKeyRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -713,6 +766,9 @@ public StreamingInputOperationResponse streamingInputOperation(StreamingInputOpe .withRequestConfiguration(clientConfiguration) .withInput(streamingInputOperationRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingInputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingInputOperation")) .withRequestBody(requestBody) .withMarshaller( StreamingRequestMarshaller.builder() @@ -803,6 +859,9 @@ public ReturnT streamingInputOutputOperation( .withRequestConfiguration(clientConfiguration) .withInput(streamingInputOutputOperationRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingInputOutputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingInputOutputOperation")) .withResponseTransformer(responseTransformer) .withRequestBody(requestBody) .withMarshaller( @@ -877,10 +936,17 @@ public ReturnT streamingOutputOperation(StreamingOutputOperationReques return clientHandler.execute( new ClientExecutionParams() - .withOperationName("StreamingOutputOperation").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(streamingOutputOperationRequest) - .withMetricCollector(apiCallMetricCollector).withResponseTransformer(responseTransformer) + .withOperationName("StreamingOutputOperation") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(streamingOutputOperationRequest) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingOutputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingOutputOperation")) + .withResponseTransformer(responseTransformer) .withMarshaller(new StreamingOutputOperationRequestMarshaller(protocolFactory)), responseTransformer); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -915,6 +981,54 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + JsonAuthSchemeProvider requestAuthSchemeProvider = request + .overrideConfiguration() + .flatMap(c -> c.authSchemeProvider()) + .map(p -> Validate + .isInstanceOf(JsonAuthSchemeProvider.class, p, "Expected an instance of JsonAuthSchemeProvider")) + .orElse(null); + JsonAuthSchemeProvider authSchemeProvider = requestAuthSchemeProvider != null ? requestAuthSchemeProvider : Validate + .isInstanceOf(JsonAuthSchemeProvider.class, clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of JsonAuthSchemeProvider"); + JsonAuthSchemeParams.Builder paramsBuilder = JsonAuthSchemeParams.builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + JsonEndpointProvider provider = (JsonEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + JsonEndpointParams endpointParams = JsonEndpointResolverUtils.ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = JsonEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = JsonEndpointResolverUtils.authSchemeWithEndpointSignerProperties(endpointAuthSchemes, + selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + JsonEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private HttpResponseHandler createErrorResponseHandler(BaseAwsJsonProtocolFactory protocolFactory, JsonOperationMetadata operationMetadata, Function> exceptionMetadataMapper) { return protocolFactory.createErrorResponseHandler(operationMetadata, exceptionMetadataMapper); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custom-context-params-async-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custom-context-params-async-client-class.java index d59409dcbcdd..a1701b2a6df6 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custom-context-params-async-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custom-context-params-async-client-class.java @@ -7,13 +7,18 @@ import java.util.Objects; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -21,14 +26,20 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; import software.amazon.awssdk.core.client.handler.AsyncClientHandler; import software.amazon.awssdk.core.client.handler.ClientExecutionParams; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -38,7 +49,12 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.foobar.auth.scheme.FooBarAuthSchemeParams; +import software.amazon.awssdk.services.foobar.auth.scheme.FooBarAuthSchemeProvider; import software.amazon.awssdk.services.foobar.endpoints.FooBarClientContextParams; +import software.amazon.awssdk.services.foobar.endpoints.FooBarEndpointParams; +import software.amazon.awssdk.services.foobar.endpoints.FooBarEndpointProvider; +import software.amazon.awssdk.services.foobar.endpoints.internal.FooBarEndpointResolverUtils; import software.amazon.awssdk.services.foobar.internal.FooBarServiceClientConfigurationBuilder; import software.amazon.awssdk.services.foobar.internal.ServiceVersionInfo; import software.amazon.awssdk.services.foobar.model.FooBarException; @@ -60,7 +76,7 @@ final class DefaultFooBarAsyncClient implements FooBarAsyncClient { private static final Logger log = LoggerFactory.getLogger(DefaultFooBarAsyncClient.class); private static final AwsProtocolMetadata protocolMetadata = AwsProtocolMetadata.builder() - .serviceProtocol(AwsServiceProtocol.REST_JSON).build(); + .serviceProtocol(AwsServiceProtocol.REST_JSON).build(); private final AsyncClientHandler clientHandler; @@ -71,7 +87,7 @@ final class DefaultFooBarAsyncClient implements FooBarAsyncClient { protected DefaultFooBarAsyncClient(SdkClientConfiguration clientConfiguration) { this.clientHandler = new AwsAsyncClientHandler(clientConfiguration); this.clientConfiguration = clientConfiguration.toBuilder().option(SdkClientOption.SDK_CLIENT, this) - .option(SdkClientOption.API_METADATA, "Foo_Bar" + "#" + ServiceVersionInfo.VERSION).build(); + .option(SdkClientOption.API_METADATA, "Foo_Bar" + "#" + ServiceVersionInfo.VERSION).build(); this.protocolFactory = init(AwsJsonProtocolFactory.builder()).build(); } @@ -100,38 +116,44 @@ protected DefaultFooBarAsyncClient(SdkClientConfiguration clientConfiguration) { @Override public CompletableFuture getDatabaseVersion(GetDatabaseVersionRequest getDatabaseVersionRequest) { SdkClientConfiguration clientConfiguration = updateSdkClientConfiguration(getDatabaseVersionRequest, - this.clientConfiguration); + this.clientConfiguration); List metricPublishers = resolveMetricPublishers(clientConfiguration, getDatabaseVersionRequest - .overrideConfiguration().orElse(null)); + .overrideConfiguration().orElse(null)); MetricCollector apiCallMetricCollector = metricPublishers.isEmpty() ? NoOpMetricCollector.create() : MetricCollector - .create("ApiCall"); + .create("ApiCall"); try { apiCallMetricCollector.reportMetric(CoreMetric.SERVICE_ID, "Foo Bar"); apiCallMetricCollector.reportMetric(CoreMetric.OPERATION_NAME, "GetDatabaseVersion"); JsonOperationMetadata operationMetadata = JsonOperationMetadata.builder().hasStreamingSuccessResponse(false) - .isPayloadJson(true).build(); + .isPayloadJson(true).build(); HttpResponseHandler responseHandler = protocolFactory.createResponseHandler( - operationMetadata, GetDatabaseVersionResponse::builder); + operationMetadata, GetDatabaseVersionResponse::builder); Function> exceptionMetadataMapper = errorCode -> { if (errorCode == null) { return Optional.empty(); } switch (errorCode) { - default: - return Optional.empty(); + default: + return Optional.empty(); } }; HttpResponseHandler errorResponseHandler = createErrorResponseHandler(protocolFactory, - operationMetadata, exceptionMetadataMapper); + operationMetadata, exceptionMetadataMapper); CompletableFuture executeFuture = clientHandler - .execute(new ClientExecutionParams() - .withOperationName("GetDatabaseVersion").withProtocolMetadata(protocolMetadata) - .withMarshaller(new GetDatabaseVersionRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) - .withInput(getDatabaseVersionRequest)); + .execute(new ClientExecutionParams() + .withOperationName("GetDatabaseVersion") + .withProtocolMetadata(protocolMetadata) + .withMarshaller(new GetDatabaseVersionRequestMarshaller(protocolFactory)) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "GetDatabaseVersion", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GetDatabaseVersion")) + .withInput(getDatabaseVersionRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); @@ -155,11 +177,11 @@ public final String serviceName() { private > T init(T builder) { return builder.clientConfiguration(clientConfiguration).defaultServiceExceptionSupplier(FooBarException::builder) - .protocol(AwsJsonProtocol.REST_JSON).protocolVersion("1.1"); + .protocol(AwsJsonProtocol.REST_JSON).protocolVersion("1.1"); } private static List resolveMetricPublishers(SdkClientConfiguration clientConfiguration, - RequestOverrideConfiguration requestOverrideConfiguration) { + RequestOverrideConfiguration requestOverrideConfiguration) { List publishers = null; if (requestOverrideConfiguration != null) { publishers = requestOverrideConfiguration.metricPublishers(); @@ -173,6 +195,53 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + FooBarAuthSchemeProvider requestAuthSchemeProvider = request + .overrideConfiguration() + .flatMap(c -> c.authSchemeProvider()) + .map(p -> Validate.isInstanceOf(FooBarAuthSchemeProvider.class, p, + "Expected an instance of FooBarAuthSchemeProvider")).orElse(null); + FooBarAuthSchemeProvider authSchemeProvider = requestAuthSchemeProvider != null ? requestAuthSchemeProvider : Validate + .isInstanceOf(FooBarAuthSchemeProvider.class, clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of FooBarAuthSchemeProvider"); + FooBarAuthSchemeParams.Builder paramsBuilder = FooBarAuthSchemeParams.builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + FooBarEndpointProvider provider = (FooBarEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + FooBarEndpointParams endpointParams = FooBarEndpointResolverUtils.ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = FooBarEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = FooBarEndpointResolverUtils.authSchemeWithEndpointSignerProperties(endpointAuthSchemes, + selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + FooBarEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); @@ -211,15 +280,15 @@ private SdkClientConfiguration updateSdkClientConfiguration(SdkRequest request, newContextParams = (newContextParams != null) ? newContextParams : AttributeMap.empty(); originalContextParams = originalContextParams != null ? originalContextParams : AttributeMap.empty(); Validate.validState( - Objects.equals(originalContextParams.get(FooBarClientContextParams.CROSS_REGION_ACCESS_ENABLED), - newContextParams.get(FooBarClientContextParams.CROSS_REGION_ACCESS_ENABLED)), - "CROSS_REGION_ACCESS_ENABLED cannot be modified by request level plugins"); + Objects.equals(originalContextParams.get(FooBarClientContextParams.CROSS_REGION_ACCESS_ENABLED), + newContextParams.get(FooBarClientContextParams.CROSS_REGION_ACCESS_ENABLED)), + "CROSS_REGION_ACCESS_ENABLED cannot be modified by request level plugins"); updateRetryStrategyClientConfiguration(configuration); return configuration.build(); } private HttpResponseHandler createErrorResponseHandler(BaseAwsJsonProtocolFactory protocolFactory, - JsonOperationMetadata operationMetadata, Function> exceptionMetadataMapper) { + JsonOperationMetadata operationMetadata, Function> exceptionMetadataMapper) { return protocolFactory.createErrorResponseHandler(operationMetadata, exceptionMetadataMapper); } diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custom-context-params-sync-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custom-context-params-sync-client-class.java index 7b8faf909d0c..2355b277425d 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custom-context-params-sync-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custom-context-params-sync-client-class.java @@ -4,11 +4,16 @@ import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsSyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -16,6 +21,7 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; @@ -23,8 +29,12 @@ import software.amazon.awssdk.core.client.handler.SyncClientHandler; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -34,7 +44,12 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.foobar.auth.scheme.FooBarAuthSchemeParams; +import software.amazon.awssdk.services.foobar.auth.scheme.FooBarAuthSchemeProvider; import software.amazon.awssdk.services.foobar.endpoints.FooBarClientContextParams; +import software.amazon.awssdk.services.foobar.endpoints.FooBarEndpointParams; +import software.amazon.awssdk.services.foobar.endpoints.FooBarEndpointProvider; +import software.amazon.awssdk.services.foobar.endpoints.internal.FooBarEndpointResolverUtils; import software.amazon.awssdk.services.foobar.internal.FooBarServiceClientConfigurationBuilder; import software.amazon.awssdk.services.foobar.internal.ServiceVersionInfo; import software.amazon.awssdk.services.foobar.model.FooBarException; @@ -56,7 +71,7 @@ final class DefaultFooBarClient implements FooBarClient { private static final Logger log = Logger.loggerFor(DefaultFooBarClient.class); private static final AwsProtocolMetadata protocolMetadata = AwsProtocolMetadata.builder() - .serviceProtocol(AwsServiceProtocol.REST_JSON).build(); + .serviceProtocol(AwsServiceProtocol.REST_JSON).build(); private final SyncClientHandler clientHandler; @@ -67,7 +82,7 @@ final class DefaultFooBarClient implements FooBarClient { protected DefaultFooBarClient(SdkClientConfiguration clientConfiguration) { this.clientHandler = new AwsSyncClientHandler(clientConfiguration); this.clientConfiguration = clientConfiguration.toBuilder().option(SdkClientOption.SDK_CLIENT, this) - .option(SdkClientOption.API_METADATA, "Foo_Bar" + "#" + ServiceVersionInfo.VERSION).build(); + .option(SdkClientOption.API_METADATA, "Foo_Bar" + "#" + ServiceVersionInfo.VERSION).build(); this.protocolFactory = init(AwsJsonProtocolFactory.builder()).build(); } @@ -91,39 +106,41 @@ protected DefaultFooBarClient(SdkClientConfiguration clientConfiguration) { */ @Override public GetDatabaseVersionResponse getDatabaseVersion(GetDatabaseVersionRequest getDatabaseVersionRequest) - throws AwsServiceException, SdkClientException, FooBarException { + throws AwsServiceException, SdkClientException, FooBarException { JsonOperationMetadata operationMetadata = JsonOperationMetadata.builder().hasStreamingSuccessResponse(false) - .isPayloadJson(true).build(); + .isPayloadJson(true).build(); HttpResponseHandler responseHandler = protocolFactory.createResponseHandler( - operationMetadata, GetDatabaseVersionResponse::builder); + operationMetadata, GetDatabaseVersionResponse::builder); Function> exceptionMetadataMapper = errorCode -> { if (errorCode == null) { return Optional.empty(); } switch (errorCode) { - default: - return Optional.empty(); + default: + return Optional.empty(); } }; HttpResponseHandler errorResponseHandler = createErrorResponseHandler(protocolFactory, - operationMetadata, exceptionMetadataMapper); + operationMetadata, exceptionMetadataMapper); SdkClientConfiguration clientConfiguration = updateSdkClientConfiguration(getDatabaseVersionRequest, - this.clientConfiguration); + this.clientConfiguration); List metricPublishers = resolveMetricPublishers(clientConfiguration, getDatabaseVersionRequest - .overrideConfiguration().orElse(null)); + .overrideConfiguration().orElse(null)); MetricCollector apiCallMetricCollector = metricPublishers.isEmpty() ? NoOpMetricCollector.create() : MetricCollector - .create("ApiCall"); + .create("ApiCall"); try { apiCallMetricCollector.reportMetric(CoreMetric.SERVICE_ID, "Foo Bar"); apiCallMetricCollector.reportMetric(CoreMetric.OPERATION_NAME, "GetDatabaseVersion"); return clientHandler.execute(new ClientExecutionParams() - .withOperationName("GetDatabaseVersion").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(getDatabaseVersionRequest) - .withMetricCollector(apiCallMetricCollector) - .withMarshaller(new GetDatabaseVersionRequestMarshaller(protocolFactory))); + .withOperationName("GetDatabaseVersion").withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration).withInput(getDatabaseVersionRequest) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "GetDatabaseVersion", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GetDatabaseVersion")) + .withMarshaller(new GetDatabaseVersionRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); } @@ -135,7 +152,7 @@ public final String serviceName() { } private static List resolveMetricPublishers(SdkClientConfiguration clientConfiguration, - RequestOverrideConfiguration requestOverrideConfiguration) { + RequestOverrideConfiguration requestOverrideConfiguration) { List publishers = null; if (requestOverrideConfiguration != null) { publishers = requestOverrideConfiguration.metricPublishers(); @@ -149,8 +166,55 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + FooBarAuthSchemeProvider requestAuthSchemeProvider = request + .overrideConfiguration() + .flatMap(c -> c.authSchemeProvider()) + .map(p -> Validate.isInstanceOf(FooBarAuthSchemeProvider.class, p, + "Expected an instance of FooBarAuthSchemeProvider")).orElse(null); + FooBarAuthSchemeProvider authSchemeProvider = requestAuthSchemeProvider != null ? requestAuthSchemeProvider : Validate + .isInstanceOf(FooBarAuthSchemeProvider.class, clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of FooBarAuthSchemeProvider"); + FooBarAuthSchemeParams.Builder paramsBuilder = FooBarAuthSchemeParams.builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + FooBarEndpointProvider provider = (FooBarEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + FooBarEndpointParams endpointParams = FooBarEndpointResolverUtils.ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = FooBarEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = FooBarEndpointResolverUtils.authSchemeWithEndpointSignerProperties(endpointAuthSchemes, + selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + FooBarEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private HttpResponseHandler createErrorResponseHandler(BaseAwsJsonProtocolFactory protocolFactory, - JsonOperationMetadata operationMetadata, Function> exceptionMetadataMapper) { + JsonOperationMetadata operationMetadata, Function> exceptionMetadataMapper) { return protocolFactory.createErrorResponseHandler(operationMetadata, exceptionMetadataMapper); } @@ -192,16 +256,16 @@ private SdkClientConfiguration updateSdkClientConfiguration(SdkRequest request, newContextParams = (newContextParams != null) ? newContextParams : AttributeMap.empty(); originalContextParams = originalContextParams != null ? originalContextParams : AttributeMap.empty(); Validate.validState( - Objects.equals(originalContextParams.get(FooBarClientContextParams.CROSS_REGION_ACCESS_ENABLED), - newContextParams.get(FooBarClientContextParams.CROSS_REGION_ACCESS_ENABLED)), - "CROSS_REGION_ACCESS_ENABLED cannot be modified by request level plugins"); + Objects.equals(originalContextParams.get(FooBarClientContextParams.CROSS_REGION_ACCESS_ENABLED), + newContextParams.get(FooBarClientContextParams.CROSS_REGION_ACCESS_ENABLED)), + "CROSS_REGION_ACCESS_ENABLED cannot be modified by request level plugins"); updateRetryStrategyClientConfiguration(configuration); return configuration.build(); } private > T init(T builder) { return builder.clientConfiguration(clientConfiguration).defaultServiceExceptionSupplier(FooBarException::builder) - .protocol(AwsJsonProtocol.REST_JSON).protocolVersion("1.1"); + .protocol(AwsJsonProtocol.REST_JSON).protocolVersion("1.1"); } @Override diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custompackage-async.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custompackage-async.java index bc69607bf5d9..7ef35cc63e3f 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custompackage-async.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custompackage-async.java @@ -2,6 +2,11 @@ import static software.amazon.awssdk.utils.FunctionalUtils.runAndLogError; +import foo.bar.helloworld.auth.scheme.ProtocolRestJsonWithCustomPackageAuthSchemeParams; +import foo.bar.helloworld.auth.scheme.ProtocolRestJsonWithCustomPackageAuthSchemeProvider; +import foo.bar.helloworld.endpoints.ProtocolRestJsonWithCustomPackageEndpointParams; +import foo.bar.helloworld.endpoints.ProtocolRestJsonWithCustomPackageEndpointProvider; +import foo.bar.helloworld.endpoints.internal.ProtocolRestJsonWithCustomPackageEndpointResolverUtils; import foo.bar.helloworld.internal.ProtocolRestJsonWithCustomPackageServiceClientConfigurationBuilder; import foo.bar.helloworld.internal.ServiceVersionInfo; import foo.bar.helloworld.model.OneOperationRequest; @@ -12,13 +17,18 @@ import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -26,14 +36,20 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; import software.amazon.awssdk.core.client.handler.AsyncClientHandler; import software.amazon.awssdk.core.client.handler.ClientExecutionParams; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -44,6 +60,7 @@ import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.utils.CompletableFutureUtils; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link ProtocolRestJsonWithCustomPackageAsyncClient}. @@ -66,8 +83,11 @@ final class DefaultProtocolRestJsonWithCustomPackageAsyncClient implements Proto protected DefaultProtocolRestJsonWithCustomPackageAsyncClient(SdkClientConfiguration clientConfiguration) { this.clientHandler = new AwsAsyncClientHandler(clientConfiguration); - this.clientConfiguration = clientConfiguration.toBuilder().option(SdkClientOption.SDK_CLIENT, this) - .option(SdkClientOption.API_METADATA, "AmazonProtocolRestJsonWithCustomPackage" + "#" + ServiceVersionInfo.VERSION).build(); + this.clientConfiguration = clientConfiguration + .toBuilder() + .option(SdkClientOption.SDK_CLIENT, this) + .option(SdkClientOption.API_METADATA, + "AmazonProtocolRestJsonWithCustomPackage" + "#" + ServiceVersionInfo.VERSION).build(); this.protocolFactory = init(AwsJsonProtocolFactory.builder()).build(); } @@ -124,7 +144,8 @@ public CompletableFuture oneOperation(OneOperationRequest .withMarshaller(new OneOperationRequestMarshaller(protocolFactory)) .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) - .withInput(oneOperationRequest)); + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "OneOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OneOperation")).withInput(oneOperationRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); @@ -168,6 +189,57 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + ProtocolRestJsonWithCustomPackageAuthSchemeProvider requestAuthSchemeProvider = request + .overrideConfiguration() + .flatMap(c -> c.authSchemeProvider()) + .map(p -> Validate.isInstanceOf(ProtocolRestJsonWithCustomPackageAuthSchemeProvider.class, p, + "Expected an instance of ProtocolRestJsonWithCustomPackageAuthSchemeProvider")).orElse(null); + ProtocolRestJsonWithCustomPackageAuthSchemeProvider authSchemeProvider = requestAuthSchemeProvider != null ? requestAuthSchemeProvider + : Validate.isInstanceOf(ProtocolRestJsonWithCustomPackageAuthSchemeProvider.class, + clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of ProtocolRestJsonWithCustomPackageAuthSchemeProvider"); + ProtocolRestJsonWithCustomPackageAuthSchemeParams.Builder paramsBuilder = ProtocolRestJsonWithCustomPackageAuthSchemeParams + .builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + ProtocolRestJsonWithCustomPackageEndpointProvider provider = (ProtocolRestJsonWithCustomPackageEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + ProtocolRestJsonWithCustomPackageEndpointParams endpointParams = ProtocolRestJsonWithCustomPackageEndpointResolverUtils + .ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = ProtocolRestJsonWithCustomPackageEndpointResolverUtils.hostPrefix(operationName, + request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = ProtocolRestJsonWithCustomPackageEndpointResolverUtils + .authSchemeWithEndpointSignerProperties(endpointAuthSchemes, selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + ProtocolRestJsonWithCustomPackageEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custompackage-sync.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custompackage-sync.java index fe80546d6f1a..7762181fb6fc 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custompackage-sync.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custompackage-sync.java @@ -1,5 +1,10 @@ package foo.bar.helloworld; +import foo.bar.helloworld.auth.scheme.ProtocolRestJsonWithCustomPackageAuthSchemeParams; +import foo.bar.helloworld.auth.scheme.ProtocolRestJsonWithCustomPackageAuthSchemeProvider; +import foo.bar.helloworld.endpoints.ProtocolRestJsonWithCustomPackageEndpointParams; +import foo.bar.helloworld.endpoints.ProtocolRestJsonWithCustomPackageEndpointProvider; +import foo.bar.helloworld.endpoints.internal.ProtocolRestJsonWithCustomPackageEndpointResolverUtils; import foo.bar.helloworld.internal.ProtocolRestJsonWithCustomPackageServiceClientConfigurationBuilder; import foo.bar.helloworld.internal.ServiceVersionInfo; import foo.bar.helloworld.model.OneOperationRequest; @@ -9,11 +14,16 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsSyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -21,6 +31,7 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; @@ -28,8 +39,12 @@ import software.amazon.awssdk.core.client.handler.SyncClientHandler; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -40,6 +55,7 @@ import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link ProtocolRestJsonWithCustomPackageClient}. @@ -62,8 +78,11 @@ final class DefaultProtocolRestJsonWithCustomPackageClient implements ProtocolRe protected DefaultProtocolRestJsonWithCustomPackageClient(SdkClientConfiguration clientConfiguration) { this.clientHandler = new AwsSyncClientHandler(clientConfiguration); - this.clientConfiguration = clientConfiguration.toBuilder().option(SdkClientOption.SDK_CLIENT, this) - .option(SdkClientOption.API_METADATA, "AmazonProtocolRestJsonWithCustomPackage" + "#" + ServiceVersionInfo.VERSION).build(); + this.clientConfiguration = clientConfiguration + .toBuilder() + .option(SdkClientOption.SDK_CLIENT, this) + .option(SdkClientOption.API_METADATA, + "AmazonProtocolRestJsonWithCustomPackage" + "#" + ServiceVersionInfo.VERSION).build(); this.protocolFactory = init(AwsJsonProtocolFactory.builder()).build(); } @@ -116,6 +135,8 @@ public OneOperationResponse oneOperation(OneOperationRequest oneOperationRequest .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withInput(oneOperationRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "OneOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OneOperation")) .withMarshaller(new OneOperationRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -142,6 +163,57 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + ProtocolRestJsonWithCustomPackageAuthSchemeProvider requestAuthSchemeProvider = request + .overrideConfiguration() + .flatMap(c -> c.authSchemeProvider()) + .map(p -> Validate.isInstanceOf(ProtocolRestJsonWithCustomPackageAuthSchemeProvider.class, p, + "Expected an instance of ProtocolRestJsonWithCustomPackageAuthSchemeProvider")).orElse(null); + ProtocolRestJsonWithCustomPackageAuthSchemeProvider authSchemeProvider = requestAuthSchemeProvider != null ? requestAuthSchemeProvider + : Validate.isInstanceOf(ProtocolRestJsonWithCustomPackageAuthSchemeProvider.class, + clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of ProtocolRestJsonWithCustomPackageAuthSchemeProvider"); + ProtocolRestJsonWithCustomPackageAuthSchemeParams.Builder paramsBuilder = ProtocolRestJsonWithCustomPackageAuthSchemeParams + .builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + ProtocolRestJsonWithCustomPackageEndpointProvider provider = (ProtocolRestJsonWithCustomPackageEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + ProtocolRestJsonWithCustomPackageEndpointParams endpointParams = ProtocolRestJsonWithCustomPackageEndpointResolverUtils + .ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = ProtocolRestJsonWithCustomPackageEndpointResolverUtils.hostPrefix(operationName, + request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = ProtocolRestJsonWithCustomPackageEndpointResolverUtils + .authSchemeWithEndpointSignerProperties(endpointAuthSchemes, selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + ProtocolRestJsonWithCustomPackageEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private HttpResponseHandler createErrorResponseHandler(BaseAwsJsonProtocolFactory protocolFactory, JsonOperationMetadata operationMetadata, Function> exceptionMetadataMapper) { return protocolFactory.createErrorResponseHandler(operationMetadata, exceptionMetadataMapper); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-customservicemetadata-async.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-customservicemetadata-async.java index 96f626d27f13..d20f340c6c8d 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-customservicemetadata-async.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-customservicemetadata-async.java @@ -6,13 +6,18 @@ import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -20,14 +25,20 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; import software.amazon.awssdk.core.client.handler.AsyncClientHandler; import software.amazon.awssdk.core.client.handler.ClientExecutionParams; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -37,6 +48,11 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.auth.scheme.ProtocolRestJsonWithCustomContentTypeAuthSchemeParams; +import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.auth.scheme.ProtocolRestJsonWithCustomContentTypeAuthSchemeProvider; +import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.endpoints.ProtocolRestJsonWithCustomContentTypeEndpointParams; +import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.endpoints.ProtocolRestJsonWithCustomContentTypeEndpointProvider; +import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.endpoints.internal.ProtocolRestJsonWithCustomContentTypeEndpointResolverUtils; import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.internal.ProtocolRestJsonWithCustomContentTypeServiceClientConfigurationBuilder; import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.internal.ServiceVersionInfo; import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.model.OneOperationRequest; @@ -44,6 +60,7 @@ import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.model.ProtocolRestJsonWithCustomContentTypeException; import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.transform.OneOperationRequestMarshaller; import software.amazon.awssdk.utils.CompletableFutureUtils; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link ProtocolRestJsonWithCustomContentTypeAsyncClient}. @@ -66,8 +83,11 @@ final class DefaultProtocolRestJsonWithCustomContentTypeAsyncClient implements P protected DefaultProtocolRestJsonWithCustomContentTypeAsyncClient(SdkClientConfiguration clientConfiguration) { this.clientHandler = new AwsAsyncClientHandler(clientConfiguration); - this.clientConfiguration = clientConfiguration.toBuilder().option(SdkClientOption.SDK_CLIENT, this) - .option(SdkClientOption.API_METADATA, "AmazonProtocolRestJsonWithCustomContentType" + "#" + ServiceVersionInfo.VERSION).build(); + this.clientConfiguration = clientConfiguration + .toBuilder() + .option(SdkClientOption.SDK_CLIENT, this) + .option(SdkClientOption.API_METADATA, + "AmazonProtocolRestJsonWithCustomContentType" + "#" + ServiceVersionInfo.VERSION).build(); this.protocolFactory = init(AwsJsonProtocolFactory.builder()).build(); } @@ -124,7 +144,8 @@ public CompletableFuture oneOperation(OneOperationRequest .withMarshaller(new OneOperationRequestMarshaller(protocolFactory)) .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) - .withInput(oneOperationRequest)); + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "OneOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OneOperation")).withInput(oneOperationRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); @@ -168,6 +189,57 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + ProtocolRestJsonWithCustomContentTypeAuthSchemeProvider requestAuthSchemeProvider = request + .overrideConfiguration() + .flatMap(c -> c.authSchemeProvider()) + .map(p -> Validate.isInstanceOf(ProtocolRestJsonWithCustomContentTypeAuthSchemeProvider.class, p, + "Expected an instance of ProtocolRestJsonWithCustomContentTypeAuthSchemeProvider")).orElse(null); + ProtocolRestJsonWithCustomContentTypeAuthSchemeProvider authSchemeProvider = requestAuthSchemeProvider != null ? requestAuthSchemeProvider + : Validate.isInstanceOf(ProtocolRestJsonWithCustomContentTypeAuthSchemeProvider.class, + clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of ProtocolRestJsonWithCustomContentTypeAuthSchemeProvider"); + ProtocolRestJsonWithCustomContentTypeAuthSchemeParams.Builder paramsBuilder = ProtocolRestJsonWithCustomContentTypeAuthSchemeParams + .builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + ProtocolRestJsonWithCustomContentTypeEndpointProvider provider = (ProtocolRestJsonWithCustomContentTypeEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + ProtocolRestJsonWithCustomContentTypeEndpointParams endpointParams = ProtocolRestJsonWithCustomContentTypeEndpointResolverUtils + .ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = ProtocolRestJsonWithCustomContentTypeEndpointResolverUtils.hostPrefix( + operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = ProtocolRestJsonWithCustomContentTypeEndpointResolverUtils + .authSchemeWithEndpointSignerProperties(endpointAuthSchemes, selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + ProtocolRestJsonWithCustomContentTypeEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-customservicemetadata-sync.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-customservicemetadata-sync.java index 3852e59b5710..f0998862d033 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-customservicemetadata-sync.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-customservicemetadata-sync.java @@ -3,11 +3,16 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsSyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -15,6 +20,7 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; @@ -22,8 +28,12 @@ import software.amazon.awssdk.core.client.handler.SyncClientHandler; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -33,6 +43,11 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.auth.scheme.ProtocolRestJsonWithCustomContentTypeAuthSchemeParams; +import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.auth.scheme.ProtocolRestJsonWithCustomContentTypeAuthSchemeProvider; +import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.endpoints.ProtocolRestJsonWithCustomContentTypeEndpointParams; +import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.endpoints.ProtocolRestJsonWithCustomContentTypeEndpointProvider; +import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.endpoints.internal.ProtocolRestJsonWithCustomContentTypeEndpointResolverUtils; import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.internal.ProtocolRestJsonWithCustomContentTypeServiceClientConfigurationBuilder; import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.internal.ServiceVersionInfo; import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.model.OneOperationRequest; @@ -40,6 +55,7 @@ import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.model.ProtocolRestJsonWithCustomContentTypeException; import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.transform.OneOperationRequestMarshaller; import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link ProtocolRestJsonWithCustomContentTypeClient}. @@ -119,6 +135,8 @@ public OneOperationResponse oneOperation(OneOperationRequest oneOperationRequest .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withInput(oneOperationRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "OneOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OneOperation")) .withMarshaller(new OneOperationRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -145,6 +163,57 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + ProtocolRestJsonWithCustomContentTypeAuthSchemeProvider requestAuthSchemeProvider = request + .overrideConfiguration() + .flatMap(c -> c.authSchemeProvider()) + .map(p -> Validate.isInstanceOf(ProtocolRestJsonWithCustomContentTypeAuthSchemeProvider.class, p, + "Expected an instance of ProtocolRestJsonWithCustomContentTypeAuthSchemeProvider")).orElse(null); + ProtocolRestJsonWithCustomContentTypeAuthSchemeProvider authSchemeProvider = requestAuthSchemeProvider != null ? requestAuthSchemeProvider + : Validate.isInstanceOf(ProtocolRestJsonWithCustomContentTypeAuthSchemeProvider.class, + clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of ProtocolRestJsonWithCustomContentTypeAuthSchemeProvider"); + ProtocolRestJsonWithCustomContentTypeAuthSchemeParams.Builder paramsBuilder = ProtocolRestJsonWithCustomContentTypeAuthSchemeParams + .builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + ProtocolRestJsonWithCustomContentTypeEndpointProvider provider = (ProtocolRestJsonWithCustomContentTypeEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + ProtocolRestJsonWithCustomContentTypeEndpointParams endpointParams = ProtocolRestJsonWithCustomContentTypeEndpointResolverUtils + .ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = ProtocolRestJsonWithCustomContentTypeEndpointResolverUtils.hostPrefix( + operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = ProtocolRestJsonWithCustomContentTypeEndpointResolverUtils + .authSchemeWithEndpointSignerProperties(endpointAuthSchemes, selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + ProtocolRestJsonWithCustomContentTypeEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private HttpResponseHandler createErrorResponseHandler(BaseAwsJsonProtocolFactory protocolFactory, JsonOperationMetadata operationMetadata, Function> exceptionMetadataMapper) { return protocolFactory.createErrorResponseHandler(operationMetadata, exceptionMetadataMapper); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-async.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-async.java index b0426321ea91..38705c3d035c 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-async.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-async.java @@ -7,6 +7,7 @@ import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import org.slf4j.Logger; @@ -16,6 +17,9 @@ import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -23,6 +27,7 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; @@ -30,9 +35,14 @@ import software.amazon.awssdk.core.client.handler.ClientExecutionParams; import software.amazon.awssdk.core.endpointdiscovery.EndpointDiscoveryRefreshCache; import software.amazon.awssdk.core.endpointdiscovery.EndpointDiscoveryRequest; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; @@ -43,6 +53,11 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.endpointdiscoverytest.auth.scheme.EndpointDiscoveryTestAuthSchemeParams; +import software.amazon.awssdk.services.endpointdiscoverytest.auth.scheme.EndpointDiscoveryTestAuthSchemeProvider; +import software.amazon.awssdk.services.endpointdiscoverytest.endpoints.EndpointDiscoveryTestEndpointParams; +import software.amazon.awssdk.services.endpointdiscoverytest.endpoints.EndpointDiscoveryTestEndpointProvider; +import software.amazon.awssdk.services.endpointdiscoverytest.endpoints.internal.EndpointDiscoveryTestEndpointResolverUtils; import software.amazon.awssdk.services.endpointdiscoverytest.internal.EndpointDiscoveryTestServiceClientConfigurationBuilder; import software.amazon.awssdk.services.endpointdiscoverytest.internal.ServiceVersionInfo; import software.amazon.awssdk.services.endpointdiscoverytest.model.DescribeEndpointsRequest; @@ -59,6 +74,7 @@ import software.amazon.awssdk.services.endpointdiscoverytest.transform.TestDiscoveryOptionalRequestMarshaller; import software.amazon.awssdk.services.endpointdiscoverytest.transform.TestDiscoveryRequiredRequestMarshaller; import software.amazon.awssdk.utils.CompletableFutureUtils; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link EndpointDiscoveryTestAsyncClient}. @@ -140,10 +156,16 @@ public CompletableFuture describeEndpoints(DescribeEn CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("DescribeEndpoints").withProtocolMetadata(protocolMetadata) + .withOperationName("DescribeEndpoints") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new DescribeEndpointsRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "DescribeEndpoints", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "DescribeEndpoints")) .withInput(describeEndpointsRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -235,10 +257,17 @@ public CompletableFuture testDiscovery CompletableFuture executeFuture = endpointFuture .thenCompose(cachedEndpoint -> clientHandler .execute(new ClientExecutionParams() - .withOperationName("TestDiscoveryIdentifiersRequired").withProtocolMetadata(protocolMetadata) + .withOperationName("TestDiscoveryIdentifiersRequired") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new TestDiscoveryIdentifiersRequiredRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "TestDiscoveryIdentifiersRequired", + clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "TestDiscoveryIdentifiersRequired")) .discoveredEndpoint(cachedEndpoint).withInput(testDiscoveryIdentifiersRequiredRequest))); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -320,10 +349,16 @@ public CompletableFuture testDiscoveryOptional( CompletableFuture executeFuture = endpointFuture .thenCompose(cachedEndpoint -> clientHandler .execute(new ClientExecutionParams() - .withOperationName("TestDiscoveryOptional").withProtocolMetadata(protocolMetadata) + .withOperationName("TestDiscoveryOptional") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new TestDiscoveryOptionalRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "TestDiscoveryOptional", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "TestDiscoveryOptional")) .discoveredEndpoint(cachedEndpoint).withInput(testDiscoveryOptionalRequest))); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -413,10 +448,16 @@ public CompletableFuture testDiscoveryRequired( CompletableFuture executeFuture = endpointFuture .thenCompose(cachedEndpoint -> clientHandler .execute(new ClientExecutionParams() - .withOperationName("TestDiscoveryRequired").withProtocolMetadata(protocolMetadata) + .withOperationName("TestDiscoveryRequired") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new TestDiscoveryRequiredRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "TestDiscoveryRequired", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "TestDiscoveryRequired")) .discoveredEndpoint(cachedEndpoint).withInput(testDiscoveryRequiredRequest))); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -460,6 +501,56 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + EndpointDiscoveryTestAuthSchemeProvider requestAuthSchemeProvider = request + .overrideConfiguration() + .flatMap(c -> c.authSchemeProvider()) + .map(p -> Validate.isInstanceOf(EndpointDiscoveryTestAuthSchemeProvider.class, p, + "Expected an instance of EndpointDiscoveryTestAuthSchemeProvider")).orElse(null); + EndpointDiscoveryTestAuthSchemeProvider authSchemeProvider = requestAuthSchemeProvider != null ? requestAuthSchemeProvider + : Validate.isInstanceOf(EndpointDiscoveryTestAuthSchemeProvider.class, + clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of EndpointDiscoveryTestAuthSchemeProvider"); + EndpointDiscoveryTestAuthSchemeParams.Builder paramsBuilder = EndpointDiscoveryTestAuthSchemeParams.builder().operation( + operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + EndpointDiscoveryTestEndpointProvider provider = (EndpointDiscoveryTestEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + EndpointDiscoveryTestEndpointParams endpointParams = EndpointDiscoveryTestEndpointResolverUtils.ruleParams(request, + executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = EndpointDiscoveryTestEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = EndpointDiscoveryTestEndpointResolverUtils.authSchemeWithEndpointSignerProperties( + endpointAuthSchemes, selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + EndpointDiscoveryTestEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-sync.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-sync.java index 731f8a467062..cd2feebca170 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-sync.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-sync.java @@ -5,6 +5,7 @@ import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import software.amazon.awssdk.annotations.Generated; @@ -12,6 +13,9 @@ import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsSyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -19,6 +23,7 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; @@ -28,8 +33,12 @@ import software.amazon.awssdk.core.endpointdiscovery.EndpointDiscoveryRequest; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; @@ -40,6 +49,11 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.endpointdiscoverytest.auth.scheme.EndpointDiscoveryTestAuthSchemeParams; +import software.amazon.awssdk.services.endpointdiscoverytest.auth.scheme.EndpointDiscoveryTestAuthSchemeProvider; +import software.amazon.awssdk.services.endpointdiscoverytest.endpoints.EndpointDiscoveryTestEndpointParams; +import software.amazon.awssdk.services.endpointdiscoverytest.endpoints.EndpointDiscoveryTestEndpointProvider; +import software.amazon.awssdk.services.endpointdiscoverytest.endpoints.internal.EndpointDiscoveryTestEndpointResolverUtils; import software.amazon.awssdk.services.endpointdiscoverytest.internal.EndpointDiscoveryTestServiceClientConfigurationBuilder; import software.amazon.awssdk.services.endpointdiscoverytest.internal.ServiceVersionInfo; import software.amazon.awssdk.services.endpointdiscoverytest.model.DescribeEndpointsRequest; @@ -57,6 +71,7 @@ import software.amazon.awssdk.services.endpointdiscoverytest.transform.TestDiscoveryRequiredRequestMarshaller; import software.amazon.awssdk.utils.CompletableFutureUtils; import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link EndpointDiscoveryTestClient}. @@ -138,6 +153,8 @@ public DescribeEndpointsResponse describeEndpoints(DescribeEndpointsRequest desc .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withInput(describeEndpointsRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "DescribeEndpoints", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "DescribeEndpoints")) .withMarshaller(new DescribeEndpointsRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -211,10 +228,17 @@ public TestDiscoveryIdentifiersRequiredResponse testDiscoveryIdentifiersRequired return clientHandler .execute(new ClientExecutionParams() - .withOperationName("TestDiscoveryIdentifiersRequired").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .discoveredEndpoint(cachedEndpoint).withRequestConfiguration(clientConfiguration) - .withInput(testDiscoveryIdentifiersRequiredRequest).withMetricCollector(apiCallMetricCollector) + .withOperationName("TestDiscoveryIdentifiersRequired") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .discoveredEndpoint(cachedEndpoint) + .withRequestConfiguration(clientConfiguration) + .withInput(testDiscoveryIdentifiersRequiredRequest) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "TestDiscoveryIdentifiersRequired", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "TestDiscoveryIdentifiersRequired")) .withMarshaller(new TestDiscoveryIdentifiersRequiredRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -277,12 +301,20 @@ public TestDiscoveryOptionalResponse testDiscoveryOptional(TestDiscoveryOptional apiCallMetricCollector.reportMetric(CoreMetric.SERVICE_ID, "AwsEndpointDiscoveryTest"); apiCallMetricCollector.reportMetric(CoreMetric.OPERATION_NAME, "TestDiscoveryOptional"); - return clientHandler.execute(new ClientExecutionParams() - .withOperationName("TestDiscoveryOptional").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .discoveredEndpoint(cachedEndpoint).withRequestConfiguration(clientConfiguration) - .withInput(testDiscoveryOptionalRequest).withMetricCollector(apiCallMetricCollector) - .withMarshaller(new TestDiscoveryOptionalRequestMarshaller(protocolFactory))); + return clientHandler + .execute(new ClientExecutionParams() + .withOperationName("TestDiscoveryOptional") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .discoveredEndpoint(cachedEndpoint) + .withRequestConfiguration(clientConfiguration) + .withInput(testDiscoveryOptionalRequest) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "TestDiscoveryOptional", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "TestDiscoveryOptional")) + .withMarshaller(new TestDiscoveryOptionalRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); } @@ -352,12 +384,20 @@ public TestDiscoveryRequiredResponse testDiscoveryRequired(TestDiscoveryRequired apiCallMetricCollector.reportMetric(CoreMetric.SERVICE_ID, "AwsEndpointDiscoveryTest"); apiCallMetricCollector.reportMetric(CoreMetric.OPERATION_NAME, "TestDiscoveryRequired"); - return clientHandler.execute(new ClientExecutionParams() - .withOperationName("TestDiscoveryRequired").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .discoveredEndpoint(cachedEndpoint).withRequestConfiguration(clientConfiguration) - .withInput(testDiscoveryRequiredRequest).withMetricCollector(apiCallMetricCollector) - .withMarshaller(new TestDiscoveryRequiredRequestMarshaller(protocolFactory))); + return clientHandler + .execute(new ClientExecutionParams() + .withOperationName("TestDiscoveryRequired") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .discoveredEndpoint(cachedEndpoint) + .withRequestConfiguration(clientConfiguration) + .withInput(testDiscoveryRequiredRequest) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "TestDiscoveryRequired", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "TestDiscoveryRequired")) + .withMarshaller(new TestDiscoveryRequiredRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); } @@ -383,6 +423,56 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + EndpointDiscoveryTestAuthSchemeProvider requestAuthSchemeProvider = request + .overrideConfiguration() + .flatMap(c -> c.authSchemeProvider()) + .map(p -> Validate.isInstanceOf(EndpointDiscoveryTestAuthSchemeProvider.class, p, + "Expected an instance of EndpointDiscoveryTestAuthSchemeProvider")).orElse(null); + EndpointDiscoveryTestAuthSchemeProvider authSchemeProvider = requestAuthSchemeProvider != null ? requestAuthSchemeProvider + : Validate.isInstanceOf(EndpointDiscoveryTestAuthSchemeProvider.class, + clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of EndpointDiscoveryTestAuthSchemeProvider"); + EndpointDiscoveryTestAuthSchemeParams.Builder paramsBuilder = EndpointDiscoveryTestAuthSchemeParams.builder().operation( + operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + EndpointDiscoveryTestEndpointProvider provider = (EndpointDiscoveryTestEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + EndpointDiscoveryTestEndpointParams endpointParams = EndpointDiscoveryTestEndpointResolverUtils.ruleParams(request, + executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = EndpointDiscoveryTestEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = EndpointDiscoveryTestEndpointResolverUtils.authSchemeWithEndpointSignerProperties( + endpointAuthSchemes, selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + EndpointDiscoveryTestEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private HttpResponseHandler createErrorResponseHandler(BaseAwsJsonProtocolFactory protocolFactory, JsonOperationMetadata operationMetadata, Function> exceptionMetadataMapper) { return protocolFactory.createErrorResponseHandler(operationMetadata, exceptionMetadataMapper); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-json-async-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-json-async-client-class.java index 4fcb45c2f919..d3d886b2d98f 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-json-async-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-json-async-client-class.java @@ -7,6 +7,7 @@ import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.function.Consumer; @@ -16,8 +17,12 @@ import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler; import software.amazon.awssdk.awscore.client.handler.AwsClientHandlerUtils; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.eventstream.EventStreamAsyncResponseTransformer; import software.amazon.awssdk.awscore.eventstream.EventStreamTaggedUnionJsonMarshaller; import software.amazon.awssdk.awscore.eventstream.EventStreamTaggedUnionPojoSupplier; @@ -33,6 +38,7 @@ import software.amazon.awssdk.core.SdkPojoBuilder; import software.amazon.awssdk.core.SdkRequest; import software.amazon.awssdk.core.SdkResponse; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.core.async.AsyncResponseTransformerUtils; @@ -44,7 +50,9 @@ import software.amazon.awssdk.core.client.handler.AsyncClientHandler; import software.amazon.awssdk.core.client.handler.AttachHttpMetadataResponseHandler; import software.amazon.awssdk.core.client.handler.ClientExecutionParams; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.interceptor.trait.HttpChecksum; import software.amazon.awssdk.core.interceptor.trait.HttpChecksumRequired; @@ -53,6 +61,8 @@ import software.amazon.awssdk.core.protocol.VoidSdkResponse; import software.amazon.awssdk.core.retry.RetryMode; import software.amazon.awssdk.core.runtime.transform.AsyncStreamingRequestMarshaller; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -62,7 +72,12 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.json.auth.scheme.JsonAuthSchemeParams; +import software.amazon.awssdk.services.json.auth.scheme.JsonAuthSchemeProvider; import software.amazon.awssdk.services.json.batchmanager.JsonAsyncBatchManager; +import software.amazon.awssdk.services.json.endpoints.JsonEndpointParams; +import software.amazon.awssdk.services.json.endpoints.JsonEndpointProvider; +import software.amazon.awssdk.services.json.endpoints.internal.JsonEndpointResolverUtils; import software.amazon.awssdk.services.json.internal.JsonServiceClientConfigurationBuilder; import software.amazon.awssdk.services.json.internal.ServiceVersionInfo; import software.amazon.awssdk.services.json.model.APostOperationRequest; @@ -129,6 +144,7 @@ import software.amazon.awssdk.utils.CompletableFutureUtils; import software.amazon.awssdk.utils.HostnameValidator; import software.amazon.awssdk.utils.Pair; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link JsonAsyncClient}. @@ -227,10 +243,16 @@ public CompletableFuture aPostOperation(APostOperationRe CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("APostOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new APostOperationRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperation")) .hostPrefixExpression(resolvedHostExpression).withInput(aPostOperationRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -301,10 +323,16 @@ public CompletableFuture aPostOperationWithOut CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperationWithOutput").withProtocolMetadata(protocolMetadata) + .withOperationName("APostOperationWithOutput") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new APostOperationWithOutputRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperationWithOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperationWithOutput")) .withInput(aPostOperationWithOutputRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -371,10 +399,16 @@ public CompletableFuture bearerAuthOperation( CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("BearerAuthOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("BearerAuthOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new BearerAuthOperationRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "BearerAuthOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "BearerAuthOperation")) .credentialType(CredentialType.TOKEN).withInput(bearerAuthOperationRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -484,11 +518,18 @@ public CompletableFuture eventStreamOperation(EventStreamOperationRequest CompletableFuture executeFuture = clientHandler.execute( new ClientExecutionParams() - .withOperationName("EventStreamOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("EventStreamOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new EventStreamOperationRequestMarshaller(protocolFactory)) - .withAsyncRequestBody(AsyncRequestBody.fromPublisher(adapted)).withFullDuplex(true) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withAsyncRequestBody(AsyncRequestBody.fromPublisher(adapted)) + .withFullDuplex(true) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "EventStreamOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "EventStreamOperation")) .withInput(eventStreamOperationRequest), restAsyncResponseTransformer); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { if (e != null) { @@ -572,11 +613,18 @@ public CompletableFuture eventStreamO CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("EventStreamOperationWithOnlyInput").withProtocolMetadata(protocolMetadata) + .withOperationName("EventStreamOperationWithOnlyInput") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new EventStreamOperationWithOnlyInputRequestMarshaller(protocolFactory)) - .withAsyncRequestBody(AsyncRequestBody.fromPublisher(adapted)).withResponseHandler(responseHandler) - .withErrorResponseHandler(errorResponseHandler).withRequestConfiguration(clientConfiguration) - .withMetricCollector(apiCallMetricCollector).withInput(eventStreamOperationWithOnlyInputRequest)); + .withAsyncRequestBody(AsyncRequestBody.fromPublisher(adapted)) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "EventStreamOperationWithOnlyInput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "EventStreamOperationWithOnlyInput")) + .withInput(eventStreamOperationWithOnlyInputRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); @@ -686,8 +734,14 @@ public CompletableFuture eventStreamOperationWithOnlyOutput( .withOperationName("EventStreamOperationWithOnlyOutput") .withProtocolMetadata(protocolMetadata) .withMarshaller(new EventStreamOperationWithOnlyOutputRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "EventStreamOperationWithOnlyOutput", + clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "EventStreamOperationWithOnlyOutput")) .withInput(eventStreamOperationWithOnlyOutputRequest), restAsyncResponseTransformer); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { if (e != null) { @@ -770,6 +824,9 @@ public CompletableFuture getOperationWithCheck .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "GetOperationWithChecksum", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GetOperationWithChecksum")) .putExecutionAttribute( SdkInternalExecutionAttribute.HTTP_CHECKSUM, HttpChecksum.builder().requestChecksumRequired(true).isRequestStreaming(false) @@ -845,10 +902,16 @@ public CompletableFuture getWithoutRequiredMe CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("GetWithoutRequiredMembers").withProtocolMetadata(protocolMetadata) + .withOperationName("GetWithoutRequiredMembers") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new GetWithoutRequiredMembersRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "GetWithoutRequiredMembers", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GetWithoutRequiredMembers")) .withInput(getWithoutRequiredMembersRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -923,6 +986,9 @@ public CompletableFuture operationWithChe .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithChecksumRequired", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithChecksumRequired")) .putExecutionAttribute(SdkInternalExecutionAttribute.HTTP_CHECKSUM_REQUIRED, HttpChecksumRequired.create()).withInput(operationWithChecksumRequiredRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { @@ -998,6 +1064,9 @@ public CompletableFuture operationWithR .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithRequestCompression", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithRequestCompression")) .putExecutionAttribute(SdkInternalExecutionAttribute.REQUEST_COMPRESSION, RequestCompression.builder().encodings("gzip").isStreaming(false).build()) .withInput(operationWithRequestCompressionRequest)); @@ -1067,10 +1136,16 @@ public CompletableFuture paginatedOpera CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("PaginatedOperationWithResultKey").withProtocolMetadata(protocolMetadata) + .withOperationName("PaginatedOperationWithResultKey") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new PaginatedOperationWithResultKeyRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PaginatedOperationWithResultKey", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PaginatedOperationWithResultKey")) .withInput(paginatedOperationWithResultKeyRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -1138,10 +1213,16 @@ public CompletableFuture paginatedOp CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("PaginatedOperationWithoutResultKey").withProtocolMetadata(protocolMetadata) + .withOperationName("PaginatedOperationWithoutResultKey") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new PaginatedOperationWithoutResultKeyRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PaginatedOperationWithoutResultKey", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PaginatedOperationWithoutResultKey")) .withInput(paginatedOperationWithoutResultKeyRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -1241,6 +1322,9 @@ public CompletableFuture putOperationWithChecksum( .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PutOperationWithChecksum", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PutOperationWithChecksum")) .withAsyncRequestBody(requestBody) .putExecutionAttribute( SdkInternalExecutionAttribute.HTTP_CHECKSUM, @@ -1344,10 +1428,15 @@ public CompletableFuture streamingInputOperatio .withMarshaller( AsyncStreamingRequestMarshaller.builder() .delegateMarshaller(new StreamingInputOperationRequestMarshaller(protocolFactory)) - .asyncRequestBody(requestBody).build()).withResponseHandler(responseHandler) - .withErrorResponseHandler(errorResponseHandler).withRequestConfiguration(clientConfiguration) - .withMetricCollector(apiCallMetricCollector).withAsyncRequestBody(requestBody) - .withInput(streamingInputOperationRequest)); + .asyncRequestBody(requestBody).build()) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingInputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingInputOperation")) + .withAsyncRequestBody(requestBody).withInput(streamingInputOperationRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); @@ -1436,8 +1525,13 @@ public CompletableFuture streamingInputOutputOperation( .delegateMarshaller( new StreamingInputOutputOperationRequestMarshaller(protocolFactory)) .asyncRequestBody(requestBody).transferEncoding(true).build()) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingInputOutputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingInputOutputOperation")) .withAsyncRequestBody(requestBody).withAsyncResponseTransformer(asyncResponseTransformer) .withInput(streamingInputOutputOperationRequest), asyncResponseTransformer); AsyncResponseTransformer finalAsyncResponseTransformer = asyncResponseTransformer; @@ -1525,10 +1619,16 @@ public CompletableFuture streamingOutputOperation( CompletableFuture executeFuture = clientHandler.execute( new ClientExecutionParams() - .withOperationName("StreamingOutputOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("StreamingOutputOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new StreamingOutputOperationRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingOutputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingOutputOperation")) .withAsyncResponseTransformer(asyncResponseTransformer).withInput(streamingOutputOperationRequest), asyncResponseTransformer); AsyncResponseTransformer finalAsyncResponseTransformer = asyncResponseTransformer; @@ -1587,6 +1687,54 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + JsonAuthSchemeProvider requestAuthSchemeProvider = request + .overrideConfiguration() + .flatMap(c -> c.authSchemeProvider()) + .map(p -> Validate + .isInstanceOf(JsonAuthSchemeProvider.class, p, "Expected an instance of JsonAuthSchemeProvider")) + .orElse(null); + JsonAuthSchemeProvider authSchemeProvider = requestAuthSchemeProvider != null ? requestAuthSchemeProvider : Validate + .isInstanceOf(JsonAuthSchemeProvider.class, clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of JsonAuthSchemeProvider"); + JsonAuthSchemeParams.Builder paramsBuilder = JsonAuthSchemeParams.builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + JsonEndpointProvider provider = (JsonEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + JsonEndpointParams endpointParams = JsonEndpointResolverUtils.ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = JsonEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = JsonEndpointResolverUtils.authSchemeWithEndpointSignerProperties(endpointAuthSchemes, + selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + JsonEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-json-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-json-client-class.java index 67092aec20a0..457086251a19 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-json-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-json-client-class.java @@ -3,11 +3,16 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsSyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -17,6 +22,7 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; @@ -24,6 +30,7 @@ import software.amazon.awssdk.core.client.handler.SyncClientHandler; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.interceptor.trait.HttpChecksum; import software.amazon.awssdk.core.interceptor.trait.HttpChecksumRequired; @@ -33,6 +40,8 @@ import software.amazon.awssdk.core.runtime.transform.StreamingRequestMarshaller; import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.core.sync.ResponseTransformer; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -42,6 +51,11 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.json.auth.scheme.JsonAuthSchemeParams; +import software.amazon.awssdk.services.json.auth.scheme.JsonAuthSchemeProvider; +import software.amazon.awssdk.services.json.endpoints.JsonEndpointParams; +import software.amazon.awssdk.services.json.endpoints.JsonEndpointProvider; +import software.amazon.awssdk.services.json.endpoints.internal.JsonEndpointResolverUtils; import software.amazon.awssdk.services.json.internal.JsonServiceClientConfigurationBuilder; import software.amazon.awssdk.services.json.internal.ServiceVersionInfo; import software.amazon.awssdk.services.json.model.APostOperationRequest; @@ -87,6 +101,7 @@ import software.amazon.awssdk.services.json.transform.StreamingOutputOperationRequestMarshaller; import software.amazon.awssdk.utils.HostnameValidator; import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link JsonClient}. @@ -174,6 +189,8 @@ public APostOperationResponse aPostOperation(APostOperationRequest aPostOperatio .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .hostPrefixExpression(resolvedHostExpression).withRequestConfiguration(clientConfiguration) .withInput(aPostOperationRequest).withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "APostOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperation")) .withMarshaller(new APostOperationRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -235,10 +252,16 @@ public APostOperationWithOutputResponse aPostOperationWithOutput( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperationWithOutput").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(aPostOperationWithOutputRequest) + .withOperationName("APostOperationWithOutput") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(aPostOperationWithOutputRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperationWithOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperationWithOutput")) .withMarshaller(new APostOperationWithOutputRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -298,6 +321,8 @@ public BearerAuthOperationResponse bearerAuthOperation(BearerAuthOperationReques .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .credentialType(CredentialType.TOKEN).withRequestConfiguration(clientConfiguration) .withInput(bearerAuthOperationRequest).withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "BearerAuthOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "BearerAuthOperation")) .withMarshaller(new BearerAuthOperationRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -362,6 +387,9 @@ public GetOperationWithChecksumResponse getOperationWithChecksum( .withRequestConfiguration(clientConfiguration) .withInput(getOperationWithChecksumRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "GetOperationWithChecksum", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GetOperationWithChecksum")) .putExecutionAttribute( SdkInternalExecutionAttribute.HTTP_CHECKSUM, HttpChecksum.builder().requestChecksumRequired(true).isRequestStreaming(false) @@ -428,10 +456,16 @@ public GetWithoutRequiredMembersResponse getWithoutRequiredMembers( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("GetWithoutRequiredMembers").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(getWithoutRequiredMembersRequest) + .withOperationName("GetWithoutRequiredMembers") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(getWithoutRequiredMembersRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "GetWithoutRequiredMembers", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GetWithoutRequiredMembers")) .withMarshaller(new GetWithoutRequiredMembersRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -496,6 +530,9 @@ public OperationWithChecksumRequiredResponse operationWithChecksumRequired( .withRequestConfiguration(clientConfiguration) .withInput(operationWithChecksumRequiredRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithChecksumRequired", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithChecksumRequired")) .putExecutionAttribute(SdkInternalExecutionAttribute.HTTP_CHECKSUM_REQUIRED, HttpChecksumRequired.create()) .withMarshaller(new OperationWithChecksumRequiredRequestMarshaller(protocolFactory))); @@ -562,6 +599,9 @@ public OperationWithRequestCompressionResponse operationWithRequestCompression( .withRequestConfiguration(clientConfiguration) .withInput(operationWithRequestCompressionRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithRequestCompression", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithRequestCompression")) .putExecutionAttribute(SdkInternalExecutionAttribute.REQUEST_COMPRESSION, RequestCompression.builder().encodings("gzip").isStreaming(false).build()) .withMarshaller(new OperationWithRequestCompressionRequestMarshaller(protocolFactory))); @@ -621,10 +661,16 @@ public PaginatedOperationWithResultKeyResponse paginatedOperationWithResultKey( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("PaginatedOperationWithResultKey").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(paginatedOperationWithResultKeyRequest) + .withOperationName("PaginatedOperationWithResultKey") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(paginatedOperationWithResultKeyRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PaginatedOperationWithResultKey", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PaginatedOperationWithResultKey")) .withMarshaller(new PaginatedOperationWithResultKeyRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -682,10 +728,16 @@ public PaginatedOperationWithoutResultKeyResponse paginatedOperationWithoutResul return clientHandler .execute(new ClientExecutionParams() - .withOperationName("PaginatedOperationWithoutResultKey").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(paginatedOperationWithoutResultKeyRequest) + .withOperationName("PaginatedOperationWithoutResultKey") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(paginatedOperationWithoutResultKeyRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PaginatedOperationWithoutResultKey", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PaginatedOperationWithoutResultKey")) .withMarshaller(new PaginatedOperationWithoutResultKeyRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -776,6 +828,9 @@ public ReturnT putOperationWithChecksum(PutOperationWithChecksumReques .withRequestConfiguration(clientConfiguration) .withInput(putOperationWithChecksumRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PutOperationWithChecksum", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PutOperationWithChecksum")) .putExecutionAttribute( SdkInternalExecutionAttribute.HTTP_CHECKSUM, HttpChecksum @@ -870,6 +925,9 @@ public StreamingInputOperationResponse streamingInputOperation(StreamingInputOpe .withRequestConfiguration(clientConfiguration) .withInput(streamingInputOperationRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingInputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingInputOperation")) .withRequestBody(requestBody) .withMarshaller( StreamingRequestMarshaller.builder() @@ -957,6 +1015,9 @@ public ReturnT streamingInputOutputOperation( .withRequestConfiguration(clientConfiguration) .withInput(streamingInputOutputOperationRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingInputOutputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingInputOutputOperation")) .withResponseTransformer(responseTransformer) .withRequestBody(requestBody) .withMarshaller( @@ -1028,10 +1089,17 @@ public ReturnT streamingOutputOperation(StreamingOutputOperationReques return clientHandler.execute( new ClientExecutionParams() - .withOperationName("StreamingOutputOperation").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(streamingOutputOperationRequest) - .withMetricCollector(apiCallMetricCollector).withResponseTransformer(responseTransformer) + .withOperationName("StreamingOutputOperation") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(streamingOutputOperationRequest) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingOutputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingOutputOperation")) + .withResponseTransformer(responseTransformer) .withMarshaller(new StreamingOutputOperationRequestMarshaller(protocolFactory)), responseTransformer); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -1066,6 +1134,54 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + JsonAuthSchemeProvider requestAuthSchemeProvider = request + .overrideConfiguration() + .flatMap(c -> c.authSchemeProvider()) + .map(p -> Validate + .isInstanceOf(JsonAuthSchemeProvider.class, p, "Expected an instance of JsonAuthSchemeProvider")) + .orElse(null); + JsonAuthSchemeProvider authSchemeProvider = requestAuthSchemeProvider != null ? requestAuthSchemeProvider : Validate + .isInstanceOf(JsonAuthSchemeProvider.class, clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of JsonAuthSchemeProvider"); + JsonAuthSchemeParams.Builder paramsBuilder = JsonAuthSchemeParams.builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + JsonEndpointProvider provider = (JsonEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + JsonEndpointParams endpointParams = JsonEndpointResolverUtils.ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = JsonEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = JsonEndpointResolverUtils.authSchemeWithEndpointSignerProperties(endpointAuthSchemes, + selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + JsonEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private HttpResponseHandler createErrorResponseHandler(BaseAwsJsonProtocolFactory protocolFactory, JsonOperationMetadata operationMetadata, Function> exceptionMetadataMapper) { return protocolFactory.createErrorResponseHandler(operationMetadata, exceptionMetadataMapper); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-query-async-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-query-async-client-class.java index b43795a31a9b..5935fe03c464 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-query-async-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-query-async-client-class.java @@ -4,14 +4,20 @@ import java.util.Collections; import java.util.List; +import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.ScheduledExecutorService; import java.util.function.Consumer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -22,6 +28,7 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.core.async.AsyncResponseTransformerUtils; @@ -30,7 +37,9 @@ import software.amazon.awssdk.core.client.config.SdkClientOption; import software.amazon.awssdk.core.client.handler.AsyncClientHandler; import software.amazon.awssdk.core.client.handler.ClientExecutionParams; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.interceptor.trait.HttpChecksum; import software.amazon.awssdk.core.interceptor.trait.HttpChecksumRequired; @@ -38,12 +47,19 @@ import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; import software.amazon.awssdk.core.runtime.transform.AsyncStreamingRequestMarshaller; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; import software.amazon.awssdk.protocols.core.ExceptionMetadata; import software.amazon.awssdk.protocols.query.AwsQueryProtocolFactory; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.query.auth.scheme.QueryAuthSchemeParams; +import software.amazon.awssdk.services.query.auth.scheme.QueryAuthSchemeProvider; +import software.amazon.awssdk.services.query.endpoints.QueryEndpointParams; +import software.amazon.awssdk.services.query.endpoints.QueryEndpointProvider; +import software.amazon.awssdk.services.query.endpoints.internal.QueryEndpointResolverUtils; import software.amazon.awssdk.services.query.internal.QueryServiceClientConfigurationBuilder; import software.amazon.awssdk.services.query.internal.ServiceVersionInfo; import software.amazon.awssdk.services.query.model.APostOperationRequest; @@ -99,6 +115,7 @@ import software.amazon.awssdk.services.query.waiters.QueryAsyncWaiter; import software.amazon.awssdk.utils.CompletableFutureUtils; import software.amazon.awssdk.utils.Pair; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link QueryAsyncClient}. @@ -173,10 +190,16 @@ public CompletableFuture aPostOperation(APostOperationRe CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("APostOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new APostOperationRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperation")) .hostPrefixExpression(resolvedHostExpression).withInput(aPostOperationRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { @@ -233,10 +256,16 @@ public CompletableFuture aPostOperationWithOut CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperationWithOutput").withProtocolMetadata(protocolMetadata) + .withOperationName("APostOperationWithOutput") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new APostOperationWithOutputRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperationWithOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperationWithOutput")) .withInput(aPostOperationWithOutputRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { @@ -289,11 +318,18 @@ public CompletableFuture bearerAuthOperation( CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("BearerAuthOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("BearerAuthOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new BearerAuthOperationRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .credentialType(CredentialType.TOKEN).withRequestConfiguration(clientConfiguration) - .withMetricCollector(apiCallMetricCollector).withInput(bearerAuthOperationRequest)); + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .credentialType(CredentialType.TOKEN) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "BearerAuthOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "BearerAuthOperation")) + .withInput(bearerAuthOperationRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -352,6 +388,9 @@ public CompletableFuture getOperationWithCheck .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "GetOperationWithChecksum", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GetOperationWithChecksum")) .putExecutionAttribute( SdkInternalExecutionAttribute.HTTP_CHECKSUM, HttpChecksum.builder().requestChecksumRequired(true).isRequestStreaming(false) @@ -417,6 +456,9 @@ public CompletableFuture operationWithChe .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithChecksumRequired", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithChecksumRequired")) .putExecutionAttribute(SdkInternalExecutionAttribute.HTTP_CHECKSUM_REQUIRED, HttpChecksumRequired.create()).withInput(operationWithChecksumRequiredRequest)); CompletableFuture whenCompleteFuture = null; @@ -470,10 +512,16 @@ public CompletableFuture operationWithContext CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithContextParam").withProtocolMetadata(protocolMetadata) + .withOperationName("OperationWithContextParam") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OperationWithContextParamRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithContextParam", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithContextParam")) .withInput(operationWithContextParamRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { @@ -527,10 +575,16 @@ public CompletableFuture operationWithCustomM CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithCustomMember").withProtocolMetadata(protocolMetadata) + .withOperationName("OperationWithCustomMember") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OperationWithCustomMemberRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithCustomMember", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithCustomMember")) .withInput(operationWithCustomMemberRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { @@ -588,8 +642,14 @@ public CompletableFuture o .withOperationName("OperationWithCustomizedOperationContextParam") .withProtocolMetadata(protocolMetadata) .withMarshaller(new OperationWithCustomizedOperationContextParamRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithCustomizedOperationContextParam", + clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithCustomizedOperationContextParam")) .withInput(operationWithCustomizedOperationContextParamRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { @@ -644,10 +704,16 @@ public CompletableFuture operatio CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithMapOperationContextParam").withProtocolMetadata(protocolMetadata) + .withOperationName("OperationWithMapOperationContextParam") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OperationWithMapOperationContextParamRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithMapOperationContextParam", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithMapOperationContextParam")) .withInput(operationWithMapOperationContextParamRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { @@ -700,10 +766,16 @@ public CompletableFuture operationWithNoneAut CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithNoneAuthType").withProtocolMetadata(protocolMetadata) + .withOperationName("OperationWithNoneAuthType") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OperationWithNoneAuthTypeRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithNoneAuthType", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithNoneAuthType")) .withInput(operationWithNoneAuthTypeRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { @@ -758,10 +830,16 @@ public CompletableFuture operationWi CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithOperationContextParam").withProtocolMetadata(protocolMetadata) + .withOperationName("OperationWithOperationContextParam") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OperationWithOperationContextParamRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithOperationContextParam", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithOperationContextParam")) .withInput(operationWithOperationContextParamRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { @@ -822,6 +900,9 @@ public CompletableFuture operationWithR .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithRequestCompression", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithRequestCompression")) .putExecutionAttribute(SdkInternalExecutionAttribute.REQUEST_COMPRESSION, RequestCompression.builder().encodings("gzip").isStreaming(false).build()) .withInput(operationWithRequestCompressionRequest)); @@ -877,10 +958,16 @@ public CompletableFuture operationWith CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithStaticContextParams").withProtocolMetadata(protocolMetadata) + .withOperationName("OperationWithStaticContextParams") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OperationWithStaticContextParamsRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithStaticContextParams", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithStaticContextParams")) .withInput(operationWithStaticContextParamsRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { @@ -966,6 +1053,9 @@ public CompletableFuture putOperationWithChecksum( .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PutOperationWithChecksum", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PutOperationWithChecksum")) .putExecutionAttribute( SdkInternalExecutionAttribute.HTTP_CHECKSUM, HttpChecksum @@ -1052,10 +1142,15 @@ public CompletableFuture streamingInputOperatio .withMarshaller( AsyncStreamingRequestMarshaller.builder() .delegateMarshaller(new StreamingInputOperationRequestMarshaller(protocolFactory)) - .asyncRequestBody(requestBody).build()).withResponseHandler(responseHandler) - .withErrorResponseHandler(errorResponseHandler).withRequestConfiguration(clientConfiguration) - .withMetricCollector(apiCallMetricCollector).withAsyncRequestBody(requestBody) - .withInput(streamingInputOperationRequest)); + .asyncRequestBody(requestBody).build()) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingInputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingInputOperation")) + .withAsyncRequestBody(requestBody).withInput(streamingInputOperationRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -1117,10 +1212,16 @@ public CompletableFuture streamingOutputOperation( CompletableFuture executeFuture = clientHandler.execute( new ClientExecutionParams() - .withOperationName("StreamingOutputOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("StreamingOutputOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new StreamingOutputOperationRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingOutputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingOutputOperation")) .withAsyncResponseTransformer(asyncResponseTransformer).withInput(streamingOutputOperationRequest), asyncResponseTransformer); CompletableFuture whenCompleteFuture = null; @@ -1183,6 +1284,53 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + QueryAuthSchemeProvider requestAuthSchemeProvider = request + .overrideConfiguration() + .flatMap(c -> c.authSchemeProvider()) + .map(p -> Validate.isInstanceOf(QueryAuthSchemeProvider.class, p, + "Expected an instance of QueryAuthSchemeProvider")).orElse(null); + QueryAuthSchemeProvider authSchemeProvider = requestAuthSchemeProvider != null ? requestAuthSchemeProvider : Validate + .isInstanceOf(QueryAuthSchemeProvider.class, clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of QueryAuthSchemeProvider"); + QueryAuthSchemeParams.Builder paramsBuilder = QueryAuthSchemeParams.builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + QueryEndpointProvider provider = (QueryEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + QueryEndpointParams endpointParams = QueryEndpointResolverUtils.ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = QueryEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = QueryEndpointResolverUtils.authSchemeWithEndpointSignerProperties(endpointAuthSchemes, + selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + QueryEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-query-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-query-client-class.java index 3c168823a547..98f0bf1b1f82 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-query-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-query-client-class.java @@ -2,10 +2,16 @@ import java.util.Collections; import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsSyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -16,6 +22,7 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; @@ -23,6 +30,7 @@ import software.amazon.awssdk.core.client.handler.SyncClientHandler; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.interceptor.trait.HttpChecksum; import software.amazon.awssdk.core.interceptor.trait.HttpChecksumRequired; @@ -32,12 +40,19 @@ import software.amazon.awssdk.core.runtime.transform.StreamingRequestMarshaller; import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.core.sync.ResponseTransformer; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; import software.amazon.awssdk.protocols.core.ExceptionMetadata; import software.amazon.awssdk.protocols.query.AwsQueryProtocolFactory; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.query.auth.scheme.QueryAuthSchemeParams; +import software.amazon.awssdk.services.query.auth.scheme.QueryAuthSchemeProvider; +import software.amazon.awssdk.services.query.endpoints.QueryEndpointParams; +import software.amazon.awssdk.services.query.endpoints.QueryEndpointProvider; +import software.amazon.awssdk.services.query.endpoints.internal.QueryEndpointResolverUtils; import software.amazon.awssdk.services.query.internal.QueryServiceClientConfigurationBuilder; import software.amazon.awssdk.services.query.internal.ServiceVersionInfo; import software.amazon.awssdk.services.query.model.APostOperationRequest; @@ -92,6 +107,7 @@ import software.amazon.awssdk.services.query.transform.StreamingOutputOperationRequestMarshaller; import software.amazon.awssdk.services.query.waiters.QueryWaiter; import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link QueryClient}. @@ -163,6 +179,8 @@ public APostOperationResponse aPostOperation(APostOperationRequest aPostOperatio .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .hostPrefixExpression(resolvedHostExpression).withRequestConfiguration(clientConfiguration) .withInput(aPostOperationRequest).withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "APostOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperation")) .withMarshaller(new APostOperationRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -210,10 +228,16 @@ public APostOperationWithOutputResponse aPostOperationWithOutput( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperationWithOutput").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(aPostOperationWithOutputRequest) + .withOperationName("APostOperationWithOutput") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(aPostOperationWithOutputRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperationWithOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperationWithOutput")) .withMarshaller(new APostOperationWithOutputRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -259,6 +283,8 @@ public BearerAuthOperationResponse bearerAuthOperation(BearerAuthOperationReques .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .credentialType(CredentialType.TOKEN).withRequestConfiguration(clientConfiguration) .withInput(bearerAuthOperationRequest).withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "BearerAuthOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "BearerAuthOperation")) .withMarshaller(new BearerAuthOperationRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -309,6 +335,9 @@ public GetOperationWithChecksumResponse getOperationWithChecksum( .withRequestConfiguration(clientConfiguration) .withInput(getOperationWithChecksumRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "GetOperationWithChecksum", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GetOperationWithChecksum")) .putExecutionAttribute( SdkInternalExecutionAttribute.HTTP_CHECKSUM, HttpChecksum.builder().requestChecksumRequired(true).isRequestStreaming(false) @@ -364,6 +393,9 @@ public OperationWithChecksumRequiredResponse operationWithChecksumRequired( .withRequestConfiguration(clientConfiguration) .withInput(operationWithChecksumRequiredRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithChecksumRequired", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithChecksumRequired")) .putExecutionAttribute(SdkInternalExecutionAttribute.HTTP_CHECKSUM_REQUIRED, HttpChecksumRequired.create()) .withMarshaller(new OperationWithChecksumRequiredRequestMarshaller(protocolFactory))); @@ -409,10 +441,16 @@ public OperationWithContextParamResponse operationWithContextParam( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithContextParam").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(operationWithContextParamRequest) + .withOperationName("OperationWithContextParam") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(operationWithContextParamRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithContextParam", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithContextParam")) .withMarshaller(new OperationWithContextParamRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -457,10 +495,16 @@ public OperationWithCustomMemberResponse operationWithCustomMember( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithCustomMember").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(operationWithCustomMemberRequest) + .withOperationName("OperationWithCustomMember") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(operationWithCustomMemberRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithCustomMember", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithCustomMember")) .withMarshaller(new OperationWithCustomMemberRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -506,10 +550,16 @@ public OperationWithCustomizedOperationContextParamResponse operationWithCustomi return clientHandler .execute(new ClientExecutionParams() .withOperationName("OperationWithCustomizedOperationContextParam") - .withProtocolMetadata(protocolMetadata).withResponseHandler(responseHandler) - .withErrorResponseHandler(errorResponseHandler).withRequestConfiguration(clientConfiguration) + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) .withInput(operationWithCustomizedOperationContextParamRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithCustomizedOperationContextParam", + clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithCustomizedOperationContextParam")) .withMarshaller(new OperationWithCustomizedOperationContextParamRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -554,10 +604,16 @@ public OperationWithMapOperationContextParamResponse operationWithMapOperationCo return clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithMapOperationContextParam").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) + .withOperationName("OperationWithMapOperationContextParam") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) - .withInput(operationWithMapOperationContextParamRequest).withMetricCollector(apiCallMetricCollector) + .withInput(operationWithMapOperationContextParamRequest) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithMapOperationContextParam", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithMapOperationContextParam")) .withMarshaller(new OperationWithMapOperationContextParamRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -601,10 +657,16 @@ public OperationWithNoneAuthTypeResponse operationWithNoneAuthType( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithNoneAuthType").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(operationWithNoneAuthTypeRequest) + .withOperationName("OperationWithNoneAuthType") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(operationWithNoneAuthTypeRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithNoneAuthType", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithNoneAuthType")) .withMarshaller(new OperationWithNoneAuthTypeRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -649,10 +711,16 @@ public OperationWithOperationContextParamResponse operationWithOperationContextP return clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithOperationContextParam").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(operationWithOperationContextParamRequest) + .withOperationName("OperationWithOperationContextParam") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(operationWithOperationContextParamRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithOperationContextParam", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithOperationContextParam")) .withMarshaller(new OperationWithOperationContextParamRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -703,6 +771,9 @@ public OperationWithRequestCompressionResponse operationWithRequestCompression( .withRequestConfiguration(clientConfiguration) .withInput(operationWithRequestCompressionRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithRequestCompression", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithRequestCompression")) .putExecutionAttribute(SdkInternalExecutionAttribute.REQUEST_COMPRESSION, RequestCompression.builder().encodings("gzip").isStreaming(false).build()) .withMarshaller(new OperationWithRequestCompressionRequestMarshaller(protocolFactory))); @@ -748,10 +819,16 @@ public OperationWithStaticContextParamsResponse operationWithStaticContextParams return clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithStaticContextParams").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(operationWithStaticContextParamsRequest) + .withOperationName("OperationWithStaticContextParams") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(operationWithStaticContextParamsRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithStaticContextParams", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithStaticContextParams")) .withMarshaller(new OperationWithStaticContextParamsRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -828,6 +905,9 @@ public ReturnT putOperationWithChecksum(PutOperationWithChecksumReques .withRequestConfiguration(clientConfiguration) .withInput(putOperationWithChecksumRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PutOperationWithChecksum", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PutOperationWithChecksum")) .putExecutionAttribute( SdkInternalExecutionAttribute.HTTP_CHECKSUM, HttpChecksum @@ -906,6 +986,9 @@ public StreamingInputOperationResponse streamingInputOperation(StreamingInputOpe .withRequestConfiguration(clientConfiguration) .withInput(streamingInputOperationRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingInputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingInputOperation")) .withRequestBody(requestBody) .withMarshaller( StreamingRequestMarshaller.builder() @@ -960,10 +1043,17 @@ public ReturnT streamingOutputOperation(StreamingOutputOperationReques return clientHandler.execute( new ClientExecutionParams() - .withOperationName("StreamingOutputOperation").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(streamingOutputOperationRequest) - .withMetricCollector(apiCallMetricCollector).withResponseTransformer(responseTransformer) + .withOperationName("StreamingOutputOperation") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(streamingOutputOperationRequest) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingOutputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingOutputOperation")) + .withResponseTransformer(responseTransformer) .withMarshaller(new StreamingOutputOperationRequestMarshaller(protocolFactory)), responseTransformer); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -1003,6 +1093,53 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + QueryAuthSchemeProvider requestAuthSchemeProvider = request + .overrideConfiguration() + .flatMap(c -> c.authSchemeProvider()) + .map(p -> Validate.isInstanceOf(QueryAuthSchemeProvider.class, p, + "Expected an instance of QueryAuthSchemeProvider")).orElse(null); + QueryAuthSchemeProvider authSchemeProvider = requestAuthSchemeProvider != null ? requestAuthSchemeProvider : Validate + .isInstanceOf(QueryAuthSchemeProvider.class, clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of QueryAuthSchemeProvider"); + QueryAuthSchemeParams.Builder paramsBuilder = QueryAuthSchemeParams.builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + QueryEndpointProvider provider = (QueryEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + QueryEndpointParams endpointParams = QueryEndpointResolverUtils.ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = QueryEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = QueryEndpointResolverUtils.authSchemeWithEndpointSignerProperties(endpointAuthSchemes, + selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + QueryEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-rpcv2-async-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-rpcv2-async-client-class.java index bd3083b4602b..d941e4301efd 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-rpcv2-async-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-rpcv2-async-client-class.java @@ -6,13 +6,18 @@ import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -20,14 +25,20 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; import software.amazon.awssdk.core.client.handler.AsyncClientHandler; import software.amazon.awssdk.core.client.handler.ClientExecutionParams; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -37,6 +48,11 @@ import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.protocols.rpcv2.SmithyRpcV2CborProtocolFactory; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.smithyrpcv2protocol.auth.scheme.SmithyRpcV2ProtocolAuthSchemeParams; +import software.amazon.awssdk.services.smithyrpcv2protocol.auth.scheme.SmithyRpcV2ProtocolAuthSchemeProvider; +import software.amazon.awssdk.services.smithyrpcv2protocol.endpoints.SmithyRpcV2ProtocolEndpointParams; +import software.amazon.awssdk.services.smithyrpcv2protocol.endpoints.SmithyRpcV2ProtocolEndpointProvider; +import software.amazon.awssdk.services.smithyrpcv2protocol.endpoints.internal.SmithyRpcV2ProtocolEndpointResolverUtils; import software.amazon.awssdk.services.smithyrpcv2protocol.internal.ServiceVersionInfo; import software.amazon.awssdk.services.smithyrpcv2protocol.internal.SmithyRpcV2ProtocolServiceClientConfigurationBuilder; import software.amazon.awssdk.services.smithyrpcv2protocol.model.ComplexErrorException; @@ -83,6 +99,7 @@ import software.amazon.awssdk.services.smithyrpcv2protocol.transform.SimpleScalarPropertiesRequestMarshaller; import software.amazon.awssdk.services.smithyrpcv2protocol.transform.SparseNullsOperationRequestMarshaller; import software.amazon.awssdk.utils.CompletableFutureUtils; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link SmithyRpcV2ProtocolAsyncClient}. @@ -169,10 +186,16 @@ public CompletableFuture emptyInputOutput(EmptyInputOu CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("EmptyInputOutput").withProtocolMetadata(protocolMetadata) + .withOperationName("EmptyInputOutput") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new EmptyInputOutputRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "EmptyInputOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "EmptyInputOutput")) .withInput(emptyInputOutputRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -246,7 +269,8 @@ public CompletableFuture float16(Float16Request float16Request) .withProtocolMetadata(protocolMetadata).withMarshaller(new Float16RequestMarshaller(protocolFactory)) .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) - .withInput(float16Request)); + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "Float16", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "Float16")).withInput(float16Request)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); @@ -317,10 +341,16 @@ public CompletableFuture fractionalSeconds(Fractional CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("FractionalSeconds").withProtocolMetadata(protocolMetadata) + .withOperationName("FractionalSeconds") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new FractionalSecondsRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "FractionalSeconds", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "FractionalSeconds")) .withInput(fractionalSecondsRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -394,10 +424,16 @@ public CompletableFuture greetingWithErrors(Greeting CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("GreetingWithErrors").withProtocolMetadata(protocolMetadata) + .withOperationName("GreetingWithErrors") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new GreetingWithErrorsRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "GreetingWithErrors", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GreetingWithErrors")) .withInput(greetingWithErrorsRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -468,10 +504,15 @@ public CompletableFuture noInputOutput(NoInputOutputReque CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("NoInputOutput").withProtocolMetadata(protocolMetadata) + .withOperationName("NoInputOutput") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new NoInputOutputRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "NoInputOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "NoInputOutput")) .withInput(noInputOutputRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -545,10 +586,16 @@ public CompletableFuture operationWithDefaults( CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithDefaults").withProtocolMetadata(protocolMetadata) + .withOperationName("OperationWithDefaults") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OperationWithDefaultsRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithDefaults", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithDefaults")) .withInput(operationWithDefaultsRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -621,10 +668,16 @@ public CompletableFuture optionalInputOutput( CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("OptionalInputOutput").withProtocolMetadata(protocolMetadata) + .withOperationName("OptionalInputOutput") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OptionalInputOutputRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OptionalInputOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OptionalInputOutput")) .withInput(optionalInputOutputRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -696,10 +749,16 @@ public CompletableFuture recursiveShapes(RecursiveShape CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("RecursiveShapes").withProtocolMetadata(protocolMetadata) + .withOperationName("RecursiveShapes") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new RecursiveShapesRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "RecursiveShapes", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "RecursiveShapes")) .withInput(recursiveShapesRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -772,10 +831,16 @@ public CompletableFuture rpcV2CborDenseMaps(RpcV2Cbo CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("RpcV2CborDenseMaps").withProtocolMetadata(protocolMetadata) + .withOperationName("RpcV2CborDenseMaps") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new RpcV2CborDenseMapsRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "RpcV2CborDenseMaps", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "RpcV2CborDenseMaps")) .withInput(rpcV2CborDenseMapsRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -847,10 +912,16 @@ public CompletableFuture rpcV2CborLists(RpcV2CborListsRe CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("RpcV2CborLists").withProtocolMetadata(protocolMetadata) + .withOperationName("RpcV2CborLists") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new RpcV2CborListsRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "RpcV2CborLists", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "RpcV2CborLists")) .withInput(rpcV2CborListsRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -924,10 +995,16 @@ public CompletableFuture rpcV2CborSparseMaps( CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("RpcV2CborSparseMaps").withProtocolMetadata(protocolMetadata) + .withOperationName("RpcV2CborSparseMaps") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new RpcV2CborSparseMapsRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "RpcV2CborSparseMaps", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "RpcV2CborSparseMaps")) .withInput(rpcV2CborSparseMapsRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -1000,10 +1077,16 @@ public CompletableFuture simpleScalarProperties( CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("SimpleScalarProperties").withProtocolMetadata(protocolMetadata) + .withOperationName("SimpleScalarProperties") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new SimpleScalarPropertiesRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "SimpleScalarProperties", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "SimpleScalarProperties")) .withInput(simpleScalarPropertiesRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -1076,10 +1159,16 @@ public CompletableFuture sparseNullsOperation( CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("SparseNullsOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("SparseNullsOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new SparseNullsOperationRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "SparseNullsOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "SparseNullsOperation")) .withInput(sparseNullsOperationRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -1123,6 +1212,56 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + SmithyRpcV2ProtocolAuthSchemeProvider requestAuthSchemeProvider = request + .overrideConfiguration() + .flatMap(c -> c.authSchemeProvider()) + .map(p -> Validate.isInstanceOf(SmithyRpcV2ProtocolAuthSchemeProvider.class, p, + "Expected an instance of SmithyRpcV2ProtocolAuthSchemeProvider")).orElse(null); + SmithyRpcV2ProtocolAuthSchemeProvider authSchemeProvider = requestAuthSchemeProvider != null ? requestAuthSchemeProvider + : Validate.isInstanceOf(SmithyRpcV2ProtocolAuthSchemeProvider.class, + clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of SmithyRpcV2ProtocolAuthSchemeProvider"); + SmithyRpcV2ProtocolAuthSchemeParams.Builder paramsBuilder = SmithyRpcV2ProtocolAuthSchemeParams.builder().operation( + operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + SmithyRpcV2ProtocolEndpointProvider provider = (SmithyRpcV2ProtocolEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + SmithyRpcV2ProtocolEndpointParams endpointParams = SmithyRpcV2ProtocolEndpointResolverUtils.ruleParams(request, + executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = SmithyRpcV2ProtocolEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = SmithyRpcV2ProtocolEndpointResolverUtils.authSchemeWithEndpointSignerProperties( + endpointAuthSchemes, selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + SmithyRpcV2ProtocolEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); @@ -1170,4 +1309,4 @@ private HttpResponseHandler createErrorResponseHandler(Base public void close() { clientHandler.close(); } -} \ No newline at end of file +} diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-rpcv2-sync.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-rpcv2-sync.java index fcda9aa09cb8..d2b53fe52c54 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-rpcv2-sync.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-rpcv2-sync.java @@ -3,11 +3,16 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsSyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -15,6 +20,7 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; @@ -22,8 +28,12 @@ import software.amazon.awssdk.core.client.handler.SyncClientHandler; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -33,6 +43,11 @@ import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.protocols.rpcv2.SmithyRpcV2CborProtocolFactory; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.smithyrpcv2protocol.auth.scheme.SmithyRpcV2ProtocolAuthSchemeParams; +import software.amazon.awssdk.services.smithyrpcv2protocol.auth.scheme.SmithyRpcV2ProtocolAuthSchemeProvider; +import software.amazon.awssdk.services.smithyrpcv2protocol.endpoints.SmithyRpcV2ProtocolEndpointParams; +import software.amazon.awssdk.services.smithyrpcv2protocol.endpoints.SmithyRpcV2ProtocolEndpointProvider; +import software.amazon.awssdk.services.smithyrpcv2protocol.endpoints.internal.SmithyRpcV2ProtocolEndpointResolverUtils; import software.amazon.awssdk.services.smithyrpcv2protocol.internal.ServiceVersionInfo; import software.amazon.awssdk.services.smithyrpcv2protocol.internal.SmithyRpcV2ProtocolServiceClientConfigurationBuilder; import software.amazon.awssdk.services.smithyrpcv2protocol.model.ComplexErrorException; @@ -79,6 +94,7 @@ import software.amazon.awssdk.services.smithyrpcv2protocol.transform.SimpleScalarPropertiesRequestMarshaller; import software.amazon.awssdk.services.smithyrpcv2protocol.transform.SparseNullsOperationRequestMarshaller; import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link SmithyRpcV2ProtocolClient}. @@ -165,6 +181,8 @@ public EmptyInputOutputResponse emptyInputOutput(EmptyInputOutputRequest emptyIn .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withInput(emptyInputOutputRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "EmptyInputOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "EmptyInputOutput")) .withMarshaller(new EmptyInputOutputRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -228,6 +246,8 @@ public Float16Response float16(Float16Request float16Request) throws AwsServiceE .withOperationName("Float16").withProtocolMetadata(protocolMetadata).withResponseHandler(responseHandler) .withErrorResponseHandler(errorResponseHandler).withRequestConfiguration(clientConfiguration) .withInput(float16Request).withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "Float16", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "Float16")) .withMarshaller(new Float16RequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -293,6 +313,8 @@ public FractionalSecondsResponse fractionalSeconds(FractionalSecondsRequest frac .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withInput(fractionalSecondsRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "FractionalSeconds", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "FractionalSeconds")) .withMarshaller(new FractionalSecondsRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -361,6 +383,8 @@ public GreetingWithErrorsResponse greetingWithErrors(GreetingWithErrorsRequest g .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withInput(greetingWithErrorsRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "GreetingWithErrors", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GreetingWithErrors")) .withMarshaller(new GreetingWithErrorsRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -425,6 +449,8 @@ public NoInputOutputResponse noInputOutput(NoInputOutputRequest noInputOutputReq .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withInput(noInputOutputRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "NoInputOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "NoInputOutput")) .withMarshaller(new NoInputOutputRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -486,12 +512,19 @@ public OperationWithDefaultsResponse operationWithDefaults(OperationWithDefaults apiCallMetricCollector.reportMetric(CoreMetric.SERVICE_ID, "SmithyRpcV2Protocol"); apiCallMetricCollector.reportMetric(CoreMetric.OPERATION_NAME, "OperationWithDefaults"); - return clientHandler.execute(new ClientExecutionParams() - .withOperationName("OperationWithDefaults").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(operationWithDefaultsRequest) - .withMetricCollector(apiCallMetricCollector) - .withMarshaller(new OperationWithDefaultsRequestMarshaller(protocolFactory))); + return clientHandler + .execute(new ClientExecutionParams() + .withOperationName("OperationWithDefaults") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(operationWithDefaultsRequest) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithDefaults", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithDefaults")) + .withMarshaller(new OperationWithDefaultsRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); } @@ -556,6 +589,8 @@ public OptionalInputOutputResponse optionalInputOutput(OptionalInputOutputReques .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withInput(optionalInputOutputRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "OptionalInputOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OptionalInputOutput")) .withMarshaller(new OptionalInputOutputRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -621,6 +656,8 @@ public RecursiveShapesResponse recursiveShapes(RecursiveShapesRequest recursiveS .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withInput(recursiveShapesRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "RecursiveShapes", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "RecursiveShapes")) .withMarshaller(new RecursiveShapesRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -687,6 +724,8 @@ public RpcV2CborDenseMapsResponse rpcV2CborDenseMaps(RpcV2CborDenseMapsRequest r .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withInput(rpcV2CborDenseMapsRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "RpcV2CborDenseMaps", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "RpcV2CborDenseMaps")) .withMarshaller(new RpcV2CborDenseMapsRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -752,6 +791,8 @@ public RpcV2CborListsResponse rpcV2CborLists(RpcV2CborListsRequest rpcV2CborList .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withInput(rpcV2CborListsRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "RpcV2CborLists", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "RpcV2CborLists")) .withMarshaller(new RpcV2CborListsRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -818,6 +859,8 @@ public RpcV2CborSparseMapsResponse rpcV2CborSparseMaps(RpcV2CborSparseMapsReques .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withInput(rpcV2CborSparseMapsRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "RpcV2CborSparseMaps", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "RpcV2CborSparseMaps")) .withMarshaller(new RpcV2CborSparseMapsRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -880,10 +923,16 @@ public SimpleScalarPropertiesResponse simpleScalarProperties(SimpleScalarPropert return clientHandler .execute(new ClientExecutionParams() - .withOperationName("SimpleScalarProperties").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(simpleScalarPropertiesRequest) + .withOperationName("SimpleScalarProperties") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(simpleScalarPropertiesRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "SimpleScalarProperties", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "SimpleScalarProperties")) .withMarshaller(new SimpleScalarPropertiesRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -949,6 +998,8 @@ public SparseNullsOperationResponse sparseNullsOperation(SparseNullsOperationReq .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withInput(sparseNullsOperationRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "SparseNullsOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "SparseNullsOperation")) .withMarshaller(new SparseNullsOperationRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -975,6 +1026,56 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + SmithyRpcV2ProtocolAuthSchemeProvider requestAuthSchemeProvider = request + .overrideConfiguration() + .flatMap(c -> c.authSchemeProvider()) + .map(p -> Validate.isInstanceOf(SmithyRpcV2ProtocolAuthSchemeProvider.class, p, + "Expected an instance of SmithyRpcV2ProtocolAuthSchemeProvider")).orElse(null); + SmithyRpcV2ProtocolAuthSchemeProvider authSchemeProvider = requestAuthSchemeProvider != null ? requestAuthSchemeProvider + : Validate.isInstanceOf(SmithyRpcV2ProtocolAuthSchemeProvider.class, + clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of SmithyRpcV2ProtocolAuthSchemeProvider"); + SmithyRpcV2ProtocolAuthSchemeParams.Builder paramsBuilder = SmithyRpcV2ProtocolAuthSchemeParams.builder().operation( + operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + SmithyRpcV2ProtocolEndpointProvider provider = (SmithyRpcV2ProtocolEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + SmithyRpcV2ProtocolEndpointParams endpointParams = SmithyRpcV2ProtocolEndpointResolverUtils.ruleParams(request, + executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = SmithyRpcV2ProtocolEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = SmithyRpcV2ProtocolEndpointResolverUtils.authSchemeWithEndpointSignerProperties( + endpointAuthSchemes, selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + SmithyRpcV2ProtocolEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private HttpResponseHandler createErrorResponseHandler(BaseAwsJsonProtocolFactory protocolFactory, JsonOperationMetadata operationMetadata, Function> exceptionMetadataMapper) { return protocolFactory.createErrorResponseHandler(operationMetadata, exceptionMetadataMapper); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-unsigned-payload-trait-async-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-unsigned-payload-trait-async-client-class.java index efa307505463..3e76e9b15269 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-unsigned-payload-trait-async-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-unsigned-payload-trait-async-client-class.java @@ -5,14 +5,20 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.Set; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -20,16 +26,25 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; import software.amazon.awssdk.core.client.handler.AsyncClientHandler; import software.amazon.awssdk.core.client.handler.ClientExecutionParams; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; import software.amazon.awssdk.core.runtime.transform.AsyncStreamingRequestMarshaller; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.aws.scheme.AwsV4aAuthScheme; +import software.amazon.awssdk.http.auth.aws.signer.AwsV4aHttpSigner; +import software.amazon.awssdk.http.auth.aws.signer.RegionSet; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -39,6 +54,11 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.database.auth.scheme.DatabaseAuthSchemeParams; +import software.amazon.awssdk.services.database.auth.scheme.DatabaseAuthSchemeProvider; +import software.amazon.awssdk.services.database.endpoints.DatabaseEndpointParams; +import software.amazon.awssdk.services.database.endpoints.DatabaseEndpointProvider; +import software.amazon.awssdk.services.database.endpoints.internal.DatabaseEndpointResolverUtils; import software.amazon.awssdk.services.database.internal.DatabaseServiceClientConfigurationBuilder; import software.amazon.awssdk.services.database.internal.ServiceVersionInfo; import software.amazon.awssdk.services.database.model.DatabaseException; @@ -76,7 +96,9 @@ import software.amazon.awssdk.services.database.transform.OpsWithSigv4AndSigv4ASignedPayloadRequestMarshaller; import software.amazon.awssdk.services.database.transform.PutRowRequestMarshaller; import software.amazon.awssdk.services.database.transform.SecondOpsWithSigv4AndSigv4ASignedPayloadRequestMarshaller; +import software.amazon.awssdk.utils.CollectionUtils; import software.amazon.awssdk.utils.CompletableFutureUtils; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link DatabaseAsyncClient}. @@ -163,7 +185,9 @@ public CompletableFuture deleteRow(DeleteRowRequest deleteRow .withProtocolMetadata(protocolMetadata) .withMarshaller(new DeleteRowRequestMarshaller(protocolFactory)).withResponseHandler(responseHandler) .withErrorResponseHandler(errorResponseHandler).withRequestConfiguration(clientConfiguration) - .withMetricCollector(apiCallMetricCollector).withInput(deleteRowRequest)); + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "DeleteRow", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "DeleteRow")).withInput(deleteRowRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); @@ -234,7 +258,8 @@ public CompletableFuture getRow(GetRowRequest getRowRequest) { .withProtocolMetadata(protocolMetadata).withMarshaller(new GetRowRequestMarshaller(protocolFactory)) .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) - .withInput(getRowRequest)); + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "GetRow", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GetRow")).withInput(getRowRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); @@ -306,10 +331,16 @@ public CompletableFuture opWithSigv CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("opWithSigv4AndSigv4aUnSignedPayload").withProtocolMetadata(protocolMetadata) + .withOperationName("opWithSigv4AndSigv4aUnSignedPayload") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OpWithSigv4AndSigv4AUnSignedPayloadRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "opWithSigv4AndSigv4aUnSignedPayload", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "opWithSigv4AndSigv4aUnSignedPayload")) .withInput(opWithSigv4AndSigv4AUnSignedPayloadRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -380,10 +411,16 @@ public CompletableFuture opWithSigv4SignedPayl CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("opWithSigv4SignedPayload").withProtocolMetadata(protocolMetadata) + .withOperationName("opWithSigv4SignedPayload") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OpWithSigv4SignedPayloadRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "opWithSigv4SignedPayload", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "opWithSigv4SignedPayload")) .withInput(opWithSigv4SignedPayloadRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -454,10 +491,16 @@ public CompletableFuture opWithSigv4UnSigned CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("opWithSigv4UnSignedPayload").withProtocolMetadata(protocolMetadata) + .withOperationName("opWithSigv4UnSignedPayload") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OpWithSigv4UnSignedPayloadRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "opWithSigv4UnSignedPayload", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "opWithSigv4UnSignedPayload")) .withInput(opWithSigv4UnSignedPayloadRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -544,8 +587,14 @@ public CompletableFuture opWithS .delegateMarshaller( new OpWithSigv4UnSignedPayloadAndStreamingRequestMarshaller(protocolFactory)) .asyncRequestBody(requestBody).transferEncoding(true).build()) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "opWithSigv4UnSignedPayloadAndStreaming", + clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "opWithSigv4UnSignedPayloadAndStreaming")) .withAsyncRequestBody(requestBody).withInput(opWithSigv4UnSignedPayloadAndStreamingRequest)); CompletableFuture whenCompleted = executeFuture .whenComplete((r, e) -> { @@ -617,10 +666,16 @@ public CompletableFuture opWithSigv4aSignedPa CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("opWithSigv4aSignedPayload").withProtocolMetadata(protocolMetadata) + .withOperationName("opWithSigv4aSignedPayload") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OpWithSigv4ASignedPayloadRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "opWithSigv4aSignedPayload", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "opWithSigv4aSignedPayload")) .withInput(opWithSigv4ASignedPayloadRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -691,10 +746,16 @@ public CompletableFuture opWithSigv4aUnSign CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("opWithSigv4aUnSignedPayload").withProtocolMetadata(protocolMetadata) + .withOperationName("opWithSigv4aUnSignedPayload") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OpWithSigv4AUnSignedPayloadRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "opWithSigv4aUnSignedPayload", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "opWithSigv4aUnSignedPayload")) .withInput(opWithSigv4AUnSignedPayloadRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -767,10 +828,16 @@ public CompletableFuture opsWithSigv CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("opsWithSigv4andSigv4aSignedPayload").withProtocolMetadata(protocolMetadata) + .withOperationName("opsWithSigv4andSigv4aSignedPayload") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OpsWithSigv4AndSigv4ASignedPayloadRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "opsWithSigv4andSigv4aSignedPayload", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "opsWithSigv4andSigv4aSignedPayload")) .withInput(opsWithSigv4AndSigv4ASignedPayloadRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -842,7 +909,8 @@ public CompletableFuture putRow(PutRowRequest putRowRequest) { .withProtocolMetadata(protocolMetadata).withMarshaller(new PutRowRequestMarshaller(protocolFactory)) .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) - .withInput(putRowRequest)); + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "PutRow", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PutRow")).withInput(putRowRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); @@ -914,10 +982,17 @@ public CompletableFuture secon CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("secondOpsWithSigv4andSigv4aSignedPayload").withProtocolMetadata(protocolMetadata) + .withOperationName("secondOpsWithSigv4andSigv4aSignedPayload") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new SecondOpsWithSigv4AndSigv4ASignedPayloadRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "secondOpsWithSigv4andSigv4aSignedPayload", + clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "secondOpsWithSigv4andSigv4aSignedPayload")) .withInput(secondOpsWithSigv4AndSigv4ASignedPayloadRequest)); CompletableFuture whenCompleted = executeFuture .whenComplete((r, e) -> { @@ -961,6 +1036,66 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + DatabaseAuthSchemeProvider requestAuthSchemeProvider = request + .overrideConfiguration() + .flatMap(c -> c.authSchemeProvider()) + .map(p -> Validate.isInstanceOf(DatabaseAuthSchemeProvider.class, p, + "Expected an instance of DatabaseAuthSchemeProvider")).orElse(null); + DatabaseAuthSchemeProvider authSchemeProvider = requestAuthSchemeProvider != null ? requestAuthSchemeProvider : Validate + .isInstanceOf(DatabaseAuthSchemeProvider.class, clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of DatabaseAuthSchemeProvider"); + DatabaseAuthSchemeParams.Builder paramsBuilder = DatabaseAuthSchemeParams.builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + Set sigv4aRegionSet = clientConfiguration.option(AwsClientOption.AWS_SIGV4A_SIGNING_REGION_SET); + if (!CollectionUtils.isNullOrEmpty(sigv4aRegionSet)) { + paramsBuilder.regionSet(RegionSet.create(sigv4aRegionSet)); + } + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + DatabaseEndpointProvider provider = (DatabaseEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + DatabaseEndpointParams endpointParams = DatabaseEndpointResolverUtils.ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = DatabaseEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = DatabaseEndpointResolverUtils.authSchemeWithEndpointSignerProperties(endpointAuthSchemes, + selectedAuthScheme); + // Precedence of SigV4a RegionSet is set according to multi-auth SigV4a specifications + if (selectedAuthScheme.authSchemeOption().schemeId().equals(AwsV4aAuthScheme.SCHEME_ID) + && selectedAuthScheme.authSchemeOption().signerProperty(AwsV4aHttpSigner.REGION_SET) == null) { + AuthSchemeOption.Builder optionBuilder = selectedAuthScheme.authSchemeOption().toBuilder(); + RegionSet rs = RegionSet.create(endpointParams.region().id()); + optionBuilder.putSignerProperty(AwsV4aHttpSigner.REGION_SET, rs); + selectedAuthScheme = new SelectedAuthScheme(selectedAuthScheme.identity(), selectedAuthScheme.signer(), + optionBuilder.build()); + } + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + DatabaseEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-unsigned-payload-trait-sync-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-unsigned-payload-trait-sync-client-class.java index ff8b3a285415..4d7ce7f0f45c 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-unsigned-payload-trait-sync-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-unsigned-payload-trait-sync-client-class.java @@ -3,11 +3,17 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsSyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -15,6 +21,7 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; @@ -22,10 +29,17 @@ import software.amazon.awssdk.core.client.handler.SyncClientHandler; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; import software.amazon.awssdk.core.runtime.transform.StreamingRequestMarshaller; import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.aws.scheme.AwsV4aAuthScheme; +import software.amazon.awssdk.http.auth.aws.signer.AwsV4aHttpSigner; +import software.amazon.awssdk.http.auth.aws.signer.RegionSet; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -35,6 +49,11 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.database.auth.scheme.DatabaseAuthSchemeParams; +import software.amazon.awssdk.services.database.auth.scheme.DatabaseAuthSchemeProvider; +import software.amazon.awssdk.services.database.endpoints.DatabaseEndpointParams; +import software.amazon.awssdk.services.database.endpoints.DatabaseEndpointProvider; +import software.amazon.awssdk.services.database.endpoints.internal.DatabaseEndpointResolverUtils; import software.amazon.awssdk.services.database.internal.DatabaseServiceClientConfigurationBuilder; import software.amazon.awssdk.services.database.internal.ServiceVersionInfo; import software.amazon.awssdk.services.database.model.DatabaseException; @@ -72,7 +91,9 @@ import software.amazon.awssdk.services.database.transform.OpsWithSigv4AndSigv4ASignedPayloadRequestMarshaller; import software.amazon.awssdk.services.database.transform.PutRowRequestMarshaller; import software.amazon.awssdk.services.database.transform.SecondOpsWithSigv4AndSigv4ASignedPayloadRequestMarshaller; +import software.amazon.awssdk.utils.CollectionUtils; import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link DatabaseClient}. @@ -155,6 +176,8 @@ public DeleteRowResponse deleteRow(DeleteRowRequest deleteRowRequest) throws Inv .withOperationName("DeleteRow").withProtocolMetadata(protocolMetadata).withResponseHandler(responseHandler) .withErrorResponseHandler(errorResponseHandler).withRequestConfiguration(clientConfiguration) .withInput(deleteRowRequest).withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "DeleteRow", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "DeleteRow")) .withMarshaller(new DeleteRowRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -216,6 +239,8 @@ public GetRowResponse getRow(GetRowRequest getRowRequest) throws InvalidInputExc .withProtocolMetadata(protocolMetadata).withResponseHandler(responseHandler) .withErrorResponseHandler(errorResponseHandler).withRequestConfiguration(clientConfiguration) .withInput(getRowRequest).withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "GetRow", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GetRow")) .withMarshaller(new GetRowRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -277,6 +302,8 @@ public PutRowResponse putRow(PutRowRequest putRowRequest) throws InvalidInputExc .withProtocolMetadata(protocolMetadata).withResponseHandler(responseHandler) .withErrorResponseHandler(errorResponseHandler).withRequestConfiguration(clientConfiguration) .withInput(putRowRequest).withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "PutRow", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PutRow")) .withMarshaller(new PutRowRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -339,10 +366,16 @@ public OpWithSigv4AndSigv4AUnSignedPayloadResponse opWithSigv4AndSigv4aUnSignedP return clientHandler .execute(new ClientExecutionParams() - .withOperationName("opWithSigv4AndSigv4aUnSignedPayload").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(opWithSigv4AndSigv4AUnSignedPayloadRequest) + .withOperationName("opWithSigv4AndSigv4aUnSignedPayload") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(opWithSigv4AndSigv4AUnSignedPayloadRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "opWithSigv4AndSigv4aUnSignedPayload", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "opWithSigv4AndSigv4aUnSignedPayload")) .withMarshaller(new OpWithSigv4AndSigv4AUnSignedPayloadRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -404,10 +437,16 @@ public OpWithSigv4SignedPayloadResponse opWithSigv4SignedPayload( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("opWithSigv4SignedPayload").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(opWithSigv4SignedPayloadRequest) + .withOperationName("opWithSigv4SignedPayload") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(opWithSigv4SignedPayloadRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "opWithSigv4SignedPayload", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "opWithSigv4SignedPayload")) .withMarshaller(new OpWithSigv4SignedPayloadRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -469,10 +508,16 @@ public OpWithSigv4UnSignedPayloadResponse opWithSigv4UnSignedPayload( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("opWithSigv4UnSignedPayload").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(opWithSigv4UnSignedPayloadRequest) + .withOperationName("opWithSigv4UnSignedPayload") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(opWithSigv4UnSignedPayloadRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "opWithSigv4UnSignedPayload", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "opWithSigv4UnSignedPayload")) .withMarshaller(new OpWithSigv4UnSignedPayloadRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -553,6 +598,10 @@ public OpWithSigv4UnSignedPayloadAndStreamingResponse opWithSigv4UnSignedPayload .withRequestConfiguration(clientConfiguration) .withInput(opWithSigv4UnSignedPayloadAndStreamingRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "opWithSigv4UnSignedPayloadAndStreaming", + clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "opWithSigv4UnSignedPayloadAndStreaming")) .withRequestBody(requestBody) .withMarshaller( StreamingRequestMarshaller @@ -620,10 +669,16 @@ public OpWithSigv4ASignedPayloadResponse opWithSigv4aSignedPayload( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("opWithSigv4aSignedPayload").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(opWithSigv4ASignedPayloadRequest) + .withOperationName("opWithSigv4aSignedPayload") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(opWithSigv4ASignedPayloadRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "opWithSigv4aSignedPayload", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "opWithSigv4aSignedPayload")) .withMarshaller(new OpWithSigv4ASignedPayloadRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -685,10 +740,16 @@ public OpWithSigv4AUnSignedPayloadResponse opWithSigv4aUnSignedPayload( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("opWithSigv4aUnSignedPayload").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(opWithSigv4AUnSignedPayloadRequest) + .withOperationName("opWithSigv4aUnSignedPayload") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(opWithSigv4AUnSignedPayloadRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "opWithSigv4aUnSignedPayload", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "opWithSigv4aUnSignedPayload")) .withMarshaller(new OpWithSigv4AUnSignedPayloadRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -751,10 +812,16 @@ public OpsWithSigv4AndSigv4ASignedPayloadResponse opsWithSigv4andSigv4aSignedPay return clientHandler .execute(new ClientExecutionParams() - .withOperationName("opsWithSigv4andSigv4aSignedPayload").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(opsWithSigv4AndSigv4ASignedPayloadRequest) + .withOperationName("opsWithSigv4andSigv4aSignedPayload") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(opsWithSigv4AndSigv4ASignedPayloadRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "opsWithSigv4andSigv4aSignedPayload", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "opsWithSigv4andSigv4aSignedPayload")) .withMarshaller(new OpsWithSigv4AndSigv4ASignedPayloadRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -817,11 +884,17 @@ public SecondOpsWithSigv4AndSigv4ASignedPayloadResponse secondOpsWithSigv4andSig return clientHandler .execute(new ClientExecutionParams() - .withOperationName("secondOpsWithSigv4andSigv4aSignedPayload").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) + .withOperationName("secondOpsWithSigv4andSigv4aSignedPayload") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) .withInput(secondOpsWithSigv4AndSigv4ASignedPayloadRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "secondOpsWithSigv4andSigv4aSignedPayload", + clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "secondOpsWithSigv4andSigv4aSignedPayload")) .withMarshaller(new SecondOpsWithSigv4AndSigv4ASignedPayloadRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -848,6 +921,66 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + DatabaseAuthSchemeProvider requestAuthSchemeProvider = request + .overrideConfiguration() + .flatMap(c -> c.authSchemeProvider()) + .map(p -> Validate.isInstanceOf(DatabaseAuthSchemeProvider.class, p, + "Expected an instance of DatabaseAuthSchemeProvider")).orElse(null); + DatabaseAuthSchemeProvider authSchemeProvider = requestAuthSchemeProvider != null ? requestAuthSchemeProvider : Validate + .isInstanceOf(DatabaseAuthSchemeProvider.class, clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of DatabaseAuthSchemeProvider"); + DatabaseAuthSchemeParams.Builder paramsBuilder = DatabaseAuthSchemeParams.builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + Set sigv4aRegionSet = clientConfiguration.option(AwsClientOption.AWS_SIGV4A_SIGNING_REGION_SET); + if (!CollectionUtils.isNullOrEmpty(sigv4aRegionSet)) { + paramsBuilder.regionSet(RegionSet.create(sigv4aRegionSet)); + } + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + DatabaseEndpointProvider provider = (DatabaseEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + DatabaseEndpointParams endpointParams = DatabaseEndpointResolverUtils.ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = DatabaseEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = DatabaseEndpointResolverUtils.authSchemeWithEndpointSignerProperties(endpointAuthSchemes, + selectedAuthScheme); + // Precedence of SigV4a RegionSet is set according to multi-auth SigV4a specifications + if (selectedAuthScheme.authSchemeOption().schemeId().equals(AwsV4aAuthScheme.SCHEME_ID) + && selectedAuthScheme.authSchemeOption().signerProperty(AwsV4aHttpSigner.REGION_SET) == null) { + AuthSchemeOption.Builder optionBuilder = selectedAuthScheme.authSchemeOption().toBuilder(); + RegionSet rs = RegionSet.create(endpointParams.region().id()); + optionBuilder.putSignerProperty(AwsV4aHttpSigner.REGION_SET, rs); + selectedAuthScheme = new SelectedAuthScheme(selectedAuthScheme.identity(), selectedAuthScheme.signer(), + optionBuilder.build()); + } + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + DatabaseEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private HttpResponseHandler createErrorResponseHandler(BaseAwsJsonProtocolFactory protocolFactory, JsonOperationMetadata operationMetadata, Function> exceptionMetadataMapper) { return protocolFactory.createErrorResponseHandler(operationMetadata, exceptionMetadataMapper); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-xml-async-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-xml-async-client-class.java index a7d0c7b4a984..9ccbb39a4456 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-xml-async-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-xml-async-client-class.java @@ -4,14 +4,20 @@ import java.util.Collections; import java.util.List; +import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.Executor; import java.util.function.Consumer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.eventstream.EventStreamAsyncResponseTransformer; import software.amazon.awssdk.awscore.eventstream.EventStreamTaggedUnionPojoSupplier; import software.amazon.awssdk.awscore.eventstream.RestEventStreamAsyncResponseTransformer; @@ -26,6 +32,7 @@ import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkPojoBuilder; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.core.async.AsyncResponseTransformerUtils; @@ -35,7 +42,9 @@ import software.amazon.awssdk.core.client.config.SdkClientOption; import software.amazon.awssdk.core.client.handler.AsyncClientHandler; import software.amazon.awssdk.core.client.handler.ClientExecutionParams; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.interceptor.trait.HttpChecksum; import software.amazon.awssdk.core.interceptor.trait.HttpChecksumRequired; @@ -43,6 +52,8 @@ import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; import software.amazon.awssdk.core.runtime.transform.AsyncStreamingRequestMarshaller; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -50,6 +61,11 @@ import software.amazon.awssdk.protocols.xml.AwsXmlProtocolFactory; import software.amazon.awssdk.protocols.xml.XmlOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.xml.auth.scheme.XmlAuthSchemeParams; +import software.amazon.awssdk.services.xml.auth.scheme.XmlAuthSchemeProvider; +import software.amazon.awssdk.services.xml.endpoints.XmlEndpointParams; +import software.amazon.awssdk.services.xml.endpoints.XmlEndpointProvider; +import software.amazon.awssdk.services.xml.endpoints.internal.XmlEndpointResolverUtils; import software.amazon.awssdk.services.xml.internal.ServiceVersionInfo; import software.amazon.awssdk.services.xml.internal.XmlServiceClientConfigurationBuilder; import software.amazon.awssdk.services.xml.model.APostOperationRequest; @@ -91,6 +107,7 @@ import software.amazon.awssdk.services.xml.transform.StreamingOutputOperationRequestMarshaller; import software.amazon.awssdk.utils.CompletableFutureUtils; import software.amazon.awssdk.utils.Pair; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link XmlAsyncClient}. @@ -164,11 +181,17 @@ public CompletableFuture aPostOperation(APostOperationRe CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperation").withRequestConfiguration(clientConfiguration) + .withOperationName("APostOperation") + .withRequestConfiguration(clientConfiguration) .withProtocolMetadata(protocolMetadata) .withMarshaller(new APostOperationRequestMarshaller(protocolFactory)) - .withCombinedResponseHandler(responseHandler).hostPrefixExpression(resolvedHostExpression) - .withMetricCollector(apiCallMetricCollector).withInput(aPostOperationRequest)); + .withCombinedResponseHandler(responseHandler) + .hostPrefixExpression(resolvedHostExpression) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperation")) + .withInput(aPostOperationRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -224,10 +247,15 @@ public CompletableFuture aPostOperationWithOut CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperationWithOutput").withRequestConfiguration(clientConfiguration) + .withOperationName("APostOperationWithOutput") + .withRequestConfiguration(clientConfiguration) .withProtocolMetadata(protocolMetadata) .withMarshaller(new APostOperationWithOutputRequestMarshaller(protocolFactory)) - .withCombinedResponseHandler(responseHandler).withMetricCollector(apiCallMetricCollector) + .withCombinedResponseHandler(responseHandler) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperationWithOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperationWithOutput")) .withInput(aPostOperationWithOutputRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { @@ -280,11 +308,17 @@ public CompletableFuture bearerAuthOperation( CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("BearerAuthOperation").withRequestConfiguration(clientConfiguration) + .withOperationName("BearerAuthOperation") + .withRequestConfiguration(clientConfiguration) .withProtocolMetadata(protocolMetadata) .withMarshaller(new BearerAuthOperationRequestMarshaller(protocolFactory)) - .withCombinedResponseHandler(responseHandler).credentialType(CredentialType.TOKEN) - .withMetricCollector(apiCallMetricCollector).withInput(bearerAuthOperationRequest)); + .withCombinedResponseHandler(responseHandler) + .credentialType(CredentialType.TOKEN) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "BearerAuthOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "BearerAuthOperation")) + .withInput(bearerAuthOperationRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -353,12 +387,17 @@ public CompletableFuture eventStreamOperation(EventStreamOperationRequest CompletableFuture executeFuture = clientHandler.execute( new ClientExecutionParams() - .withOperationName("EventStreamOperation").withRequestConfiguration(clientConfiguration) + .withOperationName("EventStreamOperation") + .withRequestConfiguration(clientConfiguration) .withProtocolMetadata(protocolMetadata) .withMarshaller(new EventStreamOperationRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withMetricCollector(apiCallMetricCollector).withInput(eventStreamOperationRequest), - restAsyncResponseTransformer); + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "EventStreamOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "EventStreamOperation")) + .withInput(eventStreamOperationRequest), restAsyncResponseTransformer); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { if (e != null) { @@ -423,6 +462,9 @@ public CompletableFuture getOperationWithCheck .withMarshaller(new GetOperationWithChecksumRequestMarshaller(protocolFactory)) .withCombinedResponseHandler(responseHandler) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "GetOperationWithChecksum", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GetOperationWithChecksum")) .putExecutionAttribute( SdkInternalExecutionAttribute.HTTP_CHECKSUM, HttpChecksum.builder().requestChecksumRequired(true).isRequestStreaming(false) @@ -487,6 +529,9 @@ public CompletableFuture operationWithChe .withMarshaller(new OperationWithChecksumRequiredRequestMarshaller(protocolFactory)) .withCombinedResponseHandler(responseHandler) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithChecksumRequired", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithChecksumRequired")) .putExecutionAttribute(SdkInternalExecutionAttribute.HTTP_CHECKSUM_REQUIRED, HttpChecksumRequired.create()).withInput(operationWithChecksumRequiredRequest)); CompletableFuture whenCompleteFuture = null; @@ -540,10 +585,15 @@ public CompletableFuture operationWithNoneAut CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithNoneAuthType").withRequestConfiguration(clientConfiguration) + .withOperationName("OperationWithNoneAuthType") + .withRequestConfiguration(clientConfiguration) .withProtocolMetadata(protocolMetadata) .withMarshaller(new OperationWithNoneAuthTypeRequestMarshaller(protocolFactory)) - .withCombinedResponseHandler(responseHandler).withMetricCollector(apiCallMetricCollector) + .withCombinedResponseHandler(responseHandler) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithNoneAuthType", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithNoneAuthType")) .withInput(operationWithNoneAuthTypeRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { @@ -603,6 +653,9 @@ public CompletableFuture operationWithR .withMarshaller(new OperationWithRequestCompressionRequestMarshaller(protocolFactory)) .withCombinedResponseHandler(responseHandler) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithRequestCompression", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithRequestCompression")) .putExecutionAttribute(SdkInternalExecutionAttribute.REQUEST_COMPRESSION, RequestCompression.builder().encodings("gzip").isStreaming(false).build()) .withInput(operationWithRequestCompressionRequest)); @@ -691,6 +744,9 @@ public CompletableFuture putOperationWithChecksum( .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PutOperationWithChecksum", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PutOperationWithChecksum")) .putExecutionAttribute( SdkInternalExecutionAttribute.HTTP_CHECKSUM, HttpChecksum @@ -777,9 +833,13 @@ public CompletableFuture streamingInputOperatio .withMarshaller( AsyncStreamingRequestMarshaller.builder() .delegateMarshaller(new StreamingInputOperationRequestMarshaller(protocolFactory)) - .asyncRequestBody(requestBody).build()).withCombinedResponseHandler(responseHandler) - .withMetricCollector(apiCallMetricCollector).withAsyncRequestBody(requestBody) - .withInput(streamingInputOperationRequest)); + .asyncRequestBody(requestBody).build()) + .withCombinedResponseHandler(responseHandler) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingInputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingInputOperation")) + .withAsyncRequestBody(requestBody).withInput(streamingInputOperationRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -842,10 +902,16 @@ public CompletableFuture streamingOutputOperation( CompletableFuture executeFuture = clientHandler.execute( new ClientExecutionParams() - .withOperationName("StreamingOutputOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("StreamingOutputOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new StreamingOutputOperationRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingOutputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingOutputOperation")) .withAsyncResponseTransformer(asyncResponseTransformer).withInput(streamingOutputOperationRequest), asyncResponseTransformer); CompletableFuture whenCompleteFuture = null; @@ -903,6 +969,51 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + XmlAuthSchemeProvider requestAuthSchemeProvider = request.overrideConfiguration().flatMap(c -> c.authSchemeProvider()) + .map(p -> Validate.isInstanceOf(XmlAuthSchemeProvider.class, p, "Expected an instance of XmlAuthSchemeProvider")) + .orElse(null); + XmlAuthSchemeProvider authSchemeProvider = requestAuthSchemeProvider != null ? requestAuthSchemeProvider : Validate + .isInstanceOf(XmlAuthSchemeProvider.class, clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of XmlAuthSchemeProvider"); + XmlAuthSchemeParams.Builder paramsBuilder = XmlAuthSchemeParams.builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + XmlEndpointProvider provider = (XmlEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + XmlEndpointParams endpointParams = XmlEndpointResolverUtils.ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = XmlEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = XmlEndpointResolverUtils.authSchemeWithEndpointSignerProperties(endpointAuthSchemes, + selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + XmlEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-xml-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-xml-client-class.java index 627684e47ff2..b9b92d71bfcc 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-xml-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-xml-client-class.java @@ -2,10 +2,16 @@ import java.util.Collections; import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsSyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -16,6 +22,7 @@ import software.amazon.awssdk.core.Response; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; @@ -23,6 +30,7 @@ import software.amazon.awssdk.core.client.handler.SyncClientHandler; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.interceptor.trait.HttpChecksum; import software.amazon.awssdk.core.interceptor.trait.HttpChecksumRequired; @@ -32,6 +40,8 @@ import software.amazon.awssdk.core.runtime.transform.StreamingRequestMarshaller; import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.core.sync.ResponseTransformer; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -39,6 +49,11 @@ import software.amazon.awssdk.protocols.xml.AwsXmlProtocolFactory; import software.amazon.awssdk.protocols.xml.XmlOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.xml.auth.scheme.XmlAuthSchemeParams; +import software.amazon.awssdk.services.xml.auth.scheme.XmlAuthSchemeProvider; +import software.amazon.awssdk.services.xml.endpoints.XmlEndpointParams; +import software.amazon.awssdk.services.xml.endpoints.XmlEndpointProvider; +import software.amazon.awssdk.services.xml.endpoints.internal.XmlEndpointResolverUtils; import software.amazon.awssdk.services.xml.internal.ServiceVersionInfo; import software.amazon.awssdk.services.xml.internal.XmlServiceClientConfigurationBuilder; import software.amazon.awssdk.services.xml.model.APostOperationRequest; @@ -74,6 +89,7 @@ import software.amazon.awssdk.services.xml.transform.StreamingInputOperationRequestMarshaller; import software.amazon.awssdk.services.xml.transform.StreamingOutputOperationRequestMarshaller; import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link XmlClient}. @@ -142,7 +158,10 @@ public APostOperationResponse aPostOperation(APostOperationRequest aPostOperatio .withOperationName("APostOperation").withProtocolMetadata(protocolMetadata) .withCombinedResponseHandler(responseHandler).withMetricCollector(apiCallMetricCollector) .hostPrefixExpression(resolvedHostExpression).withRequestConfiguration(clientConfiguration) - .withInput(aPostOperationRequest).withMarshaller(new APostOperationRequestMarshaller(protocolFactory))); + .withInput(aPostOperationRequest) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "APostOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperation")) + .withMarshaller(new APostOperationRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); } @@ -188,9 +207,15 @@ public APostOperationWithOutputResponse aPostOperationWithOutput( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperationWithOutput").withProtocolMetadata(protocolMetadata) - .withCombinedResponseHandler(responseHandler).withMetricCollector(apiCallMetricCollector) - .withRequestConfiguration(clientConfiguration).withInput(aPostOperationWithOutputRequest) + .withOperationName("APostOperationWithOutput") + .withProtocolMetadata(protocolMetadata) + .withCombinedResponseHandler(responseHandler) + .withMetricCollector(apiCallMetricCollector) + .withRequestConfiguration(clientConfiguration) + .withInput(aPostOperationWithOutputRequest) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperationWithOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperationWithOutput")) .withMarshaller(new APostOperationWithOutputRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -235,6 +260,8 @@ public BearerAuthOperationResponse bearerAuthOperation(BearerAuthOperationReques .withCombinedResponseHandler(responseHandler).withMetricCollector(apiCallMetricCollector) .credentialType(CredentialType.TOKEN).withRequestConfiguration(clientConfiguration) .withInput(bearerAuthOperationRequest) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "BearerAuthOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "BearerAuthOperation")) .withMarshaller(new BearerAuthOperationRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -283,6 +310,9 @@ public GetOperationWithChecksumResponse getOperationWithChecksum( .withMetricCollector(apiCallMetricCollector) .withRequestConfiguration(clientConfiguration) .withInput(getOperationWithChecksumRequest) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "GetOperationWithChecksum", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GetOperationWithChecksum")) .putExecutionAttribute( SdkInternalExecutionAttribute.HTTP_CHECKSUM, HttpChecksum.builder().requestChecksumRequired(true).isRequestStreaming(false) @@ -336,6 +366,9 @@ public OperationWithChecksumRequiredResponse operationWithChecksumRequired( .withMetricCollector(apiCallMetricCollector) .withRequestConfiguration(clientConfiguration) .withInput(operationWithChecksumRequiredRequest) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithChecksumRequired", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithChecksumRequired")) .putExecutionAttribute(SdkInternalExecutionAttribute.HTTP_CHECKSUM_REQUIRED, HttpChecksumRequired.create()) .withMarshaller(new OperationWithChecksumRequiredRequestMarshaller(protocolFactory))); @@ -380,9 +413,15 @@ public OperationWithNoneAuthTypeResponse operationWithNoneAuthType( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithNoneAuthType").withProtocolMetadata(protocolMetadata) - .withCombinedResponseHandler(responseHandler).withMetricCollector(apiCallMetricCollector) - .withRequestConfiguration(clientConfiguration).withInput(operationWithNoneAuthTypeRequest) + .withOperationName("OperationWithNoneAuthType") + .withProtocolMetadata(protocolMetadata) + .withCombinedResponseHandler(responseHandler) + .withMetricCollector(apiCallMetricCollector) + .withRequestConfiguration(clientConfiguration) + .withInput(operationWithNoneAuthTypeRequest) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithNoneAuthType", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithNoneAuthType")) .withMarshaller(new OperationWithNoneAuthTypeRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -431,6 +470,9 @@ public OperationWithRequestCompressionResponse operationWithRequestCompression( .withMetricCollector(apiCallMetricCollector) .withRequestConfiguration(clientConfiguration) .withInput(operationWithRequestCompressionRequest) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithRequestCompression", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithRequestCompression")) .putExecutionAttribute(SdkInternalExecutionAttribute.REQUEST_COMPRESSION, RequestCompression.builder().encodings("gzip").isStreaming(false).build()) .withMarshaller(new OperationWithRequestCompressionRequestMarshaller(protocolFactory))); @@ -509,6 +551,9 @@ public ReturnT putOperationWithChecksum(PutOperationWithChecksumReques .withRequestConfiguration(clientConfiguration) .withInput(putOperationWithChecksumRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PutOperationWithChecksum", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PutOperationWithChecksum")) .putExecutionAttribute( SdkInternalExecutionAttribute.HTTP_CHECKSUM, HttpChecksum @@ -585,6 +630,9 @@ public StreamingInputOperationResponse streamingInputOperation(StreamingInputOpe .withMetricCollector(apiCallMetricCollector) .withRequestConfiguration(clientConfiguration) .withInput(streamingInputOperationRequest) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingInputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingInputOperation")) .withRequestBody(requestBody) .withMarshaller( StreamingRequestMarshaller.builder() @@ -639,10 +687,17 @@ public ReturnT streamingOutputOperation(StreamingOutputOperationReques return clientHandler.execute( new ClientExecutionParams() - .withOperationName("StreamingOutputOperation").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(streamingOutputOperationRequest) - .withMetricCollector(apiCallMetricCollector).withResponseTransformer(responseTransformer) + .withOperationName("StreamingOutputOperation") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(streamingOutputOperationRequest) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingOutputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingOutputOperation")) + .withResponseTransformer(responseTransformer) .withMarshaller(new StreamingOutputOperationRequestMarshaller(protocolFactory)), responseTransformer); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -669,6 +724,51 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + XmlAuthSchemeProvider requestAuthSchemeProvider = request.overrideConfiguration().flatMap(c -> c.authSchemeProvider()) + .map(p -> Validate.isInstanceOf(XmlAuthSchemeProvider.class, p, "Expected an instance of XmlAuthSchemeProvider")) + .orElse(null); + XmlAuthSchemeProvider authSchemeProvider = requestAuthSchemeProvider != null ? requestAuthSchemeProvider : Validate + .isInstanceOf(XmlAuthSchemeProvider.class, clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of XmlAuthSchemeProvider"); + XmlAuthSchemeParams.Builder paramsBuilder = XmlAuthSchemeParams.builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + XmlEndpointProvider provider = (XmlEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + XmlEndpointParams endpointParams = XmlEndpointResolverUtils.ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = XmlEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = XmlEndpointResolverUtils.authSchemeWithEndpointSignerProperties(endpointAuthSchemes, + selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + XmlEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolver-utils-with-endpointsbasedauth.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolver-utils-with-endpointsbasedauth.java new file mode 100644 index 000000000000..ad3b2e674750 --- /dev/null +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolver-utils-with-endpointsbasedauth.java @@ -0,0 +1,203 @@ +package software.amazon.awssdk.services.query.endpoints.internal; + +import java.util.List; +import java.util.Optional; +import software.amazon.awssdk.annotations.Generated; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.AwsExecutionAttribute; +import software.amazon.awssdk.awscore.endpoints.AccountIdEndpointMode; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; +import software.amazon.awssdk.awscore.endpoints.authscheme.SigV4AuthScheme; +import software.amazon.awssdk.awscore.endpoints.authscheme.SigV4aAuthScheme; +import software.amazon.awssdk.awscore.internal.useragent.BusinessMetricsUtils; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.core.useragent.BusinessMetricFeatureId; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.aws.scheme.AwsV4aAuthScheme; +import software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner; +import software.amazon.awssdk.http.auth.aws.signer.AwsV4aHttpSigner; +import software.amazon.awssdk.http.auth.aws.signer.RegionSet; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; +import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; +import software.amazon.awssdk.identity.spi.Identity; +import software.amazon.awssdk.services.query.endpoints.QueryClientContextParams; +import software.amazon.awssdk.services.query.endpoints.QueryEndpointParams; +import software.amazon.awssdk.services.query.jmespath.internal.JmesPathRuntime; +import software.amazon.awssdk.services.query.model.OperationWithContextParamRequest; +import software.amazon.awssdk.services.query.model.OperationWithMapOperationContextParamRequest; +import software.amazon.awssdk.services.query.model.OperationWithOperationContextParamRequest; +import software.amazon.awssdk.utils.AttributeMap; +import software.amazon.awssdk.utils.CollectionUtils; +import software.amazon.awssdk.utils.CompletableFutureUtils; + +@Generated("software.amazon.awssdk:codegen") +@SdkInternalApi +public final class QueryEndpointResolverUtils { + private QueryEndpointResolverUtils() { + } + + public static QueryEndpointParams ruleParams(SdkRequest request, ExecutionAttributes executionAttributes) { + QueryEndpointParams.Builder builder = QueryEndpointParams.builder(); + builder.region(AwsEndpointProviderUtils.regionBuiltIn(executionAttributes)); + builder.useDualStackEndpoint(AwsEndpointProviderUtils.dualStackEnabledBuiltIn(executionAttributes)); + builder.useFipsEndpoint(AwsEndpointProviderUtils.fipsEnabledBuiltIn(executionAttributes)); + builder.accountId(resolveAndRecordAccountIdFromIdentity(executionAttributes)); + builder.accountIdEndpointMode(recordAccountIdEndpointMode(executionAttributes)); + setClientContextParams(builder, executionAttributes); + setContextParams(builder, executionAttributes.getAttribute(AwsExecutionAttribute.OPERATION_NAME), request); + setStaticContextParams(builder, executionAttributes.getAttribute(AwsExecutionAttribute.OPERATION_NAME)); + setOperationContextParams(builder, executionAttributes.getAttribute(AwsExecutionAttribute.OPERATION_NAME), request); + return builder.build(); + } + + private static void setContextParams(QueryEndpointParams.Builder params, String operationName, SdkRequest request) { + switch (operationName) { + case "OperationWithContextParam": + setContextParams(params, (OperationWithContextParamRequest) request); + break; + default: + break; + } + } + + private static void setContextParams(QueryEndpointParams.Builder params, OperationWithContextParamRequest request) { + params.operationContextParam(request.stringMember()); + } + + private static void setStaticContextParams(QueryEndpointParams.Builder params, String operationName) { + switch (operationName) { + case "OperationWithStaticContextParams": + operationWithStaticContextParamsStaticContextParams(params); + break; + default: + break; + } + } + + private static void operationWithStaticContextParamsStaticContextParams(QueryEndpointParams.Builder params) { + params.staticStringParam("hello"); + } + + public static SelectedAuthScheme authSchemeWithEndpointSignerProperties( + List endpointAuthSchemes, SelectedAuthScheme selectedAuthScheme) { + for (EndpointAuthScheme endpointAuthScheme : endpointAuthSchemes) { + if (!endpointAuthScheme.schemeId().equals(selectedAuthScheme.authSchemeOption().schemeId())) { + continue; + } + AuthSchemeOption.Builder option = selectedAuthScheme.authSchemeOption().toBuilder(); + if (endpointAuthScheme instanceof SigV4AuthScheme) { + SigV4AuthScheme v4AuthScheme = (SigV4AuthScheme) endpointAuthScheme; + if (v4AuthScheme.isDisableDoubleEncodingSet()) { + option.putSignerProperty(AwsV4HttpSigner.DOUBLE_URL_ENCODE, !v4AuthScheme.disableDoubleEncoding()); + } + if (v4AuthScheme.signingRegion() != null) { + option.putSignerProperty(AwsV4HttpSigner.REGION_NAME, v4AuthScheme.signingRegion()); + } + if (v4AuthScheme.signingName() != null) { + option.putSignerProperty(AwsV4HttpSigner.SERVICE_SIGNING_NAME, v4AuthScheme.signingName()); + } + return new SelectedAuthScheme<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build()); + } + if (endpointAuthScheme instanceof SigV4aAuthScheme) { + SigV4aAuthScheme v4aAuthScheme = (SigV4aAuthScheme) endpointAuthScheme; + if (v4aAuthScheme.isDisableDoubleEncodingSet()) { + option.putSignerProperty(AwsV4aHttpSigner.DOUBLE_URL_ENCODE, !v4aAuthScheme.disableDoubleEncoding()); + } + if (!(selectedAuthScheme.authSchemeOption().schemeId().equals(AwsV4aAuthScheme.SCHEME_ID) && selectedAuthScheme + .authSchemeOption().signerProperty(AwsV4aHttpSigner.REGION_SET) != null) + && !CollectionUtils.isNullOrEmpty(v4aAuthScheme.signingRegionSet())) { + RegionSet regionSet = RegionSet.create(v4aAuthScheme.signingRegionSet()); + option.putSignerProperty(AwsV4aHttpSigner.REGION_SET, regionSet); + } + if (v4aAuthScheme.signingName() != null) { + option.putSignerProperty(AwsV4aHttpSigner.SERVICE_SIGNING_NAME, v4aAuthScheme.signingName()); + } + return new SelectedAuthScheme<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build()); + } + throw new IllegalArgumentException("Endpoint auth scheme '" + endpointAuthScheme.name() + + "' cannot be mapped to the SDK auth scheme. Was it declared in the service's model?"); + } + return selectedAuthScheme; + } + + private static void setClientContextParams(QueryEndpointParams.Builder params, ExecutionAttributes executionAttributes) { + AttributeMap clientContextParams = executionAttributes.getAttribute(SdkInternalExecutionAttribute.CLIENT_CONTEXT_PARAMS); + Optional.ofNullable(clientContextParams.get(QueryClientContextParams.BOOLEAN_CONTEXT_PARAM)).ifPresent( + params::booleanContextParam); + Optional.ofNullable(clientContextParams.get(QueryClientContextParams.STRING_CONTEXT_PARAM)).ifPresent( + params::stringContextParam); + } + + private static void setOperationContextParams(QueryEndpointParams.Builder params, String operationName, SdkRequest request) { + switch (operationName) { + case "OperationWithMapOperationContextParam": + setOperationContextParams(params, (OperationWithMapOperationContextParamRequest) request); + break; + case "OperationWithOperationContextParam": + setOperationContextParams(params, (OperationWithOperationContextParamRequest) request); + break; + default: + break; + } + } + + private static void setOperationContextParams(QueryEndpointParams.Builder params, + OperationWithMapOperationContextParamRequest request) { + JmesPathRuntime.Value input = new JmesPathRuntime.Value(request); + params.arnList(input.field("RequestMap").keys()); + } + + private static void setOperationContextParams(QueryEndpointParams.Builder params, + OperationWithOperationContextParamRequest request) { + JmesPathRuntime.Value input = new JmesPathRuntime.Value(request); + params.customEndpointArray(input.field("ListMember").field("StringList").wildcard().field("LeafString")); + } + + public static Optional hostPrefix(String operationName, SdkRequest request) { + switch (operationName) { + case "APostOperation": { + return Optional.of("foo-"); + } + default: + return Optional.empty(); + } + } + + private static String resolveAndRecordAccountIdFromIdentity(ExecutionAttributes executionAttributes) { + String accountId = accountIdFromIdentity(executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME)); + if (accountId != null) { + executionAttributes.getAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS).addMetric( + BusinessMetricFeatureId.RESOLVED_ACCOUNT_ID.value()); + } + return accountId; + } + + private static String accountIdFromIdentity(SelectedAuthScheme selectedAuthScheme) { + T identity = CompletableFutureUtils.joinLikeSync(selectedAuthScheme.identity()); + String accountId = null; + if (identity instanceof AwsCredentialsIdentity) { + accountId = ((AwsCredentialsIdentity) identity).accountId().orElse(null); + } + return accountId; + } + + private static String recordAccountIdEndpointMode(ExecutionAttributes executionAttributes) { + AccountIdEndpointMode mode = executionAttributes.getAttribute(AwsExecutionAttribute.AWS_AUTH_ACCOUNT_ID_ENDPOINT_MODE); + BusinessMetricsUtils.resolveAccountIdEndpointModeMetric(mode).ifPresent( + m -> executionAttributes.getAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS).addMetric(m)); + return mode.name().toLowerCase(); + } + + public static void setMetricValues(Endpoint endpoint, ExecutionAttributes executionAttributes) { + if (endpoint.attribute(AwsEndpointAttribute.METRIC_VALUES) != null) { + executionAttributes.getOptionalAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS).ifPresent( + metrics -> endpoint.attribute(AwsEndpointAttribute.METRIC_VALUES).forEach(v -> metrics.addMetric(v))); + } + } +} diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolver-utils-with-multiauthsigv4a.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolver-utils-with-multiauthsigv4a.java new file mode 100644 index 000000000000..f4a83932a826 --- /dev/null +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolver-utils-with-multiauthsigv4a.java @@ -0,0 +1,104 @@ +package software.amazon.awssdk.services.database.endpoints.internal; + +import java.util.List; +import java.util.Optional; +import software.amazon.awssdk.annotations.Generated; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.AwsExecutionAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; +import software.amazon.awssdk.awscore.endpoints.authscheme.SigV4AuthScheme; +import software.amazon.awssdk.awscore.endpoints.authscheme.SigV4aAuthScheme; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.aws.scheme.AwsV4aAuthScheme; +import software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner; +import software.amazon.awssdk.http.auth.aws.signer.AwsV4aHttpSigner; +import software.amazon.awssdk.http.auth.aws.signer.RegionSet; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; +import software.amazon.awssdk.identity.spi.Identity; +import software.amazon.awssdk.services.database.endpoints.DatabaseEndpointParams; +import software.amazon.awssdk.utils.CollectionUtils; + +@Generated("software.amazon.awssdk:codegen") +@SdkInternalApi +public final class DatabaseEndpointResolverUtils { + private DatabaseEndpointResolverUtils() { + } + + public static DatabaseEndpointParams ruleParams(SdkRequest request, ExecutionAttributes executionAttributes) { + DatabaseEndpointParams.Builder builder = DatabaseEndpointParams.builder(); + builder.region(AwsEndpointProviderUtils.regionBuiltIn(executionAttributes)); + builder.endpoint(AwsEndpointProviderUtils.endpointBuiltIn(executionAttributes)); + setContextParams(builder, executionAttributes.getAttribute(AwsExecutionAttribute.OPERATION_NAME), request); + setStaticContextParams(builder, executionAttributes.getAttribute(AwsExecutionAttribute.OPERATION_NAME)); + setOperationContextParams(builder, executionAttributes.getAttribute(AwsExecutionAttribute.OPERATION_NAME), request); + return builder.build(); + } + + private static void setContextParams(DatabaseEndpointParams.Builder params, String operationName, SdkRequest request) { + } + + private static void setStaticContextParams(DatabaseEndpointParams.Builder params, String operationName) { + } + + public static SelectedAuthScheme authSchemeWithEndpointSignerProperties( + List endpointAuthSchemes, SelectedAuthScheme selectedAuthScheme) { + for (EndpointAuthScheme endpointAuthScheme : endpointAuthSchemes) { + if (!endpointAuthScheme.schemeId().equals(selectedAuthScheme.authSchemeOption().schemeId())) { + continue; + } + AuthSchemeOption.Builder option = selectedAuthScheme.authSchemeOption().toBuilder(); + if (endpointAuthScheme instanceof SigV4AuthScheme) { + SigV4AuthScheme v4AuthScheme = (SigV4AuthScheme) endpointAuthScheme; + if (v4AuthScheme.isDisableDoubleEncodingSet()) { + option.putSignerProperty(AwsV4HttpSigner.DOUBLE_URL_ENCODE, !v4AuthScheme.disableDoubleEncoding()); + } + if (v4AuthScheme.signingRegion() != null) { + option.putSignerProperty(AwsV4HttpSigner.REGION_NAME, v4AuthScheme.signingRegion()); + } + if (v4AuthScheme.signingName() != null) { + option.putSignerProperty(AwsV4HttpSigner.SERVICE_SIGNING_NAME, v4AuthScheme.signingName()); + } + return new SelectedAuthScheme<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build()); + } + if (endpointAuthScheme instanceof SigV4aAuthScheme) { + SigV4aAuthScheme v4aAuthScheme = (SigV4aAuthScheme) endpointAuthScheme; + if (v4aAuthScheme.isDisableDoubleEncodingSet()) { + option.putSignerProperty(AwsV4aHttpSigner.DOUBLE_URL_ENCODE, !v4aAuthScheme.disableDoubleEncoding()); + } + if (!(selectedAuthScheme.authSchemeOption().schemeId().equals(AwsV4aAuthScheme.SCHEME_ID) && selectedAuthScheme + .authSchemeOption().signerProperty(AwsV4aHttpSigner.REGION_SET) != null) + && !CollectionUtils.isNullOrEmpty(v4aAuthScheme.signingRegionSet())) { + RegionSet regionSet = RegionSet.create(v4aAuthScheme.signingRegionSet()); + option.putSignerProperty(AwsV4aHttpSigner.REGION_SET, regionSet); + } + if (v4aAuthScheme.signingName() != null) { + option.putSignerProperty(AwsV4aHttpSigner.SERVICE_SIGNING_NAME, v4aAuthScheme.signingName()); + } + return new SelectedAuthScheme<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build()); + } + throw new IllegalArgumentException("Endpoint auth scheme '" + endpointAuthScheme.name() + + "' cannot be mapped to the SDK auth scheme. Was it declared in the service's model?"); + } + return selectedAuthScheme; + } + + private static void setOperationContextParams(DatabaseEndpointParams.Builder params, String operationName, SdkRequest request) { + } + + public static Optional hostPrefix(String operationName, SdkRequest request) { + return Optional.empty(); + } + + public static void setMetricValues(Endpoint endpoint, ExecutionAttributes executionAttributes) { + if (endpoint.attribute(AwsEndpointAttribute.METRIC_VALUES) != null) { + executionAttributes.getOptionalAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS).ifPresent( + metrics -> endpoint.attribute(AwsEndpointAttribute.METRIC_VALUES).forEach(v -> metrics.addMetric(v))); + } + } +} diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolver-utils-with-stringarray.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolver-utils-with-stringarray.java new file mode 100644 index 000000000000..aaed43624739 --- /dev/null +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolver-utils-with-stringarray.java @@ -0,0 +1,131 @@ +package software.amazon.awssdk.services.samplesvc.endpoints.internal; + +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import software.amazon.awssdk.annotations.Generated; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.AwsExecutionAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; +import software.amazon.awssdk.awscore.endpoints.authscheme.SigV4AuthScheme; +import software.amazon.awssdk.awscore.endpoints.authscheme.SigV4aAuthScheme; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner; +import software.amazon.awssdk.http.auth.aws.signer.AwsV4aHttpSigner; +import software.amazon.awssdk.http.auth.aws.signer.RegionSet; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; +import software.amazon.awssdk.identity.spi.Identity; +import software.amazon.awssdk.services.samplesvc.endpoints.SampleSvcEndpointParams; +import software.amazon.awssdk.services.samplesvc.jmespath.internal.JmesPathRuntime; +import software.amazon.awssdk.services.samplesvc.model.ListOfObjectsOperationRequest; +import software.amazon.awssdk.utils.CollectionUtils; + +@Generated("software.amazon.awssdk:codegen") +@SdkInternalApi +public final class SampleSvcEndpointResolverUtils { + private SampleSvcEndpointResolverUtils() { + } + + public static SampleSvcEndpointParams ruleParams(SdkRequest request, + ExecutionAttributes executionAttributes) { + SampleSvcEndpointParams.Builder builder = SampleSvcEndpointParams.builder(); + setContextParams(builder, executionAttributes.getAttribute(AwsExecutionAttribute.OPERATION_NAME), request); + setStaticContextParams(builder, executionAttributes.getAttribute(AwsExecutionAttribute.OPERATION_NAME)); + setOperationContextParams(builder, executionAttributes.getAttribute(AwsExecutionAttribute.OPERATION_NAME), request); + return builder.build(); + } + + private static void setContextParams(SampleSvcEndpointParams.Builder params, String operationName, + SdkRequest request) { + } + + private static void setStaticContextParams(SampleSvcEndpointParams.Builder params, + String operationName) { + switch (operationName) { + case "EmptyStaticContextOperation":emptyStaticContextOperationStaticContextParams(params); + break; + case "StaticContextOperation":staticContextOperationStaticContextParams(params); + break; + default:break; + } + } + + private static void emptyStaticContextOperationStaticContextParams( + SampleSvcEndpointParams.Builder params) { + params.stringArrayParam(Arrays.asList()); + } + + private static void staticContextOperationStaticContextParams( + SampleSvcEndpointParams.Builder params) { + params.stringArrayParam(Arrays.asList("staticValue1")); + } + + public static SelectedAuthScheme authSchemeWithEndpointSignerProperties( + List endpointAuthSchemes, SelectedAuthScheme selectedAuthScheme) { + for (EndpointAuthScheme endpointAuthScheme : endpointAuthSchemes) { + if (!endpointAuthScheme.schemeId().equals(selectedAuthScheme.authSchemeOption().schemeId())) { + continue; + } + AuthSchemeOption.Builder option = selectedAuthScheme.authSchemeOption().toBuilder(); + if (endpointAuthScheme instanceof SigV4AuthScheme) { + SigV4AuthScheme v4AuthScheme = (SigV4AuthScheme) endpointAuthScheme; + if (v4AuthScheme.isDisableDoubleEncodingSet()) { + option.putSignerProperty(AwsV4HttpSigner.DOUBLE_URL_ENCODE, !v4AuthScheme.disableDoubleEncoding()); + } + if (v4AuthScheme.signingRegion() != null) { + option.putSignerProperty(AwsV4HttpSigner.REGION_NAME, v4AuthScheme.signingRegion()); + } + if (v4AuthScheme.signingName() != null) { + option.putSignerProperty(AwsV4HttpSigner.SERVICE_SIGNING_NAME, v4AuthScheme.signingName()); + } + return new SelectedAuthScheme<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build()); + } + if (endpointAuthScheme instanceof SigV4aAuthScheme) { + SigV4aAuthScheme v4aAuthScheme = (SigV4aAuthScheme) endpointAuthScheme; + if (v4aAuthScheme.isDisableDoubleEncodingSet()) { + option.putSignerProperty(AwsV4aHttpSigner.DOUBLE_URL_ENCODE, !v4aAuthScheme.disableDoubleEncoding()); + } + if (!CollectionUtils.isNullOrEmpty(v4aAuthScheme.signingRegionSet())) { + RegionSet regionSet = RegionSet.create(v4aAuthScheme.signingRegionSet()); + option.putSignerProperty(AwsV4aHttpSigner.REGION_SET, regionSet); + } + if (v4aAuthScheme.signingName() != null) { + option.putSignerProperty(AwsV4aHttpSigner.SERVICE_SIGNING_NAME, v4aAuthScheme.signingName()); + } + return new SelectedAuthScheme<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build()); + } + throw new IllegalArgumentException("Endpoint auth scheme '" + endpointAuthScheme.name() + "' cannot be mapped to the SDK auth scheme. Was it declared in the service's model?"); + } + return selectedAuthScheme; + } + + private static void setOperationContextParams(SampleSvcEndpointParams.Builder params, + String operationName, SdkRequest request) { + switch (operationName) { + case "ListOfObjectsOperation":setOperationContextParams(params, (ListOfObjectsOperationRequest) request); + break; + default:break; + } + } + + private static void setOperationContextParams(SampleSvcEndpointParams.Builder params, + ListOfObjectsOperationRequest request) { + JmesPathRuntime.Value input = new JmesPathRuntime.Value(request); + params.stringArrayParam(input.field("nested").field("listOfObjects").wildcard().field("key").stringValues()); + } + + public static Optional hostPrefix(String operationName, SdkRequest request) { + return Optional.empty(); + } + + public static void setMetricValues(Endpoint endpoint, ExecutionAttributes executionAttributes) { + if (endpoint.attribute(AwsEndpointAttribute.METRIC_VALUES) != null) { + executionAttributes.getOptionalAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS).ifPresent(metrics -> endpoint.attribute(AwsEndpointAttribute.METRIC_VALUES).forEach(v -> metrics.addMetric(v))); + } + } +} diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolver-utils.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolver-utils.java new file mode 100644 index 000000000000..01c7ba43831f --- /dev/null +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolver-utils.java @@ -0,0 +1,210 @@ +package software.amazon.awssdk.services.query.endpoints.internal; + +import java.util.List; +import java.util.Optional; +import software.amazon.awssdk.annotations.Generated; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.AwsExecutionAttribute; +import software.amazon.awssdk.awscore.endpoints.AccountIdEndpointMode; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; +import software.amazon.awssdk.awscore.endpoints.authscheme.SigV4AuthScheme; +import software.amazon.awssdk.awscore.endpoints.authscheme.SigV4aAuthScheme; +import software.amazon.awssdk.awscore.internal.useragent.BusinessMetricsUtils; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.core.useragent.BusinessMetricFeatureId; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner; +import software.amazon.awssdk.http.auth.aws.signer.AwsV4aHttpSigner; +import software.amazon.awssdk.http.auth.aws.signer.RegionSet; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; +import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; +import software.amazon.awssdk.identity.spi.Identity; +import software.amazon.awssdk.services.query.endpoints.QueryClientContextParams; +import software.amazon.awssdk.services.query.endpoints.QueryEndpointParams; +import software.amazon.awssdk.services.query.jmespath.internal.JmesPathRuntime; +import software.amazon.awssdk.services.query.model.OperationWithContextParamRequest; +import software.amazon.awssdk.services.query.model.OperationWithCustomizedOperationContextParamRequest; +import software.amazon.awssdk.services.query.model.OperationWithMapOperationContextParamRequest; +import software.amazon.awssdk.services.query.model.OperationWithOperationContextParamRequest; +import software.amazon.awssdk.utils.AttributeMap; +import software.amazon.awssdk.utils.CollectionUtils; +import software.amazon.awssdk.utils.CompletableFutureUtils; + +@Generated("software.amazon.awssdk:codegen") +@SdkInternalApi +public final class QueryEndpointResolverUtils { + private QueryEndpointResolverUtils() { + } + + public static QueryEndpointParams ruleParams(SdkRequest request, ExecutionAttributes executionAttributes) { + QueryEndpointParams.Builder builder = QueryEndpointParams.builder(); + builder.region(AwsEndpointProviderUtils.regionBuiltIn(executionAttributes)); + builder.useDualStackEndpoint(AwsEndpointProviderUtils.dualStackEnabledBuiltIn(executionAttributes)); + builder.useFipsEndpoint(AwsEndpointProviderUtils.fipsEnabledBuiltIn(executionAttributes)); + builder.accountId(resolveAndRecordAccountIdFromIdentity(executionAttributes)); + builder.accountIdEndpointMode(recordAccountIdEndpointMode(executionAttributes)); + setClientContextParams(builder, executionAttributes); + setContextParams(builder, executionAttributes.getAttribute(AwsExecutionAttribute.OPERATION_NAME), request); + setStaticContextParams(builder, executionAttributes.getAttribute(AwsExecutionAttribute.OPERATION_NAME)); + setOperationContextParams(builder, executionAttributes.getAttribute(AwsExecutionAttribute.OPERATION_NAME), request); + return builder.build(); + } + + private static void setContextParams(QueryEndpointParams.Builder params, String operationName, SdkRequest request) { + switch (operationName) { + case "OperationWithContextParam": + setContextParams(params, (OperationWithContextParamRequest) request); + break; + default: + break; + } + } + + private static void setContextParams(QueryEndpointParams.Builder params, OperationWithContextParamRequest request) { + params.operationContextParam(request.stringMember()); + } + + private static void setStaticContextParams(QueryEndpointParams.Builder params, String operationName) { + switch (operationName) { + case "OperationWithStaticContextParams": + operationWithStaticContextParamsStaticContextParams(params); + break; + default: + break; + } + } + + private static void operationWithStaticContextParamsStaticContextParams(QueryEndpointParams.Builder params) { + params.staticStringParam("hello"); + } + + public static SelectedAuthScheme authSchemeWithEndpointSignerProperties( + List endpointAuthSchemes, SelectedAuthScheme selectedAuthScheme) { + for (EndpointAuthScheme endpointAuthScheme : endpointAuthSchemes) { + if (!endpointAuthScheme.schemeId().equals(selectedAuthScheme.authSchemeOption().schemeId())) { + continue; + } + AuthSchemeOption.Builder option = selectedAuthScheme.authSchemeOption().toBuilder(); + if (endpointAuthScheme instanceof SigV4AuthScheme) { + SigV4AuthScheme v4AuthScheme = (SigV4AuthScheme) endpointAuthScheme; + if (v4AuthScheme.isDisableDoubleEncodingSet()) { + option.putSignerProperty(AwsV4HttpSigner.DOUBLE_URL_ENCODE, !v4AuthScheme.disableDoubleEncoding()); + } + if (v4AuthScheme.signingRegion() != null) { + option.putSignerProperty(AwsV4HttpSigner.REGION_NAME, v4AuthScheme.signingRegion()); + } + if (v4AuthScheme.signingName() != null) { + option.putSignerProperty(AwsV4HttpSigner.SERVICE_SIGNING_NAME, v4AuthScheme.signingName()); + } + return new SelectedAuthScheme<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build()); + } + if (endpointAuthScheme instanceof SigV4aAuthScheme) { + SigV4aAuthScheme v4aAuthScheme = (SigV4aAuthScheme) endpointAuthScheme; + if (v4aAuthScheme.isDisableDoubleEncodingSet()) { + option.putSignerProperty(AwsV4aHttpSigner.DOUBLE_URL_ENCODE, !v4aAuthScheme.disableDoubleEncoding()); + } + if (!CollectionUtils.isNullOrEmpty(v4aAuthScheme.signingRegionSet())) { + RegionSet regionSet = RegionSet.create(v4aAuthScheme.signingRegionSet()); + option.putSignerProperty(AwsV4aHttpSigner.REGION_SET, regionSet); + } + if (v4aAuthScheme.signingName() != null) { + option.putSignerProperty(AwsV4aHttpSigner.SERVICE_SIGNING_NAME, v4aAuthScheme.signingName()); + } + return new SelectedAuthScheme<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build()); + } + throw new IllegalArgumentException("Endpoint auth scheme '" + endpointAuthScheme.name() + + "' cannot be mapped to the SDK auth scheme. Was it declared in the service's model?"); + } + return selectedAuthScheme; + } + + private static void setClientContextParams(QueryEndpointParams.Builder params, ExecutionAttributes executionAttributes) { + AttributeMap clientContextParams = executionAttributes.getAttribute(SdkInternalExecutionAttribute.CLIENT_CONTEXT_PARAMS); + Optional.ofNullable(clientContextParams.get(QueryClientContextParams.BOOLEAN_CONTEXT_PARAM)).ifPresent( + params::booleanContextParam); + Optional.ofNullable(clientContextParams.get(QueryClientContextParams.STRING_CONTEXT_PARAM)).ifPresent( + params::stringContextParam); + } + + private static void setOperationContextParams(QueryEndpointParams.Builder params, String operationName, SdkRequest request) { + switch (operationName) { + case "OperationWithCustomizedOperationContextParam": + setOperationContextParams(params, (OperationWithCustomizedOperationContextParamRequest) request); + break; + case "OperationWithMapOperationContextParam": + setOperationContextParams(params, (OperationWithMapOperationContextParamRequest) request); + break; + case "OperationWithOperationContextParam": + setOperationContextParams(params, (OperationWithOperationContextParamRequest) request); + break; + default: + break; + } + } + + private static void setOperationContextParams(QueryEndpointParams.Builder params, + OperationWithCustomizedOperationContextParamRequest request) { + JmesPathRuntime.Value input = new JmesPathRuntime.Value(request); + params.customEndpointArray(input.field("ListMember").field("StringList").wildcard().field("LeafString").stringValues()); + } + + private static void setOperationContextParams(QueryEndpointParams.Builder params, + OperationWithMapOperationContextParamRequest request) { + JmesPathRuntime.Value input = new JmesPathRuntime.Value(request); + params.arnList(input.field("RequestMap").keys().stringValues()); + } + + private static void setOperationContextParams(QueryEndpointParams.Builder params, + OperationWithOperationContextParamRequest request) { + JmesPathRuntime.Value input = new JmesPathRuntime.Value(request); + params.customEndpointArray(input.field("ListMember").field("StringList").wildcard().field("LeafString").stringValues()); + } + + public static Optional hostPrefix(String operationName, SdkRequest request) { + switch (operationName) { + case "APostOperation": { + return Optional.of("foo-"); + } + default: + return Optional.empty(); + } + } + + private static String resolveAndRecordAccountIdFromIdentity(ExecutionAttributes executionAttributes) { + String accountId = accountIdFromIdentity(executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME)); + if (accountId != null) { + executionAttributes.getAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS).addMetric( + BusinessMetricFeatureId.RESOLVED_ACCOUNT_ID.value()); + } + return accountId; + } + + private static String accountIdFromIdentity(SelectedAuthScheme selectedAuthScheme) { + T identity = CompletableFutureUtils.joinLikeSync(selectedAuthScheme.identity()); + String accountId = null; + if (identity instanceof AwsCredentialsIdentity) { + accountId = ((AwsCredentialsIdentity) identity).accountId().orElse(null); + } + return accountId; + } + + private static String recordAccountIdEndpointMode(ExecutionAttributes executionAttributes) { + AccountIdEndpointMode mode = executionAttributes.getAttribute(AwsExecutionAttribute.AWS_AUTH_ACCOUNT_ID_ENDPOINT_MODE); + BusinessMetricsUtils.resolveAccountIdEndpointModeMetric(mode).ifPresent( + m -> executionAttributes.getAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS).addMetric(m)); + return mode.name().toLowerCase(); + } + + public static void setMetricValues(Endpoint endpoint, ExecutionAttributes executionAttributes) { + if (endpoint.attribute(AwsEndpointAttribute.METRIC_VALUES) != null) { + executionAttributes.getOptionalAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS).ifPresent( + metrics -> endpoint.attribute(AwsEndpointAttribute.METRIC_VALUES).forEach(v -> metrics.addMetric(v))); + } + } +} diff --git a/codegen/src/main/resources/software/amazon/awssdk/codegen/rules/AwsEndpointProviderUtils.java.resource b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/endpoints/AwsEndpointProviderUtils.java similarity index 75% rename from codegen/src/main/resources/software/amazon/awssdk/codegen/rules/AwsEndpointProviderUtils.java.resource rename to core/aws-core/src/main/java/software/amazon/awssdk/awscore/endpoints/AwsEndpointProviderUtils.java index 1d0fedb9f7ff..53555c1adb53 100644 --- a/codegen/src/main/resources/software/amazon/awssdk/codegen/rules/AwsEndpointProviderUtils.java.resource +++ b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/endpoints/AwsEndpointProviderUtils.java @@ -1,27 +1,41 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.awscore.endpoints; + import static software.amazon.awssdk.utils.FunctionalUtils.invokeSafely; import java.net.URI; -import java.util.List; -import java.util.Map; -import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.annotations.SdkProtectedApi; import software.amazon.awssdk.awscore.AwsExecutionAttribute; -import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; -import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; -import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.useragent.BusinessMetricFeatureId; import software.amazon.awssdk.endpoints.Endpoint; import software.amazon.awssdk.http.SdkHttpRequest; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.utils.HostnameValidator; -import software.amazon.awssdk.utils.Logger; import software.amazon.awssdk.utils.StringUtils; import software.amazon.awssdk.utils.http.SdkHttpUtils; -@SdkInternalApi +/** + * Shared utility methods for endpoint resolution across all AWS services. + * Previously this was generated per-service as an identical copy; now de-duplicated here. + */ +@SdkProtectedApi public final class AwsEndpointProviderUtils { - private static final Logger LOG = Logger.loggerFor(AwsEndpointProviderUtils.class); private AwsEndpointProviderUtils() { } @@ -57,14 +71,14 @@ public static String endpointBuiltIn(ExecutionAttributes executionAttributes) { } /** - * Read {@link SdkExecutionAttribute#CLIENT_ENDPOINT_PROVIDER}'s isEndpointOverridden attribute. + * True if the {@link SdkInternalExecutionAttribute#CLIENT_ENDPOINT_PROVIDER}'s endpoint is overridden. */ public static boolean endpointIsOverridden(ExecutionAttributes attrs) { return attrs.getAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER).isEndpointOverridden(); } /** - * True if the the {@link SdkInternalExecutionAttribute#IS_DISCOVERED_ENDPOINT} attribute is present and its value + * True if the {@link SdkInternalExecutionAttribute#IS_DISCOVERED_ENDPOINT} attribute is present and its value * is {@code true}, {@code false} otherwise. */ public static boolean endpointIsDiscovered(ExecutionAttributes attrs) { @@ -72,7 +86,7 @@ public static boolean endpointIsDiscovered(ExecutionAttributes attrs) { } /** - * True if the the {@link SdkInternalExecutionAttribute#DISABLE_HOST_PREFIX_INJECTION} attribute is present and its + * True if the {@link SdkInternalExecutionAttribute#DISABLE_HOST_PREFIX_INJECTION} attribute is present and its * value is {@code true}, {@code false} otherwise. */ public static boolean disableHostPrefixInjection(ExecutionAttributes attrs) { @@ -86,14 +100,11 @@ public static Endpoint addHostPrefix(Endpoint endpoint, String prefix) { if (StringUtils.isBlank(prefix)) { return endpoint; } - validatePrefixIsHostNameCompliant(prefix); - URI originalUrl = endpoint.url(); String newHost = prefix + endpoint.url().getHost(); URI newUrl = invokeSafely(() -> new URI(originalUrl.getScheme(), null, newHost, originalUrl.getPort(), originalUrl.getPath(), originalUrl.getQuery(), originalUrl.getFragment())); - return endpoint.toBuilder().url(newUrl).build(); } @@ -107,7 +118,7 @@ public static Endpoint addHostPrefix(Endpoint endpoint, String prefix) { * no way to know, just from the HTTP request object, where the override path ends (if it's even there) and where * the request path starts. Additionally, the rule itself may also append other parts to the endpoint override path. *

- * To solve this issue, we pass in the endpoint set on the path, which allows us to the strip the path from the + * To solve this issue, we pass in the endpoint set on the client, which allows us to the strip the path from the * endpoint override from the request path, and then correctly combine the paths. *

* For example, let's suppose the endpoint override on the client is {@code https://example.com/a}. Then we call an @@ -117,19 +128,11 @@ public static Endpoint addHostPrefix(Endpoint endpoint, String prefix) { * path is {@code https://example.com/a/b/c}. */ public static SdkHttpRequest setUri(SdkHttpRequest request, URI clientEndpoint, URI resolvedUri) { - // [client endpoint path] String clientEndpointPath = clientEndpoint.getRawPath(); - - // [client endpoint path]/[request path] String requestPath = request.encodedPath(); - - // [client endpoint path]/[additional path added by resolver] String resolvedUriPath = resolvedUri.getRawPath(); String finalPath = requestPath; - - // If there is an additional path added by resolver, i.e., [additional path added by resolver] not null, - // we need to combine the path if (!resolvedUriPath.equals(clientEndpointPath)) { finalPath = combinePath(clientEndpointPath, requestPath, resolvedUriPath); } @@ -139,26 +142,18 @@ public static SdkHttpRequest setUri(SdkHttpRequest request, URI clientEndpoint, } /** - * Our goal is to construct [client endpoint path]/[additional path added by resolver]/[request path], so we just - * need to strip the client endpoint path from the marshalled request path to isolate just the part added by the - * marshaller. Trailing slash is removed from client endpoint path before stripping because it could cause the - * leading slash to be removed from the request path: e.g., StringUtils.replaceOnce("/", "//test", "") generates - * "/test" and the expected result is "//test" + * Constructs [resolved URI path]/[request path without client endpoint path]. Strips the client endpoint path from + * the marshalled request path to isolate just the part added by the marshaller, then appends it to the resolved path. */ private static String combinePath(String clientEndpointPath, String requestPath, String resolvedUriPath) { String requestPathWithClientPathRemoved = StringUtils.replaceOnce(requestPath, clientEndpointPath, ""); - String finalPath = SdkHttpUtils.appendUri(resolvedUriPath, requestPathWithClientPathRemoved); - return finalPath; + return SdkHttpUtils.appendUri(resolvedUriPath, requestPathWithClientPathRemoved); } private static void validatePrefixIsHostNameCompliant(String prefix) { - String[] components = splitHostLabelOnDots(prefix); + String[] components = prefix.split("\\."); for (String component : components) { HostnameValidator.validateHostnameCompliant(component, component, "request"); } } - - private static String[] splitHostLabelOnDots(String label) { - return label.split("\\."); - } } diff --git a/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilder.java b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilder.java index cd82644da16f..254169ea84c3 100644 --- a/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilder.java +++ b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilder.java @@ -28,12 +28,13 @@ import java.util.Map; import java.util.Optional; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.auth.credentials.AwsCredentials; import software.amazon.awssdk.auth.signer.AwsSignerExecutionAttribute; import software.amazon.awssdk.awscore.AwsExecutionAttribute; -import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.internal.authcontext.AuthorizationStrategy; import software.amazon.awssdk.awscore.internal.authcontext.AuthorizationStrategyFactory; +import software.amazon.awssdk.awscore.internal.identity.AwsIdentityProviderUpdater; import software.amazon.awssdk.awscore.util.SignerOverrideUtils; import software.amazon.awssdk.core.HttpChecksumConstant; import software.amazon.awssdk.core.RequestOverrideConfiguration; @@ -147,9 +148,26 @@ private AwsExecutionContextBuilder() { // Auth Scheme resolution related attributes putAuthSchemeResolutionAttributes(executionAttributes, clientConfig, originalRequest); + if (executionParams.authSchemeOptionsResolver() != null) { + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER, + executionParams.authSchemeOptionsResolver()); + } + + if (executionParams.endpointResolver() != null) { + executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER, + executionParams.endpointResolver()); + } + + // Set the identity provider updater for the pipeline stage to use + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDER_UPDATER, + AwsIdentityProviderUpdater.create()); + ExecutionInterceptorChain executionInterceptorChain = new ExecutionInterceptorChain(clientConfig.option(SdkClientOption.EXECUTION_INTERCEPTORS)); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_SNAPSHOT_PRE_INTERCEPTORS, + executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME)); + InterceptorContext interceptorContext = InterceptorContext.builder() .request(originalRequest) .asyncRequestBody(executionParams.getAsyncRequestBody()) @@ -173,6 +191,13 @@ private AwsExecutionContextBuilder() { signer, executionAttributes, executionAttributes.getOptionalAttribute( AwsSignerExecutionAttribute.AWS_CREDENTIALS).orElse(null))); + Signer resolvedSigner = signer; + AwsCredentials capturedCredentials = executionAttributes.getOptionalAttribute( + AwsSignerExecutionAttribute.AWS_CREDENTIALS).orElse(null); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SIGNING_METHOD_UPDATER, attrs -> + attrs.putAttribute(HttpChecksumConstant.SIGNING_METHOD, + resolveSigningMethodUsed(resolvedSigner, attrs, capturedCredentials))); + putStreamingInputOutputTypesMetadata(executionAttributes, executionParams); putHttpClientConfigTypeMetadata(executionAttributes, clientConfig); @@ -282,7 +307,7 @@ private static void putAuthSchemeResolutionAttributes(ExecutionAttributes execut // request preferred over client. Map> authSchemes = clientConfig.option(SdkClientOption.AUTH_SCHEMES); - IdentityProviders identityProviders = resolveIdentityProviders(originalRequest, clientConfig); + IdentityProviders identityProviders = clientConfig.option(SdkClientOption.IDENTITY_PROVIDERS); executionAttributes .putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_RESOLVER, authSchemeProvider) @@ -290,31 +315,6 @@ private static void putAuthSchemeResolutionAttributes(ExecutionAttributes execut .putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, identityProviders); } - private static IdentityProviders resolveIdentityProviders(SdkRequest originalRequest, - SdkClientConfiguration clientConfig) { - IdentityProviders identityProviders = - clientConfig.option(SdkClientOption.IDENTITY_PROVIDERS); - - // identityProviders can be null, for new core with old client. In this case, even if AwsRequestOverrideConfiguration - // has credentialsIdentityProvider set (because it is in new core), it is ok to not setup IDENTITY_PROVIDERS, as old - // client won't have AUTH_SCHEME_PROVIDER/AUTH_SCHEMES set either, which are also needed for SRA logic. - if (identityProviders == null) { - return null; - } - - return originalRequest - .overrideConfiguration() - .filter(c -> c instanceof AwsRequestOverrideConfiguration) - .map(c -> (AwsRequestOverrideConfiguration) c) - .map(c -> { - return identityProviders.copy(b -> { - c.credentialsIdentityProvider().ifPresent(b::putIdentityProvider); - c.tokenIdentityProvider().ifPresent(b::putIdentityProvider); - }); - }) - .orElse(identityProviders); - } - /** * Finalize {@link SdkRequest} by running beforeExecution and modifyRequest interceptors. * @@ -377,7 +377,7 @@ private static AuthSchemeProvider resolveAuthSchemeProvider(SdkRequest request, } private static BusinessMetricCollection - resolveUserAgentBusinessMetrics(SdkClientConfiguration clientConfig, + resolveUserAgentBusinessMetrics(SdkClientConfiguration clientConfig, ClientExecutionParams executionParams) { BusinessMetricCollection businessMetrics = new BusinessMetricCollection(); Optional retryModeMetric = resolveRetryMode(clientConfig.option(RETRY_POLICY), @@ -395,4 +395,5 @@ private static boolean isRpcV2CborProtocol(SdkProtocolMetadata protocolMetadata) return protocolMetadata != null && SMITHY_RPC_V2_CBOR.toString().equals(protocolMetadata.serviceProtocol()); } + } diff --git a/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/identity/AwsIdentityProviderUpdater.java b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/identity/AwsIdentityProviderUpdater.java new file mode 100644 index 000000000000..c4b291aa48c9 --- /dev/null +++ b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/identity/AwsIdentityProviderUpdater.java @@ -0,0 +1,73 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.awscore.internal.identity; + +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.auth.credentials.AwsCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.auth.signer.AwsSignerExecutionAttribute; +import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.spi.identity.IdentityProviderUpdater; +import software.amazon.awssdk.identity.spi.IdentityProviders; + +/** + * AWS implementation of {@link IdentityProviderUpdater} that reads credential overrides + * from {@link AwsRequestOverrideConfiguration} and deprecated {@link AwsSignerExecutionAttribute#AWS_CREDENTIALS}. + */ +@SdkInternalApi +public final class AwsIdentityProviderUpdater implements IdentityProviderUpdater { + + private static final AwsIdentityProviderUpdater INSTANCE = new AwsIdentityProviderUpdater(); + + private AwsIdentityProviderUpdater() { + } + + public static AwsIdentityProviderUpdater create() { + return INSTANCE; + } + + @Override + public IdentityProviders update(SdkRequest request, IdentityProviders base, ExecutionAttributes executionAttributes) { + if (base == null) { + return null; + } + + IdentityProviders updated = request.overrideConfiguration() + .filter(c -> c instanceof AwsRequestOverrideConfiguration) + .map(c -> (AwsRequestOverrideConfiguration) c) + .map(c -> base.copy(b -> { + c.credentialsIdentityProvider().ifPresent(b::putIdentityProvider); + c.tokenIdentityProvider().ifPresent(b::putIdentityProvider); + })) + .orElse(null); + + if (updated != null) { + return updated; + } + + // Support deprecated AWS_CREDENTIALS execution attribute for backwards compatibility + // with interceptors that set credentials via AwsSignerExecutionAttribute.AWS_CREDENTIALS + AwsCredentials credentials = executionAttributes.getOptionalAttribute(AwsSignerExecutionAttribute.AWS_CREDENTIALS) + .orElse(null); + if (credentials != null) { + return base.copy(b -> b.putIdentityProvider(StaticCredentialsProvider.create(credentials))); + } + + return base; + } +} diff --git a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/AwsEndpointProviderUtilsTest.java b/core/aws-core/src/test/java/software/amazon/awssdk/awscore/endpoints/AwsEndpointProviderUtilsTest.java similarity index 69% rename from test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/AwsEndpointProviderUtilsTest.java rename to core/aws-core/src/test/java/software/amazon/awssdk/awscore/endpoints/AwsEndpointProviderUtilsTest.java index 83d9d1d0fb59..b1fa2b82d6a1 100644 --- a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/AwsEndpointProviderUtilsTest.java +++ b/core/aws-core/src/test/java/software/amazon/awssdk/awscore/endpoints/AwsEndpointProviderUtilsTest.java @@ -13,13 +13,13 @@ * permissions and limitations under the License. */ -package software.amazon.awssdk.services.endpointproviders; +package software.amazon.awssdk.awscore.endpoints; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import java.net.URI; -import org.junit.Test; +import org.junit.jupiter.api.Test; import software.amazon.awssdk.awscore.AwsExecutionAttribute; import software.amazon.awssdk.core.ClientEndpointProvider; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; @@ -29,103 +29,147 @@ import software.amazon.awssdk.http.SdkHttpMethod; import software.amazon.awssdk.http.SdkHttpRequest; import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.restjsonendpointproviders.endpoints.internal.AwsEndpointProviderUtils; -public class AwsEndpointProviderUtilsTest { +class AwsEndpointProviderUtilsTest { + + @Test + void regionBuiltIn_returnsAttrValue() { + ExecutionAttributes attrs = new ExecutionAttributes(); + attrs.putAttribute(AwsExecutionAttribute.AWS_REGION, Region.US_EAST_1); + assertThat(AwsEndpointProviderUtils.regionBuiltIn(attrs)).isEqualTo(Region.US_EAST_1); + } + @Test - public void endpointOverridden_attrIsFalse_returnsFalse() { + void dualStackEnabledBuiltIn_returnsAttrValue() { + ExecutionAttributes attrs = new ExecutionAttributes(); + attrs.putAttribute(AwsExecutionAttribute.DUALSTACK_ENDPOINT_ENABLED, true); + assertThat(AwsEndpointProviderUtils.dualStackEnabledBuiltIn(attrs)).isEqualTo(true); + } + + @Test + void fipsEnabledBuiltIn_returnsAttrValue() { + ExecutionAttributes attrs = new ExecutionAttributes(); + attrs.putAttribute(AwsExecutionAttribute.FIPS_ENDPOINT_ENABLED, true); + assertThat(AwsEndpointProviderUtils.fipsEnabledBuiltIn(attrs)).isEqualTo(true); + } + + @Test + void endpointBuiltIn_doesNotIncludeQueryParams() { + URI endpoint = URI.create("https://example.com/path?foo=bar"); + ExecutionAttributes attrs = new ExecutionAttributes(); + attrs.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, + ClientEndpointProvider.forEndpointOverride(endpoint)); + attrs.putAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS, new BusinessMetricCollection()); + + assertThat(AwsEndpointProviderUtils.endpointBuiltIn(attrs)).isEqualTo("https://example.com/path"); + } + + @Test + void endpointBuiltIn_notOverridden_returnsNull() { + ExecutionAttributes attrs = new ExecutionAttributes(); + attrs.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, endpointProvider(false)); + assertThat(AwsEndpointProviderUtils.endpointBuiltIn(attrs)).isNull(); + } + + @Test + void endpointOverridden_attrIsFalse_returnsFalse() { ExecutionAttributes attrs = new ExecutionAttributes(); attrs.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, endpointProvider(false)); assertThat(AwsEndpointProviderUtils.endpointIsOverridden(attrs)).isFalse(); } @Test - public void endpointOverridden_attrIsTrue_returnsTrue() { + void endpointOverridden_attrIsTrue_returnsTrue() { ExecutionAttributes attrs = new ExecutionAttributes(); attrs.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, endpointProvider(true)); assertThat(AwsEndpointProviderUtils.endpointIsOverridden(attrs)).isTrue(); } @Test - public void endpointIsDiscovered_attrIsFalse_returnsFalse() { + void endpointIsDiscovered_attrIsFalse_returnsFalse() { ExecutionAttributes attrs = new ExecutionAttributes(); attrs.putAttribute(SdkInternalExecutionAttribute.IS_DISCOVERED_ENDPOINT, false); assertThat(AwsEndpointProviderUtils.endpointIsDiscovered(attrs)).isFalse(); } @Test - public void endpointIsDiscovered_attrIsAbsent_returnsFalse() { + void endpointIsDiscovered_attrIsAbsent_returnsFalse() { ExecutionAttributes attrs = new ExecutionAttributes(); assertThat(AwsEndpointProviderUtils.endpointIsDiscovered(attrs)).isFalse(); } @Test - public void endpointIsDiscovered_attrIsTrue_returnsTrue() { + void endpointIsDiscovered_attrIsTrue_returnsTrue() { ExecutionAttributes attrs = new ExecutionAttributes(); attrs.putAttribute(SdkInternalExecutionAttribute.IS_DISCOVERED_ENDPOINT, true); assertThat(AwsEndpointProviderUtils.endpointIsDiscovered(attrs)).isTrue(); } @Test - public void disableHostPrefixInjection_attrIsFalse_returnsFalse() { + void disableHostPrefixInjection_attrIsFalse_returnsFalse() { ExecutionAttributes attrs = new ExecutionAttributes(); attrs.putAttribute(SdkInternalExecutionAttribute.DISABLE_HOST_PREFIX_INJECTION, false); assertThat(AwsEndpointProviderUtils.disableHostPrefixInjection(attrs)).isFalse(); } @Test - public void disableHostPrefixInjection_attrIsAbsent_returnsFalse() { + void disableHostPrefixInjection_attrIsAbsent_returnsFalse() { ExecutionAttributes attrs = new ExecutionAttributes(); assertThat(AwsEndpointProviderUtils.disableHostPrefixInjection(attrs)).isFalse(); } @Test - public void disableHostPrefixInjection_attrIsTrue_returnsTrue() { + void disableHostPrefixInjection_attrIsTrue_returnsTrue() { ExecutionAttributes attrs = new ExecutionAttributes(); attrs.putAttribute(SdkInternalExecutionAttribute.DISABLE_HOST_PREFIX_INJECTION, true); assertThat(AwsEndpointProviderUtils.disableHostPrefixInjection(attrs)).isTrue(); } @Test - public void regionBuiltIn_returnsAttrValue() { - ExecutionAttributes attrs = new ExecutionAttributes(); - attrs.putAttribute(AwsExecutionAttribute.AWS_REGION, Region.US_EAST_1); - assertThat(AwsEndpointProviderUtils.regionBuiltIn(attrs)).isEqualTo(Region.US_EAST_1); + void addHostPrefix_prefixIsNull_returnsUnmodified() { + URI url = URI.create("https://foo.aws"); + Endpoint e = Endpoint.builder().url(url).build(); + assertThat(AwsEndpointProviderUtils.addHostPrefix(e, null).url()).isEqualTo(url); } @Test - public void dualStackEnabledBuiltIn_returnsAttrValue() { - ExecutionAttributes attrs = new ExecutionAttributes(); - attrs.putAttribute(AwsExecutionAttribute.DUALSTACK_ENDPOINT_ENABLED, true); - assertThat(AwsEndpointProviderUtils.dualStackEnabledBuiltIn(attrs)).isEqualTo(true); + void addHostPrefix_prefixIsEmpty_returnsUnmodified() { + URI url = URI.create("https://foo.aws"); + Endpoint e = Endpoint.builder().url(url).build(); + assertThat(AwsEndpointProviderUtils.addHostPrefix(e, "").url()).isEqualTo(url); } @Test - public void fipsEnabledBuiltIn_returnsAttrValue() { - ExecutionAttributes attrs = new ExecutionAttributes(); - attrs.putAttribute(AwsExecutionAttribute.FIPS_ENDPOINT_ENABLED, true); - assertThat(AwsEndpointProviderUtils.fipsEnabledBuiltIn(attrs)).isEqualTo(true); + void addHostPrefix_prefixPresent_returnsPrefixPrepended() { + URI url = URI.create("https://foo.aws"); + Endpoint e = Endpoint.builder().url(url).build(); + assertThat(AwsEndpointProviderUtils.addHostPrefix(e, "api.").url()) + .isEqualTo(URI.create("https://api.foo.aws")); } @Test - public void endpointBuiltIn_doesNotIncludeQueryParams() { - URI endpoint = URI.create("https://example.com/path?foo=bar"); - ExecutionAttributes attrs = new ExecutionAttributes(); - attrs.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, - ClientEndpointProvider.forEndpointOverride(endpoint)); - attrs.putAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS, new BusinessMetricCollection()); + void addHostPrefix_prefixPresent_preservesPortPathAndQuery() { + URI url = URI.create("https://foo.aws:1234/a/b/c?queryParam1=val1"); + Endpoint e = Endpoint.builder().url(url).build(); + assertThat(AwsEndpointProviderUtils.addHostPrefix(e, "api.").url()) + .isEqualTo(URI.create("https://api.foo.aws:1234/a/b/c?queryParam1=val1")); + } - assertThat(AwsEndpointProviderUtils.endpointBuiltIn(attrs).toString()).isEqualTo("https://example.com/path"); + @Test + void addHostPrefix_prefixInvalid_throws() { + URI url = URI.create("https://foo.aws"); + Endpoint e = Endpoint.builder().url(url).build(); + assertThatThrownBy(() -> AwsEndpointProviderUtils.addHostPrefix(e, "foo#bar.*baz")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("component must match the pattern"); } @Test - public void setUri_combinesPathsCorrectly() { + void setUri_combinesPathsCorrectly() { URI clientEndpoint = URI.create("https://override.example.com/a"); - URI requestUri = URI.create("https://override.example.com/a/c"); URI resolvedUri = URI.create("https://override.example.com/a/b"); - SdkHttpRequest request = SdkHttpRequest.builder() - .uri(requestUri) + .uri(URI.create("https://override.example.com/a/c")) .method(SdkHttpMethod.GET) .build(); @@ -134,13 +178,11 @@ public void setUri_combinesPathsCorrectly() { } @Test - public void setUri_doubleSlash_combinesPathsCorrectly() { + void setUri_doubleSlash_combinesPathsCorrectly() { URI clientEndpoint = URI.create("https://override.example.com/a"); - URI requestUri = URI.create("https://override.example.com/a//c"); URI resolvedUri = URI.create("https://override.example.com/a/b"); - SdkHttpRequest request = SdkHttpRequest.builder() - .uri(requestUri) + .uri(URI.create("https://override.example.com/a//c")) .method(SdkHttpMethod.GET) .build(); @@ -149,12 +191,11 @@ public void setUri_doubleSlash_combinesPathsCorrectly() { } @Test - public void setUri_withTrailingSlashNoPath_combinesPathsCorrectly() { + void setUri_withTrailingSlashNoPath_combinesPathsCorrectly() { URI clientEndpoint = URI.create("https://override.example.com/"); - URI requestUri = URI.create("https://override.example.com//a"); URI resolvedUri = URI.create("https://override.example.com/"); SdkHttpRequest request = SdkHttpRequest.builder() - .uri(requestUri) + .uri(URI.create("https://override.example.com//a")) .method(SdkHttpMethod.GET) .build(); @@ -163,57 +204,17 @@ public void setUri_withTrailingSlashNoPath_combinesPathsCorrectly() { } @Test - public void addHostPrefix_prefixIsNull_returnsUnModified() { - URI url = URI.create("https://foo.aws"); - Endpoint e = Endpoint.builder() - .url(url) - .build(); - - assertThat(AwsEndpointProviderUtils.addHostPrefix(e, null).url()).isEqualTo(url); - } - - @Test - public void addHostPrefix_prefixIsEmpty_returnsUnModified() { - URI url = URI.create("https://foo.aws"); - Endpoint e = Endpoint.builder() - .url(url) - .build(); - - assertThat(AwsEndpointProviderUtils.addHostPrefix(e, "").url()).isEqualTo(url); - } - - @Test - public void addHostPrefix_prefixPresent_returnsPrefixPrepended() { - URI url = URI.create("https://foo.aws"); - Endpoint e = Endpoint.builder() - .url(url) - .build(); - - URI expected = URI.create("https://api.foo.aws"); - assertThat(AwsEndpointProviderUtils.addHostPrefix(e, "api.").url()).isEqualTo(expected); - } - - @Test - public void addHostPrefix_prefixPresent_preservesPortPathAndQuery() { - URI url = URI.create("https://foo.aws:1234/a/b/c?queryParam1=val1"); - Endpoint e = Endpoint.builder() - .url(url) - .build(); - - URI expected = URI.create("https://api.foo.aws:1234/a/b/c?queryParam1=val1"); - assertThat(AwsEndpointProviderUtils.addHostPrefix(e, "api.").url()).isEqualTo(expected); - } - - @Test - public void addHostPrefix_prefixInvalid_throws() { - URI url = URI.create("https://foo.aws"); - Endpoint e = Endpoint.builder() - .url(url) - .build(); + void setUri_noPathDifference_keepsRequestPath() { + URI clientEndpoint = URI.create("https://example.com"); + URI resolvedUri = URI.create("https://resolved.example.com"); + SdkHttpRequest request = SdkHttpRequest.builder() + .uri(URI.create("https://example.com/operation")) + .method(SdkHttpMethod.GET) + .build(); - assertThatThrownBy(() -> AwsEndpointProviderUtils.addHostPrefix(e, "foo#bar.*baz")) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("component must match the pattern"); + SdkHttpRequest result = AwsEndpointProviderUtils.setUri(request, clientEndpoint, resolvedUri); + assertThat(result.host()).isEqualTo("resolved.example.com"); + assertThat(result.encodedPath()).isEqualTo("/operation"); } private static ClientEndpointProvider endpointProvider(boolean isEndpointOverridden) { diff --git a/core/aws-core/src/test/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilderTest.java b/core/aws-core/src/test/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilderTest.java index a3ff2966ec52..bbd3358059a9 100644 --- a/core/aws-core/src/test/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilderTest.java +++ b/core/aws-core/src/test/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilderTest.java @@ -38,7 +38,6 @@ import org.mockito.junit.MockitoJUnitRunner; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; -import software.amazon.awssdk.auth.token.credentials.StaticTokenProvider; import software.amazon.awssdk.awscore.AwsRequest; import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; import software.amazon.awssdk.awscore.client.config.AwsClientOption; @@ -77,7 +76,6 @@ import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; import software.amazon.awssdk.identity.spi.IdentityProvider; import software.amazon.awssdk.identity.spi.IdentityProviders; -import software.amazon.awssdk.identity.spi.TokenIdentity; import software.amazon.awssdk.profiles.ProfileFile; @RunWith(MockitoJUnitRunner.class) @@ -402,49 +400,6 @@ public void invokeInterceptorsAndCreateExecutionContext_withoutIdentityProviders assertThat(executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS)).isNull(); } - @Test - public void invokeInterceptorsAndCreateExecutionContext_requestOverrideForIdentityProvider_updatesIdentityProviders() { - IdentityProvider clientCredentialsProvider = - StaticCredentialsProvider.create(AwsBasicCredentials.create("foo", "bar")); - IdentityProvider clientTokenProvider = StaticTokenProvider.create(() -> "client-token"); - IdentityProviders identityProviders = - IdentityProviders.builder() - .putIdentityProvider(clientCredentialsProvider) - .putIdentityProvider(clientTokenProvider) - .build(); - SdkClientConfiguration clientConfig = testClientConfiguration() - .option(SdkClientOption.IDENTITY_PROVIDERS, identityProviders) - .build(); - - IdentityProvider requestCredentialsProvider = - StaticCredentialsProvider.create(AwsBasicCredentials.create("akid", "skid")); - IdentityProvider requestTokenProvider = StaticTokenProvider.create(() -> "request-token"); - Optional overrideConfiguration = - Optional.of(AwsRequestOverrideConfiguration.builder() - .credentialsProvider(requestCredentialsProvider) - .tokenIdentityProvider(requestTokenProvider) - .build()); - when(sdkRequest.overrideConfiguration()).thenReturn(overrideConfiguration); - - ClientExecutionParams executionParams = clientExecutionParams(); - - ExecutionContext executionContext = - AwsExecutionContextBuilder.invokeInterceptorsAndCreateExecutionContext(executionParams, clientConfig); - - IdentityProviders actualIdentityProviders = - executionContext.executionAttributes().getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS); - - IdentityProvider actualIdentityProvider = - actualIdentityProviders.identityProvider(AwsCredentialsIdentity.class); - - assertThat(actualIdentityProvider).isSameAs(requestCredentialsProvider); - - IdentityProvider actualTokenProvider = - actualIdentityProviders.identityProvider(TokenIdentity.class); - - assertThat(actualTokenProvider).isSameAs(requestTokenProvider); - } - @Test public void invokeInterceptorsAndCreateExecutionContext_withRequestBody_addsUserAgentMetadata() throws IOException { ClientExecutionParams executionParams = clientExecutionParams(); diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/SelectedAuthScheme.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/SelectedAuthScheme.java index 10cb488792e2..8772260e3042 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/SelectedAuthScheme.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/SelectedAuthScheme.java @@ -22,9 +22,27 @@ import software.amazon.awssdk.identity.spi.Identity; import software.amazon.awssdk.utils.Validate; -/** - * A container for the identity resolver, signer and auth option that we selected for use with this service call attempt. - */ + +/// +/// A container for the identity resolver, signer and auth option that we selected for use with this service call attempt. +/// ## The Hierarchy +/// ``` +/// IDENTITY_PROVIDERS (IdentityProviders) +/// └── contains multiple IdentityProvider instances +/// e.g., IdentityProvider for AWS credentials +/// e.g., IdentityProvider for bearer tokens +/// +/// AUTH_SCHEMES (Map>) +/// └── each AuthScheme knows: +/// - which IdentityProvider type it needs +/// - which HttpSigner to use +/// +/// SELECTED_AUTH_SCHEME (SelectedAuthScheme) +/// └── the chosen auth scheme, containing: +/// - identity: CompletableFuture ← the resolved identity! +/// - signer: HttpSigner +/// - authSchemeOption: AuthSchemeOption +/// ``` @SdkProtectedApi public final class SelectedAuthScheme { private final CompletableFuture identity; diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/client/handler/ClientExecutionParams.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/client/handler/ClientExecutionParams.java index ef31aba69ecb..77e0150b1af3 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/client/handler/ClientExecutionParams.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/client/handler/ClientExecutionParams.java @@ -29,7 +29,9 @@ import software.amazon.awssdk.core.http.HttpResponseHandler; import software.amazon.awssdk.core.interceptor.ExecutionAttribute; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.internal.endpoint.EndpointResolver; import software.amazon.awssdk.core.runtime.transform.Marshaller; +import software.amazon.awssdk.core.spi.identity.AuthSchemeOptionsResolver; import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.core.sync.ResponseTransformer; import software.amazon.awssdk.metrics.MetricCollector; @@ -64,6 +66,8 @@ public final class ClientExecutionParams { private MetricCollector metricCollector; private final ExecutionAttributes attributes = new ExecutionAttributes(); private SdkClientConfiguration requestConfiguration; + private AuthSchemeOptionsResolver authSchemeOptionsResolver; + private EndpointResolver endpointResolver; public Marshaller getMarshaller() { return marshaller; @@ -275,4 +279,23 @@ public ClientExecutionParams withRequestConfiguration(SdkCl this.requestConfiguration = requestConfiguration; return this; } + + public AuthSchemeOptionsResolver authSchemeOptionsResolver() { + return authSchemeOptionsResolver; + } + + public ClientExecutionParams withAuthSchemeOptionsResolver( + AuthSchemeOptionsResolver authSchemeOptionsResolver) { + this.authSchemeOptionsResolver = authSchemeOptionsResolver; + return this; + } + + public EndpointResolver endpointResolver() { + return endpointResolver; + } + + public ClientExecutionParams withEndpointResolver(EndpointResolver endpointResolver) { + this.endpointResolver = endpointResolver; + return this; + } } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/http/auth/AuthSchemeResolver.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/http/auth/AuthSchemeResolver.java new file mode 100644 index 000000000000..27ec50412508 --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/http/auth/AuthSchemeResolver.java @@ -0,0 +1,309 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.http.auth; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import software.amazon.awssdk.annotations.SdkProtectedApi; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; +import software.amazon.awssdk.core.exception.SdkException; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.core.internal.util.MetricUtils; +import software.amazon.awssdk.core.metrics.CoreMetric; +import software.amazon.awssdk.core.spi.identity.AuthSchemeOptionsResolver; +import software.amazon.awssdk.core.spi.identity.IdentityProviderUpdater; +import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; +import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; +import software.amazon.awssdk.http.auth.spi.signer.SignerProperty; +import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; +import software.amazon.awssdk.identity.spi.Identity; +import software.amazon.awssdk.identity.spi.IdentityProvider; +import software.amazon.awssdk.identity.spi.IdentityProviders; +import software.amazon.awssdk.identity.spi.ResolveIdentityRequest; +import software.amazon.awssdk.identity.spi.TokenIdentity; +import software.amazon.awssdk.metrics.MetricCollector; +import software.amazon.awssdk.metrics.SdkMetric; +import software.amazon.awssdk.utils.Logger; + +/** + * Shared utility for selecting auth schemes from a list of options. + */ +@SdkProtectedApi +public final class AuthSchemeResolver { + + private static final Logger LOG = Logger.loggerFor(AuthSchemeResolver.class); + + private AuthSchemeResolver() { + } + + /** + * Resolve an auth scheme from execution attributes, applying any identity provider overrides. + * This is a convenience method for use by interceptors that need to resolve an auth scheme + * outside of the normal pipeline flow (e.g., presign interceptors). + * + * @param request The SDK request (may contain credential overrides) + * @param executionAttributes The execution attributes containing auth scheme resolution inputs + * @return The selected auth scheme + */ + public static SelectedAuthScheme resolveAuthScheme( + SdkRequest request, ExecutionAttributes executionAttributes) { + AuthSchemeOptionsResolver optionsResolver = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER); + Map> authSchemes = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES); + IdentityProviders identityProviders = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS); + + IdentityProviderUpdater updater = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDER_UPDATER); + if (updater != null) { + identityProviders = updater.update(request, identityProviders, executionAttributes); + } + + List authOptions = optionsResolver.resolve(request); + return selectAuthScheme(authOptions, authSchemes, identityProviders, null); + } + + /** + * Select an auth scheme from the given options. + * + * @param authOptions List of auth scheme options to try in order + * @param authSchemes Map of available auth schemes + * @param identityProviders Identity providers to use for resolving identity + * @param metricCollector Optional metric collector for recording identity fetch duration + * @return The selected auth scheme + * @throws SdkException if no auth scheme could be selected + */ + public static SelectedAuthScheme selectAuthScheme( + List authOptions, + Map> authSchemes, + IdentityProviders identityProviders, + MetricCollector metricCollector) { + + List> discardedReasons = new ArrayList<>(); + + for (AuthSchemeOption authOption : authOptions) { + AuthScheme authScheme = authSchemes.get(authOption.schemeId()); + SelectedAuthScheme selectedAuthScheme = trySelectAuthScheme( + authOption, authScheme, identityProviders, discardedReasons, metricCollector); + + if (selectedAuthScheme != null) { + if (!discardedReasons.isEmpty()) { + LOG.debug(() -> String.format("%s auth will be used, discarded: '%s'", + authOption.schemeId(), + discardedReasons.stream().map(Supplier::get).collect(Collectors.joining(", ")))); + } + return selectedAuthScheme; + } + } + + throw SdkException.builder() + .message("Failed to determine how to authenticate the user: " + + discardedReasons.stream().map(Supplier::get).collect(Collectors.joining(", "))) + .build(); + } + + /** + * Merge properties from any pre-existing auth scheme into the selected one. + * + * After auth scheme resolution produces a fresh selectedAuthScheme, this method ensures that any signer properties + * explicitly set by interceptors (e.g., signing region override) take priority over the resolved values. + */ + public static SelectedAuthScheme mergePreExistingAuthSchemeProperties( + SelectedAuthScheme selectedAuthScheme, + ExecutionAttributes executionAttributes) { + + // The "existing" auth scheme is what's currently on SELECTED_AUTH_SCHEME - potentially modified by interceptors. + SelectedAuthScheme existingAuthScheme = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + + if (existingAuthScheme == null) { + return selectedAuthScheme; + } + + // Snapshot taken before interceptors ran — used to detect what interceptors changed. + SelectedAuthScheme authSchemeBeforeInterceptors = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_SNAPSHOT_PRE_INTERCEPTORS); + + // If no interceptor modified the auth scheme option, skip the diff logic. Still merge existing properties with + // putIfAbsent so that properties from the initial placeholder (e.g., REGION_NAME) carry over to the + // freshly resolved scheme. + if (authSchemeBeforeInterceptors != null && + authSchemeBeforeInterceptors.authSchemeOption() == existingAuthScheme.authSchemeOption()) { + AuthSchemeOption.Builder mergedOption = selectedAuthScheme.authSchemeOption().toBuilder(); + existingAuthScheme.authSchemeOption().forEachSignerProperty(mergedOption::putSignerPropertyIfAbsent); + existingAuthScheme.authSchemeOption().forEachIdentityProperty(mergedOption::putIdentityPropertyIfAbsent); + return new SelectedAuthScheme<>( + selectedAuthScheme.identity(), + selectedAuthScheme.signer(), + mergedOption.build() + ); + } + + // Start with the freshly resolved auth scheme as the base. + AuthSchemeOption.Builder mergedOption = selectedAuthScheme.authSchemeOption().toBuilder(); + + // For each signer property on the interceptor-modified scheme: + // If the interceptor changed it (differs from pre-interceptor snapshot), apply interceptor override + // If unchanged (same as before interceptors) only add if not already on the resolved scheme + existingAuthScheme.authSchemeOption().forEachSignerProperty(new AuthSchemeOption.SignerPropertyConsumer() { + @Override + public void accept(SignerProperty key, S value) { + if (wasModifiedByInterceptor(authSchemeBeforeInterceptors, key, value)) { + mergedOption.putSignerProperty(key, value); + } else { + mergedOption.putSignerPropertyIfAbsent(key, value); + } + } + }); + + existingAuthScheme.authSchemeOption().forEachIdentityProperty(mergedOption::putIdentityPropertyIfAbsent); + + return new SelectedAuthScheme<>( + selectedAuthScheme.identity(), + selectedAuthScheme.signer(), + mergedOption.build() + ); + } + + /** + * Returns true if the given property value differs from what it was before interceptors ran, + * meaning an interceptor explicitly changed it. + */ + private static boolean wasModifiedByInterceptor(SelectedAuthScheme authSchemeBeforeInterceptors, + SignerProperty key, T currentValue) { + if (authSchemeBeforeInterceptors == null) { + return false; + } + T originalValue = authSchemeBeforeInterceptors.authSchemeOption().signerProperty(key); + return !Objects.equals(originalValue, currentValue); + } + + /** + * Re-applies interceptor-modified signer properties onto the current auth scheme. + * Called after endpoint resolution, which may have overwritten properties that interceptors set. + */ + public static void applyInterceptorModifiedProperties(SelectedAuthScheme currentScheme, + SelectedAuthScheme authSchemeBeforeInterceptors, + SelectedAuthScheme afterInterceptors, + ExecutionAttributes attrs) { + if (afterInterceptors == null) { + return; + } + doApplyInterceptorModifiedProperties(currentScheme, authSchemeBeforeInterceptors, afterInterceptors, attrs); + } + + @SuppressWarnings("unchecked") + private static void doApplyInterceptorModifiedProperties( + SelectedAuthScheme currentScheme, + SelectedAuthScheme authSchemeBeforeInterceptors, + SelectedAuthScheme afterInterceptors, + ExecutionAttributes attrs) { + + // Start with the current endpoint resolved auth scheme as the base. + AuthSchemeOption.Builder mergedOption = currentScheme.authSchemeOption().toBuilder(); + boolean[] changed = {false}; + + // For each property on the post-interceptor scheme, check if the interceptor changed it. + // If yes, apply it onto the current scheme. + afterInterceptors.authSchemeOption().forEachSignerProperty(new AuthSchemeOption.SignerPropertyConsumer() { + @Override + public void accept(SignerProperty key, S value) { + if (wasModifiedByInterceptor(authSchemeBeforeInterceptors, key, value)) { + mergedOption.putSignerProperty(key, value); + changed[0] = true; + } + } + }); + + // Only update SELECTED_AUTH_SCHEME if at least one property was re-applied. + if (changed[0]) { + attrs.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, + new SelectedAuthScheme<>(currentScheme.identity(), + currentScheme.signer(), + mergedOption.build())); + } + } + + private static SelectedAuthScheme trySelectAuthScheme( + AuthSchemeOption authOption, + AuthScheme authScheme, + IdentityProviders identityProviders, + List> discardedReasons, + MetricCollector metricCollector) { + + if (authScheme == null) { + discardedReasons.add(() -> String.format("'%s' is not enabled for this request.", authOption.schemeId())); + return null; + } + + IdentityProvider identityProvider = authScheme.identityProvider(identityProviders); + if (identityProvider == null) { + discardedReasons.add(() -> String.format("'%s' does not have an identity provider configured.", + authOption.schemeId())); + return null; + } + + HttpSigner signer; + try { + signer = authScheme.signer(); + } catch (RuntimeException e) { + discardedReasons.add(() -> String.format("'%s' signer could not be retrieved: %s", + authOption.schemeId(), e.getMessage())); + return null; + } + + ResolveIdentityRequest.Builder identityRequestBuilder = ResolveIdentityRequest.builder(); + authOption.forEachIdentityProperty(identityRequestBuilder::putProperty); + + CompletableFuture identity = resolveIdentity( + identityProvider, identityRequestBuilder.build(), metricCollector); + + return new SelectedAuthScheme<>(identity, signer, authOption); + } + + private static CompletableFuture resolveIdentity( + IdentityProvider identityProvider, + ResolveIdentityRequest request, + MetricCollector metricCollector) { + + SdkMetric metric = getIdentityMetric(identityProvider); + if (metric == null || metricCollector == null) { + return identityProvider.resolveIdentity(request); + } + return MetricUtils.reportDuration(() -> identityProvider.resolveIdentity(request), metricCollector, metric); + } + + private static SdkMetric getIdentityMetric(IdentityProvider identityProvider) { + Class identityType = identityProvider.identityType(); + if (identityType == AwsCredentialsIdentity.class) { + return CoreMetric.CREDENTIALS_FETCH_DURATION; + } + if (identityType == TokenIdentity.class) { + return CoreMetric.TOKEN_FETCH_DURATION; + } + return null; + } +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/interceptor/SdkInternalExecutionAttribute.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/interceptor/SdkInternalExecutionAttribute.java index 529e129e18cd..77daf8558a6e 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/interceptor/SdkInternalExecutionAttribute.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/interceptor/SdkInternalExecutionAttribute.java @@ -15,9 +15,11 @@ package software.amazon.awssdk.core.interceptor; +import java.net.URI; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; import software.amazon.awssdk.annotations.SdkProtectedApi; import software.amazon.awssdk.core.ClientEndpointProvider; import software.amazon.awssdk.core.SdkClient; @@ -28,7 +30,10 @@ import software.amazon.awssdk.core.checksums.ResponseChecksumValidation; import software.amazon.awssdk.core.interceptor.trait.HttpChecksum; import software.amazon.awssdk.core.interceptor.trait.HttpChecksumRequired; +import software.amazon.awssdk.core.internal.endpoint.EndpointResolver; import software.amazon.awssdk.core.internal.interceptor.trait.RequestCompression; +import software.amazon.awssdk.core.spi.identity.AuthSchemeOptionsResolver; +import software.amazon.awssdk.core.spi.identity.IdentityProviderUpdater; import software.amazon.awssdk.core.useragent.AdditionalMetadata; import software.amazon.awssdk.core.useragent.BusinessMetricCollection; import software.amazon.awssdk.endpoints.Endpoint; @@ -166,12 +171,63 @@ public final class SdkInternalExecutionAttribute extends SdkExecutionAttribute { */ public static final ExecutionAttribute IDENTITY_PROVIDERS = new ExecutionAttribute<>("IdentityProviders"); + /** + * Callback for updating identity providers based on request-level overrides. + * This allows aws-core to provide AWS-specific logic without sdk-core depending on aws-core. + */ + public static final ExecutionAttribute IDENTITY_PROVIDER_UPDATER = + new ExecutionAttribute<>("IdentityProviderUpdater"); + + /** + * Callback to resolve auth scheme options from the (possibly modified) request. + * Called by AuthSchemeResolutionStage after interceptors have run. + */ + public static final ExecutionAttribute AUTH_SCHEME_OPTIONS_RESOLVER = + new ExecutionAttribute<>("AuthSchemeOptionsResolver"); + + /** + * Callback to recompute {@code SIGNING_METHOD} after auth scheme resolution. + * Set by {@code AwsExecutionContextBuilder} + */ + public static final ExecutionAttribute> SIGNING_METHOD_UPDATER = + new ExecutionAttribute<>("SigningMethodUpdater"); + + /** + * Callback for resolving the endpoint. Generated per-service as a lambda in the client class. + * Called by EndpointResolutionStage after interceptors have run. + */ + public static final ExecutionAttribute ENDPOINT_RESOLVER = + new ExecutionAttribute<>("EndpointResolver"); + + /** + * The HTTP request URI captured before modifyHttpRequest interceptors run. + * Used by EndpointResolutionStage to detect if a customer interceptor modified the URL. + */ + public static final ExecutionAttribute HTTP_REQUEST_URI_BEFORE_MODIFY = + new ExecutionAttribute<>("HttpRequestUriBeforeModify"); + /** * The selected auth scheme for a request. */ public static final ExecutionAttribute> SELECTED_AUTH_SCHEME = new ExecutionAttribute<>("SelectedAuthScheme"); + /** + * Snapshot of {@link #SELECTED_AUTH_SCHEME} taken before execution interceptors run. + * Used by {@code AuthSchemeResolver#mergePreExistingAuthSchemeProperties} to detect which signer properties + * were explicitly modified by interceptors (and should therefore override the freshly-resolved values). + */ + public static final ExecutionAttribute> AUTH_SCHEME_SNAPSHOT_PRE_INTERCEPTORS = + new ExecutionAttribute<>("AuthSchemeSnapshotPreInterceptors"); + + /** + * Snapshot of {@link #SELECTED_AUTH_SCHEME} taken after interceptors run but before auth scheme resolution. + * Together with {@link #AUTH_SCHEME_SNAPSHOT_PRE_INTERCEPTORS}, this allows detecting which signer properties + * were explicitly modified by interceptors so they can be re-applied after endpoint resolution. + */ + public static final ExecutionAttribute> AUTH_SCHEME_SNAPSHOT_POST_INTERCEPTORS = + new ExecutionAttribute<>("AuthSchemeSnapshotPostInterceptors"); + /** * The supported compression algorithms for an operation, and whether the operation is streaming or not. */ diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/endpoint/EndpointResolver.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/endpoint/EndpointResolver.java new file mode 100644 index 000000000000..5eafd02e8c3e --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/endpoint/EndpointResolver.java @@ -0,0 +1,37 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.endpoint; + +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.endpoints.Endpoint; + +/** + * Callback interface for resolving endpoints from the request and execution context. + */ +@FunctionalInterface +@SdkInternalApi +public interface EndpointResolver { + /** + * Resolves the endpoint for the given request. + * + * @param request The SDK request (after interceptors have modified it) + * @param executionAttributes The execution attributes containing client config, region, etc. + * @return The resolved endpoint + */ + Endpoint resolve(SdkRequest request, ExecutionAttributes executionAttributes); +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/handler/BaseClientHandler.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/handler/BaseClientHandler.java index 9aea1f328993..41cb4f6477dc 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/handler/BaseClientHandler.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/handler/BaseClientHandler.java @@ -80,6 +80,13 @@ static InterceptorContext finalizeSdkHttpFu addHttpRequest(executionContext, request); runAfterMarshallingInterceptors(executionContext); + + // Snapshot the HTTP request URI before modifyHttpRequest interceptors run. + // EndpointResolutionStage uses this to detect if a customer interceptor modified the URL. + executionContext.executionAttributes().putAttribute( + SdkInternalExecutionAttribute.HTTP_REQUEST_URI_BEFORE_MODIFY, + executionContext.interceptorContext().httpRequest().getUri()); + return runModifyHttpRequestAndHttpContentInterceptors(executionContext); } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonAsyncHttpClient.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonAsyncHttpClient.java index 59bd964d6fad..5772db0555a3 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonAsyncHttpClient.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonAsyncHttpClient.java @@ -39,7 +39,9 @@ import software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncExecutionFailureExceptionReportingStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncRetryableStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncSigningStage; +import software.amazon.awssdk.core.internal.http.pipeline.stages.AuthSchemeResolutionStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.CompressRequestStage; +import software.amazon.awssdk.core.internal.http.pipeline.stages.EndpointResolutionStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.HttpChecksumStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.MakeAsyncHttpRequestStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.MakeRequestImmutableStage; @@ -197,6 +199,8 @@ public CompletableFuture execute( .then(MergeCustomQueryParamsStage::new) .then(QueryParametersToBodyStage::new) .then(() -> new CompressRequestStage(httpClientDependencies)) + .then(AuthSchemeResolutionStage::new) + .then(EndpointResolutionStage::new) .then(() -> new HttpChecksumStage(ClientType.ASYNC)) .then(ApplyUserAgentStage::new) .then(MakeRequestImmutableStage::new) diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonSyncHttpClient.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonSyncHttpClient.java index dccaad1e2109..1996766372a4 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonSyncHttpClient.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonSyncHttpClient.java @@ -34,9 +34,11 @@ import software.amazon.awssdk.core.internal.http.pipeline.stages.ApiCallTimeoutTrackingStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.ApplyTransactionIdStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.ApplyUserAgentStage; +import software.amazon.awssdk.core.internal.http.pipeline.stages.AuthSchemeResolutionStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.BeforeTransmissionExecutionInterceptorsStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.BeforeUnmarshallingExecutionInterceptorsStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.CompressRequestStage; +import software.amazon.awssdk.core.internal.http.pipeline.stages.EndpointResolutionStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.ExecutionFailureExceptionReportingStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.HandleResponseStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.HttpChecksumStage; @@ -185,6 +187,8 @@ public OutputT execute(HttpResponseHandler> response .then(MergeCustomQueryParamsStage::new) .then(QueryParametersToBodyStage::new) .then(() -> new CompressRequestStage(httpClientDependencies)) + .then(AuthSchemeResolutionStage::new) + .then(EndpointResolutionStage::new) .then(() -> new HttpChecksumStage(ClientType.SYNC)) .then(ApplyUserAgentStage::new) .then(MakeRequestImmutableStage::new) diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/ApiCallMetricCollectionStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/ApiCallMetricCollectionStage.java index 3b78dedaf2ad..c7a6783fa7da 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/ApiCallMetricCollectionStage.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/ApiCallMetricCollectionStage.java @@ -40,17 +40,23 @@ public ApiCallMetricCollectionStage(RequestPipeline execute(SdkHttpFullRequest input, RequestExecutionContext context) throws Exception { MetricCollector metricCollector = context.executionContext().metricCollector(); - MetricUtils.collectServiceEndpointMetrics(metricCollector, input); // Note: at this point, any exception, even a service exception, will // be thrown from the wrapped pipeline so we can't use // MetricUtil.measureDuration() long callStart = System.nanoTime(); try { - return wrapped.execute(input, context); + Response response = wrapped.execute(input, context); + return response; } finally { long d = System.nanoTime() - callStart; metricCollector.reportMetric(CoreMetric.API_CALL_DURATION, Duration.ofNanos(d)); + // Collect SERVICE_ENDPOINT after pipeline execution so that EndpointResolutionStage + // has applied the resolved URL to the request. + SdkHttpFullRequest finalRequest = context.executionContext().interceptorContext().httpRequest() != null + ? (SdkHttpFullRequest) context.executionContext().interceptorContext().httpRequest() + : input; + MetricUtils.collectServiceEndpointMetrics(metricCollector, finalRequest); } } } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncApiCallMetricCollectionStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncApiCallMetricCollectionStage.java index 1d7d040971cf..e85a6f752aee 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncApiCallMetricCollectionStage.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncApiCallMetricCollectionStage.java @@ -41,7 +41,6 @@ public AsyncApiCallMetricCollectionStage(RequestPipeline execute(SdkHttpFullRequest input, RequestExecutionContext context) throws Exception { MetricCollector metricCollector = context.executionContext().metricCollector(); - MetricUtils.collectServiceEndpointMetrics(metricCollector, input); CompletableFuture future = new CompletableFuture<>(); @@ -52,6 +51,13 @@ public CompletableFuture execute(SdkHttpFullRequest input, RequestExecu long duration = System.nanoTime() - callStart; metricCollector.reportMetric(CoreMetric.API_CALL_DURATION, Duration.ofNanos(duration)); + // Collect SERVICE_ENDPOINT after pipeline execution so that EndpointResolutionStage + // has applied the resolved URL to the request. + SdkHttpFullRequest finalRequest = context.executionContext().interceptorContext().httpRequest() != null + ? (SdkHttpFullRequest) context.executionContext().interceptorContext().httpRequest() + : input; + MetricUtils.collectServiceEndpointMetrics(metricCollector, finalRequest); + if (t != null) { future.completeExceptionally(t); } else { diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStage.java new file mode 100644 index 000000000000..5d13830a178d --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStage.java @@ -0,0 +1,199 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.http.pipeline.stages; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.RequestOverrideConfiguration; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; +import software.amazon.awssdk.core.http.auth.AuthSchemeResolver; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.core.internal.http.HttpClientDependencies; +import software.amazon.awssdk.core.internal.http.RequestExecutionContext; +import software.amazon.awssdk.core.internal.http.pipeline.MutableRequestToRequestPipeline; +import software.amazon.awssdk.core.spi.identity.AuthSchemeOptionsResolver; +import software.amazon.awssdk.core.spi.identity.IdentityProviderUpdater; +import software.amazon.awssdk.core.useragent.BusinessMetricCollection; +import software.amazon.awssdk.core.useragent.BusinessMetricFeatureId; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; +import software.amazon.awssdk.identity.spi.Identity; +import software.amazon.awssdk.identity.spi.IdentityProvider; +import software.amazon.awssdk.identity.spi.IdentityProviders; +import software.amazon.awssdk.identity.spi.ResolveIdentityRequest; +import software.amazon.awssdk.identity.spi.TokenIdentity; +import software.amazon.awssdk.metrics.MetricCollector; + +/** + * Pipeline stage that resolves the auth scheme and identity for signing. + */ +@SdkInternalApi +public final class AuthSchemeResolutionStage implements MutableRequestToRequestPipeline { + + private static final String SIGV4A_SCHEME_ID = "aws.auth#sigv4a"; + private static final String BEARER_SCHEME_ID = "smithy.api#httpBearerAuth"; + + public AuthSchemeResolutionStage(HttpClientDependencies dependencies) { + } + + @Override + public SdkHttpFullRequest.Builder execute(SdkHttpFullRequest.Builder request, RequestExecutionContext context) + throws Exception { + ExecutionAttributes executionAttributes = context.executionAttributes(); + + Map> authSchemes = executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES); + if (authSchemes == null) { + return request; + } + + SdkRequest sdkRequest = context.executionContext().interceptorContext().request(); + List authOptions = resolveAuthSchemeOptions(executionAttributes, sdkRequest); + if (authOptions == null || authOptions.isEmpty()) { + return request; + } + + IdentityProviders identityProviders = updateIdentityProvidersIfNeeded(executionAttributes, sdkRequest); + + // Skip resolution if auth scheme was already resolved by an old service interceptor + SelectedAuthScheme existing = executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (existing != null && !"unset".equals(existing.authSchemeOption().schemeId())) { + //Updates the identity on an already-resolved auth scheme, ensuring credential overrides via interceptors + //are respected even with old service clients. + updateIdentityOnExistingScheme(existing, identityProviders, executionAttributes); + return request; + } + + MetricCollector metricCollector = + executionAttributes.getAttribute(SdkExecutionAttribute.API_CALL_METRIC_COLLECTOR); + + SelectedAuthScheme selectedAuthScheme = + AuthSchemeResolver.selectAuthScheme(authOptions, authSchemes, identityProviders, metricCollector); + + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_SNAPSHOT_POST_INTERCEPTORS, + executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME)); + + selectedAuthScheme = AuthSchemeResolver.mergePreExistingAuthSchemeProperties(selectedAuthScheme, executionAttributes); + + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + + Consumer signingMethodUpdater = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.SIGNING_METHOD_UPDATER); + if (signingMethodUpdater != null) { + signingMethodUpdater.accept(executionAttributes); + } + + recordBusinessMetrics(selectedAuthScheme, sdkRequest, executionAttributes); + + return request; + } + + private List resolveAuthSchemeOptions(ExecutionAttributes executionAttributes, SdkRequest request) { + AuthSchemeOptionsResolver resolver = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER); + + if (resolver == null) { + return null; + } + return resolver.resolve(request); + } + + /** + * Returns identity providers after applying any request-level overrides. This allows aws-core to inject + * credential overrides from {@code AwsRequestOverrideConfiguration} (e.g., per-request credentials provider) + * without sdk-core depending on aws-core. The updater is set by {@code AwsExecutionContextBuilder} and runs + * after interceptors have modified the request, ensuring user-injected credentials are respected. + */ + private IdentityProviders updateIdentityProvidersIfNeeded(ExecutionAttributes executionAttributes, SdkRequest request) { + IdentityProviders identityProviders = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS); + + IdentityProviderUpdater updater = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDER_UPDATER); + if (updater != null) { + identityProviders = updater.update(request, identityProviders, executionAttributes); + } + return identityProviders; + } + + @SuppressWarnings("unchecked") + private void updateIdentityOnExistingScheme(SelectedAuthScheme existing, + IdentityProviders identityProviders, + ExecutionAttributes executionAttributes) { + AuthScheme authScheme = (AuthScheme) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES) + .get(existing.authSchemeOption().schemeId()); + if (authScheme == null) { + return; + } + IdentityProvider identityProvider = authScheme.identityProvider(identityProviders); + if (identityProvider == null) { + return; + } + CompletableFuture identity = identityProvider.resolveIdentity(ResolveIdentityRequest.builder().build()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, + new SelectedAuthScheme<>(identity, existing.signer(), existing.authSchemeOption())); + } + + private void recordBusinessMetrics(SelectedAuthScheme selectedAuthScheme, + SdkRequest request, + ExecutionAttributes executionAttributes) { + if (selectedAuthScheme == null) { + return; + } + + BusinessMetricCollection businessMetrics = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS); + if (businessMetrics == null) { + return; + } + + String schemeId = selectedAuthScheme.authSchemeOption().schemeId(); + + if (isSignerOverridden(request, executionAttributes)) { + return; + } + + if (SIGV4A_SCHEME_ID.equals(schemeId)) { + businessMetrics.addMetric(BusinessMetricFeatureId.SIGV4A_SIGNING.value()); + } + + if (BEARER_SCHEME_ID.equals(schemeId) && selectedAuthScheme.identity().isDone()) { + Identity identity = selectedAuthScheme.identity().getNow(null); + if (identity instanceof TokenIdentity) { + String tokenFromEnv = executionAttributes.getAttribute(SdkInternalExecutionAttribute.TOKEN_CONFIGURED_FROM_ENV); + if (tokenFromEnv != null && tokenFromEnv.equals(((TokenIdentity) identity).token())) { + businessMetrics.addMetric(BusinessMetricFeatureId.BEARER_SERVICE_ENV_VARS.value()); + } + } + } + } + + private boolean isSignerOverridden(SdkRequest request, ExecutionAttributes executionAttributes) { + boolean isClientSignerOverridden = + Boolean.TRUE.equals(executionAttributes.getAttribute(SdkExecutionAttribute.SIGNER_OVERRIDDEN)); + boolean isRequestSignerOverridden = request.overrideConfiguration() + .flatMap(RequestOverrideConfiguration::signer) + .isPresent(); + return isClientSignerOverridden || isRequestSignerOverridden; + } +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/EndpointResolutionStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/EndpointResolutionStage.java new file mode 100644 index 000000000000..0355786afef3 --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/EndpointResolutionStage.java @@ -0,0 +1,170 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.http.pipeline.stages; + +import java.net.URI; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.ClientEndpointProvider; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; +import software.amazon.awssdk.core.http.auth.AuthSchemeResolver; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.core.internal.endpoint.EndpointResolver; +import software.amazon.awssdk.core.internal.http.HttpClientDependencies; +import software.amazon.awssdk.core.internal.http.RequestExecutionContext; +import software.amazon.awssdk.core.internal.http.pipeline.MutableRequestToRequestPipeline; +import software.amazon.awssdk.core.metrics.CoreMetric; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.metrics.MetricCollector; +import software.amazon.awssdk.utils.StringUtils; +import software.amazon.awssdk.utils.http.SdkHttpUtils; + +/** + * Pipeline stage that resolves the endpoint using the service's endpoint rules engine. + *

+ * This stage runs after all interceptors (including {@code modifyHttpRequest}) and after + * {@link AuthSchemeResolutionStage}. It calls a service-specific callback to resolve the endpoint, + * then applies the resolved URL, headers, and metrics. + *

+ */ +@SdkInternalApi +public final class EndpointResolutionStage implements MutableRequestToRequestPipeline { + + public EndpointResolutionStage(HttpClientDependencies dependencies) { + } + + @Override + public SdkHttpFullRequest.Builder execute(SdkHttpFullRequest.Builder request, RequestExecutionContext context) + throws Exception { + ExecutionAttributes attrs = context.executionAttributes(); + + if (Boolean.TRUE.equals(attrs.getAttribute(SdkInternalExecutionAttribute.IS_DISCOVERED_ENDPOINT))) { + return request; + } + + // Skip if endpoint was already resolved by an old service interceptor + Endpoint existingEndpoint = attrs.getAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT); + if (existingEndpoint != null) { + return request; + } + + EndpointResolver resolver = attrs.getAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER); + if (resolver == null) { + return request; + } + + SdkRequest sdkRequest = context.executionContext().interceptorContext().request(); + + long resolveEndpointStart = System.nanoTime(); + Endpoint endpoint = resolver.resolve(sdkRequest, attrs); + Duration resolveEndpointDuration = Duration.ofNanos(System.nanoTime() - resolveEndpointStart); + + reapplyInterceptorModifiedAuthProperties(attrs); + + MetricCollector metricCollector = attrs.getAttribute(SdkExecutionAttribute.API_CALL_METRIC_COLLECTOR); + if (metricCollector != null) { + metricCollector.reportMetric(CoreMetric.ENDPOINT_RESOLVE_DURATION, resolveEndpointDuration); + } + + attrs.putAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT, endpoint); + + // Copy endpoint headers onto HTTP request + Map> headers = endpoint.headers(); + if (headers != null && !headers.isEmpty()) { + headers.forEach((name, values) -> + values.forEach(v -> request.appendHeader(name, v))); + } + + // Apply resolved endpoint URL, unless a customer interceptor modified the URL in modifyHttpRequest() + ClientEndpointProvider clientEndpointProvider = + attrs.getAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER); + if (interceptorModifiedEndpoint(request, attrs)) { + applyResolvedPath(request, clientEndpointProvider.clientEndpoint(), endpoint.url()); + return request; + } + return setUri(request, clientEndpointProvider.clientEndpoint(), endpoint.url()); + } + + /** + * Detects if an interceptor modified the HTTP request URL in modifyHttpRequest(). + * Compares the current request's host and scheme against the snapshot taken before interceptors ran. + */ + private static boolean interceptorModifiedEndpoint(SdkHttpFullRequest.Builder request, ExecutionAttributes attrs) { + URI preModifyUri = attrs.getAttribute(SdkInternalExecutionAttribute.HTTP_REQUEST_URI_BEFORE_MODIFY); + if (preModifyUri == null) { + return false; + } + String requestHost = request.host(); + Integer requestPort = request.port(); + return requestHost != null + && (!requestHost.equals(preModifyUri.getHost()) + || !String.valueOf(request.protocol()).equals(preModifyUri.getScheme()) + || (requestPort != null && requestPort != preModifyUri.getPort())); + } + + /** + * Applies the resolved endpoint URL to the HTTP request, merging three path components: + * client endpoint path + rules engine additions + marshaller operation path. + */ + private static SdkHttpFullRequest.Builder setUri(SdkHttpFullRequest.Builder request, + URI clientEndpoint, + URI resolvedUri) { + applyResolvedPath(request, clientEndpoint, resolvedUri); + return request.protocol(resolvedUri.getScheme()) + .host(resolvedUri.getHost()) + .port(resolvedUri.getPort()); + } + + private static void applyResolvedPath(SdkHttpFullRequest.Builder request, + URI clientEndpoint, + URI resolvedUri) { + String clientEndpointPath = clientEndpoint.getRawPath(); + String requestPath = request.encodedPath(); + String resolvedUriPath = resolvedUri.getRawPath(); + + if (!resolvedUriPath.equals(clientEndpointPath)) { + request.encodedPath(combinePath(clientEndpointPath, requestPath, resolvedUriPath)); + } + } + + private static String combinePath(String clientEndpointPath, String requestPath, String resolvedUriPath) { + String requestPathWithClientPathRemoved = StringUtils.replaceOnce(requestPath, clientEndpointPath, ""); + return SdkHttpUtils.appendUri(resolvedUriPath, requestPathWithClientPathRemoved); + } + + private static void reapplyInterceptorModifiedAuthProperties(ExecutionAttributes attrs) { + SelectedAuthScheme currentScheme = attrs.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (currentScheme == null) { + return; + } + SelectedAuthScheme beforeInterceptors = + attrs.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_SNAPSHOT_PRE_INTERCEPTORS); + + SelectedAuthScheme afterInterceptors = + attrs.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_SNAPSHOT_POST_INTERCEPTORS); + if (afterInterceptors == null) { + return; + } + + AuthSchemeResolver.applyInterceptorModifiedProperties(currentScheme, beforeInterceptors, afterInterceptors, attrs); + } +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/identity/AuthSchemeOptionsResolver.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/identity/AuthSchemeOptionsResolver.java new file mode 100644 index 000000000000..7a37d7c3dd6b --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/identity/AuthSchemeOptionsResolver.java @@ -0,0 +1,39 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.spi.identity; + +import java.util.List; +import software.amazon.awssdk.annotations.SdkProtectedApi; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; + +/** + * Callback interface for resolving auth scheme options from the request. + *

+ * This allows auth scheme resolution to happen after interceptors have modified the request, + * ensuring that any request modifications affecting auth scheme selection are respected. + */ +@FunctionalInterface +@SdkProtectedApi +public interface AuthSchemeOptionsResolver { + /** + * Resolves auth scheme options for the given request. + * + * @param request The request (after interceptors have modified it) + * @return List of auth scheme options in priority order + */ + List resolve(SdkRequest request); +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/identity/IdentityProviderUpdater.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/identity/IdentityProviderUpdater.java new file mode 100644 index 000000000000..ac5053b4c33a --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/identity/IdentityProviderUpdater.java @@ -0,0 +1,42 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.spi.identity; + +import software.amazon.awssdk.annotations.SdkProtectedApi; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.identity.spi.IdentityProviders; + +/** + * Callback interface for updating identity providers based on request-level overrides. + *

+ * This allows aws-core to provide AWS-specific logic for reading credential overrides + * from {@code AwsRequestOverrideConfiguration} without sdk-core depending on aws-core. + */ +@FunctionalInterface +@SdkProtectedApi +public interface IdentityProviderUpdater { + /** + * Updates identity providers by applying request-level credential overrides or + * credentials set via {@code AwsSignerExecutionAttribute.AWS_CREDENTIALS} by interceptors. + * + * @param request The request (after interceptors have modified it) + * @param base The base identity providers from client configuration + * @param executionAttributes The execution attributes, checked for interceptor-set AWS_CREDENTIALS + * @return Updated identity providers, or base if no overrides apply + */ + IdentityProviders update(SdkRequest request, IdentityProviders base, ExecutionAttributes executionAttributes); +} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/auth/AuthSchemeResolverTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/auth/AuthSchemeResolverTest.java new file mode 100644 index 000000000000..a3afe853d0e8 --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/auth/AuthSchemeResolverTest.java @@ -0,0 +1,190 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.http.auth; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.SelectedAuthScheme; +import software.amazon.awssdk.core.exception.SdkException; +import software.amazon.awssdk.core.http.auth.AuthSchemeResolver; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; +import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; +import software.amazon.awssdk.identity.spi.Identity; +import software.amazon.awssdk.identity.spi.IdentityProvider; +import software.amazon.awssdk.identity.spi.IdentityProviders; +import software.amazon.awssdk.identity.spi.ResolveIdentityRequest; + +class AuthSchemeResolverTest { + + private static final String SCHEME_A = "schemeA"; + private static final String SCHEME_B = "schemeB"; + + @Test + void selectAuthScheme_firstOptionSucceeds_returnsFirstScheme() { + AuthScheme schemeA = createMockAuthScheme(); + Map> authSchemes = new HashMap<>(); + authSchemes.put(SCHEME_A, schemeA); + + List options = Collections.singletonList( + AuthSchemeOption.builder().schemeId(SCHEME_A).build() + ); + + SelectedAuthScheme result = AuthSchemeResolver.selectAuthScheme( + options, authSchemes, mock(IdentityProviders.class), null); + + assertThat(result.authSchemeOption().schemeId()).isEqualTo(SCHEME_A); + } + + @Test + void selectAuthScheme_firstOptionNoScheme_fallsBackToSecond() { + AuthScheme schemeB = createMockAuthScheme(); + Map> authSchemes = new HashMap<>(); + authSchemes.put(SCHEME_B, schemeB); + + List options = Arrays.asList( + AuthSchemeOption.builder().schemeId(SCHEME_A).build(), + AuthSchemeOption.builder().schemeId(SCHEME_B).build() + ); + + SelectedAuthScheme result = AuthSchemeResolver.selectAuthScheme( + options, authSchemes, mock(IdentityProviders.class), null); + + assertThat(result.authSchemeOption().schemeId()).isEqualTo(SCHEME_B); + } + + @Test + void selectAuthScheme_firstOptionNoIdentityProvider_fallsBackToSecond() { + AuthScheme schemeA = createMockAuthScheme(); + when(schemeA.identityProvider(any())).thenReturn(null); + + AuthScheme schemeB = createMockAuthScheme(); + + Map> authSchemes = new HashMap<>(); + authSchemes.put(SCHEME_A, schemeA); + authSchemes.put(SCHEME_B, schemeB); + + List options = Arrays.asList( + AuthSchemeOption.builder().schemeId(SCHEME_A).build(), + AuthSchemeOption.builder().schemeId(SCHEME_B).build() + ); + + SelectedAuthScheme result = AuthSchemeResolver.selectAuthScheme( + options, authSchemes, mock(IdentityProviders.class), null); + + assertThat(result.authSchemeOption().schemeId()).isEqualTo(SCHEME_B); + } + + @Test + void selectAuthScheme_signerThrows_fallsBackToSecond() { + AuthScheme schemeA = createMockAuthScheme(); + when(schemeA.signer()).thenThrow(new RuntimeException("Signer not available")); + + AuthScheme schemeB = createMockAuthScheme(); + + Map> authSchemes = new HashMap<>(); + authSchemes.put(SCHEME_A, schemeA); + authSchemes.put(SCHEME_B, schemeB); + + List options = Arrays.asList( + AuthSchemeOption.builder().schemeId(SCHEME_A).build(), + AuthSchemeOption.builder().schemeId(SCHEME_B).build() + ); + + SelectedAuthScheme result = AuthSchemeResolver.selectAuthScheme( + options, authSchemes, mock(IdentityProviders.class), null); + + assertThat(result.authSchemeOption().schemeId()).isEqualTo(SCHEME_B); + } + + @Test + void selectAuthScheme_allOptionsFail_throwsException() { + Map> authSchemes = new HashMap<>(); + + List options = Collections.singletonList( + AuthSchemeOption.builder().schemeId(SCHEME_A).build() + ); + + assertThatThrownBy(() -> AuthSchemeResolver.selectAuthScheme( + options, authSchemes, mock(IdentityProviders.class), null)) + .isInstanceOf(SdkException.class) + .hasMessageContaining("Failed to determine how to authenticate"); + } + + @Test + void mergeProperties_noExistingScheme_returnsOriginal() { + SelectedAuthScheme selected = createSelectedAuthScheme(SCHEME_A); + ExecutionAttributes attributes = new ExecutionAttributes(); + + SelectedAuthScheme result = AuthSchemeResolver.mergePreExistingAuthSchemeProperties( + selected, attributes); + + assertThat(result).isSameAs(selected); + } + + @Test + @SuppressWarnings("unchecked") + void mergeProperties_withExistingScheme_returnsNewInstance() { + SelectedAuthScheme selected = createSelectedAuthScheme(SCHEME_A); + SelectedAuthScheme existing = createSelectedAuthScheme(SCHEME_B); + + ExecutionAttributes attributes = new ExecutionAttributes(); + attributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, existing); + + SelectedAuthScheme result = AuthSchemeResolver.mergePreExistingAuthSchemeProperties( + selected, attributes); + + assertThat(result).isNotSameAs(selected); + assertThat(result.authSchemeOption().schemeId()).isEqualTo(SCHEME_A); + } + + @SuppressWarnings("unchecked") + private AuthScheme createMockAuthScheme() { + AuthScheme scheme = mock(AuthScheme.class); + IdentityProvider identityProvider = mock(IdentityProvider.class); + Identity mockIdentity = mock(Identity.class); + doReturn(CompletableFuture.completedFuture(mockIdentity)) + .when(identityProvider).resolveIdentity(any(ResolveIdentityRequest.class)); + when(scheme.identityProvider(any())).thenReturn(identityProvider); + when(scheme.signer()).thenReturn(mock(HttpSigner.class)); + return scheme; + } + + @SuppressWarnings("unchecked") + private SelectedAuthScheme createSelectedAuthScheme(String schemeId) { + Identity mockIdentity = mock(Identity.class); + HttpSigner mockSigner = mock(HttpSigner.class); + return new SelectedAuthScheme<>( + CompletableFuture.completedFuture(mockIdentity), + mockSigner, + AuthSchemeOption.builder().schemeId(schemeId).build() + ); + } +} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStageTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStageTest.java new file mode 100644 index 000000000000..177ddc9823ef --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStageTest.java @@ -0,0 +1,274 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.http.pipeline.stages; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; +import software.amazon.awssdk.core.http.ExecutionContext; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.InterceptorContext; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.core.internal.http.RequestExecutionContext; +import software.amazon.awssdk.core.spi.identity.AuthSchemeOptionsResolver; +import software.amazon.awssdk.core.spi.identity.IdentityProviderUpdater; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; +import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; +import software.amazon.awssdk.identity.spi.Identity; +import software.amazon.awssdk.identity.spi.IdentityProvider; +import software.amazon.awssdk.identity.spi.IdentityProviders; +import software.amazon.awssdk.identity.spi.ResolveIdentityRequest; + +class AuthSchemeResolutionStageTest { + + private static final String SCHEME_ID = "test.scheme"; + + private AuthSchemeResolutionStage stage; + private SdkHttpFullRequest.Builder httpRequestBuilder; + private RequestExecutionContext context; + private ExecutionAttributes executionAttributes; + private SdkRequest sdkRequest; + + @BeforeEach + void setup() { + stage = new AuthSchemeResolutionStage(null); + httpRequestBuilder = mock(SdkHttpFullRequest.Builder.class); + sdkRequest = mock(SdkRequest.class); + executionAttributes = new ExecutionAttributes(); + + InterceptorContext interceptorContext = InterceptorContext.builder() + .request(sdkRequest) + .build(); + ExecutionContext executionContext = ExecutionContext.builder() + .interceptorContext(interceptorContext) + .executionAttributes(executionAttributes) + .build(); + context = RequestExecutionContext.builder() + .executionContext(executionContext) + .originalRequest(sdkRequest) + .build(); + } + + @Test + void execute_noAuthSchemes_returnsRequestUnchanged() throws Exception { + // AUTH_SCHEMES is null + SdkHttpFullRequest.Builder result = stage.execute(httpRequestBuilder, context); + + assertThat(result).isSameAs(httpRequestBuilder); + assertThat(executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME)).isNull(); + } + + @Test + void execute_noResolver_returnsRequestUnchanged() throws Exception { + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, createAuthSchemes()); + // AUTH_SCHEME_OPTIONS_RESOLVER is null + + SdkHttpFullRequest.Builder result = stage.execute(httpRequestBuilder, context); + + assertThat(result).isSameAs(httpRequestBuilder); + assertThat(executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME)).isNull(); + } + + @Test + void execute_resolverReturnsEmpty_returnsRequestUnchanged() throws Exception { + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, createAuthSchemes()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER, + (AuthSchemeOptionsResolver) req -> Collections.emptyList()); + + SdkHttpFullRequest.Builder result = stage.execute(httpRequestBuilder, context); + + assertThat(result).isSameAs(httpRequestBuilder); + assertThat(executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME)).isNull(); + } + + @Test + void execute_resolverReceivesRequestFromInterceptorContext() throws Exception { + SdkRequest modifiedRequest = mock(SdkRequest.class); + + // Setup interceptor context with a DIFFERENT request than originalRequest + InterceptorContext interceptorContext = InterceptorContext.builder() + .request(modifiedRequest) + .build(); + ExecutionContext executionContext = ExecutionContext.builder() + .interceptorContext(interceptorContext) + .executionAttributes(executionAttributes) + .build(); + context = RequestExecutionContext.builder() + .executionContext(executionContext) + .originalRequest(sdkRequest) // Different from modifiedRequest + .build(); + + AuthSchemeOptionsResolver resolver = mock(AuthSchemeOptionsResolver.class); + doReturn(createAuthOptions()).when(resolver).resolve(modifiedRequest); + + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, createAuthSchemes()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER, resolver); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, createIdentityProviders()); + + stage.execute(httpRequestBuilder, context); + + // Verify resolver was called with the MODIFIED request, not originalRequest + verify(resolver).resolve(modifiedRequest); + } + + @Test + void execute_withIdentityProviderUpdater_callsUpdaterWithRequest() throws Exception { + // Create mocks first before any stubbing + IdentityProvider identityProvider = createMockIdentityProvider(); + Map> authSchemes = createAuthSchemes(); + IdentityProviders baseProviders = mock(IdentityProviders.class); + IdentityProviders updatedProviders = mock(IdentityProviders.class); + + IdentityProviderUpdater updater = mock(IdentityProviderUpdater.class); + doReturn(updatedProviders).when(updater).update(sdkRequest, baseProviders, executionAttributes); + + // Setup so that auth scheme uses the updated providers + @SuppressWarnings("unchecked") + AuthScheme scheme = (AuthScheme) authSchemes.get(SCHEME_ID); + doReturn(identityProvider).when(scheme).identityProvider(updatedProviders); + + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, authSchemes); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER, + (AuthSchemeOptionsResolver) req -> createAuthOptions()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, baseProviders); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDER_UPDATER, updater); + + stage.execute(httpRequestBuilder, context); + + verify(updater).update(sdkRequest, baseProviders, executionAttributes); + } + + @Test + void execute_withoutIdentityProviderUpdater_doesNotFail() throws Exception { + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, createAuthSchemes()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER, + (AuthSchemeOptionsResolver) req -> createAuthOptions()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, createIdentityProviders()); + // No IDENTITY_PROVIDER_UPDATER set + + SdkHttpFullRequest.Builder result = stage.execute(httpRequestBuilder, context); + + assertThat(result).isSameAs(httpRequestBuilder); + assertThat(executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME)).isNotNull(); + } + + @Test + void execute_happyPath_setsSelectedAuthScheme() throws Exception { + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, createAuthSchemes()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER, + (AuthSchemeOptionsResolver) req -> createAuthOptions()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, createIdentityProviders()); + + SdkHttpFullRequest.Builder result = stage.execute(httpRequestBuilder, context); + + assertThat(result).isSameAs(httpRequestBuilder); + SelectedAuthScheme selectedAuthScheme = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + assertThat(selectedAuthScheme).isNotNull(); + assertThat(selectedAuthScheme.authSchemeOption().schemeId()).isEqualTo(SCHEME_ID); + } + + @SuppressWarnings("unchecked") + private Map> createAuthSchemes() { + IdentityProvider identityProvider = createMockIdentityProvider(); + HttpSigner signer = mock(HttpSigner.class); + + AuthScheme scheme = mock(AuthScheme.class); + doReturn(identityProvider).when(scheme).identityProvider(any()); + doReturn(signer).when(scheme).signer(); + + Map> schemes = new HashMap<>(); + schemes.put(SCHEME_ID, scheme); + return schemes; + } + + @SuppressWarnings("unchecked") + private IdentityProvider createMockIdentityProvider() { + IdentityProvider provider = mock(IdentityProvider.class); + Identity mockIdentity = mock(Identity.class); + doReturn(CompletableFuture.completedFuture(mockIdentity)) + .when(provider).resolveIdentity(any(ResolveIdentityRequest.class)); + return provider; + } + + private IdentityProviders createIdentityProviders() { + IdentityProviders providers = mock(IdentityProviders.class); + return providers; + } + + private List createAuthOptions() { + return Collections.singletonList( + AuthSchemeOption.builder().schemeId(SCHEME_ID).build() + ); + } + + @Test + void execute_authSchemeAlreadyResolved_skipsResolution() throws Exception { + // Simulate old service interceptor already resolved auth scheme + SelectedAuthScheme alreadyResolved = new SelectedAuthScheme<>( + CompletableFuture.completedFuture(mock(Identity.class)), + mock(HttpSigner.class), + AuthSchemeOption.builder().schemeId("aws.auth#sigv4").build() + ); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, alreadyResolved); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, createAuthSchemes()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER, + (AuthSchemeOptionsResolver) req -> createAuthOptions()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, createIdentityProviders()); + + SdkHttpFullRequest.Builder result = stage.execute(httpRequestBuilder, context); + + assertThat(result).isSameAs(httpRequestBuilder); + // Auth scheme should still be the one set by the old interceptor (not re-resolved) + SelectedAuthScheme finalScheme = executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + assertThat(finalScheme.authSchemeOption().schemeId()).isEqualTo("aws.auth#sigv4"); + } + + @Test + void execute_authSchemeUnset_proceedsWithResolution() throws Exception { + // "unset" placeholder — pipeline should resolve + SelectedAuthScheme unsetPlaceholder = new SelectedAuthScheme<>( + CompletableFuture.completedFuture(mock(Identity.class)), + mock(HttpSigner.class), + AuthSchemeOption.builder().schemeId("unset").build() + ); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, unsetPlaceholder); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, createAuthSchemes()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER, + (AuthSchemeOptionsResolver) req -> createAuthOptions()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, createIdentityProviders()); + + stage.execute(httpRequestBuilder, context); + + // Should have resolved to the test scheme + SelectedAuthScheme finalScheme = executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + assertThat(finalScheme.authSchemeOption().schemeId()).isEqualTo(SCHEME_ID); + } +} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/EndpointResolutionStageTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/EndpointResolutionStageTest.java new file mode 100644 index 000000000000..d13717389c5f --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/EndpointResolutionStageTest.java @@ -0,0 +1,276 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.http.pipeline.stages; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import java.net.URI; +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import software.amazon.awssdk.core.ClientEndpointProvider; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.http.ExecutionContext; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.InterceptorContext; +import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.core.internal.http.RequestExecutionContext; +import software.amazon.awssdk.core.metrics.CoreMetric; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.SdkHttpMethod; +import software.amazon.awssdk.metrics.MetricCollector; + +class EndpointResolutionStageTest { + + private static final URI CLIENT_ENDPOINT = URI.create("https://myservice.us-east-1.amazonaws.com"); + private static final URI RESOLVED_ENDPOINT = URI.create("https://resolved.us-west-2.amazonaws.com"); + + private EndpointResolutionStage stage; + private ExecutionAttributes executionAttributes; + private SdkRequest sdkRequest; + + @BeforeEach + void setup() { + stage = new EndpointResolutionStage(null); + sdkRequest = mock(SdkRequest.class); + executionAttributes = new ExecutionAttributes(); + } + + @Test + void execute_noResolver_returnsRequestUnchanged() throws Exception { + SdkHttpFullRequest.Builder request = defaultRequest(); + RequestExecutionContext context = createContext(); + + SdkHttpFullRequest.Builder result = stage.execute(request, context); + + assertThat(result).isSameAs(request); + } + + @Test + void execute_discoveredEndpoint_returnsRequestUnchanged() throws Exception { + SdkHttpFullRequest.Builder request = defaultRequest(); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IS_DISCOVERED_ENDPOINT, true); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER, + (req, attrs) -> Endpoint.builder().url(RESOLVED_ENDPOINT).build()); + RequestExecutionContext context = createContext(); + + SdkHttpFullRequest.Builder result = stage.execute(request, context); + + assertThat(result).isSameAs(request); + } + + @Test + void execute_happyPath_appliesResolvedUrl() throws Exception { + SdkHttpFullRequest.Builder request = defaultRequest(); + Endpoint endpoint = Endpoint.builder().url(RESOLVED_ENDPOINT).build(); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER, (req, attrs) -> endpoint); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, + ClientEndpointProvider.forEndpointOverride(CLIENT_ENDPOINT)); + RequestExecutionContext context = createContext(); + + SdkHttpFullRequest.Builder result = stage.execute(request, context); + + assertThat(result.host()).isEqualTo("resolved.us-west-2.amazonaws.com"); + assertThat(result.protocol()).isEqualTo("https"); + } + + @Test + void execute_happyPath_storesResolvedEndpoint() throws Exception { + SdkHttpFullRequest.Builder request = defaultRequest(); + Endpoint endpoint = Endpoint.builder().url(RESOLVED_ENDPOINT).build(); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER, (req, attrs) -> endpoint); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, + ClientEndpointProvider.forEndpointOverride(CLIENT_ENDPOINT)); + RequestExecutionContext context = createContext(); + + stage.execute(request, context); + + assertThat(executionAttributes.getAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT)).isSameAs(endpoint); + } + + @Test + void execute_withHeaders_appendsToRequest() throws Exception { + Endpoint endpoint = Endpoint.builder().url(RESOLVED_ENDPOINT) + .putHeader("x-amz-custom", "val1") + .putHeader("x-amz-custom", "val2") + .build(); + + SdkHttpFullRequest.Builder request = defaultRequest(); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER, (req, attrs) -> endpoint); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, + ClientEndpointProvider.forEndpointOverride(CLIENT_ENDPOINT)); + RequestExecutionContext context = createContext(); + + SdkHttpFullRequest.Builder result = stage.execute(request, context); + + assertThat(result.build().matchingHeaders("x-amz-custom")).containsExactly("val1", "val2"); + } + + @Test + void execute_withMetricCollector_reportsEndpointResolveDuration() throws Exception { + Endpoint endpoint = Endpoint.builder().url(RESOLVED_ENDPOINT).build(); + MetricCollector metricCollector = mock(MetricCollector.class); + + SdkHttpFullRequest.Builder request = defaultRequest(); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER, (req, attrs) -> endpoint); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, + ClientEndpointProvider.forEndpointOverride(CLIENT_ENDPOINT)); + executionAttributes.putAttribute(SdkExecutionAttribute.API_CALL_METRIC_COLLECTOR, metricCollector); + RequestExecutionContext context = createContext(); + + stage.execute(request, context); + + ArgumentCaptor durationCaptor = ArgumentCaptor.forClass(Duration.class); + verify(metricCollector).reportMetric(org.mockito.ArgumentMatchers.eq(CoreMetric.ENDPOINT_RESOLVE_DURATION), + durationCaptor.capture()); + assertThat(durationCaptor.getValue()).isGreaterThanOrEqualTo(Duration.ZERO); + } + + @Test + void execute_resolvedPathDiffersFromClientPath_combinesPaths() throws Exception { + URI clientEndpoint = URI.create("https://myservice.amazonaws.com"); + URI resolvedUri = URI.create("https://resolved.amazonaws.com/v2"); + + Endpoint endpoint = Endpoint.builder().url(resolvedUri).build(); + SdkHttpFullRequest.Builder request = SdkHttpFullRequest.builder() + .method(SdkHttpMethod.GET) + .protocol("https") + .host("myservice.amazonaws.com") + .encodedPath("/my/operation"); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER, (req, attrs) -> endpoint); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, + ClientEndpointProvider.forEndpointOverride(clientEndpoint)); + RequestExecutionContext context = createContext(); + + SdkHttpFullRequest.Builder result = stage.execute(request, context); + + assertThat(result.encodedPath()).isEqualTo("/v2/my/operation"); + } + + @Test + void execute_resolverReceivesRequestFromInterceptorContext() throws Exception { + SdkRequest modifiedRequest = mock(SdkRequest.class); + SdkRequest[] capturedRequest = new SdkRequest[1]; + + Endpoint endpoint = Endpoint.builder().url(RESOLVED_ENDPOINT).build(); + SdkHttpFullRequest.Builder request = defaultRequest(); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER, (req, attrs) -> { + capturedRequest[0] = req; + return endpoint; + }); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, + ClientEndpointProvider.forEndpointOverride(CLIENT_ENDPOINT)); + + // Use modifiedRequest in interceptor context (different from sdkRequest) + RequestExecutionContext context = createContext(modifiedRequest); + + stage.execute(request, context); + + assertThat(capturedRequest[0]).isSameAs(modifiedRequest); + } + + @Test + void execute_interceptorModifiedHost_preservesCustomHostAndScheme() throws Exception { + Endpoint endpoint = Endpoint.builder().url(URI.create("https://resolved.amazonaws.com")).build(); + SdkHttpFullRequest.Builder request = SdkHttpFullRequest.builder() + .method(SdkHttpMethod.GET) + .protocol("http") + .host("custom.example.com") + .port(8080) + .encodedPath("/my-operation"); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.HTTP_REQUEST_URI_BEFORE_MODIFY, + URI.create("https://myservice.amazonaws.com/my-operation")); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER, (req, attrs) -> endpoint); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, + ClientEndpointProvider.forEndpointOverride(CLIENT_ENDPOINT)); + RequestExecutionContext context = createContext(); + + SdkHttpFullRequest.Builder result = stage.execute(request, context); + + assertThat(result.host()).isEqualTo("custom.example.com"); + assertThat(result.protocol()).isEqualTo("http"); + assertThat(result.port()).isEqualTo(8080); + } + + @Test + void execute_interceptorModifiedHost_stillAppliesResolvedPath() throws Exception { + URI clientEndpoint = URI.create("https://s3.us-west-2.amazonaws.com"); + URI resolvedUri = URI.create("https://s3.us-west-2.amazonaws.com/my-bucket"); + + Endpoint endpoint = Endpoint.builder().url(resolvedUri).build(); + SdkHttpFullRequest.Builder request = SdkHttpFullRequest.builder() + .method(SdkHttpMethod.GET) + .protocol("https") + .host("example.com") + .encodedPath("/my-key"); + + executionAttributes.putAttribute(SdkInternalExecutionAttribute.HTTP_REQUEST_URI_BEFORE_MODIFY, + URI.create("https://s3.us-west-2.amazonaws.com/my-key")); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER, (req, attrs) -> endpoint); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, + ClientEndpointProvider.forEndpointOverride(clientEndpoint)); + RequestExecutionContext context = createContext(); + + SdkHttpFullRequest.Builder result = stage.execute(request, context); + + assertThat(result.host()).isEqualTo("example.com"); + assertThat(result.encodedPath()).isEqualTo("/my-bucket/my-key"); + } + + private SdkHttpFullRequest.Builder defaultRequest() { + return SdkHttpFullRequest.builder() + .method(SdkHttpMethod.GET) + .protocol("https") + .host(CLIENT_ENDPOINT.getHost()) + .encodedPath("/"); + } + + private RequestExecutionContext createContext() { + return createContext(sdkRequest); + } + + private RequestExecutionContext createContext(SdkRequest request) { + InterceptorContext interceptorContext = InterceptorContext.builder() + .request(request) + .build(); + ExecutionContext executionContext = ExecutionContext.builder() + .interceptorContext(interceptorContext) + .executionAttributes(executionAttributes) + .build(); + return RequestExecutionContext.builder() + .executionContext(executionContext) + .originalRequest(request) + .build(); + } + + @Test + void execute_endpointAlreadyResolved_skipsResolution() throws Exception { + Endpoint preResolved = Endpoint.builder().url(URI.create("https://pre-resolved.example.com")).build(); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT, preResolved); + + SdkHttpFullRequest.Builder request = defaultRequest(); + SdkHttpFullRequest.Builder result = stage.execute(request, createContext()); + + + assertThat(result).isSameAs(request); + assertThat(executionAttributes.getAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT)).isSameAs(preResolved); + } +} diff --git a/services/docdb/src/main/java/software/amazon/awssdk/services/docdb/internal/RdsPresignInterceptor.java b/services/docdb/src/main/java/software/amazon/awssdk/services/docdb/internal/RdsPresignInterceptor.java index 55b0fd574e64..cd2c0d8c4de1 100644 --- a/services/docdb/src/main/java/software/amazon/awssdk/services/docdb/internal/RdsPresignInterceptor.java +++ b/services/docdb/src/main/java/software/amazon/awssdk/services/docdb/internal/RdsPresignInterceptor.java @@ -15,7 +15,6 @@ package software.amazon.awssdk.services.docdb.internal; -import static software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME; import java.net.URI; import java.time.Clock; @@ -32,6 +31,7 @@ import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; +import software.amazon.awssdk.core.http.auth.AuthSchemeResolver; import software.amazon.awssdk.core.interceptor.Context; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; @@ -106,7 +106,7 @@ public final SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, return request.toBuilder().removeQueryParameter(PARAM_SOURCE_REGION).build(); } - SelectedAuthScheme selectedAuthScheme = executionAttributes.getAttribute(SELECTED_AUTH_SCHEME); + SelectedAuthScheme selectedAuthScheme = resolveAuthScheme(context.request(), executionAttributes); String sourceRegion = presignableRequest.getSourceRegion(); String destinationRegion = selectedAuthScheme.authSchemeOption().signerProperty(AwsV4HttpSigner.REGION_NAME); URI endpoint = createEndpoint(sourceRegion, SERVICE_NAME, executionAttributes); @@ -118,7 +118,7 @@ public final SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, .removeQueryParameter(PARAM_SOURCE_REGION) .build(); - requestToPresign = sraPresignRequest(executionAttributes, requestToPresign, sourceRegion); + requestToPresign = sraPresignRequest(selectedAuthScheme, requestToPresign, sourceRegion); String presignedUrl = requestToPresign.getUri().toString(); @@ -129,6 +129,14 @@ public final SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, .build(); } + /** + * Resolves the auth scheme from execution attributes, applying any request-level credential overrides. + */ + private SelectedAuthScheme resolveAuthScheme(SdkRequest request, + ExecutionAttributes executionAttributes) { + return AuthSchemeResolver.resolveAuthScheme(request, executionAttributes); + } + /** * Adapts the request to the {@link PresignableRequest}. * @@ -159,11 +167,8 @@ private PresignableRequest toPresignableRequest(SdkHttpRequest request, Context. /** * Presign the provided HTTP request using SRA HttpSigner */ - private SdkHttpFullRequest sraPresignRequest(ExecutionAttributes executionAttributes, SdkHttpFullRequest request, + private SdkHttpFullRequest sraPresignRequest(SelectedAuthScheme selectedAuthScheme, SdkHttpFullRequest request, String signingRegion) { - SelectedAuthScheme selectedAuthScheme = executionAttributes.getAttribute(SELECTED_AUTH_SCHEME); - - Instant signingInstant; if (signingClockOverride != null) { signingInstant = signingClockOverride.instant(); diff --git a/services/dynamodb/src/test/java/software/amazon/awssdk/services/dynamodb/PaginatorInUserAgentTest.java b/services/dynamodb/src/test/java/software/amazon/awssdk/services/dynamodb/PaginatorInUserAgentTest.java index b4bc3a504391..6ade0f58c699 100644 --- a/services/dynamodb/src/test/java/software/amazon/awssdk/services/dynamodb/PaginatorInUserAgentTest.java +++ b/services/dynamodb/src/test/java/software/amazon/awssdk/services/dynamodb/PaginatorInUserAgentTest.java @@ -39,7 +39,7 @@ import software.amazon.awssdk.core.useragent.BusinessMetricCollection; import software.amazon.awssdk.core.useragent.BusinessMetricFeatureId; import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.dynamodb.endpoints.internal.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; import software.amazon.awssdk.services.dynamodb.paginators.QueryPublisher; public class PaginatorInUserAgentTest { diff --git a/services/ec2/src/main/java/software/amazon/awssdk/services/ec2/transform/internal/GeneratePreSignUrlInterceptor.java b/services/ec2/src/main/java/software/amazon/awssdk/services/ec2/transform/internal/GeneratePreSignUrlInterceptor.java index a4ebd7265773..acfc4546e5ac 100644 --- a/services/ec2/src/main/java/software/amazon/awssdk/services/ec2/transform/internal/GeneratePreSignUrlInterceptor.java +++ b/services/ec2/src/main/java/software/amazon/awssdk/services/ec2/transform/internal/GeneratePreSignUrlInterceptor.java @@ -15,7 +15,6 @@ package software.amazon.awssdk.services.ec2.transform.internal; -import static software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME; import java.net.URI; import java.time.Clock; @@ -32,6 +31,7 @@ import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.core.http.auth.AuthSchemeResolver; import software.amazon.awssdk.core.interceptor.Context; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; @@ -131,7 +131,7 @@ public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, Execu .build(); URI presignedUrl = - sraPresignRequest(executionAttributes, requestForPresigning, sourceRegion); + sraPresignRequest(context.request(), executionAttributes, requestForPresigning, sourceRegion); return request.toBuilder() .putRawQueryParameter("DestinationRegion", destinationRegion) @@ -142,9 +142,14 @@ public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, Execu return request; } - private URI sraPresignRequest(ExecutionAttributes executionAttributes, SdkHttpFullRequest request, - String signingRegion) { - SelectedAuthScheme selectedAuthScheme = executionAttributes.getAttribute(SELECTED_AUTH_SCHEME); + private SelectedAuthScheme resolveAuthScheme(SdkRequest request, + ExecutionAttributes executionAttributes) { + return AuthSchemeResolver.resolveAuthScheme(request, executionAttributes); + } + + private URI sraPresignRequest(SdkRequest sdkRequest, ExecutionAttributes executionAttributes, + SdkHttpFullRequest request, String signingRegion) { + SelectedAuthScheme selectedAuthScheme = resolveAuthScheme(sdkRequest, executionAttributes); Instant signingInstant; if (testClock != null) { signingInstant = testClock.instant(); diff --git a/services/ec2/src/test/java/software/amazon/awssdk/services/ec2/transform/internal/GeneratePreSignUrlInterceptorTest.java b/services/ec2/src/test/java/software/amazon/awssdk/services/ec2/transform/internal/GeneratePreSignUrlInterceptorTest.java index f7e03816e6b9..3ae2c4dc5c6e 100644 --- a/services/ec2/src/test/java/software/amazon/awssdk/services/ec2/transform/internal/GeneratePreSignUrlInterceptorTest.java +++ b/services/ec2/src/test/java/software/amazon/awssdk/services/ec2/transform/internal/GeneratePreSignUrlInterceptorTest.java @@ -25,17 +25,24 @@ import java.time.Instant; import java.time.ZoneId; import java.time.ZonedDateTime; +import java.util.Collections; +import java.util.Map; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; import software.amazon.awssdk.auth.signer.AwsSignerExecutionAttribute; import software.amazon.awssdk.core.interceptor.Context; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.http.SdkHttpFullRequest; import software.amazon.awssdk.http.SdkHttpMethod; import software.amazon.awssdk.http.SdkHttpRequest; +import software.amazon.awssdk.http.auth.aws.scheme.AwsV4AuthScheme; +import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; +import software.amazon.awssdk.identity.spi.IdentityProviders; import software.amazon.awssdk.services.ec2.model.CopySnapshotRequest; @RunWith(MockitoJUnitRunner.class) @@ -63,6 +70,7 @@ public void copySnapshotRequest_httpsProtocolAddedToEndpoint() { ExecutionAttributes attrs = new ExecutionAttributes(); attrs.putAttribute(AWS_CREDENTIALS, AwsBasicCredentials.create("foo", "bar")); attrs.putAttribute(AwsSignerExecutionAttribute.SERVICE_SIGNING_NAME, "ec2"); + addSraAttributes(attrs, AwsBasicCredentials.create("foo", "bar")); SdkHttpRequest modifiedRequest = INTERCEPTOR.modifyHttpRequest(mockContext, attrs); @@ -116,6 +124,7 @@ public void copySnapshotRequest_generatesCorrectPresignedUrl() { ExecutionAttributes attrs = new ExecutionAttributes(); attrs.putAttribute(AWS_CREDENTIALS, AwsBasicCredentials.create("akid", "skid")); attrs.putAttribute(AwsSignerExecutionAttribute.SERVICE_SIGNING_NAME, "ec2"); + addSraAttributes(attrs, AwsBasicCredentials.create("akid", "skid")); SdkHttpRequest modifiedRequest = interceptor.modifyHttpRequest(mockContext, attrs); @@ -123,4 +132,23 @@ public void copySnapshotRequest_generatesCorrectPresignedUrl() { assertThat(generatedPresignedUrl).isEqualTo(expectedPresignedUrl); } + + private static void addSraAttributes(ExecutionAttributes attrs, AwsBasicCredentials credentials) { + AwsV4AuthScheme authScheme = AwsV4AuthScheme.create(); + Map> authSchemes = Collections.singletonMap(authScheme.schemeId(), authScheme); + IdentityProviders identityProviders = IdentityProviders.builder() + .putIdentityProvider(StaticCredentialsProvider.create(credentials)) + .build(); + attrs.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, authSchemes); + attrs.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, identityProviders); + attrs.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER, + request -> Collections.singletonList( + software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption.builder() + .schemeId(authScheme.schemeId()) + .putSignerProperty(software.amazon.awssdk.http.auth.aws.signer.AwsV4FamilyHttpSigner + .SERVICE_SIGNING_NAME, "ec2") + .putSignerProperty(software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner + .REGION_NAME, "us-west-2") + .build())); + } } diff --git a/services/neptune/src/main/java/software/amazon/awssdk/services/neptune/internal/RdsPresignInterceptor.java b/services/neptune/src/main/java/software/amazon/awssdk/services/neptune/internal/RdsPresignInterceptor.java index 34f8fe30247c..1ff7adf4221d 100644 --- a/services/neptune/src/main/java/software/amazon/awssdk/services/neptune/internal/RdsPresignInterceptor.java +++ b/services/neptune/src/main/java/software/amazon/awssdk/services/neptune/internal/RdsPresignInterceptor.java @@ -15,7 +15,6 @@ package software.amazon.awssdk.services.neptune.internal; -import static software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME; import java.net.URI; import java.time.Clock; @@ -32,6 +31,7 @@ import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; +import software.amazon.awssdk.core.http.auth.AuthSchemeResolver; import software.amazon.awssdk.core.interceptor.Context; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; @@ -51,7 +51,6 @@ import software.amazon.awssdk.services.neptune.model.NeptuneRequest; import software.amazon.awssdk.utils.CompletableFutureUtils; - /** * Abstract pre-sign handler that follows the pre-signing scheme outlined in the 'RDS Presigned URL for Cross-Region Copying' * SEP. @@ -107,7 +106,7 @@ public final SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, return request.toBuilder().removeQueryParameter(PARAM_SOURCE_REGION).build(); } - SelectedAuthScheme selectedAuthScheme = executionAttributes.getAttribute(SELECTED_AUTH_SCHEME); + SelectedAuthScheme selectedAuthScheme = resolveAuthScheme(context.request(), executionAttributes); String sourceRegion = presignableRequest.getSourceRegion(); String destinationRegion = selectedAuthScheme.authSchemeOption().signerProperty(AwsV4HttpSigner.REGION_NAME); URI endpoint = createEndpoint(sourceRegion, SERVICE_NAME, executionAttributes); @@ -119,7 +118,7 @@ public final SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, .removeQueryParameter(PARAM_SOURCE_REGION) .build(); - requestToPresign = sraPresignRequest(executionAttributes, requestToPresign, sourceRegion); + requestToPresign = sraPresignRequest(selectedAuthScheme, requestToPresign, sourceRegion); String presignedUrl = requestToPresign.getUri().toString(); @@ -130,6 +129,14 @@ public final SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, .build(); } + /** + * Resolves the auth scheme from execution attributes, applying any request-level credential overrides. + */ + private SelectedAuthScheme resolveAuthScheme(SdkRequest request, + ExecutionAttributes executionAttributes) { + return AuthSchemeResolver.resolveAuthScheme(request, executionAttributes); + } + /** * Adapts the request to the {@link PresignableRequest}. * @@ -160,11 +167,8 @@ private PresignableRequest toPresignableRequest(SdkHttpRequest request, Context. /** * Presign the provided HTTP request using SRA HttpSigner */ - private SdkHttpFullRequest sraPresignRequest(ExecutionAttributes executionAttributes, SdkHttpFullRequest request, + private SdkHttpFullRequest sraPresignRequest(SelectedAuthScheme selectedAuthScheme, SdkHttpFullRequest request, String signingRegion) { - SelectedAuthScheme selectedAuthScheme = executionAttributes.getAttribute(SELECTED_AUTH_SCHEME); - - Instant signingInstant; if (signingClockOverride != null) { signingInstant = signingClockOverride.instant(); diff --git a/services/rds/src/main/java/software/amazon/awssdk/services/rds/internal/RdsPresignInterceptor.java b/services/rds/src/main/java/software/amazon/awssdk/services/rds/internal/RdsPresignInterceptor.java index 3e55f5ae9f24..f3b9cb723277 100644 --- a/services/rds/src/main/java/software/amazon/awssdk/services/rds/internal/RdsPresignInterceptor.java +++ b/services/rds/src/main/java/software/amazon/awssdk/services/rds/internal/RdsPresignInterceptor.java @@ -15,7 +15,6 @@ package software.amazon.awssdk.services.rds.internal; -import static software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME; import java.net.URI; import java.time.Clock; @@ -32,6 +31,7 @@ import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; +import software.amazon.awssdk.core.http.auth.AuthSchemeResolver; import software.amazon.awssdk.core.interceptor.Context; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; @@ -106,7 +106,7 @@ public final SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, return request.toBuilder().removeQueryParameter(PARAM_SOURCE_REGION).build(); } - SelectedAuthScheme selectedAuthScheme = executionAttributes.getAttribute(SELECTED_AUTH_SCHEME); + SelectedAuthScheme selectedAuthScheme = resolveAuthScheme(context.request(), executionAttributes); String sourceRegion = presignableRequest.getSourceRegion(); String destinationRegion = selectedAuthScheme.authSchemeOption().signerProperty(AwsV4HttpSigner.REGION_NAME); URI endpoint = createEndpoint(sourceRegion, SERVICE_NAME, executionAttributes); @@ -118,7 +118,7 @@ public final SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, .removeQueryParameter(PARAM_SOURCE_REGION) .build(); - requestToPresign = sraPresignRequest(executionAttributes, requestToPresign, sourceRegion); + requestToPresign = sraPresignRequest(selectedAuthScheme, requestToPresign, sourceRegion); String presignedUrl = requestToPresign.getUri().toString(); @@ -129,6 +129,14 @@ public final SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, .build(); } + /** + * Resolves the auth scheme from execution attributes, applying any request-level credential overrides. + */ + private SelectedAuthScheme resolveAuthScheme(SdkRequest request, + ExecutionAttributes executionAttributes) { + return AuthSchemeResolver.resolveAuthScheme(request, executionAttributes); + } + /** * Adapts the request to the {@link PresignableRequest}. * @@ -148,7 +156,6 @@ private PresignableRequest toPresignableRequest(SdkHttpRequest request, Context. if (request.firstMatchingRawQueryParameter(PARAM_PRESIGNED_URL).isPresent()) { return null; } - PresignableRequest presignableRequest = adaptRequest(requestClassToPreSign.cast(originalRequest)); String sourceRegion = presignableRequest.getSourceRegion(); if (sourceRegion == null) { @@ -160,9 +167,8 @@ private PresignableRequest toPresignableRequest(SdkHttpRequest request, Context. /** * Presign the provided HTTP request using SRA HttpSigner */ - private SdkHttpFullRequest sraPresignRequest(ExecutionAttributes executionAttributes, SdkHttpFullRequest request, + private SdkHttpFullRequest sraPresignRequest(SelectedAuthScheme selectedAuthScheme, SdkHttpFullRequest request, String signingRegion) { - SelectedAuthScheme selectedAuthScheme = executionAttributes.getAttribute(SELECTED_AUTH_SCHEME); Instant signingInstant; if (signingClockOverride != null) { signingInstant = signingClockOverride.instant(); diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/S3Utilities.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/S3Utilities.java index cb3064798a99..e5f4d07df7e3 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/S3Utilities.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/S3Utilities.java @@ -19,11 +19,11 @@ import java.net.MalformedURLException; import java.net.URI; import java.net.URL; -import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Supplier; import java.util.regex.Matcher; @@ -38,19 +38,19 @@ import software.amazon.awssdk.awscore.endpoint.AwsClientEndpointProvider; import software.amazon.awssdk.awscore.endpoint.DualstackEnabledProvider; import software.amazon.awssdk.awscore.endpoint.FipsEnabledProvider; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; import software.amazon.awssdk.awscore.internal.defaultsmode.DefaultsModeConfiguration; import software.amazon.awssdk.core.ClientEndpointProvider; import software.amazon.awssdk.core.ClientType; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; -import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; -import software.amazon.awssdk.core.interceptor.ExecutionInterceptorChain; -import software.amazon.awssdk.core.interceptor.InterceptorContext; import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.endpoints.Endpoint; import software.amazon.awssdk.http.SdkHttpFullRequest; import software.amazon.awssdk.http.SdkHttpMethod; import software.amazon.awssdk.http.SdkHttpRequest; @@ -62,9 +62,9 @@ import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.services.s3.endpoints.S3ClientContextParams; +import software.amazon.awssdk.services.s3.endpoints.S3EndpointParams; import software.amazon.awssdk.services.s3.endpoints.S3EndpointProvider; -import software.amazon.awssdk.services.s3.endpoints.internal.S3RequestSetEndpointInterceptor; -import software.amazon.awssdk.services.s3.endpoints.internal.S3ResolveEndpointInterceptor; +import software.amazon.awssdk.services.s3.endpoints.internal.S3EndpointResolverUtils; import software.amazon.awssdk.services.s3.internal.endpoints.UseGlobalEndpointResolver; import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.services.s3.model.GetUrlRequest; @@ -110,7 +110,6 @@ public final class S3Utilities { private final Supplier profileFile; private final String profileName; private final boolean fipsEnabled; - private final ExecutionInterceptorChain interceptorChain; private final UseGlobalEndpointResolver useGlobalEndpointResolver; /** @@ -140,8 +139,6 @@ private S3Utilities(Builder builder) { .isFipsEnabled() .orElse(false); - this.interceptorChain = createEndpointInterceptorChain(); - this.useGlobalEndpointResolver = createUseGlobalEndpointResolver(); } @@ -243,14 +240,9 @@ public URL getUrl(GetUrlRequest getUrlRequest) { .versionId(getUrlRequest.versionId()) .build(); - InterceptorContext interceptorContext = InterceptorContext.builder() - .httpRequest(marshalledRequest) - .request(getObjectRequest) - .build(); - ExecutionAttributes executionAttributes = createExecutionAttributes(clientEndpoint, resolvedRegion); - SdkHttpRequest modifiedRequest = runInterceptors(interceptorContext, executionAttributes).httpRequest(); + SdkHttpRequest modifiedRequest = resolveEndpointAndApply(getObjectRequest, marshalledRequest, executionAttributes); try { return modifiedRequest.getUri().toURL(); } catch (MalformedURLException exception) { @@ -506,16 +498,36 @@ private AttributeMap createClientContextParams() { return params.build(); } - private InterceptorContext runInterceptors(InterceptorContext context, ExecutionAttributes executionAttributes) { - context = interceptorChain.modifyRequest(context, executionAttributes); - return interceptorChain.modifyHttpRequestAndHttpContent(context, executionAttributes); - } + private SdkHttpRequest resolveEndpointAndApply(GetObjectRequest request, SdkHttpRequest httpRequest, + ExecutionAttributes executionAttributes) { + S3EndpointParams endpointParams = S3EndpointResolverUtils.ruleParams(request, executionAttributes); + S3EndpointProvider provider = (S3EndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + + Endpoint endpoint; + try { + endpoint = provider.resolveEndpoint(endpointParams).join(); + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + + ClientEndpointProvider clientEndpointProvider = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER); + SdkHttpRequest result = AwsEndpointProviderUtils.setUri(httpRequest, + clientEndpointProvider.clientEndpoint(), endpoint.url()); + + if (!endpoint.headers().isEmpty()) { + SdkHttpRequest.Builder builder = result.toBuilder(); + endpoint.headers().forEach((name, values) -> + values.forEach(v -> builder.appendHeader(name, v))); + result = builder.build(); + } - private ExecutionInterceptorChain createEndpointInterceptorChain() { - List interceptors = new ArrayList<>(); - interceptors.add(new S3ResolveEndpointInterceptor()); - interceptors.add(new S3RequestSetEndpointInterceptor()); - return new ExecutionInterceptorChain(interceptors); + return result; } private UseGlobalEndpointResolver createUseGlobalEndpointResolver() { diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/DefaultS3CrtAsyncClient.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/DefaultS3CrtAsyncClient.java index e49abb7126e5..8a7bfebbff43 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/DefaultS3CrtAsyncClient.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/DefaultS3CrtAsyncClient.java @@ -429,7 +429,7 @@ public void afterMarshalling(Context.AfterMarshalling context, .put(SIGNING_REGION, executionAttributes.getAttribute(AwsSignerExecutionAttribute.SIGNING_REGION)) .put(S3InternalSdkHttpExecutionAttribute.OBJECT_FILE_PATH, executionAttributes.getAttribute(OBJECT_FILE_PATH)) - .put(USE_S3_EXPRESS_AUTH, S3ExpressUtils.useS3ExpressAuthScheme(executionAttributes)) + .put(USE_S3_EXPRESS_AUTH, S3ExpressUtils.isS3ExpressAuthRequest(context.request(), executionAttributes)) .put(SIGNING_NAME, executionAttributes.getAttribute(SERVICE_SIGNING_NAME)) .put(REQUEST_CHECKSUM_CALCULATION, executionAttributes.getAttribute(SdkInternalExecutionAttribute.REQUEST_CHECKSUM_CALCULATION)) @@ -462,6 +462,20 @@ public void afterMarshalling(Context.AfterMarshalling context, executionAttributes.putAttribute(SDK_HTTP_EXECUTION_ATTRIBUTES, attributes); } + + @Override + public void beforeTransmission(Context.BeforeTransmission context, + ExecutionAttributes executionAttributes) { + SdkHttpExecutionAttributes httpAttributes = executionAttributes.getAttribute(SDK_HTTP_EXECUTION_ATTRIBUTES); + if (httpAttributes != null) { + executionAttributes.putAttribute(SDK_HTTP_EXECUTION_ATTRIBUTES, + httpAttributes.toBuilder() + .put(SIGNING_REGION, executionAttributes.getAttribute( + AwsSignerExecutionAttribute.SIGNING_REGION)) + .put(SIGNING_NAME, executionAttributes.getAttribute(SERVICE_SIGNING_NAME)) + .build()); + } + } } private static final class ValidateRequestInterceptor implements ExecutionInterceptor { diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/handlers/EnableTrailingChecksumInterceptor.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/handlers/EnableTrailingChecksumInterceptor.java index 098ecff45f5f..e38cf40b71ff 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/handlers/EnableTrailingChecksumInterceptor.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/handlers/EnableTrailingChecksumInterceptor.java @@ -48,7 +48,7 @@ public SdkRequest modifyRequest(Context.ModifyRequest context, ExecutionAttribut SdkRequest request = context.request(); if (getObjectChecksumEnabledPerRequest(request, executionAttributes) - && S3ExpressUtils.useS3Express(executionAttributes)) { + && S3ExpressUtils.isS3ExpressBucket(request)) { return ((GetObjectRequest) request).toBuilder().checksumMode(ChecksumMode.ENABLED).build(); } return request; @@ -63,7 +63,7 @@ public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) { if (getObjectChecksumEnabledPerRequest(context.request(), executionAttributes) - && !S3ExpressUtils.useS3Express(executionAttributes)) { + && !S3ExpressUtils.isS3ExpressBucket(context.request())) { return context.httpRequest() .toBuilder() .putHeader(ENABLE_CHECKSUM_REQUEST_HEADER, ENABLE_MD5_CHECKSUM_HEADER_VALUE) diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressUtils.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressUtils.java index 441e575687e4..42c7928db434 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressUtils.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressUtils.java @@ -17,10 +17,13 @@ import static software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME; +import java.util.List; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.SdkRequest; import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.core.spi.identity.AuthSchemeOptionsResolver; import software.amazon.awssdk.core.useragent.BusinessMetricFeatureId; import software.amazon.awssdk.endpoints.Endpoint; import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; @@ -31,26 +34,37 @@ public final class S3ExpressUtils { public static final String S3_EXPRESS = "S3Express"; + private static final String S3_EXPRESS_BUCKET_SUFFIX = "--x-s3"; private S3ExpressUtils() { } /** - * Returns true if the resolved endpoint contains S3Express, else false. + * Determines if this request targets an S3Express bucket by checking the bucket name suffix. */ - public static boolean useS3Express(ExecutionAttributes executionAttributes) { - Endpoint endpoint = executionAttributes.getAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT); - if (endpoint != null) { - String useS3Express = endpoint.attribute(KnownS3ExpressEndpointProperty.BACKEND); - return S3_EXPRESS.equals(useS3Express); + public static boolean isS3ExpressBucket(SdkRequest request) { + return request.getValueForField("Bucket", String.class) + .map(b -> b.endsWith(S3_EXPRESS_BUCKET_SUFFIX)) + .orElse(false); + } + + /** + * Determines if this request uses S3Express auth by checking the auth scheme options. + */ + public static boolean isS3ExpressAuthRequest(SdkRequest request, ExecutionAttributes executionAttributes) { + AuthSchemeOptionsResolver resolver = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER); + if (resolver != null) { + List options = resolver.resolve(request); + return options.stream().anyMatch(o -> S3ExpressAuthScheme.SCHEME_ID.equals(o.schemeId())); } return false; } /** - * Whether aws.auth#sigv4-s3express is used or not + * Whether aws.auth#sigv4-s3express is the selected auth scheme. */ - public static boolean useS3ExpressAuthScheme(ExecutionAttributes executionAttributes) { + private static boolean useS3ExpressAuthScheme(ExecutionAttributes executionAttributes) { SelectedAuthScheme selectedAuthScheme = executionAttributes.getAttribute(SELECTED_AUTH_SCHEME); if (selectedAuthScheme != null) { AuthSchemeOption authSchemeOption = selectedAuthScheme.authSchemeOption(); @@ -62,8 +76,10 @@ public static boolean useS3ExpressAuthScheme(ExecutionAttributes executionAttrib /** * Adds S3 Express business metric if applicable for the current operation. */ - public static void addS3ExpressBusinessMetricIfApplicable(ExecutionAttributes executionAttributes) { - if (executionAttributes != null && useS3Express(executionAttributes) && useS3ExpressAuthScheme(executionAttributes)) { + public static void addS3ExpressBusinessMetricIfApplicable(Endpoint endpoint, ExecutionAttributes executionAttributes) { + if (endpoint != null && executionAttributes != null + && S3_EXPRESS.equals(endpoint.attribute(KnownS3ExpressEndpointProperty.BACKEND)) + && useS3ExpressAuthScheme(executionAttributes)) { executionAttributes.getOptionalAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS) .ifPresent(businessMetrics -> businessMetrics.addMetric(BusinessMetricFeatureId.S3_EXPRESS_BUCKET.value())); diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/signing/DefaultS3Presigner.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/signing/DefaultS3Presigner.java index 48e09c49a6ba..d40ad32c01c2 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/signing/DefaultS3Presigner.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/signing/DefaultS3Presigner.java @@ -32,7 +32,9 @@ import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.function.Function; +import java.util.stream.Collectors; import java.util.stream.Stream; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.auth.signer.AwsSignerExecutionAttribute; @@ -41,19 +43,27 @@ import software.amazon.awssdk.awscore.client.builder.AwsDefaultClientBuilder; import software.amazon.awssdk.awscore.defaultsmode.DefaultsMode; import software.amazon.awssdk.awscore.endpoint.AwsClientEndpointProvider; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.internal.AwsExecutionContextBuilder; import software.amazon.awssdk.awscore.internal.defaultsmode.DefaultsModeConfiguration; import software.amazon.awssdk.awscore.presigner.PresignRequest; import software.amazon.awssdk.awscore.presigner.PresignedRequest; +import software.amazon.awssdk.core.ClientEndpointProvider; import software.amazon.awssdk.core.ClientType; import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.core.SdkClient; import software.amazon.awssdk.core.SdkRequest; import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.builder.SdkDefaultClientBuilder; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.ExecutionContext; +import software.amazon.awssdk.core.http.auth.AuthSchemeResolver; +import software.amazon.awssdk.core.identity.SdkIdentityProperty; import software.amazon.awssdk.core.interceptor.ClasspathInterceptorChainFactory; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; @@ -64,6 +74,8 @@ import software.amazon.awssdk.core.signer.Presigner; import software.amazon.awssdk.core.signer.Signer; import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.endpoints.EndpointProvider; import software.amazon.awssdk.http.ContentStreamProvider; import software.amazon.awssdk.http.SdkHttpFullRequest; import software.amazon.awssdk.http.SdkHttpMethod; @@ -71,6 +83,7 @@ import software.amazon.awssdk.http.auth.aws.scheme.AwsV4AuthScheme; import software.amazon.awssdk.http.auth.aws.scheme.AwsV4aAuthScheme; import software.amazon.awssdk.http.auth.aws.signer.AwsV4FamilyHttpSigner; +import software.amazon.awssdk.http.auth.aws.signer.RegionSet; import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; @@ -84,12 +97,13 @@ import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.s3.S3Configuration; +import software.amazon.awssdk.services.s3.auth.scheme.S3AuthSchemeParams; import software.amazon.awssdk.services.s3.auth.scheme.S3AuthSchemeProvider; -import software.amazon.awssdk.services.s3.auth.scheme.internal.S3AuthSchemeInterceptor; +import software.amazon.awssdk.services.s3.auth.scheme.internal.S3EndpointResolverAware; import software.amazon.awssdk.services.s3.endpoints.S3ClientContextParams; +import software.amazon.awssdk.services.s3.endpoints.S3EndpointParams; import software.amazon.awssdk.services.s3.endpoints.S3EndpointProvider; -import software.amazon.awssdk.services.s3.endpoints.internal.S3RequestSetEndpointInterceptor; -import software.amazon.awssdk.services.s3.endpoints.internal.S3ResolveEndpointInterceptor; +import software.amazon.awssdk.services.s3.endpoints.internal.S3EndpointResolverUtils; import software.amazon.awssdk.services.s3.internal.endpoints.UseGlobalEndpointResolver; import software.amazon.awssdk.services.s3.internal.s3express.S3ExpressAuthSchemeProvider; import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; @@ -131,6 +145,7 @@ import software.amazon.awssdk.services.s3.transform.PutObjectRequestMarshaller; import software.amazon.awssdk.services.s3.transform.UploadPartRequestMarshaller; import software.amazon.awssdk.utils.AttributeMap; +import software.amazon.awssdk.utils.CollectionUtils; import software.amazon.awssdk.utils.CompletableFutureUtils; import software.amazon.awssdk.utils.IoUtils; import software.amazon.awssdk.utils.Logger; @@ -239,11 +254,6 @@ private List initializeInterceptors() { ClasspathInterceptorChainFactory interceptorFactory = new ClasspathInterceptorChainFactory(); List s3Interceptors = interceptorFactory.getInterceptors("software/amazon/awssdk/services/s3/execution.interceptors"); - List additionalInterceptors = new ArrayList<>(); - additionalInterceptors.add(new S3AuthSchemeInterceptor()); - additionalInterceptors.add(new S3ResolveEndpointInterceptor()); - additionalInterceptors.add(new S3RequestSetEndpointInterceptor()); - s3Interceptors = mergeLists(s3Interceptors, additionalInterceptors); return mergeLists(interceptorFactory.getGlobalInterceptors(), s3Interceptors); } @@ -405,6 +415,12 @@ private T presign(T presignedRequest, addRequestLevelHeadersAndQueryParameters(execCtx); callModifyHttpRequestHooksAndUpdateContext(execCtx); + // Resolve auth scheme after interceptors complete + resolveAndSelectAuthScheme(execCtx, requestToPresign, operationName); + + // Resolve endpoint + resolveEndpointAndUpdateContext(execCtx, operationName); + SdkHttpFullRequest httpRequest = getHttpFullRequest(execCtx); SdkHttpFullRequest signedHttpRequest = execCtx.signer() != null @@ -594,6 +610,121 @@ private SdkHttpFullRequest getHttpFullRequest(ExecutionContext execCtx) { .build(); } + /** + * Resolve and select auth scheme for presigning. + */ + private void resolveAndSelectAuthScheme(ExecutionContext execCtx, SdkRequest request, String operationName) { + ExecutionAttributes executionAttributes = execCtx.executionAttributes(); + + Map> authSchemes = executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES); + if (authSchemes == null) { + return; + } + + List authOptions = resolveAuthSchemeOptions(request, operationName, executionAttributes); + if (authOptions == null || authOptions.isEmpty()) { + return; + } + + IdentityProviders identityProviders = executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS); + + SelectedAuthScheme selectedAuthScheme = + AuthSchemeResolver.selectAuthScheme(authOptions, authSchemes, identityProviders, null); + + selectedAuthScheme = AuthSchemeResolver.mergePreExistingAuthSchemeProperties(selectedAuthScheme, executionAttributes); + + executionAttributes.putAttribute(SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + + /** + * Resolve the endpoint using the rules engine and apply it to the HTTP request in the execution context. + */ + private void resolveEndpointAndUpdateContext(ExecutionContext execCtx, String operationName) { + ExecutionAttributes executionAttributes = execCtx.executionAttributes(); + SdkRequest sdkRequest = execCtx.interceptorContext().request(); + + S3EndpointParams endpointParams = S3EndpointResolverUtils.ruleParams(sdkRequest, executionAttributes); + S3EndpointProvider provider = (S3EndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + + Endpoint endpoint; + try { + endpoint = provider.resolveEndpoint(endpointParams).join(); + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = S3EndpointResolverUtils.hostPrefix(operationName, sdkRequest); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes.getAttribute(SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = S3EndpointResolverUtils.authSchemeWithEndpointSignerProperties( + endpointAuthSchemes, selectedAuthScheme); + executionAttributes.putAttribute(SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + + SdkHttpRequest httpRequest = execCtx.interceptorContext().httpRequest(); + ClientEndpointProvider clientEndpointProvider = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER); + SdkHttpRequest updatedRequest = AwsEndpointProviderUtils.setUri(httpRequest, + clientEndpointProvider.clientEndpoint(), endpoint.url()); + + if (!endpoint.headers().isEmpty()) { + SdkHttpRequest.Builder requestBuilder = updatedRequest.toBuilder(); + endpoint.headers().forEach((name, values) -> + values.forEach(v -> requestBuilder.appendHeader(name, v))); + updatedRequest = requestBuilder.build(); + } + + SdkHttpRequest finalRequest = updatedRequest; + execCtx.interceptorContext(execCtx.interceptorContext().copy(c -> c.httpRequest(finalRequest))); + } + + /** + * Resolve auth scheme options using full endpoint params. + */ + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + ExecutionAttributes executionAttributes) { + S3AuthSchemeProvider authSchemeProvider = Validate.isInstanceOf(S3AuthSchemeProvider.class, + executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_RESOLVER), + "Expected an instance of S3AuthSchemeProvider"); + + S3EndpointParams endpointParams = S3EndpointResolverUtils.ruleParams(request, executionAttributes); + S3AuthSchemeParams.Builder paramsBuilder = S3AuthSchemeParams.fromEndpointParams(endpointParams) + .operation(operationName); + + executionAttributes.getOptionalAttribute(AwsExecutionAttribute.AWS_SIGV4A_SIGNING_REGION_SET) + .filter(regionSet -> !CollectionUtils.isNullOrEmpty(regionSet)) + .ifPresent(nonEmptyRegionSet -> paramsBuilder.regionSet(RegionSet.create(nonEmptyRegionSet))); + + if (paramsBuilder instanceof S3EndpointResolverAware.Builder) { + EndpointProvider endpointProvider = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + if (endpointProvider instanceof S3EndpointProvider) { + ((S3EndpointResolverAware.Builder) paramsBuilder).endpointProvider((S3EndpointProvider) endpointProvider); + } + } + + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + SdkClient sdkClient = executionAttributes.getAttribute(SdkInternalExecutionAttribute.SDK_CLIENT); + if (sdkClient != null) { + options = options.stream() + .map(o -> o.toBuilder().putIdentityProperty(SdkIdentityProperty.SDK_CLIENT, sdkClient).build()) + .collect(Collectors.toList()); + } + return options; + } + /** * Presign the provided HTTP request using old Signer */ diff --git a/services/s3/src/main/resources/codegen-resources/customization.config b/services/s3/src/main/resources/codegen-resources/customization.config index 4dee2d9e88d9..6df5a5bf4c49 100644 --- a/services/s3/src/main/resources/codegen-resources/customization.config +++ b/services/s3/src/main/resources/codegen-resources/customization.config @@ -206,6 +206,7 @@ "hasChecksumValidationEnabledProperty":true }, "skipEndpointTests": { + "Access points when Access points explicitly disabled (used for CreateBucket)": "CreateBucketInterceptor validates bucket name before endpoint resolution stage runs", "Invalid access point ARN: Not S3": "Test assumes UseArnRegion is true but SDK defaults to false", "Invalid access point ARN: AccountId is invalid": "Test assumes UseArnRegion is true but SDK defaults to false", "Invalid access point ARN: access point name is invalid": "Test assumes UseArnRegion is true but SDK defaults to false", diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/ExecutionAttributeBackwardsCompatibilityTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/ExecutionAttributeBackwardsCompatibilityTest.java index a4816d22e4d0..22b8c7f4efae 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/ExecutionAttributeBackwardsCompatibilityTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/ExecutionAttributeBackwardsCompatibilityTest.java @@ -72,10 +72,7 @@ public void canSetSignerExecutionAttributes_beforeExecution() { public void beforeExecution(Context.BeforeExecution context, ExecutionAttributes executionAttributes) { attributeModifications.accept(executionAttributes); } - }, - AwsSignerExecutionAttribute.SERVICE_SIGNING_NAME, // Endpoint rules override signing name - AwsSignerExecutionAttribute.SIGNING_REGION, // Endpoint rules override signing region - AwsSignerExecutionAttribute.SIGNER_DOUBLE_URL_ENCODE); // Endpoint rules override double-url-encode + }); } @Test diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/functionaltests/AsyncResponseTransformerTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/functionaltests/AsyncResponseTransformerTest.java index 4c2c5e748c0b..155306ab9faf 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/functionaltests/AsyncResponseTransformerTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/functionaltests/AsyncResponseTransformerTest.java @@ -17,6 +17,7 @@ import static org.assertj.core.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -50,6 +51,6 @@ public void AsyncResponseTransformerPrepareCalled_BeforeCredentialsResolution() } verify(asyncResponseTransformer, times(1)).prepare(); - verify(asyncResponseTransformer, times(1)).exceptionOccurred(any()); + verify(asyncResponseTransformer, atLeast(1)).exceptionOccurred(any()); } } diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/functionaltests/S3ExpressCreateSessionTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/functionaltests/S3ExpressCreateSessionTest.java index 815736f9d1e6..d7e1a31e7286 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/functionaltests/S3ExpressCreateSessionTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/functionaltests/S3ExpressCreateSessionTest.java @@ -71,8 +71,10 @@ public class S3ExpressCreateSessionTest extends BaseRuleSetClientTest { private static final Logger log = Logger.loggerFor(S3ExpressCreateSessionTest.class); - private static final Function WM_HTTP_ENDPOINT = wm -> URI.create(wm.getHttpBaseUrl()); - private static final Function WM_HTTPS_ENDPOINT = wm -> URI.create(wm.getHttpsBaseUrl()); + private static final Function WM_HTTP_ENDPOINT = + wm -> URI.create("http://127.0.0.1:" + wm.getHttpPort()); + private static final Function WM_HTTPS_ENDPOINT = + wm -> URI.create("https://127.0.0.1:" + wm.getHttpsPort()); private static final AwsCredentialsProvider CREDENTIALS_PROVIDER = StaticCredentialsProvider.create(AwsBasicCredentials.create("akid", "skid")); private static final PathStyleEnforcingInterceptor PATH_STYLE_INTERCEPTOR = new PathStyleEnforcingInterceptor(); @@ -316,7 +318,11 @@ private static final class PathStyleEnforcingInterceptor implements ExecutionInt public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) { SdkHttpRequest sdkHttpRequest = context.httpRequest(); String host = sdkHttpRequest.host(); - String bucket = host.substring(0, host.indexOf(".localhost")); + int idx = host.indexOf(".localhost"); + if (idx < 0) { + return sdkHttpRequest; + } + String bucket = host.substring(0, idx); return sdkHttpRequest.toBuilder().host("localhost") .encodedPath(SdkHttpUtils.appendUri(bucket, sdkHttpRequest.encodedPath())) diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/functionaltests/S3ExpressTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/functionaltests/S3ExpressTest.java index 10414df6c14f..9d138286ea8f 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/functionaltests/S3ExpressTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/functionaltests/S3ExpressTest.java @@ -76,8 +76,10 @@ @WireMockTest(httpsEnabled = true) public class S3ExpressTest extends BaseRuleSetClientTest { private static final Logger log = Logger.loggerFor(S3ExpressTest.class); - private static final Function WM_HTTP_ENDPOINT = wm -> URI.create(wm.getHttpBaseUrl()); - private static final Function WM_HTTPS_ENDPOINT = wm -> URI.create(wm.getHttpsBaseUrl()); + private static final Function WM_HTTP_ENDPOINT = + wm -> URI.create("http://127.0.0.1:" + wm.getHttpPort()); + private static final Function WM_HTTPS_ENDPOINT = + wm -> URI.create("https://127.0.0.1:" + wm.getHttpsPort()); private static final AwsCredentialsProvider CREDENTIALS_PROVIDER = StaticCredentialsProvider.create(AwsBasicCredentials.create("akid", "skid")); private static final PathStyleEnforcingInterceptor PATH_STYLE_INTERCEPTOR = new PathStyleEnforcingInterceptor(); @@ -411,7 +413,11 @@ private static final class PathStyleEnforcingInterceptor implements ExecutionInt public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) { SdkHttpRequest sdkHttpRequest = context.httpRequest(); String host = sdkHttpRequest.host(); - String bucket = host.substring(0, host.indexOf(".localhost")); + int idx = host.indexOf(".localhost"); + if (idx < 0) { + return sdkHttpRequest; + } + String bucket = host.substring(0, idx); return sdkHttpRequest.toBuilder().host("localhost") .encodedPath(SdkHttpUtils.appendUri(bucket, sdkHttpRequest.encodedPath())) diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionAsyncClientTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionAsyncClientTest.java index 1a45adee7a98..4d001dcd47c9 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionAsyncClientTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionAsyncClientTest.java @@ -71,6 +71,7 @@ import software.amazon.awssdk.services.s3.internal.crossregion.endpointprovider.BucketEndpointProvider; import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.HeadBucketRequest; import software.amazon.awssdk.services.s3.model.ListObjectsV2Response; import software.amazon.awssdk.services.s3.model.S3Exception; import software.amazon.awssdk.services.s3.model.S3Response; @@ -487,10 +488,10 @@ void given_globalRegion_Client_Updates_region_to_useast1_and_useGlobalEndpointFl assertThat(captureInterceptor.endpointProvider).isInstanceOf(BucketEndpointProvider.class); ArgumentCaptor collectionCaptor = ArgumentCaptor.forClass(S3EndpointParams.class); verify(mockEndpointProvider, atLeastOnce()).resolveEndpoint(collectionCaptor.capture()); - collectionCaptor.getAllValues().forEach(resolvedParams -> { - assertThat(resolvedParams.region()).isEqualTo(Region.US_EAST_1); - assertThat(resolvedParams.useGlobalEndpoint()).isFalse(); - }); + S3EndpointParams lastResolvedParams = collectionCaptor.getAllValues() + .get(collectionCaptor.getAllValues().size() - 1); + assertThat(lastResolvedParams.region()).isEqualTo(Region.US_EAST_1); + assertThat(lastResolvedParams.useGlobalEndpoint()).isFalse(); } @@ -529,7 +530,9 @@ private static final class CaptureInterceptor implements ExecutionInterceptor { @Override public void beforeMarshalling(Context.BeforeMarshalling context, ExecutionAttributes executionAttributes) { - endpointProvider = executionAttributes.getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + if (!(context.request() instanceof HeadBucketRequest)) { + endpointProvider = executionAttributes.getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + } } } } diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionSyncClientTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionSyncClientTest.java index e1187124ecb9..fdec8f328aca 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionSyncClientTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionSyncClientTest.java @@ -61,6 +61,7 @@ import software.amazon.awssdk.services.s3.endpoints.internal.DefaultS3EndpointProvider; import software.amazon.awssdk.services.s3.internal.crossregion.endpointprovider.BucketEndpointProvider; import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.HeadBucketRequest; import software.amazon.awssdk.services.s3.model.ListObjectsV2Response; import software.amazon.awssdk.services.s3.model.S3Exception; import software.amazon.awssdk.services.s3.paginators.ListObjectsV2Iterable; @@ -303,10 +304,10 @@ void given_globalRegion_Client_Updates_region_to_useast1_and_useGlobalEndpointFl assertThat(captureInterceptor.endpointProvider).isInstanceOf(BucketEndpointProvider.class); ArgumentCaptor collectionCaptor = ArgumentCaptor.forClass(S3EndpointParams.class); verify(mockEndpointProvider, atLeastOnce()).resolveEndpoint(collectionCaptor.capture()); - collectionCaptor.getAllValues().forEach(resolvedParams ->{ - assertThat(resolvedParams.region()).isEqualTo(Region.US_EAST_1); - assertThat(resolvedParams.useGlobalEndpoint()).isFalse(); - }); + S3EndpointParams lastResolvedParams = collectionCaptor.getAllValues() + .get(collectionCaptor.getAllValues().size() - 1); + assertThat(lastResolvedParams.region()).isEqualTo(Region.US_EAST_1); + assertThat(lastResolvedParams.useGlobalEndpoint()).isFalse(); } private static GetObjectRequest.Builder getObjectBuilder() { @@ -509,7 +510,9 @@ private static final class CaptureInterceptor implements ExecutionInterceptor { @Override public void beforeMarshalling(Context.BeforeMarshalling context, ExecutionAttributes executionAttributes) { - endpointProvider = executionAttributes.getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + if (!(context.request() instanceof HeadBucketRequest)) { + endpointProvider = executionAttributes.getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + } } } } diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/DefaultS3CrtAsyncClientTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/DefaultS3CrtAsyncClientTest.java index 0276d116b16b..ca34f9c045ca 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/DefaultS3CrtAsyncClientTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/DefaultS3CrtAsyncClientTest.java @@ -22,7 +22,9 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; import software.amazon.awssdk.auth.signer.AwsS3V4Signer; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncResponseTransformer; @@ -31,8 +33,10 @@ import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.http.SdkHttpExecutionAttributes; import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; import software.amazon.awssdk.identity.spi.IdentityProvider; +import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.s3.DelegatingS3AsyncClient; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.endpoints.S3ClientContextParams; @@ -161,4 +165,75 @@ void build_withAdvancedOptions() { assertThat(client).isInstanceOf(DefaultS3CrtAsyncClient.class); } } + + @Test + void s3ExpressBucket_defaultConfig_useS3ExpressAuthIsTrue() { + AtomicReference capturedUseS3ExpressAuth = new AtomicReference<>(); + + ExecutionInterceptor captor = new ExecutionInterceptor() { + @Override + public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) { + SdkHttpExecutionAttributes httpAttrs = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.SDK_HTTP_EXECUTION_ATTRIBUTES); + if (httpAttrs != null) { + capturedUseS3ExpressAuth.set(httpAttrs.getAttribute(S3InternalSdkHttpExecutionAttribute.USE_S3_EXPRESS_AUTH)); + } + throw new RuntimeException("STOP"); + } + }; + + DefaultS3CrtAsyncClient.DefaultS3CrtClientBuilder builder = + (DefaultS3CrtAsyncClient.DefaultS3CrtClientBuilder) S3CrtAsyncClient.builder(); + builder.addExecutionInterceptor(captor); + + try (S3AsyncClient client = builder + .region(Region.US_EAST_1) + .credentialsProvider(StaticCredentialsProvider.create( + AwsBasicCredentials.create("key", "secret"))) + .build()) { + + assertThatThrownBy(() -> client.getObject( + r -> r.bucket("my-bucket--usw2-az1--x-s3").key("key"), + AsyncResponseTransformer.toBytes()).join()) + .hasMessageContaining("STOP"); + } + + assertThat(capturedUseS3ExpressAuth.get()).isTrue(); + } + + @Test + void s3ExpressBucket_disableS3ExpressSessionAuth_useS3ExpressAuthIsFalse() { + AtomicReference capturedUseS3ExpressAuth = new AtomicReference<>(); + + ExecutionInterceptor captor = new ExecutionInterceptor() { + @Override + public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) { + SdkHttpExecutionAttributes httpAttrs = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.SDK_HTTP_EXECUTION_ATTRIBUTES); + if (httpAttrs != null) { + capturedUseS3ExpressAuth.set(httpAttrs.getAttribute(S3InternalSdkHttpExecutionAttribute.USE_S3_EXPRESS_AUTH)); + } + throw new RuntimeException("STOP"); + } + }; + + DefaultS3CrtAsyncClient.DefaultS3CrtClientBuilder builder = + (DefaultS3CrtAsyncClient.DefaultS3CrtClientBuilder) S3CrtAsyncClient.builder(); + builder.addExecutionInterceptor(captor); + + try (S3AsyncClient client = builder + .region(Region.US_EAST_1) + .credentialsProvider(StaticCredentialsProvider.create( + AwsBasicCredentials.create("key", "secret"))) + .disableS3ExpressSessionAuth(true) + .build()) { + + assertThatThrownBy(() -> client.getObject( + r -> r.bucket("my-bucket--usw2-az1--x-s3").key("key"), + AsyncResponseTransformer.toBytes()).join()) + .hasMessageContaining("STOP"); + } + + assertThat(capturedUseS3ExpressAuth.get()).isFalse(); + } } diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/handlers/EnableTrailingChecksumInterceptorTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/handlers/EnableTrailingChecksumInterceptorTest.java index f5e56277dca0..d1875027a32f 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/handlers/EnableTrailingChecksumInterceptorTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/handlers/EnableTrailingChecksumInterceptorTest.java @@ -36,11 +36,8 @@ import software.amazon.awssdk.core.checksums.ResponseChecksumValidation; import software.amazon.awssdk.core.interceptor.Context; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; -import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; -import software.amazon.awssdk.endpoints.Endpoint; import software.amazon.awssdk.http.SdkHttpFullResponse; import software.amazon.awssdk.http.SdkHttpRequest; -import software.amazon.awssdk.services.s3.endpoints.internal.KnownS3ExpressEndpointProperty; import software.amazon.awssdk.services.s3.model.ChecksumMode; import software.amazon.awssdk.services.s3.model.GetObjectAclRequest; import software.amazon.awssdk.services.s3.model.GetObjectAclResponse; @@ -62,8 +59,8 @@ private static Stream getObjectRequestParams() { @ParameterizedTest @MethodSource("getObjectRequestParams") public void testGetObjectModifyRequest(boolean useS3Express, boolean disableChecksumValidation, boolean checksumModeEnabled) { - GetObjectRequest request = createGetObjectRequest(checksumModeEnabled); - ExecutionAttributes executionAttributes = createExecutionAttributes(disableChecksumValidation, useS3Express); + GetObjectRequest request = createGetObjectRequest(checksumModeEnabled, useS3Express); + ExecutionAttributes executionAttributes = createExecutionAttributes(disableChecksumValidation); Context.ModifyRequest modifyRequestContext = () -> request; SdkRequest modifiedRequest = interceptor.modifyRequest(modifyRequestContext, executionAttributes); @@ -106,8 +103,11 @@ private void validateGetObjectValues(SdkHttpRequest sdkHttpRequest, GetObjectReq } } - private GetObjectRequest createGetObjectRequest(boolean checksumModeEnabled) { + private GetObjectRequest createGetObjectRequest(boolean checksumModeEnabled, boolean isS3ExpressBucket) { GetObjectRequest.Builder requestBuilder = GetObjectRequest.builder(); + if (isS3ExpressBucket) { + requestBuilder.bucket("my-bucket--x-s3"); + } if (checksumModeEnabled) { requestBuilder.checksumMode(ChecksumMode.ENABLED); } @@ -120,7 +120,7 @@ public void modifyRequest_nonGetObjectRequest_shouldNotModify() { boolean useS3Express = false; boolean disableChecksumValidation = false; - ExecutionAttributes executionAttributes = createExecutionAttributes(disableChecksumValidation, useS3Express); + ExecutionAttributes executionAttributes = createExecutionAttributes(disableChecksumValidation); Context.ModifyRequest modifyRequestContext = () -> request; SdkRequest modifiedRequest = interceptor.modifyRequest(modifyRequestContext, executionAttributes); @@ -130,7 +130,7 @@ public void modifyRequest_nonGetObjectRequest_shouldNotModify() { assertThat(sdkHttpRequest.headers().get(ENABLE_CHECKSUM_REQUEST_HEADER)).isNull(); } - private ExecutionAttributes createExecutionAttributes(boolean disableChecksumValidation, boolean useS3Express) { + private ExecutionAttributes createExecutionAttributes(boolean disableChecksumValidation) { ExecutionAttributes executionAttributes = new ExecutionAttributes(); if (disableChecksumValidation) { executionAttributes.putAttribute(REQUEST_CHECKSUM_CALCULATION, RequestChecksumCalculation.WHEN_REQUIRED); @@ -139,10 +139,6 @@ private ExecutionAttributes createExecutionAttributes(boolean disableChecksumVal executionAttributes.putAttribute(REQUEST_CHECKSUM_CALCULATION, RequestChecksumCalculation.WHEN_SUPPORTED); executionAttributes.putAttribute(RESPONSE_CHECKSUM_VALIDATION, ResponseChecksumValidation.WHEN_SUPPORTED); } - if (useS3Express) { - Endpoint s3ExpressEndpoint = Endpoint.builder().putAttribute(KnownS3ExpressEndpointProperty.BACKEND, "S3Express").build(); - executionAttributes.putAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT, s3ExpressEndpoint); - } return executionAttributes; } diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/plugins/ListOfStringsAuthParamPluginTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/plugins/ListOfStringsAuthParamPluginTest.java index 2f3a2b6f2194..12cd20fac3cf 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/plugins/ListOfStringsAuthParamPluginTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/plugins/ListOfStringsAuthParamPluginTest.java @@ -18,119 +18,74 @@ import static java.util.Arrays.asList; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; -import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; -import com.github.tomakehurst.wiremock.junit5.WireMockTest; import java.net.URI; import java.util.Collections; import java.util.List; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; -import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; -import software.amazon.awssdk.core.SdkPlugin; -import software.amazon.awssdk.core.SdkServiceClientConfiguration; -import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; +import software.amazon.awssdk.awscore.AwsExecutionAttribute; +import software.amazon.awssdk.core.ClientEndpointProvider; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.s3.S3AsyncClient; -import software.amazon.awssdk.services.s3.S3ServiceClientConfiguration; -import software.amazon.awssdk.services.s3.auth.scheme.S3AuthSchemeParams; -import software.amazon.awssdk.services.s3.auth.scheme.S3AuthSchemeProvider; +import software.amazon.awssdk.services.s3.endpoints.S3EndpointParams; +import software.amazon.awssdk.services.s3.endpoints.internal.S3EndpointResolverUtils; import software.amazon.awssdk.services.s3.model.Delete; import software.amazon.awssdk.services.s3.model.DeleteObjectsRequest; import software.amazon.awssdk.services.s3.model.ObjectIdentifier; +import software.amazon.awssdk.utils.AttributeMap; -@WireMockTest class ListOfStringsAuthParamPluginTest { - private static S3AsyncClient s3Client; - private ListOfStringsParamPlugin plugin; - - @BeforeEach - public void init(WireMockRuntimeInfo wm) { - plugin = new ListOfStringsParamPlugin(); - - AwsBasicCredentials credentials = AwsBasicCredentials.create("key", "secret"); - s3Client = S3AsyncClient.builder() - .region(Region.US_EAST_1) - .endpointOverride(URI.create(wm.getHttpBaseUrl())) - .addPlugin(plugin) - .credentialsProvider(StaticCredentialsProvider.create(credentials)) - .build(); - } - @Test void callingDeleteObjects_requestWithCompleteListOfKey_returnsRightValues() { - s3Client.deleteObjects(r -> r.bucket("test").delete(o -> o.objects( + DeleteObjectsRequest request = DeleteObjectsRequest.builder().bucket("test").delete(d -> d.objects( ObjectIdentifier.builder().key("x").versionId("1").build(), ObjectIdentifier.builder().key("y").versionId("2").build() - ))); - assertThat(plugin.storedKeys()).isEqualTo(asList("x", "y")); + )).build(); + S3EndpointParams params = S3EndpointResolverUtils.ruleParams(request, deleteObjectsAttributes()); + assertThat(params.deleteObjectKeys()).isEqualTo(asList("x", "y")); } @Test void callingDeleteObjects_requestWithInCompleteListOfKey_returnsRightValues() { - s3Client.deleteObjects(r -> r.bucket("test").delete(o -> o.objects( + DeleteObjectsRequest request = DeleteObjectsRequest.builder().bucket("test").delete(d -> d.objects( ObjectIdentifier.builder().versionId("1").build(), ObjectIdentifier.builder().key("y").versionId("2").build() - ))); - assertThat(plugin.storedKeys()).isEqualTo(asList("y")); + )).build(); + S3EndpointParams params = S3EndpointResolverUtils.ruleParams(request, deleteObjectsAttributes()); + assertThat(params.deleteObjectKeys()).isEqualTo(asList("y")); } @Test void callingDeleteObjects_requestWithNoList_returnsRightValues() { - s3Client.deleteObjects(DeleteObjectsRequest.builder().delete(Delete.builder().build()).build()); - assertThat(plugin.storedKeys()).asList().isEmpty(); + DeleteObjectsRequest request = DeleteObjectsRequest.builder().delete(Delete.builder().build()).build(); + S3EndpointParams params = S3EndpointResolverUtils.ruleParams(request, deleteObjectsAttributes()); + assertThat(params.deleteObjectKeys()).asList().isEmpty(); } @Test void callingDeleteObjects_requestWithoutEmptyList_returnsRightValues() { - s3Client.deleteObjects(DeleteObjectsRequest.builder().delete(Delete.builder().objects(Collections.emptyList()).build()).build()); - assertThat(plugin.storedKeys()).asList().isEmpty(); + DeleteObjectsRequest request = DeleteObjectsRequest.builder().delete(Delete.builder().objects(Collections.emptyList()).build()).build(); + S3EndpointParams params = S3EndpointResolverUtils.ruleParams(request, deleteObjectsAttributes()); + assertThat(params.deleteObjectKeys()).asList().isEmpty(); } @Test void callingDeleteObjects_requestWithListButNoKey_returnsRightValues() { List objects = asList(ObjectIdentifier.builder().build()); - s3Client.deleteObjects(DeleteObjectsRequest.builder().delete(Delete.builder().objects(objects).build()).build()); - assertThat(plugin.storedKeys()).asList().isEmpty(); + DeleteObjectsRequest request = DeleteObjectsRequest.builder().delete(Delete.builder().objects(objects).build()).build(); + S3EndpointParams params = S3EndpointResolverUtils.ruleParams(request, deleteObjectsAttributes()); + assertThat(params.deleteObjectKeys()).asList().isEmpty(); } - private static class ListOfStringsParamPlugin implements SdkPlugin { - - private ListOfStringsParamAuthSchemeProvider localProvider; - - @Override - public void configureClient(SdkServiceClientConfiguration.Builder config) { - S3ServiceClientConfiguration.Builder serviceClientConfiguration = (S3ServiceClientConfiguration.Builder) config; - S3AuthSchemeProvider defaultProvider = serviceClientConfiguration.authSchemeProvider(); - localProvider = new ListOfStringsParamAuthSchemeProvider(defaultProvider); - serviceClientConfiguration.authSchemeProvider(localProvider); - } - - List storedKeys() { - return localProvider.storedKeys; - } - - private static class ListOfStringsParamAuthSchemeProvider implements S3AuthSchemeProvider { - - private List storedKeys; - S3AuthSchemeProvider delegate; - - public ListOfStringsParamAuthSchemeProvider(S3AuthSchemeProvider authSchemeProvider) { - this.delegate = authSchemeProvider; - } - - @Override - public List resolveAuthScheme(S3AuthSchemeParams authSchemeParams) { - List availableAuthSchemes = delegate.resolveAuthScheme(authSchemeParams); - - List keys = authSchemeParams.deleteObjectKeys(); - if (keys != null) { - storedKeys = keys; - } - return availableAuthSchemes; - } - } + private static ExecutionAttributes deleteObjectsAttributes() { + ExecutionAttributes attrs = new ExecutionAttributes(); + attrs.putAttribute(SdkExecutionAttribute.OPERATION_NAME, "DeleteObjects"); + attrs.putAttribute(AwsExecutionAttribute.AWS_REGION, Region.US_EAST_1); + attrs.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, + ClientEndpointProvider.forEndpointOverride(URI.create("https://s3.us-east-1.amazonaws.com"))); + attrs.putAttribute(SdkInternalExecutionAttribute.CLIENT_CONTEXT_PARAMS, AttributeMap.empty()); + return attrs; } - } diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressAuthSchemeProviderTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressAuthSchemeProviderTest.java index 13c7b37ab958..1a176f1e119b 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressAuthSchemeProviderTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressAuthSchemeProviderTest.java @@ -17,125 +17,48 @@ import static org.assertj.core.api.AssertionsForClassTypes.assertThat; -import java.net.URI; -import java.util.HashMap; -import java.util.Map; -import java.util.concurrent.CompletableFuture; +import java.util.List; import org.junit.jupiter.api.Test; -import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; -import software.amazon.awssdk.awscore.AwsExecutionAttribute; -import software.amazon.awssdk.core.ClientEndpointProvider; -import software.amazon.awssdk.core.SelectedAuthScheme; -import software.amazon.awssdk.core.interceptor.ExecutionAttributes; -import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; -import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; -import software.amazon.awssdk.http.auth.aws.scheme.AwsV4AuthScheme; -import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; -import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; -import software.amazon.awssdk.identity.spi.IdentityProvider; -import software.amazon.awssdk.identity.spi.IdentityProviders; -import software.amazon.awssdk.identity.spi.ResolveIdentityRequest; -import software.amazon.awssdk.identity.spi.internal.DefaultIdentityProviders; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.auth.scheme.S3AuthSchemeParams; +import software.amazon.awssdk.services.s3.auth.scheme.S3AuthSchemeProvider; import software.amazon.awssdk.services.s3.auth.scheme.internal.DefaultS3AuthSchemeProvider; -import software.amazon.awssdk.services.s3.auth.scheme.internal.S3AuthSchemeInterceptor; -import software.amazon.awssdk.services.s3.endpoints.S3ClientContextParams; -import software.amazon.awssdk.services.s3.endpoints.S3EndpointProvider; -import software.amazon.awssdk.services.s3.model.PutObjectRequest; -import software.amazon.awssdk.services.s3.s3express.S3ExpressAuthScheme; -import software.amazon.awssdk.services.s3.s3express.S3ExpressSessionCredentials; -import software.amazon.awssdk.utils.AttributeMap; class S3ExpressAuthSchemeProviderTest { private static final String S3EXPRESS_BUCKET = "s3expressformat--use1-az1--x-s3"; @Test - public void s3express_defaultAuthEnabled_returnspressAuthScheme() { - PutObjectRequest request = PutObjectRequest.builder().bucket(S3EXPRESS_BUCKET).key("k").build(); - AttributeMap clientContextParams = AttributeMap.builder().build(); - ExecutionAttributes executionAttributes = requiredExecutionAttributes(clientContextParams); - - new S3AuthSchemeInterceptor().beforeExecution(() -> request, executionAttributes); - - SelectedAuthScheme attribute = executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); - assertThat(attribute).isNotNull(); - verifyAuthScheme("aws.auth#sigv4-s3express", attribute); - } - - private ExecutionAttributes requiredExecutionAttributes(AttributeMap clientContextParams) { - ExecutionAttributes executionAttributes = new ExecutionAttributes(); - executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_RESOLVER, DefaultS3AuthSchemeProvider.create()); - executionAttributes.putAttribute(SdkExecutionAttribute.OPERATION_NAME, "PutObject"); - executionAttributes.putAttribute(AwsExecutionAttribute.AWS_REGION, Region.US_EAST_1); - executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_CONTEXT_PARAMS, clientContextParams); - executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, authSchemesWithS3Express()); - executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, - DefaultIdentityProviders.builder() - .putIdentityProvider(DefaultCredentialsProvider.create()) - .build()); - executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, - ClientEndpointProvider.forEndpointOverride(URI.create("https://localhost"))); - executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER, - S3EndpointProvider.defaultProvider()); - return executionAttributes; + public void s3express_defaultAuthEnabled_returnsS3ExpressAuthScheme() { + S3AuthSchemeProvider provider = DefaultS3AuthSchemeProvider.create(); + S3AuthSchemeParams params = S3AuthSchemeParams.builder() + .operation("PutObject") + .region(Region.US_EAST_1) + .bucket(S3EXPRESS_BUCKET) + .build(); + + List authOptions = provider.resolveAuthScheme(params); + + assertThat(authOptions).isNotNull(); + assertThat(authOptions.isEmpty()).isFalse(); + assertThat(authOptions.get(0).schemeId()).isEqualTo("aws.auth#sigv4-s3express"); } @Test public void s3express_authDisabled_returnsV4AuthScheme() { - PutObjectRequest request = PutObjectRequest.builder().bucket(S3EXPRESS_BUCKET).key("k").build(); - AttributeMap clientContextParams = AttributeMap.builder() - .put(S3ClientContextParams.DISABLE_S3_EXPRESS_SESSION_AUTH, true) - .build(); - ExecutionAttributes executionAttributes = requiredExecutionAttributes(clientContextParams); - - new S3AuthSchemeInterceptor().beforeExecution(() -> request, executionAttributes); - - SelectedAuthScheme attribute = executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); - assertThat(attribute).isNotNull(); - verifyAuthScheme("aws.auth#sigv4", attribute); - } - - private void verifyAuthScheme(String expectedAuthSchemeId, SelectedAuthScheme authScheme) { - assertThat(authScheme).isNotNull(); - assertThat(authScheme.authSchemeOption()).isNotNull(); - assertThat(authScheme.authSchemeOption().schemeId()).isEqualTo(expectedAuthSchemeId); - - assertThat(authScheme.identity()).isNotNull(); - assertThat(authScheme.signer()).isNotNull(); - } - - private Map> authSchemesWithS3Express() { - Map> schemes = new HashMap<>(); - AwsV4AuthScheme awsV4AuthScheme = AwsV4AuthScheme.create(); - schemes.put(awsV4AuthScheme.schemeId(), awsV4AuthScheme); - S3ExpressAuthScheme s3ExpressAuthScheme = new S3ExpressAuthScheme() { - @Override - public IdentityProvider identityProvider(IdentityProviders providers) { - return new IdentityProvider() { - @Override - public Class identityType() { - return null; - } - - @Override - public CompletableFuture resolveIdentity(ResolveIdentityRequest request) { - return CompletableFuture.completedFuture(S3ExpressSessionCredentials.create("a","b","c")); - } - }; - } - - @Override - public HttpSigner signer() { - return DefaultS3ExpressHttpSigner.create(); - } - - @Override - public String schemeId() { - return SCHEME_ID; - } - }; - schemes.put(s3ExpressAuthScheme.schemeId(), s3ExpressAuthScheme); - return schemes; + S3AuthSchemeProvider provider = DefaultS3AuthSchemeProvider.create(); + S3AuthSchemeParams params = S3AuthSchemeParams.builder() + .operation("PutObject") + .region(Region.US_EAST_1) + .bucket(S3EXPRESS_BUCKET) + .disableS3ExpressSessionAuth(true) + .build(); + + List authOptions = provider.resolveAuthScheme(params); + + assertThat(authOptions).isNotNull(); + assertThat(authOptions.isEmpty()).isFalse(); + assertThat(authOptions.get(0).schemeId()).isEqualTo("aws.auth#sigv4"); } -} \ No newline at end of file +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressCacheFunctionalTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressCacheFunctionalTest.java index 22b3e3983b71..32e0a4fdc612 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressCacheFunctionalTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressCacheFunctionalTest.java @@ -44,6 +44,7 @@ import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.interceptor.Context; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; @@ -68,7 +69,8 @@ @WireMockTest(httpsEnabled = true) public class S3ExpressCacheFunctionalTest { - private static final Function WM_HTTPS_ENDPOINT = wm -> URI.create(wm.getHttpsBaseUrl()); + private static final Function WM_HTTPS_ENDPOINT = + wm -> URI.create("https://127.0.0.1:" + wm.getHttpsPort()); private static final PathStyleEnforcingInterceptor PATH_STYLE_INTERCEPTOR = new PathStyleEnforcingInterceptor(); private static final String S3EXPRESS_BUCKET_1 = "s3express-cache-1--use1-az1--x-s3"; private static final String S3EXPRESS_BUCKET_2 = "s3express-cache-2--use1-az1--x-s3"; @@ -278,21 +280,29 @@ public List> apiCredentialsProviders() @Override public void beforeExecution(Context.BeforeExecution context, ExecutionAttributes executionAttributes) { - IdentityProviders providers = executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS); - IdentityProvider awsCredentialsIdentityIdentityProvider = - providers.identityProvider(AwsCredentialsIdentity.class); + + IdentityProvider credentialsProvider = context.request() + .overrideConfiguration() + .filter(c -> c instanceof AwsRequestOverrideConfiguration) + .map(c -> (AwsRequestOverrideConfiguration) c) + .flatMap(AwsRequestOverrideConfiguration::credentialsIdentityProvider) + .map(p -> (IdentityProvider) p) + .orElseGet(() -> { + IdentityProviders providers = executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS); + return providers.identityProvider(AwsCredentialsIdentity.class); + }); String operationName = executionAttributes.getAttribute(SdkExecutionAttribute.OPERATION_NAME); if (operationName.equalsIgnoreCase("createsession")) { sessionRequests++; - sessionCredentialsProvider.add(awsCredentialsIdentityIdentityProvider); + sessionCredentialsProvider.add(credentialsProvider); } else { - apiCredentialsProvider.add(awsCredentialsIdentityIdentityProvider); + apiCredentialsProvider.add(credentialsProvider); } } @Override - public void beforeMarshalling(Context.BeforeMarshalling context, ExecutionAttributes executionAttributes) { + public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) { SelectedAuthScheme attribute = executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); CompletableFuture identity = attribute.identity(); @@ -311,7 +321,11 @@ private static final class PathStyleEnforcingInterceptor implements ExecutionInt public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) { SdkHttpRequest sdkHttpRequest = context.httpRequest(); String host = sdkHttpRequest.host(); - String bucket = host.substring(0, host.indexOf(".localhost")); + int idx = host.indexOf(".localhost"); + if (idx < 0) { + return sdkHttpRequest; + } + String bucket = host.substring(0, idx); return sdkHttpRequest.toBuilder().host("localhost") .encodedPath(SdkHttpUtils.appendUri(bucket, sdkHttpRequest.encodedPath())) diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressUtilsTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressUtilsTest.java new file mode 100644 index 000000000000..53d84f42ad66 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressUtilsTest.java @@ -0,0 +1,111 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.s3express; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.IOException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; + +class S3ExpressUtilsTest { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + + @Test + void isS3ExpressBucket_bucketWithS3ExpressSuffix_returnsTrue() { + GetObjectRequest request = GetObjectRequest.builder() + .bucket("my-bucket--use1-az1--x-s3") + .key("key") + .build(); + assertThat(S3ExpressUtils.isS3ExpressBucket(request)).isTrue(); + } + + @Test + void isS3ExpressBucket_regularBucket_returnsFalse() { + GetObjectRequest request = GetObjectRequest.builder() + .bucket("my-regular-bucket") + .key("key") + .build(); + assertThat(S3ExpressUtils.isS3ExpressBucket(request)).isFalse(); + } + + @Test + void isS3ExpressBucket_noBucketField_returnsFalse() { + GetObjectRequest request = GetObjectRequest.builder() + .key("key") + .build(); + assertThat(S3ExpressUtils.isS3ExpressBucket(request)).isFalse(); + } + + /** + * Validates that the S3Express bucket suffix used in {@link S3ExpressUtils#isS3ExpressBucket} matches the suffix + * defined in the endpoint ruleset. + */ + @Test + void isS3ExpressBucket_suffixMatchesEndpointRuleset() throws IOException { + String rulesetSuffix = extractBucketSuffixFromRuleset(); + assertThat(rulesetSuffix).isEqualTo("--x-s3"); + GetObjectRequest request = GetObjectRequest.builder() + .bucket("test-bucket" + rulesetSuffix) + .key("key") + .build(); + assertThat(S3ExpressUtils.isS3ExpressBucket(request)) + .as("isS3ExpressBucket should recognize the suffix '%s' from the endpoint ruleset", rulesetSuffix) + .isTrue(); + } + + /** + * Parses the endpoint-rule-set.json and extracts the S3Express bucket suffix. + */ + private String extractBucketSuffixFromRuleset() throws IOException { + Path rulesetPath = Paths.get("src/main/resources/codegen-resources/endpoint-rule-set.json"); + assertThat(rulesetPath.toFile()).as("endpoint-rule-set.json should exist").exists(); + JsonNode root = MAPPER.readTree(rulesetPath.toFile()); + List suffixes = new ArrayList<>(); + findBucketSuffixValues(root, suffixes); + assertThat(suffixes) + .as("Expected exactly one bucketSuffix stringEquals check in the endpoint ruleset") + .hasSize(1); + return suffixes.get(0); + } + + private void findBucketSuffixValues(JsonNode node, List results) { + if (node.isObject() && "stringEquals".equals(node.path("fn").asText(null))) { + JsonNode argv = node.path("argv"); + if (argv.isArray() && argv.size() == 2) { + for (int i = 0; i < 2; i++) { + JsonNode arg = argv.get(i); + JsonNode other = argv.get(1 - i); + if (arg.isObject() && "bucketSuffix".equals(arg.path("ref").asText(null)) + && other.isTextual()) { + results.add(other.asText()); + return; + } + } + } + } + for (JsonNode child : node) { + findBucketSuffixValues(child, results); + } + } +} diff --git a/services/s3control/src/test/java/software/amazon/awssdk/services/s3control/internal/functionaltests/arns/S3ControlWireMockRerouteInterceptor.java b/services/s3control/src/test/java/software/amazon/awssdk/services/s3control/internal/functionaltests/arns/S3ControlWireMockRerouteInterceptor.java index a561c919bab2..fb209138ab2f 100644 --- a/services/s3control/src/test/java/software/amazon/awssdk/services/s3control/internal/functionaltests/arns/S3ControlWireMockRerouteInterceptor.java +++ b/services/s3control/src/test/java/software/amazon/awssdk/services/s3control/internal/functionaltests/arns/S3ControlWireMockRerouteInterceptor.java @@ -6,6 +6,8 @@ import software.amazon.awssdk.core.interceptor.Context; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.endpoints.Endpoint; import software.amazon.awssdk.http.SdkHttpRequest; /** @@ -25,12 +27,21 @@ public class S3ControlWireMockRerouteInterceptor implements ExecutionInterceptor public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) { SdkHttpRequest request = context.httpRequest(); - recordedEndpoints.add(request.getUri()); recordedRequests.add(request); return request.toBuilder().uri(rerouteEndpoint).build(); } + @Override + public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) { + Endpoint resolvedEndpoint = executionAttributes.getAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT); + if (resolvedEndpoint != null) { + recordedEndpoints.add(resolvedEndpoint.url()); + } else { + recordedEndpoints.add(context.httpRequest().getUri()); + } + } + public List getRecordedRequests() { return recordedRequests; } diff --git a/services/s3control/src/test/java/software/amazon/awssdk/services/s3control/internal/functionaltests/signing/UrlEncodingTest.java b/services/s3control/src/test/java/software/amazon/awssdk/services/s3control/internal/functionaltests/signing/UrlEncodingTest.java index b5688ad41c51..03765e9c8534 100644 --- a/services/s3control/src/test/java/software/amazon/awssdk/services/s3control/internal/functionaltests/signing/UrlEncodingTest.java +++ b/services/s3control/src/test/java/software/amazon/awssdk/services/s3control/internal/functionaltests/signing/UrlEncodingTest.java @@ -84,7 +84,7 @@ private static class ExecutionAttributeInterceptor implements ExecutionIntercept } @Override - public void beforeExecution(Context.BeforeExecution context, ExecutionAttributes executionAttributes) { + public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) { signerDoubleUrlEncode = executionAttributes.getAttribute(AwsSignerExecutionAttribute.SIGNER_DOUBLE_URL_ENCODE); } diff --git a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/AuthSchemeInterceptorTest.java b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/AuthSchemeInterceptorTest.java deleted file mode 100644 index e187ab4435d5..000000000000 --- a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/AuthSchemeInterceptorTest.java +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package software.amazon.awssdk.services; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider; -import software.amazon.awssdk.core.SelectedAuthScheme; -import software.amazon.awssdk.core.interceptor.Context; -import software.amazon.awssdk.core.interceptor.ExecutionAttributes; -import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; -import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; -import software.amazon.awssdk.http.auth.aws.scheme.AwsV4AuthScheme; -import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; -import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; -import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; -import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; -import software.amazon.awssdk.identity.spi.IdentityProvider; -import software.amazon.awssdk.identity.spi.IdentityProviders; -import software.amazon.awssdk.services.protocolrestjson.auth.scheme.ProtocolRestJsonAuthSchemeParams; -import software.amazon.awssdk.services.protocolrestjson.auth.scheme.ProtocolRestJsonAuthSchemeProvider; -import software.amazon.awssdk.services.protocolrestjson.auth.scheme.internal.ProtocolRestJsonAuthSchemeInterceptor; - -public class AuthSchemeInterceptorTest { - private static final ProtocolRestJsonAuthSchemeInterceptor INTERCEPTOR = new ProtocolRestJsonAuthSchemeInterceptor(); - - private Context.BeforeExecution mockContext; - - @BeforeEach - public void setup() { - mockContext = mock(Context.BeforeExecution.class); - } - - @Test - public void resolveAuthScheme_authSchemeSignerThrows_continuesToNextAuthScheme() { - ProtocolRestJsonAuthSchemeProvider mockAuthSchemeProvider = mock(ProtocolRestJsonAuthSchemeProvider.class); - List authSchemeOptions = Arrays.asList( - AuthSchemeOption.builder().schemeId(TestAuthScheme.SCHEME_ID).build(), - AuthSchemeOption.builder().schemeId(AwsV4AuthScheme.SCHEME_ID).build() - ); - when(mockAuthSchemeProvider.resolveAuthScheme(any(ProtocolRestJsonAuthSchemeParams.class))).thenReturn(authSchemeOptions); - - IdentityProviders mockIdentityProviders = mock(IdentityProviders.class); - when(mockIdentityProviders.identityProvider(any(Class.class))).thenReturn(AnonymousCredentialsProvider.create()); - - Map> authSchemes = new HashMap<>(); - authSchemes.put(AwsV4AuthScheme.SCHEME_ID, AwsV4AuthScheme.create()); - - TestAuthScheme notProvidedAuthScheme = spy(new TestAuthScheme()); - authSchemes.put(TestAuthScheme.SCHEME_ID, notProvidedAuthScheme); - - ExecutionAttributes attributes = new ExecutionAttributes(); - attributes.putAttribute(SdkExecutionAttribute.OPERATION_NAME, "GetFoo"); - attributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_RESOLVER, mockAuthSchemeProvider); - attributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, mockIdentityProviders); - attributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, authSchemes); - - INTERCEPTOR.beforeExecution(mockContext, attributes); - - SelectedAuthScheme selectedAuthScheme = attributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); - - verify(notProvidedAuthScheme).signer(); - assertThat(selectedAuthScheme.authSchemeOption().schemeId()).isEqualTo(AwsV4AuthScheme.SCHEME_ID); - } - - private static class TestAuthScheme implements AuthScheme { - public static final String SCHEME_ID = "codegen-test-scheme"; - - @Override - public String schemeId() { - return SCHEME_ID; - } - - @Override - public IdentityProvider identityProvider(IdentityProviders providers) { - return providers.identityProvider(AwsCredentialsIdentity.class); - } - - @Override - public HttpSigner signer() { - throw new RuntimeException("Not on classpath"); - } - } -} diff --git a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/IdentityResolutionOverrideTest.java b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/IdentityResolutionOverrideTest.java index f20c84a25756..daf73cd838e3 100644 --- a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/IdentityResolutionOverrideTest.java +++ b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/IdentityResolutionOverrideTest.java @@ -71,9 +71,8 @@ void when_apiCall_setsCredentialsProviderInRequestOverride_overrideCredentialsAr assertSelectedAuthSchemeBeforeTransmissionContains(OVERRIDE_CREDENTIALS); } - // Changing the credentials provider in modifyRequest does not work in SRA identity resolution - // Identity is resolved in beforeExecution (and happens before user applied interceptors) and cannot - // be affected by execution interceptors. + // After moving identity resolution to pipeline stage (after interceptors), credentials provider + // set in modifyRequest is now respected (identity resolved after interceptors complete) @Test void when_executionInterceptorModifyRequest_setsCredentialProviderInRequestOverride_clientCredentialsAreUsed() { ExecutionInterceptor overridingInterceptor = @@ -83,7 +82,7 @@ void when_executionInterceptorModifyRequest_setsCredentialProviderInRequestOverr assertThatThrownBy(() -> syncClient.allTypes(r -> {})).hasMessageContaining("stop"); - assertSelectedAuthSchemeBeforeTransmissionContains(CLIENT_CREDENTIALS); + assertSelectedAuthSchemeBeforeTransmissionContains(OVERRIDE_CREDENTIALS); } @Test diff --git a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/EndpointInterceptorTests.java b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/EndpointInterceptorTests.java index d24456368fbd..df32b88193b2 100644 --- a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/EndpointInterceptorTests.java +++ b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/EndpointInterceptorTests.java @@ -177,7 +177,7 @@ public void sync_endpointProviderReturnsHeaders_appendedToExistingRequest() { "TestValue0")))) .hasMessageContaining("stop"); - assertThat(interceptor.context.httpRequest().matchingHeaders("TestHeader")).containsExactly("TestValue", "TestValue0"); + assertThat(interceptor.context.httpRequest().matchingHeaders("TestHeader")).containsExactly("TestValue0", "TestValue"); } @Test @@ -198,7 +198,7 @@ public void async_endpointProviderReturnsHeaders_appendedToExistingRequest() { .join()) .hasMessageContaining("stop"); - assertThat(interceptor.context.httpRequest().matchingHeaders("TestHeader")).containsExactly("TestValue", "TestValue0"); + assertThat(interceptor.context.httpRequest().matchingHeaders("TestHeader")).containsExactly("TestValue0", "TestValue"); } diff --git a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/RequestOverrideEndpointProviderTests.java b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/RequestOverrideEndpointProviderTests.java index 075976d136c8..07dd11000688 100644 --- a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/RequestOverrideEndpointProviderTests.java +++ b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/RequestOverrideEndpointProviderTests.java @@ -18,16 +18,12 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import java.net.URI; -import java.util.concurrent.CompletableFuture; -import org.junit.Test; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; import software.amazon.awssdk.core.client.config.SdkAdvancedClientOption; import software.amazon.awssdk.core.interceptor.Context; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; -import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.endpoints.Endpoint; import software.amazon.awssdk.http.SdkHttpRequest; import software.amazon.awssdk.regions.Region; @@ -37,6 +33,7 @@ import software.amazon.awssdk.services.restjsonendpointproviders.RestJsonEndpointProvidersClientBuilder; import software.amazon.awssdk.services.restjsonendpointproviders.endpoints.RestJsonEndpointProvidersEndpointParams; import software.amazon.awssdk.services.restjsonendpointproviders.endpoints.RestJsonEndpointProvidersEndpointProvider; +import org.junit.Test; public class RequestOverrideEndpointProviderTests { @@ -49,34 +46,28 @@ public void sync_endpointOverridden_equals_requestOverride() { .build(); assertThatThrownBy(() -> client.operationWithHostPrefix( - r -> r.overrideConfiguration(o -> o.endpointProvider(new CustomEndpointProvider(Region.AWS_GLOBAL)))) - - ) + r -> r.overrideConfiguration(o -> o.endpointProvider(new CustomEndpointProvider(Region.AWS_GLOBAL))))) .hasMessageContaining("stop"); - Endpoint endpoint = interceptor.executionAttributes().getAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT); - assertThat(endpoint.url().getHost()).isEqualTo("restjson.aws-global.amazonaws.com"); + + assertThat(interceptor.httpRequest.host()).isEqualTo("restjson.aws-global.amazonaws.com"); } @Test public void sync_endpointOverridden_equals_ClientsWhenNoRequestOverride() { CapturingInterceptor interceptor = new CapturingInterceptor(); - RestJsonEndpointProvidersClient client = syncClientBuilder().endpointProvider(new CustomEndpointProvider(Region.EU_WEST_2)) + RestJsonEndpointProvidersClient client = syncClientBuilder() + .endpointProvider(new CustomEndpointProvider(Region.EU_WEST_2)) .overrideConfiguration(o -> o.addExecutionInterceptor(interceptor) .putAdvancedOption(SdkAdvancedClientOption.DISABLE_HOST_PREFIX_INJECTION, true)) .build(); assertThatThrownBy(() -> client.operationWithHostPrefix(r -> {})).hasMessageContaining("stop"); - Endpoint endpoint = interceptor.executionAttributes().getAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT); - - assertThat(endpoint.url().getHost()).isEqualTo("restjson.eu-west-2.amazonaws.com"); + assertThat(interceptor.httpRequest.host()).isEqualTo("restjson.eu-west-2.amazonaws.com"); } - @Test public void async_endpointOverridden_equals_requestOverride() { - - CapturingInterceptor interceptor = new CapturingInterceptor(); RestJsonEndpointProvidersAsyncClient client = asyncClientBuilder() .overrideConfiguration(o -> o.addExecutionInterceptor(interceptor) @@ -88,18 +79,14 @@ public void async_endpointOverridden_equals_requestOverride() { ).join()) .hasMessageContaining("stop"); - Endpoint endpoint = interceptor.executionAttributes().getAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT); - - assertThat(endpoint.url().getHost()).isEqualTo("restjson.aws-global.amazonaws.com"); + assertThat(interceptor.httpRequest.host()).isEqualTo("restjson.aws-global.amazonaws.com"); } - @Test public void async_endpointOverridden_equals_ClientsWhenNoRequestOverride() { - - CapturingInterceptor interceptor = new CapturingInterceptor(); - RestJsonEndpointProvidersAsyncClient client = asyncClientBuilder().endpointProvider(new CustomEndpointProvider(Region.EU_WEST_2)) + RestJsonEndpointProvidersAsyncClient client = asyncClientBuilder() + .endpointProvider(new CustomEndpointProvider(Region.EU_WEST_2)) .overrideConfiguration(o -> o.addExecutionInterceptor(interceptor) .putAdvancedOption(SdkAdvancedClientOption.DISABLE_HOST_PREFIX_INJECTION, true)) .build(); @@ -107,31 +94,25 @@ public void async_endpointOverridden_equals_ClientsWhenNoRequestOverride() { assertThatThrownBy(() -> client.operationWithHostPrefix(r -> {}).join()) .hasMessageContaining("stop"); - Endpoint endpoint = interceptor.executionAttributes().getAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT); - - assertThat(endpoint.url().getHost()).isEqualTo("restjson.eu-west-2.amazonaws.com"); + assertThat(interceptor.httpRequest.host()).isEqualTo("restjson.eu-west-2.amazonaws.com"); } public static class CapturingInterceptor implements ExecutionInterceptor { - - private ExecutionAttributes executionAttributes; + private SdkHttpRequest httpRequest; @Override - public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) { - this.executionAttributes = executionAttributes; + public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) { + this.httpRequest = context.httpRequest(); throw new CaptureCompletedException("stop"); } - public ExecutionAttributes executionAttributes() { - return executionAttributes; - } - public class CaptureCompletedException extends RuntimeException { CaptureCompletedException(String message) { super(message); } } } + private RestJsonEndpointProvidersClientBuilder syncClientBuilder() { return RestJsonEndpointProvidersClient.builder() .region(Region.US_WEST_2) @@ -142,23 +123,23 @@ private RestJsonEndpointProvidersClientBuilder syncClientBuilder() { private RestJsonEndpointProvidersAsyncClientBuilder asyncClientBuilder() { return RestJsonEndpointProvidersAsyncClient.builder() - .region(Region.US_WEST_2) - .credentialsProvider( - StaticCredentialsProvider.create( - AwsBasicCredentials.create("akid", "skid"))); + .region(Region.US_WEST_2) + .credentialsProvider( + StaticCredentialsProvider.create( + AwsBasicCredentials.create("akid", "skid"))); } + public static class CustomEndpointProvider implements RestJsonEndpointProvidersEndpointProvider { + private final Region region; - class CustomEndpointProvider implements RestJsonEndpointProvidersEndpointProvider{ - final RestJsonEndpointProvidersEndpointProvider endpointProvider; - final Region overridingRegion; - CustomEndpointProvider (Region region){ - this.endpointProvider = RestJsonEndpointProvidersEndpointProvider.defaultProvider(); - this.overridingRegion = region; + CustomEndpointProvider(Region region) { + this.region = region; } + @Override - public CompletableFuture resolveEndpoint(RestJsonEndpointProvidersEndpointParams endpointParams) { - return endpointProvider.resolveEndpoint(endpointParams.toBuilder().region(overridingRegion).build()); + public java.util.concurrent.CompletableFuture resolveEndpoint(RestJsonEndpointProvidersEndpointParams params) { + return RestJsonEndpointProvidersEndpointProvider.defaultProvider() + .resolveEndpoint(params.toBuilder().region(region).build()); } } } diff --git a/test/sdk-standard-benchmarks/src/main/java/software/amazon/awssdk/benchmark/endpoints/LambdaEndpointResolverBenchmark.java b/test/sdk-standard-benchmarks/src/main/java/software/amazon/awssdk/benchmark/endpoints/LambdaEndpointResolverBenchmark.java index 78e5a3fbd1aa..4efad1b6917a 100644 --- a/test/sdk-standard-benchmarks/src/main/java/software/amazon/awssdk/benchmark/endpoints/LambdaEndpointResolverBenchmark.java +++ b/test/sdk-standard-benchmarks/src/main/java/software/amazon/awssdk/benchmark/endpoints/LambdaEndpointResolverBenchmark.java @@ -38,7 +38,7 @@ import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.lambda.endpoints.LambdaEndpointParams; import software.amazon.awssdk.services.lambda.endpoints.LambdaEndpointProvider; -import software.amazon.awssdk.services.lambda.endpoints.internal.LambdaResolveEndpointInterceptor; +import software.amazon.awssdk.services.lambda.endpoints.internal.LambdaEndpointResolverUtils; import software.amazon.awssdk.services.lambda.model.ListFunctionsRequest; import software.amazon.awssdk.utils.AttributeMap; @@ -88,7 +88,7 @@ public void setup() { @Benchmark public void resolveEndpoint(Blackhole blackhole) { - LambdaEndpointParams params = LambdaResolveEndpointInterceptor.ruleParams(request, executionAttributes); + LambdaEndpointParams params = LambdaEndpointResolverUtils.ruleParams(request, executionAttributes); blackhole.consume(provider.resolveEndpoint(params).join()); } diff --git a/test/sdk-standard-benchmarks/src/main/java/software/amazon/awssdk/benchmark/endpoints/S3EndpointResolverBenchmark.java b/test/sdk-standard-benchmarks/src/main/java/software/amazon/awssdk/benchmark/endpoints/S3EndpointResolverBenchmark.java index 9571040f3694..240891adcfdb 100644 --- a/test/sdk-standard-benchmarks/src/main/java/software/amazon/awssdk/benchmark/endpoints/S3EndpointResolverBenchmark.java +++ b/test/sdk-standard-benchmarks/src/main/java/software/amazon/awssdk/benchmark/endpoints/S3EndpointResolverBenchmark.java @@ -39,7 +39,7 @@ import software.amazon.awssdk.services.s3.endpoints.S3ClientContextParams; import software.amazon.awssdk.services.s3.endpoints.S3EndpointParams; import software.amazon.awssdk.services.s3.endpoints.S3EndpointProvider; -import software.amazon.awssdk.services.s3.endpoints.internal.S3ResolveEndpointInterceptor; +import software.amazon.awssdk.services.s3.endpoints.internal.S3EndpointResolverUtils; import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.utils.AttributeMap; @@ -121,7 +121,7 @@ public void setup() { @Benchmark public void resolveEndpoint(Blackhole blackhole) { - S3EndpointParams params = S3ResolveEndpointInterceptor.ruleParams(request, executionAttributes); + S3EndpointParams params = S3EndpointResolverUtils.ruleParams(request, executionAttributes); blackhole.consume(provider.resolveEndpoint(params).join()); }