diff --git a/src/RestSharp/AsyncHelpers.cs b/src/RestSharp/AsyncHelpers.cs
index 5d3db9db4..c520d0fa8 100644
--- a/src/RestSharp/AsyncHelpers.cs
+++ b/src/RestSharp/AsyncHelpers.cs
@@ -25,16 +25,7 @@ static class AsyncHelpers {
///
/// Callback for asynchronous task to run
static void RunSync(Func task) {
- var currentContext = SynchronizationContext.Current;
- var customContext = new CustomSynchronizationContext(task);
-
- try {
- SynchronizationContext.SetSynchronizationContext(customContext);
- customContext.Run();
- }
- finally {
- SynchronizationContext.SetSynchronizationContext(currentContext);
- }
+ CustomSynchronizationContext.Run(task);
}
///
@@ -63,7 +54,7 @@ class CustomSynchronizationContext : SynchronizationContext {
/// Constructor for the custom context
///
/// Task to execute
- public CustomSynchronizationContext(Func task) =>
+ private CustomSynchronizationContext(Func task) =>
_task = task ?? throw new ArgumentNullException(nameof(task), "Please remember to pass a Task to be executed");
///
@@ -79,27 +70,40 @@ public override void Post(SendOrPostCallback function, object? state) {
///
/// Enqueues the function to be executed and executes all resulting continuations until it is completely done
///
- public void Run() {
- Post(PostCallback, null);
-
- while (!_done) {
- if (_items.TryDequeue(out var task)) {
- task.Item1(task.Item2);
- if (_caughtException == null) {
- continue;
+ private void Run() {
+ var currentContext = SynchronizationContext.Current;
+
+ try {
+ SynchronizationContext.SetSynchronizationContext(this);
+
+ Post(PostCallback, null);
+
+ while (!_done) {
+ if (_items.TryDequeue(out var task)) {
+ task.Item1(task.Item2);
+ if (_caughtException == null) {
+ continue;
+ }
+ _caughtException.Throw();
+ }
+ else {
+ _workItemsWaiting.WaitOne();
}
- _caughtException.Throw();
- }
- else {
- _workItemsWaiting.WaitOne();
}
}
+ finally {
+ SynchronizationContext.SetSynchronizationContext(currentContext);
+ }
+
return;
+ // This method is only called from within this custom context for the initial task.
async void PostCallback(object? _) {
try {
- await _task().ConfigureAwait(false);
+ // Do not call ConfigureAwait(false) here to ensure all continuations are
+ // queued on this context, not the thread pool.
+ await _task();
}
catch (Exception exception) {
_caughtException = ExceptionDispatchInfo.Capture(exception);
@@ -111,6 +115,12 @@ async void PostCallback(object? _) {
}
}
+ public static void Run(Func task) {
+ var customContext = new CustomSynchronizationContext(task);
+
+ customContext.Run();
+ }
+
///
/// When overridden in a derived class, dispatches a synchronous message to a synchronization context.
///