Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 120 additions & 6 deletions dotnet/src/Client.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ namespace GitHub.Copilot.SDK;
/// </example>
public sealed partial class CopilotClient : IDisposable, IAsyncDisposable
{
/// <summary>
/// Minimum protocol version this SDK can communicate with.
/// </summary>
private const int MinProtocolVersion = 2;

private readonly ConcurrentDictionary<string, CopilotSession> _sessions = new();
private readonly CopilotClientOptions _options;
private readonly ILogger _logger;
Expand All @@ -62,6 +67,7 @@ public sealed partial class CopilotClient : IDisposable, IAsyncDisposable
private readonly int? _optionsPort;
private readonly string? _optionsHost;
private int? _actualPort;
private int? _negotiatedProtocolVersion;
private List<ModelInfo>? _modelsCache;
private readonly SemaphoreSlim _modelsCacheLock = new(1, 1);
private readonly List<Action<SessionLifecycleEvent>> _lifecycleHandlers = [];
Expand Down Expand Up @@ -923,27 +929,30 @@ private Task<Connection> EnsureConnectedAsync(CancellationToken cancellationToke
return (Task<Connection>)StartAsync(cancellationToken);
}

private static async Task VerifyProtocolVersionAsync(Connection connection, CancellationToken cancellationToken)
private async Task VerifyProtocolVersionAsync(Connection connection, CancellationToken cancellationToken)
{
var expectedVersion = SdkProtocolVersion.GetVersion();
var maxVersion = SdkProtocolVersion.GetVersion();
var pingResponse = await InvokeRpcAsync<PingResponse>(
connection.Rpc, "ping", [new PingRequest()], connection.StderrBuffer, cancellationToken);

if (!pingResponse.ProtocolVersion.HasValue)
{
throw new InvalidOperationException(
$"SDK protocol version mismatch: SDK expects version {expectedVersion}, " +
$"SDK protocol version mismatch: SDK supports versions {MinProtocolVersion}-{maxVersion}, " +
$"but server does not report a protocol version. " +
$"Please update your server to ensure compatibility.");
}

if (pingResponse.ProtocolVersion.Value != expectedVersion)
var serverVersion = pingResponse.ProtocolVersion.Value;
if (serverVersion < MinProtocolVersion || serverVersion > maxVersion)
{
throw new InvalidOperationException(
$"SDK protocol version mismatch: SDK expects version {expectedVersion}, " +
$"but server reports version {pingResponse.ProtocolVersion.Value}. " +
$"SDK protocol version mismatch: SDK supports versions {MinProtocolVersion}-{maxVersion}, " +
$"but server reports version {serverVersion}. " +
$"Please update your SDK or server to ensure compatibility.");
}

_negotiatedProtocolVersion = serverVersion;
}

private static async Task<(Process Process, int? DetectedLocalhostTcpPort, StringBuilder StderrBuffer)> StartCliServerAsync(CopilotClientOptions options, ILogger logger, CancellationToken cancellationToken)
Expand Down Expand Up @@ -1137,6 +1146,12 @@ private async Task<Connection> ConnectToServerAsync(Process? cliProcess, string?
var handler = new RpcHandler(this);
rpc.AddLocalRpcMethod("session.event", handler.OnSessionEvent);
rpc.AddLocalRpcMethod("session.lifecycle", handler.OnSessionLifecycle);
// Protocol v3 servers send tool calls / permission requests as broadcast events.
// Protocol v2 servers use the older tool.call / permission.request RPC model.
// We always register v2 adapters because handlers are set up before version
// negotiation; a v3 server will simply never send these requests.
rpc.AddLocalRpcMethod("tool.call", handler.OnToolCallV2);
rpc.AddLocalRpcMethod("permission.request", handler.OnPermissionRequestV2);
rpc.AddLocalRpcMethod("userInput.request", handler.OnUserInputRequest);
rpc.AddLocalRpcMethod("hooks.invoke", handler.OnHooksInvoke);
rpc.StartListening();
Expand Down Expand Up @@ -1257,6 +1272,96 @@ public async Task<HooksInvokeResponse> OnHooksInvoke(string sessionId, string ho
var output = await session.HandleHooksInvokeAsync(hookType, input);
return new HooksInvokeResponse(output);
}

// Protocol v2 backward-compatibility adapters

public async Task<ToolCallResponseV2> OnToolCallV2(string sessionId,
string toolCallId,
string toolName,
object? arguments)
{
var session = client.GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}");
if (session.GetTool(toolName) is not { } tool)
{
return new ToolCallResponseV2(new ToolResultObject
{
TextResultForLlm = $"Tool '{toolName}' is not supported.",
ResultType = "failure",
Error = $"tool '{toolName}' not supported"
});
}

try
{
var invocation = new ToolInvocation
{
SessionId = sessionId,
ToolCallId = toolCallId,
ToolName = toolName,
Arguments = arguments
};

var aiFunctionArgs = new AIFunctionArguments
{
Context = new Dictionary<object, object?>
{
[typeof(ToolInvocation)] = invocation
}
};

if (arguments is not null)
{
if (arguments is not JsonElement incomingJsonArgs)
{
throw new InvalidOperationException($"Incoming arguments must be a {nameof(JsonElement)}; received {arguments.GetType().Name}");
}

foreach (var prop in incomingJsonArgs.EnumerateObject())
{
aiFunctionArgs[prop.Name] = prop.Value;
}
}

var result = await tool.InvokeAsync(aiFunctionArgs);

var toolResultObject = result is ToolResultAIContent trac ? trac.Result : new ToolResultObject
{
ResultType = "success",
TextResultForLlm = result is JsonElement { ValueKind: JsonValueKind.String } je
? je.GetString()!
: JsonSerializer.Serialize(result, tool.JsonSerializerOptions.GetTypeInfo(typeof(object))),
};
return new ToolCallResponseV2(toolResultObject);
}
catch (Exception ex)
{
return new ToolCallResponseV2(new ToolResultObject
{
TextResultForLlm = "Invoking this tool produced an error. Detailed information is not available.",
ResultType = "failure",
Error = ex.Message
});
}
Comment on lines +1336 to +1344
}

public async Task<PermissionRequestResponseV2> OnPermissionRequestV2(string sessionId, JsonElement permissionRequest)
{
var session = client.GetSession(sessionId)
?? throw new ArgumentException($"Unknown session {sessionId}");

try
{
var result = await session.HandlePermissionRequestAsync(permissionRequest);
return new PermissionRequestResponseV2(result);
}
catch (Exception)
{
return new PermissionRequestResponseV2(new PermissionRequestResult
{
Kind = PermissionRequestResultKind.DeniedCouldNotRequestFromUser
});
}
Comment on lines +1357 to +1363
}
}

private class Connection(
Expand Down Expand Up @@ -1376,6 +1481,13 @@ internal record UserInputRequestResponse(
internal record HooksInvokeResponse(
object? Output);

// Protocol v2 backward-compatibility response types
internal record ToolCallResponseV2(
ToolResultObject Result);

internal record PermissionRequestResponseV2(
PermissionRequestResult Result);

/// <summary>Trace source that forwards all logs to the ILogger.</summary>
internal sealed class LoggerTraceSource : TraceSource
{
Expand Down Expand Up @@ -1469,11 +1581,13 @@ private static LogLevel MapLevel(TraceEventType eventType)
[JsonSerializable(typeof(ListSessionsRequest))]
[JsonSerializable(typeof(ListSessionsResponse))]
[JsonSerializable(typeof(PermissionRequestResult))]
[JsonSerializable(typeof(PermissionRequestResponseV2))]
[JsonSerializable(typeof(ProviderConfig))]
[JsonSerializable(typeof(ResumeSessionRequest))]
[JsonSerializable(typeof(ResumeSessionResponse))]
[JsonSerializable(typeof(SessionMetadata))]
[JsonSerializable(typeof(SystemMessageConfig))]
[JsonSerializable(typeof(ToolCallResponseV2))]
[JsonSerializable(typeof(ToolDefinition))]
[JsonSerializable(typeof(ToolResultAIContent))]
[JsonSerializable(typeof(ToolResultObject))]
Expand Down
Loading
Loading