Skip to content

Commit a2cbbcd

Browse files
committed
Adjust code changes to the main launcher
1 parent 55358d4 commit a2cbbcd

4 files changed

Lines changed: 204 additions & 69 deletions

File tree

SharedStatic.V1Ext_Update4.cs

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ public partial class SharedStaticV1Ext
1313
/// Registers the function of the Speed Throttler Service from the Main Application.
1414
/// To disable the speed throttler service, set the <paramref name="addBytesOrWaitAsyncCallback"/> to <see cref="nint.Zero"/>.
1515
/// </summary>
16-
/// <param name="addBytesOrWaitAsyncCallback">The address of the speed throttler service callback</param>
17-
internal delegate HResult RegisterSpeedThrottlerServiceDelegate(nint addBytesOrWaitAsyncCallback);
16+
/// <param name="addBytesOrWaitAsyncCallback">The address of the adder wait callback</param>
17+
/// <param name="addBytesOrWaitAsyncCallback">The address of the current speed getter callback</param>
18+
internal delegate HResult RegisterSpeedThrottlerServiceDelegate(nint addBytesOrWaitAsyncCallback, nint getSharedThrottleBytesCallback);
1819
}
1920

2021
public partial class SharedStaticV1Ext<T>
@@ -37,53 +38,61 @@ private static void InitExtension_Update4Exports()
3738
/// <summary>
3839
/// This method is an ABI proxy and installer for the Speed Throttler Service functionality. To use Speed Throttler Service functionalities, Use <see cref="SpeedLimiterService"/> instead.
3940
/// </summary>
40-
private static unsafe HResult RegisterSpeedThrottlerService(nint addBytesOrWaitAsyncCallback)
41+
private static unsafe HResult RegisterSpeedThrottlerService(nint addBytesOrWaitAsyncCallback, nint getSharedThrottleBytesCallback)
4142
{
42-
SpeedLimiterService.AddBytesOrWaitAsyncCallback = (delegate* unmanaged[Stdcall]<nint, long, nint, out nint, int>)addBytesOrWaitAsyncCallback;
43-
if (addBytesOrWaitAsyncCallback == nint.Zero)
43+
if (addBytesOrWaitAsyncCallback == nint.Zero && getSharedThrottleBytesCallback == nint.Zero)
4444
{
4545
SpeedLimiterService.AddBytesOrWaitAsyncCallback = null;
46+
SpeedLimiterService.GetSharedThrottleBytesCallback = null;
4647
InstanceLogger.LogTrace("[RegisterSpeedThrottlerService] Speed Throttler Service has been uninstalled");
4748

4849
return HResult.Ok;
4950
}
5051

51-
// Test the delegate first before registering it.
52-
nint contextP = SpeedLimiterService.CreateServiceContext();
53-
try
52+
if (addBytesOrWaitAsyncCallback != nint.Zero && getSharedThrottleBytesCallback != nint.Zero)
5453
{
55-
// Try call the increment function with 16 bytes of load, then check for exception
54+
SpeedLimiterService.AddBytesOrWaitAsyncCallback = (delegate* unmanaged[Cdecl]<nint, long, nint, out nint, int>)addBytesOrWaitAsyncCallback;
55+
56+
// Test the delegate first before registering it.
57+
nint contextP = SpeedLimiterService.CreateServiceContext();
58+
try
59+
{
60+
// Try call the increment function with 16 bytes of load, then check for exception
5661
#pragma warning disable CA2012
57-
ValueTask task = SpeedLimiterService.AddBytesOrWaitAsync(contextP, 16);
62+
ValueTask task = SpeedLimiterService.AddBytesOrWaitAsync(contextP, 16);
5863
#pragma warning restore CA2012
59-
if (task is { IsCompleted: true, IsFaulted: false } ||
60-
task.IsCompletedSuccessfully)
61-
{
62-
LogSuccess();
63-
return HResult.Ok;
64-
}
64+
if (task is { IsCompleted: true, IsFaulted: false } ||
65+
task.IsCompletedSuccessfully)
66+
{
67+
LogSuccess();
68+
return HResult.Ok;
69+
}
6570

66-
// Try block and await if task is still going.
67-
task.GetAwaiter().GetResult();
71+
// Try block and await if task is still going.
72+
task.GetAwaiter().GetResult();
6873

69-
// If nothing blown up, return OK.
70-
LogSuccess();
71-
return HResult.Ok;
74+
// If nothing blown up, return OK.
75+
LogSuccess();
76+
return HResult.Ok;
7277

73-
void LogSuccess()
78+
void LogSuccess()
79+
{
80+
InstanceLogger.LogTrace("[RegisterSpeedThrottlerService] Speed Throttler Service has been installed. Service's callback is located at address: 0x{Ptr:x8}", addBytesOrWaitAsyncCallback);
81+
}
82+
}
83+
catch (Exception ex)
7484
{
75-
InstanceLogger.LogTrace("[RegisterSpeedThrottlerService] Speed Throttler Service has been installed. Service's callback is located at address: 0x{Ptr:x8}", addBytesOrWaitAsyncCallback);
85+
SpeedLimiterService.AddBytesOrWaitAsyncCallback = null; // Reset the callback
86+
return Marshal.GetHRForException(ex);
87+
}
88+
finally
89+
{
90+
SpeedLimiterService.FreeServiceContext(contextP);
7691
}
7792
}
78-
catch (Exception ex)
79-
{
80-
SpeedLimiterService.AddBytesOrWaitAsyncCallback = null; // Reset the callback
81-
return Marshal.GetHRForException(ex);
82-
}
83-
finally
84-
{
85-
SpeedLimiterService.FreeServiceContext(contextP);
86-
}
93+
94+
InstanceLogger.LogError("[RegisterSpeedThrottlerService] Failed to install/uninstall Speed Throttler Service. You must provide both arguments either all null or not-null!");
95+
return 0x80070057; // ERROR_INVALID_PARAMETER
8796
}
8897
#endregion
8998
}

Utility/RetryableCopyToStreamTask.cs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ public async ValueTask DisposeAsync()
7575
GC.SuppressFinalize(this);
7676
}
7777

78-
private RetryableCopyToStreamTask(SourceStreamFactory sourceStreamFactory, Stream targetStream, RetryableCopyToStreamTaskOptions options)
78+
private RetryableCopyToStreamTask(SourceStreamFactory sourceStreamFactory,
79+
Stream targetStream,
80+
RetryableCopyToStreamTaskOptions options)
7981
{
8082
_sourceStreamFactory = sourceStreamFactory;
8183
_targetStream = targetStream;
@@ -135,7 +137,9 @@ private async ValueTask WriteTaskCore(ReadDelegate? readDelegate, byte[] buffer,
135137

136138
while (retryAttemptLeft > 0)
137139
{
138-
var (timedOutCts, coopCts) = RenewTimeOutCancelToken(in timeoutSpan, in token);
140+
(CancellationTokenSource timedOutCts, CancellationTokenSource coopCts) =
141+
RenewTimeOutCancelToken(in timeoutSpan, in token);
142+
139143
_sourceStream = await _sourceStreamFactory(lastBytesPosition, coopCts.Token);
140144
if (_sourceStream == null)
141145
{
@@ -150,8 +154,13 @@ private async ValueTask WriteTaskCore(ReadDelegate? readDelegate, byte[] buffer,
150154
.ReadAsync(buffer, coopCts.Token)
151155
.ConfigureAwait(false)) > 0)
152156
{
157+
await SpeedLimiterService.AddBytesOrWaitAsync(_options.SpeedLimiterServiceContext,
158+
read,
159+
coopCts.Token);
160+
153161
lastBytesPosition += read;
154162
readDelegate?.Invoke(read);
163+
155164
await _targetStream
156165
.WriteAsync(buffer.AsMemory(0, read), coopCts.Token)
157166
.ConfigureAwait(false);
@@ -213,7 +222,7 @@ private static bool IsRetryableHttpStatusCode(HttpStatusCode? statusCode)
213222
private static (CancellationTokenSource TimedOutCts, CancellationTokenSource CoopCts)
214223
RenewTimeOutCancelToken(in TimeSpan timeoutSpan, in CancellationToken innerToken)
215224
{
216-
CancellationTokenSource timedOutCts = new CancellationTokenSource(timeoutSpan);
225+
CancellationTokenSource timedOutCts = new(timeoutSpan);
217226
CancellationTokenSource coopCts = CancellationTokenSource.CreateLinkedTokenSource(timedOutCts.Token, innerToken);
218227

219228
return (timedOutCts, coopCts);

Utility/RetryableCopyToStreamTaskOptions.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,15 @@ public class RetryableCopyToStreamTaskOptions
4747
/// Whether to dispose the target <see cref="Stream"/> once <see cref="IDisposable.Dispose()"/> is being called.
4848
/// </summary>
4949
/// <remarks>
50-
/// Default: false
50+
/// Default: <see langword="false"/>
5151
/// </remarks>
5252
public bool IsDisposeTargetStream { get; init; }
53+
54+
/// <summary>
55+
/// The context of speed limiter service to be used for download speed throttling.
56+
/// </summary>
57+
/// <remarks>
58+
/// Default: <see cref="nint.Zero"/>
59+
/// </remarks>
60+
public nint SpeedLimiterServiceContext { get; init; }
5361
}

Utility/SpeedLimiterService.cs

Lines changed: 142 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
using System.Runtime.CompilerServices;
1+
using Microsoft.Win32.SafeHandles;
2+
using System;
3+
using System.Runtime.CompilerServices;
24
using System.Runtime.InteropServices;
35
using System.Threading;
46
using System.Threading.Tasks;
5-
using Hi3Helper.Plugin.Core.Utility.Windows;
6-
using Microsoft.Win32.SafeHandles;
7+
using System.Threading.Tasks.Sources;
78

89
namespace Hi3Helper.Plugin.Core.Utility;
910

@@ -35,15 +36,46 @@ namespace Hi3Helper.Plugin.Core.Utility;
3536
/// </remarks>
3637
public static class SpeedLimiterService
3738
{
38-
internal static unsafe delegate* unmanaged[Stdcall]<nint, long, nint, out nint, int> AddBytesOrWaitAsyncCallback =
39+
internal static unsafe delegate* unmanaged[Cdecl]<nint, long, nint, out nint, int> AddBytesOrWaitAsyncCallback =
40+
null;
41+
42+
internal static unsafe delegate* unmanaged[Cdecl]<ref long, ref long, void> GetSharedThrottleBytesCallback =
3943
null;
4044

4145
/// <summary>
4246
/// Creates a context to be used for the speed limiter service. This context can be used into multiple instances or threads of your downloader.
4347
/// </summary>
4448
/// <returns></returns>
4549
public static unsafe nint CreateServiceContext()
46-
=> (nint)Mem.Alloc<long>(2); // Context struct is 16 bytes in size.
50+
{
51+
ThrottleServiceContext* alloc = Mem.Alloc<ThrottleServiceContext>();
52+
alloc->AvailableTokens = 0;
53+
alloc->LastTimestamp = Environment.TickCount64;
54+
55+
if (GetSharedThrottleBytesCallback == null)
56+
return (nint)alloc;
57+
58+
try
59+
{
60+
long bytesPerSecond = 0;
61+
long burstBytes = 0;
62+
63+
GetSharedThrottleBytesCallback(ref bytesPerSecond, ref burstBytes);
64+
65+
if (bytesPerSecond < burstBytes)
66+
{
67+
bytesPerSecond = burstBytes;
68+
}
69+
70+
alloc->AvailableTokens = bytesPerSecond;
71+
}
72+
catch
73+
{
74+
// ignored
75+
}
76+
77+
return (nint)alloc;
78+
}
4779

4880
/// <summary>
4981
/// Free the speed limiter service context.
@@ -65,49 +97,126 @@ public static unsafe ValueTask AddBytesOrWaitAsync(
6597
long readBytes,
6698
CancellationToken token = default)
6799
{
68-
if (AddBytesOrWaitAsyncCallback == null)
69-
{
100+
if (context == nint.Zero || AddBytesOrWaitAsyncCallback == null)
70101
return ValueTask.CompletedTask;
71-
}
72102

73-
nint tokenHandle = token.WaitHandle.SafeWaitHandle.DangerousGetHandle();
74103
int hr = AddBytesOrWaitAsyncCallback(context,
75-
readBytes,
76-
tokenHandle,
77-
out nint asyncWaitHandle);
104+
readBytes,
105+
nint.Zero,
106+
out nint completionHandle);
78107

79-
AsyncValueTaskMethodBuilder valueTaskCs = new();
80-
if (Marshal.GetExceptionForHR(hr) is { } exception)
81-
{
82-
valueTaskCs.SetException(exception);
83-
return valueTaskCs.Task;
84-
}
108+
if (Marshal.GetExceptionForHR(hr) is { } ex)
109+
return ValueTask.FromException(ex);
110+
111+
NativeThrottleOperation op = new();
112+
op.Initialize(completionHandle, token);
85113

86-
SafeWaitHandle safeHandle = new(asyncWaitHandle, false);
87-
WaitHandle waitHandle = new EventWaitHandle(false, EventResetMode.ManualReset)
114+
return op.AsValueTask();
115+
}
116+
117+
[StructLayout(LayoutKind.Sequential, Pack = 8)] // Pack to 8 bytes to ensure aligning
118+
private struct ThrottleServiceContext
119+
{
120+
public long AvailableTokens;
121+
public long LastTimestamp;
122+
}
123+
124+
private sealed class NativeThrottleOperation : IValueTaskSource
125+
{
126+
private ManualResetValueTaskSourceCore<bool> _core = new()
88127
{
89-
SafeWaitHandle = safeHandle
128+
RunContinuationsAsynchronously = true
90129
};
91130

92-
ThreadPool.UnsafeRegisterWaitForSingleObject(waitHandle,
93-
DisposeWaitHandleCallback,
94-
null,
95-
-1,
96-
true);
131+
private int _isCompleted;
132+
private EventWaitHandle? _completionWait;
133+
private SafeWaitHandle? _completionSafe;
134+
private RegisteredWaitHandle? _registeredWait;
135+
private CancellationTokenRegistration _ctr;
97136

98-
return valueTaskCs.Task;
137+
public ValueTask AsValueTask()
138+
=> new(this, _core.Version);
99139

100-
void DisposeWaitHandleCallback(object? state, bool isTimedOut)
140+
public void Initialize(
141+
nint completionHandle,
142+
CancellationToken token)
101143
{
102-
safeHandle.Dispose();
103-
waitHandle.Dispose();
144+
_completionSafe = new SafeWaitHandle(completionHandle, true);
145+
_completionWait = new EventWaitHandle(false, EventResetMode.ManualReset)
146+
{
147+
SafeWaitHandle = _completionSafe
148+
};
149+
150+
_registeredWait =
151+
ThreadPool.RegisterWaitForSingleObject(_completionWait,
152+
OnWaitSingleCompleted,
153+
this,
154+
-1,
155+
true);
156+
157+
if (token.CanBeCanceled)
158+
{
159+
_ctr = token.Register(OnCancellationRequested, this);
160+
}
161+
}
162+
163+
private static void OnWaitSingleCompleted(object? state, bool isTimedOut)
164+
{
165+
NativeThrottleOperation op = (NativeThrottleOperation)state!;
166+
op.Complete();
167+
}
104168

105-
if (asyncWaitHandle != nint.Zero)
169+
private static void OnCancellationRequested(object? state)
170+
{
171+
NativeThrottleOperation op = (NativeThrottleOperation)state!;
172+
op.Cancel();
173+
}
174+
175+
private void Complete()
176+
{
177+
if (Interlocked.Exchange(ref _isCompleted, 1) == 1)
106178
{
107-
_ = PInvoke.CloseHandle(asyncWaitHandle);
179+
return;
108180
}
109181

110-
valueTaskCs.SetResult();
182+
Cleanup();
183+
_core.SetResult(true);
111184
}
185+
186+
private void Cancel()
187+
{
188+
if (Interlocked.Exchange(ref _isCompleted, 1) == 1)
189+
{
190+
return;
191+
}
192+
193+
Cleanup();
194+
_core.SetException(new OperationCanceledException());
195+
}
196+
197+
private void Cleanup()
198+
{
199+
_registeredWait?.Unregister(null);
200+
_registeredWait = null;
201+
202+
_ctr.Dispose();
203+
204+
_completionWait?.Dispose();
205+
_completionWait = null;
206+
_completionSafe = null;
207+
}
208+
209+
public void GetResult(short token)
210+
=> _core.GetResult(token);
211+
212+
public ValueTaskSourceStatus GetStatus(short token)
213+
=> _core.GetStatus(token);
214+
215+
public void OnCompleted(
216+
Action<object?> continuation,
217+
object? state,
218+
short token,
219+
ValueTaskSourceOnCompletedFlags flags)
220+
=> _core.OnCompleted(continuation, state, token, flags);
112221
}
113222
}

0 commit comments

Comments
 (0)