diff --git a/src/RestSharp/AsyncHelpers.cs b/src/RestSharp/AsyncHelpers.cs index 5d3db9db4..c520d0fa8 100644 --- a/src/RestSharp/AsyncHelpers.cs +++ b/src/RestSharp/AsyncHelpers.cs @@ -25,16 +25,7 @@ static class AsyncHelpers { /// /// Callback for asynchronous task to run static void RunSync(Func task) { - var currentContext = SynchronizationContext.Current; - var customContext = new CustomSynchronizationContext(task); - - try { - SynchronizationContext.SetSynchronizationContext(customContext); - customContext.Run(); - } - finally { - SynchronizationContext.SetSynchronizationContext(currentContext); - } + CustomSynchronizationContext.Run(task); } /// @@ -63,7 +54,7 @@ class CustomSynchronizationContext : SynchronizationContext { /// Constructor for the custom context /// /// Task to execute - public CustomSynchronizationContext(Func task) => + private CustomSynchronizationContext(Func task) => _task = task ?? throw new ArgumentNullException(nameof(task), "Please remember to pass a Task to be executed"); /// @@ -79,27 +70,40 @@ public override void Post(SendOrPostCallback function, object? state) { /// /// Enqueues the function to be executed and executes all resulting continuations until it is completely done /// - public void Run() { - Post(PostCallback, null); - - while (!_done) { - if (_items.TryDequeue(out var task)) { - task.Item1(task.Item2); - if (_caughtException == null) { - continue; + private void Run() { + var currentContext = SynchronizationContext.Current; + + try { + SynchronizationContext.SetSynchronizationContext(this); + + Post(PostCallback, null); + + while (!_done) { + if (_items.TryDequeue(out var task)) { + task.Item1(task.Item2); + if (_caughtException == null) { + continue; + } + _caughtException.Throw(); + } + else { + _workItemsWaiting.WaitOne(); } - _caughtException.Throw(); - } - else { - _workItemsWaiting.WaitOne(); } } + finally { + SynchronizationContext.SetSynchronizationContext(currentContext); + } + return; + // This method is only called from within this custom context for the initial task. async void PostCallback(object? _) { try { - await _task().ConfigureAwait(false); + // Do not call ConfigureAwait(false) here to ensure all continuations are + // queued on this context, not the thread pool. + await _task(); } catch (Exception exception) { _caughtException = ExceptionDispatchInfo.Capture(exception); @@ -111,6 +115,12 @@ async void PostCallback(object? _) { } } + public static void Run(Func task) { + var customContext = new CustomSynchronizationContext(task); + + customContext.Run(); + } + /// /// When overridden in a derived class, dispatches a synchronous message to a synchronization context. ///