Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import java.util.concurrent.CancellationException;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
Expand All @@ -72,16 +73,23 @@ class ChannelPool extends ManagedChannel {

private final ChannelPoolSettings settings;
private final ChannelFactory channelFactory;
private final ScheduledExecutorService executor;
private final ScheduledExecutorService backgroundExecutor;
private final boolean shouldShutdownExecutor;

private ScheduledFuture<?> refreshFuture = null;
private ScheduledFuture<?> resizeFuture = null;

private final Object entryWriteLock = new Object();
@VisibleForTesting final AtomicReference<ImmutableList<Entry>> entries = new AtomicReference<>();
private final AtomicInteger indexTicker = new AtomicInteger();
private final String authority;

static ChannelPool create(ChannelPoolSettings settings, ChannelFactory channelFactory)
static ChannelPool create(
ChannelPoolSettings settings,
ChannelFactory channelFactory,
@Nullable ScheduledExecutorService backgroundExecutor)
throws IOException {
return new ChannelPool(settings, channelFactory, Executors.newSingleThreadScheduledExecutor());
return new ChannelPool(settings, channelFactory, backgroundExecutor);
}

/**
Expand All @@ -95,7 +103,7 @@ static ChannelPool create(ChannelPoolSettings settings, ChannelFactory channelFa
ChannelPool(
ChannelPoolSettings settings,
ChannelFactory channelFactory,
ScheduledExecutorService executor)
@Nullable ScheduledExecutorService executor)
throws IOException {
this.settings = settings;
this.channelFactory = channelFactory;
Expand All @@ -108,21 +116,30 @@ static ChannelPool create(ChannelPoolSettings settings, ChannelFactory channelFa

entries.set(initialListBuilder.build());
authority = entries.get().get(0).channel.authority();
this.executor = executor;

if (executor == null) {
this.backgroundExecutor = Executors.newSingleThreadScheduledExecutor();
this.shouldShutdownExecutor = true;
} else {
this.backgroundExecutor = executor;
this.shouldShutdownExecutor = false;
}

if (!settings.isStaticSize()) {
executor.scheduleAtFixedRate(
this::resizeSafely,
ChannelPoolSettings.RESIZE_INTERVAL.getSeconds(),
ChannelPoolSettings.RESIZE_INTERVAL.getSeconds(),
TimeUnit.SECONDS);
resizeFuture =
backgroundExecutor.scheduleAtFixedRate(
this::resizeSafely,
ChannelPoolSettings.RESIZE_INTERVAL.getSeconds(),
ChannelPoolSettings.RESIZE_INTERVAL.getSeconds(),
TimeUnit.SECONDS);
}
if (settings.isPreemptiveRefreshEnabled()) {
executor.scheduleAtFixedRate(
this::refreshSafely,
REFRESH_PERIOD.getSeconds(),
REFRESH_PERIOD.getSeconds(),
TimeUnit.SECONDS);
refreshFuture =
backgroundExecutor.scheduleAtFixedRate(
this::refreshSafely,
REFRESH_PERIOD.getSeconds(),
REFRESH_PERIOD.getSeconds(),
TimeUnit.SECONDS);
}
}

Expand Down Expand Up @@ -157,9 +174,16 @@ public ManagedChannel shutdown() {
for (Entry entry : localEntries) {
entry.channel.shutdown();
}
if (executor != null) {
if (shouldShutdownExecutor) {
// shutdownNow will cancel scheduled tasks
executor.shutdownNow();
backgroundExecutor.shutdownNow();
} else {
if (resizeFuture != null) {
resizeFuture.cancel(false);
}
if (refreshFuture != null) {
refreshFuture.cancel(false);
}
}
return this;
}
Expand All @@ -173,7 +197,7 @@ public boolean isShutdown() {
return false;
}
}
return executor == null || executor.isShutdown();
return !shouldShutdownExecutor || backgroundExecutor.isShutdown();
}

/** {@inheritDoc} */
Expand All @@ -186,7 +210,7 @@ public boolean isTerminated() {
}
}

return executor == null || executor.isTerminated();
return !shouldShutdownExecutor || backgroundExecutor.isTerminated();
}

/** {@inheritDoc} */
Expand All @@ -198,8 +222,15 @@ public ManagedChannel shutdownNow() {
for (Entry entry : localEntries) {
entry.channel.shutdownNow();
}
if (executor != null) {
executor.shutdownNow();
if (shouldShutdownExecutor) {
backgroundExecutor.shutdownNow();
} else {
if (resizeFuture != null) {
resizeFuture.cancel(true);
}
if (refreshFuture != null) {
refreshFuture.cancel(true);
}
}
return this;
}
Expand All @@ -216,9 +247,9 @@ public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedE
}
entry.channel.awaitTermination(awaitTimeNanos, TimeUnit.NANOSECONDS);
}
if (executor != null) {
if (shouldShutdownExecutor) {
long awaitTimeNanos = endTimeNanos - System.nanoTime();
executor.awaitTermination(awaitTimeNanos, TimeUnit.NANOSECONDS);
backgroundExecutor.awaitTermination(awaitTimeNanos, TimeUnit.NANOSECONDS);
}
return isTerminated();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP

private final int processorCount;
private final Executor executor;
@Nullable private final ScheduledExecutorService backgroundExecutor;
private final HeaderProvider headerProvider;
private final boolean useS2A;
private final String endpoint;
Expand Down Expand Up @@ -181,6 +182,7 @@ public enum HardBoundTokenTypes {
private InstantiatingGrpcChannelProvider(Builder builder) {
this.processorCount = builder.processorCount;
this.executor = builder.executor;
this.backgroundExecutor = builder.backgroundExecutor;
this.headerProvider = builder.headerProvider;
this.useS2A = builder.useS2A;
this.endpoint = builder.endpoint;
Expand Down Expand Up @@ -356,7 +358,9 @@ private TransportChannel createChannel() throws IOException {
return GrpcTransportChannel.newBuilder()
.setManagedChannel(
ChannelPool.create(
channelPoolSettings, InstantiatingGrpcChannelProvider.this::createSingleChannel))
channelPoolSettings,
InstantiatingGrpcChannelProvider.this::createSingleChannel,
backgroundExecutor))
.setDirectPath(this.canUseDirectPath())
.build();
}
Expand Down Expand Up @@ -839,6 +843,11 @@ public ChannelPoolSettings getChannelPoolSettings() {
return channelPoolSettings;
}

/** Gets the background executor for channel refresh and resize. */
public ScheduledExecutorService getBackgroundExecutor() {
return backgroundExecutor;
}

@Override
public boolean shouldAutoClose() {
return true;
Expand All @@ -855,6 +864,7 @@ public static Builder newBuilder() {
public static final class Builder {
@Deprecated private int processorCount;
private Executor executor;
private ScheduledExecutorService backgroundExecutor;
private HeaderProvider headerProvider;
private String endpoint;
private String mtlsEndpoint;
Expand Down Expand Up @@ -891,6 +901,7 @@ private Builder() {
private Builder(InstantiatingGrpcChannelProvider provider) {
this.processorCount = provider.processorCount;
this.executor = provider.executor;
this.backgroundExecutor = provider.backgroundExecutor;
this.headerProvider = provider.headerProvider;
this.endpoint = provider.endpoint;
this.useS2A = provider.useS2A;
Expand Down Expand Up @@ -950,6 +961,19 @@ public Builder setExecutorProvider(ExecutorProvider executorProvider) {
return setExecutor((Executor) executorProvider.getExecutor());
}

/**
* Sets the background executor for this TransportChannelProvider. The life cycle of the
* executor should be managed by the caller.
*
* <p>This is optional. The background executor is used for channel refresh and channel resize
* on {@link ChannelPool}. This allows us to reuse the same executor for other long running
* operations.
*/
public Builder setBackgroundExecutor(ScheduledExecutorService executor) {
this.backgroundExecutor = executor;
return this;
}

/**
* Sets the HeaderProvider for this TransportChannelProvider.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import com.google.api.gax.rpc.UnaryCallable;
import com.google.api.gax.util.FakeLogHandler;
import com.google.auth.Credentials;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.type.Color;
Expand Down Expand Up @@ -82,9 +81,10 @@ class ChannelPoolTest {

@AfterEach
void cleanup() throws InterruptedException {
Preconditions.checkNotNull(pool, "Channel pool was never created");
pool.shutdown();
pool.awaitTermination(DEFAULT_AWAIT_TERMINATION_SEC, TimeUnit.SECONDS);
if (pool != null) {
pool.shutdown();
pool.awaitTermination(DEFAULT_AWAIT_TERMINATION_SEC, TimeUnit.SECONDS);
}
}

@Test
Expand All @@ -97,7 +97,8 @@ void testAuthority() throws IOException {
pool =
ChannelPool.create(
ChannelPoolSettings.staticallySized(2),
new FakeChannelFactory(Arrays.asList(sub1, sub2)));
new FakeChannelFactory(Arrays.asList(sub1, sub2)),
null);
assertThat(pool.authority()).isEqualTo("myAuth");
}

Expand All @@ -111,7 +112,9 @@ void testRoundRobin() throws IOException {
ArrayList<ManagedChannel> channels = Lists.newArrayList(sub1, sub2);
pool =
ChannelPool.create(
ChannelPoolSettings.staticallySized(channels.size()), new FakeChannelFactory(channels));
ChannelPoolSettings.staticallySized(channels.size()),
new FakeChannelFactory(channels),
null);

verifyTargetChannel(pool, channels, sub1);
verifyTargetChannel(pool, channels, sub2);
Expand Down Expand Up @@ -168,7 +171,8 @@ void ensureEvenDistribution() throws InterruptedException, IOException {
pool =
ChannelPool.create(
ChannelPoolSettings.staticallySized(numChannels),
new FakeChannelFactory(Arrays.asList(channels)));
new FakeChannelFactory(Arrays.asList(channels)),
null);

int numThreads = 20;
final int numPerThread = 1000;
Expand Down Expand Up @@ -204,7 +208,8 @@ void channelPrimerShouldCallPoolConstruction() throws IOException {
ChannelPoolSettings.staticallySized(2).toBuilder()
.setPreemptiveRefreshEnabled(true)
.build(),
new FakeChannelFactory(Arrays.asList(channel1, channel2), mockChannelPrimer));
new FakeChannelFactory(Arrays.asList(channel1, channel2), mockChannelPrimer),
null);
Mockito.verify(mockChannelPrimer, Mockito.times(2))
.primeChannel(Mockito.any(ManagedChannel.class));
}
Expand Down Expand Up @@ -266,7 +271,7 @@ void callShouldCompleteAfterCreation() throws IOException {
ManagedChannel replacementChannel = Mockito.mock(ManagedChannel.class);
FakeChannelFactory channelFactory =
new FakeChannelFactory(ImmutableList.of(underlyingChannel, replacementChannel));
pool = ChannelPool.create(ChannelPoolSettings.staticallySized(1), channelFactory);
pool = ChannelPool.create(ChannelPoolSettings.staticallySized(1), channelFactory, null);

// create a mock call when new call comes to the underlying channel
MockClientCall<String, Integer> mockClientCall = new MockClientCall<>(1, Status.OK);
Expand Down Expand Up @@ -315,7 +320,7 @@ void callShouldCompleteAfterStarted() throws IOException {

FakeChannelFactory channelFactory =
new FakeChannelFactory(ImmutableList.of(underlyingChannel, replacementChannel));
pool = ChannelPool.create(ChannelPoolSettings.staticallySized(1), channelFactory);
pool = ChannelPool.create(ChannelPoolSettings.staticallySized(1), channelFactory, null);

// create a mock call when new call comes to the underlying channel
MockClientCall<String, Integer> mockClientCall = new MockClientCall<>(1, Status.OK);
Expand Down Expand Up @@ -360,7 +365,7 @@ void channelShouldShutdown() throws IOException {

FakeChannelFactory channelFactory =
new FakeChannelFactory(ImmutableList.of(underlyingChannel, replacementChannel));
pool = ChannelPool.create(ChannelPoolSettings.staticallySized(1), channelFactory);
pool = ChannelPool.create(ChannelPoolSettings.staticallySized(1), channelFactory, null);

// create a mock call when new call comes to the underlying channel
MockClientCall<String, Integer> mockClientCall = new MockClientCall<>(1, Status.OK);
Expand Down Expand Up @@ -623,7 +628,7 @@ void testReleasingClientCallCancelEarly() throws IOException {
Mockito.when(fakeChannel.newCall(Mockito.any(), Mockito.any())).thenReturn(mockClientCall);
ChannelPoolSettings channelPoolSettings = ChannelPoolSettings.staticallySized(1);
ChannelFactory factory = new FakeChannelFactory(ImmutableList.of(fakeChannel));
pool = ChannelPool.create(channelPoolSettings, factory);
pool = ChannelPool.create(channelPoolSettings, factory, null);

EndpointContext endpointContext = Mockito.mock(EndpointContext.class);
Mockito.doNothing()
Expand Down Expand Up @@ -681,7 +686,7 @@ void testDoubleRelease() throws Exception {
ChannelPoolSettings channelPoolSettings = ChannelPoolSettings.staticallySized(1);
ChannelFactory factory = new FakeChannelFactory(ImmutableList.of(fakeChannel));

pool = ChannelPool.create(channelPoolSettings, factory);
pool = ChannelPool.create(channelPoolSettings, factory, null);

EndpointContext endpointContext = Mockito.mock(EndpointContext.class);
Mockito.doNothing()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ void testAffinity() throws IOException {
Channel pool =
ChannelPool.create(
ChannelPoolSettings.staticallySized(2),
new FakeChannelFactory(Arrays.asList(channel0, channel1)));
new FakeChannelFactory(Arrays.asList(channel0, channel1)),
null);
GrpcCallContext context = defaultCallContext.withChannel(pool);

ClientCall<Color, Money> gotCallA =
Expand Down
Loading