From eae00f3e3a7b2dec91613e6319ec3805054745e1 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Fri, 8 May 2026 17:57:10 +0200 Subject: [PATCH] Coop suspension WIP --- .../restate/sdk/core/HandlerContextImpl.java | 32 +- .../core/statemachine/AsyncResultsState.java | 306 +++++++++++++++++- .../sdk/core/statemachine/MessageType.java | 7 + .../core/statemachine/ProcessingState.java | 46 +-- .../sdk/core/statemachine/ReplayingState.java | 23 +- .../sdk/core/statemachine/RunState.java | 6 +- .../core/statemachine/ServiceProtocol.java | 8 +- .../restate/sdk/core/statemachine/State.java | 45 ++- .../sdk/core/statemachine/StateMachine.java | 18 +- .../core/statemachine/StateMachineImpl.java | 9 +- .../core/statemachine/UnresolvedFuture.java | 57 ++++ .../dev/restate/service/protocol.proto | 49 ++- .../dev/restate/sdk/core/CallTestSuite.java | 4 +- 13 files changed, 522 insertions(+), 88 deletions(-) create mode 100644 sdk-core/src/main/java/dev/restate/sdk/core/statemachine/UnresolvedFuture.java diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java index 57906ef86..aee0b1835 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/HandlerContextImpl.java @@ -18,6 +18,7 @@ import dev.restate.sdk.core.statemachine.InvocationState; import dev.restate.sdk.core.statemachine.NotificationValue; import dev.restate.sdk.core.statemachine.StateMachine; +import dev.restate.sdk.core.statemachine.UnresolvedFuture; import dev.restate.sdk.endpoint.definition.AsyncResult; import dev.restate.sdk.endpoint.definition.HandlerType; import dev.restate.sdk.endpoint.definition.ServiceType; @@ -28,7 +29,6 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import java.util.function.Consumer; -import java.util.stream.Stream; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.jspecify.annotations.Nullable; @@ -422,33 +422,39 @@ private void pollAsyncResultInner(AsyncResultInternal asyncResult) { asyncResult.tryComplete(this.stateMachine); // Now let's take the unprocessed leaves - List uncompletedLeaves = - Stream.concat(asyncResult.uncompletedLeaves(), Stream.of(CANCEL_HANDLE)).toList(); - if (uncompletedLeaves.size() == 1) { + List uncompletedLeaves = asyncResult.uncompletedLeaves().toList(); + if (uncompletedLeaves.isEmpty()) { // Nothing else to do! return; } + // Build the UnresolvedFuture from the leaf handles + UnresolvedFuture future = + uncompletedLeaves.size() == 1 + ? new UnresolvedFuture.Single(uncompletedLeaves.get(0)) + : new UnresolvedFuture.FirstCompleted( + uncompletedLeaves.stream() + .map(h -> (UnresolvedFuture) new UnresolvedFuture.Single(h)) + .toList()); + // Not ready yet, let's try to do some progress - StateMachine.DoProgressResponse response; + StateMachine.AwaitResponse response; try { - response = this.stateMachine.doProgress(uncompletedLeaves); + response = this.stateMachine.doAwait(future); } catch (Throwable e) { this.failWithoutContextSwitch(e); asyncResult.publicFuture().completeExceptionally(AbortedExecutionException.INSTANCE); return; } - if (response instanceof StateMachine.DoProgressResponse.AnyCompleted) { + if (response instanceof StateMachine.AwaitResponse.AnyCompleted) { // Let it loop now - } else if (response instanceof StateMachine.DoProgressResponse.ReadFromInput - || response instanceof StateMachine.DoProgressResponse.WaitingPendingRun) { + } else if (response instanceof StateMachine.AwaitResponse.WaitingExternalProgress wep) { this.stateMachine.onNextEvent( - () -> this.pollAsyncResultInner(asyncResult), - response instanceof StateMachine.DoProgressResponse.ReadFromInput); + () -> this.pollAsyncResultInner(asyncResult), wep.waitingInput()); return; - } else if (response instanceof StateMachine.DoProgressResponse.ExecuteRun) { - triggerScheduledRun(((StateMachine.DoProgressResponse.ExecuteRun) response).handle()); + } else if (response instanceof StateMachine.AwaitResponse.ExecuteRun) { + triggerScheduledRun(((StateMachine.AwaitResponse.ExecuteRun) response).handle()); // Let it loop now } } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/AsyncResultsState.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/AsyncResultsState.java index 0594624c4..f715e127f 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/AsyncResultsState.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/AsyncResultsState.java @@ -74,21 +74,198 @@ public int createHandleMapping(NotificationId notificationId) { return assignedHandle; } - public boolean processNextUntilAnyFound(Set ids) { - while (!toProcess.isEmpty()) { - Map.Entry notif = toProcess.removeFirst(); - boolean anyFound = ids.contains(notif.getKey()); - ready.put(notif.getKey(), notif.getValue()); - if (anyFound) { - return true; + /** + * Try to resolve the given future against available notifications. + * + *

Operates on a deep-mutable copy of {@code unresolved} so the caller's object is unchanged. + * + * @return {@link ResolveFutureResult.AnyCompleted} if the future can be resolved, or {@link + * ResolveFutureResult.WaitExternalInput} with the remaining (reduced) unresolved future if + * not. + */ + public ResolveFutureResult tryResolveFuture(UnresolvedFuture unresolved) { + // Work on a mutable copy so we can prune completed children in place. + unresolved = deepMutableCopy(unresolved); + + while (true) { + TryResolveResult result = tryResolveFutureInternal(unresolved); + + if (result == TryResolveResult.SHORT_CIRCUITED || result.handleState().isCompleted()) { + return ResolveFutureResult.ANY_COMPLETED; + } + + // Not completed yet — try popping the next queued notification and retry + if (!popNotificationQueue()) { + return new ResolveFutureResult.WaitExternalInput(unresolved); } } - return false; } + /** Create a deep copy of a future tree with all children stored in mutable {@link ArrayList}s. */ + private static UnresolvedFuture deepMutableCopy(UnresolvedFuture fut) { + if (fut instanceof UnresolvedFuture.Single) { + return fut; + } else if (fut instanceof UnresolvedFuture.FirstCompleted fc) { + var copy = new ArrayList(fc.children().size()); + for (var c : fc.children()) copy.add(deepMutableCopy(c)); + return new UnresolvedFuture.FirstCompleted(copy); + } else if (fut instanceof UnresolvedFuture.AllCompleted ac) { + var copy = new ArrayList(ac.children().size()); + for (var c : ac.children()) copy.add(deepMutableCopy(c)); + return new UnresolvedFuture.AllCompleted(copy); + } else if (fut instanceof UnresolvedFuture.FirstSucceededOrAllFailed fsaf) { + var copy = new ArrayList(fsaf.children().size()); + for (var c : fsaf.children()) copy.add(deepMutableCopy(c)); + return new UnresolvedFuture.FirstSucceededOrAllFailed(copy); + } else if (fut instanceof UnresolvedFuture.AllSucceededOrFirstFailed asff) { + var copy = new ArrayList(asff.children().size()); + for (var c : asff.children()) copy.add(deepMutableCopy(c)); + return new UnresolvedFuture.AllSucceededOrFirstFailed(copy); + } else if (fut instanceof UnresolvedFuture.Unknown u) { + var copy = new ArrayList(u.children().size()); + for (var c : u.children()) copy.add(deepMutableCopy(c)); + return new UnresolvedFuture.Unknown(copy); + } + throw new IllegalStateException("Unknown UnresolvedFuture type: " + fut); + } + + /** Returns false if there's nothing left in toProcess. */ + private boolean popNotificationQueue() { + Map.Entry notif = toProcess.pollFirst(); + if (notif == null) { + return false; + } + ready.put(notif.getKey(), notif.getValue()); + return true; + } + + /** + * Internal recursive resolution. Returns {@link TryResolveResult#SHORT_CIRCUITED} to signal early + * exit (a combinator completed and wants to propagate up). + * + *

This method mutates {@code unresolved} in place when children are removed (e.g. completed + * children are removed from AllCompleted lists). + */ + private TryResolveResult tryResolveFutureInternal(UnresolvedFuture unresolved) { + if (unresolved instanceof UnresolvedFuture.Single s) { + return new TryResolveResult(resolveHandleState(s.handle())); + + } else if (unresolved instanceof UnresolvedFuture.FirstCompleted fc) { + return resolveFirstCompleted(fc.children()); + + } else if (unresolved instanceof UnresolvedFuture.Unknown u) { + return resolveFirstCompleted(u.children()); + + } else if (unresolved instanceof UnresolvedFuture.AllCompleted ac) { + return resolveAllCompleted(ac.children()); + + } else if (unresolved instanceof UnresolvedFuture.FirstSucceededOrAllFailed fsaf) { + return resolveFirstSucceededOrAllFailed(fsaf.children()); + + } else if (unresolved instanceof UnresolvedFuture.AllSucceededOrFirstFailed asff) { + return resolveAllSucceededOrFirstFailed(asff.children()); + } + + throw new IllegalStateException("Unknown UnresolvedFuture type: " + unresolved); + } + + /** FirstCompleted / Unknown: resolve as soon as any child completes (success or failure). */ + private TryResolveResult resolveFirstCompleted(List children) { + for (UnresolvedFuture child : children) { + TryResolveResult childResult = tryResolveFutureInternal(child); + if (childResult == TryResolveResult.SHORT_CIRCUITED + || childResult.handleState().isCompleted()) { + children.clear(); + return TryResolveResult.SHORT_CIRCUITED; + } + } + return TryResolveResult.PENDING; + } + + /** AllCompleted: wait for every child to complete (success or failure). */ + private TryResolveResult resolveAllCompleted(List children) { + var it = children.listIterator(); + while (it.hasNext()) { + UnresolvedFuture child = it.next(); + TryResolveResult childResult = tryResolveFutureInternal(child); + if (childResult == TryResolveResult.SHORT_CIRCUITED) { + // A nested combinator short-circuited — propagate immediately + return TryResolveResult.SHORT_CIRCUITED; + } else if (childResult.handleState().isCompleted()) { + it.remove(); + } + } + if (children.isEmpty()) { + return new TryResolveResult(HandleState.SUCCEEDED); + } + return TryResolveResult.PENDING; + } + + /** FirstSucceededOrAllFailed: first success wins; fail only if all fail. */ + private TryResolveResult resolveFirstSucceededOrAllFailed(List children) { + var it = children.listIterator(); + while (it.hasNext()) { + UnresolvedFuture child = it.next(); + TryResolveResult childResult = tryResolveFutureInternal(child); + if (childResult == TryResolveResult.SHORT_CIRCUITED) { + // A nested combinator short-circuited — treat as succeeded, propagate + children.clear(); + return TryResolveResult.SHORT_CIRCUITED; + } + HandleState state = childResult.handleState(); + if (state == HandleState.SUCCEEDED) { + children.clear(); + return TryResolveResult.SHORT_CIRCUITED; + } else if (state == HandleState.FAILED) { + it.remove(); + } + } + if (children.isEmpty()) { + return new TryResolveResult(HandleState.FAILED); + } + return TryResolveResult.PENDING; + } + + /** AllSucceededOrFirstFailed: all must succeed; first failure short-circuits. */ + private TryResolveResult resolveAllSucceededOrFirstFailed(List children) { + var it = children.listIterator(); + while (it.hasNext()) { + UnresolvedFuture child = it.next(); + TryResolveResult childResult = tryResolveFutureInternal(child); + if (childResult == TryResolveResult.SHORT_CIRCUITED) { + // A nested combinator short-circuited — propagate immediately + return TryResolveResult.SHORT_CIRCUITED; + } + HandleState state = childResult.handleState(); + if (state == HandleState.FAILED) { + children.clear(); + return TryResolveResult.SHORT_CIRCUITED; + } else if (state == HandleState.SUCCEEDED) { + it.remove(); + } + } + if (children.isEmpty()) { + return new TryResolveResult(HandleState.SUCCEEDED); + } + return TryResolveResult.PENDING; + } + + private HandleState resolveHandleState(int handle) { + NotificationId id = handleMapping.get(handle); + if (id == null) { + return HandleState.PENDING; + } + NotificationValue val = ready.get(id); + if (val == null) { + return HandleState.PENDING; + } + return (val instanceof NotificationValue.Failure) ? HandleState.FAILED : HandleState.SUCCEEDED; + } + + /** After {@code take_handle} the mapping is gone, so unknown handles are treated as completed. */ public boolean isHandleCompleted(int handle) { NotificationId id = handleMapping.get(handle); - return id != null && ready.containsKey(id); + return id == null || ready.containsKey(id); } public boolean nonDeterministicFindId(NotificationId id) { @@ -128,4 +305,115 @@ public Optional takeHandle(int handle) { } return Optional.empty(); } + + public Optional copyHandle(int handle) { + NotificationId id = handleMapping.get(handle); + if (id == null) { + return Optional.empty(); + } + return Optional.ofNullable(ready.get(id)); + } + + /** + * Convert an {@link UnresolvedFuture} tree to the wire-format {@link Protocol.Future} message. + * Single children are inlined into the parent's waiting_* fields; all other children become + * nested Future messages. + */ + public Protocol.Future resolveUnresolvedFuture(UnresolvedFuture unresolved) { + var builder = Protocol.Future.newBuilder(); + + if (unresolved instanceof UnresolvedFuture.Single s) { + builder.setCombinatorType(Protocol.CombinatorType.FIRST_COMPLETED); + pushHandle(builder, s.handle()); + return builder.build(); + } + + List children; + if (unresolved instanceof UnresolvedFuture.Unknown u) { + builder.setCombinatorType(Protocol.CombinatorType.COMBINATOR_UNKNOWN); + children = u.children(); + } else if (unresolved instanceof UnresolvedFuture.FirstCompleted fc) { + builder.setCombinatorType(Protocol.CombinatorType.FIRST_COMPLETED); + children = fc.children(); + } else if (unresolved instanceof UnresolvedFuture.AllCompleted ac) { + builder.setCombinatorType(Protocol.CombinatorType.ALL_COMPLETED); + children = ac.children(); + } else if (unresolved instanceof UnresolvedFuture.FirstSucceededOrAllFailed fsaf) { + builder.setCombinatorType(Protocol.CombinatorType.FIRST_SUCCEEDED_OR_ALL_FAILED); + children = fsaf.children(); + } else if (unresolved instanceof UnresolvedFuture.AllSucceededOrFirstFailed asff) { + builder.setCombinatorType(Protocol.CombinatorType.ALL_SUCCEEDED_OR_FIRST_FAILED); + children = asff.children(); + } else { + throw new IllegalStateException("Unknown UnresolvedFuture type: " + unresolved); + } + + for (UnresolvedFuture child : children) { + if (child instanceof UnresolvedFuture.Single s) { + pushHandle(builder, s.handle()); + } else { + builder.addNestedFutures(resolveUnresolvedFuture(child)); + } + } + + return builder.build(); + } + + private void pushHandle(Protocol.Future.Builder builder, int handle) { + NotificationId id = handleMapping.get(handle); + if (id == null) { + return; + } + if (id instanceof NotificationId.CompletionId cid) { + builder.addWaitingCompletions(cid.id()); + } else if (id instanceof NotificationId.SignalId sid) { + builder.addWaitingSignals(sid.id()); + } else if (id instanceof NotificationId.SignalName sn) { + builder.addWaitingNamedSignals(sn.name()); + } + } + + // --- Inner types --- + + sealed interface ResolveFutureResult + permits ResolveFutureResult.AnyCompleted, ResolveFutureResult.WaitExternalInput { + + ResolveFutureResult ANY_COMPLETED = new AnyCompleted(); + + record AnyCompleted() implements ResolveFutureResult {} + + record WaitExternalInput(UnresolvedFuture remaining) implements ResolveFutureResult {} + } + + private enum HandleState { + SUCCEEDED, + FAILED, + PENDING; + + boolean isCompleted() { + return this == SUCCEEDED || this == FAILED; + } + } + + /** + * Wrapper for the internal resolution result. A sentinel {@link #SHORT_CIRCUITED} value signals + * that a nested combinator completed and the loop should stop. + */ + private static final class TryResolveResult { + static final TryResolveResult SHORT_CIRCUITED = new TryResolveResult(null); + static final TryResolveResult PENDING = new TryResolveResult(HandleState.PENDING); + + private final HandleState state; + + private TryResolveResult(HandleState state) { + this.state = state; + } + + HandleState handleState() { + if (state == null) { + throw new IllegalStateException("SHORT_CIRCUITED has no HandleState"); + } + return state; + } + } } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageType.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageType.java index b4dd6e10f..f207b1d51 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageType.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/MessageType.java @@ -19,6 +19,7 @@ public enum MessageType { ErrorMessage, EndMessage, ProposeRunCompletionMessage, + AwaitingOnMessage, InputCommandMessage, OutputCommandMessage, @@ -58,6 +59,7 @@ public enum MessageType { public static final short ErrorMessage_TYPE = (short) 0x0002; public static final short EndMessage_TYPE = (short) 0x0003; public static final short ProposeRunCompletionMessage_TYPE = (short) 0x0005; + public static final short AwaitingOnMessage_TYPE = (short) 0x0006; public static final short InputCommandMessage_TYPE = (short) 0x0400; public static final short OutputCommandMessage_TYPE = (short) 0x0401; public static final short GetLazyStateCommandMessage_TYPE = (short) 0x0402; @@ -98,6 +100,7 @@ public Parser messageParser() { case ErrorMessage -> Protocol.ErrorMessage.parser(); case EndMessage -> Protocol.EndMessage.parser(); case ProposeRunCompletionMessage -> Protocol.ProposeRunCompletionMessage.parser(); + case AwaitingOnMessage -> Protocol.AwaitingOnMessage.parser(); case InputCommandMessage -> Protocol.InputCommandMessage.parser(); case OutputCommandMessage -> Protocol.OutputCommandMessage.parser(); case GetLazyStateCommandMessage -> Protocol.GetLazyStateCommandMessage.parser(); @@ -141,6 +144,7 @@ public short encode() { case ErrorMessage -> ErrorMessage_TYPE; case EndMessage -> EndMessage_TYPE; case ProposeRunCompletionMessage -> ProposeRunCompletionMessage_TYPE; + case AwaitingOnMessage -> AwaitingOnMessage_TYPE; case InputCommandMessage -> InputCommandMessage_TYPE; case OutputCommandMessage -> OutputCommandMessage_TYPE; case GetLazyStateCommandMessage -> GetLazyStateCommandMessage_TYPE; @@ -236,6 +240,7 @@ public static MessageType decode(short value) throws ProtocolException { case ErrorMessage_TYPE -> ErrorMessage; case EndMessage_TYPE -> EndMessage; case ProposeRunCompletionMessage_TYPE -> ProposeRunCompletionMessage; + case AwaitingOnMessage_TYPE -> AwaitingOnMessage; case InputCommandMessage_TYPE -> InputCommandMessage; case OutputCommandMessage_TYPE -> OutputCommandMessage; case GetLazyStateCommandMessage_TYPE -> GetLazyStateCommandMessage; @@ -290,6 +295,8 @@ public static MessageType fromMessage(MessageLite msg) { return MessageType.EndMessage; } else if (msg instanceof Protocol.ProposeRunCompletionMessage) { return MessageType.ProposeRunCompletionMessage; + } else if (msg instanceof Protocol.AwaitingOnMessage) { + return MessageType.AwaitingOnMessage; } else if (msg instanceof Protocol.InputCommandMessage) { return MessageType.InputCommandMessage; } else if (msg instanceof Protocol.OutputCommandMessage) { diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ProcessingState.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ProcessingState.java index 5300444d0..3efc73f19 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ProcessingState.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ProcessingState.java @@ -20,9 +20,9 @@ import dev.restate.sdk.core.ExceptionUtils; import dev.restate.sdk.core.ProtocolException; import dev.restate.sdk.core.generated.protocol.Protocol; -import dev.restate.sdk.core.statemachine.StateMachine.DoProgressResponse; +import dev.restate.sdk.core.statemachine.AsyncResultsState.ResolveFutureResult; +import dev.restate.sdk.core.statemachine.StateMachine.AwaitResponse; import java.time.Duration; -import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; import org.apache.logging.log4j.LogManager; @@ -61,35 +61,43 @@ public void onNewMessage( } @Override - public DoProgressResponse doProgress(List awaitingOn, StateContext stateContext) { - if (awaitingOn.stream().anyMatch(this.asyncResultsState::isHandleCompleted)) { - return DoProgressResponse.AnyCompleted.INSTANCE; - } + public AwaitResponse doAwait(UnresolvedFuture future, StateContext stateContext) { + ResolveFutureResult resolveResult = asyncResultsState.tryResolveFuture(future); - var notificationIds = asyncResultsState.resolveNotificationHandles(awaitingOn); - if (notificationIds.isEmpty()) { - return DoProgressResponse.AnyCompleted.INSTANCE; + if (resolveResult instanceof ResolveFutureResult.AnyCompleted) { + return AwaitResponse.AnyCompleted.INSTANCE; } - if (asyncResultsState.processNextUntilAnyFound(notificationIds)) { - return DoProgressResponse.AnyCompleted.INSTANCE; - } + var remaining = ((ResolveFutureResult.WaitExternalInput) resolveResult).remaining(); + var awaitingOnHandles = remaining.handles(); - Integer maybeRunHandle = runState.tryExecuteRun(awaitingOn); + Integer maybeRunHandle = runState.tryExecuteRun(awaitingOnHandles); if (maybeRunHandle != null) { - return new DoProgressResponse.ExecuteRun(maybeRunHandle); + return new AwaitResponse.ExecuteRun(maybeRunHandle); } + boolean waitingRunProposal = runState.anyExecutingInThisSet(awaitingOnHandles); + if (stateContext.isInputClosed()) { - if (runState.anyExecuting(awaitingOn)) { - return DoProgressResponse.WaitingPendingRun.INSTANCE; + if (waitingRunProposal) { + return new AwaitResponse.WaitingExternalProgress(false, true); } - - this.hitSuspended(notificationIds, stateContext); + this.hitSuspended(remaining, asyncResultsState, stateContext); ExceptionUtils.sneakyThrow(AbortedExecutionException.INSTANCE); } - return DoProgressResponse.ReadFromInput.INSTANCE; + // Send AwaitingOnMessage if protocol version >= V7 + if (stateContext.getNegotiatedProtocolVersion().getNumber() + >= Protocol.ServiceProtocolVersion.V7_VALUE + && !runState.anyExecuting()) { + stateContext.maybeWriteMessageOut( + Protocol.AwaitingOnMessage.newBuilder() + .setAwaitingOn(asyncResultsState.resolveUnresolvedFuture(remaining)) + .setExecutingSideEffects(false) + .build()); + } + + return new AwaitResponse.WaitingExternalProgress(true, waitingRunProposal); } @Override diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ReplayingState.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ReplayingState.java index c88974797..052ebcd02 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ReplayingState.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ReplayingState.java @@ -17,7 +17,8 @@ import dev.restate.sdk.core.ExceptionUtils; import dev.restate.sdk.core.ProtocolException; import dev.restate.sdk.core.generated.protocol.Protocol; -import dev.restate.sdk.core.statemachine.StateMachine.DoProgressResponse; +import dev.restate.sdk.core.statemachine.AsyncResultsState.ResolveFutureResult; +import dev.restate.sdk.core.statemachine.StateMachine.AwaitResponse; import java.util.*; import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; @@ -87,19 +88,14 @@ public void onNewMessage( } @Override - public DoProgressResponse doProgress(List awaitingOn, StateContext stateContext) { - if (awaitingOn.stream().anyMatch(this.asyncResultsState::isHandleCompleted)) { - return DoProgressResponse.AnyCompleted.INSTANCE; - } + public AwaitResponse doAwait(UnresolvedFuture future, StateContext stateContext) { + ResolveFutureResult resolveResult = asyncResultsState.tryResolveFuture(future); - var notificationIds = asyncResultsState.resolveNotificationHandles(awaitingOn); - if (notificationIds.isEmpty()) { - return DoProgressResponse.AnyCompleted.INSTANCE; + if (resolveResult instanceof ResolveFutureResult.AnyCompleted) { + return AwaitResponse.AnyCompleted.INSTANCE; } - if (asyncResultsState.processNextUntilAnyFound(notificationIds)) { - return DoProgressResponse.AnyCompleted.INSTANCE; - } + var remaining = ((ResolveFutureResult.WaitExternalInput) resolveResult).remaining(); // This assertion proves the user mutated the code, adding an await point. // @@ -113,12 +109,15 @@ public DoProgressResponse doProgress(List awaitingOn, StateContext stat // This contradiction proves the code was mutated: an await must have been added after // the journal was originally created. + var awaitingOnHandles = remaining.handles(); + var notificationIds = asyncResultsState.resolveNotificationHandles(awaitingOnHandles); + // Prepare error metadata to make it easier to debug Map knownNotificationMetadata = new HashMap<>(); CommandRelationship relatedCommand = null; // Collect run info - for (int handle : awaitingOn) { + for (int handle : awaitingOnHandles) { RunState.Run runInfo = runState.getRunInfo(handle); if (runInfo != null) { var notifId = asyncResultsState.mustResolveNotificationHandle(handle); diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/RunState.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/RunState.java index dedae9b66..0fe385113 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/RunState.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/RunState.java @@ -36,11 +36,15 @@ public void insertRunToExecute(int handle, int commandIndex, String commandName) return runs.get(handle); } - public boolean anyExecuting(Collection anyHandle) { + public boolean anyExecutingInThisSet(Collection anyHandle) { return anyHandle.stream() .anyMatch(h -> runs.containsKey(h) && runs.get(h).state == RunStateInner.Executing); } + public boolean anyExecuting() { + return runs.values().stream().anyMatch(r -> r.state == RunStateInner.Executing); + } + /** * Notifies that execution has completed for the given handle. * diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ServiceProtocol.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ServiceProtocol.java index 214fb0e2b..0afb476eb 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ServiceProtocol.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/ServiceProtocol.java @@ -15,7 +15,7 @@ public class ServiceProtocol { public static final Protocol.ServiceProtocolVersion MIN_SERVICE_PROTOCOL_VERSION = Protocol.ServiceProtocolVersion.V5; public static final Protocol.ServiceProtocolVersion MAX_SERVICE_PROTOCOL_VERSION = - Protocol.ServiceProtocolVersion.V6; + Protocol.ServiceProtocolVersion.V7; static final String CONTENT_TYPE = "content-type"; @@ -43,6 +43,9 @@ static Protocol.ServiceProtocolVersion parseServiceProtocolVersion(String versio if (version.equals("application/vnd.restate.invocation.v6")) { return Protocol.ServiceProtocolVersion.V6; } + if (version.equals("application/vnd.restate.invocation.v7")) { + return Protocol.ServiceProtocolVersion.V7; + } return Protocol.ServiceProtocolVersion.SERVICE_PROTOCOL_VERSION_UNSPECIFIED; } @@ -65,6 +68,9 @@ static String serviceProtocolVersionToHeaderValue(Protocol.ServiceProtocolVersio if (Objects.requireNonNull(version) == Protocol.ServiceProtocolVersion.V6) { return "application/vnd.restate.invocation.v6"; } + if (Objects.requireNonNull(version) == Protocol.ServiceProtocolVersion.V7) { + return "application/vnd.restate.invocation.v7"; + } throw new IllegalArgumentException( String.format("Service protocol version '%s' has no header value", version.getNumber())); } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/State.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/State.java index 5c8cb6758..fd33694a0 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/State.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/State.java @@ -17,8 +17,6 @@ import java.io.PrintWriter; import java.io.StringWriter; import java.time.Duration; -import java.util.Collection; -import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; import org.apache.logging.log4j.LogManager; @@ -41,8 +39,7 @@ default void onNewMessage( throw ProtocolException.badState(this); } - default StateMachine.DoProgressResponse doProgress( - List anyHandle, StateContext stateContext) { + default StateMachine.AwaitResponse doAwait(UnresolvedFuture future, StateContext stateContext) { throw ProtocolException.badState(this); } @@ -162,27 +159,45 @@ default void hitError( stateContext.closeOutputSubscriber(); } - default void hitSuspended(Collection awaitingOn, StateContext stateContext) { + default void hitSuspended( + UnresolvedFuture awaitingOn, AsyncResultsState asyncResultsState, StateContext stateContext) { LOG.info("Invocation suspended"); LOG.debug("Awaiting on {}", awaitingOn); - var suspensionMessageBuilder = Protocol.SuspensionMessage.newBuilder(); - for (var notificationId : awaitingOn) { - if (notificationId instanceof NotificationId.CompletionId completionId) { - suspensionMessageBuilder.addWaitingCompletions(completionId.id()); - } else if (notificationId instanceof NotificationId.SignalId signalId) { - suspensionMessageBuilder.addWaitingSignals(signalId.id()); - } else if (notificationId instanceof NotificationId.SignalName signalName) { - suspensionMessageBuilder.addWaitingNamedSignals(signalName.name()); - } + Protocol.SuspensionMessage suspensionMessage; + if (stateContext.getNegotiatedProtocolVersion().getNumber() + >= Protocol.ServiceProtocolVersion.V7_VALUE) { + var future = asyncResultsState.resolveUnresolvedFuture(awaitingOn); + suspensionMessage = Protocol.SuspensionMessage.newBuilder().setAwaitingOn(future).build(); + } else { + // V6 format: flatten the future tree into the flat waiting_* lists + suspensionMessage = buildV6SuspensionMessage(awaitingOn, asyncResultsState); } - stateContext.maybeWriteMessageOut(suspensionMessageBuilder.build()); + stateContext.maybeWriteMessageOut(suspensionMessage); stateContext.getStateHolder().transition(new ClosedState()); stateContext.closeOutputSubscriber(); } + private static Protocol.SuspensionMessage buildV6SuspensionMessage( + UnresolvedFuture awaitingOn, AsyncResultsState asyncResultsState) { + var builder = Protocol.SuspensionMessage.newBuilder(); + for (int handle : awaitingOn.handles()) { + var notifId = asyncResultsState.resolveNotificationHandles(java.util.List.of(handle)); + for (var id : notifId) { + if (id instanceof NotificationId.CompletionId c) { + builder.addWaitingCompletions(c.id()); + } else if (id instanceof NotificationId.SignalId s) { + builder.addWaitingSignals(s.id()); + } else if (id instanceof NotificationId.SignalName n) { + builder.addWaitingNamedSignals(n.name()); + } + } + } + return builder.build(); + } + default void end(StateContext stateContext) { LOG.info("Invocation ended"); diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachine.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachine.java index 14d810c11..e555a1659 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachine.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachine.java @@ -15,7 +15,6 @@ import dev.restate.sdk.endpoint.HeadersAccessor; import java.time.Duration; import java.util.Collection; -import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; @@ -48,23 +47,18 @@ static StateMachine init( // --- Async results - sealed interface DoProgressResponse { - record AnyCompleted() implements DoProgressResponse { + sealed interface AwaitResponse { + record AnyCompleted() implements AwaitResponse { static AnyCompleted INSTANCE = new AnyCompleted(); } - record ReadFromInput() implements DoProgressResponse { - static ReadFromInput INSTANCE = new ReadFromInput(); - } - - record ExecuteRun(int handle) implements DoProgressResponse {} + record WaitingExternalProgress(boolean waitingInput, boolean waitingRunProposal) + implements AwaitResponse {} - record WaitingPendingRun() implements DoProgressResponse { - static WaitingPendingRun INSTANCE = new WaitingPendingRun(); - } + record ExecuteRun(int handle) implements AwaitResponse {} } - DoProgressResponse doProgress(List anyHandle); + AwaitResponse doAwait(UnresolvedFuture future); boolean isCompleted(int handle); diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachineImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachineImpl.java index 5d9a5ddfd..94051522a 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachineImpl.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/StateMachineImpl.java @@ -178,8 +178,13 @@ public String getResponseContentType() { } @Override - public DoProgressResponse doProgress(List anyHandle) { - return this.stateContext.getCurrentState().doProgress(anyHandle, this.stateContext); + public StateMachine.AwaitResponse doAwait(UnresolvedFuture future) { + // Wrap with cancel signal for implicit cancellation support + var futureWithCancellation = + new UnresolvedFuture.FirstCompleted( + List.of( + future, new UnresolvedFuture.Single(AsyncResultsState.CANCEL_NOTIFICATION_HANDLE))); + return this.stateContext.getCurrentState().doAwait(futureWithCancellation, this.stateContext); } @Override diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/UnresolvedFuture.java b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/UnresolvedFuture.java new file mode 100644 index 000000000..ed55c4aa5 --- /dev/null +++ b/sdk-core/src/main/java/dev/restate/sdk/core/statemachine/UnresolvedFuture.java @@ -0,0 +1,57 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.statemachine; + +import java.util.ArrayList; +import java.util.List; + +/** Represents an unresolved future tree that the SDK is awaiting on. */ +public sealed interface UnresolvedFuture + permits UnresolvedFuture.Single, + UnresolvedFuture.FirstCompleted, + UnresolvedFuture.AllCompleted, + UnresolvedFuture.FirstSucceededOrAllFailed, + UnresolvedFuture.AllSucceededOrFirstFailed, + UnresolvedFuture.Unknown { + + record Single(int handle) implements UnresolvedFuture {} + + record FirstCompleted(List children) implements UnresolvedFuture {} + + record AllCompleted(List children) implements UnresolvedFuture {} + + record FirstSucceededOrAllFailed(List children) implements UnresolvedFuture {} + + record AllSucceededOrFirstFailed(List children) implements UnresolvedFuture {} + + record Unknown(List children) implements UnresolvedFuture {} + + /** Collect all leaf handles from this future tree. */ + default List handles() { + var result = new ArrayList(); + collectHandles(this, result); + return result; + } + + private static void collectHandles(UnresolvedFuture fut, List out) { + if (fut instanceof Single s) { + out.add(s.handle()); + } else if (fut instanceof FirstCompleted fc) { + fc.children().forEach(c -> collectHandles(c, out)); + } else if (fut instanceof AllCompleted ac) { + ac.children().forEach(c -> collectHandles(c, out)); + } else if (fut instanceof FirstSucceededOrAllFailed fsaf) { + fsaf.children().forEach(c -> collectHandles(c, out)); + } else if (fut instanceof AllSucceededOrFirstFailed asff) { + asff.children().forEach(c -> collectHandles(c, out)); + } else if (fut instanceof Unknown u) { + u.children().forEach(c -> collectHandles(c, out)); + } + } +} diff --git a/sdk-core/src/main/service-protocol/dev/restate/service/protocol.proto b/sdk-core/src/main/service-protocol/dev/restate/service/protocol.proto index 0a6696533..49224e589 100644 --- a/sdk-core/src/main/service-protocol/dev/restate/service/protocol.proto +++ b/sdk-core/src/main/service-protocol/dev/restate/service/protocol.proto @@ -36,6 +36,9 @@ enum ServiceProtocolVersion { // * StartMessage.random_seed // * Failure.metadata V6 = 6; + // Added: + // * Future & AwaitingOnMessage + Changed the SuspensionMessage + V7 = 7; } // --- Core frames --- @@ -83,16 +86,46 @@ message StartMessage { uint64 random_seed = 9; } +// Defines how a set of child futures are combined. +enum CombinatorType { + // Should be treated as FIRST_COMPLETED. + COMBINATOR_UNKNOWN = 0; + // Resolve as soon as any one child future completes with success, or with failure (same as JS Promise.race). + FIRST_COMPLETED = 1; + // Wait for every child to complete, regardless of success or failure (same as JS Promise.allSettled). + ALL_COMPLETED = 2; + // Resolve on the first success; fail only if all children fail (same as JS Promise.any). + FIRST_SUCCEEDED_OR_ALL_FAILED = 3; + // Resolve when all children succeed; short-circuit on the first failure (same as JS Promise.all). + ALL_SUCCEEDED_OR_FIRST_FAILED = 4; +} + +// Recursively describes an await point as a tree of future combinators. +// +// Leaf data is the set of notification IDs this node is waiting for. +// For representation purposes, the list of notification ids is flattened in 3 lists to avoid a per-element oneof wrapper. +// Inner nodes combine their children (leaves + nested) via `combinator_type`. +message Future { + repeated uint32 waiting_completions = 1; + repeated uint32 waiting_signals = 2; + repeated string waiting_named_signals = 3; + repeated Future nested_futures = 4; + CombinatorType combinator_type = 5; +} + // Type: 0x0000 + 1 // Implementations MUST send this message when suspending an invocation. // -// These lists represent any of the notification_idx and/or notification_name the invocation is waiting on to progress. -// The runtime will resume the invocation as soon as either one of the given notification_idx or notification_name is completed. -// Between the two lists there MUST be at least one element. +// V6 and earlier: populate waiting_completions/signals/named_signals (fields 1-3). +// V7 and later: populate awaiting_on (field 4) with the full future tree. +// The field numbers are disjoint, so both formats can coexist in the same message definition. message SuspensionMessage { + // V6: flat lists of notification IDs the invocation is waiting on. repeated uint32 waiting_completions = 1; repeated uint32 waiting_signals = 2; repeated string waiting_named_signals = 3; + // V7: full future tree describing the await point. + Future awaiting_on = 4; } // Type: 0x0000 + 2 @@ -142,6 +175,16 @@ message ProposeRunCompletionMessage { }; } +// Type: 0x0000 + 6 +// The SDK MAY send this message to the runtime when inside an await point, to notify what the user code is currently blocked awaiting on. +// This information SHOULD be considered outdated by the runtime as soon as a notification in the future tree is sent over. +message AwaitingOnMessage { + // Describes the await point. + Future awaiting_on = 1; + // True if any of the notifications the SDK is awaiting on are side effects the SDK is currently executing. + bool executing_side_effects = 2; +} + // --- Commands and Notifications --- // The Journal is modelled as commands and notifications. diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/CallTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/CallTestSuite.java index 511edeb86..be4f79fc2 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/CallTestSuite.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/CallTestSuite.java @@ -61,7 +61,9 @@ public Stream definitions() { callCmd(1, 2, GREETER_SERVICE_TARGET, BODY.toByteArray()), CANCELLATION_SIGNAL) .onlyBidiStream() - .expectingOutput(Protocol.SuspensionMessage.newBuilder().addWaitingCompletions(1)) + // Cancel handle was consumed before suspension, so no waiting_signals here + .expectingOutput( + Protocol.SuspensionMessage.newBuilder().addWaitingCompletions(1).build()) .named("Suspends on waiting the invocation id"), implicitCancellation(GREETER_SERVICE_TARGET, BODY) .withInput(