diff --git a/src/ChibiRuby/RFiber.cs b/src/ChibiRuby/RFiber.cs index 6b1f6f8..a4019de 100644 --- a/src/ChibiRuby/RFiber.cs +++ b/src/ChibiRuby/RFiber.cs @@ -28,6 +28,8 @@ public sealed class RFiber : RObject readonly MRubyContext context = new(); readonly MRubyState state; readonly MultiConsumerValueTaskNotifier resumeSource = new(); + readonly TaskCompletionSource terminationSource = + new(TaskCreationOptions.RunContinuationsAsynchronously); internal RFiber(MRubyState state, RClass c) : base(MRubyVType.Fiber, c) { @@ -48,17 +50,17 @@ public ValueTask WaitForResumeAsync(CancellationToken cancellation = return resumeSource.WaitAsync(cancellation); } - public async ValueTask WaitForTerminateAsync(CancellationToken cancellation = default) + public ValueTask WaitForTerminateAsync(CancellationToken cancellation = default) { - // Wait for fiber completion - MRubyValue result = default; - while (IsAlive) + var task = terminationSource.Task; + if (cancellation.CanBeCanceled) { - var wait = WaitForResumeAsync(cancellation); - if (wait.IsCompleted) continue; - result = await wait; + cancellation.Register(() => + { + terminationSource.TrySetCanceled(cancellation); + }); } - return result; + return new ValueTask(task); } public async IAsyncEnumerable AsAsyncEnumerable([EnumeratorCancellation] CancellationToken cancellation = default) @@ -302,11 +304,19 @@ internal MRubyValue MoveNext(ReadOnlySpan args, bool transfer, bool if (pending is not null) state.Raise(pending); } + if (context.State == FiberState.Terminated) + { + terminationSource.TrySetResult(result); + } resumeSource.SetResult(result); return result; } catch (Exception ex) { + if (context.State == FiberState.Terminated) + { + terminationSource.TrySetException(ex); + } resumeSource.SetException(ex); throw; }