From 335fbdec99f9c0fe0d8f5c4652a3ec750230a19a Mon Sep 17 00:00:00 2001 From: Sam Whittle Date: Fri, 12 Jun 2026 12:53:54 +0200 Subject: [PATCH] [FnApi Java] Add support for separate named data streams to provide bundle isolation. This is advertised to the runner via a new NAMED_DATA_STREAMS protocol capability. The runner is then free to assign bundles to named data streams as it chooses to isolate bundle processing from each other. Instead of single data stream from the sdk, the sdk will create a data stream for each name. The benefit of doing so is that the multiplexing currently performed on data stream messages being received allows a slow bundle to fill up buffers and block the shared stream. With separate named streams, bundles on other data streams have separate grpc flow control from the blocked stream and are not affected. --- .../model/fn_execution/v1/beam_fn_api.proto | 38 ++++- .../model/pipeline/v1/beam_runner_api.proto | 4 + .../fnexecution/control/SdkHarnessClient.java | 2 +- .../fnexecution/data/FnDataService.java | 3 +- .../fnexecution/data/GrpcDataService.java | 13 +- .../fnexecution/data/GrpcDataServiceTest.java | 2 +- .../fn/data/BeamFnDataGrpcMultiplexer.java | 2 +- .../fn/data/BeamFnDataOutboundAggregator.java | 131 ++++++++++++----- .../sdk/util/construction/Environments.java | 1 + .../BeamFnDataOutboundAggregatorTest.java | 135 ++++++++++-------- .../org/apache/beam/fn/harness/FnHarness.java | 49 ++++--- .../harness/control/ProcessBundleHandler.java | 50 ++++--- .../fn/harness/data/BeamFnDataClient.java | 28 +--- .../fn/harness/data/BeamFnDataGrpcClient.java | 91 ++++++++---- .../fn/harness/BeamFnDataWriteRunnerTest.java | 44 +++--- .../beam/fn/harness/FnApiDoFnRunnerTest.java | 2 +- .../PTransformRunnerFactoryTestContext.java | 12 +- .../control/ProcessBundleHandlerTest.java | 47 +++--- .../data/BeamFnDataGrpcClientTest.java | 41 +++--- 19 files changed, 412 insertions(+), 283 deletions(-) diff --git a/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto b/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto index ecef3f2e7a94..80a0fa6f7f28 100644 --- a/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto +++ b/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto @@ -120,6 +120,9 @@ message RemoteGrpcPort { service BeamFnControl { // Instructions sent by the runner to the SDK requesting different types // of work. + // + // Header metadata has the specified keys pairs: + // - "worker_id": the id of the sdk rpc Control( // A stream of responses to instructions the SDK was asked to be // performed. @@ -130,6 +133,9 @@ service BeamFnControl { // Used to get the full process bundle descriptors for bundles one // is asked to process. + // + // Header metadata has the specified keys pairs: + // - "worker_id": the id of the sdk rpc GetProcessBundleDescriptor(GetProcessBundleDescriptorRequest) returns ( ProcessBundleDescriptor) {} } @@ -416,14 +422,22 @@ message ProcessBundleRequest { // at https://s.apache.org/beam-fn-api-control-data-embedding. Elements elements = 3; - // indicates that the runner has no stare for the keys in this bundle + // Indicates that the runner has no state for the keys in this bundle // so SDk can safely begin stateful processing with a locally-generated - // initial empty state + // initial empty state. bool has_no_state = 4; - // indicates that the runner will never process another bundle for the keys + // Indicates that the runner will never process another bundle for the keys // in this bundle so state need not be included in the bundle commit. bool only_bundle_for_keys = 5; + + // (Optional) If non-empty, the ID of the data stream to use for all data + // requests related to this bundle. See comments at BeamFnData.Data for + // more details. + // + // The runner should only populate this field if the sdk advertises the + // beam:protocol:named_data_streams:v1 capability. + string data_stream_id = 6; } message ProcessBundleResponse { @@ -834,7 +848,15 @@ message Elements { // Stable service BeamFnData { - // Used to send data between harnesses. + // Used to send data between harnesses. Sdks default to using an unnamed data stream + // (without "data_stream_id" header value) for bundles unless the runner requests another named stream to be + // used for a bundle. SDKs can advertise that they support named data streams with the capability + // `beam:protocol:named_data_streams:v1`. + // + // Header metadata has the specified keys pairs: + // - "worker_id": value is the id of the sdk + // - "data_stream_id": value is the id of the data stream, distinguishing it from other data streams from the same + // sdk. This field should only be populated if requested in a received ProcessBundleRequest from the runner. rpc Data( // A stream of data representing input. stream Elements) @@ -900,6 +922,9 @@ message StateResponse { service BeamFnState { // Used to get/append/clear state stored by the runner on behalf of the SDK. + // + // Header metadata has the specified keys pairs: + // - "worker_id": the id of the sdk rpc State( // A stream of state instructions requested of the runner. stream StateRequest) @@ -1295,6 +1320,9 @@ message LogControl {} service BeamFnLogging { // Allows for the SDK to emit log entries which the runner can // associate with the active job. + // + // Header metadata has the specified keys pairs: + // - "worker_id": the id of the sdk rpc Logging( // A stream of log entries batched into lists emitted by the SDK harness. stream LogEntry.List) @@ -1356,6 +1384,8 @@ message WorkerStatusResponse { // API for SDKs to report debug-related statuses to runner during pipeline execution. service BeamFnWorkerStatus { + // Header metadata has the specified keys pairs: + // - "worker_id": the id of the sdk rpc WorkerStatus (stream WorkerStatusResponse) returns (stream WorkerStatusRequest) {} } diff --git a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto index 67df8b9e8003..5824c9bf4b73 100644 --- a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto +++ b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto @@ -1689,6 +1689,10 @@ message StandardProtocols { // Indicates whether the SDK supports multimap state. MULTIMAP_STATE = 12 [(beam_urn) = "beam:protocol:multimap_state:v1"]; + + // Indicates whether the SDK supports data stream ids being requested by the runner in + // ProcessBundleRequests. + NAMED_DATA_STREAMS = 13 [(beam_urn) = "beam:protocol:named_data_streams:v1"]; } } diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java index 682c45e30795..704d298a195d 100644 --- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java @@ -298,7 +298,7 @@ public ActiveBundle newBundle( ImmutableMap.Builder> receiverBuilder = ImmutableMap.builder(); BeamFnDataOutboundAggregator beamFnDataOutboundAggregator = - fnApiDataService.createOutboundAggregator(() -> bundleId, false); + fnApiDataService.createOutboundAggregator(bundleId, false); for (RemoteInputDestination remoteInput : remoteInputs) { LogicalEndpoint endpoint = LogicalEndpoint.data(bundleId, remoteInput.getPTransformId()); receiverBuilder.put( diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/data/FnDataService.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/data/FnDataService.java index 7c5f110eab28..657ec74553bc 100644 --- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/data/FnDataService.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/data/FnDataService.java @@ -17,7 +17,6 @@ */ package org.apache.beam.runners.fnexecution.data; -import java.util.function.Supplier; import org.apache.beam.model.fnexecution.v1.BeamFnApi.Elements; import org.apache.beam.sdk.fn.data.BeamFnDataOutboundAggregator; import org.apache.beam.sdk.fn.data.CloseableFnDataReceiver; @@ -69,5 +68,5 @@ public interface FnDataService { *

The returned aggregator is not thread safe. */ BeamFnDataOutboundAggregator createOutboundAggregator( - Supplier processBundleRequestIdSupplier, boolean collectElementsIfNoFlushes); + String processBundleId, boolean collectElementsIfNoFlushes); } diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/data/GrpcDataService.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/data/GrpcDataService.java index d4e45c8ccf82..a3a8c3244044 100644 --- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/data/GrpcDataService.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/data/GrpcDataService.java @@ -23,7 +23,6 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import java.util.function.Supplier; import org.apache.beam.model.fnexecution.v1.BeamFnApi; import org.apache.beam.model.fnexecution.v1.BeamFnApi.Elements; import org.apache.beam.model.fnexecution.v1.BeamFnDataGrpc; @@ -175,13 +174,13 @@ public void unregisterReceiver(String instructionId) { @Override public BeamFnDataOutboundAggregator createOutboundAggregator( - Supplier processBundleRequestIdSupplier, boolean collectElementsIfNoFlushes) { + String instructionId, boolean collectElementsIfNoFlushes) { try { - return new BeamFnDataOutboundAggregator( - options, - processBundleRequestIdSupplier, - connectedClient.get(3, TimeUnit.MINUTES).getOutboundObserver(), - collectElementsIfNoFlushes); + BeamFnDataOutboundAggregator aggregator = + new BeamFnDataOutboundAggregator(options, collectElementsIfNoFlushes); + aggregator.prepareForInstruction( + instructionId, connectedClient.get(3, TimeUnit.MINUTES).getOutboundObserver()); + return aggregator; } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new RuntimeException(e); diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/data/GrpcDataServiceTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/data/GrpcDataServiceTest.java index f84467077501..363367f1087f 100644 --- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/data/GrpcDataServiceTest.java +++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/data/GrpcDataServiceTest.java @@ -102,7 +102,7 @@ public void testMessageReceivedBySingleClientWhenThereAreMultipleClients() throw for (int i = 0; i < 3; ++i) { final String instructionId = Integer.toString(i); BeamFnDataOutboundAggregator aggregator = - service.createOutboundAggregator(() -> instructionId, false); + service.createOutboundAggregator(instructionId, false); aggregator.start(); FnDataReceiver> consumer = aggregator.registerOutputDataLocation(TRANSFORM_ID, CODER); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexer.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexer.java index 8fec8b455cce..0b9d6adab4f0 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexer.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexer.java @@ -63,7 +63,7 @@ public class BeamFnDataGrpcMultiplexer implements AutoCloseable { private final Cache poisonedInstructionIds; private static class PoisonedException extends RuntimeException { - public PoisonedException() { + private PoisonedException() { super("Instruction poisoned"); } }; diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregator.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregator.java index 9b9603706b48..04c7e6ef96b4 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregator.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregator.java @@ -17,6 +17,9 @@ */ package org.apache.beam.sdk.fn.data; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; + import java.io.IOException; import java.util.Collections; import java.util.HashMap; @@ -28,7 +31,6 @@ import java.util.concurrent.Executors; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; -import java.util.function.Supplier; import javax.annotation.Nullable; import javax.annotation.concurrent.NotThreadSafe; import org.apache.beam.model.fnexecution.v1.BeamFnApi.Elements; @@ -53,10 +55,7 @@ *

The default time-based buffer threshold can be overridden by specifying the experiment {@code * data_buffer_time_limit_ms=} */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -// The calling thread that invokes sendBufferedDataAndFinishOutboundStreams synchronizes on +// The calling thread that invokes sendOrCollectBufferedDataAndFinishOutboundStreams synchronizes on // flushLock effectively making the periodic flushing no longer read or mutate hasFlushedForBundle // and allowing the calling thread to read and mutate hasFlushedForBundle safely without needing to // create another memory barrier. Also note that flush is always invoked when synchronizing on @@ -72,31 +71,54 @@ public class BeamFnDataOutboundAggregator { private static final Logger LOG = LoggerFactory.getLogger(BeamFnDataOutboundAggregator.class); private final int sizeLimit; private final long timeLimit; - private final Supplier processBundleRequestIdSupplier; - @VisibleForTesting final Map> outputDataReceivers; - @VisibleForTesting final Map> outputTimersReceivers; - private final StreamObserver outboundObserver; + // The instructionId is set between prepareForInstruction and finishInstruction/discard. + private @Nullable String instructionId = null; + @VisibleForTesting final Map> outputDataReceivers = new HashMap<>(); + @VisibleForTesting final Map> outputTimersReceivers = new HashMap<>(); + @Nullable private StreamObserver outboundObserver; @Nullable @VisibleForTesting ScheduledFuture flushFuture; - private long bytesWrittenSinceFlush; - private final Object flushLock; + private long bytesWrittenSinceFlush = 0; + private final Object flushLock = new Object(); private final boolean collectElementsIfNoFlushes; - private boolean hasFlushedForBundle; + private boolean hasFlushedForBundle = false; - public BeamFnDataOutboundAggregator( - PipelineOptions options, - Supplier processBundleRequestIdSupplier, - StreamObserver outboundObserver, - boolean collectElementsIfNoFlushes) { + public BeamFnDataOutboundAggregator(PipelineOptions options, boolean collectElementsIfNoFlushes) { this.sizeLimit = getSizeLimit(options); this.timeLimit = getTimeLimit(options); this.collectElementsIfNoFlushes = collectElementsIfNoFlushes; - this.outputDataReceivers = new HashMap<>(); - this.outputTimersReceivers = new HashMap<>(); - this.outboundObserver = outboundObserver; - this.processBundleRequestIdSupplier = processBundleRequestIdSupplier; - this.bytesWrittenSinceFlush = 0L; - this.flushLock = new Object(); - this.hasFlushedForBundle = false; + } + + public void prepareForInstruction( + String instructionId, StreamObserver outboundObserver) { + if (timeLimit > 0) { + synchronized (flushLock) { + checkState(this.instructionId == null && this.outboundObserver == null); + this.instructionId = instructionId; + this.outboundObserver = outboundObserver; + } + } else { + checkState(this.instructionId == null && this.outboundObserver == null); + this.instructionId = instructionId; + this.outboundObserver = outboundObserver; + } + } + + public void finishInstruction() { + if (flushFuture != null) { + synchronized (flushLock) { + checkState( + this.instructionId != null && this.outboundObserver != null, + "instruction was not started or previously completed"); + checkState(bytesWrittenSinceFlush == 0, "bytes were not flushed for instruction"); + this.instructionId = null; + this.outboundObserver = null; + } + } else { + checkState(this.instructionId != null && this.outboundObserver != null); + checkState(bytesWrittenSinceFlush == 0, "bytes were not flushed for instruction"); + this.instructionId = null; + this.outboundObserver = null; + } } /** Starts the flushing daemon thread if data_buffer_time_limit_ms is set. */ @@ -166,7 +188,7 @@ private void flushInternal() { } Elements.Builder elements = convertBufferForTransmission(); if (elements.getDataCount() > 0 || elements.getTimersCount() > 0) { - outboundObserver.onNext(elements.build()); + checkNotNull(outboundObserver).onNext(elements.build()); } hasFlushedForBundle = true; } @@ -177,10 +199,15 @@ private void flushInternal() { * collectElementsIfNoFlushes=true, and there was no previous flush in this bundle, otherwise * returns null. */ + @Nullable public Elements sendOrCollectBufferedDataAndFinishOutboundStreams() { if (outputTimersReceivers.isEmpty() && outputDataReceivers.isEmpty()) { return null; } + String instructionId = + checkNotNull( + this.instructionId, + "This method should only be called between prepareForInstruction and finishInstruction"); Elements.Builder bufferedElements; if (timeLimit > 0) { synchronized (flushLock) { @@ -191,14 +218,14 @@ public Elements sendOrCollectBufferedDataAndFinishOutboundStreams() { } LOG.debug( "Closing streams for instruction {} and outbound data {} and timers {}.", - processBundleRequestIdSupplier.get(), + instructionId, outputDataReceivers, outputTimersReceivers); for (Map.Entry> entry : outputDataReceivers.entrySet()) { String pTransformId = entry.getKey(); bufferedElements .addDataBuilder() - .setInstructionId(processBundleRequestIdSupplier.get()) + .setInstructionId(instructionId) .setTransformId(pTransformId) .setIsLast(true); entry.getValue().resetStats(); @@ -207,35 +234,60 @@ public Elements sendOrCollectBufferedDataAndFinishOutboundStreams() { TimerEndpoint timerKey = entry.getKey(); bufferedElements .addTimersBuilder() - .setInstructionId(processBundleRequestIdSupplier.get()) + .setInstructionId(instructionId) .setTransformId(timerKey.pTransformId) .setTimerFamilyId(timerKey.timerFamilyId) .setIsLast(true); entry.getValue().resetStats(); } + // This is the end of the bundle so we reset state to prepare for future bundles. if (collectElementsIfNoFlushes && !hasFlushedForBundle) { return bufferedElements.build(); } - outboundObserver.onNext(bufferedElements.build()); - // This is now at the end of a bundle, so we reset hasFlushedForBundle to prepare for new - // bundles. + checkNotNull(outboundObserver).onNext(bufferedElements.build()); hasFlushedForBundle = false; return null; } // Send the elements to the StreamObserver associated with this aggregator. public void sendElements(Elements elements) { - outboundObserver.onNext(elements); + if (timeLimit > 0) { + synchronized (flushLock) { + checkNotNull(outboundObserver).onNext(elements); + } + } else { + checkNotNull(outboundObserver).onNext(elements); + } } + // Prepares for discarding the aggregator without preserving its output or + // preparing it for reuse. public void discard() { - if (flushFuture != null) { - flushFuture.cancel(true); + if (timeLimit > 0) { + // Short-circuit the possibly concurrently running flush. + synchronized (flushLock) { + bytesWrittenSinceFlush = 0L; + finishInstruction(); + } + if (flushFuture != null) { + flushFuture.cancel(false); + } + } else { + bytesWrittenSinceFlush = 0L; + finishInstruction(); } } private Elements.Builder convertBufferForTransmission() { Elements.Builder bufferedElements = Elements.newBuilder(); + if (bytesWrittenSinceFlush == 0) { + return bufferedElements; + } + bytesWrittenSinceFlush = 0L; + String instructionId = + checkNotNull( + this.instructionId, + "This method should only be called between prepareForInstruction and finishInstruction"); for (Map.Entry> entry : outputDataReceivers.entrySet()) { if (!entry.getValue().hasBufferedOutput()) { continue; @@ -243,7 +295,7 @@ private Elements.Builder convertBufferForTransmission() { ByteString bytes = entry.getValue().toByteStringAndResetBuffer(); bufferedElements .addDataBuilder() - .setInstructionId(processBundleRequestIdSupplier.get()) + .setInstructionId(instructionId) .setTransformId(entry.getKey()) .setData(bytes); } @@ -254,12 +306,11 @@ private Elements.Builder convertBufferForTransmission() { ByteString bytes = entry.getValue().toByteStringAndResetBuffer(); bufferedElements .addTimersBuilder() - .setInstructionId(processBundleRequestIdSupplier.get()) + .setInstructionId(instructionId) .setTransformId(entry.getKey().pTransformId) .setTimerFamilyId(entry.getKey().timerFamilyId) .setTimers(bytes); } - bytesWrittenSinceFlush = 0L; return bufferedElements; } @@ -277,7 +328,7 @@ void flush() { /** Check if the flush thread failed with an exception. */ private void checkFlushThreadException() throws IOException { - if (timeLimit > 0 && flushFuture.isDone()) { + if (flushFuture != null && flushFuture.isDone()) { try { flushFuture.get(); throw new IOException("Periodic flushing thread finished unexpectedly."); @@ -353,10 +404,12 @@ public void accept(T input) throws Exception { } } + @VisibleForTesting public long getByteCount() { return perBundleByteCount; } + @VisibleForTesting public long getElementCount() { return perBundleElementCount; } @@ -392,7 +445,7 @@ public TimerEndpoint(String pTransformId, String timerFamilyId) { } @Override - public boolean equals(Object o) { + public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/Environments.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/Environments.java index 969bda88d07f..c3b1a7a5235e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/Environments.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/Environments.java @@ -522,6 +522,7 @@ public static Set getJavaCapabilities() { capabilities.add(BeamUrns.getUrn(StandardProtocols.Enum.SDK_CONSUMING_RECEIVED_DATA)); capabilities.add(BeamUrns.getUrn(StandardProtocols.Enum.ORDERED_LIST_STATE)); capabilities.add(BeamUrns.getUrn(StandardProtocols.Enum.MULTIMAP_STATE)); + capabilities.add(BeamUrns.getUrn(StandardProtocols.Enum.NAMED_DATA_STREAMS)); return capabilities.build(); } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregatorTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregatorTest.java index 092ba200c94b..9bcf615d638b 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregatorTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataOutboundAggregatorTest.java @@ -20,6 +20,7 @@ import static org.hamcrest.Matchers.empty; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.fail; import java.io.IOException; @@ -75,13 +76,12 @@ public void testWithDefaultBuffer() throws Exception { final List values = new ArrayList<>(); final AtomicBoolean onCompletedWasCalled = new AtomicBoolean(); BeamFnDataOutboundAggregator aggregator = - new BeamFnDataOutboundAggregator( - PipelineOptionsFactory.create(), - endpoint::getInstructionId, - TestStreams.withOnNext(values::add) - .withOnCompleted(() -> onCompletedWasCalled.set(true)) - .build(), - false); + new BeamFnDataOutboundAggregator(PipelineOptionsFactory.create(), false); + aggregator.prepareForInstruction( + endpoint.getInstructionId(), + TestStreams.withOnNext(values::add) + .withOnCompleted(() -> onCompletedWasCalled.set(true)) + .build()); // Test that nothing is emitted till the default buffer size is surpassed. FnDataReceiver dataReceiver = registerOutputLocation(aggregator, endpoint, CODER); @@ -124,14 +124,12 @@ public void testConfiguredBufferLimit() throws Exception { options .as(ExperimentalOptions.class) .setExperiments(Arrays.asList("data_buffer_size_limit=100")); - BeamFnDataOutboundAggregator aggregator = - new BeamFnDataOutboundAggregator( - options, - endpoint::getInstructionId, - TestStreams.withOnNext(values::add) - .withOnCompleted(() -> onCompletedWasCalled.set(true)) - .build(), - false); + BeamFnDataOutboundAggregator aggregator = new BeamFnDataOutboundAggregator(options, false); + aggregator.prepareForInstruction( + endpoint.getInstructionId(), + TestStreams.withOnNext(values::add) + .withOnCompleted(() -> onCompletedWasCalled.set(true)) + .build()); // Test that nothing is emitted till the default buffer size is surpassed. FnDataReceiver dataReceiver = registerOutputLocation(aggregator, endpoint, CODER); aggregator.start(); @@ -187,18 +185,16 @@ public void testConfiguredTimeLimit() throws Exception { .as(ExperimentalOptions.class) .setExperiments(Arrays.asList("data_buffer_time_limit_ms=1")); final CountDownLatch waitForFlush = new CountDownLatch(1); - BeamFnDataOutboundAggregator aggregator = - new BeamFnDataOutboundAggregator( - options, - endpoint::getInstructionId, - TestStreams.withOnNext( - (Consumer) - e -> { - values.add(e); - waitForFlush.countDown(); - }) - .build(), - false); + BeamFnDataOutboundAggregator aggregator = new BeamFnDataOutboundAggregator(options, false); + aggregator.prepareForInstruction( + endpoint.getInstructionId(), + TestStreams.withOnNext( + (Consumer) + e -> { + values.add(e); + waitForFlush.countDown(); + }) + .build()); // Test that it emits when time passed the time limit FnDataReceiver dataReceiver = registerOutputLocation(aggregator, endpoint, CODER); @@ -214,17 +210,15 @@ public void testConfiguredTimeLimitExceptionPropagation() throws Exception { options .as(ExperimentalOptions.class) .setExperiments(Arrays.asList("data_buffer_time_limit_ms=1")); - BeamFnDataOutboundAggregator aggregator = - new BeamFnDataOutboundAggregator( - options, - endpoint::getInstructionId, - TestStreams.withOnNext( - (Consumer) - e -> { - throw new RuntimeException(""); - }) - .build(), - false); + BeamFnDataOutboundAggregator aggregator = new BeamFnDataOutboundAggregator(options, false); + aggregator.prepareForInstruction( + endpoint.getInstructionId(), + TestStreams.withOnNext( + (Consumer) + e -> { + throw new RuntimeException(""); + }) + .build()); // Test that it emits when time passed the time limit FnDataReceiver dataReceiver = registerOutputLocation(aggregator, endpoint, CODER); @@ -243,17 +237,15 @@ public void testConfiguredTimeLimitExceptionPropagation() throws Exception { // expected } - aggregator = - new BeamFnDataOutboundAggregator( - options, - endpoint::getInstructionId, - TestStreams.withOnNext( - (Consumer) - e -> { - throw new RuntimeException(""); - }) - .build(), - false); + aggregator = new BeamFnDataOutboundAggregator(options, false); + aggregator.prepareForInstruction( + endpoint.getInstructionId(), + TestStreams.withOnNext( + (Consumer) + e -> { + throw new RuntimeException(""); + }) + .build()); dataReceiver = registerOutputLocation(aggregator, endpoint, CODER); aggregator.start(); dataReceiver.accept(new byte[1]); @@ -279,14 +271,12 @@ public void testConfiguredBufferLimitMultipleEndpoints() throws Exception { options .as(ExperimentalOptions.class) .setExperiments(Arrays.asList("data_buffer_size_limit=100")); - BeamFnDataOutboundAggregator aggregator = - new BeamFnDataOutboundAggregator( - options, - endpoint::getInstructionId, - TestStreams.withOnNext(values::add) - .withOnCompleted(() -> onCompletedWasCalled.set(true)) - .build(), - false); + BeamFnDataOutboundAggregator aggregator = new BeamFnDataOutboundAggregator(options, false); + aggregator.prepareForInstruction( + endpoint.getInstructionId(), + TestStreams.withOnNext(values::add) + .withOnCompleted(() -> onCompletedWasCalled.set(true)) + .build()); // Test that nothing is emitted till the default buffer size is surpassed. LogicalEndpoint additionalEndpoint = LogicalEndpoint.data( @@ -334,6 +324,37 @@ public void testConfiguredBufferLimitMultipleEndpoints() throws Exception { checkEqualInAnyOrder(builder.build(), values.get(1)); } + @Test + public void testInstructionLifecycle() { + BeamFnDataOutboundAggregator aggregator = + new BeamFnDataOutboundAggregator(PipelineOptionsFactory.create(), false); + assertThrows( + NullPointerException.class, () -> aggregator.sendElements(Elements.getDefaultInstance())); + aggregator.prepareForInstruction( + "testInstruction", + TestStreams.withOnNext( + (Consumer) + e -> { + throw new RuntimeException(""); + }) + .build()); + assertThrows( + IllegalStateException.class, + () -> + aggregator.prepareForInstruction( + "testInstruction", + TestStreams.withOnNext( + (Consumer) + e -> { + throw new RuntimeException(""); + }) + .build())); + aggregator.finishInstruction(); + assertThrows( + NullPointerException.class, () -> aggregator.sendElements(Elements.getDefaultInstance())); + assertThrows(IllegalStateException.class, aggregator::finishInstruction); + } + private void checkEqualInAnyOrder(Elements first, Elements second) { MatcherAssert.assertThat( first.getDataList(), Matchers.containsInAnyOrder(second.getDataList().toArray())); diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java index 703e726739a0..60e83251f147 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java @@ -17,6 +17,8 @@ */ package org.apache.beam.fn.harness; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; + import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import java.nio.charset.StandardCharsets; @@ -30,7 +32,6 @@ import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.function.Function; -import javax.annotation.Nullable; import org.apache.beam.fn.harness.control.BeamFnControlClient; import org.apache.beam.fn.harness.control.ExecutionStateSampler; import org.apache.beam.fn.harness.control.FinalizeBundleHandler; @@ -72,6 +73,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.MoreExecutors; +import org.checkerframework.checker.nullness.qual.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -92,9 +94,6 @@ * for further details. * */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) public class FnHarness { private static final String HARNESS_ID = "HARNESS_ID"; private static final String CONTROL_API_SERVICE_DESCRIPTOR = "CONTROL_API_SERVICE_DESCRIPTOR"; @@ -138,22 +137,31 @@ private static void removeKeyRecursively(JsonNode node, String keyToRemove) { } public static void main(String[] args) throws Exception { - main(System::getenv); + Function environmentVarGetter = System::getenv; + main(environmentVarGetter); } @VisibleForTesting - public static void main(Function environmentVarGetter) throws Exception { + public static void main(Function environmentVarGetter) + throws Exception { JvmInitializers.runOnStartup(); Endpoints.ApiServiceDescriptor loggingApiServiceDescriptor = - getApiServiceDescriptor(environmentVarGetter.apply(LOGGING_API_SERVICE_DESCRIPTOR)); + getApiServiceDescriptor( + checkNotNull( + environmentVarGetter.apply(LOGGING_API_SERVICE_DESCRIPTOR), + "LOGGING_API_SERVICE_DESCRIPTOR env var must be set.")); Endpoints.ApiServiceDescriptor controlApiServiceDescriptor = - getApiServiceDescriptor(environmentVarGetter.apply(CONTROL_API_SERVICE_DESCRIPTOR)); - Endpoints.ApiServiceDescriptor statusApiServiceDescriptor = - environmentVarGetter.apply(STATUS_API_SERVICE_DESCRIPTOR) == null - ? null - : getApiServiceDescriptor(environmentVarGetter.apply(STATUS_API_SERVICE_DESCRIPTOR)); - String id = environmentVarGetter.apply(HARNESS_ID); + getApiServiceDescriptor( + checkNotNull( + environmentVarGetter.apply(CONTROL_API_SERVICE_DESCRIPTOR), + "CONTROL_API_SERVICE_DESCRIPTOR env var must be set.")); + + @Nullable String envVar = environmentVarGetter.apply(STATUS_API_SERVICE_DESCRIPTOR); + Endpoints.@Nullable ApiServiceDescriptor statusApiServiceDescriptor = + (envVar == null) ? null : getApiServiceDescriptor(envVar); + String id = + checkNotNull(environmentVarGetter.apply(HARNESS_ID), "HARNESS_ID env var must be set."); System.out.format("SDK Fn Harness started%n"); System.out.format("Harness ID %s%n", id); @@ -161,11 +169,11 @@ public static void main(Function environmentVarGetter) throws Ex System.out.format("Control location %s%n", controlApiServiceDescriptor); System.out.format("Status location %s%n", statusApiServiceDescriptor); - String pipelineOptionsJson = environmentVarGetter.apply(PIPELINE_OPTIONS); // Try looking for a file first. If that exists it should override PIPELINE_OPTIONS to avoid // maxing out the kernel's environment space + @Nullable String pipelineOptionsJson = null; try { - String pipelineOptionsPath = environmentVarGetter.apply(PIPELINE_OPTIONS_FILE); + @Nullable String pipelineOptionsPath = environmentVarGetter.apply(PIPELINE_OPTIONS_FILE); System.out.format("Pipeline Options File %s%n", pipelineOptionsPath); if (pipelineOptionsPath != null) { Path filePath = Paths.get(pipelineOptionsPath); @@ -179,11 +187,12 @@ public static void main(Function environmentVarGetter) throws Ex } catch (Exception e) { System.out.format("Problem loading pipeline options from file: %s%n", e.getMessage()); } - + if (pipelineOptionsJson == null) { + pipelineOptionsJson = checkNotNull(environmentVarGetter.apply(PIPELINE_OPTIONS)); + } System.out.format("Pipeline options %s%n", pipelineOptionsJson); // TODO: https://github.com/apache/beam/issues/30301 pipelineOptionsJson = removeNestedKey(pipelineOptionsJson, "impersonateServiceAccount"); - PipelineOptions options = PipelineOptionsTranslation.fromJson(pipelineOptionsJson); String runnerCapabilitesOrNull = environmentVarGetter.apply(RUNNER_CAPABILITIES); @@ -219,7 +228,7 @@ public static void main( Set runnerCapabilities, Endpoints.ApiServiceDescriptor loggingApiServiceDescriptor, Endpoints.ApiServiceDescriptor controlApiServiceDescriptor, - @Nullable Endpoints.ApiServiceDescriptor statusApiServiceDescriptor) + Endpoints.@Nullable ApiServiceDescriptor statusApiServiceDescriptor) throws Exception { ManagedChannelFactory channelFactory; if (ExperimentalOptions.hasExperiment(options, "beam_fn_api_epoll")) { @@ -263,7 +272,7 @@ public static void main( Set runnerCapabilites, Endpoints.ApiServiceDescriptor loggingApiServiceDescriptor, Endpoints.ApiServiceDescriptor controlApiServiceDescriptor, - Endpoints.ApiServiceDescriptor statusApiServiceDescriptor, + Endpoints.@Nullable ApiServiceDescriptor statusApiServiceDescriptor, ManagedChannelFactory channelFactory, OutboundObserverFactory outboundObserverFactory, Cache processWideCache) @@ -318,7 +327,7 @@ public static void main( BeamFnControlGrpc.newBlockingStub(channel); BeamFnDataGrpcClient beamFnDataMultiplexer = - new BeamFnDataGrpcClient(options, channelFactory::forDescriptor, outboundObserverFactory); + new BeamFnDataGrpcClient(channelFactory::forDescriptor, outboundObserverFactory); BeamFnStateGrpcClientCache beamFnStateGrpcClientCache = new BeamFnStateGrpcClientCache(idGenerator, channelFactory, outboundObserverFactory); diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java index 5a57b137bf6b..449afd6a0243 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java @@ -224,7 +224,6 @@ public ProcessBundleHandler( private void addRunnerAndConsumersForPTransformRecursively( BeamFnStateClient beamFnStateClient, - BeamFnDataClient queueingClient, String pTransformId, PTransform pTransform, Supplier processBundleInstructionId, @@ -257,7 +256,6 @@ private void addRunnerAndConsumersForPTransformRecursively( for (String consumingPTransformId : pCollectionIdsToConsumingPTransforms.get(pCollectionId)) { addRunnerAndConsumersForPTransformRecursively( beamFnStateClient, - queueingClient, consumingPTransformId, processBundleDescriptor.getTransformsMap().get(consumingPTransformId), processBundleInstructionId, @@ -315,7 +313,7 @@ public ShortIdMap getShortIdMap() { @Override public BeamFnDataClient getBeamFnDataClient() { - return queueingClient; + return beamFnDataClient; } @Override @@ -378,9 +376,8 @@ public FnDataReceiver addOutgoingDataEndpoint( outboundAggregatorMap.computeIfAbsent( apiServiceDescriptor, asd -> - queueingClient.createOutboundAggregator( - asd, - processBundleInstructionId, + new BeamFnDataOutboundAggregator( + options, runnerCapabilities.contains( BeamUrns.getUrn( StandardRunnerProtocols.Enum @@ -391,21 +388,19 @@ public FnDataReceiver addOutgoingDataEndpoint( @Override public FnDataReceiver> addOutgoingTimersEndpoint( String timerFamilyId, org.apache.beam.sdk.coders.Coder> coder) { - BeamFnDataOutboundAggregator aggregator; if (!processBundleDescriptor.hasTimerApiServiceDescriptor()) { throw new IllegalStateException( String.format( - "Timers are unsupported because the " - + "ProcessBundleRequest %s does not provide a timer ApiServiceDescriptor.", + "Timers are unsupported because the ProcessBundleRequest %s does not" + + " provide a timer ApiServiceDescriptor.", processBundleInstructionId.get())); } - aggregator = + BeamFnDataOutboundAggregator aggregator = outboundAggregatorMap.computeIfAbsent( processBundleDescriptor.getTimerApiServiceDescriptor(), asd -> - queueingClient.createOutboundAggregator( - asd, - processBundleInstructionId, + new BeamFnDataOutboundAggregator( + options, runnerCapabilities.contains( BeamUrns.getUrn( StandardRunnerProtocols.Enum @@ -499,6 +494,8 @@ public BundleFinalizer getBundleFinalizer() { */ public BeamFnApi.InstructionResponse.Builder processBundle(InstructionRequest request) throws Exception { + String instructionId = request.getInstructionId(); + String dataStreamId = request.getProcessBundle().getDataStreamId(); @Nullable BundleProcessor bundleProcessor = null; try { bundleProcessor = @@ -515,13 +512,20 @@ public BeamFnApi.InstructionResponse.Builder processBundle(InstructionRequest re } })); + for (Map.Entry entry : + bundleProcessor.getOutboundAggregators().entrySet()) { + BeamFnDataOutboundAggregator aggregator = entry.getValue(); + aggregator.prepareForInstruction( + instructionId, beamFnDataClient.getOutboundObserver(entry.getKey(), dataStreamId)); + } + PTransformFunctionRegistry startFunctionRegistry = bundleProcessor.getStartFunctionRegistry(); PTransformFunctionRegistry finishFunctionRegistry = bundleProcessor.getFinishFunctionRegistry(); ExecutionStateTracker stateTracker = bundleProcessor.getStateTracker(); ProcessBundleResponse.Builder response = ProcessBundleResponse.newBuilder(); try (HandleStateCallsForBundle beamFnStateClient = bundleProcessor.getBeamFnStateClient()) { - stateTracker.start(request.getInstructionId()); + stateTracker.start(instructionId); try { // Already in reverse topological order so we don't need to do anything. for (ThrowingRunnable startFunction : startFunctionRegistry.getFunctions()) { @@ -545,12 +549,14 @@ public BeamFnApi.InstructionResponse.Builder processBundle(InstructionRequest re } else if (!bundleProcessor.getInboundEndpointApiServiceDescriptors().isEmpty()) { BeamFnDataInboundObserver observer = bundleProcessor.getInboundObserver(); beamFnDataClient.registerReceiver( - request.getInstructionId(), + instructionId, + dataStreamId, bundleProcessor.getInboundEndpointApiServiceDescriptors(), observer); observer.awaitCompletion(); beamFnDataClient.unregisterReceiver( - request.getInstructionId(), + instructionId, + dataStreamId, bundleProcessor.getInboundEndpointApiServiceDescriptors()); } @@ -581,7 +587,7 @@ public BeamFnApi.InstructionResponse.Builder processBundle(InstructionRequest re if (!bundleProcessor.getBundleFinalizationCallbackRegistrations().isEmpty()) { finalizeBundleHandler.registerCallbacks( - bundleProcessor.getInstructionId(), + instructionId, ImmutableList.copyOf(bundleProcessor.getBundleFinalizationCallbackRegistrations())); response.setRequiresFinalization(true); } @@ -599,7 +605,7 @@ public BeamFnApi.InstructionResponse.Builder processBundle(InstructionRequest re } catch (Exception e) { LOG.debug( "Error processing bundle {} with bundleProcessor for {} after exception", - request.getInstructionId(), + instructionId, request.getProcessBundle().getProcessBundleDescriptorId(), e); if (bundleProcessor != null) { @@ -607,7 +613,7 @@ public BeamFnApi.InstructionResponse.Builder processBundle(InstructionRequest re bundleProcessorCache.discard(bundleProcessor); } // Ensure that if more data arrives for the instruction it is discarded. - beamFnDataClient.poisonInstructionId(request.getInstructionId()); + beamFnDataClient.poisonInstructionId(instructionId); throw e; } } @@ -629,6 +635,10 @@ private void embedOutboundElementsIfApplicable( collectedElements.add(elements); } if (!hasFlushedAggregator) { + for (BeamFnDataOutboundAggregator aggregator : + bundleProcessor.getOutboundAggregators().values()) { + aggregator.finishInstruction(); + } Elements.Builder elementsToEmbed = Elements.newBuilder(); for (Elements collectedElement : collectedElements) { elementsToEmbed.mergeFrom(collectedElement); @@ -645,6 +655,7 @@ private void embedOutboundElementsIfApplicable( if (elements != null) { aggregator.sendElements(elements); } + aggregator.finishInstruction(); } } } @@ -875,7 +886,6 @@ public void afterBundleCommit(Instant callbackExpiry, Callback callback) { addRunnerAndConsumersForPTransformRecursively( beamFnStateClient, - beamFnDataClient, entry.getKey(), entry.getValue(), bundleProcessor::getInstructionId, diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataClient.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataClient.java index 94d59d0fcb62..1a50f5b448c5 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataClient.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataClient.java @@ -18,13 +18,12 @@ package org.apache.beam.fn.harness.data; import java.util.List; -import java.util.function.Supplier; import org.apache.beam.model.fnexecution.v1.BeamFnApi.Elements; import org.apache.beam.model.pipeline.v1.Endpoints; import org.apache.beam.model.pipeline.v1.Endpoints.ApiServiceDescriptor; -import org.apache.beam.sdk.fn.data.BeamFnDataOutboundAggregator; import org.apache.beam.sdk.fn.data.CloseableFnDataReceiver; import org.apache.beam.sdk.fn.data.FnDataReceiver; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver; /** * The {@link BeamFnDataClient} is able to forward inbound elements to a {@link FnDataReceiver} and @@ -47,6 +46,7 @@ public interface BeamFnDataClient { */ void registerReceiver( String instructionId, + String dataStreamId, List apiServiceDescriptors, CloseableFnDataReceiver receiver); @@ -58,7 +58,8 @@ void registerReceiver( * to the {@link BeamFnDataClient} during a future {@link FnDataReceiver#accept} invocation or via * a call to {@link #poisonInstructionId}. */ - void unregisterReceiver(String instructionId, List apiServiceDescriptors); + void unregisterReceiver( + String instructionId, String dataStreamId, List apiServiceDescriptors); /** * Poisons the instruction id, indicating that future data arriving for it should be discarded. @@ -68,22 +69,7 @@ void registerReceiver( */ void poisonInstructionId(String instructionId); - /** - * Creates a {@link BeamFnDataOutboundAggregator} for buffering and sending outbound data and - * timers over the data plane. It is important that {@link - * BeamFnDataOutboundAggregator#sendOrCollectBufferedDataAndFinishOutboundStreams()} is called on - * the returned BeamFnDataOutboundAggregator at the end of each bundle. If - * collectElementsIfNoFlushes is set to true, {@link - * BeamFnDataOutboundAggregator#sendOrCollectBufferedDataAndFinishOutboundStreams()} returns the - * buffered elements instead of sending it through the outbound StreamObserver if there's no - * previous flush. - * - *

Closing the returned aggregator signals the end of the streams. - * - *

The returned aggregator is not thread safe. - */ - BeamFnDataOutboundAggregator createOutboundAggregator( - Endpoints.ApiServiceDescriptor apiServiceDescriptor, - Supplier processBundleRequestIdSupplier, - boolean collectElementsIfNoFlushes); + /** Get the outbound observer for the specified apiServiceDescriptor and dataStreamId. */ + StreamObserver getOutboundObserver( + Endpoints.ApiServiceDescriptor apiServiceDescriptor, String dataStreamId); } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java index 499d816f8cc0..2f2a6b0fc660 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java @@ -18,20 +18,22 @@ package org.apache.beam.fn.harness.data; import java.util.List; +import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.function.Function; -import java.util.function.Supplier; +import javax.annotation.Nullable; import org.apache.beam.model.fnexecution.v1.BeamFnApi.Elements; import org.apache.beam.model.fnexecution.v1.BeamFnDataGrpc; import org.apache.beam.model.pipeline.v1.Endpoints; import org.apache.beam.model.pipeline.v1.Endpoints.ApiServiceDescriptor; import org.apache.beam.sdk.fn.data.BeamFnDataGrpcMultiplexer; -import org.apache.beam.sdk.fn.data.BeamFnDataOutboundAggregator; import org.apache.beam.sdk.fn.data.CloseableFnDataReceiver; import org.apache.beam.sdk.fn.stream.OutboundObserverFactory; -import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Metadata; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.MetadataUtils; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -44,17 +46,42 @@ public class BeamFnDataGrpcClient implements BeamFnDataClient { private static final Logger LOG = LoggerFactory.getLogger(BeamFnDataGrpcClient.class); - private final ConcurrentMap - multiplexerCache; + private static class MultiplexerKey { + private final Endpoints.ApiServiceDescriptor apiServiceDescriptor; + private final String dataStreamId; + + private MultiplexerKey( + Endpoints.ApiServiceDescriptor apiServiceDescriptor, String dataStreamId) { + this.apiServiceDescriptor = apiServiceDescriptor; + this.dataStreamId = dataStreamId; + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (!(o instanceof MultiplexerKey)) { + return false; + } + MultiplexerKey that = (MultiplexerKey) o; + return Objects.equals(dataStreamId, that.dataStreamId) + && Objects.equals(apiServiceDescriptor, that.apiServiceDescriptor); + } + + @Override + public int hashCode() { + return Objects.hash(apiServiceDescriptor, dataStreamId); + } + } + + private final ConcurrentMap multiplexerCache; private final Function channelFactory; private final OutboundObserverFactory outboundObserverFactory; - private final PipelineOptions options; public BeamFnDataGrpcClient( - PipelineOptions options, Function channelFactory, OutboundObserverFactory outboundObserverFactory) { - this.options = options; this.channelFactory = channelFactory; this.outboundObserverFactory = outboundObserverFactory; this.multiplexerCache = new ConcurrentHashMap<>(); @@ -63,21 +90,22 @@ public BeamFnDataGrpcClient( @Override public void registerReceiver( String instructionId, + String dataStreamId, List apiServiceDescriptors, CloseableFnDataReceiver receiver) { LOG.debug("Registering consumer for {}", instructionId); for (int i = 0, size = apiServiceDescriptors.size(); i < size; i++) { - BeamFnDataGrpcMultiplexer client = getClientFor(apiServiceDescriptors.get(i)); + BeamFnDataGrpcMultiplexer client = getMultiplexer(apiServiceDescriptors.get(i), dataStreamId); client.registerConsumer(instructionId, receiver); } } @Override public void unregisterReceiver( - String instructionId, List apiServiceDescriptors) { + String instructionId, String dataStreamId, List apiServiceDescriptors) { LOG.debug("Unregistering consumer for {}", instructionId); for (int i = 0, size = apiServiceDescriptors.size(); i < size; i++) { - BeamFnDataGrpcMultiplexer client = getClientFor(apiServiceDescriptors.get(i)); + BeamFnDataGrpcMultiplexer client = getMultiplexer(apiServiceDescriptors.get(i), dataStreamId); client.unregisterConsumer(instructionId); } } @@ -91,25 +119,32 @@ public void poisonInstructionId(String instructionId) { } @Override - public BeamFnDataOutboundAggregator createOutboundAggregator( - ApiServiceDescriptor apiServiceDescriptor, - Supplier processBundleRequestIdSupplier, - boolean collectElementsIfNoFlushes) { - return new BeamFnDataOutboundAggregator( - options, - processBundleRequestIdSupplier, - getClientFor(apiServiceDescriptor).getOutboundObserver(), - collectElementsIfNoFlushes); + public StreamObserver getOutboundObserver( + ApiServiceDescriptor apiServiceDescriptor, String dataStreamId) { + return getMultiplexer(apiServiceDescriptor, dataStreamId).getOutboundObserver(); } - private BeamFnDataGrpcMultiplexer getClientFor( - Endpoints.ApiServiceDescriptor apiServiceDescriptor) { + private BeamFnDataGrpcMultiplexer getMultiplexer( + Endpoints.ApiServiceDescriptor apiServiceDescriptor, String dataStreamId) { + MultiplexerKey key = new MultiplexerKey(apiServiceDescriptor, dataStreamId); return multiplexerCache.computeIfAbsent( - apiServiceDescriptor, - (Endpoints.ApiServiceDescriptor descriptor) -> - new BeamFnDataGrpcMultiplexer( - descriptor, - outboundObserverFactory, - BeamFnDataGrpc.newStub(channelFactory.apply(apiServiceDescriptor))::data)); + key, + k -> { + OutboundObserverFactory.BasicFactory baseOutboundObserverFactory = + inboundObserver -> { + BeamFnDataGrpc.BeamFnDataStub stub = + BeamFnDataGrpc.newStub(channelFactory.apply(apiServiceDescriptor)); + if (dataStreamId != null && !dataStreamId.isEmpty()) { + Metadata headers = new Metadata(); + headers.put( + Metadata.Key.of("data_stream_id", Metadata.ASCII_STRING_MARSHALLER), + dataStreamId); + stub = stub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(headers)); + } + return stub.data(inboundObserver); + }; + return new BeamFnDataGrpcMultiplexer( + apiServiceDescriptor, outboundObserverFactory, baseOutboundObserverFactory); + }); } } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java index 70a894e7b375..2882b75a2593 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java @@ -32,7 +32,6 @@ import java.util.Map; import java.util.ServiceLoader; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Supplier; import org.apache.beam.fn.harness.PTransformRunnerFactory.Registrar; import org.apache.beam.fn.harness.data.BeamFnDataClient; import org.apache.beam.model.fnexecution.v1.BeamFnApi; @@ -107,19 +106,29 @@ public void setUp() { MockitoAnnotations.initMocks(this); } - private BeamFnDataOutboundAggregator createRecordingAggregator( - Map>> output, Supplier bundleId) { + @Test + public void testReuseForMultipleBundles() throws Exception { + AtomicReference bundleId = new AtomicReference<>("0"); + String localInputId = "inputPC"; + RunnerApi.PTransform pTransform = + RemoteGrpcPortWrite.writeToPort(localInputId, PORT_SPEC).toPTransform(); + + List> output0 = new ArrayList<>(); + List> output1 = new ArrayList<>(); + Map aggregators = new HashMap<>(); + PipelineOptions options = PipelineOptionsFactory.create(); options.as(ExperimentalOptions.class).setExperiments(Arrays.asList("data_buffer_size_limit=0")); - return new BeamFnDataOutboundAggregator( - options, - bundleId, + BeamFnDataOutboundAggregator aggregator = new BeamFnDataOutboundAggregator(options, false); + + Map>> outputs = ImmutableMap.of("0", output0, "1", output1); + StreamObserver observer = new StreamObserver() { @Override public void onNext(Elements elements) { for (Data data : elements.getDataList()) { try { - output.get(bundleId.get()).add(WIRE_CODER.decode(data.getData().newInput())); + outputs.get(bundleId.get()).add(WIRE_CODER.decode(data.getData().newInput())); } catch (IOException e) { throw new RuntimeException("Failed to decode output."); } @@ -131,22 +140,9 @@ public void onError(Throwable throwable) {} @Override public void onCompleted() {} - }, - false); - } - - @Test - public void testReuseForMultipleBundles() throws Exception { - AtomicReference bundleId = new AtomicReference<>("0"); - String localInputId = "inputPC"; - RunnerApi.PTransform pTransform = - RemoteGrpcPortWrite.writeToPort(localInputId, PORT_SPEC).toPTransform(); + }; - List> output0 = new ArrayList<>(); - List> output1 = new ArrayList<>(); - Map aggregators = new HashMap<>(); - BeamFnDataOutboundAggregator aggregator = - createRecordingAggregator(ImmutableMap.of("0", output0, "1", output1), bundleId::get); + aggregator.prepareForInstruction(bundleId.get(), observer); aggregators.put(PORT_SPEC.getApiServiceDescriptor(), aggregator); PTransformRunnerFactoryTestContext context = @@ -172,18 +168,20 @@ public void testReuseForMultipleBundles() throws Exception { FnDataReceiver pCollectionConsumer = context.getPCollectionConsumer(localInputId); pCollectionConsumer.accept(valueInGlobalWindow("ABC")); pCollectionConsumer.accept(valueInGlobalWindow("DEF")); - assertThat(output0, contains(valueInGlobalWindow("ABC"), valueInGlobalWindow("DEF"))); + aggregator.finishInstruction(); output0.clear(); // Process for bundle id 1 bundleId.set("1"); + aggregator.prepareForInstruction(bundleId.get(), observer); pCollectionConsumer.accept(valueInGlobalWindow("GHI")); pCollectionConsumer.accept(valueInGlobalWindow("JKL")); assertThat(output1, contains(valueInGlobalWindow("GHI"), valueInGlobalWindow("JKL"))); + aggregator.finishInstruction(); verifyNoMoreInteractions(mockBeamFnDataClient); } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java index 50a2fec0b5a2..2aa555e83cd5 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java @@ -839,7 +839,7 @@ private class TestBeamFnDataOutboundAggregator extends BeamFnDataOutboundAggrega private Supplier processBundleRequestIdSupplier; public TestBeamFnDataOutboundAggregator(Supplier bundleIdSupplier) { - super(PipelineOptionsFactory.create(), bundleIdSupplier, null, false); + super(PipelineOptionsFactory.create(), false); this.timers = new HashMap<>(); this.dataOutput = new HashMap<>(); this.processBundleRequestIdSupplier = bundleIdSupplier; diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PTransformRunnerFactoryTestContext.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PTransformRunnerFactoryTestContext.java index 7b4387738a4c..51e49953b406 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PTransformRunnerFactoryTestContext.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PTransformRunnerFactoryTestContext.java @@ -54,6 +54,7 @@ import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.util.construction.Timer; import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver; import org.joda.time.Instant; /** @@ -74,6 +75,7 @@ public static Builder builder(String pTransformId, RunnerApi.PTransform pTransfo @Override public void registerReceiver( String instructionId, + String dataStreamId, List apiServiceDescriptors, CloseableFnDataReceiver receiver) { throw new UnsupportedOperationException("Unexpected call during test."); @@ -81,15 +83,15 @@ public void registerReceiver( @Override public void unregisterReceiver( - String instructionId, List apiServiceDescriptors) { + String instructionId, + String dataStreamId, + List apiServiceDescriptors) { throw new UnsupportedOperationException("Unexpected call during test."); } @Override - public BeamFnDataOutboundAggregator createOutboundAggregator( - ApiServiceDescriptor apiServiceDescriptor, - Supplier processBundleRequestIdSupplier, - boolean collectElementsIfNoFlushes) { + public StreamObserver getOutboundObserver( + ApiServiceDescriptor apiServiceDescriptor, String dataStreamId) { throw new UnsupportedOperationException("Unexpected call during test."); } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java index 47f85178b0a1..c03f82726740 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java @@ -37,7 +37,6 @@ import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.Mockito.argThat; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.eq; @@ -1071,28 +1070,20 @@ private ProcessBundleHandler setupProcessBundleHandlerForSimpleRecordingDoFn( dataOutput.add(input.getValue()); })); - Mockito.doAnswer( - (invocation) -> - new BeamFnDataOutboundAggregator( - PipelineOptionsFactory.create(), - invocation.getArgument(1), - new StreamObserver() { - @Override - public void onNext(Elements elements) { - for (Timers timer : elements.getTimersList()) { - timerOutput.addAll(elements.getTimersList()); - } - } + Mockito.when(beamFnDataClient.getOutboundObserver(any(), any())) + .thenReturn( + new StreamObserver() { + @Override + public void onNext(Elements elements) { + timerOutput.addAll(elements.getTimersList()); + } - @Override - public void onError(Throwable throwable) {} + @Override + public void onError(Throwable throwable) {} - @Override - public void onCompleted() {} - }, - invocation.getArgument(2))) - .when(beamFnDataClient) - .createOutboundAggregator(any(), any(), anyBoolean()); + @Override + public void onCompleted() {} + }); return new ProcessBundleHandler( PipelineOptionsFactory.create(), @@ -1409,7 +1400,7 @@ public void testInstructionIsUnregisteredFromBeamFnDataClientOnSuccess() throws (invocation) -> { String instructionId = invocation.getArgument(0, String.class); CloseableFnDataReceiver data = - invocation.getArgument(2, CloseableFnDataReceiver.class); + invocation.getArgument(3, CloseableFnDataReceiver.class); data.accept( BeamFnApi.Elements.newBuilder() .addData( @@ -1421,7 +1412,7 @@ public void testInstructionIsUnregisteredFromBeamFnDataClientOnSuccess() throws return null; }) .when(beamFnDataClient) - .registerReceiver(any(), any(), any()); + .registerReceiver(any(), any(), any(), any()); ProcessBundleHandler handler = new ProcessBundleHandler( @@ -1451,8 +1442,8 @@ public void testInstructionIsUnregisteredFromBeamFnDataClientOnSuccess() throws .build()); // Ensure that we unregister during successful processing - verify(beamFnDataClient).registerReceiver(eq("instructionId"), any(), any()); - verify(beamFnDataClient).unregisterReceiver(eq("instructionId"), any()); + verify(beamFnDataClient).registerReceiver(eq("instructionId"), any(), any(), any()); + verify(beamFnDataClient).unregisterReceiver(eq("instructionId"), any(), any()); verifyNoMoreInteractions(beamFnDataClient); } @@ -1475,7 +1466,7 @@ public void testDataProcessingExceptionsArePropagated() throws Exception { StringUtf8Coder.of().encode("A", encodedData); String instructionId = invocation.getArgument(0, String.class); CloseableFnDataReceiver data = - invocation.getArgument(2, CloseableFnDataReceiver.class); + invocation.getArgument(3, CloseableFnDataReceiver.class); data.accept( BeamFnApi.Elements.newBuilder() .addData( @@ -1489,7 +1480,7 @@ public void testDataProcessingExceptionsArePropagated() throws Exception { return null; }) .when(beamFnDataClient) - .registerReceiver(any(), any(), any()); + .registerReceiver(any(), any(), any(), any()); ProcessBundleHandler handler = new ProcessBundleHandler( @@ -1526,7 +1517,7 @@ public void testDataProcessingExceptionsArePropagated() throws Exception { .build())); // Ensure that we unregister during successful processing - verify(beamFnDataClient).registerReceiver(eq("instructionId"), any(), any()); + verify(beamFnDataClient).registerReceiver(eq("instructionId"), any(), any(), any()); verify(beamFnDataClient).poisonInstructionId(eq("instructionId")); verifyNoMoreInteractions(beamFnDataClient); } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java index 15f83f2582c7..9d9efa0b9c49 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java @@ -23,8 +23,8 @@ import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.empty; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; import java.util.Arrays; import java.util.Collection; @@ -49,6 +49,7 @@ import org.apache.beam.sdk.fn.data.LogicalEndpoint; import org.apache.beam.sdk.fn.stream.OutboundObserverFactory; import org.apache.beam.sdk.fn.test.TestStreams; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.values.WindowedValue; @@ -169,7 +170,6 @@ public StreamObserver data( BeamFnDataGrpcClient clientFactory = new BeamFnDataGrpcClient( - PipelineOptionsFactory.create(), (Endpoints.ApiServiceDescriptor descriptor) -> channel, OutboundObserverFactory.trivial()); @@ -183,7 +183,7 @@ public StreamObserver data( Collections.emptyList()); clientFactory.registerReceiver( - INSTRUCTION_ID_A, Arrays.asList(apiServiceDescriptor), observerA); + INSTRUCTION_ID_A, "", Arrays.asList(apiServiceDescriptor), observerA); waitForClientToConnect.await(); outboundServerObserver.get().onNext(ELEMENTS_A_1); @@ -193,7 +193,7 @@ public StreamObserver data( Thread.sleep(100); clientFactory.registerReceiver( - INSTRUCTION_ID_B, Arrays.asList(apiServiceDescriptor), observerB); + INSTRUCTION_ID_B, "", Arrays.asList(apiServiceDescriptor), observerB); // Show that out of order stream completion can occur. observerB.awaitCompletion(); @@ -245,7 +245,6 @@ public StreamObserver data( BeamFnDataGrpcClient clientFactory = new BeamFnDataGrpcClient( - PipelineOptionsFactory.create(), (Endpoints.ApiServiceDescriptor descriptor) -> channel, OutboundObserverFactory.trivial()); @@ -262,7 +261,7 @@ public StreamObserver data( Collections.emptyList()); clientFactory.registerReceiver( - INSTRUCTION_ID_A, Arrays.asList(apiServiceDescriptor), observer); + INSTRUCTION_ID_A, "", Arrays.asList(apiServiceDescriptor), observer); waitForClientToConnect.await(); @@ -270,12 +269,8 @@ public StreamObserver data( outboundServerObserver.get().onNext(ELEMENTS_A_1); outboundServerObserver.get().onNext(ELEMENTS_A_2); - try { - observer.awaitCompletion(); - fail("Expected channel to fail"); - } catch (Exception e) { - assertEquals(exceptionToThrow, e); - } + Exception e = assertThrows(Exception.class, observer::awaitCompletion); + assertEquals(exceptionToThrow, e); // The server should not have received any values assertThat(inboundServerValues, empty()); // The consumer should have only been invoked once @@ -321,7 +316,6 @@ public StreamObserver data( BeamFnDataGrpcClient clientFactory = new BeamFnDataGrpcClient( - PipelineOptionsFactory.create(), (Endpoints.ApiServiceDescriptor descriptor) -> channel, OutboundObserverFactory.trivial()); @@ -347,7 +341,7 @@ public StreamObserver data( }); clientFactory.registerReceiver( - INSTRUCTION_ID_A, Arrays.asList(apiServiceDescriptor), observerA); + INSTRUCTION_ID_A, "", Arrays.asList(apiServiceDescriptor), observerA); waitForClientToConnect.await(); outboundServerObserver.get().onNext(ELEMENTS_B_1); @@ -358,11 +352,9 @@ public StreamObserver data( assertTrue(receivedAElement.await(5, TimeUnit.SECONDS)); clientFactory.poisonInstructionId(INSTRUCTION_ID_A); - try { - future.get(); - fail(); // We expect the awaitCompletion to fail due to closing. - } catch (Exception ignored) { - } + // We expect the awaitCompletion to fail due to closing. + // Expected. + assertThrows(Exception.class, future::get); outboundServerObserver.get().onNext(ELEMENTS_A_2); @@ -404,16 +396,15 @@ public StreamObserver data( ManagedChannel channel = InProcessChannelBuilder.forName(apiServiceDescriptor.getUrl()).build(); + PipelineOptions options = + PipelineOptionsFactory.fromArgs("--experiments=data_buffer_size_limit=20").create(); BeamFnDataGrpcClient clientFactory = new BeamFnDataGrpcClient( - PipelineOptionsFactory.fromArgs( - new String[] {"--experiments=data_buffer_size_limit=20"}) - .create(), (Endpoints.ApiServiceDescriptor descriptor) -> channel, OutboundObserverFactory.trivial()); - BeamFnDataOutboundAggregator aggregator = - clientFactory.createOutboundAggregator( - apiServiceDescriptor, () -> INSTRUCTION_ID_A, false); + BeamFnDataOutboundAggregator aggregator = new BeamFnDataOutboundAggregator(options, false); + aggregator.prepareForInstruction( + INSTRUCTION_ID_A, clientFactory.getOutboundObserver(apiServiceDescriptor, "")); FnDataReceiver> fnDataReceiver = aggregator.registerOutputDataLocation(TRANSFORM_ID_A, CODER); fnDataReceiver.accept(valueInGlobalWindow("ABC"));