diff --git a/core/src/main/java/org/testcontainers/containers/GenericContainer.java b/core/src/main/java/org/testcontainers/containers/GenericContainer.java index fa5711807e7..86fbffe721d 100644 --- a/core/src/main/java/org/testcontainers/containers/GenericContainer.java +++ b/core/src/main/java/org/testcontainers/containers/GenericContainer.java @@ -1180,6 +1180,13 @@ public SELF withImagePullPolicy(ImagePullPolicy imagePullPolicy) { this.image = this.image.withImagePullPolicy(imagePullPolicy); return self(); } + /** + * Sets the platform to use when pulling the image, for example linux/amd64. + */ + public SELF withImagePlatform(String imagePlatform) { + this.image = this.image.withImagePlatform(imagePlatform); + return self(); + } /** * {@inheritDoc} diff --git a/core/src/main/java/org/testcontainers/images/RemoteDockerImage.java b/core/src/main/java/org/testcontainers/images/RemoteDockerImage.java index 9d669e46b07..f9eb42f81f5 100644 --- a/core/src/main/java/org/testcontainers/images/RemoteDockerImage.java +++ b/core/src/main/java/org/testcontainers/images/RemoteDockerImage.java @@ -1,9 +1,9 @@ package org.testcontainers.images; +import org.jetbrains.annotations.Nullable; import com.github.dockerjava.api.DockerClient; import com.github.dockerjava.api.command.PullImageCmd; import com.github.dockerjava.api.exception.DockerClientException; -import com.github.dockerjava.api.exception.InternalServerErrorException; import com.github.dockerjava.api.exception.NotFoundException; import com.google.common.util.concurrent.Futures; import lombok.AccessLevel; @@ -46,9 +46,14 @@ public class RemoteDockerImage extends LazyFuture { @With ImagePullPolicy imagePullPolicy = PullPolicy.defaultPolicy(); + @With + @Nullable + private String imagePlatform; + @With private ImageNameSubstitutor imageNameSubstitutor = ImageNameSubstitutor.instance(); + @With @ToString.Exclude private DockerClient dockerClient = DockerClientFactory.lazyClient(); @@ -92,6 +97,9 @@ protected final String resolve() { final PullImageCmd pullImageCmd = dockerClient .pullImageCmd(imageName.getUnversionedPart()) .withTag(imageName.getVersionPart()); + if (imagePlatform != null) { + pullImageCmd.withPlatform(imagePlatform); + } final AtomicReference dockerImageName = new AtomicReference<>(); // The following poll interval in ms: 50, 100, 200, 400, 800.... @@ -142,7 +150,7 @@ private Callable tryImagePullCommand( pullImage(pullImageCmd, logger); dockerImageName.set(imageName.asCanonicalNameString()); return true; - } catch (InterruptedException | InternalServerErrorException e) { + } catch (InterruptedException e) { // these classes of exception often relate to timeout/connection errors so should be retried lastFailure.set(e); logger.warn( diff --git a/core/src/test/java/org/testcontainers/images/RemoteDockerImageTest.java b/core/src/test/java/org/testcontainers/images/RemoteDockerImageTest.java index 156df3ae7d3..6dafd0fef79 100644 --- a/core/src/test/java/org/testcontainers/images/RemoteDockerImageTest.java +++ b/core/src/test/java/org/testcontainers/images/RemoteDockerImageTest.java @@ -1,5 +1,12 @@ package org.testcontainers.images; +import com.github.dockerjava.api.DockerClient; +import com.github.dockerjava.api.command.PullImageCmd; +import org.testcontainers.images.TimeLimitedLoggedPullImageResultCallback; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import org.testcontainers.utility.Base58; @@ -76,4 +83,24 @@ protected String resolve() { imageNameFuture.get(); assertThat(remoteDockerImage.toString()).contains("imageName=" + imageName); } + @Test + void passesExplicitPlatformToPullImageCommand() throws Exception { + DockerClient dockerClient = mock(DockerClient.class); + PullImageCmd pullImageCmd = mock(PullImageCmd.class); + + when(dockerClient.pullImageCmd("test/image")).thenReturn(pullImageCmd); + when(pullImageCmd.withTag("latest")).thenReturn(pullImageCmd); + when(pullImageCmd.withPlatform("linux/amd64")).thenReturn(pullImageCmd); + TimeLimitedLoggedPullImageResultCallback callback = mock(TimeLimitedLoggedPullImageResultCallback.class); + when(pullImageCmd.exec(any(TimeLimitedLoggedPullImageResultCallback.class))).thenReturn(callback); + when(callback.awaitCompletion()).thenReturn(callback); + + RemoteDockerImage remoteDockerImage = new RemoteDockerImage(DockerImageName.parse("test/image:latest")) + .withImagePlatform("linux/amd64"); + + remoteDockerImage.withDockerClient(dockerClient).get(); + + verify(pullImageCmd).withPlatform("linux/amd64"); + } + }