diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/AuthSchemeGeneratorTasks.java b/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/AuthSchemeGeneratorTasks.java index 471d9a5ddaab..0794a165181c 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/AuthSchemeGeneratorTasks.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/AuthSchemeGeneratorTasks.java @@ -20,7 +20,6 @@ import software.amazon.awssdk.codegen.emitters.GeneratorTask; import software.amazon.awssdk.codegen.emitters.GeneratorTaskParams; import software.amazon.awssdk.codegen.emitters.PoetGeneratorTask; -import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeInterceptorSpec; import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeParamsSpec; import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeProviderSpec; import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeSpecUtils; @@ -85,10 +84,6 @@ private GeneratorTask generateEndpointAwareAuthSchemeParams() { } - private GeneratorTask generateAuthSchemeInterceptor() { - return new PoetGeneratorTask(authSchemeInternalDir(), model.getFileHeader(), new AuthSchemeInterceptorSpec(model)); - } - private String authSchemeDir() { return generatorTaskParams.getPathProvider().getAuthSchemeDirectory(); } diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/EndpointProviderTasks.java b/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/EndpointProviderTasks.java index 7f5d64da8e79..b72a33263c5b 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/EndpointProviderTasks.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/EndpointProviderTasks.java @@ -32,9 +32,8 @@ import software.amazon.awssdk.codegen.poet.rules.EndpointProviderInterfaceSpec; import software.amazon.awssdk.codegen.poet.rules.EndpointProviderSpec; import software.amazon.awssdk.codegen.poet.rules.EndpointProviderTestSpec; -import software.amazon.awssdk.codegen.poet.rules.EndpointResolverInterceptorSpec; +import software.amazon.awssdk.codegen.poet.rules.EndpointResolverUtilsSpec; import software.amazon.awssdk.codegen.poet.rules.EndpointRulesClientTestSpec; -import software.amazon.awssdk.codegen.poet.rules.RequestEndpointInterceptorSpec; import software.amazon.awssdk.codegen.poet.rules2.EndpointProviderSpec2; public final class EndpointProviderTasks extends BaseGeneratorTasks { @@ -103,8 +102,7 @@ private boolean shouldGenerateCompiledEndpointRules() { private Collection generateInterceptors() { return Arrays.asList( - new PoetGeneratorTask(endpointRulesInternalDir(), model.getFileHeader(), new EndpointResolverInterceptorSpec(model)), - new PoetGeneratorTask(endpointRulesInternalDir(), model.getFileHeader(), new RequestEndpointInterceptorSpec(model))); + new PoetGeneratorTask(endpointRulesInternalDir(), model.getFileHeader(), new EndpointResolverUtilsSpec(model))); } private GeneratorTask generateClientTests() { diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeInterceptorSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeInterceptorSpec.java deleted file mode 100644 index c4f9e768eaa3..000000000000 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeInterceptorSpec.java +++ /dev/null @@ -1,524 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package software.amazon.awssdk.codegen.poet.auth.scheme; - -import com.squareup.javapoet.ClassName; -import com.squareup.javapoet.FieldSpec; -import com.squareup.javapoet.MethodSpec; -import com.squareup.javapoet.ParameterSpec; -import com.squareup.javapoet.ParameterizedTypeName; -import com.squareup.javapoet.TypeName; -import com.squareup.javapoet.TypeSpec; -import com.squareup.javapoet.TypeVariableName; -import com.squareup.javapoet.WildcardTypeName; -import java.time.Duration; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.concurrent.CompletableFuture; -import java.util.function.Supplier; -import java.util.stream.Collectors; -import javax.lang.model.element.Modifier; -import software.amazon.awssdk.annotations.SdkInternalApi; -import software.amazon.awssdk.awscore.AwsExecutionAttribute; -import software.amazon.awssdk.codegen.model.intermediate.IntermediateModel; -import software.amazon.awssdk.codegen.poet.ClassSpec; -import software.amazon.awssdk.codegen.poet.PoetUtils; -import software.amazon.awssdk.codegen.poet.rules.EndpointRulesSpecUtils; -import software.amazon.awssdk.core.SdkRequest; -import software.amazon.awssdk.core.SelectedAuthScheme; -import software.amazon.awssdk.core.exception.SdkException; -import software.amazon.awssdk.core.identity.SdkIdentityProperty; -import software.amazon.awssdk.core.interceptor.Context; -import software.amazon.awssdk.core.interceptor.ExecutionAttributes; -import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; -import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; -import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; -import software.amazon.awssdk.core.internal.util.MetricUtils; -import software.amazon.awssdk.core.metrics.CoreMetric; -import software.amazon.awssdk.core.useragent.BusinessMetricFeatureId; -import software.amazon.awssdk.endpoints.EndpointProvider; -import software.amazon.awssdk.http.auth.aws.scheme.AwsV4aAuthScheme; -import software.amazon.awssdk.http.auth.aws.signer.RegionSet; -import software.amazon.awssdk.http.auth.scheme.BearerAuthScheme; -import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; -import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; -import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; -import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; -import software.amazon.awssdk.identity.spi.Identity; -import software.amazon.awssdk.identity.spi.IdentityProvider; -import software.amazon.awssdk.identity.spi.IdentityProviders; -import software.amazon.awssdk.identity.spi.ResolveIdentityRequest; -import software.amazon.awssdk.identity.spi.TokenIdentity; -import software.amazon.awssdk.metrics.MetricCollector; -import software.amazon.awssdk.metrics.SdkMetric; -import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.utils.CollectionUtils; -import software.amazon.awssdk.utils.Logger; -import software.amazon.awssdk.utils.Validate; - -public final class AuthSchemeInterceptorSpec implements ClassSpec { - private final AuthSchemeSpecUtils authSchemeSpecUtils; - private final EndpointRulesSpecUtils endpointRulesSpecUtils; - private final IntermediateModel intermediateModel; - - public AuthSchemeInterceptorSpec(IntermediateModel intermediateModel) { - this.intermediateModel = intermediateModel; - this.authSchemeSpecUtils = new AuthSchemeSpecUtils(intermediateModel); - this.endpointRulesSpecUtils = new EndpointRulesSpecUtils(intermediateModel); - } - - @Override - public ClassName className() { - return authSchemeSpecUtils.authSchemeInterceptor(); - } - - @Override - public TypeSpec poetSpec() { - TypeSpec.Builder builder = PoetUtils.createClassBuilder(className()) - .addSuperinterface(ExecutionInterceptor.class) - .addModifiers(Modifier.PUBLIC, Modifier.FINAL) - .addAnnotation(SdkInternalApi.class); - - builder.addField(FieldSpec.builder(Logger.class, "LOG", Modifier.PRIVATE, Modifier.STATIC) - .initializer("$T.loggerFor($T.class)", Logger.class, className()) - .build()); - - builder.addMethod(generateBeforeExecution()) - .addMethod(generateResolveAuthOptions()) - .addMethod(generateSelectAuthScheme()) - .addMethod(generateAuthSchemeParams()) - .addMethod(generateTrySelectAuthScheme()) - .addMethod(generateGetIdentityMetric()) - .addMethod(putSelectedAuthSchemeMethodSpec()); - if (intermediateModel.getCustomizationConfig().isEnableEnvironmentBearerToken()) { - builder.addMethod(generateEnvironmentTokenMetric()); - } - return builder.build(); - } - - private MethodSpec generateEnvironmentTokenMetric() { - return MethodSpec - .methodBuilder("recordEnvironmentTokenBusinessMetric") - .addModifiers(Modifier.PRIVATE) - .addTypeVariable(TypeVariableName.get("T", Identity.class)) - .addParameter(ParameterSpec.builder( - ParameterizedTypeName.get(ClassName.get(SelectedAuthScheme.class), - TypeVariableName.get("T")), - "selectedAuthScheme").build()) - .addParameter(ExecutionAttributes.class, "executionAttributes") - .addStatement("$T tokenFromEnv = executionAttributes.getAttribute($T.TOKEN_CONFIGURED_FROM_ENV)", - String.class, SdkInternalExecutionAttribute.class) - .beginControlFlow("if (selectedAuthScheme != null && selectedAuthScheme.authSchemeOption().schemeId().equals($T" - + ".SCHEME_ID) && selectedAuthScheme.identity().isDone())", BearerAuthScheme.class) - .beginControlFlow("if (selectedAuthScheme.identity().getNow(null) instanceof $T)", TokenIdentity.class) - - .addStatement("$T configuredToken = ($T) selectedAuthScheme.identity().getNow(null)", - TokenIdentity.class, TokenIdentity.class) - .beginControlFlow("if (configuredToken.token().equals(tokenFromEnv))") - .addStatement("executionAttributes.getAttribute($T.BUSINESS_METRICS)" - + ".addMetric($T.BEARER_SERVICE_ENV_VARS.value())", - SdkInternalExecutionAttribute.class, BusinessMetricFeatureId.class) - .endControlFlow() - .endControlFlow() - .endControlFlow() - .build(); - - - } - - private MethodSpec generateBeforeExecution() { - MethodSpec.Builder builder = MethodSpec.methodBuilder("beforeExecution") - .addAnnotation(Override.class) - .addModifiers(Modifier.PUBLIC) - .addParameter(Context.BeforeExecution.class, - "context") - .addParameter(ExecutionAttributes.class, - "executionAttributes"); - - builder.addStatement("$T authOptions = resolveAuthOptions(context, executionAttributes)", - listOf(AuthSchemeOption.class)) - .addStatement("$T selectedAuthScheme = selectAuthScheme(authOptions, executionAttributes)", - wildcardSelectedAuthScheme()) - .addStatement("putSelectedAuthScheme(executionAttributes, selectedAuthScheme)"); - - if (intermediateModel.getCustomizationConfig().isEnableEnvironmentBearerToken()) { - builder.addStatement("recordEnvironmentTokenBusinessMetric(selectedAuthScheme, " - + "executionAttributes)"); - } - - if (authSchemeSpecUtils.hasSigV4aSupport()) { - builder.beginControlFlow("if (selectedAuthScheme != null && " - + "selectedAuthScheme.authSchemeOption().schemeId().equals($T.SCHEME_ID) && " - + "!$T.isSignerOverridden(context.request(), executionAttributes))", - AwsV4aAuthScheme.class, - ClassName.get("software.amazon.awssdk.awscore.util", "SignerOverrideUtils")) - .addStatement("$T businessMetrics = executionAttributes.getAttribute($T.BUSINESS_METRICS)", - ClassName.get("software.amazon.awssdk.core.useragent", "BusinessMetricCollection"), - SdkInternalExecutionAttribute.class) - .beginControlFlow("if (businessMetrics != null)") - .addStatement("businessMetrics.addMetric($T.SIGV4A_SIGNING.value())", - BusinessMetricFeatureId.class) - .endControlFlow() - .endControlFlow(); - } - return builder.build(); - } - - private MethodSpec generateResolveAuthOptions() { - MethodSpec.Builder builder = MethodSpec.methodBuilder("resolveAuthOptions") - .addModifiers(Modifier.PRIVATE) - .returns(listOf(AuthSchemeOption.class)) - .addParameter(Context.BeforeExecution.class, - "context") - .addParameter(ExecutionAttributes.class, - "executionAttributes"); - - builder.addStatement("$1T authSchemeProvider = $2T.isInstanceOf($1T.class, executionAttributes" - + ".getAttribute($3T.AUTH_SCHEME_RESOLVER), $4S)", - authSchemeSpecUtils.providerInterfaceName(), - Validate.class, - SdkInternalExecutionAttribute.class, - "Expected an instance of " + authSchemeSpecUtils.providerInterfaceName().simpleName()); - builder.addStatement("$T params = authSchemeParams(context.request(), executionAttributes)", - authSchemeSpecUtils.parametersInterfaceName()); - builder.addStatement("return authSchemeProvider.resolveAuthScheme(params)"); - return builder.build(); - } - - private MethodSpec generateAuthSchemeParams() { - MethodSpec.Builder builder = MethodSpec.methodBuilder("authSchemeParams") - .addModifiers(Modifier.PRIVATE) - .returns(authSchemeSpecUtils.parametersInterfaceName()) - .addParameter(SdkRequest.class, "request") - .addParameter(ExecutionAttributes.class, "executionAttributes"); - - if (!authSchemeSpecUtils.useEndpointBasedAuthProvider()) { - builder.addStatement("$T operation = executionAttributes.getAttribute($T.OPERATION_NAME)", String.class, - SdkExecutionAttribute.class); - builder.addStatement("$T.Builder builder = $T.builder().operation(operation)", - authSchemeSpecUtils.parametersInterfaceName(), - authSchemeSpecUtils.parametersInterfaceName()); - - if (authSchemeSpecUtils.usesSigV4()) { - builder.addStatement("$T region = executionAttributes.getAttribute($T.AWS_REGION)", Region.class, - AwsExecutionAttribute.class); - builder.addStatement("builder.region(region)"); - } - generateSigv4aSigningRegionSet(builder); - builder.addStatement("return builder.build()"); - return builder.build(); - } - - builder.addStatement("$T endpointParams = $T.ruleParams(request, executionAttributes)", - endpointRulesSpecUtils.parametersClassName(), - endpointRulesSpecUtils.resolverInterceptorName()); - builder.addStatement("$1T.Builder builder = $1T.builder()", authSchemeSpecUtils.parametersInterfaceName()); - boolean regionIncluded = false; - for (String paramName : endpointRulesSpecUtils.parameters().keySet()) { - if (!authSchemeSpecUtils.includeParamForProvider(paramName)) { - continue; - } - regionIncluded = regionIncluded || paramName.equalsIgnoreCase("region"); - String methodName = endpointRulesSpecUtils.paramMethodName(paramName); - builder.addStatement("builder.$1N(endpointParams.$1N())", methodName); - } - - builder.addStatement("$T operation = executionAttributes.getAttribute($T.OPERATION_NAME)", String.class, - SdkExecutionAttribute.class); - builder.addStatement("builder.operation(operation)"); - if (authSchemeSpecUtils.usesSigV4() && !regionIncluded) { - builder.addStatement("$T region = executionAttributes.getAttribute($T.AWS_REGION)", Region.class, - AwsExecutionAttribute.class); - builder.addStatement("builder.region(region)"); - } - generateSigv4aSigningRegionSet(builder); - ClassName paramsBuilderClass = authSchemeSpecUtils.parametersEndpointAwareDefaultImplName().nestedClass("Builder"); - builder.beginControlFlow("if (builder instanceof $T)", - paramsBuilderClass); - ClassName endpointProviderClass = endpointRulesSpecUtils.providerInterfaceName(); - builder.addStatement("$T endpointProvider = executionAttributes.getAttribute($T.ENDPOINT_PROVIDER)", - EndpointProvider.class, - SdkInternalExecutionAttribute.class); - builder.beginControlFlow("if (endpointProvider instanceof $T)", endpointProviderClass); - builder.addStatement("(($T)builder).endpointProvider(($T)endpointProvider)", paramsBuilderClass, endpointProviderClass); - builder.endControlFlow(); - builder.endControlFlow(); - builder.addStatement("return builder.build()"); - return builder.build(); - } - - private MethodSpec generateSelectAuthScheme() { - MethodSpec.Builder builder = MethodSpec.methodBuilder("selectAuthScheme") - .addModifiers(Modifier.PRIVATE) - .returns(wildcardSelectedAuthScheme()) - .addParameter(listOf(AuthSchemeOption.class), "authOptions") - .addParameter(ExecutionAttributes.class, "executionAttributes"); - - builder.addStatement("$T metricCollector = executionAttributes.getAttribute($T.API_CALL_METRIC_COLLECTOR)", - MetricCollector.class, SdkExecutionAttribute.class) - .addStatement("$T authSchemes = executionAttributes.getAttribute($T.AUTH_SCHEMES)", - mapOf(String.class, wildcardAuthScheme()), - SdkInternalExecutionAttribute.class) - .addStatement("$T identityProviders = executionAttributes.getAttribute($T.IDENTITY_PROVIDERS)", - IdentityProviders.class, SdkInternalExecutionAttribute.class) - .addStatement("$T discardedReasons = new $T<>()", - listOfStringSuppliers(), ArrayList.class); - - builder.beginControlFlow("for ($T authOption : authOptions)", AuthSchemeOption.class); - { - builder.addStatement("$T authScheme = authSchemes.get(authOption.schemeId())", wildcardAuthScheme()) - .addStatement("$T selectedAuthScheme = trySelectAuthScheme(authOption, authScheme, identityProviders, " - + "discardedReasons, metricCollector, executionAttributes)", - wildcardSelectedAuthScheme()); - builder.beginControlFlow("if (selectedAuthScheme != null)"); - { - addLogDebugDiscardedOptions(builder); - builder.addStatement("return selectedAuthScheme") - .endControlFlow(); - } - // end foreach - builder.endControlFlow(); - } - builder.addStatement("throw $T.builder()" - + ".message($S + discardedReasons.stream().map($T::get).collect($T.joining(\", \")))" - + ".build()", - SdkException.class, - "Failed to determine how to authenticate the user: ", - Supplier.class, - Collectors.class); - return builder.build(); - } - - //TODO (s3express) Review "general" identity properties and their propagation - private MethodSpec generateTrySelectAuthScheme() { - MethodSpec.Builder builder = MethodSpec.methodBuilder("trySelectAuthScheme") - .addModifiers(Modifier.PRIVATE) - .returns(namedSelectedAuthScheme()) - .addParameter(AuthSchemeOption.class, "authOption") - .addParameter(namedAuthScheme(), "authScheme") - .addParameter(IdentityProviders.class, "identityProviders") - .addParameter(listOfStringSuppliers(), "discardedReasons") - .addParameter(MetricCollector.class, "metricCollector") - .addParameter(ExecutionAttributes.class, "executionAttributes") - .addTypeVariable(TypeVariableName.get("T", Identity.class)); - - builder.beginControlFlow("if (authScheme == null)"); - { - builder.addStatement("discardedReasons.add(() -> String.format($S, authOption.schemeId()))", - "'%s' is not enabled for this request.") - .addStatement("return null") - .endControlFlow(); - } - builder.addStatement("$T identityProvider = authScheme.identityProvider(identityProviders)", - namedIdentityProvider()); - - builder.beginControlFlow("if (identityProvider == null)"); - { - builder.addStatement("discardedReasons.add(() -> String.format($S, authOption.schemeId()))", - "'%s' does not have an identity provider configured.") - .addStatement("return null") - .endControlFlow(); - } - - builder.addStatement("$T signer", - ParameterizedTypeName.get(ClassName.get(HttpSigner.class), TypeVariableName.get("T"))); - builder.beginControlFlow("try"); - { - builder.addStatement("signer = authScheme.signer()"); - builder.nextControlFlow("catch (RuntimeException e)"); - builder.addStatement("discardedReasons.add(() -> String.format($S, authOption.schemeId(), e.getMessage()))", - "'%s' signer could not be retrieved: %s") - .addStatement("return null") - .endControlFlow(); - } - - - builder.addStatement("$T.Builder identityRequestBuilder = $T.builder()", - ResolveIdentityRequest.class, - ResolveIdentityRequest.class); - builder.addStatement("authOption.forEachIdentityProperty(identityRequestBuilder::putProperty)"); - if (endpointRulesSpecUtils.isS3()) { - builder.addStatement("identityRequestBuilder.putProperty($T.SDK_CLIENT, " - + "executionAttributes.getAttribute($T.SDK_CLIENT))", - SdkIdentityProperty.class, - SdkInternalExecutionAttribute.class); - } - builder.addStatement("$T identity", namedIdentityFuture()); - builder.addStatement("$T metric = getIdentityMetric(identityProvider)", durationSdkMetric()); - builder.beginControlFlow("if (metric == null)") - .addStatement("identity = identityProvider.resolveIdentity(identityRequestBuilder.build())") - .nextControlFlow("else") - .addStatement("identity = $T.reportDuration(" - + "() -> identityProvider.resolveIdentity(identityRequestBuilder.build()), metricCollector, metric)", - MetricUtils.class) - .endControlFlow(); - - builder.addStatement("return new $T<>(identity, signer, authOption)", SelectedAuthScheme.class); - return builder.build(); - } - - private MethodSpec generateGetIdentityMetric() { - MethodSpec.Builder builder = MethodSpec.methodBuilder("getIdentityMetric") - .addModifiers(Modifier.PRIVATE) - .returns(durationSdkMetric()) - .addParameter(wildcardIdentityProvider(), "identityProvider"); - - builder.addStatement("Class identityType = identityProvider.identityType()") - .beginControlFlow("if (identityType == $T.class)", AwsCredentialsIdentity.class) - .addStatement("return $T.CREDENTIALS_FETCH_DURATION", CoreMetric.class) - .endControlFlow() - .beginControlFlow("if (identityType == $T.class)", TokenIdentity.class) - .addStatement("return $T.TOKEN_FETCH_DURATION", CoreMetric.class) - .endControlFlow() - .addStatement("return null"); - - return builder.build(); - } - - private MethodSpec putSelectedAuthSchemeMethodSpec() { - String attributeParamName = "attributes"; - String selectedAuthSchemeParamName = "selectedAuthScheme"; - MethodSpec.Builder builder = MethodSpec.methodBuilder("putSelectedAuthScheme") - .addModifiers(Modifier.PRIVATE) - .addTypeVariable(TypeVariableName.get("T", Identity.class)) - .addParameter(ExecutionAttributes.class, attributeParamName) - .addParameter(ParameterSpec.builder( - ParameterizedTypeName.get(ClassName.get(SelectedAuthScheme.class), - TypeVariableName.get("T")), - selectedAuthSchemeParamName).build()); - builder.addStatement("$T existingAuthScheme = $N.getAttribute($T.SELECTED_AUTH_SCHEME)", - ParameterizedTypeName.get(ClassName.get(SelectedAuthScheme.class), - WildcardTypeName.subtypeOf(Object.class)), - attributeParamName, - SdkInternalExecutionAttribute.class); - - builder.beginControlFlow("if (existingAuthScheme != null)") - .addStatement("$T selectedOption = $N.authSchemeOption().toBuilder()", - AuthSchemeOption.Builder.class, selectedAuthSchemeParamName) - .addStatement("existingAuthScheme.authSchemeOption().forEachIdentityProperty" - + "(selectedOption::putIdentityPropertyIfAbsent)") - .addStatement("existingAuthScheme.authSchemeOption().forEachSignerProperty" - + "(selectedOption::putSignerPropertyIfAbsent)") - .addStatement("$N = new $T<>($N.identity(), $N.signer(), selectedOption.build())", - selectedAuthSchemeParamName, - SelectedAuthScheme.class, - selectedAuthSchemeParamName, - selectedAuthSchemeParamName); - builder.endControlFlow(); - - builder.addStatement("$N.putAttribute($T.SELECTED_AUTH_SCHEME, $N)", - attributeParamName, SdkInternalExecutionAttribute.class, selectedAuthSchemeParamName); - - return builder.build(); - } - - private void addLogDebugDiscardedOptions(MethodSpec.Builder builder) { - builder.beginControlFlow("if (!discardedReasons.isEmpty())"); - { - builder.addStatement("LOG.debug(() -> String.format(\"%s auth will be used, discarded: '%s'\", " - + "authOption.schemeId(), " - + "discardedReasons.stream().map($T::get).collect($T.joining(\", \"))))", - Supplier.class, Collectors.class) - .endControlFlow(); - } - } - - // IdentityProvider - private TypeName namedIdentityProvider() { - return ParameterizedTypeName.get(ClassName.get(IdentityProvider.class), TypeVariableName.get("T")); - } - - // IdentityProvider - private TypeName wildcardIdentityProvider() { - return ParameterizedTypeName.get(ClassName.get(IdentityProvider.class), WildcardTypeName.subtypeOf(Object.class)); - } - - // CompletableFuture - private TypeName namedIdentityFuture() { - return ParameterizedTypeName.get(ClassName.get(CompletableFuture.class), - WildcardTypeName.subtypeOf(TypeVariableName.get("T"))); - } - - // AuthScheme - private TypeName namedAuthScheme() { - return ParameterizedTypeName.get(ClassName.get(AuthScheme.class), - TypeVariableName.get("T", Identity.class)); - } - - // AuthScheme - private TypeName wildcardAuthScheme() { - return ParameterizedTypeName.get(ClassName.get(AuthScheme.class), - WildcardTypeName.subtypeOf(Object.class)); - } - - // SelectedAuthScheme - private TypeName namedSelectedAuthScheme() { - return ParameterizedTypeName.get(ClassName.get(SelectedAuthScheme.class), - TypeVariableName.get("T", Identity.class)); - } - - // SelectedAuthScheme - private TypeName wildcardSelectedAuthScheme() { - return ParameterizedTypeName.get(ClassName.get(SelectedAuthScheme.class), - WildcardTypeName.subtypeOf(Identity.class)); - } - - // List> - private TypeName listOfStringSuppliers() { - return listOf(ParameterizedTypeName.get(Supplier.class, String.class)); - } - - // Map - private TypeName mapOf(Object keyType, Object valueType) { - return ParameterizedTypeName.get(ClassName.get(Map.class), toTypeName(keyType), toTypeName(valueType)); - } - - // List - private TypeName listOf(Object valueType) { - return ParameterizedTypeName.get(ClassName.get(List.class), toTypeName(valueType)); - } - - // SdkMetric - private ParameterizedTypeName durationSdkMetric() { - return ParameterizedTypeName.get(ClassName.get(SdkMetric.class), toTypeName(Duration.class)); - } - - private TypeName toTypeName(Object valueType) { - TypeName result; - if (valueType instanceof Class) { - result = ClassName.get((Class) valueType); - } else if (valueType instanceof TypeName) { - result = (TypeName) valueType; - } else { - throw new IllegalArgumentException("Don't know how to convert " + valueType + " to TypeName"); - } - return result; - } - - private void generateSigv4aSigningRegionSet(MethodSpec.Builder builder) { - if (authSchemeSpecUtils.hasSigV4aSupport()) { - builder.addStatement( - "executionAttributes.getOptionalAttribute($T.AWS_SIGV4A_SIGNING_REGION_SET)\n" + - " .filter(regionSet -> !$T.isNullOrEmpty(regionSet))\n" + - " .ifPresent(nonEmptyRegionSet -> builder.regionSet($T.create(nonEmptyRegionSet)))", - AwsExecutionAttribute.class, - CollectionUtils.class, - RegionSet.class - ); - } - } -} diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/BaseClientBuilderClass.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/BaseClientBuilderClass.java index e342a79b4e25..7533cf25c0c4 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/BaseClientBuilderClass.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/BaseClientBuilderClass.java @@ -370,9 +370,6 @@ private MethodSpec finalizeServiceConfigurationMethod() { List builtInInterceptors = new ArrayList<>(); - builtInInterceptors.add(endpointRulesSpecUtils.resolverInterceptorName()); - builtInInterceptors.add(endpointRulesSpecUtils.requestModifierInterceptorName()); - for (String interceptor : model.getCustomizationConfig().getInterceptors()) { builtInInterceptors.add(ClassName.bestGuess(interceptor)); } diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java index fc31f3910476..09bbaa7caf80 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java @@ -172,7 +172,8 @@ protected void addAdditionalMethods(TypeSpec.Builder type) { .addMethods(protocolSpec.additionalMethods()) .addMethod(protocolSpec.initProtocolFactory(model)) .addMethod(resolveMetricPublishersMethod()) - .addMethod(ClientClassUtils.resolveAuthSchemeOptionsMethod(authSchemeSpecUtils, endpointRulesSpecUtils)); + .addMethod(ClientClassUtils.resolveAuthSchemeOptionsMethod(authSchemeSpecUtils, endpointRulesSpecUtils)) + .addMethod(ClientClassUtils.resolveEndpointMethod(authSchemeSpecUtils, endpointRulesSpecUtils)); type.addMethod(ClientClassUtils.updateRetryStrategyClientConfigurationMethod()); type.addMethod(updateSdkClientConfigurationMethod(configurationUtils.serviceClientConfigurationBuilderClassName(), diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/ClientClassUtils.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/ClientClassUtils.java index 77cb15c9614e..0b7452d3111d 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/ClientClassUtils.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/ClientClassUtils.java @@ -25,11 +25,13 @@ import com.squareup.javapoet.ParameterizedTypeName; import com.squareup.javapoet.TypeName; import com.squareup.javapoet.TypeVariableName; +import com.squareup.javapoet.WildcardTypeName; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.stream.Collectors; import javax.lang.model.element.Modifier; @@ -51,11 +53,16 @@ import software.amazon.awssdk.core.SdkClient; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.retry.RetryMode; import software.amazon.awssdk.core.signer.Signer; +import software.amazon.awssdk.endpoints.Endpoint; import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.utils.AttributeMap; @@ -406,9 +413,6 @@ private static void addSimpleAuthSchemeResolution(MethodSpec.Builder builder, ClassName.get(Set.class), awsClientOption); builder.beginControlFlow("if (!$T.isNullOrEmpty(sigv4aRegionSet))", CollectionUtils.class); builder.addStatement("paramsBuilder.regionSet($T.create(sigv4aRegionSet))", regionSet); - builder.nextControlFlow("else"); - builder.addStatement("paramsBuilder.regionSet($T.create(clientConfiguration.option($T.AWS_REGION).id()))", - regionSet, awsClientOption); builder.endControlFlow(); } @@ -422,7 +426,7 @@ private static void addEndpointBasedAuthSchemeResolution(MethodSpec.Builder buil ClassName paramsInterface = authSchemeSpecUtils.parametersInterfaceName(); ClassName awsClientOption = ClassName.get("software.amazon.awssdk.awscore.client.config", "AwsClientOption"); ClassName endpointParamsClass = endpointRulesSpecUtils.parametersClassName(); - ClassName resolverInterceptor = endpointRulesSpecUtils.resolverInterceptorName(); + ClassName endpointResolverUtils = endpointRulesSpecUtils.endpointResolverUtilsName(); ClassName executionAttributesClass = ClassName.get("software.amazon.awssdk.core.interceptor", "ExecutionAttributes"); ClassName awsExecutionAttribute = ClassName.get("software.amazon.awssdk.awscore", "AwsExecutionAttribute"); ClassName sdkExecutionAttribute = ClassName.get("software.amazon.awssdk.core.interceptor", "SdkExecutionAttribute"); @@ -447,7 +451,7 @@ private static void addEndpointBasedAuthSchemeResolution(MethodSpec.Builder buil sdkInternalExecutionAttribute, SdkClientOption.class); builder.addStatement("$T endpointParams = $T.ruleParams(request, executionAttributes)", - endpointParamsClass, resolverInterceptor); + endpointParamsClass, endpointResolverUtils); builder.addStatement("$T.Builder paramsBuilder = $T.builder()", paramsInterface, paramsInterface); @@ -467,6 +471,15 @@ private static void addEndpointBasedAuthSchemeResolution(MethodSpec.Builder buil builder.addStatement("paramsBuilder.region(clientConfiguration.option($T.AWS_REGION))", awsClientOption); } + if (authSchemeSpecUtils.hasSigV4aSupport()) { + ClassName regionSet = ClassName.get("software.amazon.awssdk.http.auth.aws.signer", "RegionSet"); + builder.addStatement("$T sigv4aRegionSet = clientConfiguration.option($T.AWS_SIGV4A_SIGNING_REGION_SET)", + ClassName.get(Set.class), awsClientOption); + builder.beginControlFlow("if (!$T.isNullOrEmpty(sigv4aRegionSet))", CollectionUtils.class); + builder.addStatement("paramsBuilder.regionSet($T.create(sigv4aRegionSet))", regionSet); + builder.endControlFlow(); + } + ClassName paramsBuilderClass = authSchemeSpecUtils.parametersEndpointAwareDefaultImplName().nestedClass("Builder"); ClassName endpointProviderInterface = endpointRulesSpecUtils.providerInterfaceName(); @@ -482,4 +495,86 @@ private static void addEndpointBasedAuthSchemeResolution(MethodSpec.Builder buil builder.addStatement("$T<$T> options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build())", List.class, AuthSchemeOption.class); } + + static MethodSpec resolveEndpointMethod(AuthSchemeSpecUtils authSchemeSpecUtils, + EndpointRulesSpecUtils endpointRulesSpecUtils) { + ClassName utilsClass = endpointRulesSpecUtils.endpointResolverUtilsName(); + ClassName endpointParamsClass = endpointRulesSpecUtils.parametersClassName(); + ClassName providerInterface = endpointRulesSpecUtils.providerInterfaceName(); + ClassName awsEndpointProviderUtils = endpointRulesSpecUtils.sharedAwsEndpointProviderUtilsName(); + ClassName awsEndpointAttribute = ClassName.get("software.amazon.awssdk.awscore.endpoints", "AwsEndpointAttribute"); + ClassName endpointAuthScheme = ClassName.get("software.amazon.awssdk.awscore.endpoints.authscheme", "EndpointAuthScheme"); + + MethodSpec.Builder b = MethodSpec.methodBuilder("resolveEndpoint") + .addModifiers(PRIVATE) + .returns(Endpoint.class) + .addParameter(SdkRequest.class, "request") + .addParameter(ExecutionAttributes.class, "executionAttributes") + .addParameter(String.class, "operationName"); + + b.addStatement("$1T provider = ($1T) executionAttributes.getAttribute($2T.ENDPOINT_PROVIDER)", + providerInterface, SdkInternalExecutionAttribute.class); + + b.beginControlFlow("try"); + b.addStatement("$T endpointParams = $T.ruleParams(request, executionAttributes)", + endpointParamsClass, utilsClass); + b.addStatement("$T endpoint = provider.resolveEndpoint(endpointParams).join()", Endpoint.class); + + b.beginControlFlow("if (!$T.disableHostPrefixInjection(executionAttributes))", awsEndpointProviderUtils); + b.addStatement("$T hostPrefix = $T.hostPrefix(operationName, request)", + ParameterizedTypeName.get(Optional.class, String.class), utilsClass); + b.beginControlFlow("if (hostPrefix.isPresent())"); + b.addStatement("endpoint = $T.addHostPrefix(endpoint, hostPrefix.get())", awsEndpointProviderUtils); + b.endControlFlow(); + b.endControlFlow(); + + b.addStatement("$T endpointAuthSchemes = endpoint.attribute($T.AUTH_SCHEMES)", + ParameterizedTypeName.get(ClassName.get(List.class), endpointAuthScheme), + awsEndpointAttribute); + b.addStatement("$T selectedAuthScheme = executionAttributes.getAttribute($T.SELECTED_AUTH_SCHEME)", + ParameterizedTypeName.get(ClassName.get(SelectedAuthScheme.class), + WildcardTypeName.subtypeOf(Object.class)), + SdkInternalExecutionAttribute.class); + b.beginControlFlow("if (endpointAuthSchemes != null && selectedAuthScheme != null)"); + b.addStatement("selectedAuthScheme = $T.authSchemeWithEndpointSignerProperties(endpointAuthSchemes, selectedAuthScheme)", + utilsClass); + + if (authSchemeSpecUtils.usesSigV4a() || authSchemeSpecUtils.generateEndpointBasedParams()) { + ClassName awsV4aAuthScheme = ClassName.get("software.amazon.awssdk.http.auth.aws.scheme", "AwsV4aAuthScheme"); + ClassName awsV4aHttpSigner = ClassName.get("software.amazon.awssdk.http.auth.aws.signer", "AwsV4aHttpSigner"); + ClassName regionSet = ClassName.get("software.amazon.awssdk.http.auth.aws.signer", "RegionSet"); + ClassName authSchemeOption = ClassName.get("software.amazon.awssdk.http.auth.spi.scheme", "AuthSchemeOption"); + + b.addComment("Precedence of SigV4a RegionSet is set according to multi-auth SigV4a specifications"); + b.beginControlFlow("if (selectedAuthScheme.authSchemeOption().schemeId().equals($T.SCHEME_ID) " + + "&& selectedAuthScheme.authSchemeOption().signerProperty($T.REGION_SET) == null)", + awsV4aAuthScheme, awsV4aHttpSigner); + b.addStatement("$T optionBuilder = selectedAuthScheme.authSchemeOption().toBuilder()", + authSchemeOption.nestedClass("Builder")); + b.addStatement("$1T rs = $1T.create(endpointParams.region().id())", regionSet); + b.addStatement("optionBuilder.putSignerProperty($T.REGION_SET, rs)", awsV4aHttpSigner); + b.addStatement("selectedAuthScheme = new $T(selectedAuthScheme.identity(), selectedAuthScheme.signer(), " + + "optionBuilder.build())", SelectedAuthScheme.class); + b.endControlFlow(); + } + + b.addStatement("executionAttributes.putAttribute($T.SELECTED_AUTH_SCHEME, selectedAuthScheme)", + SdkInternalExecutionAttribute.class); + b.endControlFlow(); + + b.addStatement("$T.setMetricValues(endpoint, executionAttributes)", utilsClass); + + b.addStatement("return endpoint"); + + b.nextControlFlow("catch ($T e)", CompletionException.class); + b.addStatement("$T cause = e.getCause()", Throwable.class); + b.beginControlFlow("if (cause instanceof $T)", SdkClientException.class); + b.addStatement("throw ($T) cause", SdkClientException.class); + b.endControlFlow(); + b.addStatement("throw $T.create($S + cause.getMessage(), cause)", + SdkClientException.class, "Endpoint resolution failed: "); + b.endControlFlow(); + + return b.build(); + } } diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/SyncClientClass.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/SyncClientClass.java index 6d9bc4e0ef4f..0d33e441bc40 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/SyncClientClass.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/SyncClientClass.java @@ -140,7 +140,8 @@ protected void addAdditionalMethods(TypeSpec.Builder type) { .addMethod(nameMethod()) .addMethods(protocolSpec.additionalMethods()) .addMethod(resolveMetricPublishersMethod()) - .addMethod(ClientClassUtils.resolveAuthSchemeOptionsMethod(authSchemeSpecUtils, endpointRulesSpecUtils)); + .addMethod(ClientClassUtils.resolveAuthSchemeOptionsMethod(authSchemeSpecUtils, endpointRulesSpecUtils)) + .addMethod(ClientClassUtils.resolveEndpointMethod(authSchemeSpecUtils, endpointRulesSpecUtils)); protocolSpec.createErrorResponseHandler().ifPresent(type::addMethod); type.addMethod(ClientClassUtils.updateRetryStrategyClientConfigurationMethod()); diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java index ad964239ca02..57790e49d34d 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java @@ -224,6 +224,8 @@ public CodeBlock executionHandler(OperationModel opModel) { .add(".withMetricCollector(apiCallMetricCollector)\n") .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", opModel.getOperationName()) + .add(".withEndpointResolver((r, a) -> resolveEndpoint(r, a, $S))\n", + opModel.getOperationName()) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); @@ -299,6 +301,8 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper .add(".withMetricCollector(apiCallMetricCollector)\n") .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", opModel.getOperationName()) + .add(".withEndpointResolver((r, a) -> resolveEndpoint(r, a, $S))\n", + opModel.getOperationName()) .add(hostPrefixExpression(opModel)) .add(discoveredEndpoint(opModel)) .add(credentialType(opModel, model)) diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/QueryProtocolSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/QueryProtocolSpec.java index 624a4813655d..1132bb911435 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/QueryProtocolSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/QueryProtocolSpec.java @@ -118,6 +118,8 @@ public CodeBlock executionHandler(OperationModel opModel) { .add(".withMetricCollector(apiCallMetricCollector)") .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", opModel.getOperationName()) + .add(".withEndpointResolver((r, a) -> resolveEndpoint(r, a, $S))\n", + opModel.getOperationName()) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); @@ -159,6 +161,8 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper .add(".withMetricCollector(apiCallMetricCollector)\n") .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", opModel.getOperationName()) + .add(".withEndpointResolver((r, a) -> resolveEndpoint(r, a, $S))\n", + opModel.getOperationName()) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/XmlProtocolSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/XmlProtocolSpec.java index 4c849e9cee0d..207e38a49371 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/XmlProtocolSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/XmlProtocolSpec.java @@ -136,6 +136,8 @@ public CodeBlock executionHandler(OperationModel opModel) { .add(".withRequestConfiguration(clientConfiguration)") .add(".withInput($L)", opModel.getInput().getVariableName()) .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", + opModel.getOperationName()) + .add(".withEndpointResolver((r, a) -> resolveEndpoint(r, a, $S))\n", opModel.getOperationName()) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); @@ -217,6 +219,8 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper .add(".withMetricCollector(apiCallMetricCollector)\n") .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", opModel.getOperationName()) + .add(".withEndpointResolver((r, a) -> resolveEndpoint(r, a, $S))\n", + opModel.getOperationName()) .add(asyncRequestBody(opModel)) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverUtilsSpec.java similarity index 76% rename from codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpec.java rename to codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverUtilsSpec.java index 726d61cdb99e..485b9c553788 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverUtilsSpec.java @@ -22,13 +22,11 @@ import com.fasterxml.jackson.jr.stree.JrsString; import com.squareup.javapoet.ClassName; import com.squareup.javapoet.CodeBlock; -import com.squareup.javapoet.FieldSpec; import com.squareup.javapoet.MethodSpec; import com.squareup.javapoet.ParameterizedTypeName; import com.squareup.javapoet.TypeName; import com.squareup.javapoet.TypeSpec; import com.squareup.javapoet.TypeVariableName; -import java.time.Duration; import java.util.Iterator; import java.util.List; import java.util.Locale; @@ -36,7 +34,6 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; -import java.util.concurrent.CompletionException; import javax.lang.model.element.Modifier; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.awscore.AwsExecutionAttribute; @@ -61,15 +58,9 @@ import software.amazon.awssdk.codegen.poet.waiters.JmesPathAcceptorGenerator; import software.amazon.awssdk.core.SdkRequest; import software.amazon.awssdk.core.SelectedAuthScheme; -import software.amazon.awssdk.core.exception.SdkClientException; -import software.amazon.awssdk.core.interceptor.Context; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; -import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; -import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; -import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.endpoints.Endpoint; -import software.amazon.awssdk.http.SdkHttpRequest; import software.amazon.awssdk.http.auth.aws.scheme.AwsV4AuthScheme; import software.amazon.awssdk.http.auth.aws.scheme.AwsV4aAuthScheme; import software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner; @@ -77,14 +68,18 @@ import software.amazon.awssdk.http.auth.aws.signer.RegionSet; import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.identity.spi.Identity; -import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.utils.AttributeMap; import software.amazon.awssdk.utils.CollectionUtils; import software.amazon.awssdk.utils.HostnameValidator; import software.amazon.awssdk.utils.StringUtils; import software.amazon.awssdk.utils.internal.CodegenNamingUtils; -public class EndpointResolverInterceptorSpec implements ClassSpec { +/** + * Generates a per-service endpoint resolver utility class (e.g. {@code DynamoDbEndpointResolverUtils}) + * containing all static helper methods for endpoint resolution: building endpoint params, host prefix + * resolution, auth scheme property copying, and business metrics. + */ +public class EndpointResolverUtilsSpec implements ClassSpec { private final IntermediateModel model; private final EndpointRulesSpecUtils endpointRulesSpecUtils; @@ -95,35 +90,28 @@ public class EndpointResolverInterceptorSpec implements ClassSpec { private final boolean multiAuthSigv4a; private final boolean legacyAuthFromEndpointRulesService; - - public EndpointResolverInterceptorSpec(IntermediateModel model) { + public EndpointResolverUtilsSpec(IntermediateModel model) { this.model = model; this.endpointRulesSpecUtils = new EndpointRulesSpecUtils(model); this.endpointParamsKnowledgeIndex = EndpointParamsKnowledgeIndex.of(model); this.poetExtension = new PoetExtension(model); this.jmesPathGenerator = new JmesPathAcceptorGenerator(poetExtension.jmesPathRuntimeClass()); - // We need to know whether the service has a dependency on the http-auth-aws module. Because we can't check that - // directly, assume that if they're using AwsV4AuthScheme or AwsV4aAuthScheme that it's available. Set> supportedAuthSchemes = ModelAuthSchemeClassesKnowledgeIndex.of(model).serviceConcreteAuthSchemeClasses(); this.dependsOnHttpAuthAws = supportedAuthSchemes.contains(AwsV4AuthScheme.class) || supportedAuthSchemes.contains(AwsV4aAuthScheme.class); - this.multiAuthSigv4a = new AuthSchemeSpecUtils(model).usesSigV4a(); this.legacyAuthFromEndpointRulesService = new AuthSchemeSpecUtils(model).generateEndpointBasedParams(); } @Override public TypeSpec poetSpec() { - FieldSpec endpointAuthSchemeStrategyFieldSpec = endpointAuthSchemeStrategyFieldSpec(); TypeSpec.Builder b = PoetUtils.createClassBuilder(className()) .addModifiers(Modifier.PUBLIC, Modifier.FINAL) .addAnnotation(SdkInternalApi.class) - .addSuperinterface(ExecutionInterceptor.class); + .addMethod(privateConstructor()); - b.addMethod(modifyRequestMethod(endpointAuthSchemeStrategyFieldSpec.name)); - b.addMethod(modifyHttpRequestMethod()); b.addMethod(ruleParams()); b.addMethod(setContextParams()); @@ -146,126 +134,17 @@ public TypeSpec poetSpec() { endpointParamsKnowledgeIndex.addAccountIdMethodsIfPresent(b); b.addMethod(setMetricValuesMethod()); + return b.build(); } @Override public ClassName className() { - return endpointRulesSpecUtils.resolverInterceptorName(); - } - - private FieldSpec endpointAuthSchemeStrategyFieldSpec() { - return FieldSpec.builder(endpointRulesSpecUtils.rulesRuntimeClassName("EndpointAuthSchemeStrategy"), - "endpointAuthSchemeStrategy", Modifier.PRIVATE, Modifier.FINAL) - .build(); - } - - private MethodSpec modifyRequestMethod(String endpointAuthSchemeStrategyFieldName) { - - MethodSpec.Builder b = MethodSpec.methodBuilder("modifyRequest") - .addModifiers(Modifier.PUBLIC) - .addAnnotation(Override.class) - .returns(SdkRequest.class) - .addParameter(Context.ModifyRequest.class, "context") - .addParameter(ExecutionAttributes.class, "executionAttributes"); - - String providerVar = "provider"; - - b.addStatement("$T result = context.request()", SdkRequest.class); - // We skip resolution if the source of the endpoint is the endpoint discovery call - b.beginControlFlow("if ($1T.endpointIsDiscovered(executionAttributes))", - endpointRulesSpecUtils.rulesRuntimeClassName("AwsEndpointProviderUtils")); - b.addStatement("return result"); - b.endControlFlow(); - - b.addStatement("$1T $2N = ($1T) executionAttributes.getAttribute($3T.ENDPOINT_PROVIDER)", - endpointRulesSpecUtils.providerInterfaceName(), providerVar, SdkInternalExecutionAttribute.class); - b.beginControlFlow("try"); - b.addStatement("long resolveEndpointStart = $T.nanoTime()", System.class); - b.addStatement("$T endpointParams = ruleParams(result, executionAttributes)", - endpointRulesSpecUtils.parametersClassName()); - b.addStatement("$T endpoint = $N.resolveEndpoint(endpointParams).join()", - Endpoint.class, providerVar); - b.addStatement("$1T resolveEndpointDuration = $1T.ofNanos($2T.nanoTime() - resolveEndpointStart)", Duration.class, - System.class); - b.addStatement("$T metricCollector = executionAttributes.getOptionalAttribute($T.API_CALL_METRIC_COLLECTOR)", - ParameterizedTypeName.get(Optional.class, MetricCollector.class), SdkExecutionAttribute.class); - b.addStatement("metricCollector.ifPresent(mc -> mc.reportMetric($T.ENDPOINT_RESOLVE_DURATION, resolveEndpointDuration))", - CoreMetric.class); - b.beginControlFlow("if (!$T.disableHostPrefixInjection(executionAttributes))", - endpointRulesSpecUtils.rulesRuntimeClassName("AwsEndpointProviderUtils")); - b.addStatement("$T hostPrefix = hostPrefix(executionAttributes.getAttribute($T.OPERATION_NAME), result)", - ParameterizedTypeName.get(Optional.class, String.class), SdkExecutionAttribute.class); - b.beginControlFlow("if (hostPrefix.isPresent())"); - b.addStatement("endpoint = $T.addHostPrefix(endpoint, hostPrefix.get())", - endpointRulesSpecUtils.rulesRuntimeClassName("AwsEndpointProviderUtils")); - b.endControlFlow(); - b.endControlFlow(); - - - // If the endpoint resolver returns auth settings, use them as signer properties. - // This effectively works to set the preSRA Signer ExecutionAttributes, so it is not conditional on useSraAuth. - b.addStatement("$T<$T> endpointAuthSchemes = endpoint.attribute($T.AUTH_SCHEMES)", - List.class, EndpointAuthScheme.class, AwsEndpointAttribute.class); - b.addStatement("$T selectedAuthScheme = executionAttributes.getAttribute($T.SELECTED_AUTH_SCHEME)", - SelectedAuthScheme.class, SdkInternalExecutionAttribute.class); - b.beginControlFlow("if (endpointAuthSchemes != null && selectedAuthScheme != null)"); - b.addStatement("selectedAuthScheme = authSchemeWithEndpointSignerProperties(endpointAuthSchemes, selectedAuthScheme)"); - if (multiAuthSigv4a || legacyAuthFromEndpointRulesService) { - b.addComment("Precedence of SigV4a RegionSet is set according to multi-auth SigV4a specifications"); - b.beginControlFlow("if(selectedAuthScheme.authSchemeOption().schemeId().equals($T.SCHEME_ID) " - + "&& selectedAuthScheme.authSchemeOption().signerProperty($T.REGION_SET) == null)", - AwsV4aAuthScheme.class, AwsV4aHttpSigner.class); - b.addStatement("$T optionBuilder = selectedAuthScheme.authSchemeOption().toBuilder()", - AuthSchemeOption.Builder.class); - b.addStatement("$T regionSet = $T.create(endpointParams.region().id())", - RegionSet.class, RegionSet.class); - b.addStatement("optionBuilder.putSignerProperty($T.REGION_SET, regionSet)", AwsV4aHttpSigner.class); - b.addStatement("selectedAuthScheme = new $T(selectedAuthScheme.identity(), selectedAuthScheme.signer(), " - + "optionBuilder.build())", SelectedAuthScheme.class); - b.endControlFlow(); - } - b.addStatement("executionAttributes.putAttribute($T.SELECTED_AUTH_SCHEME, selectedAuthScheme)", - SdkInternalExecutionAttribute.class); - b.endControlFlow(); - - b.addStatement("executionAttributes.putAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT, endpoint)"); - b.addStatement("setMetricValues(endpoint, executionAttributes)"); - b.addStatement("return result"); - b.endControlFlow(); - b.beginControlFlow("catch ($T e)", CompletionException.class); - b.addStatement("$T cause = e.getCause()", Throwable.class); - b.beginControlFlow("if (cause instanceof $T)", SdkClientException.class); - b.addStatement("throw ($T) cause", SdkClientException.class); - b.endControlFlow(); - b.beginControlFlow("else"); - b.addStatement("throw $T.create($S, cause)", SdkClientException.class, "Endpoint resolution failed"); - b.endControlFlow(); - b.endControlFlow(); - return b.build(); + return endpointRulesSpecUtils.endpointResolverUtilsName(); } - private MethodSpec modifyHttpRequestMethod() { - MethodSpec.Builder b = MethodSpec.methodBuilder("modifyHttpRequest") - .addModifiers(Modifier.PUBLIC) - .addAnnotation(Override.class) - .returns(SdkHttpRequest.class) - .addParameter(Context.ModifyHttpRequest.class, "context") - .addParameter(ExecutionAttributes.class, "executionAttributes"); - - b.addStatement("$T resolvedEndpoint = executionAttributes.getAttribute($T.RESOLVED_ENDPOINT)", - Endpoint.class, SdkInternalExecutionAttribute.class); - b.beginControlFlow("if (resolvedEndpoint.headers().isEmpty())"); - b.addStatement("return context.httpRequest()"); - b.endControlFlow(); - - b.addStatement("$T httpRequestBuilder = context.httpRequest().toBuilder()", SdkHttpRequest.Builder.class); - b.addCode("resolvedEndpoint.headers().forEach((name, values) -> {"); - b.addStatement("values.forEach(v -> httpRequestBuilder.appendHeader(name, v))"); - b.addCode("});"); - b.addStatement("return httpRequestBuilder.build()"); - - return b.build(); + private static MethodSpec privateConstructor() { + return MethodSpec.constructorBuilder().addModifiers(Modifier.PRIVATE).build(); } private MethodSpec ruleParams() { @@ -278,7 +157,6 @@ private MethodSpec ruleParams() { b.addStatement("$T builder = $T.builder()", paramsBuilderClass(), endpointRulesSpecUtils.parametersClassName()); Map parameters = model.getEndpointRuleSetModel().getParameters(); - parameters.forEach((n, m) -> { if (m.getBuiltInEnum() == null) { return; @@ -307,15 +185,11 @@ private MethodSpec ruleParams() { b.addStatement("builder.$N(executionAttributes.getAttribute($T.$N))", setter, AwsExecutionAttribute.class, model.getNamingStrategy().getEnumValueName(n)); break; - // The S3 specific built-ins are set through the existing S3Configuration and set through client context params case AWS_S3_ACCELERATE: case AWS_S3_DISABLE_MULTI_REGION_ACCESS_POINTS: case AWS_S3_FORCE_PATH_STYLE: case AWS_S3_USE_ARN_REGION: case AWS_S3_CONTROL_USE_ARN_REGION: - // end of S3 specific builtins - - // V2 doesn't support this, only regional endpoints case AWS_STS_USE_GLOBAL_ENDPOINT: return; default: @@ -339,85 +213,77 @@ private MethodSpec ruleParams() { private CodeBlock endpointProviderUtilsSetter(String builtInFn, String setterName) { return CodeBlock.of("builder.$N($T.$N(executionAttributes))", setterName, - endpointRulesSpecUtils.rulesRuntimeClassName("AwsEndpointProviderUtils"), builtInFn); + endpointRulesSpecUtils.sharedAwsEndpointProviderUtilsName(), builtInFn); } private ClassName paramsBuilderClass() { return endpointRulesSpecUtils.parametersClassName().nestedClass("Builder"); } - private MethodSpec addStaticContextParamsMethod(OperationModel opModel) { - String methodName = staticContextParamsMethodName(opModel); + private MethodSpec setContextParams() { + Map operations = model.getOperations(); - MethodSpec.Builder b = MethodSpec.methodBuilder(methodName) + MethodSpec.Builder b = MethodSpec.methodBuilder("setContextParams") .addModifiers(Modifier.PRIVATE, Modifier.STATIC) - .returns(void.class) - .addParameter(paramsBuilderClass(), "params"); + .addParameter(paramsBuilderClass(), "params") + .addParameter(String.class, "operationName") + .addParameter(SdkRequest.class, "request") + .returns(void.class); - opModel.getStaticContextParams().forEach((n, m) -> { - String setterName = endpointRulesSpecUtils.paramMethodName(n); - TreeNode value = m.getValue(); - switch (value.asToken()) { - case VALUE_STRING: - b.addStatement("params.$N($S)", setterName, ((JrsString) value).getValue()); - break; - case VALUE_TRUE: - case VALUE_FALSE: - b.addStatement("params.$N($L)", setterName, ((JrsBoolean) value).booleanValue()); - break; - case START_ARRAY: - JrsArray arrayValue = (JrsArray) value; - CodeBlock arrayCode = endpointRulesSpecUtils.treeNodeToLiteral(arrayValue); - b.addStatement("params.$N($L)", setterName, arrayCode); - break; - default: - throw new RuntimeException("Don't know how to set parameter of type " + value.asToken()); - } - }); + boolean generateSwitch = operations.values().stream().anyMatch(this::hasContextParams); + if (generateSwitch) { + b.beginControlFlow("switch (operationName)"); + operations.forEach((n, m) -> { + if (!hasContextParams(m)) { + return; + } + String requestClassName = model.getNamingStrategy().getRequestClassName(m.getOperationName()); + ClassName requestClass = poetExtension.getModelClass(requestClassName); + b.addCode("case $S:", n); + b.addStatement("setContextParams(params, ($T) request)", requestClass); + b.addStatement("break"); + }); + b.addCode("default:"); + b.addStatement("break"); + b.endControlFlow(); + } return b.build(); } - private String staticContextParamsMethodName(OperationModel opModel) { - return opModel.getMethodName() + "StaticContextParams"; - } - - private boolean hasStaticContextParams(OperationModel opModel) { - Map staticContextParams = opModel.getStaticContextParams(); - return staticContextParams != null && !staticContextParams.isEmpty(); - } - - private boolean hasOperationContextParams(OperationModel opModel) { - return CollectionUtils.isNotEmpty(opModel.getOperationContextParams()); - } - - private void addStaticContextParamMethods(TypeSpec.Builder classBuilder) { - Map operations = model.getOperations(); - - operations.forEach((n, m) -> { - if (hasStaticContextParams(m)) { - classBuilder.addMethod(addStaticContextParamsMethod(m)); - } - }); - } - private void addContextParamMethods(TypeSpec.Builder classBuilder) { - Map operations = model.getOperations(); - - operations.forEach((n, m) -> { + model.getOperations().forEach((n, m) -> { if (hasContextParams(m)) { classBuilder.addMethod(setContextParamsMethod(m)); } }); } - private void addOperationContextParamMethods(TypeSpec.Builder classBuilder) { - Map operations = model.getOperations(); - operations.forEach((n, m) -> { - if (hasOperationContextParams(m)) { - classBuilder.addMethod(setOperationContextParamsMethod(m)); + private MethodSpec setContextParamsMethod(OperationModel opModel) { + String requestClassName = model.getNamingStrategy().getRequestClassName(opModel.getOperationName()); + ClassName requestClass = poetExtension.getModelClass(requestClassName); + + MethodSpec.Builder b = MethodSpec.methodBuilder("setContextParams") + .addModifiers(Modifier.PRIVATE, Modifier.STATIC) + .addParameter(paramsBuilderClass(), "params") + .addParameter(requestClass, "request") + .returns(void.class); + + opModel.getInputShape().getMembers().forEach(m -> { + ContextParam param = m.getContextParam(); + if (param == null) { + return; } + String setterName = endpointRulesSpecUtils.paramMethodName(param.getName()); + b.addStatement("params.$N(request.$N())", setterName, m.getFluentGetterMethodName()); }); + + return b.build(); + } + + private boolean hasContextParams(OperationModel opModel) { + return opModel.getInputShape().getMembers().stream() + .anyMatch(m -> m.getContextParam() != null); } private MethodSpec setStaticContextParamsMethod() { @@ -432,12 +298,10 @@ private MethodSpec setStaticContextParamsMethod() { boolean generateSwitch = operations.values().stream().anyMatch(this::hasStaticContextParams); if (generateSwitch) { b.beginControlFlow("switch (operationName)"); - operations.forEach((n, m) -> { if (!hasStaticContextParams(m)) { return; } - b.addCode("case $S:", n); b.addStatement("$N(params)", staticContextParamsMethodName(m)); b.addStatement("break"); @@ -450,38 +314,53 @@ private MethodSpec setStaticContextParamsMethod() { return b.build(); } - private MethodSpec setContextParams() { - Map operations = model.getOperations(); + private void addStaticContextParamMethods(TypeSpec.Builder classBuilder) { + model.getOperations().forEach((n, m) -> { + if (hasStaticContextParams(m)) { + classBuilder.addMethod(addStaticContextParamsMethod(m)); + } + }); + } - MethodSpec.Builder b = MethodSpec.methodBuilder("setContextParams") - .addModifiers(Modifier.PRIVATE, Modifier.STATIC) - .addParameter(paramsBuilderClass(), "params") - .addParameter(String.class, "operationName") - .addParameter(SdkRequest.class, "request") - .returns(void.class); + private MethodSpec addStaticContextParamsMethod(OperationModel opModel) { + String methodName = staticContextParamsMethodName(opModel); - boolean generateSwitch = operations.values().stream().anyMatch(this::hasContextParams); - if (generateSwitch) { - b.beginControlFlow("switch (operationName)"); + MethodSpec.Builder b = MethodSpec.methodBuilder(methodName) + .addModifiers(Modifier.PRIVATE, Modifier.STATIC) + .returns(void.class) + .addParameter(paramsBuilderClass(), "params"); - operations.forEach((n, m) -> { - if (!hasContextParams(m)) { - return; - } + opModel.getStaticContextParams().forEach((n, m) -> { + String setterName = endpointRulesSpecUtils.paramMethodName(n); + TreeNode value = m.getValue(); + switch (value.asToken()) { + case VALUE_STRING: + b.addStatement("params.$N($S)", setterName, ((JrsString) value).getValue()); + break; + case VALUE_TRUE: + case VALUE_FALSE: + b.addStatement("params.$N($L)", setterName, ((JrsBoolean) value).booleanValue()); + break; + case START_ARRAY: + JrsArray arrayValue = (JrsArray) value; + CodeBlock arrayCode = endpointRulesSpecUtils.treeNodeToLiteral(arrayValue); + b.addStatement("params.$N($L)", setterName, arrayCode); + break; + default: + throw new RuntimeException("Don't know how to set parameter of type " + value.asToken()); + } + }); - String requestClassName = model.getNamingStrategy().getRequestClassName(m.getOperationName()); - ClassName requestClass = poetExtension.getModelClass(requestClassName); + return b.build(); + } - b.addCode("case $S:", n); - b.addStatement("setContextParams(params, ($T) request)", requestClass); - b.addStatement("break"); - }); - b.addCode("default:"); - b.addStatement("break"); - b.endControlFlow(); - } + private String staticContextParamsMethodName(OperationModel opModel) { + return opModel.getMethodName() + "StaticContextParams"; + } - return b.build(); + private boolean hasStaticContextParams(OperationModel opModel) { + Map staticContextParams = opModel.getStaticContextParams(); + return staticContextParams != null && !staticContextParams.isEmpty(); } private MethodSpec setOperationContextParams() { @@ -497,15 +376,12 @@ private MethodSpec setOperationContextParams() { boolean generateSwitch = operations.values().stream().anyMatch(this::hasOperationContextParams); if (generateSwitch) { b.beginControlFlow("switch (operationName)"); - operations.forEach((n, m) -> { if (!hasOperationContextParams(m)) { return; } - String requestClassName = model.getNamingStrategy().getRequestClassName(m.getOperationName()); ClassName requestClass = poetExtension.getModelClass(requestClassName); - b.addCode("case $S:", n); b.addStatement("setOperationContextParams(params, ($T) request)", requestClass); b.addStatement("break"); @@ -518,28 +394,12 @@ private MethodSpec setOperationContextParams() { return b.build(); } - private MethodSpec setContextParamsMethod(OperationModel opModel) { - String requestClassName = model.getNamingStrategy().getRequestClassName(opModel.getOperationName()); - ClassName requestClass = poetExtension.getModelClass(requestClassName); - - MethodSpec.Builder b = MethodSpec.methodBuilder("setContextParams") - .addModifiers(Modifier.PRIVATE, Modifier.STATIC) - .addParameter(paramsBuilderClass(), "params") - .addParameter(requestClass, "request") - .returns(void.class); - - opModel.getInputShape().getMembers().forEach(m -> { - ContextParam param = m.getContextParam(); - if (param == null) { - return; + private void addOperationContextParamMethods(TypeSpec.Builder classBuilder) { + model.getOperations().forEach((n, m) -> { + if (hasOperationContextParams(m)) { + classBuilder.addMethod(setOperationContextParamsMethod(m)); } - - String setterName = endpointRulesSpecUtils.paramMethodName(param.getName()); - - b.addStatement("params.$N(request.$N())", setterName, m.getFluentGetterMethodName()); }); - - return b.build(); } private MethodSpec setOperationContextParamsMethod(OperationModel opModel) { @@ -557,7 +417,6 @@ private MethodSpec setOperationContextParamsMethod(OperationModel opModel) { opModel.getOperationContextParams().forEach((key, value) -> { if (Objects.requireNonNull(value.getPath().asToken()) == JsonToken.VALUE_STRING) { String setterName = endpointRulesSpecUtils.paramMethodName(key); - String jmesPathString = ((JrsString) value.getPath()).getValue(); CodeBlock addParam = CodeBlock.builder() .add("params.$N(", setterName) @@ -565,18 +424,20 @@ private MethodSpec setOperationContextParamsMethod(OperationModel opModel) { .add(matchToParameterType(key)) .add(")") .build(); - b.addStatement(addParam); } else { throw new RuntimeException("Invalid operation context parameter path for " + opModel.getOperationName() + ". Expected VALUE_STRING, but got " + value.getPath().asToken()); } - }); return b.build(); } + private boolean hasOperationContextParams(OperationModel opModel) { + return CollectionUtils.isNotEmpty(opModel.getOperationContextParams()); + } + private CodeBlock matchToParameterType(String paramName) { Map parameters = model.getEndpointRuleSetModel().getParameters(); Optional endpointParameter = parameters.entrySet().stream() @@ -601,11 +462,6 @@ private CodeBlock convertValueToParameterType(ParameterModel parameterModel) { } } - private boolean hasContextParams(OperationModel opModel) { - return opModel.getInputShape().getMembers().stream() - .anyMatch(m -> m.getContextParam() != null); - } - private boolean hasClientContextParams() { Map clientContextParams = model.getClientContextParams(); return clientContextParams != null && !clientContextParams.isEmpty(); @@ -627,20 +483,18 @@ private MethodSpec setClientContextParamsMethod() { params.forEach((n, m) -> { String attrName = endpointRulesSpecUtils.clientContextParamName(n); b.addStatement("$T.ofNullable(clientContextParams.get($T.$N)).ifPresent(params::$N)", Optional.class, paramsClass, - attrName, - endpointRulesSpecUtils.paramMethodName(n)); + attrName, endpointRulesSpecUtils.paramMethodName(n)); }); return b.build(); } - private MethodSpec hostPrefixMethod() { MethodSpec.Builder builder = MethodSpec.methodBuilder("hostPrefix") .returns(ParameterizedTypeName.get(Optional.class, String.class)) .addParameter(String.class, "operationName") .addParameter(SdkRequest.class, "request") - .addModifiers(Modifier.PRIVATE, Modifier.STATIC); + .addModifiers(Modifier.PUBLIC, Modifier.STATIC); boolean generateSwitch = model.getOperations().values().stream().anyMatch(opModel -> StringUtils.isNotBlank(getHostPrefix(opModel))); @@ -649,16 +503,13 @@ private MethodSpec hostPrefixMethod() { builder.addStatement("return $T.empty()", Optional.class); } else { builder.beginControlFlow("switch (operationName)"); - model.getOperations().forEach((name, opModel) -> { String hostPrefix = getHostPrefix(opModel); if (StringUtils.isBlank(hostPrefix)) { return; } - builder.beginControlFlow("case $S:", name); HostPrefixProcessor processor = new HostPrefixProcessor(hostPrefix); - if (processor.c2jNames().isEmpty()) { builder.addStatement("return $T.of($S)", Optional.class, processor.hostWithStringSpecifier()); } else { @@ -666,12 +517,8 @@ private MethodSpec hostPrefixMethod() { processor.c2jNames().forEach(c2jName -> { builder.addStatement("$1T.validateHostnameCompliant(request.getValueForField($2S, $3T.class)" + ".orElse(null), $2S, $4S)", - HostnameValidator.class, - c2jName, - String.class, - requestVar); + HostnameValidator.class, c2jName, String.class, requestVar); }); - builder.addCode("return $T.of($T.format($S, ", Optional.class, String.class, processor.hostWithStringSpecifier()); Iterator c2jNamesIter = processor.c2jNames().listIterator(); @@ -685,7 +532,6 @@ private MethodSpec hostPrefixMethod() { } builder.endControlFlow(); }); - builder.addCode("default:"); builder.addStatement("return $T.empty()", Optional.class); builder.endControlFlow(); @@ -699,7 +545,6 @@ private String getHostPrefix(OperationModel opModel) { if (endpointTrait == null) { return null; } - return endpointTrait.getHostPrefix(); } @@ -711,15 +556,13 @@ private MethodSpec authSchemeWithEndpointSignerPropertiesMethod() { MethodSpec.Builder method = MethodSpec.methodBuilder("authSchemeWithEndpointSignerProperties") - .addModifiers(Modifier.PRIVATE) + .addModifiers(Modifier.PUBLIC, Modifier.STATIC) .addTypeVariable(tExtendsIdentity) .returns(selectedAuthSchemeOfT) .addParameter(listOfEndpointAuthScheme, "endpointAuthSchemes") .addParameter(selectedAuthSchemeOfT, "selectedAuthScheme"); method.beginControlFlow("for ($T endpointAuthScheme : endpointAuthSchemes)", EndpointAuthScheme.class); - - // Don't include signer properties for auth options that don't match our selected auth scheme method.beginControlFlow("if (!endpointAuthScheme.schemeId()" + ".equals(selectedAuthScheme.authSchemeOption().schemeId()))"); method.addStatement("continue"); @@ -738,9 +581,7 @@ private MethodSpec authSchemeWithEndpointSignerPropertiesMethod() { method.addStatement("throw new $T(\"Endpoint auth scheme '\" + endpointAuthScheme.name() + \"' cannot be mapped to the " + "SDK auth scheme. Was it declared in the service's model?\")", IllegalArgumentException.class); - method.endControlFlow(); - method.addStatement("return selectedAuthScheme"); return method.build(); @@ -750,34 +591,27 @@ private static CodeBlock copyV4EndpointSignerPropertiesToAuth() { CodeBlock.Builder code = CodeBlock.builder(); code.beginControlFlow("if (endpointAuthScheme instanceof $T)", SigV4AuthScheme.class); code.addStatement("$1T v4AuthScheme = ($1T) endpointAuthScheme", SigV4AuthScheme.class); - code.beginControlFlow("if (v4AuthScheme.isDisableDoubleEncodingSet())"); code.addStatement("option.putSignerProperty($T.DOUBLE_URL_ENCODE, !v4AuthScheme.disableDoubleEncoding())", AwsV4HttpSigner.class); code.endControlFlow(); - code.beginControlFlow("if (v4AuthScheme.signingRegion() != null)"); - code.addStatement("option.putSignerProperty($T.REGION_NAME, v4AuthScheme.signingRegion())", - AwsV4HttpSigner.class); + code.addStatement("option.putSignerProperty($T.REGION_NAME, v4AuthScheme.signingRegion())", AwsV4HttpSigner.class); code.endControlFlow(); - code.beginControlFlow("if (v4AuthScheme.signingName() != null)"); code.addStatement("option.putSignerProperty($T.SERVICE_SIGNING_NAME, v4AuthScheme.signingName())", AwsV4HttpSigner.class); code.endControlFlow(); - code.addStatement("return new $T<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build())", SelectedAuthScheme.class); code.endControlFlow(); return code.build(); } - private CodeBlock copyV4aEndpointSignerPropertiesToAuth() { + private CodeBlock copyV4aEndpointSignerPropertiesToAuth() { CodeBlock.Builder code = CodeBlock.builder(); - code.beginControlFlow("if (endpointAuthScheme instanceof $T)", SigV4aAuthScheme.class); code.addStatement("$1T v4aAuthScheme = ($1T) endpointAuthScheme", SigV4aAuthScheme.class); - code.beginControlFlow("if (v4aAuthScheme.isDisableDoubleEncodingSet())"); code.addStatement("option.putSignerProperty($T.DOUBLE_URL_ENCODE, !v4aAuthScheme.disableDoubleEncoding())", AwsV4aHttpSigner.class); @@ -793,12 +627,10 @@ private CodeBlock copyV4aEndpointSignerPropertiesToAuth() { code.addStatement("$1T regionSet = $1T.create(v4aAuthScheme.signingRegionSet())", RegionSet.class); code.addStatement("option.putSignerProperty($T.REGION_SET, regionSet)", AwsV4aHttpSigner.class); code.endControlFlow(); - code.beginControlFlow("if (v4aAuthScheme.signingName() != null)"); code.addStatement("option.putSignerProperty($T.SERVICE_SIGNING_NAME, v4aAuthScheme.signingName())", AwsV4aHttpSigner.class); code.endControlFlow(); - code.addStatement("return new $T<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build())", SelectedAuthScheme.class); code.endControlFlow(); @@ -812,22 +644,18 @@ private CodeBlock copyS3ExpressEndpointSignerPropertiesToAuth() { "S3ExpressEndpointAuthScheme"); code.beginControlFlow("if (endpointAuthScheme instanceof $T)", s3ExpressEndpointAuthSchemeClassName); code.addStatement("$1T s3ExpressAuthScheme = ($1T) endpointAuthScheme", s3ExpressEndpointAuthSchemeClassName); - code.beginControlFlow("if (s3ExpressAuthScheme.isDisableDoubleEncodingSet())"); code.addStatement("option.putSignerProperty($T.DOUBLE_URL_ENCODE, !s3ExpressAuthScheme.disableDoubleEncoding())", AwsV4HttpSigner.class); code.endControlFlow(); - code.beginControlFlow("if (s3ExpressAuthScheme.signingRegion() != null)"); code.addStatement("option.putSignerProperty($T.REGION_NAME, s3ExpressAuthScheme.signingRegion())", AwsV4HttpSigner.class); code.endControlFlow(); - code.beginControlFlow("if (s3ExpressAuthScheme.signingName() != null)"); code.addStatement("option.putSignerProperty($T.SERVICE_SIGNING_NAME, s3ExpressAuthScheme.signingName())", AwsV4HttpSigner.class); code.endControlFlow(); - code.addStatement("return new $T<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build())", SelectedAuthScheme.class); code.endControlFlow(); @@ -836,7 +664,7 @@ private CodeBlock copyS3ExpressEndpointSignerPropertiesToAuth() { private MethodSpec setMetricValuesMethod() { MethodSpec.Builder b = MethodSpec.methodBuilder("setMetricValues") - .addModifiers(Modifier.PRIVATE) + .addModifiers(Modifier.PUBLIC, Modifier.STATIC) .addParameter(Endpoint.class, "endpoint") .addParameter(ExecutionAttributes.class, "executionAttributes") .returns(void.class); @@ -851,7 +679,7 @@ private MethodSpec setMetricValuesMethod() { b.addStatement("$T.addS3ExpressBusinessMetricIfApplicable(executionAttributes)", ClassName.get("software.amazon.awssdk.services.s3.internal.s3express", "S3ExpressUtils")); } - + return b.build(); } } diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointRulesSpecUtils.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointRulesSpecUtils.java index dfcf68b056fd..25f483fa4829 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointRulesSpecUtils.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointRulesSpecUtils.java @@ -93,6 +93,16 @@ public ClassName resolverInterceptorName() { md.getServiceName() + "ResolveEndpointInterceptor"); } + public ClassName endpointResolverUtilsName() { + Metadata md = intermediateModel.getMetadata(); + return ClassName.get(md.getFullInternalEndpointRulesPackageName(), + md.getServiceName() + "EndpointResolverUtils"); + } + + public ClassName sharedAwsEndpointProviderUtilsName() { + return ClassName.get("software.amazon.awssdk.awscore.internal.endpoints", "AwsEndpointProviderUtils"); + } + public ClassName requestModifierInterceptorName() { Metadata md = intermediateModel.getMetadata(); return ClassName.get(md.getFullInternalEndpointRulesPackageName(), diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/RequestEndpointInterceptorSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/RequestEndpointInterceptorSpec.java deleted file mode 100644 index 63b5133f180e..000000000000 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/RequestEndpointInterceptorSpec.java +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package software.amazon.awssdk.codegen.poet.rules; - -import com.squareup.javapoet.ClassName; -import com.squareup.javapoet.MethodSpec; -import com.squareup.javapoet.TypeSpec; -import javax.lang.model.element.Modifier; -import software.amazon.awssdk.annotations.SdkInternalApi; -import software.amazon.awssdk.codegen.model.intermediate.IntermediateModel; -import software.amazon.awssdk.codegen.poet.ClassSpec; -import software.amazon.awssdk.codegen.poet.PoetUtils; -import software.amazon.awssdk.core.interceptor.Context; -import software.amazon.awssdk.core.interceptor.ExecutionAttributes; -import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; -import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; -import software.amazon.awssdk.endpoints.Endpoint; -import software.amazon.awssdk.http.SdkHttpRequest; - -public class RequestEndpointInterceptorSpec implements ClassSpec { - private final EndpointRulesSpecUtils endpointRulesSpecUtils; - - public RequestEndpointInterceptorSpec(IntermediateModel model) { - this.endpointRulesSpecUtils = new EndpointRulesSpecUtils(model); - } - - @Override - public TypeSpec poetSpec() { - TypeSpec.Builder b = PoetUtils.createClassBuilder(className()) - .addModifiers(Modifier.PUBLIC, Modifier.FINAL) - .addAnnotation(SdkInternalApi.class) - .addSuperinterface(ExecutionInterceptor.class); - - b.addMethod(modifyHttpRequestMethod()); - - return b.build(); - } - - @Override - public ClassName className() { - return endpointRulesSpecUtils.requestModifierInterceptorName(); - } - - private MethodSpec modifyHttpRequestMethod() { - - MethodSpec.Builder b = MethodSpec.methodBuilder("modifyHttpRequest") - .addModifiers(Modifier.PUBLIC) - .addAnnotation(Override.class) - .returns(SdkHttpRequest.class) - .addParameter(Context.ModifyHttpRequest.class, "context") - .addParameter(ExecutionAttributes.class, "executionAttributes"); - - - // We skip setting the endpoint here if the source of the endpoint is the endpoint discovery call - b.beginControlFlow("if ($1T.endpointIsDiscovered(executionAttributes))", - endpointRulesSpecUtils.rulesRuntimeClassName("AwsEndpointProviderUtils")) - .addStatement("return context.httpRequest()") - .endControlFlow().build(); - - b.addStatement("$T endpoint = executionAttributes.getAttribute($T.RESOLVED_ENDPOINT)", - Endpoint.class, - SdkInternalExecutionAttribute.class); - b.addStatement("return $T.setUri(context.httpRequest()," - + "executionAttributes.getAttribute($T.CLIENT_ENDPOINT_PROVIDER).clientEndpoint()," - + "endpoint.url())", - endpointRulesSpecUtils.rulesRuntimeClassName("AwsEndpointProviderUtils"), - SdkInternalExecutionAttribute.class); - return b.build(); - } - -} diff --git a/codegen/src/test/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeSpecTest.java b/codegen/src/test/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeSpecTest.java index 6b095f01ddc6..067d5d1673ed 100644 --- a/codegen/src/test/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeSpecTest.java +++ b/codegen/src/test/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeSpecTest.java @@ -187,28 +187,6 @@ static List parameters() { .caseName("auth-with-legacy-trait") .outputFileSuffix("default-provider") .build(), - // Interceptors - // - Normal case - TestCase.builder() - .modelProvider(ClientTestModels::queryServiceModels) - .classSpecProvider(AuthSchemeInterceptorSpec::new) - .caseName("query") - .outputFileSuffix("interceptor") - .build(), - // - Endpoints based params with allow list - TestCase.builder() - .modelProvider(ClientTestModels::queryServiceModelsEndpointAuthParamsWithAllowList) - .classSpecProvider(AuthSchemeInterceptorSpec::new) - .caseName("query-endpoint-auth-params-with-allowlist") - .outputFileSuffix("interceptor") - .build(), - // - Endpoints based params without allow list - TestCase.builder() - .modelProvider(ClientTestModels::queryServiceModelsEndpointAuthParamsWithoutAllowList) - .classSpecProvider(AuthSchemeInterceptorSpec::new) - .caseName("query-endpoint-auth-params-without-allowlist") - .outputFileSuffix("interceptor") - .build(), // Service with auth trait with Sigv4a TestCase.builder() .modelProvider(ClientTestModels::opsWithSigv4a) @@ -228,19 +206,6 @@ static List parameters() { .caseName("ops-auth-sigv4a-value") .outputFileSuffix("default-params") .build(), - TestCase.builder() - .modelProvider(ClientTestModels::opsWithSigv4a) - .classSpecProvider(AuthSchemeInterceptorSpec::new) - .caseName("ops-auth-sigv4a-value") - .outputFileSuffix("interceptor") - .build(), - // service with environment bearer token enabled - TestCase.builder() - .modelProvider(ClientTestModels::envBearerTokenServiceModels) - .classSpecProvider(AuthSchemeInterceptorSpec::new) - .caseName("env-bearer-token") - .outputFileSuffix("interceptor") - .build(), // Rest Json service with checksum TestCase.builder() .modelProvider(ClientTestModels::restJsonServiceModels) diff --git a/codegen/src/test/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpecTest.java b/codegen/src/test/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpecTest.java deleted file mode 100644 index 89f87ff41952..000000000000 --- a/codegen/src/test/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpecTest.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package software.amazon.awssdk.codegen.poet.rules; - -import static org.hamcrest.MatcherAssert.assertThat; -import static software.amazon.awssdk.codegen.poet.PoetMatchers.generatesTo; - -import org.junit.jupiter.api.Test; -import software.amazon.awssdk.codegen.model.intermediate.IntermediateModel; -import software.amazon.awssdk.codegen.poet.ClassSpec; -import software.amazon.awssdk.codegen.poet.ClientTestModels; - -public class EndpointResolverInterceptorSpecTest { - @Test - public void endpointResolverInterceptorClass() { - ClassSpec endpointProviderInterceptor = new EndpointResolverInterceptorSpec(getModel()); - assertThat(endpointProviderInterceptor, generatesTo("endpoint-resolve-interceptor.java")); - } - - private static IntermediateModel getModel() { - IntermediateModel model = ClientTestModels.queryServiceModels(); - return model; - } - - @Test - void endpointResolverInterceptorClassWithSigv4aMultiAuth() { - ClassSpec endpointProviderInterceptor = new EndpointResolverInterceptorSpec(ClientTestModels.opsWithSigv4a()); - assertThat(endpointProviderInterceptor, generatesTo("endpoint-resolve-interceptor-with-multiauthsigv4a.java")); - } - - @Test - void endpointResolverInterceptorClassWithEndpointBasedAuth() { - ClassSpec endpointProviderInterceptor = new EndpointResolverInterceptorSpec(ClientTestModels.queryServiceModelsEndpointAuthParamsWithoutAllowList()); - assertThat(endpointProviderInterceptor, generatesTo("endpoint-resolve-interceptor-with-endpointsbasedauth.java")); - } - - @Test - public void endpointProviderTestClassWithStringArray() { - ClassSpec endpointProviderInterceptor = new EndpointResolverInterceptorSpec(ClientTestModels.stringArrayServiceModels()); - assertThat(endpointProviderInterceptor, generatesTo("endpoint-resolve-interceptor-with-stringarray.java")); - } -} diff --git a/codegen/src/test/java/software/amazon/awssdk/codegen/poet/rules/RequestEndpointProviderInterceptorSpecTest.java b/codegen/src/test/java/software/amazon/awssdk/codegen/poet/rules/RequestEndpointProviderInterceptorSpecTest.java deleted file mode 100644 index f4f2434340c9..000000000000 --- a/codegen/src/test/java/software/amazon/awssdk/codegen/poet/rules/RequestEndpointProviderInterceptorSpecTest.java +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package software.amazon.awssdk.codegen.poet.rules; - -import static org.hamcrest.MatcherAssert.assertThat; -import static software.amazon.awssdk.codegen.poet.PoetMatchers.generatesTo; - -import org.junit.jupiter.api.Test; -import software.amazon.awssdk.codegen.poet.ClassSpec; -import software.amazon.awssdk.codegen.poet.ClientTestModels; - -public class RequestEndpointProviderInterceptorSpecTest { - @Test - public void requestSetEndpointInterceptorClass() { - ClassSpec endpointProviderInterceptor = new RequestEndpointInterceptorSpec(ClientTestModels.queryServiceModels()); - assertThat(endpointProviderInterceptor, generatesTo("request-set-endpoint-interceptor.java")); - } -} diff --git a/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilder.java b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilder.java index add51ac598a8..b3ce212efbee 100644 --- a/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilder.java +++ b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilder.java @@ -149,6 +149,11 @@ private AwsExecutionContextBuilder() { executionParams.authSchemeOptionsResolver()); } + if (executionParams.endpointResolver() != null) { + executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER, + executionParams.endpointResolver()); + } + // Set the identity provider updater for the pipeline stage to use executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDER_UPDATER, AwsIdentityProviderUpdater.create()); diff --git a/codegen/src/main/resources/software/amazon/awssdk/codegen/rules/AwsEndpointProviderUtils.java.resource b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/endpoints/AwsEndpointProviderUtils.java similarity index 54% rename from codegen/src/main/resources/software/amazon/awssdk/codegen/rules/AwsEndpointProviderUtils.java.resource rename to core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/endpoints/AwsEndpointProviderUtils.java index 1d0fedb9f7ff..1cd27489455f 100644 --- a/codegen/src/main/resources/software/amazon/awssdk/codegen/rules/AwsEndpointProviderUtils.java.resource +++ b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/endpoints/AwsEndpointProviderUtils.java @@ -1,27 +1,41 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.awscore.internal.endpoints; + import static software.amazon.awssdk.utils.FunctionalUtils.invokeSafely; import java.net.URI; -import java.util.List; -import java.util.Map; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.awscore.AwsExecutionAttribute; -import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; -import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; -import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.useragent.BusinessMetricFeatureId; import software.amazon.awssdk.endpoints.Endpoint; import software.amazon.awssdk.http.SdkHttpRequest; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.utils.HostnameValidator; -import software.amazon.awssdk.utils.Logger; import software.amazon.awssdk.utils.StringUtils; import software.amazon.awssdk.utils.http.SdkHttpUtils; +/** + * Shared utility methods for endpoint resolution across all AWS services. + * Previously this was generated per-service as an identical copy; now de-duplicated here. + */ @SdkInternalApi public final class AwsEndpointProviderUtils { - private static final Logger LOG = Logger.loggerFor(AwsEndpointProviderUtils.class); private AwsEndpointProviderUtils() { } @@ -38,10 +52,6 @@ public static Boolean fipsEnabledBuiltIn(ExecutionAttributes executionAttributes return executionAttributes.getAttribute(AwsExecutionAttribute.FIPS_ENDPOINT_ENABLED); } - /** - * Returns the endpoint set on the client. Note that this strips off the query part of the URI because the endpoint - * rules library, e.g. {@code ParseURL} will return an exception if the URI it parses has query parameters. - */ public static String endpointBuiltIn(ExecutionAttributes executionAttributes) { if (endpointIsOverridden(executionAttributes)) { executionAttributes.getOptionalAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS).ifPresent( @@ -56,80 +66,36 @@ public static String endpointBuiltIn(ExecutionAttributes executionAttributes) { return null; } - /** - * Read {@link SdkExecutionAttribute#CLIENT_ENDPOINT_PROVIDER}'s isEndpointOverridden attribute. - */ public static boolean endpointIsOverridden(ExecutionAttributes attrs) { return attrs.getAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER).isEndpointOverridden(); } - /** - * True if the the {@link SdkInternalExecutionAttribute#IS_DISCOVERED_ENDPOINT} attribute is present and its value - * is {@code true}, {@code false} otherwise. - */ public static boolean endpointIsDiscovered(ExecutionAttributes attrs) { return attrs.getOptionalAttribute(SdkInternalExecutionAttribute.IS_DISCOVERED_ENDPOINT).orElse(false); } - /** - * True if the the {@link SdkInternalExecutionAttribute#DISABLE_HOST_PREFIX_INJECTION} attribute is present and its - * value is {@code true}, {@code false} otherwise. - */ public static boolean disableHostPrefixInjection(ExecutionAttributes attrs) { return attrs.getOptionalAttribute(SdkInternalExecutionAttribute.DISABLE_HOST_PREFIX_INJECTION).orElse(false); } - /** - * Apply the given endpoint prefix to the endpoint. - */ public static Endpoint addHostPrefix(Endpoint endpoint, String prefix) { if (StringUtils.isBlank(prefix)) { return endpoint; } - validatePrefixIsHostNameCompliant(prefix); - URI originalUrl = endpoint.url(); String newHost = prefix + endpoint.url().getHost(); URI newUrl = invokeSafely(() -> new URI(originalUrl.getScheme(), null, newHost, originalUrl.getPort(), originalUrl.getPath(), originalUrl.getQuery(), originalUrl.getFragment())); - return endpoint.toBuilder().url(newUrl).build(); } - /** - * This sets the request URI to the resolved URI returned by the endpoint provider. There are some things to be - * careful about to make this work properly: - *

- * If the client endpoint is an endpoint override, it may contain a path. In addition, the request marshaller itself - * may add components to the path if it's modeled for the operation. Unfortunately, - * {@link SdkHttpRequest#encodedPath()} returns the combined path from both the endpoint and the request. There is - * no way to know, just from the HTTP request object, where the override path ends (if it's even there) and where - * the request path starts. Additionally, the rule itself may also append other parts to the endpoint override path. - *

- * To solve this issue, we pass in the endpoint set on the path, which allows us to the strip the path from the - * endpoint override from the request path, and then correctly combine the paths. - *

- * For example, let's suppose the endpoint override on the client is {@code https://example.com/a}. Then we call an - * operation {@code Foo()}, that marshalls {@code /c} to the path. The resulting request path is {@code /a/c}. - * However, we also pass the endpoint to provider as a parameter, and the resolver returns - * {@code https://example.com/a/b}. This method takes care of combining the paths correctly so that the resulting - * path is {@code https://example.com/a/b/c}. - */ public static SdkHttpRequest setUri(SdkHttpRequest request, URI clientEndpoint, URI resolvedUri) { - // [client endpoint path] String clientEndpointPath = clientEndpoint.getRawPath(); - - // [client endpoint path]/[request path] String requestPath = request.encodedPath(); - - // [client endpoint path]/[additional path added by resolver] String resolvedUriPath = resolvedUri.getRawPath(); String finalPath = requestPath; - - // If there is an additional path added by resolver, i.e., [additional path added by resolver] not null, - // we need to combine the path if (!resolvedUriPath.equals(clientEndpointPath)) { finalPath = combinePath(clientEndpointPath, requestPath, resolvedUriPath); } @@ -138,27 +104,15 @@ public static SdkHttpRequest setUri(SdkHttpRequest request, URI clientEndpoint, .encodedPath(finalPath).build(); } - /** - * Our goal is to construct [client endpoint path]/[additional path added by resolver]/[request path], so we just - * need to strip the client endpoint path from the marshalled request path to isolate just the part added by the - * marshaller. Trailing slash is removed from client endpoint path before stripping because it could cause the - * leading slash to be removed from the request path: e.g., StringUtils.replaceOnce("/", "//test", "") generates - * "/test" and the expected result is "//test" - */ private static String combinePath(String clientEndpointPath, String requestPath, String resolvedUriPath) { String requestPathWithClientPathRemoved = StringUtils.replaceOnce(requestPath, clientEndpointPath, ""); - String finalPath = SdkHttpUtils.appendUri(resolvedUriPath, requestPathWithClientPathRemoved); - return finalPath; + return SdkHttpUtils.appendUri(resolvedUriPath, requestPathWithClientPathRemoved); } private static void validatePrefixIsHostNameCompliant(String prefix) { - String[] components = splitHostLabelOnDots(prefix); + String[] components = prefix.split("\\."); for (String component : components) { HostnameValidator.validateHostnameCompliant(component, component, "request"); } } - - private static String[] splitHostLabelOnDots(String label) { - return label.split("\\."); - } } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/client/handler/ClientExecutionParams.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/client/handler/ClientExecutionParams.java index 66b4d9e7bf0f..6950d6105fa2 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/client/handler/ClientExecutionParams.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/client/handler/ClientExecutionParams.java @@ -30,6 +30,7 @@ import software.amazon.awssdk.core.interceptor.ExecutionAttribute; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.runtime.transform.Marshaller; +import software.amazon.awssdk.core.spi.endpoint.EndpointResolver; import software.amazon.awssdk.core.spi.identity.AuthSchemeOptionsResolver; import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.core.sync.ResponseTransformer; @@ -65,6 +66,7 @@ public final class ClientExecutionParams { private final ExecutionAttributes attributes = new ExecutionAttributes(); private SdkClientConfiguration requestConfiguration; private AuthSchemeOptionsResolver authSchemeOptionsResolver; + private EndpointResolver endpointResolver; public Marshaller getMarshaller() { return marshaller; @@ -273,4 +275,13 @@ public ClientExecutionParams withAuthSchemeOptionsResolver( this.authSchemeOptionsResolver = authSchemeOptionsResolver; return this; } + + public EndpointResolver endpointResolver() { + return endpointResolver; + } + + public ClientExecutionParams withEndpointResolver(EndpointResolver endpointResolver) { + this.endpointResolver = endpointResolver; + return this; + } } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/interceptor/SdkInternalExecutionAttribute.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/interceptor/SdkInternalExecutionAttribute.java index 8c73ce3b70ef..30ea3e948e7e 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/interceptor/SdkInternalExecutionAttribute.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/interceptor/SdkInternalExecutionAttribute.java @@ -29,6 +29,7 @@ import software.amazon.awssdk.core.interceptor.trait.HttpChecksum; import software.amazon.awssdk.core.interceptor.trait.HttpChecksumRequired; import software.amazon.awssdk.core.internal.interceptor.trait.RequestCompression; +import software.amazon.awssdk.core.spi.endpoint.EndpointResolver; import software.amazon.awssdk.core.spi.identity.AuthSchemeOptionsResolver; import software.amazon.awssdk.core.spi.identity.IdentityProviderUpdater; import software.amazon.awssdk.core.useragent.AdditionalMetadata; @@ -182,6 +183,13 @@ public final class SdkInternalExecutionAttribute extends SdkExecutionAttribute { public static final ExecutionAttribute AUTH_SCHEME_OPTIONS_RESOLVER = new ExecutionAttribute<>("AuthSchemeOptionsResolver"); + /** + * Callback for resolving the endpoint. Generated per-service as a lambda in the client class. + * Called by EndpointResolutionStage after interceptors have run. + */ + public static final ExecutionAttribute ENDPOINT_RESOLVER = + new ExecutionAttribute<>("EndpointResolver"); + /** * The selected auth scheme for a request. */ diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonAsyncHttpClient.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonAsyncHttpClient.java index 8319ee438a26..6594db840740 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonAsyncHttpClient.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonAsyncHttpClient.java @@ -41,6 +41,7 @@ import software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncSigningStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.AuthSchemeResolutionStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.CompressRequestStage; +import software.amazon.awssdk.core.internal.http.pipeline.stages.EndpointResolutionStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.HttpChecksumStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.MakeAsyncHttpRequestStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.MakeRequestImmutableStage; @@ -200,6 +201,7 @@ public CompletableFuture execute( .then(() -> new CompressRequestStage(httpClientDependencies)) .then(() -> new HttpChecksumStage(ClientType.ASYNC)) .then(AuthSchemeResolutionStage::new) + .then(EndpointResolutionStage::new) .then(ApplyUserAgentStage::new) .then(MakeRequestImmutableStage::new) .then(RequestPipelineBuilder diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonSyncHttpClient.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonSyncHttpClient.java index 13e91eb06c17..2298855e98fe 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonSyncHttpClient.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonSyncHttpClient.java @@ -38,6 +38,7 @@ import software.amazon.awssdk.core.internal.http.pipeline.stages.BeforeTransmissionExecutionInterceptorsStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.BeforeUnmarshallingExecutionInterceptorsStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.CompressRequestStage; +import software.amazon.awssdk.core.internal.http.pipeline.stages.EndpointResolutionStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.ExecutionFailureExceptionReportingStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.HandleResponseStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.HttpChecksumStage; @@ -188,6 +189,7 @@ public OutputT execute(HttpResponseHandler> response .then(() -> new CompressRequestStage(httpClientDependencies)) .then(() -> new HttpChecksumStage(ClientType.SYNC)) .then(AuthSchemeResolutionStage::new) + .then(EndpointResolutionStage::new) .then(ApplyUserAgentStage::new) .then(MakeRequestImmutableStage::new) // End of mutating request diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/ApiCallMetricCollectionStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/ApiCallMetricCollectionStage.java index 3b78dedaf2ad..c7a6783fa7da 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/ApiCallMetricCollectionStage.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/ApiCallMetricCollectionStage.java @@ -40,17 +40,23 @@ public ApiCallMetricCollectionStage(RequestPipeline execute(SdkHttpFullRequest input, RequestExecutionContext context) throws Exception { MetricCollector metricCollector = context.executionContext().metricCollector(); - MetricUtils.collectServiceEndpointMetrics(metricCollector, input); // Note: at this point, any exception, even a service exception, will // be thrown from the wrapped pipeline so we can't use // MetricUtil.measureDuration() long callStart = System.nanoTime(); try { - return wrapped.execute(input, context); + Response response = wrapped.execute(input, context); + return response; } finally { long d = System.nanoTime() - callStart; metricCollector.reportMetric(CoreMetric.API_CALL_DURATION, Duration.ofNanos(d)); + // Collect SERVICE_ENDPOINT after pipeline execution so that EndpointResolutionStage + // has applied the resolved URL to the request. + SdkHttpFullRequest finalRequest = context.executionContext().interceptorContext().httpRequest() != null + ? (SdkHttpFullRequest) context.executionContext().interceptorContext().httpRequest() + : input; + MetricUtils.collectServiceEndpointMetrics(metricCollector, finalRequest); } } } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncApiCallMetricCollectionStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncApiCallMetricCollectionStage.java index 1d7d040971cf..e85a6f752aee 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncApiCallMetricCollectionStage.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncApiCallMetricCollectionStage.java @@ -41,7 +41,6 @@ public AsyncApiCallMetricCollectionStage(RequestPipeline execute(SdkHttpFullRequest input, RequestExecutionContext context) throws Exception { MetricCollector metricCollector = context.executionContext().metricCollector(); - MetricUtils.collectServiceEndpointMetrics(metricCollector, input); CompletableFuture future = new CompletableFuture<>(); @@ -52,6 +51,13 @@ public CompletableFuture execute(SdkHttpFullRequest input, RequestExecu long duration = System.nanoTime() - callStart; metricCollector.reportMetric(CoreMetric.API_CALL_DURATION, Duration.ofNanos(duration)); + // Collect SERVICE_ENDPOINT after pipeline execution so that EndpointResolutionStage + // has applied the resolved URL to the request. + SdkHttpFullRequest finalRequest = context.executionContext().interceptorContext().httpRequest() != null + ? (SdkHttpFullRequest) context.executionContext().interceptorContext().httpRequest() + : input; + MetricUtils.collectServiceEndpointMetrics(metricCollector, finalRequest); + if (t != null) { future.completeExceptionally(t); } else { diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/EndpointResolutionStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/EndpointResolutionStage.java new file mode 100644 index 000000000000..70d2886c2fd1 --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/EndpointResolutionStage.java @@ -0,0 +1,119 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.http.pipeline.stages; + +import java.net.URI; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.ClientEndpointProvider; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.core.internal.http.HttpClientDependencies; +import software.amazon.awssdk.core.internal.http.RequestExecutionContext; +import software.amazon.awssdk.core.internal.http.pipeline.MutableRequestToRequestPipeline; +import software.amazon.awssdk.core.metrics.CoreMetric; +import software.amazon.awssdk.core.spi.endpoint.EndpointResolver; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.metrics.MetricCollector; +import software.amazon.awssdk.utils.StringUtils; +import software.amazon.awssdk.utils.http.SdkHttpUtils; + +/** + * Pipeline stage that resolves the endpoint using the service's endpoint rules engine. + *

+ * This stage runs after all interceptors (including {@code modifyHttpRequest}) and after + * {@link AuthSchemeResolutionStage}. It calls a service-specific callback to resolve the endpoint, + * then applies the resolved URL, headers, and metrics. + *

+ */ +@SdkInternalApi +public final class EndpointResolutionStage implements MutableRequestToRequestPipeline { + + public EndpointResolutionStage(HttpClientDependencies dependencies) { + } + + @Override + public SdkHttpFullRequest.Builder execute(SdkHttpFullRequest.Builder request, RequestExecutionContext context) + throws Exception { + ExecutionAttributes attrs = context.executionAttributes(); + + if (Boolean.TRUE.equals(attrs.getAttribute(SdkInternalExecutionAttribute.IS_DISCOVERED_ENDPOINT))) { + return request; + } + + EndpointResolver resolver = attrs.getAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER); + if (resolver == null) { + return request; + } + + SdkRequest sdkRequest = context.executionContext().interceptorContext().request(); + + long resolveEndpointStart = System.nanoTime(); + Endpoint endpoint = resolver.resolve(sdkRequest, attrs); + Duration resolveEndpointDuration = Duration.ofNanos(System.nanoTime() - resolveEndpointStart); + + MetricCollector metricCollector = attrs.getAttribute(SdkExecutionAttribute.API_CALL_METRIC_COLLECTOR); + if (metricCollector != null) { + metricCollector.reportMetric(CoreMetric.ENDPOINT_RESOLVE_DURATION, resolveEndpointDuration); + } + + attrs.putAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT, endpoint); + + // Copy endpoint headers onto HTTP request + Map> headers = endpoint.headers(); + if (headers != null && !headers.isEmpty()) { + headers.forEach((name, values) -> + values.forEach(v -> request.appendHeader(name, v))); + } + + // Apply resolved endpoint URL + ClientEndpointProvider clientEndpointProvider = + attrs.getAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER); + return setUri(request, clientEndpointProvider.clientEndpoint(), endpoint.url()); + } + + /** + * Applies the resolved endpoint URL to the HTTP request, merging three path components: + * client endpoint path + rules engine additions + marshaller operation path. + */ + private static SdkHttpFullRequest.Builder setUri(SdkHttpFullRequest.Builder request, + URI clientEndpoint, + URI resolvedUri) { + String clientEndpointPath = clientEndpoint.getRawPath(); + String requestPath = request.encodedPath(); + String resolvedUriPath = resolvedUri.getRawPath(); + + String finalPath = requestPath; + if (!resolvedUriPath.equals(clientEndpointPath)) { + finalPath = combinePath(clientEndpointPath, requestPath, resolvedUriPath); + } + + return request.protocol(resolvedUri.getScheme()) + .host(resolvedUri.getHost()) + .port(resolvedUri.getPort()) + .encodedPath(finalPath); + } + + private static String combinePath(String clientEndpointPath, String requestPath, String resolvedUriPath) { + String requestPathWithClientPathRemoved = StringUtils.replaceOnce(requestPath, clientEndpointPath, ""); + return SdkHttpUtils.appendUri(resolvedUriPath, requestPathWithClientPathRemoved); + } +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/endpoint/EndpointResolver.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/endpoint/EndpointResolver.java new file mode 100644 index 000000000000..62ba60164d24 --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/endpoint/EndpointResolver.java @@ -0,0 +1,37 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.spi.endpoint; + +import software.amazon.awssdk.annotations.SdkProtectedApi; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.endpoints.Endpoint; + +/** + * Callback interface for resolving endpoints from the request and execution context. + */ +@FunctionalInterface +@SdkProtectedApi +public interface EndpointResolver { + /** + * Resolves the endpoint for the given request. + * + * @param request The SDK request (after interceptors have modified it) + * @param executionAttributes The execution attributes containing client config, region, etc. + * @return The resolved endpoint + */ + Endpoint resolve(SdkRequest request, ExecutionAttributes executionAttributes); +} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/EndpointResolutionStageTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/EndpointResolutionStageTest.java new file mode 100644 index 000000000000..94ee03b5df50 --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/EndpointResolutionStageTest.java @@ -0,0 +1,217 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.http.pipeline.stages; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import java.net.URI; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import software.amazon.awssdk.core.ClientEndpointProvider; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.http.ExecutionContext; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.InterceptorContext; +import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.core.internal.http.RequestExecutionContext; +import software.amazon.awssdk.core.metrics.CoreMetric; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.SdkHttpMethod; +import software.amazon.awssdk.metrics.MetricCollector; + +class EndpointResolutionStageTest { + + private static final URI CLIENT_ENDPOINT = URI.create("https://myservice.us-east-1.amazonaws.com"); + private static final URI RESOLVED_ENDPOINT = URI.create("https://resolved.us-west-2.amazonaws.com"); + + private EndpointResolutionStage stage; + private ExecutionAttributes executionAttributes; + private SdkRequest sdkRequest; + + @BeforeEach + void setup() { + stage = new EndpointResolutionStage(null); + sdkRequest = mock(SdkRequest.class); + executionAttributes = new ExecutionAttributes(); + } + + @Test + void execute_noResolver_returnsRequestUnchanged() throws Exception { + SdkHttpFullRequest.Builder request = defaultRequest(); + RequestExecutionContext context = createContext(); + + SdkHttpFullRequest.Builder result = stage.execute(request, context); + + assertThat(result).isSameAs(request); + } + + @Test + void execute_discoveredEndpoint_returnsRequestUnchanged() throws Exception { + SdkHttpFullRequest.Builder request = defaultRequest(); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IS_DISCOVERED_ENDPOINT, true); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER, + (req, attrs) -> Endpoint.builder().url(RESOLVED_ENDPOINT).build()); + RequestExecutionContext context = createContext(); + + SdkHttpFullRequest.Builder result = stage.execute(request, context); + + assertThat(result).isSameAs(request); + } + + @Test + void execute_happyPath_appliesResolvedUrl() throws Exception { + SdkHttpFullRequest.Builder request = defaultRequest(); + Endpoint endpoint = Endpoint.builder().url(RESOLVED_ENDPOINT).build(); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER, (req, attrs) -> endpoint); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, + ClientEndpointProvider.forEndpointOverride(CLIENT_ENDPOINT)); + RequestExecutionContext context = createContext(); + + SdkHttpFullRequest.Builder result = stage.execute(request, context); + + assertThat(result.host()).isEqualTo("resolved.us-west-2.amazonaws.com"); + assertThat(result.protocol()).isEqualTo("https"); + } + + @Test + void execute_happyPath_storesResolvedEndpoint() throws Exception { + SdkHttpFullRequest.Builder request = defaultRequest(); + Endpoint endpoint = Endpoint.builder().url(RESOLVED_ENDPOINT).build(); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER, (req, attrs) -> endpoint); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, + ClientEndpointProvider.forEndpointOverride(CLIENT_ENDPOINT)); + RequestExecutionContext context = createContext(); + + stage.execute(request, context); + + assertThat(executionAttributes.getAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT)).isSameAs(endpoint); + } + + @Test + void execute_withHeaders_appendsToRequest() throws Exception { + Endpoint endpoint = Endpoint.builder().url(RESOLVED_ENDPOINT) + .putHeader("x-amz-custom", "val1") + .putHeader("x-amz-custom", "val2") + .build(); + + SdkHttpFullRequest.Builder request = defaultRequest(); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER, (req, attrs) -> endpoint); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, + ClientEndpointProvider.forEndpointOverride(CLIENT_ENDPOINT)); + RequestExecutionContext context = createContext(); + + SdkHttpFullRequest.Builder result = stage.execute(request, context); + + assertThat(result.build().matchingHeaders("x-amz-custom")).containsExactly("val1", "val2"); + } + + @Test + void execute_withMetricCollector_reportsEndpointResolveDuration() throws Exception { + Endpoint endpoint = Endpoint.builder().url(RESOLVED_ENDPOINT).build(); + MetricCollector metricCollector = mock(MetricCollector.class); + + SdkHttpFullRequest.Builder request = defaultRequest(); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER, (req, attrs) -> endpoint); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, + ClientEndpointProvider.forEndpointOverride(CLIENT_ENDPOINT)); + executionAttributes.putAttribute(SdkExecutionAttribute.API_CALL_METRIC_COLLECTOR, metricCollector); + RequestExecutionContext context = createContext(); + + stage.execute(request, context); + + ArgumentCaptor durationCaptor = ArgumentCaptor.forClass(Duration.class); + verify(metricCollector).reportMetric(org.mockito.ArgumentMatchers.eq(CoreMetric.ENDPOINT_RESOLVE_DURATION), + durationCaptor.capture()); + assertThat(durationCaptor.getValue()).isGreaterThanOrEqualTo(Duration.ZERO); + } + + @Test + void execute_resolvedPathDiffersFromClientPath_combinesPaths() throws Exception { + URI clientEndpoint = URI.create("https://myservice.amazonaws.com"); + URI resolvedUri = URI.create("https://resolved.amazonaws.com/v2"); + + Endpoint endpoint = Endpoint.builder().url(resolvedUri).build(); + SdkHttpFullRequest.Builder request = SdkHttpFullRequest.builder() + .method(SdkHttpMethod.GET) + .protocol("https") + .host("myservice.amazonaws.com") + .encodedPath("/my/operation"); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER, (req, attrs) -> endpoint); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, + ClientEndpointProvider.forEndpointOverride(clientEndpoint)); + RequestExecutionContext context = createContext(); + + SdkHttpFullRequest.Builder result = stage.execute(request, context); + + assertThat(result.encodedPath()).isEqualTo("/v2/my/operation"); + } + + @Test + void execute_resolverReceivesRequestFromInterceptorContext() throws Exception { + SdkRequest modifiedRequest = mock(SdkRequest.class); + SdkRequest[] capturedRequest = new SdkRequest[1]; + + Endpoint endpoint = Endpoint.builder().url(RESOLVED_ENDPOINT).build(); + SdkHttpFullRequest.Builder request = defaultRequest(); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_RESOLVER, (req, attrs) -> { + capturedRequest[0] = req; + return endpoint; + }); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, + ClientEndpointProvider.forEndpointOverride(CLIENT_ENDPOINT)); + + // Use modifiedRequest in interceptor context (different from sdkRequest) + RequestExecutionContext context = createContext(modifiedRequest); + + stage.execute(request, context); + + assertThat(capturedRequest[0]).isSameAs(modifiedRequest); + } + + private SdkHttpFullRequest.Builder defaultRequest() { + return SdkHttpFullRequest.builder() + .method(SdkHttpMethod.GET) + .protocol("https") + .host(CLIENT_ENDPOINT.getHost()) + .encodedPath("/"); + } + + private RequestExecutionContext createContext() { + return createContext(sdkRequest); + } + + private RequestExecutionContext createContext(SdkRequest request) { + InterceptorContext interceptorContext = InterceptorContext.builder() + .request(request) + .build(); + ExecutionContext executionContext = ExecutionContext.builder() + .interceptorContext(interceptorContext) + .executionAttributes(executionAttributes) + .build(); + return RequestExecutionContext.builder() + .executionContext(executionContext) + .originalRequest(request) + .build(); + } +} diff --git a/services/dynamodb/src/test/java/software/amazon/awssdk/services/dynamodb/PaginatorInUserAgentTest.java b/services/dynamodb/src/test/java/software/amazon/awssdk/services/dynamodb/PaginatorInUserAgentTest.java index b4bc3a504391..ef4f2b3b9e70 100644 --- a/services/dynamodb/src/test/java/software/amazon/awssdk/services/dynamodb/PaginatorInUserAgentTest.java +++ b/services/dynamodb/src/test/java/software/amazon/awssdk/services/dynamodb/PaginatorInUserAgentTest.java @@ -39,7 +39,7 @@ import software.amazon.awssdk.core.useragent.BusinessMetricCollection; import software.amazon.awssdk.core.useragent.BusinessMetricFeatureId; import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.dynamodb.endpoints.internal.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.internal.endpoints.AwsEndpointProviderUtils; import software.amazon.awssdk.services.dynamodb.paginators.QueryPublisher; public class PaginatorInUserAgentTest { diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/S3Utilities.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/S3Utilities.java index cb3064798a99..a745071d9e62 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/S3Utilities.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/S3Utilities.java @@ -19,11 +19,11 @@ import java.net.MalformedURLException; import java.net.URI; import java.net.URL; -import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Supplier; import java.util.regex.Matcher; @@ -39,18 +39,18 @@ import software.amazon.awssdk.awscore.endpoint.DualstackEnabledProvider; import software.amazon.awssdk.awscore.endpoint.FipsEnabledProvider; import software.amazon.awssdk.awscore.internal.defaultsmode.DefaultsModeConfiguration; +import software.amazon.awssdk.awscore.internal.endpoints.AwsEndpointProviderUtils; import software.amazon.awssdk.core.ClientEndpointProvider; import software.amazon.awssdk.core.ClientType; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; -import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; -import software.amazon.awssdk.core.interceptor.ExecutionInterceptorChain; -import software.amazon.awssdk.core.interceptor.InterceptorContext; import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.endpoints.Endpoint; import software.amazon.awssdk.http.SdkHttpFullRequest; import software.amazon.awssdk.http.SdkHttpMethod; import software.amazon.awssdk.http.SdkHttpRequest; @@ -62,9 +62,9 @@ import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.services.s3.endpoints.S3ClientContextParams; +import software.amazon.awssdk.services.s3.endpoints.S3EndpointParams; import software.amazon.awssdk.services.s3.endpoints.S3EndpointProvider; -import software.amazon.awssdk.services.s3.endpoints.internal.S3RequestSetEndpointInterceptor; -import software.amazon.awssdk.services.s3.endpoints.internal.S3ResolveEndpointInterceptor; +import software.amazon.awssdk.services.s3.endpoints.internal.S3EndpointResolverUtils; import software.amazon.awssdk.services.s3.internal.endpoints.UseGlobalEndpointResolver; import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.services.s3.model.GetUrlRequest; @@ -110,7 +110,6 @@ public final class S3Utilities { private final Supplier profileFile; private final String profileName; private final boolean fipsEnabled; - private final ExecutionInterceptorChain interceptorChain; private final UseGlobalEndpointResolver useGlobalEndpointResolver; /** @@ -140,8 +139,6 @@ private S3Utilities(Builder builder) { .isFipsEnabled() .orElse(false); - this.interceptorChain = createEndpointInterceptorChain(); - this.useGlobalEndpointResolver = createUseGlobalEndpointResolver(); } @@ -243,14 +240,9 @@ public URL getUrl(GetUrlRequest getUrlRequest) { .versionId(getUrlRequest.versionId()) .build(); - InterceptorContext interceptorContext = InterceptorContext.builder() - .httpRequest(marshalledRequest) - .request(getObjectRequest) - .build(); - ExecutionAttributes executionAttributes = createExecutionAttributes(clientEndpoint, resolvedRegion); - SdkHttpRequest modifiedRequest = runInterceptors(interceptorContext, executionAttributes).httpRequest(); + SdkHttpRequest modifiedRequest = resolveEndpointAndApply(getObjectRequest, marshalledRequest, executionAttributes); try { return modifiedRequest.getUri().toURL(); } catch (MalformedURLException exception) { @@ -506,16 +498,36 @@ private AttributeMap createClientContextParams() { return params.build(); } - private InterceptorContext runInterceptors(InterceptorContext context, ExecutionAttributes executionAttributes) { - context = interceptorChain.modifyRequest(context, executionAttributes); - return interceptorChain.modifyHttpRequestAndHttpContent(context, executionAttributes); - } + private SdkHttpRequest resolveEndpointAndApply(GetObjectRequest request, SdkHttpRequest httpRequest, + ExecutionAttributes executionAttributes) { + S3EndpointParams endpointParams = S3EndpointResolverUtils.ruleParams(request, executionAttributes); + S3EndpointProvider provider = (S3EndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + + Endpoint endpoint; + try { + endpoint = provider.resolveEndpoint(endpointParams).join(); + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + + ClientEndpointProvider clientEndpointProvider = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER); + SdkHttpRequest result = AwsEndpointProviderUtils.setUri(httpRequest, + clientEndpointProvider.clientEndpoint(), endpoint.url()); + + if (!endpoint.headers().isEmpty()) { + SdkHttpRequest.Builder builder = result.toBuilder(); + endpoint.headers().forEach((name, values) -> + values.forEach(v -> builder.appendHeader(name, v))); + result = builder.build(); + } - private ExecutionInterceptorChain createEndpointInterceptorChain() { - List interceptors = new ArrayList<>(); - interceptors.add(new S3ResolveEndpointInterceptor()); - interceptors.add(new S3RequestSetEndpointInterceptor()); - return new ExecutionInterceptorChain(interceptors); + return result; } private UseGlobalEndpointResolver createUseGlobalEndpointResolver() { diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/signing/DefaultS3Presigner.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/signing/DefaultS3Presigner.java index 727b0f2dc7e1..0aa198a34f48 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/signing/DefaultS3Presigner.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/signing/DefaultS3Presigner.java @@ -32,6 +32,7 @@ import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -42,10 +43,14 @@ import software.amazon.awssdk.awscore.client.builder.AwsDefaultClientBuilder; import software.amazon.awssdk.awscore.defaultsmode.DefaultsMode; import software.amazon.awssdk.awscore.endpoint.AwsClientEndpointProvider; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.internal.AwsExecutionContextBuilder; import software.amazon.awssdk.awscore.internal.defaultsmode.DefaultsModeConfiguration; +import software.amazon.awssdk.awscore.internal.endpoints.AwsEndpointProviderUtils; import software.amazon.awssdk.awscore.presigner.PresignRequest; import software.amazon.awssdk.awscore.presigner.PresignedRequest; +import software.amazon.awssdk.core.ClientEndpointProvider; import software.amazon.awssdk.core.ClientType; import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkBytes; @@ -55,6 +60,7 @@ import software.amazon.awssdk.core.client.builder.SdkDefaultClientBuilder; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.ExecutionContext; import software.amazon.awssdk.core.identity.SdkIdentityProperty; import software.amazon.awssdk.core.interceptor.ClasspathInterceptorChainFactory; @@ -68,6 +74,7 @@ import software.amazon.awssdk.core.signer.Presigner; import software.amazon.awssdk.core.signer.Signer; import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.endpoints.Endpoint; import software.amazon.awssdk.endpoints.EndpointProvider; import software.amazon.awssdk.http.ContentStreamProvider; import software.amazon.awssdk.http.SdkHttpFullRequest; @@ -96,8 +103,7 @@ import software.amazon.awssdk.services.s3.endpoints.S3ClientContextParams; import software.amazon.awssdk.services.s3.endpoints.S3EndpointParams; import software.amazon.awssdk.services.s3.endpoints.S3EndpointProvider; -import software.amazon.awssdk.services.s3.endpoints.internal.S3RequestSetEndpointInterceptor; -import software.amazon.awssdk.services.s3.endpoints.internal.S3ResolveEndpointInterceptor; +import software.amazon.awssdk.services.s3.endpoints.internal.S3EndpointResolverUtils; import software.amazon.awssdk.services.s3.internal.endpoints.UseGlobalEndpointResolver; import software.amazon.awssdk.services.s3.internal.s3express.S3ExpressAuthSchemeProvider; import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; @@ -248,10 +254,6 @@ private List initializeInterceptors() { ClasspathInterceptorChainFactory interceptorFactory = new ClasspathInterceptorChainFactory(); List s3Interceptors = interceptorFactory.getInterceptors("software/amazon/awssdk/services/s3/execution.interceptors"); - List additionalInterceptors = new ArrayList<>(); - additionalInterceptors.add(new S3ResolveEndpointInterceptor()); - additionalInterceptors.add(new S3RequestSetEndpointInterceptor()); - s3Interceptors = mergeLists(s3Interceptors, additionalInterceptors); return mergeLists(interceptorFactory.getGlobalInterceptors(), s3Interceptors); } @@ -416,6 +418,9 @@ private T presign(T presignedRequest, // Resolve auth scheme after interceptors complete resolveAndSelectAuthScheme(execCtx, requestToPresign, operationName); + // Resolve endpoint + resolveEndpointAndUpdateContext(execCtx, operationName); + SdkHttpFullRequest httpRequest = getHttpFullRequest(execCtx); SdkHttpFullRequest signedHttpRequest = execCtx.signer() != null @@ -631,6 +636,60 @@ private void resolveAndSelectAuthScheme(ExecutionContext execCtx, SdkRequest req executionAttributes.putAttribute(SELECTED_AUTH_SCHEME, selectedAuthScheme); } + /** + * Resolve the endpoint using the rules engine and apply it to the HTTP request in the execution context. + */ + private void resolveEndpointAndUpdateContext(ExecutionContext execCtx, String operationName) { + ExecutionAttributes executionAttributes = execCtx.executionAttributes(); + SdkRequest sdkRequest = execCtx.interceptorContext().request(); + + S3EndpointParams endpointParams = S3EndpointResolverUtils.ruleParams(sdkRequest, executionAttributes); + S3EndpointProvider provider = (S3EndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + + Endpoint endpoint; + try { + endpoint = provider.resolveEndpoint(endpointParams).join(); + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = S3EndpointResolverUtils.hostPrefix(operationName, sdkRequest); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes.getAttribute(SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = S3EndpointResolverUtils.authSchemeWithEndpointSignerProperties( + endpointAuthSchemes, selectedAuthScheme); + executionAttributes.putAttribute(SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + + SdkHttpRequest httpRequest = execCtx.interceptorContext().httpRequest(); + ClientEndpointProvider clientEndpointProvider = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER); + SdkHttpRequest updatedRequest = AwsEndpointProviderUtils.setUri(httpRequest, + clientEndpointProvider.clientEndpoint(), endpoint.url()); + + if (!endpoint.headers().isEmpty()) { + SdkHttpRequest.Builder requestBuilder = updatedRequest.toBuilder(); + endpoint.headers().forEach((name, values) -> + values.forEach(v -> requestBuilder.appendHeader(name, v))); + updatedRequest = requestBuilder.build(); + } + + SdkHttpRequest finalRequest = updatedRequest; + execCtx.interceptorContext(execCtx.interceptorContext().copy(c -> c.httpRequest(finalRequest))); + } + /** * Resolve auth scheme options using full endpoint params. */ @@ -640,7 +699,7 @@ private List resolveAuthSchemeOptions(SdkRequest request, Stri executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_RESOLVER), "Expected an instance of S3AuthSchemeProvider"); - S3EndpointParams endpointParams = S3ResolveEndpointInterceptor.ruleParams(request, executionAttributes); + S3EndpointParams endpointParams = S3EndpointResolverUtils.ruleParams(request, executionAttributes); S3AuthSchemeParams.Builder paramsBuilder = S3AuthSchemeParams.fromEndpointParams(endpointParams) .operation(operationName); diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/functionaltests/AsyncResponseTransformerTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/functionaltests/AsyncResponseTransformerTest.java index 4c2c5e748c0b..155306ab9faf 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/functionaltests/AsyncResponseTransformerTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/functionaltests/AsyncResponseTransformerTest.java @@ -17,6 +17,7 @@ import static org.assertj.core.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -50,6 +51,6 @@ public void AsyncResponseTransformerPrepareCalled_BeforeCredentialsResolution() } verify(asyncResponseTransformer, times(1)).prepare(); - verify(asyncResponseTransformer, times(1)).exceptionOccurred(any()); + verify(asyncResponseTransformer, atLeast(1)).exceptionOccurred(any()); } } diff --git a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/AwsEndpointProviderUtilsTest.java b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/AwsEndpointProviderUtilsTest.java index 83d9d1d0fb59..25878e3cd9e2 100644 --- a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/AwsEndpointProviderUtilsTest.java +++ b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/AwsEndpointProviderUtilsTest.java @@ -29,7 +29,7 @@ import software.amazon.awssdk.http.SdkHttpMethod; import software.amazon.awssdk.http.SdkHttpRequest; import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.restjsonendpointproviders.endpoints.internal.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.internal.endpoints.AwsEndpointProviderUtils; public class AwsEndpointProviderUtilsTest { @Test diff --git a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/EndpointInterceptorTests.java b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/EndpointInterceptorTests.java index d24456368fbd..df32b88193b2 100644 --- a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/EndpointInterceptorTests.java +++ b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/EndpointInterceptorTests.java @@ -177,7 +177,7 @@ public void sync_endpointProviderReturnsHeaders_appendedToExistingRequest() { "TestValue0")))) .hasMessageContaining("stop"); - assertThat(interceptor.context.httpRequest().matchingHeaders("TestHeader")).containsExactly("TestValue", "TestValue0"); + assertThat(interceptor.context.httpRequest().matchingHeaders("TestHeader")).containsExactly("TestValue0", "TestValue"); } @Test @@ -198,7 +198,7 @@ public void async_endpointProviderReturnsHeaders_appendedToExistingRequest() { .join()) .hasMessageContaining("stop"); - assertThat(interceptor.context.httpRequest().matchingHeaders("TestHeader")).containsExactly("TestValue", "TestValue0"); + assertThat(interceptor.context.httpRequest().matchingHeaders("TestHeader")).containsExactly("TestValue0", "TestValue"); } diff --git a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/RequestOverrideEndpointProviderTests.java b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/RequestOverrideEndpointProviderTests.java index 075976d136c8..07dd11000688 100644 --- a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/RequestOverrideEndpointProviderTests.java +++ b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/endpointproviders/RequestOverrideEndpointProviderTests.java @@ -18,16 +18,12 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import java.net.URI; -import java.util.concurrent.CompletableFuture; -import org.junit.Test; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; import software.amazon.awssdk.core.client.config.SdkAdvancedClientOption; import software.amazon.awssdk.core.interceptor.Context; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; -import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.endpoints.Endpoint; import software.amazon.awssdk.http.SdkHttpRequest; import software.amazon.awssdk.regions.Region; @@ -37,6 +33,7 @@ import software.amazon.awssdk.services.restjsonendpointproviders.RestJsonEndpointProvidersClientBuilder; import software.amazon.awssdk.services.restjsonendpointproviders.endpoints.RestJsonEndpointProvidersEndpointParams; import software.amazon.awssdk.services.restjsonendpointproviders.endpoints.RestJsonEndpointProvidersEndpointProvider; +import org.junit.Test; public class RequestOverrideEndpointProviderTests { @@ -49,34 +46,28 @@ public void sync_endpointOverridden_equals_requestOverride() { .build(); assertThatThrownBy(() -> client.operationWithHostPrefix( - r -> r.overrideConfiguration(o -> o.endpointProvider(new CustomEndpointProvider(Region.AWS_GLOBAL)))) - - ) + r -> r.overrideConfiguration(o -> o.endpointProvider(new CustomEndpointProvider(Region.AWS_GLOBAL))))) .hasMessageContaining("stop"); - Endpoint endpoint = interceptor.executionAttributes().getAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT); - assertThat(endpoint.url().getHost()).isEqualTo("restjson.aws-global.amazonaws.com"); + + assertThat(interceptor.httpRequest.host()).isEqualTo("restjson.aws-global.amazonaws.com"); } @Test public void sync_endpointOverridden_equals_ClientsWhenNoRequestOverride() { CapturingInterceptor interceptor = new CapturingInterceptor(); - RestJsonEndpointProvidersClient client = syncClientBuilder().endpointProvider(new CustomEndpointProvider(Region.EU_WEST_2)) + RestJsonEndpointProvidersClient client = syncClientBuilder() + .endpointProvider(new CustomEndpointProvider(Region.EU_WEST_2)) .overrideConfiguration(o -> o.addExecutionInterceptor(interceptor) .putAdvancedOption(SdkAdvancedClientOption.DISABLE_HOST_PREFIX_INJECTION, true)) .build(); assertThatThrownBy(() -> client.operationWithHostPrefix(r -> {})).hasMessageContaining("stop"); - Endpoint endpoint = interceptor.executionAttributes().getAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT); - - assertThat(endpoint.url().getHost()).isEqualTo("restjson.eu-west-2.amazonaws.com"); + assertThat(interceptor.httpRequest.host()).isEqualTo("restjson.eu-west-2.amazonaws.com"); } - @Test public void async_endpointOverridden_equals_requestOverride() { - - CapturingInterceptor interceptor = new CapturingInterceptor(); RestJsonEndpointProvidersAsyncClient client = asyncClientBuilder() .overrideConfiguration(o -> o.addExecutionInterceptor(interceptor) @@ -88,18 +79,14 @@ public void async_endpointOverridden_equals_requestOverride() { ).join()) .hasMessageContaining("stop"); - Endpoint endpoint = interceptor.executionAttributes().getAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT); - - assertThat(endpoint.url().getHost()).isEqualTo("restjson.aws-global.amazonaws.com"); + assertThat(interceptor.httpRequest.host()).isEqualTo("restjson.aws-global.amazonaws.com"); } - @Test public void async_endpointOverridden_equals_ClientsWhenNoRequestOverride() { - - CapturingInterceptor interceptor = new CapturingInterceptor(); - RestJsonEndpointProvidersAsyncClient client = asyncClientBuilder().endpointProvider(new CustomEndpointProvider(Region.EU_WEST_2)) + RestJsonEndpointProvidersAsyncClient client = asyncClientBuilder() + .endpointProvider(new CustomEndpointProvider(Region.EU_WEST_2)) .overrideConfiguration(o -> o.addExecutionInterceptor(interceptor) .putAdvancedOption(SdkAdvancedClientOption.DISABLE_HOST_PREFIX_INJECTION, true)) .build(); @@ -107,31 +94,25 @@ public void async_endpointOverridden_equals_ClientsWhenNoRequestOverride() { assertThatThrownBy(() -> client.operationWithHostPrefix(r -> {}).join()) .hasMessageContaining("stop"); - Endpoint endpoint = interceptor.executionAttributes().getAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT); - - assertThat(endpoint.url().getHost()).isEqualTo("restjson.eu-west-2.amazonaws.com"); + assertThat(interceptor.httpRequest.host()).isEqualTo("restjson.eu-west-2.amazonaws.com"); } public static class CapturingInterceptor implements ExecutionInterceptor { - - private ExecutionAttributes executionAttributes; + private SdkHttpRequest httpRequest; @Override - public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) { - this.executionAttributes = executionAttributes; + public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) { + this.httpRequest = context.httpRequest(); throw new CaptureCompletedException("stop"); } - public ExecutionAttributes executionAttributes() { - return executionAttributes; - } - public class CaptureCompletedException extends RuntimeException { CaptureCompletedException(String message) { super(message); } } } + private RestJsonEndpointProvidersClientBuilder syncClientBuilder() { return RestJsonEndpointProvidersClient.builder() .region(Region.US_WEST_2) @@ -142,23 +123,23 @@ private RestJsonEndpointProvidersClientBuilder syncClientBuilder() { private RestJsonEndpointProvidersAsyncClientBuilder asyncClientBuilder() { return RestJsonEndpointProvidersAsyncClient.builder() - .region(Region.US_WEST_2) - .credentialsProvider( - StaticCredentialsProvider.create( - AwsBasicCredentials.create("akid", "skid"))); + .region(Region.US_WEST_2) + .credentialsProvider( + StaticCredentialsProvider.create( + AwsBasicCredentials.create("akid", "skid"))); } + public static class CustomEndpointProvider implements RestJsonEndpointProvidersEndpointProvider { + private final Region region; - class CustomEndpointProvider implements RestJsonEndpointProvidersEndpointProvider{ - final RestJsonEndpointProvidersEndpointProvider endpointProvider; - final Region overridingRegion; - CustomEndpointProvider (Region region){ - this.endpointProvider = RestJsonEndpointProvidersEndpointProvider.defaultProvider(); - this.overridingRegion = region; + CustomEndpointProvider(Region region) { + this.region = region; } + @Override - public CompletableFuture resolveEndpoint(RestJsonEndpointProvidersEndpointParams endpointParams) { - return endpointProvider.resolveEndpoint(endpointParams.toBuilder().region(overridingRegion).build()); + public java.util.concurrent.CompletableFuture resolveEndpoint(RestJsonEndpointProvidersEndpointParams params) { + return RestJsonEndpointProvidersEndpointProvider.defaultProvider() + .resolveEndpoint(params.toBuilder().region(region).build()); } } }