From c348a9a9e46e963ee5eef588ab8440ee4149aebd Mon Sep 17 00:00:00 2001 From: Heath Stewart Date: Fri, 22 May 2026 20:28:33 -0700 Subject: [PATCH 1/4] Refactor Rust SDK errors to use structs with a `kind()` method MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The conventional `#[non_exhaustive] enum Error { ... }` pattern appears safe but creates problems as a library evolves. This PR changes all error types to the struct-with-`kind()` pattern, which also aligns with the Azure SDK for Rust error design. Why not a flat error enum: - `#[non_exhaustive]` on the enum prevents exhaustive matching, but individual variants are still fixed. Adding a field to any variant — even just to improve an error message with a line number or file path — is a breaking change. - Adding context data is harder than it looks. With a flat enum, new fields touch every affected variant and all match arms across the codebase. With a struct, new fields are added in one place and callers who don't use them are unaffected. - A single enum conflates all failure modes, making it impossible to document or guarantee which variants a given function can actually return. Callers must handle unrelated variants they will never see, or accept a wildcard arm that silently swallows future additions. The struct + kind pattern: | Concern | Flat enum | Struct + `kind()` | |---|---|---| | Categorization | Match directly on variant | Call `.kind()` → `&*Kind` | | Adding context | Breaking: add fields to variant | Non-breaking: add fields to struct | | `non_exhaustive` | On enum; variants are fixed | Not needed on struct with only private fields | | Simple display | Must match all variants | `format!("{err}")` — no match needed | Callers who only want to display or propagate an error with `?` do not need to call `.kind()` at all. Only callers who need to inspect the failure category call `.kind()`, and they get a stable, scoped `*Kind` enum to match against. --- .github/copilot-instructions.md | 1 + .github/lsp.json | 18 + .github/skills/rust-coding-skill/SKILL.md | 29 +- .vscode/settings.json | 3 + rust/Cargo.lock | 1 - rust/Cargo.toml | 13 +- rust/examples/manual_tool_resume.rs | 2 +- rust/examples/session_fs.rs | 8 +- rust/src/embeddedcli.rs | 125 ++++-- rust/src/errors.rs | 507 ++++++++++++++++++++++ rust/src/jsonrpc.rs | 16 +- rust/src/lib.rs | 351 ++++----------- rust/src/resolve.rs | 21 +- rust/src/session.rs | 49 ++- rust/src/session_fs.rs | 161 +++++-- rust/src/subscription.rs | 107 ++++- rust/src/tool.rs | 4 +- rust/src/types.rs | 16 +- rust/tests/e2e/session_fs_sqlite.rs | 22 +- rust/tests/protocol_version_test.rs | 4 +- rust/tests/session_test.rs | 42 +- 21 files changed, 1027 insertions(+), 473 deletions(-) create mode 100644 rust/src/errors.rs diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index b058535ba..f3986ec35 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -29,6 +29,7 @@ - **.NET testing note:** Never add `InternalsVisibleTo` to any project file when writing tests. Tests must only access public APIs. - Java: `cd java && mvn clean verify` (full build + tests), `mvn spotless:apply` (format code before commit) - **Java testing note:** Always use `mvn verify` without `-q` and without piping through `grep`. Never add `InternalsVisibleTo` equivalent — tests must only access public APIs. +- Use configured LSPs for supported operations like finding references instead of pattern matching, renaming symbols, etc. ## Testing & E2E tips ⚙️ diff --git a/.github/lsp.json b/.github/lsp.json index e58456ac4..753521284 100644 --- a/.github/lsp.json +++ b/.github/lsp.json @@ -21,6 +21,24 @@ ".go": "go" }, "rootUri": "go" + }, + "rust-analyzer": { + "command": "rust-analyzer", + "fileExtensions": { + ".rs": "rust" + }, + "initializationOptions": { + "cargo": { + "buildScripts": { + "enable": true + }, + "allFeatures": true + }, + "checkOnSave": true, + "check": { + "command": "clippy" + } + } } } } diff --git a/.github/skills/rust-coding-skill/SKILL.md b/.github/skills/rust-coding-skill/SKILL.md index 7e0342f06..858dbea51 100644 --- a/.github/skills/rust-coding-skill/SKILL.md +++ b/.github/skills/rust-coding-skill/SKILL.md @@ -14,10 +14,9 @@ Opinionated Rust rules for the Copilot Rust SDK (`rust/`). Priority order: ## Error handling The SDK's public error type is `crate::Error` (`rust/src/error.rs`). Add new -variants there rather than introducing parallel error enums per module — every -public failure mode is part of the API contract and should be expressible in one -type. Internal modules can use `thiserror` enums when a richer local taxonomy -helps; convert at the boundary. +variants to `crate::ErrorKind` rather than introducing parallel error enums +per module — every public failure mode is part of the API contract and should +be expressible in one type. `anyhow` is reserved for binaries and example code. Library code never returns `anyhow::Result` — callers can't pattern-match on `anyhow::Error`, so it would @@ -42,7 +41,7 @@ it on shutdown. Fire-and-forget spawns silently swallow panics and outlive the session; don't. Blocking calls (filesystem, subprocess wait) belong in -`tokio::task::spawn_blocking`, *not* on the async runtime. The blocking pool is +`tokio::task::spawn_blocking`, _not_ on the async runtime. The blocking pool is bounded, so for genuinely long-lived workers (think: file watchers that run for the lifetime of a session) prefer `std::thread::spawn` with a channel back into async land. @@ -81,12 +80,12 @@ Trivial field re-shaping is best inlined. Closures should stay short (under ~10 **Channels, not callback closures, for event flow.** Closures fight `Send + Sync + 'static` and don't compose with `select!`. Channel choice by semantics: -| Use case | Primitive | -|---|---| -| One producer → one consumer with backpressure | `tokio::sync::mpsc` (cap 1) or `tokio::sync::oneshot` for single value | -| Many producers → one consumer | `tokio::sync::mpsc` | -| One producer → many consumers, every event delivered (pub/sub) | `tokio::sync::broadcast` | -| One producer → many consumers, only the latest value matters | `tokio::sync::watch` | +| Use case | Primitive | +| -------------------------------------------------------------- | ---------------------------------------------------------------------- | +| One producer → one consumer with backpressure | `tokio::sync::mpsc` (cap 1) or `tokio::sync::oneshot` for single value | +| Many producers → one consumer | `tokio::sync::mpsc` | +| One producer → many consumers, every event delivered (pub/sub) | `tokio::sync::broadcast` | +| One producer → many consumers, only the latest value matters | `tokio::sync::watch` | For the **public** API, prefer returning `impl Stream` (wrap a `broadcast::Receiver` in `tokio_stream::wrappers::BroadcastStream`). `Stream` composes with `select!`, `take`, `map`, `filter`, `timeout`. See `EventSubscription` and `LifecycleSubscription`. @@ -115,7 +114,7 @@ JSON: `#[serde(rename_all = "camelCase")]` at the type level, per-field `#[serde Banned via `clippy.toml`. Use manual spans with `error_span!`: - **Almost always use `error_span!`**, not `info_span!`. Span level controls - the *minimum* filter at which the span appears. An `info_span` disappears when + the _minimum_ filter at which the span appears. An `info_span` disappears when the filter is `warn` or `error` — taking all child events with it, even errors. `error_span!` ensures the span is always present. - **Spawned tasks lose parent context.** Attach a span with `.instrument()` or @@ -239,9 +238,9 @@ Match those exact commands locally before pushing. JSON-RPC and session-event types are generated from the Copilot CLI schema: -| Source | Output | -|---|---| -| `nodejs/node_modules/@github/copilot/schemas/api.schema.json` | `rust/src/generated/api_types.rs` | +| Source | Output | +| ------------------------------------------------------------------------ | -------------------------------------- | +| `nodejs/node_modules/@github/copilot/schemas/api.schema.json` | `rust/src/generated/api_types.rs` | | `nodejs/node_modules/@github/copilot/schemas/session-events.schema.json` | `rust/src/generated/session_events.rs` | Regenerate with: diff --git a/.vscode/settings.json b/.vscode/settings.json index 8d5642595..0345a3f38 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -14,6 +14,9 @@ "python.testing.pytestEnabled": true, "python.testing.unittestEnabled": false, "python.testing.pytestArgs": ["python"], + "rust-analyzer.cargo.features": "all", + "rust-analyzer.check.command": "clippy", + "rust-analyzer.check.features": "all", "[python]": { "editor.defaultFormatter": "charliermarsh.ruff" }, diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 56e658ad1..1676f2f91 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -354,7 +354,6 @@ dependencies = [ "sha2", "tar", "tempfile", - "thiserror 2.0.18", "tokio", "tokio-stream", "tokio-util", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 1e02f267c..b4d7dd809 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -43,7 +43,6 @@ async-trait = "0.1" schemars = { version = "1", optional = true } serde = { version = "1", features = ["derive"] } serde_json = "1" -thiserror = "2" tokio = { version = "1", features = ["io-util", "sync", "rt", "process", "net", "time", "macros"] } tokio-stream = { version = "0.1", features = ["sync"] } tokio-util = { version = "0.7", default-features = false } @@ -68,6 +67,18 @@ serial_test = "3" tempfile = "3" tokio = { version = "1", features = ["rt-multi-thread"] } +# Integration tests that call test-support-only Client methods (e.g. +# `from_streams_with_connection_token`, `from_streams_with_trace_provider`) +# require the `test-support` feature because `cfg(test)` is not set on the +# library when Cargo compiles it for integration tests. +[[test]] +name = "session_test" +required-features = ["test-support"] + +[[test]] +name = "protocol_version_test" +required-features = ["test-support"] + [build-dependencies] flate2 = "1" sha2 = "0.10" diff --git a/rust/examples/manual_tool_resume.rs b/rust/examples/manual_tool_resume.rs index dfb2b6232..becb53793 100644 --- a/rust/examples/manual_tool_resume.rs +++ b/rust/examples/manual_tool_resume.rs @@ -10,7 +10,7 @@ use github_copilot_sdk::generated::session_events::{ AssistantMessageData, ExternalToolRequestedData, PermissionRequestedData, SessionEventType, }; use github_copilot_sdk::{ - Client, ClientOptions, EventSubscription, RecvError, ResumeSessionConfig, SessionConfig, + Client, ClientOptions, EventSubscription, subscription::RecvError, ResumeSessionConfig, SessionConfig, }; use serde_json::json; diff --git a/rust/examples/session_fs.rs b/rust/examples/session_fs.rs index 924e6947f..ad31f6849 100644 --- a/rust/examples/session_fs.rs +++ b/rust/examples/session_fs.rs @@ -15,7 +15,7 @@ use std::sync::Arc; use async_trait::async_trait; use github_copilot_sdk::handler::ApproveAllHandler; use github_copilot_sdk::session_fs::{ - DirEntry, DirEntryKind, FileInfo, FsError, SessionFsConfig, SessionFsConventions, + DirEntry, DirEntryKind, FileInfo, FsError, FsErrorKind, SessionFsConfig, SessionFsConventions, SessionFsProvider, }; use github_copilot_sdk::types::{MessageOptions, SessionConfig}; @@ -46,7 +46,7 @@ impl SessionFsProvider for InMemoryProvider { .lock() .get(path) .cloned() - .ok_or_else(|| FsError::NotFound(path.to_string())) + .ok_or_else(|| FsError::from(FsErrorKind::NotFound(path.to_string()))) } async fn write_file( @@ -69,7 +69,7 @@ impl SessionFsProvider for InMemoryProvider { let files = self.files.lock(); let content = files .get(path) - .ok_or_else(|| FsError::NotFound(path.to_string()))?; + .ok_or_else(|| FsError::from(FsErrorKind::NotFound(path.to_string())))?; Ok(FileInfo::new( true, false, @@ -101,7 +101,7 @@ impl SessionFsProvider for InMemoryProvider { async fn rm(&self, path: &str, _recursive: bool, force: bool) -> Result<(), FsError> { if self.files.lock().remove(path).is_none() && !force { - return Err(FsError::NotFound(path.to_string())); + return Err(FsError::from(FsErrorKind::NotFound(path.to_string()))); } Ok(()) } diff --git a/rust/src/embeddedcli.rs b/rust/src/embeddedcli.rs index e1eb147dc..022d8c238 100644 --- a/rust/src/embeddedcli.rs +++ b/rust/src/embeddedcli.rs @@ -15,7 +15,7 @@ use std::fs; #[cfg(all(has_bundled_cli, not(windows)))] use std::io::Read; #[cfg(has_bundled_cli)] -use std::io::{self, Write}; +use std::io::Write; use std::path::{Path, PathBuf}; use std::sync::OnceLock; @@ -109,7 +109,7 @@ fn default_install_dir(version: &str) -> PathBuf { fn install(install_dir: &Path, archive: &[u8]) -> Result { let verbose = std::env::var("COPILOT_CLI_INSTALL_VERBOSE").ok().as_deref() == Some("1"); - fs::create_dir_all(install_dir).map_err(EmbeddedCliError::CreateDir)?; + fs::create_dir_all(install_dir).map_err(|e| EmbeddedCliError::new(EmbeddedCliErrorKind::CreateDir, e))?; let final_path = install_dir.join(build_time::CLI_BINARY_NAME); @@ -141,35 +141,35 @@ fn install(install_dir: &Path, archive: &[u8]) -> Result Result, EmbeddedCliError> { let gz = flate2::read::GzDecoder::new(archive); let mut tar = tar::Archive::new(gz); - for entry in tar.entries().map_err(EmbeddedCliError::Archive)? { - let mut entry = entry.map_err(EmbeddedCliError::Archive)?; - let path = entry.path().map_err(EmbeddedCliError::Archive)?; + for entry in tar.entries().map_err(|e| EmbeddedCliError::new(EmbeddedCliErrorKind::Archive, e))? { + let mut entry = entry.map_err(|e| EmbeddedCliError::new(EmbeddedCliErrorKind::Archive, e))?; + let path = entry.path().map_err(|e| EmbeddedCliError::new(EmbeddedCliErrorKind::Archive, e))?; let name = path.to_string_lossy(); if name == binary_name || name.ends_with(&format!("/{binary_name}")) { let mut bytes = Vec::with_capacity(entry.size() as usize); entry .read_to_end(&mut bytes) - .map_err(EmbeddedCliError::Archive)?; + .map_err(|e| EmbeddedCliError::new(EmbeddedCliErrorKind::Archive, e))?; return Ok(bytes); } } - Err(EmbeddedCliError::BinaryNotFoundInArchive) + Err(EmbeddedCliErrorKind::BinaryNotFoundInArchive.into()) } #[cfg(all(has_bundled_cli, windows))] fn extract_binary(archive: &[u8], binary_name: &str) -> Result, EmbeddedCliError> { let cursor = std::io::Cursor::new(archive); - let mut zip = zip::ZipArchive::new(cursor).map_err(EmbeddedCliError::Zip)?; + let mut zip = zip::ZipArchive::new(cursor).map_err(|e| EmbeddedCliError::new(EmbeddedCliErrorKind::Zip, e))?; for i in 0..zip.len() { - let mut entry = zip.by_index(i).map_err(EmbeddedCliError::Zip)?; + let mut entry = zip.by_index(i).map_err(|e| EmbeddedCliError::new(EmbeddedCliErrorKind::Zip, e))?; let name = entry.name().to_string(); if name == binary_name || name.ends_with(&format!("/{binary_name}")) { let mut bytes = Vec::with_capacity(entry.size() as usize); - std::io::copy(&mut entry, &mut bytes).map_err(EmbeddedCliError::Io)?; + std::io::copy(&mut entry, &mut bytes).map_err(|e| EmbeddedCliError::new(EmbeddedCliErrorKind::Io, e))?; return Ok(bytes); } } - Err(EmbeddedCliError::BinaryNotFoundInArchive) + Err(EmbeddedCliErrorKind::BinaryNotFoundInArchive.into()) } #[cfg(has_bundled_cli)] @@ -190,38 +190,107 @@ fn write_binary(path: &Path, data: &[u8]) -> Result<(), EmbeddedCliError> { .create(true) .truncate(true) .open(path) - .map_err(EmbeddedCliError::Io)?; + .map_err(|e| EmbeddedCliError::new(EmbeddedCliErrorKind::Io, e))?; - file.write_all(data).map_err(EmbeddedCliError::Io)?; + file.write_all(data).map_err(|e| EmbeddedCliError::new(EmbeddedCliErrorKind::Io, e))?; #[cfg(unix)] { use std::os::unix::fs::PermissionsExt; fs::set_permissions(path, fs::Permissions::from_mode(0o755)) - .map_err(EmbeddedCliError::Io)?; + .map_err(|e| EmbeddedCliError::new(EmbeddedCliErrorKind::Io, e))?; } Ok(()) } #[cfg(has_bundled_cli)] -#[derive(Debug, thiserror::Error)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] #[allow(dead_code)] -enum EmbeddedCliError { - #[error("failed to create install directory: {0}")] - CreateDir(io::Error), - +enum EmbeddedCliErrorKind { + CreateDir, #[cfg(not(windows))] - #[error("failed to read archive entry: {0}")] - Archive(io::Error), - + Archive, #[cfg(windows)] - #[error("failed to read zip archive: {0}")] - Zip(zip::result::ZipError), - - #[error("CLI binary not found in embedded archive")] + Zip, BinaryNotFoundInArchive, + Io, +} + +#[cfg(has_bundled_cli)] +impl std::fmt::Display for EmbeddedCliErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + EmbeddedCliErrorKind::CreateDir => f.write_str("failed to create install directory"), + #[cfg(not(windows))] + EmbeddedCliErrorKind::Archive => f.write_str("failed to read archive entry"), + #[cfg(windows)] + EmbeddedCliErrorKind::Zip => f.write_str("failed to read zip archive"), + EmbeddedCliErrorKind::BinaryNotFoundInArchive => { + f.write_str("CLI binary not found in embedded archive") + } + EmbeddedCliErrorKind::Io => f.write_str("I/O error"), + } + } +} - #[error("I/O error: {0}")] - Io(io::Error), +#[cfg(has_bundled_cli)] +#[allow(dead_code)] +struct EmbeddedCliError { + repr: crate::errors::Repr, +} + +#[cfg(has_bundled_cli)] +impl EmbeddedCliError { + fn new(kind: EmbeddedCliErrorKind, error: E) -> Self + where + E: Into>, + { + Self { + repr: crate::errors::Repr::Custom(crate::errors::Custom { + kind, + error: error.into(), + }), + } + } + +} + +#[cfg(has_bundled_cli)] +impl From for EmbeddedCliError { + fn from(kind: EmbeddedCliErrorKind) -> Self { + Self { + repr: crate::errors::Repr::Simple(kind), + } + } +} + +#[cfg(has_bundled_cli)] +impl std::fmt::Display for EmbeddedCliError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.repr { + crate::errors::Repr::Simple(kind) => write!(f, "{kind}"), + crate::errors::Repr::SimpleMessage(_, msg) => write!(f, "{msg}"), + crate::errors::Repr::Custom(crate::errors::Custom { kind, error }) => { + write!(f, "{kind}: {error}") + } + } + } +} + +#[cfg(has_bundled_cli)] +impl std::fmt::Debug for EmbeddedCliError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "EmbeddedCliError({self})") + } +} + +#[cfg(has_bundled_cli)] +impl std::error::Error for EmbeddedCliError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match &self.repr { + crate::errors::Repr::Custom(crate::errors::Custom { error, .. }) => Some(&**error), + _ => None, + } + } } diff --git a/rust/src/errors.rs b/rust/src/errors.rs new file mode 100644 index 000000000..75e9739ad --- /dev/null +++ b/rust/src/errors.rs @@ -0,0 +1,507 @@ +//! Crate errors. + +use std::{ + borrow::{Borrow, Cow}, + fmt, + time::Duration, +}; +use crate::types::SessionId; + +/// Crate-specific [`Result`](std::result::Result). +pub type Result = std::result::Result; + +// ── Repr / Custom ───────────────────────────────────────────────────────────── + +/// Internal representation shared by all SDK error structs. +/// +/// `T` is the `*Kind` enum specific to each error struct. Shared across +/// [`Error`], [`ProtocolError`], [`SessionError`], [`FsError`], +/// [`RecvError`], and the crate-internal `EmbeddedCliError`. +#[derive(Debug)] +pub(crate) enum Repr { + Simple(T), + SimpleMessage(T, Cow<'static, str>), + Custom(Custom), + // CustomMessage(Custom, Cow<'static, str>), +} + +/// Custom error representation: a kind tag plus a boxed source error. +#[derive(Debug)] +pub(crate) struct Custom { + pub(crate) kind: T, + pub(crate) error: Box, +} + +// ── ProtocolErrorKind / ProtocolError ───────────────────────────────────────── + +/// Specific protocol-level error kind in the JSON-RPC transport or CLI lifecycle. +#[derive(Clone, Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum ProtocolErrorKind { + /// Missing `Content-Length` header in a JSON-RPC message. + MissingContentLength, + + /// Invalid `Content-Length` header value. + InvalidContentLength(String), + + /// A pending JSON-RPC request was cancelled (e.g. the response channel was dropped). + RequestCancelled, + + /// The CLI process did not report a listening port within the timeout. + CliStartupTimeout, + + /// The CLI process exited before reporting a listening port. + CliStartupFailed, + + /// The CLI server's protocol version is outside the SDK's supported range. + VersionMismatch { + /// Version reported by the server. + server: u32, + /// Minimum version supported by this SDK. + min: u32, + /// Maximum version supported by this SDK. + max: u32, + }, + + /// The CLI server's protocol version changed between calls. + VersionChanged { + /// Previously negotiated version. + previous: u32, + /// Newly reported version. + current: u32, + }, +} + +impl fmt::Display for ProtocolErrorKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ProtocolErrorKind::MissingContentLength => { + write!(f, "missing Content-Length header") + } + ProtocolErrorKind::InvalidContentLength(v) => { + write!(f, "invalid Content-Length value: \"{v}\"") + } + ProtocolErrorKind::RequestCancelled => write!(f, "request cancelled"), + ProtocolErrorKind::CliStartupTimeout => { + write!(f, "timed out waiting for CLI to report listening port") + } + ProtocolErrorKind::CliStartupFailed => { + write!(f, "CLI exited before reporting listening port") + } + ProtocolErrorKind::VersionMismatch { server, min, max } => { + write!( + f, + "version mismatch: server={server}, supported={min}\u{2013}{max}" + ) + } + ProtocolErrorKind::VersionChanged { previous, current } => { + write!(f, "version changed: was {previous}, now {current}") + } + } + } +} + +/// Errors in the JSON-RPC transport or CLI lifecycle. +/// +/// Accessible via [`Error::kind`] when the kind is +/// [`ErrorKind::Protocol`]. +#[derive(Debug)] +pub struct ProtocolError { + repr: Repr, +} + +impl ProtocolError { + /// The [`ProtocolErrorKind`] of this error. + pub fn kind(&self) -> &ProtocolErrorKind { + match &self.repr { + Repr::Simple(k) + | Repr::SimpleMessage(k, ..) + | Repr::Custom(Custom { kind: k, .. }) => k, + } + } + + /// The message provided when this error was constructed, or `None`. + pub fn message(&self) -> Option<&str> { + match &self.repr { + Repr::SimpleMessage(_, m) => Some(m.borrow()), + _ => None, + } + } +} + +impl fmt::Display for ProtocolError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.repr { + Repr::Simple(k) => write!(f, "{k}"), + Repr::SimpleMessage(_, m) => write!(f, "{m}"), + Repr::Custom(Custom { error, .. }) => write!(f, "{error}"), + } + } +} + +impl std::error::Error for ProtocolError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match &self.repr { + Repr::Custom(Custom { error, .. }) => Some(&**error), + _ => None, + } + } +} + +impl From for ProtocolError { + fn from(kind: ProtocolErrorKind) -> Self { + Self { repr: Repr::Simple(kind) } + } +} + +// ── SessionErrorKind / SessionError ─────────────────────────────────────────── + +/// Session-scoped error kind. +#[derive(Clone, Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum SessionErrorKind { + /// The CLI could not find the requested session. + NotFound(SessionId), + + /// The CLI reported an error during agent execution (via `session.error` event). + AgentError, + + /// A `send_and_wait` call exceeded its timeout. + Timeout(Duration), + + /// `send` was called while a `send_and_wait` is in flight. + SendWhileWaiting, + + /// The session event loop exited before a pending `send_and_wait` completed. + EventLoopClosed, + + /// Elicitation is not supported by the host. + /// Check `session.capabilities().ui.elicitation` before calling UI methods. + ElicitationNotSupported, + + /// The client was started with [`crate::ClientOptions::session_fs`] but this + /// session was created without a [`crate::session_fs::SessionFsProvider`]. Set one via + /// [`crate::SessionConfig::with_session_fs_provider`] (or + /// [`crate::ResumeSessionConfig::with_session_fs_provider`]). + SessionFsProviderRequired, + + /// [`crate::ClientOptions::session_fs`] was provided with empty or invalid + /// fields. All of `initial_cwd` and `session_state_path` must be non-empty. + InvalidSessionFsConfig, + + /// The CLI returned a different session ID than the one the SDK registered. + SessionIdMismatch { + /// Session ID registered by the SDK before the RPC was sent. + requested: SessionId, + /// Session ID returned by the CLI. + returned: SessionId, + }, +} + +impl fmt::Display for SessionErrorKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SessionErrorKind::NotFound(id) => write!(f, "session not found: {id}"), + SessionErrorKind::AgentError => write!(f, "agent error"), + SessionErrorKind::Timeout(d) => write!(f, "timed out after {d:?}"), + SessionErrorKind::SendWhileWaiting => { + write!(f, "cannot send while send_and_wait is in flight") + } + SessionErrorKind::EventLoopClosed => { + write!(f, "event loop closed before session reached idle") + } + SessionErrorKind::ElicitationNotSupported => write!( + f, + "elicitation not supported by host \ + \u{2014} check session.capabilities().ui.elicitation first" + ), + SessionErrorKind::SessionFsProviderRequired => write!( + f, + "session was created on a client with session_fs configured \ + but no SessionFsProvider was supplied" + ), + SessionErrorKind::InvalidSessionFsConfig => { + write!(f, "invalid SessionFsConfig") + } + SessionErrorKind::SessionIdMismatch { requested, returned } => write!( + f, + "CLI returned session ID {returned} after SDK registered {requested}" + ), + } + } +} + +/// Session-scoped errors. +/// +/// Accessible via [`Error::kind`] when the kind is [`ErrorKind::Session`]. +#[derive(Debug)] +pub struct SessionError { + repr: Repr, +} + +impl SessionError { + /// The [`SessionErrorKind`] of this error. + pub fn kind(&self) -> &SessionErrorKind { + match &self.repr { + Repr::Simple(k) + | Repr::SimpleMessage(k, ..) + | Repr::Custom(Custom { kind: k, .. }) => k, + } + } + + /// The message provided when this error was constructed, or `None`. + pub fn message(&self) -> Option<&str> { + match &self.repr { + Repr::SimpleMessage(_, m) => Some(m.borrow()), + _ => None, + } + } +} + +impl fmt::Display for SessionError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if let &SessionErrorKind::InvalidSessionFsConfig = self.kind() { + write!(f, "{}: ", SessionErrorKind::InvalidSessionFsConfig)?; + } + match &self.repr { + Repr::Simple(k) => write!(f, "{k}"), + Repr::SimpleMessage(_, m) => write!(f, "{m}"), + Repr::Custom(Custom { error, .. }) => write!(f, "{error}"), + } + } +} + +impl std::error::Error for SessionError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match &self.repr { + Repr::Custom(Custom { error, .. }) => Some(&**error), + _ => None, + } + } +} + +impl From for SessionError { + fn from(kind: SessionErrorKind) -> Self { + Self { repr: Repr::Simple(kind) } + } +} + +// ── ErrorKind ───────────────────────────────────────────────────────────────── + +/// The kind of [`Error`]. +#[derive(Clone, Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum ErrorKind { + /// JSON-RPC transport or protocol violation. + Protocol(ProtocolErrorKind), + /// The CLI returned a JSON-RPC error response. + Rpc { + /// JSON-RPC error code. + code: i32, + }, + /// Session-scoped error (not found, agent error, timeout, etc.). + Session(SessionErrorKind), + /// I/O error on the stdio transport or during process spawn. + Io, + /// Failed to serialize or deserialize a JSON-RPC message. + Json, + /// A required binary was not found on the system. + BinaryNotFound { + /// Name of the binary. + name: String, + /// Optional hint for how to resolve the issue. + hint: Option, + }, + /// Invalid combination of options or configuration. + InvalidConfig, +} + +impl fmt::Display for ErrorKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ErrorKind::Protocol(k) => write!(f, "{k}"), + ErrorKind::Rpc { code } => write!(f, "RPC error {code}"), + ErrorKind::Session(k) => write!(f, "{k}"), + ErrorKind::Io => write!(f, "I/O error"), + ErrorKind::Json => write!(f, "JSON error"), + ErrorKind::BinaryNotFound { name, hint: Some(h) } => { + write!(f, "binary not found: {name} ({h})") + } + ErrorKind::BinaryNotFound { name, hint: None } => { + write!(f, "binary not found: {name}") + } + ErrorKind::InvalidConfig => write!(f, "invalid configuration"), + } + } +} + +/// Errors returned by the SDK. +#[derive(Debug)] +pub struct Error { + repr: Repr, +} + +impl Error { + /// Constructs a new `Error` boxing another [`std::error::Error`]. + pub(crate) fn new(kind: ErrorKind, error: E) -> Self + where + E: Into>, + { + Self { + repr: Repr::Custom(Custom { + kind, + error: error.into(), + }), + } + } + + /// The [`ErrorKind`] of this `Error`. + pub fn kind(&self) -> &ErrorKind { + match &self.repr { + Repr::Simple(kind) + | Repr::SimpleMessage(kind, ..) + | Repr::Custom(Custom { kind, .. }) => kind, + } + } + + /// The message provided when this `Error` was constructed, or `None`. + pub fn message(&self) -> Option<&str> { + match &self.repr { + Repr::SimpleMessage(_, message) => Some(message.borrow()), + _ => None, + } + } + + /// Create an `Error` with a message. + #[must_use] + pub fn with_message(kind: ErrorKind, message: C) -> Self + where + C: Into>, + { + Self { + repr: Repr::SimpleMessage(kind, message.into()), + } + } + + /// Returns `true` if this error indicates the transport is broken — the CLI + /// process exited, the connection was lost, or an I/O failure occurred. + /// Callers should discard the client and create a fresh one. + pub fn is_transport_failure(&self) -> bool { + matches!(self.kind(), ErrorKind::Io) + || matches!( + self.kind(), + ErrorKind::Protocol(ProtocolErrorKind::RequestCancelled) + ) + } + + /// Returns the JSON-RPC error code if this is an [`ErrorKind::Rpc`] error. + pub fn rpc_code(&self) -> Option { + match self.kind() { + ErrorKind::Rpc { code } => Some(*code), + _ => None, + } + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if let &ErrorKind::Rpc { code } = self.kind() { + write!(f, "{}: ", ErrorKind::Rpc { code })?; + } + match &self.repr { + Repr::Simple(kind) => write!(f, "{kind}"), + Repr::SimpleMessage(_, message) => write!(f, "{message}"), + Repr::Custom(Custom { error, .. }) => write!(f, "{error}"), + } + } +} + +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match &self.repr { + Repr::Custom(Custom { error, .. }) => Some(&**error), + _ => None, + } + } +} + +impl From for Error { + fn from(kind: ErrorKind) -> Self { + Self { + repr: Repr::Simple(kind), + } + } +} + +impl From for Error { + fn from(kind: ProtocolErrorKind) -> Self { + Self::from(ErrorKind::Protocol(kind)) + } +} + +impl From for Error { + fn from(kind: SessionErrorKind) -> Self { + Self::from(ErrorKind::Session(kind)) + } +} + +impl From for Error { + fn from(error: std::io::Error) -> Self { + Self::new(ErrorKind::Io, error) + } +} + +impl From for Error { + fn from(error: serde_json::Error) -> Self { + Self::new(ErrorKind::Json, error) + } +} + +/// Aggregate of errors collected during [`crate::Client::stop`]. +/// +/// `Client::stop` performs cooperative shutdown across every active +/// session before killing the CLI child process. Errors from any +/// per-session `session.destroy` RPC and from the terminal child-kill +/// step are collected here rather than short-circuiting on the first +/// failure, so callers see the full picture of what went wrong during +/// teardown. +/// +/// Implements [`std::error::Error`] and forwards to `Display` for the +/// first error, with a count suffix when there are more. +#[derive(Debug)] +pub struct StopErrors(pub(crate) Vec); + +impl StopErrors { + /// Borrow the collected errors as a slice, in the order they + /// occurred (per-session destroys first, then child-kill last). + pub fn errors(&self) -> &[Error] { + &self.0 + } + + /// Consume the aggregate and return the underlying error vector. + pub fn into_errors(self) -> Vec { + self.0 + } +} + +impl fmt::Display for StopErrors { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.0.as_slice() { + [] => write!(f, "stop completed with no errors"), + [only] => write!(f, "stop failed: {only}"), + [first, rest @ ..] => write!( + f, + "stop failed with {n} errors; first: {first}", + n = 1 + rest.len(), + ), + } + } +} + +impl std::error::Error for StopErrors { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.0 + .first() + .map(|e| e as &(dyn std::error::Error + 'static)) + } +} diff --git a/rust/src/jsonrpc.rs b/rust/src/jsonrpc.rs index 88a9670cd..a4769ac8a 100644 --- a/rust/src/jsonrpc.rs +++ b/rust/src/jsonrpc.rs @@ -11,7 +11,7 @@ use tokio::sync::{broadcast, mpsc, oneshot}; use tokio::task::JoinHandle; use tracing::{Instrument, debug, error, warn}; -use crate::{Error, ProtocolError}; +use crate::{Error, ErrorKind, ProtocolErrorKind}; /// A JSON-RPC 2.0 request message. #[derive(Debug, Clone, Serialize, Deserialize)] @@ -352,15 +352,15 @@ impl JsonRpcClient { if let Some(value) = trimmed.strip_prefix(CONTENT_LENGTH_HEADER) { content_length = Some(value.trim().parse::().map_err(|_| { - Error::Protocol(ProtocolError::InvalidContentLength( + Error::from(ErrorKind::Protocol(ProtocolErrorKind::InvalidContentLength( value.trim().to_string(), - )) + ))) })?); } } let Some(length) = content_length else { - return Err(Error::Protocol(ProtocolError::MissingContentLength)); + return Err(ErrorKind::Protocol(ProtocolErrorKind::MissingContentLength).into()); }; let mut body = vec![0u8; length]; @@ -420,7 +420,7 @@ impl JsonRpcClient { let response = match rx.await { Ok(response) => response, Err(_) => { - let error = Error::Protocol(ProtocolError::RequestCancelled); + let error = ErrorKind::Protocol(ProtocolErrorKind::RequestCancelled).into(); warn!( elapsed_ms = request_start.elapsed().as_millis(), method = %method, @@ -475,7 +475,7 @@ impl JsonRpcClient { self.write_tx .send(WriteCommand { frame, ack: ack_tx }) .map_err(|_| { - Error::Io(std::io::Error::new( + Error::from(std::io::Error::new( std::io::ErrorKind::BrokenPipe, "writer actor has shut down", )) @@ -483,8 +483,8 @@ impl JsonRpcClient { match ack_rx.await { Ok(Ok(())) => Ok(()), - Ok(Err(e)) => Err(Error::Io(e)), - Err(_) => Err(Error::Io(std::io::Error::new( + Ok(Err(e)) => Err(Error::from(e)), + Err(_) => Err(Error::from(std::io::Error::new( std::io::ErrorKind::BrokenPipe, "writer actor dropped ack without responding", ))), diff --git a/rust/src/lib.rs b/rust/src/lib.rs index c7294c3c3..95d772f50 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -5,6 +5,8 @@ /// Bundled CLI binary extraction and caching. pub(crate) mod embeddedcli; +mod errors; +pub use errors::*; /// Event handler traits for session lifecycle. pub mod handler; /// Lifecycle hook callbacks (pre/post tool use, prompt submission, session start/end). @@ -66,219 +68,11 @@ pub use types::*; mod sdk_protocol_version; pub use sdk_protocol_version::{SDK_PROTOCOL_VERSION, get_sdk_protocol_version}; -pub use subscription::{EventSubscription, Lagged, LifecycleSubscription, RecvError}; +pub use subscription::{EventSubscription, LifecycleSubscription}; /// Minimum protocol version this SDK can communicate with. const MIN_PROTOCOL_VERSION: u32 = 3; -/// Errors returned by the SDK. -#[derive(Debug, thiserror::Error)] -#[non_exhaustive] -pub enum Error { - /// JSON-RPC transport or protocol violation. - #[error("protocol error: {0}")] - Protocol(ProtocolError), - - /// The CLI returned a JSON-RPC error response. - #[error("RPC error {code}: {message}")] - Rpc { - /// JSON-RPC error code. - code: i32, - /// Human-readable error message. - message: String, - }, - - /// Session-scoped error (not found, agent error, timeout, etc.). - #[error("session error: {0}")] - Session(SessionError), - - /// I/O error on the stdio transport or during process spawn. - #[error(transparent)] - Io(#[from] std::io::Error), - - /// Failed to serialize or deserialize a JSON-RPC message. - #[error(transparent)] - Json(#[from] serde_json::Error), - - /// A required binary was not found on the system. - #[error("binary not found: {name} ({hint})")] - BinaryNotFound { - /// Binary name that was searched for. - name: &'static str, - /// Guidance on how to install or configure the binary. - hint: &'static str, - }, - - /// Invalid combination of [`ClientOptions`] supplied to [`Client::start`]. - /// Surfaces consumer-side configuration errors that would otherwise - /// produce confusing runtime failures (e.g. a connection token paired - /// with stdio transport). - #[error("invalid client configuration: {0}")] - InvalidConfig(String), -} - -impl Error { - /// Returns true if this error indicates the transport is broken — the CLI - /// process exited, the connection was lost, or an I/O failure occurred. - /// Callers should discard the client and create a fresh one. - pub fn is_transport_failure(&self) -> bool { - matches!( - self, - Error::Protocol(ProtocolError::RequestCancelled) | Error::Io(_) - ) - } -} - -/// Aggregate of errors collected during [`Client::stop`]. -/// -/// `Client::stop` performs cooperative shutdown across every active -/// session before killing the CLI child process. Errors from any -/// per-session `session.destroy` RPC and from the terminal child-kill -/// step are collected here rather than short-circuiting on the first -/// failure, so callers see the full picture of what went wrong during -/// teardown. -/// -/// Implements [`std::error::Error`] and forwards to `Display` for the -/// first error, with a count suffix when there are more. -#[derive(Debug)] -pub struct StopErrors(Vec); - -impl StopErrors { - /// Borrow the collected errors as a slice, in the order they - /// occurred (per-session destroys first, then child-kill last). - pub fn errors(&self) -> &[Error] { - &self.0 - } - - /// Consume the aggregate and return the underlying error vector. - pub fn into_errors(self) -> Vec { - self.0 - } -} - -impl std::fmt::Display for StopErrors { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self.0.as_slice() { - [] => write!(f, "stop completed with no errors"), - [only] => write!(f, "stop failed: {only}"), - [first, rest @ ..] => write!( - f, - "stop failed with {n} errors; first: {first}", - n = 1 + rest.len(), - ), - } - } -} - -impl std::error::Error for StopErrors { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - self.0 - .first() - .map(|e| e as &(dyn std::error::Error + 'static)) - } -} - -/// Specific protocol-level errors in the JSON-RPC transport or CLI lifecycle. -#[derive(Debug, thiserror::Error)] -#[non_exhaustive] -pub enum ProtocolError { - /// Missing `Content-Length` header in a JSON-RPC message. - #[error("missing Content-Length header")] - MissingContentLength, - - /// Invalid `Content-Length` header value. - #[error("invalid Content-Length value: \"{0}\"")] - InvalidContentLength(String), - - /// A pending JSON-RPC request was cancelled (e.g. the response channel was dropped). - #[error("request cancelled")] - RequestCancelled, - - /// The CLI process did not report a listening port within the timeout. - #[error("timed out waiting for CLI to report listening port")] - CliStartupTimeout, - - /// The CLI process exited before reporting a listening port. - #[error("CLI exited before reporting listening port")] - CliStartupFailed, - - /// The CLI server's protocol version is outside the SDK's supported range. - #[error("version mismatch: server={server}, supported={min}–{max}")] - VersionMismatch { - /// Version reported by the server. - server: u32, - /// Minimum version supported by this SDK. - min: u32, - /// Maximum version supported by this SDK. - max: u32, - }, - - /// The CLI server's protocol version changed between calls. - #[error("version changed: was {previous}, now {current}")] - VersionChanged { - /// Previously negotiated version. - previous: u32, - /// Newly reported version. - current: u32, - }, -} - -/// Session-scoped errors. -#[derive(Debug, thiserror::Error)] -#[non_exhaustive] -pub enum SessionError { - /// The CLI could not find the requested session. - #[error("session not found: {0}")] - NotFound(SessionId), - - /// The CLI reported an error during agent execution (via `session.error` event). - #[error("{0}")] - AgentError(String), - - /// A `send_and_wait` call exceeded its timeout. - #[error("timed out after {0:?}")] - Timeout(std::time::Duration), - - /// `send` was called while a `send_and_wait` is in flight. - #[error("cannot send while send_and_wait is in flight")] - SendWhileWaiting, - - /// The session event loop exited before a pending `send_and_wait` completed. - #[error("event loop closed before session reached idle")] - EventLoopClosed, - - /// Elicitation is not supported by the host. - /// Check `session.capabilities().ui.elicitation` before calling UI methods. - #[error( - "elicitation not supported by host — check session.capabilities().ui.elicitation first" - )] - ElicitationNotSupported, - - /// The client was started with [`ClientOptions::session_fs`] but this - /// session was created without a [`SessionFsProvider`]. Set one via - /// [`SessionConfig::with_session_fs_provider`] (or - /// [`ResumeSessionConfig::with_session_fs_provider`]). - #[error( - "session was created on a client with session_fs configured but no SessionFsProvider was supplied" - )] - SessionFsProviderRequired, - - /// [`ClientOptions::session_fs`] was provided with empty or invalid - /// fields. All of `initial_cwd` and `session_state_path` must be - /// non-empty. - #[error("invalid SessionFsConfig: {0}")] - InvalidSessionFsConfig(String), - - /// The CLI returned a different session ID than the one the SDK registered. - #[error("CLI returned session ID {returned} after SDK registered {requested}")] - SessionIdMismatch { - /// Session ID registered by the SDK before the RPC was sent. - requested: SessionId, - /// Session ID returned by the CLI. - returned: SessionId, - }, -} - /// How the SDK communicates with the CLI server. #[derive(Debug, Default)] #[non_exhaustive] @@ -473,7 +267,7 @@ impl std::fmt::Debug for ClientOptions { #[async_trait] pub trait ListModelsHandler: Send + Sync + 'static { /// Return the list of available models. - async fn list_models(&self) -> Result, Error>; + async fn list_models(&self) -> Result>; } /// Log verbosity for the CLI server (passed via `--log-level`). @@ -831,16 +625,18 @@ impl ClientOptions { } /// Validate a [`SessionFsConfig`] before sending `sessionFs.setProvider`. -fn validate_session_fs_config(cfg: &SessionFsConfig) -> Result<(), Error> { +fn validate_session_fs_config(cfg: &SessionFsConfig) -> Result<()> { if cfg.initial_cwd.trim().is_empty() { - return Err(Error::Session(SessionError::InvalidSessionFsConfig( - "initial_cwd must not be empty".to_string(), - ))); + return Err(Error::with_message( + ErrorKind::Session(SessionErrorKind::InvalidSessionFsConfig), + "invalid SessionFsConfig: initial_cwd must not be empty", + )); } if cfg.session_state_path.trim().is_empty() { - return Err(Error::Session(SessionError::InvalidSessionFsConfig( - "session_state_path must not be empty".to_string(), - ))); + return Err(Error::with_message( + ErrorKind::Session(SessionErrorKind::InvalidSessionFsConfig), + "invalid SessionFsConfig: session_state_path must not be empty", + )); } Ok(()) } @@ -916,7 +712,7 @@ impl Client { /// When [`ClientOptions::session_fs`] is set, also calls /// `sessionFs.setProvider` to register the SDK as the filesystem /// backend. - pub async fn start(options: ClientOptions) -> Result { + pub async fn start(options: ClientOptions) -> Result { let start_time = Instant::now(); if let Some(cfg) = &options.session_fs { validate_session_fs_config(cfg)?; @@ -925,17 +721,17 @@ impl Client { // external server, the server manages its own auth. if matches!(options.transport, Transport::External { .. }) { if options.github_token.is_some() { - return Err(Error::InvalidConfig( - "github_token cannot be used with Transport::External \ - (external server manages its own auth)" - .to_string(), + return Err(Error::with_message( + ErrorKind::InvalidConfig, + "invalid client configuration: github_token cannot be used with \ + Transport::External (external server manages its own auth)", )); } if options.use_logged_in_user == Some(true) { - return Err(Error::InvalidConfig( - "use_logged_in_user cannot be used with Transport::External \ - (external server manages its own auth)" - .to_string(), + return Err(Error::with_message( + ErrorKind::InvalidConfig, + "invalid client configuration: use_logged_in_user cannot be used with \ + Transport::External (external server manages its own auth)", )); } } @@ -951,8 +747,9 @@ impl Client { connection_token: Some(t), .. } if t.is_empty() => { - return Err(Error::InvalidConfig( - "connection_token must be a non-empty string".to_string(), + return Err(Error::with_message( + ErrorKind::InvalidConfig, + "invalid client configuration: connection_token must be a non-empty string", )); } _ => {} @@ -1122,7 +919,7 @@ impl Client { reader: impl AsyncRead + Unpin + Send + 'static, writer: impl AsyncWrite + Unpin + Send + 'static, cwd: PathBuf, - ) -> Result { + ) -> Result { Self::from_transport(reader, writer, None, cwd, None, false, false, None, None) } @@ -1139,7 +936,7 @@ impl Client { writer: impl AsyncWrite + Unpin + Send + 'static, cwd: PathBuf, provider: Arc, - ) -> Result { + ) -> Result { Self::from_transport( reader, writer, @@ -1162,7 +959,7 @@ impl Client { writer: impl AsyncWrite + Unpin + Send + 'static, cwd: PathBuf, token: Option, - ) -> Result { + ) -> Result { Self::from_transport(reader, writer, None, cwd, None, false, false, None, token) } @@ -1187,7 +984,7 @@ impl Client { session_fs_sqlite_declared: bool, on_get_trace_context: Option>, effective_connection_token: Option, - ) -> Result { + ) -> Result { let setup_start = Instant::now(); let (request_tx, request_rx) = mpsc::unbounded_channel::(); let (notification_broadcast_tx, _) = broadcast::channel::(1024); @@ -1382,7 +1179,7 @@ impl Client { } } - fn spawn_stdio(program: &Path, options: &ClientOptions) -> Result { + fn spawn_stdio(program: &Path, options: &ClientOptions) -> Result { info!(cwd = ?options.working_directory, program = %program.display(), "spawning copilot CLI (stdio)"); let mut command = Self::build_command(program, options); command @@ -1406,7 +1203,7 @@ impl Client { program: &Path, options: &ClientOptions, port: u16, - ) -> Result<(Child, u16), Error> { + ) -> Result<(Child, u16)> { info!(cwd = ?options.working_directory, program = %program.display(), port = %port, "spawning copilot CLI (tcp)"); let mut command = Self::build_command(program, options); command @@ -1454,8 +1251,8 @@ impl Client { let port_wait_start = Instant::now(); let actual_port = tokio::time::timeout(std::time::Duration::from_secs(10), port_rx) .await - .map_err(|_| Error::Protocol(ProtocolError::CliStartupTimeout))? - .map_err(|_| Error::Protocol(ProtocolError::CliStartupFailed))?; + .map_err(|_| Error::from(ErrorKind::Protocol(ProtocolErrorKind::CliStartupTimeout)))? + .map_err(|_| Error::from(ErrorKind::Protocol(ProtocolErrorKind::CliStartupFailed)))?; debug!( elapsed_ms = port_wait_start.elapsed().as_millis(), @@ -1505,7 +1302,7 @@ impl Client { &self, method: &str, params: Option, - ) -> Result { + ) -> Result { self.inner.rpc.send_request(method, params).await } @@ -1532,7 +1329,7 @@ impl Client { &self, method: &str, params: Option, - ) -> Result { + ) -> Result { let session_id: Option = params .as_ref() .and_then(|p| p.get("sessionId")) @@ -1541,20 +1338,20 @@ impl Client { let response = self.send_request(method, params).await?; if let Some(err) = response.error { if err.message.contains("Session not found") { - return Err(Error::Session(SessionError::NotFound( + return Err(ErrorKind::Session(SessionErrorKind::NotFound( session_id.unwrap_or_else(|| "unknown".into()), - ))); + )).into()); } - return Err(Error::Rpc { - code: err.code, - message: err.message, - }); + return Err(Error::with_message( + ErrorKind::Rpc { code: err.code }, + err.message, + )); } Ok(response.result.unwrap_or(serde_json::Value::Null)) } /// Send a JSON-RPC response back to the CLI (e.g. for permission or tool call requests). - pub(crate) async fn send_response(&self, response: &JsonRpcResponse) -> Result<(), Error> { + pub(crate) async fn send_response(&self, response: &JsonRpcResponse) -> Result<()> { self.inner.rpc.write(response).await } @@ -1621,7 +1418,7 @@ impl Client { /// Returns an error if the negotiated `protocolVersion` is outside /// `MIN_PROTOCOL_VERSION`..=[`SDK_PROTOCOL_VERSION`]. If the server /// doesn't report a version, logs a warning and succeeds. - pub async fn verify_protocol_version(&self) -> Result<(), Error> { + pub async fn verify_protocol_version(&self) -> Result<()> { let handshake_start = Instant::now(); let mut used_fallback_ping = false; // Try the new `connect` handshake first (sends the connection @@ -1629,7 +1426,7 @@ impl Client { // that don't expose `connect` (-32601 MethodNotFound). let server_version = match self.connect_handshake().await { Ok(v) => v, - Err(Error::Rpc { code, .. }) if code == error_codes::METHOD_NOT_FOUND => { + Err(ref e) if e.rpc_code() == Some(error_codes::METHOD_NOT_FOUND) => { used_fallback_ping = true; self.ping(None).await?.protocol_version } @@ -1641,19 +1438,19 @@ impl Client { warn!("CLI server did not report protocolVersion; skipping version check"); } Some(v) if !(MIN_PROTOCOL_VERSION..=SDK_PROTOCOL_VERSION).contains(&v) => { - return Err(Error::Protocol(ProtocolError::VersionMismatch { + return Err(ErrorKind::Protocol(ProtocolErrorKind::VersionMismatch { server: v, min: MIN_PROTOCOL_VERSION, max: SDK_PROTOCOL_VERSION, - })); +}).into()); } Some(v) => { if let Some(&existing) = self.inner.negotiated_protocol_version.get() { if existing != v { - return Err(Error::Protocol(ProtocolError::VersionChanged { + return Err(ErrorKind::Protocol(ProtocolErrorKind::VersionChanged { previous: existing, current: v, - })); + }).into()); } } else { let _ = self.inner.negotiated_protocol_version.set(v); @@ -1676,7 +1473,7 @@ impl Client { /// auto-generated token for SDK-spawned TCP servers) as the `token` /// param. Server-side, the token is required when the server was /// started with `COPILOT_CONNECTION_TOKEN`. - async fn connect_handshake(&self) -> Result, Error> { + async fn connect_handshake(&self) -> Result> { let result = self .rpc() .connect(crate::generated::api_types::ConnectRequest { @@ -1693,7 +1490,7 @@ impl Client { /// the CLI reports one. /// /// [`PingResponse`]: crate::types::PingResponse - pub async fn ping(&self, message: Option<&str>) -> Result { + pub async fn ping(&self, message: Option<&str>) -> Result { let params = match message { Some(m) => serde_json::json!({ "message": m }), None => serde_json::json!({}), @@ -1709,7 +1506,7 @@ impl Client { pub async fn list_sessions( &self, filter: Option, - ) -> Result, Error> { + ) -> Result> { let params = match filter { Some(f) => serde_json::json!({ "filter": f }), None => serde_json::json!({}), @@ -1739,7 +1536,7 @@ impl Client { pub async fn get_session_metadata( &self, session_id: &SessionId, - ) -> Result, Error> { + ) -> Result> { let result = self .call( "session.getMetadata", @@ -1751,7 +1548,7 @@ impl Client { } /// Delete a persisted session by ID. - pub async fn delete_session(&self, session_id: &SessionId) -> Result<(), Error> { + pub async fn delete_session(&self, session_id: &SessionId) -> Result<()> { self.call( "session.delete", Some(serde_json::json!({ "sessionId": session_id })), @@ -1775,7 +1572,7 @@ impl Client { /// # Ok(()) /// # } /// ``` - pub async fn get_last_session_id(&self) -> Result, Error> { + pub async fn get_last_session_id(&self) -> Result> { let result = self .call("session.getLastId", Some(serde_json::json!({}))) .await?; @@ -1787,7 +1584,7 @@ impl Client { /// /// Only meaningful when connected to a server running in TUI+server mode /// (`--ui-server`). Returns `Ok(None)` if no foreground session is set. - pub async fn get_foreground_session_id(&self) -> Result, Error> { + pub async fn get_foreground_session_id(&self) -> Result> { let result = self .call("session.getForeground", Some(serde_json::json!({}))) .await?; @@ -1799,7 +1596,7 @@ impl Client { /// /// Only meaningful when connected to a server running in TUI+server mode /// (`--ui-server`). - pub async fn set_foreground_session_id(&self, session_id: &SessionId) -> Result<(), Error> { + pub async fn set_foreground_session_id(&self, session_id: &SessionId) -> Result<()> { self.call( "session.setForeground", Some(serde_json::json!({ "sessionId": session_id })), @@ -1809,13 +1606,13 @@ impl Client { } /// Get the CLI server status. - pub async fn get_status(&self) -> Result { + pub async fn get_status(&self) -> Result { let result = self.call("status.get", Some(serde_json::json!({}))).await?; Ok(serde_json::from_value(result)?) } /// Get authentication status. - pub async fn get_auth_status(&self) -> Result { + pub async fn get_auth_status(&self) -> Result { let result = self .call("auth.getStatus", Some(serde_json::json!({}))) .await?; @@ -1826,7 +1623,7 @@ impl Client { /// /// When [`ClientOptions::on_list_models`] is set, returns the handler's /// result without making a `models.list` RPC. Otherwise queries the CLI. - pub async fn list_models(&self) -> Result, Error> { + pub async fn list_models(&self) -> Result> { let cache = self.inner.models_cache.lock().clone(); let models = cache .get_or_try_init(|| async { @@ -1880,7 +1677,7 @@ impl Client { /// or call `stop()` again with a fresh future. The documented /// `tokio::time::timeout(..., client.stop())` pattern in the example /// below uses `force_stop` as the fallback for exactly this case. - pub async fn stop(&self) -> Result<(), StopErrors> { + pub async fn stop(&self) -> std::result::Result<(), StopErrors> { let pid = self.pid(); info!(pid = ?pid, "stopping CLI process"); let mut errors: Vec = Vec::new(); @@ -1914,7 +1711,7 @@ impl Client { if let Some(mut child) = child && let Err(e) = child.kill().await { - errors.push(Error::Io(e)); + errors.push(e.into()); } info!(pid = ?pid, errors = errors.len(), "CLI process stopped"); @@ -1983,7 +1780,8 @@ impl Client { /// /// Each subscriber maintains its own queue. If a consumer cannot keep /// up, the oldest events are dropped and `recv` returns - /// [`RecvError::Lagged`] with the count of skipped events; consumers + /// [`RecvErrorKind::Lagged`](crate::subscription::RecvErrorKind::Lagged) + /// with the count of skipped events; consumers /// should match on it and continue. Slow consumers do not block the /// producer. /// @@ -2027,28 +1825,25 @@ mod tests { #[test] fn is_transport_failure_matches_request_cancelled() { - let err = Error::Protocol(ProtocolError::RequestCancelled); + let err = Error::from(ErrorKind::Protocol(ProtocolErrorKind::RequestCancelled)); assert!(err.is_transport_failure()); } #[test] fn is_transport_failure_matches_io_error() { - let err = Error::Io(std::io::Error::new(std::io::ErrorKind::BrokenPipe, "gone")); + let err = Error::from(std::io::Error::new(std::io::ErrorKind::BrokenPipe, "gone")); assert!(err.is_transport_failure()); } #[test] fn is_transport_failure_rejects_rpc_error() { - let err = Error::Rpc { - code: -1, - message: "bad".into(), - }; + let err = Error::with_message(ErrorKind::Rpc { code: -1 }, "bad"); assert!(!err.is_transport_failure()); } #[test] fn is_transport_failure_rejects_session_error() { - let err = Error::Session(SessionError::NotFound("s1".into())); + let err = Error::from(ErrorKind::Session(SessionErrorKind::NotFound("s1".into()))); assert!(!err.is_transport_failure()); } @@ -2087,7 +1882,7 @@ mod tests { #[test] fn is_transport_failure_rejects_other_protocol_errors() { - let err = Error::Protocol(ProtocolError::CliStartupTimeout); + let err = Error::from(ErrorKind::Protocol(ProtocolErrorKind::CliStartupTimeout)); assert!(!err.is_transport_failure()); } @@ -2323,7 +2118,7 @@ mod tests { }) .with_program(CliProgram::Path(PathBuf::from("/bin/echo"))); let err = Client::start(opts).await.unwrap_err(); - assert!(matches!(err, Error::InvalidConfig(_)), "got {err:?}"); + assert!(matches!(err.kind(), ErrorKind::InvalidConfig), "got {err:?}"); } #[tokio::test] @@ -2336,7 +2131,7 @@ mod tests { }) .with_program(CliProgram::Path(PathBuf::from("/bin/echo"))); let err = Client::start(opts).await.unwrap_err(); - assert!(matches!(err, Error::InvalidConfig(_)), "got {err:?}"); + assert!(matches!(err.kind(), ErrorKind::InvalidConfig), "got {err:?}"); } #[test] @@ -2454,7 +2249,7 @@ mod tests { struct StubHandler; #[async_trait] impl ListModelsHandler for StubHandler { - async fn list_models(&self) -> Result, Error> { + async fn list_models(&self) -> Result> { Ok(vec![]) } } @@ -2479,7 +2274,7 @@ mod tests { } #[async_trait] impl ListModelsHandler for CountingHandler { - async fn list_models(&self) -> Result, Error> { + async fn list_models(&self) -> Result> { self.calls.fetch_add(1, Ordering::SeqCst); Ok(self.models.clone()) } @@ -2514,7 +2309,7 @@ mod tests { } #[async_trait] impl ListModelsHandler for SlowCountingHandler { - async fn list_models(&self) -> Result, Error> { + async fn list_models(&self) -> Result> { self.calls.fetch_add(1, Ordering::SeqCst); tokio::time::sleep(std::time::Duration::from_millis(25)).await; Ok(self.models.clone()) diff --git a/rust/src/resolve.rs b/rust/src/resolve.rs index 7a1b29a04..58a37675b 100644 --- a/rust/src/resolve.rs +++ b/rust/src/resolve.rs @@ -12,14 +12,14 @@ //! If you've opted out of bundling (via `default-features = false`) and //! neither `CliProgram::Path` nor `COPILOT_CLI_PATH` is set, //! [`Client::start`](crate::Client::start) returns -//! [`Error::BinaryNotFound`](crate::Error::BinaryNotFound). +//! an [`ErrorKind::BinaryNotFound`](crate::ErrorKind::BinaryNotFound) error. use std::env; use std::path::{Path, PathBuf}; use tracing::warn; -use crate::Error; +use crate::{Error, ErrorKind}; /// Resolve the CLI binary, optionally overriding the directory the bundled /// CLI is extracted to. Called by `Client::start` to thread @@ -47,11 +47,14 @@ pub(crate) fn copilot_binary_with_extract_dir( return Ok(path); } - Err(Error::BinaryNotFound { - name: "copilot", - hint: "the Copilot CLI is not bundled in this build of github-copilot-sdk and \ - COPILOT_CLI_PATH is not set. Either keep the default `bundled-cli` cargo \ - feature enabled, set COPILOT_CLI_PATH, or supply an explicit path via \ - `CliProgram::Path(...)` on `ClientOptions::program`.", - }) + Err(ErrorKind::BinaryNotFound { + name: "copilot".into(), + hint: Some( + "the Copilot CLI is not bundled in this build of github-copilot-sdk and \ + COPILOT_CLI_PATH is not set. Either keep the default `bundled-cli` cargo \ + feature enabled, set COPILOT_CLI_PATH, or supply an explicit path via \ + `CliProgram::Path(...)` on `ClientOptions::program`." + .into(), + ), + }.into()) } diff --git a/rust/src/session.rs b/rust/src/session.rs index f8a35ce0c..61d471efb 100644 --- a/rust/src/session.rs +++ b/rust/src/session.rs @@ -30,7 +30,7 @@ use crate::types::{ SessionConfig, SessionEvent, SessionId, SetModelOptions, SystemMessageConfig, ToolInvocation, ToolResult, ToolResultExpanded, TraceContext, UiInputOptions, ensure_attachment_display_names, }; -use crate::{Client, Error, JsonRpcResponse, SessionError, SessionEventNotification, error_codes}; +use crate::{Client, Error, ErrorKind, JsonRpcResponse, SessionErrorKind, SessionEventNotification, error_codes}; /// Bundle of the per-session callbacks the SDK dispatches to. Built from a /// [`SessionConfig`] / [`ResumeSessionConfig`] at @@ -66,7 +66,7 @@ struct IdleWaiter { /// Without this, an outer cancellation between "install waiter" and /// "drain channel" would leave the slot occupied, causing all subsequent /// `send` and `send_and_wait` calls on the session to return -/// [`SendWhileWaiting`](SessionError::SendWhileWaiting). Closes RFD-400 +/// [`SendWhileWaiting`](SessionErrorKind::SendWhileWaiting). Closes RFD-400 /// review finding #2. struct WaiterGuard { slot: Arc>>, @@ -247,7 +247,7 @@ impl Session { /// /// Each subscriber maintains its own queue. If a consumer cannot keep /// up, the oldest events are dropped and `recv` returns - /// [`RecvError::Lagged`](crate::subscription::RecvError::Lagged) + /// [`RecvErrorKind::Lagged`](crate::subscription::RecvErrorKind::Lagged) /// reporting the count of skipped events. Slow consumers do not block /// the session's event loop. /// @@ -303,7 +303,7 @@ impl Session { if let Some(waiter) = self.idle_waiter.lock().take() { let _ = waiter .tx - .send(Err(Error::Session(SessionError::EventLoopClosed))); + .send(Err(ErrorKind::Session(SessionErrorKind::EventLoopClosed).into())); } } @@ -333,7 +333,7 @@ impl Session { /// message ID. pub async fn send(&self, opts: impl Into) -> Result { if self.idle_waiter.lock().is_some() { - return Err(Error::Session(SessionError::SendWhileWaiting)); + return Err(ErrorKind::Session(SessionErrorKind::SendWhileWaiting).into()); } self.send_inner(opts.into()).await } @@ -411,7 +411,7 @@ impl Session { { let mut guard = self.idle_waiter.lock(); if guard.is_some() { - return Err(Error::Session(SessionError::SendWhileWaiting)); + return Err(ErrorKind::Session(SessionErrorKind::SendWhileWaiting).into()); } *guard = Some(IdleWaiter { tx, @@ -433,7 +433,7 @@ impl Session { self.send_inner(opts).await?; match rx.await { Ok(result) => result, - Err(_) => Err(Error::Session(SessionError::EventLoopClosed)), + Err(_) => Err(ErrorKind::Session(SessionErrorKind::EventLoopClosed).into()), } }) .await; @@ -455,7 +455,7 @@ impl Session { completed_by = "timeout", "Session::send_and_wait failed" ); - Err(Error::Session(SessionError::Timeout(timeout_duration))) + Err(ErrorKind::Session(SessionErrorKind::Timeout(timeout_duration)).into()) } } } @@ -592,7 +592,7 @@ impl Session { .and_then(|u| u.elicitation) != Some(true) { - return Err(Error::Session(SessionError::ElicitationNotSupported)); + return Err(ErrorKind::Session(SessionErrorKind::ElicitationNotSupported).into()); } Ok(()) } @@ -810,16 +810,16 @@ impl Client { let command_handlers = build_command_handler_map(runtime.commands.as_deref()); let session_fs_provider = runtime.session_fs_provider.take(); if self.inner.session_fs_configured && session_fs_provider.is_none() { - return Err(Error::Session(SessionError::SessionFsProviderRequired)); + return Err(ErrorKind::Session(SessionErrorKind::SessionFsProviderRequired).into()); } if self.inner.session_fs_sqlite_declared && let Some(ref provider) = session_fs_provider && provider.sqlite().is_none() { - return Err(Error::InvalidConfig( + return Err(Error::with_message( + ErrorKind::InvalidConfig, "SessionFs capabilities declare SQLite support but the provider \ - does not implement SessionFsSqliteProvider" - .to_string(), + does not implement SessionFsSqliteProvider", )); } @@ -879,10 +879,10 @@ impl Client { }; if create_result.session_id != session_id { registration.cleanup(event_loop).await; - return Err(Error::Session(SessionError::SessionIdMismatch { + return Err(ErrorKind::Session(SessionErrorKind::SessionIdMismatch { requested: session_id, returned: create_result.session_id, - })); + }).into()); } *capabilities.write() = create_result.capabilities.unwrap_or_default(); @@ -947,16 +947,16 @@ impl Client { let command_handlers = build_command_handler_map(runtime.commands.as_deref()); let session_fs_provider = runtime.session_fs_provider.take(); if self.inner.session_fs_configured && session_fs_provider.is_none() { - return Err(Error::Session(SessionError::SessionFsProviderRequired)); + return Err(ErrorKind::Session(SessionErrorKind::SessionFsProviderRequired).into()); } if self.inner.session_fs_sqlite_declared && let Some(ref provider) = session_fs_provider && provider.sqlite().is_none() { - return Err(Error::InvalidConfig( + return Err(Error::with_message( + ErrorKind::InvalidConfig, "SessionFs capabilities declare SQLite support but the provider \ - does not implement SessionFsSqliteProvider" - .to_string(), + does not implement SessionFsSqliteProvider", )); } @@ -1017,10 +1017,10 @@ impl Client { .into(); if cli_session_id != session_id { registration.cleanup(event_loop).await; - return Err(Error::Session(SessionError::SessionIdMismatch { + return Err(ErrorKind::Session(SessionErrorKind::SessionIdMismatch { requested: session_id, returned: cli_session_id, - })); + }).into()); } let resume_capabilities: Option = result @@ -1150,7 +1150,7 @@ fn spawn_event_loop( if let Some(waiter) = idle_waiter.lock().take() { let _ = waiter .tx - .send(Err(Error::Session(SessionError::EventLoopClosed))); + .send(Err(ErrorKind::Session(SessionErrorKind::EventLoopClosed).into())); } } .instrument(span), @@ -1255,7 +1255,10 @@ async fn handle_notification( .unwrap_or_else(|| "session error".to_string()); let _ = waiter .tx - .send(Err(Error::Session(SessionError::AgentError(error_msg)))); + .send(Err(Error::with_message( + ErrorKind::Session(SessionErrorKind::AgentError), + error_msg, + ))); } } } diff --git a/rust/src/session_fs.rs b/rust/src/session_fs.rs index 0e13be7d7..e4433e909 100644 --- a/rust/src/session_fs.rs +++ b/rust/src/session_fs.rs @@ -17,8 +17,9 @@ //! //! Provider methods return [`Result`]. The SDK adapts these into //! the schema's `{ ..., error: Option }` payload, mapping -//! [`FsError::NotFound`] to the wire's `ENOENT` and everything else to -//! `UNKNOWN`. A [`From`] conversion is provided so handlers +//! [`FsErrorKind::NotFound`](crate::session_fs::FsErrorKind::NotFound) to +//! the wire's `ENOENT` and everything else to `UNKNOWN`. +//! A [`From`] conversion is provided so handlers //! backed by [`tokio::fs`](https://docs.rs/tokio/latest/tokio/fs/index.html) //! can propagate `io::Error` with `?`. //! @@ -40,15 +41,13 @@ //! } //! ``` -use std::collections::HashMap; - use async_trait::async_trait; - +use std::{borrow::{Borrow, Cow}, collections::HashMap, fmt}; pub use crate::generated::api_types::SessionFsSqliteQueryType; -use crate::generated::api_types::{ +use crate::{Custom, Repr, generated::api_types::{ SessionFsError, SessionFsErrorCode, SessionFsReaddirWithTypesEntry, SessionFsReaddirWithTypesEntryType, SessionFsSetProviderConventions, SessionFsStatResult, -}; +}}; /// Optional capabilities declared by a session filesystem provider. #[non_exhaustive] @@ -135,49 +134,129 @@ impl SessionFsConventions { } } -/// Error returned by a [`SessionFsProvider`] method. +/// Error kind returned by a [`SessionFsProvider`] method. /// -/// The SDK maps this onto the wire schema's [`SessionFsError`]: -/// [`FsError::NotFound`] becomes `ENOENT`, everything else becomes `UNKNOWN`. +/// The SDK maps this onto the wire schema's `SessionFsError`: +/// [`FsErrorKind::NotFound`] becomes `ENOENT`, everything else becomes `UNKNOWN`. +#[derive(Clone, Debug, PartialEq, Eq)] #[non_exhaustive] -#[derive(Debug, Clone, thiserror::Error)] -pub enum FsError { +pub enum FsErrorKind { /// File or directory does not exist. - #[error("not found: {0}")] NotFound(String), /// Any other filesystem error (permission denied, I/O error, etc.). - /// - /// The wire mapping always uses `UNKNOWN` as the code; the message is - /// preserved for diagnostics. - #[error("{0}")] - Other(String), + Other, +} + +impl fmt::Display for FsErrorKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + FsErrorKind::NotFound(path) => write!(f, "not found: {path}"), + FsErrorKind::Other => write!(f, "filesystem error"), + } + } +} + +/// Error returned by a [`crate::session_fs::SessionFsProvider`] method. +/// +/// The SDK maps this onto the wire schema's `SessionFsError`: +/// [`FsErrorKind::NotFound`] becomes `ENOENT`, everything else becomes `UNKNOWN`. +#[derive(Debug)] +pub struct FsError { + repr: Repr, } impl FsError { + /// Construct a `FsError` wrapping a source error. + pub fn new(kind: FsErrorKind, error: E) -> Self + where + E: Into>, + { + Self { + repr: Repr::Custom(Custom { + kind, + error: error.into(), + }), + } + } + + /// The [`FsErrorKind`] of this error. + pub fn kind(&self) -> &FsErrorKind { + match &self.repr { + Repr::Simple(k) + | Repr::SimpleMessage(k, ..) + | Repr::Custom(Custom { kind: k, .. }) => k, + } + } + + /// The message provided when this error was constructed, or `None`. + pub fn message(&self) -> Option<&str> { + match &self.repr { + Repr::SimpleMessage(_, m) => Some(m.borrow()), + _ => None, + } + } + + /// Create a `FsError` with a custom message. + #[must_use] + pub fn with_message(kind: FsErrorKind, message: C) -> Self + where + C: Into>, + { + Self { + repr: Repr::SimpleMessage(kind, message.into()), + } + } + pub(crate) fn into_wire(self) -> SessionFsError { - match self { - Self::NotFound(message) => SessionFsError { + match self.kind() { + FsErrorKind::NotFound(message) => SessionFsError { code: SessionFsErrorCode::ENOENT, - message: Some(message), + message: Some(message.clone()), }, - Self::Other(message) => SessionFsError { + FsErrorKind::Other => SessionFsError { code: SessionFsErrorCode::UNKNOWN, - message: Some(message), + message: Some(self.to_string()), }, } } } +impl fmt::Display for FsError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.repr { + Repr::Simple(k) => write!(f, "{k}"), + Repr::SimpleMessage(_, m) => write!(f, "{m}"), + Repr::Custom(Custom { error, .. }) => write!(f, "{error}"), + } + } +} + +impl std::error::Error for FsError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match &self.repr { + Repr::Custom(Custom { error, .. }) => Some(&**error), + _ => None, + } + } +} + +impl From for FsError { + fn from(kind: FsErrorKind) -> Self { + Self { repr: Repr::Simple(kind) } + } +} + impl From for FsError { fn from(err: std::io::Error) -> Self { match err.kind() { - std::io::ErrorKind::NotFound => Self::NotFound(err.to_string()), - _ => Self::Other(err.to_string()), + std::io::ErrorKind::NotFound => Self::new(FsErrorKind::NotFound(err.to_string()), err), + _ => Self::new(FsErrorKind::Other, err), } } } + /// File or directory metadata returned by [`SessionFsProvider::stat`]. /// /// The SDK adapts this into the wire's [`SessionFsStatResult`]. @@ -296,7 +375,7 @@ impl DirEntry { /// # Forward compatibility /// /// Methods on this trait have default implementations that return -/// `Err(FsError::Other("operation not supported".into()))`. When the CLI +/// `Err(FsError::with_message(FsErrorKind::Other, "operation not supported"))`. When the CLI /// schema grows new `sessionFs.*` methods, the SDK adds them to this trait /// with default impls so existing implementations continue to compile. /// Override only the methods relevant to your backing store. @@ -305,7 +384,7 @@ pub trait SessionFsProvider: Send + Sync + 'static { /// Read the full contents of a file as UTF-8. async fn read_file(&self, path: &str) -> Result { let _ = path; - Err(FsError::Other("read_file not supported".to_string())) + Err(FsError::with_message(FsErrorKind::Other, "read_file not supported")) } /// Write content to a file, creating parent directories if needed. @@ -316,7 +395,7 @@ pub trait SessionFsProvider: Send + Sync + 'static { mode: Option, ) -> Result<(), FsError> { let _ = (path, content, mode); - Err(FsError::Other("write_file not supported".to_string())) + Err(FsError::with_message(FsErrorKind::Other, "write_file not supported")) } /// Append content to a file, creating parent directories if needed. @@ -327,54 +406,52 @@ pub trait SessionFsProvider: Send + Sync + 'static { mode: Option, ) -> Result<(), FsError> { let _ = (path, content, mode); - Err(FsError::Other("append_file not supported".to_string())) + Err(FsError::with_message(FsErrorKind::Other, "append_file not supported")) } /// Check whether a path exists. /// - /// Returns `Ok(false)` for non-existent paths, not [`FsError::NotFound`]. + /// Returns `Ok(false)` for non-existent paths, not [`FsErrorKind::NotFound`]. async fn exists(&self, path: &str) -> Result { let _ = path; - Err(FsError::Other("exists not supported".to_string())) + Err(FsError::with_message(FsErrorKind::Other, "exists not supported")) } /// Get metadata about a file or directory. async fn stat(&self, path: &str) -> Result { let _ = path; - Err(FsError::Other("stat not supported".to_string())) + Err(FsError::with_message(FsErrorKind::Other, "stat not supported")) } /// Create a directory. When `recursive`, missing parents are also created. async fn mkdir(&self, path: &str, recursive: bool, mode: Option) -> Result<(), FsError> { let _ = (path, recursive, mode); - Err(FsError::Other("mkdir not supported".to_string())) + Err(FsError::with_message(FsErrorKind::Other, "mkdir not supported")) } /// List entry names in a directory. async fn readdir(&self, path: &str) -> Result, FsError> { let _ = path; - Err(FsError::Other("readdir not supported".to_string())) + Err(FsError::with_message(FsErrorKind::Other, "readdir not supported")) } /// List directory entries with type information. async fn readdir_with_types(&self, path: &str) -> Result, FsError> { let _ = path; - Err(FsError::Other( - "readdir_with_types not supported".to_string(), - )) + Err(FsError::with_message(FsErrorKind::Other, "readdir_with_types not supported")) } /// Remove a file or directory. When `force`, missing paths are not an /// error. When `recursive`, directory contents are removed as well. async fn rm(&self, path: &str, recursive: bool, force: bool) -> Result<(), FsError> { let _ = (path, recursive, force); - Err(FsError::Other("rm not supported".to_string())) + Err(FsError::with_message(FsErrorKind::Other, "rm not supported")) } /// Rename or move a file or directory. async fn rename(&self, src: &str, dest: &str) -> Result<(), FsError> { let _ = (src, dest); - Err(FsError::Other("rename not supported".to_string())) + Err(FsError::with_message(FsErrorKind::Other, "rename not supported")) } /// Return a reference to the SQLite provider, if this provider supports @@ -443,7 +520,7 @@ mod tests { fn fs_error_maps_io_not_found_to_enoent() { let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "missing.txt"); let fs_err: FsError = io_err.into(); - assert!(matches!(fs_err, FsError::NotFound(_))); + assert!(matches!(fs_err.kind(), FsErrorKind::NotFound(message) if message == "missing.txt")); let wire = fs_err.into_wire(); assert_eq!(wire.code, SessionFsErrorCode::ENOENT); } @@ -452,7 +529,7 @@ mod tests { fn fs_error_maps_other_io_to_unknown() { let io_err = std::io::Error::other("disk full"); let fs_err: FsError = io_err.into(); - assert!(matches!(fs_err, FsError::Other(_))); + assert!(matches!(fs_err.kind(), FsErrorKind::Other)); let wire = fs_err.into_wire(); assert_eq!(wire.code, SessionFsErrorCode::UNKNOWN); assert!(wire.message.unwrap().contains("disk full")); @@ -478,6 +555,6 @@ mod tests { async fn default_impls_return_unsupported() { let p = DefaultProvider; let err = p.read_file("/x").await.unwrap_err(); - assert!(matches!(err, FsError::Other(ref m) if m.contains("not supported"))); + assert!(matches!(err.kind(), FsErrorKind::Other) && err.to_string().contains("not supported")); } } diff --git a/rust/src/subscription.rs b/rust/src/subscription.rs index 69886a195..134822a73 100644 --- a/rust/src/subscription.rs +++ b/rust/src/subscription.rs @@ -23,9 +23,10 @@ //! //! Each subscriber maintains its own internal queue. If a consumer cannot //! keep up, the oldest events are dropped and the next call yields -//! [`Lagged`] reporting how many events were skipped. Slow subscribers do -//! not block the producer. +//! [`Lagged`](crate::subscription::Lagged) reporting how many events were skipped. +//! Slow subscribers do not block the producer. +use std::fmt; use std::pin::Pin; use std::task::{Context, Poll}; @@ -34,6 +35,7 @@ use tokio_stream::wrappers::BroadcastStream; use tokio_stream::wrappers::errors::BroadcastStreamRecvError; use tokio_stream::{Stream, StreamExt as _}; +use crate::{Custom, Repr}; use crate::types::{SessionEvent, SessionLifecycleEvent}; /// The subscription fell behind the producer. @@ -43,9 +45,8 @@ use crate::types::{SessionEvent, SessionLifecycleEvent}; /// after this error, starting from the next live event — callers who care /// about lag should match on it and decide whether to resync, re-fetch, or /// log and continue. -#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] -#[error("subscription lagged behind by {0} events")] -pub struct Lagged(u64); +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Lagged(pub(crate) u64); impl Lagged { /// Number of events skipped before this consumer could read them. @@ -54,19 +55,82 @@ impl Lagged { } } -/// Error returned by [`EventSubscription::recv`] and -/// [`LifecycleSubscription::recv`]. -#[derive(Debug, thiserror::Error)] +impl fmt::Display for Lagged { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "subscription lagged behind by {} events", self.0) + } +} + +impl std::error::Error for Lagged {} + +/// Error kind for subscription receive operations. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] #[non_exhaustive] -pub enum RecvError { +pub enum RecvErrorKind { /// The producer is gone — the session has shut down or the client has /// stopped. No further events will be delivered. - #[error("subscription closed")] Closed, /// The subscriber fell behind. See [`Lagged`]. - #[error(transparent)] - Lagged(#[from] Lagged), + Lagged(Lagged), +} + +impl fmt::Display for RecvErrorKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + RecvErrorKind::Closed => write!(f, "subscription closed"), + RecvErrorKind::Lagged(l) => write!(f, "{l}"), + } + } +} + +/// Error returned by [`crate::subscription::EventSubscription::recv`] and +/// [`crate::subscription::LifecycleSubscription::recv`]. +#[derive(Debug)] +pub struct RecvError { + repr: Repr, +} + +impl RecvError { + /// The [`RecvErrorKind`] of this error. + pub fn kind(&self) -> &RecvErrorKind { + match &self.repr { + Repr::Simple(k) + | Repr::SimpleMessage(k, ..) + | Repr::Custom(Custom { kind: k, .. }) => k, + } + } +} + +impl fmt::Display for RecvError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.repr { + Repr::Simple(k) => write!(f, "{k}"), + Repr::SimpleMessage(_, m) => write!(f, "{m}"), + Repr::Custom(Custom { error, .. }) => write!(f, "{error}"), + } + } +} + +impl std::error::Error for RecvError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match &self.repr { + Repr::Custom(Custom { error, .. }) => Some(&**error), + _ => None, + } + } +} + +impl From for RecvError { + fn from(kind: RecvErrorKind) -> Self { + Self { repr: Repr::Simple(kind) } + } +} + +impl From for RecvError { + fn from(lagged: Lagged) -> Self { + Self::from(RecvErrorKind::Lagged(lagged)) + } } macro_rules! define_subscription { @@ -92,9 +156,9 @@ macro_rules! define_subscription { /// Returns: /// /// - `Ok(event)` for the next delivered event. - /// - `Err(`[`RecvError::Lagged`]`)` if the subscriber fell behind; + /// - `Err(`[`RecvErrorKind::Lagged`]`)` if the subscriber fell behind; /// call `recv` again to continue from the next live event. - /// - `Err(`[`RecvError::Closed`]`)` once the producer is gone. + /// - `Err(`[`RecvErrorKind::Closed`]`)` once the producer is gone. /// /// # Cancel safety /// @@ -107,9 +171,9 @@ macro_rules! define_subscription { match self.inner.next().await { Some(Ok(event)) => Ok(event), Some(Err(BroadcastStreamRecvError::Lagged(n))) => { - Err(RecvError::Lagged(Lagged(n))) + Err(Lagged(n).into()) } - None => Err(RecvError::Closed), + None => Err(RecvErrorKind::Closed.into()), } } } @@ -184,7 +248,7 @@ mod tests { assert_eq!(sub.recv().await.unwrap().id, "a"); assert_eq!(sub.recv().await.unwrap().id, "b"); - assert!(matches!(sub.recv().await, Err(RecvError::Closed))); + assert!(matches!(sub.recv().await.unwrap_err().kind(), RecvErrorKind::Closed)); } #[tokio::test] @@ -194,10 +258,11 @@ mod tests { for id in ["a", "b", "c", "d"] { tx.send(make_event(id)).unwrap(); } - match sub.recv().await { - Err(RecvError::Lagged(l)) => assert_eq!(l.skipped(), 2), - other => panic!("expected Lagged, got {other:?}"), - } + let err = sub.recv().await.expect_err("expected a Lagged error"); + let RecvErrorKind::Lagged(l) = err.kind() else { + panic!("expected Lagged, got {:?}", err.kind()); + }; + assert_eq!(l.skipped(), 2); // Subscription continues with the live tail. assert_eq!(sub.recv().await.unwrap().id, "c"); assert_eq!(sub.recv().await.unwrap().id, "d"); diff --git a/rust/src/tool.rs b/rust/src/tool.rs index b9b44bc0a..189bc6f21 100644 --- a/rust/src/tool.rs +++ b/rust/src/tool.rs @@ -621,7 +621,7 @@ mod tests { use serde::Deserialize; use super::super::*; - use crate::SessionId; + use crate::{ErrorKind, SessionId}; #[derive(Deserialize, schemars::JsonSchema)] struct GetWeatherParams { @@ -712,7 +712,7 @@ mod tests { }; let err = tool.call(inv).await.unwrap_err(); - assert!(matches!(err, Error::Json(_))); + assert!(matches!(err.kind(), ErrorKind::Json)); } #[tokio::test] diff --git a/rust/src/types.rs b/rust/src/types.rs index df5767aa4..bbffcb1a0 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -1374,10 +1374,10 @@ impl SessionConfig { if let Some(handler) = tool.handler.take() && tool_handlers.insert(tool.name.clone(), handler).is_some() { - return Err(crate::Error::InvalidConfig(format!( - "duplicate tool handler registered for name {:?}", - tool.name - ))); + return Err(crate::Error::with_message( + crate::ErrorKind::InvalidConfig, + format!("duplicate tool handler registered for name {:?}", tool.name), + )); } } } @@ -1964,10 +1964,10 @@ impl ResumeSessionConfig { if let Some(handler) = tool.handler.take() && tool_handlers.insert(tool.name.clone(), handler).is_some() { - return Err(crate::Error::InvalidConfig(format!( - "duplicate tool handler registered for name {:?}", - tool.name - ))); + return Err(crate::Error::with_message( + crate::ErrorKind::InvalidConfig, + format!("duplicate tool handler registered for name {:?}", tool.name), + )); } } } diff --git a/rust/tests/e2e/session_fs_sqlite.rs b/rust/tests/e2e/session_fs_sqlite.rs index cd8758c31..7f485c3c2 100644 --- a/rust/tests/e2e/session_fs_sqlite.rs +++ b/rust/tests/e2e/session_fs_sqlite.rs @@ -3,7 +3,7 @@ use std::sync::{Arc, Mutex}; use async_trait::async_trait; use github_copilot_sdk::{ - Client, DirEntry, DirEntryKind, FileInfo, FsError, SessionConfig, SessionFsCapabilities, + Client, DirEntry, DirEntryKind, FileInfo, session_fs::{FsError, FsErrorKind}, SessionConfig, SessionFsCapabilities, SessionFsConfig, SessionFsConventions, SessionFsProvider, SessionFsSqliteProvider, SessionFsSqliteQueryResult, SessionFsSqliteQueryType, }; @@ -53,9 +53,9 @@ impl InMemorySqliteProvider { fn get_or_create_db(db: &mut Option) -> Result<&mut Connection, FsError> { if db.is_none() { - let conn = Connection::open_in_memory().map_err(|e| FsError::Other(e.to_string()))?; + let conn = Connection::open_in_memory().map_err(|e| FsError::new(FsErrorKind::Other, e))?; conn.execute_batch("PRAGMA busy_timeout = 5000;") - .map_err(|e| FsError::Other(e.to_string()))?; + .map_err(|e| FsError::new(FsErrorKind::Other, e))?; *db = Some(conn); } Ok(db.as_mut().unwrap()) @@ -69,7 +69,7 @@ impl SessionFsProvider for InMemorySqliteProvider { files .get(path) .cloned() - .ok_or_else(|| FsError::NotFound(path.to_string())) + .ok_or_else(|| FsError::from(FsErrorKind::NotFound(path.to_string()))) } async fn write_file( @@ -114,7 +114,7 @@ impl SessionFsProvider for InMemorySqliteProvider { } else if let Some(content) = files.get(path) { Ok(FileInfo::new(true, false, content.len() as i64, now, now)) } else { - Err(FsError::NotFound(path.to_string())) + Err(FsError::from(FsErrorKind::NotFound(path.to_string()))) } } @@ -244,7 +244,7 @@ impl SessionFsSqliteProvider for InMemorySqliteProvider { match query_type { SessionFsSqliteQueryType::Exec => { db.execute_batch(trimmed) - .map_err(|e| FsError::Other(e.to_string()))?; + .map_err(|e| FsError::new(FsErrorKind::Other, e))?; Ok(Some(SessionFsSqliteQueryResult { columns: vec![], rows: vec![], @@ -255,21 +255,21 @@ impl SessionFsSqliteProvider for InMemorySqliteProvider { SessionFsSqliteQueryType::Query => { let mut stmt = db .prepare(trimmed) - .map_err(|e| FsError::Other(e.to_string()))?; + .map_err(|e| FsError::new(FsErrorKind::Other, e))?; let col_count = stmt.column_count(); let columns: Vec = (0..col_count) .map(|i| stmt.column_name(i).unwrap().to_string()) .collect(); let mut rows = vec![]; - let mut query_rows = stmt.query([]).map_err(|e| FsError::Other(e.to_string()))?; + let mut query_rows = stmt.query([]).map_err(|e| FsError::new(FsErrorKind::Other, e))?; while let Some(row) = query_rows .next() - .map_err(|e| FsError::Other(e.to_string()))? + .map_err(|e| FsError::new(FsErrorKind::Other, e))? { let mut map = HashMap::new(); for (i, col) in columns.iter().enumerate() { let val: rusqlite::types::Value = - row.get(i).map_err(|e| FsError::Other(e.to_string()))?; + row.get(i).map_err(|e| FsError::new(FsErrorKind::Other, e))?; let json_val = match val { rusqlite::types::Value::Null => serde_json::Value::Null, rusqlite::types::Value::Integer(n) => { @@ -297,7 +297,7 @@ impl SessionFsSqliteProvider for InMemorySqliteProvider { SessionFsSqliteQueryType::Run => { let affected = db .execute(trimmed, []) - .map_err(|e| FsError::Other(e.to_string()))?; + .map_err(|e| FsError::new(FsErrorKind::Other, e))?; let last_id = db.last_insert_rowid(); Ok(Some(SessionFsSqliteQueryResult { columns: vec![], diff --git a/rust/tests/protocol_version_test.rs b/rust/tests/protocol_version_test.rs index fd4eecada..903c01c4a 100644 --- a/rust/tests/protocol_version_test.rs +++ b/rust/tests/protocol_version_test.rs @@ -91,8 +91,8 @@ async fn rejected_when_version_out_of_range() { let (res, version) = verify_with_result(serde_json::json!({ "protocolVersion": 1 })).await; let err = res.unwrap_err(); assert!(matches!( - err, - github_copilot_sdk::Error::Protocol(github_copilot_sdk::ProtocolError::VersionMismatch { + err.kind(), + github_copilot_sdk::ErrorKind::Protocol(github_copilot_sdk::ProtocolErrorKind::VersionMismatch { server: 1, .. }) diff --git a/rust/tests/session_test.rs b/rust/tests/session_test.rs index ed3698951..4922ee8b6 100644 --- a/rust/tests/session_test.rs +++ b/rust/tests/session_test.rs @@ -1597,7 +1597,8 @@ async fn send_and_wait_returns_error_on_session_error() { .unwrap() .unwrap_err(); assert!( - matches!(err, github_copilot_sdk::Error::Session(github_copilot_sdk::SessionError::AgentError(ref msg)) if msg.contains("something went wrong")) + matches!(err.kind(), github_copilot_sdk::ErrorKind::Session(github_copilot_sdk::SessionErrorKind::AgentError)) + && err.to_string().contains("something went wrong") ); } @@ -1626,8 +1627,8 @@ async fn send_and_wait_times_out() { .unwrap() .unwrap_err(); assert!(matches!( - err, - github_copilot_sdk::Error::Session(github_copilot_sdk::SessionError::Timeout(_)) + err.kind(), + github_copilot_sdk::ErrorKind::Session(github_copilot_sdk::SessionErrorKind::Timeout(_)) )); } @@ -2391,17 +2392,17 @@ async fn elicitation_methods_fail_without_capability() { .await .unwrap_err(); assert!(matches!( - err, - github_copilot_sdk::Error::Session( - github_copilot_sdk::SessionError::ElicitationNotSupported + err.kind(), + github_copilot_sdk::ErrorKind::Session( + github_copilot_sdk::SessionErrorKind::ElicitationNotSupported ) )); let err = session.ui().confirm("ok?").await.unwrap_err(); assert!(matches!( - err, - github_copilot_sdk::Error::Session( - github_copilot_sdk::SessionError::ElicitationNotSupported + err.kind(), + github_copilot_sdk::ErrorKind::Session( + github_copilot_sdk::SessionErrorKind::ElicitationNotSupported ) )); } @@ -2878,8 +2879,11 @@ impl CommandHandler for CountingCommandHandler { async fn on_command(&self, ctx: CommandContext) -> Result<(), github_copilot_sdk::Error> { *self.last_ctx.lock() = Some(ctx); if let Some(message) = &self.error_to_return { - Err(github_copilot_sdk::Error::Session( - github_copilot_sdk::SessionError::AgentError(message.clone()), + Err(github_copilot_sdk::Error::with_message( + github_copilot_sdk::ErrorKind::Session( + github_copilot_sdk::SessionErrorKind::AgentError, + ), + message.clone(), )) } else { Ok(()) @@ -3100,7 +3104,7 @@ async fn command_execute_handler_error_propagates_to_ack() { // SessionFsProvider tests -------------------------------------------------- use github_copilot_sdk::session_fs::{ - DirEntry, DirEntryKind, FileInfo, FsError, SessionFsConventions, SessionFsProvider, + DirEntry, DirEntryKind, FileInfo, FsError, FsErrorKind, SessionFsConventions, SessionFsProvider, SessionFsSqliteProvider, SessionFsSqliteQueryResult, SessionFsSqliteQueryType, }; @@ -3130,7 +3134,7 @@ impl SessionFsProvider for RecordingFsProvider { .lock() .get(path) .cloned() - .ok_or_else(|| FsError::NotFound(path.to_string())) + .ok_or_else(|| FsError::from(FsErrorKind::NotFound(path.to_string()))) } async fn write_file( @@ -3149,7 +3153,7 @@ impl SessionFsProvider for RecordingFsProvider { let files = self.files.lock(); let content = files .get(path) - .ok_or_else(|| FsError::NotFound(path.to_string()))?; + .ok_or_else(|| FsError::from(FsErrorKind::NotFound(path.to_string())))?; Ok(FileInfo::new( true, false, @@ -3169,7 +3173,7 @@ impl SessionFsProvider for RecordingFsProvider { async fn rm(&self, path: &str, _recursive: bool, force: bool) -> Result<(), FsError> { let mut files = self.files.lock(); if files.remove(path).is_none() && !force { - return Err(FsError::NotFound(path.to_string())); + return Err(FsError::from(FsErrorKind::NotFound(path.to_string()))); } Ok(()) } @@ -3311,7 +3315,7 @@ async fn session_fs_maps_other_to_unknown() { #[async_trait] impl SessionFsProvider for AlwaysFails { async fn stat(&self, _path: &str) -> Result { - Err(FsError::Other("backing store unavailable".to_string())) + Err(FsError::with_message(FsErrorKind::Other, "backing store unavailable")) } } @@ -3402,11 +3406,11 @@ async fn session_fs_maps_sqlite_errors_to_results() { _query: &str, _params: Option<&std::collections::HashMap>, ) -> Result, FsError> { - Err(FsError::Other("sqlite unavailable".to_string())) + Err(FsError::with_message(FsErrorKind::Other, "sqlite unavailable")) } async fn sqlite_exists(&self) -> Result { - Err(FsError::Other("sqlite unavailable".to_string())) + Err(FsError::with_message(FsErrorKind::Other, "sqlite unavailable")) } } @@ -3534,7 +3538,7 @@ async fn create_session_errors_when_provider_required_but_missing() { // through Client::start; the unit-level behavior is covered by the // SessionError::SessionFsProviderRequired variant being constructible. // This test asserts the error type's display formatting is stable. - let err = github_copilot_sdk::SessionError::SessionFsProviderRequired; + let err = github_copilot_sdk::SessionError::from(github_copilot_sdk::SessionErrorKind::SessionFsProviderRequired); assert!(format!("{err}").contains("session_fs")); } From 10a6c1344b6b4f4df8b0a9f42a946ffeb339f900 Mon Sep 17 00:00:00 2001 From: Heath Stewart Date: Fri, 22 May 2026 20:42:27 -0700 Subject: [PATCH 2/4] Remove unnecessary error structs Sticking with `*Kind` as a convention for error enums. --- rust/src/errors.rs | 112 +------------------------------------ rust/tests/session_test.rs | 2 +- 2 files changed, 3 insertions(+), 111 deletions(-) diff --git a/rust/src/errors.rs b/rust/src/errors.rs index 75e9739ad..a90ccc57c 100644 --- a/rust/src/errors.rs +++ b/rust/src/errors.rs @@ -32,7 +32,7 @@ pub(crate) struct Custom { pub(crate) error: Box, } -// ── ProtocolErrorKind / ProtocolError ───────────────────────────────────────── +// ── ProtocolErrorKind ───────────────────────────────────────── /// Specific protocol-level error kind in the JSON-RPC transport or CLI lifecycle. #[derive(Clone, Debug, PartialEq, Eq)] @@ -101,60 +101,7 @@ impl fmt::Display for ProtocolErrorKind { } } -/// Errors in the JSON-RPC transport or CLI lifecycle. -/// -/// Accessible via [`Error::kind`] when the kind is -/// [`ErrorKind::Protocol`]. -#[derive(Debug)] -pub struct ProtocolError { - repr: Repr, -} - -impl ProtocolError { - /// The [`ProtocolErrorKind`] of this error. - pub fn kind(&self) -> &ProtocolErrorKind { - match &self.repr { - Repr::Simple(k) - | Repr::SimpleMessage(k, ..) - | Repr::Custom(Custom { kind: k, .. }) => k, - } - } - - /// The message provided when this error was constructed, or `None`. - pub fn message(&self) -> Option<&str> { - match &self.repr { - Repr::SimpleMessage(_, m) => Some(m.borrow()), - _ => None, - } - } -} - -impl fmt::Display for ProtocolError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match &self.repr { - Repr::Simple(k) => write!(f, "{k}"), - Repr::SimpleMessage(_, m) => write!(f, "{m}"), - Repr::Custom(Custom { error, .. }) => write!(f, "{error}"), - } - } -} - -impl std::error::Error for ProtocolError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match &self.repr { - Repr::Custom(Custom { error, .. }) => Some(&**error), - _ => None, - } - } -} - -impl From for ProtocolError { - fn from(kind: ProtocolErrorKind) -> Self { - Self { repr: Repr::Simple(kind) } - } -} - -// ── SessionErrorKind / SessionError ─────────────────────────────────────────── +// ── SessionErrorKind ─────────────────────────────────────────── /// Session-scoped error kind. #[derive(Clone, Debug, PartialEq, Eq)] @@ -231,61 +178,6 @@ impl fmt::Display for SessionErrorKind { } } -/// Session-scoped errors. -/// -/// Accessible via [`Error::kind`] when the kind is [`ErrorKind::Session`]. -#[derive(Debug)] -pub struct SessionError { - repr: Repr, -} - -impl SessionError { - /// The [`SessionErrorKind`] of this error. - pub fn kind(&self) -> &SessionErrorKind { - match &self.repr { - Repr::Simple(k) - | Repr::SimpleMessage(k, ..) - | Repr::Custom(Custom { kind: k, .. }) => k, - } - } - - /// The message provided when this error was constructed, or `None`. - pub fn message(&self) -> Option<&str> { - match &self.repr { - Repr::SimpleMessage(_, m) => Some(m.borrow()), - _ => None, - } - } -} - -impl fmt::Display for SessionError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if let &SessionErrorKind::InvalidSessionFsConfig = self.kind() { - write!(f, "{}: ", SessionErrorKind::InvalidSessionFsConfig)?; - } - match &self.repr { - Repr::Simple(k) => write!(f, "{k}"), - Repr::SimpleMessage(_, m) => write!(f, "{m}"), - Repr::Custom(Custom { error, .. }) => write!(f, "{error}"), - } - } -} - -impl std::error::Error for SessionError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match &self.repr { - Repr::Custom(Custom { error, .. }) => Some(&**error), - _ => None, - } - } -} - -impl From for SessionError { - fn from(kind: SessionErrorKind) -> Self { - Self { repr: Repr::Simple(kind) } - } -} - // ── ErrorKind ───────────────────────────────────────────────────────────────── /// The kind of [`Error`]. diff --git a/rust/tests/session_test.rs b/rust/tests/session_test.rs index 4922ee8b6..6950aedd8 100644 --- a/rust/tests/session_test.rs +++ b/rust/tests/session_test.rs @@ -3538,7 +3538,7 @@ async fn create_session_errors_when_provider_required_but_missing() { // through Client::start; the unit-level behavior is covered by the // SessionError::SessionFsProviderRequired variant being constructible. // This test asserts the error type's display formatting is stable. - let err = github_copilot_sdk::SessionError::from(github_copilot_sdk::SessionErrorKind::SessionFsProviderRequired); + let err = github_copilot_sdk::SessionErrorKind::SessionFsProviderRequired; assert!(format!("{err}").contains("session_fs")); } From 1ee00b738f847f240069057a010db122102b172f Mon Sep 17 00:00:00 2001 From: Heath Stewart Date: Fri, 22 May 2026 20:50:34 -0700 Subject: [PATCH 3/4] Add backtrace support to Error struct Enhanced the Error struct to include an optional backtrace, which is captured only when `RUST_BACKTRACE` is set. This change helps in debugging by providing context on error occurrences without inflating the Error size unnecessarily. --- rust/src/errors.rs | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/rust/src/errors.rs b/rust/src/errors.rs index a90ccc57c..ed9c49194 100644 --- a/rust/src/errors.rs +++ b/rust/src/errors.rs @@ -1,9 +1,7 @@ //! Crate errors. use std::{ - borrow::{Borrow, Cow}, - fmt, - time::Duration, + backtrace::{Backtrace, BacktraceStatus}, borrow::{Borrow, Cow}, fmt, time::Duration }; use crate::types::SessionId; @@ -228,9 +226,11 @@ impl fmt::Display for ErrorKind { } /// Errors returned by the SDK. -#[derive(Debug)] pub struct Error { repr: Repr, + // Only `Some` when `RUST_BACKTRACE` is set; boxed so the `Some` variant + // doesn't inflate `Error` beyond `clippy::result_large_err` limits. + backtrace: Option>, } impl Error { @@ -244,6 +244,7 @@ impl Error { kind, error: error.into(), }), + backtrace: capture_backtrace(), } } @@ -272,6 +273,7 @@ impl Error { { Self { repr: Repr::SimpleMessage(kind, message.into()), + backtrace: capture_backtrace(), } } @@ -308,6 +310,17 @@ impl fmt::Display for Error { } } +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut dbg = f.debug_struct("Error"); + dbg.field("context", &self.repr); + if let Some(backtrace) = &self.backtrace { + return dbg.field("backtrace", backtrace).finish(); + } + dbg.finish_non_exhaustive() + } +} + impl std::error::Error for Error { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match &self.repr { @@ -321,6 +334,7 @@ impl From for Error { fn from(kind: ErrorKind) -> Self { Self { repr: Repr::Simple(kind), + backtrace: capture_backtrace(), } } } @@ -349,6 +363,16 @@ impl From for Error { } } +#[inline(always)] +fn capture_backtrace() -> Option> { + let backtrace = Backtrace::capture(); + if backtrace.status() == BacktraceStatus::Captured { + Some(Box::new(backtrace)) + } else { + None + } +} + /// Aggregate of errors collected during [`crate::Client::stop`]. /// /// `Client::stop` performs cooperative shutdown across every active From e99a1e7683125c0e14fb9a20576af4feb410b311 Mon Sep 17 00:00:00 2001 From: Heath Stewart Date: Fri, 22 May 2026 21:15:00 -0700 Subject: [PATCH 4/4] Resolve PR feedback --- .github/skills/rust-coding-skill/SKILL.md | 2 +- .vscode/settings.json | 4 +++- rust/src/errors.rs | 26 +++++++++++++++++------ rust/src/subscription.rs | 4 ++-- 4 files changed, 25 insertions(+), 11 deletions(-) diff --git a/.github/skills/rust-coding-skill/SKILL.md b/.github/skills/rust-coding-skill/SKILL.md index 858dbea51..b33cd2c43 100644 --- a/.github/skills/rust-coding-skill/SKILL.md +++ b/.github/skills/rust-coding-skill/SKILL.md @@ -13,7 +13,7 @@ Opinionated Rust rules for the Copilot Rust SDK (`rust/`). Priority order: ## Error handling -The SDK's public error type is `crate::Error` (`rust/src/error.rs`). Add new +The SDK's public error type is `crate::Error` (`rust/src/errors.rs`). Add new variants to `crate::ErrorKind` rather than introducing parallel error enums per module — every public failure mode is part of the API contract and should be expressible in one type. diff --git a/.vscode/settings.json b/.vscode/settings.json index 0345a3f38..482d009aa 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -16,7 +16,9 @@ "python.testing.pytestArgs": ["python"], "rust-analyzer.cargo.features": "all", "rust-analyzer.check.command": "clippy", - "rust-analyzer.check.features": "all", + "[rust]": { + "editor.defaultFormatter": "rust-lang.rust-analyzer" + }, "[python]": { "editor.defaultFormatter": "charliermarsh.ruff" }, diff --git a/rust/src/errors.rs b/rust/src/errors.rs index ed9c49194..6670e77aa 100644 --- a/rust/src/errors.rs +++ b/rust/src/errors.rs @@ -1,9 +1,12 @@ //! Crate errors. +use crate::types::SessionId; use std::{ - backtrace::{Backtrace, BacktraceStatus}, borrow::{Borrow, Cow}, fmt, time::Duration + backtrace::{Backtrace, BacktraceStatus}, + borrow::{Borrow, Cow}, + fmt, + time::Duration, }; -use crate::types::SessionId; /// Crate-specific [`Result`](std::result::Result). pub type Result = std::result::Result; @@ -168,7 +171,10 @@ impl fmt::Display for SessionErrorKind { SessionErrorKind::InvalidSessionFsConfig => { write!(f, "invalid SessionFsConfig") } - SessionErrorKind::SessionIdMismatch { requested, returned } => write!( + SessionErrorKind::SessionIdMismatch { + requested, + returned, + } => write!( f, "CLI returned session ID {returned} after SDK registered {requested}" ), @@ -214,7 +220,10 @@ impl fmt::Display for ErrorKind { ErrorKind::Session(k) => write!(f, "{k}"), ErrorKind::Io => write!(f, "I/O error"), ErrorKind::Json => write!(f, "JSON error"), - ErrorKind::BinaryNotFound { name, hint: Some(h) } => { + ErrorKind::BinaryNotFound { + name, + hint: Some(h), + } => { write!(f, "binary not found: {name} ({h})") } ErrorKind::BinaryNotFound { name, hint: None } => { @@ -299,12 +308,15 @@ impl Error { impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if let &ErrorKind::Rpc { code } = self.kind() { - write!(f, "{}: ", ErrorKind::Rpc { code })?; - } match &self.repr { Repr::Simple(kind) => write!(f, "{kind}"), + Repr::SimpleMessage(kind, message) if matches!(kind, ErrorKind::Rpc { code: _ }) => { + write!(f, "{kind}: {message}") + } Repr::SimpleMessage(_, message) => write!(f, "{message}"), + Repr::Custom(Custom { kind, error }) if matches!(kind, ErrorKind::Rpc { code: _ }) => { + write!(f, "{kind}: {error}") + } Repr::Custom(Custom { error, .. }) => write!(f, "{error}"), } } diff --git a/rust/src/subscription.rs b/rust/src/subscription.rs index 134822a73..83d31dfc2 100644 --- a/rust/src/subscription.rs +++ b/rust/src/subscription.rs @@ -156,9 +156,9 @@ macro_rules! define_subscription { /// Returns: /// /// - `Ok(event)` for the next delivered event. - /// - `Err(`[`RecvErrorKind::Lagged`]`)` if the subscriber fell behind; + /// - `Err(`[`RecvError`]`)` with [`RecvError::kind()`] [`RecvErrorKind::Lagged`] if the subscriber fell behind; /// call `recv` again to continue from the next live event. - /// - `Err(`[`RecvErrorKind::Closed`]`)` once the producer is gone. + /// - `Err(`[`RecvError`]`)` with [`RecvError::kind()`] [`RecvErrorKind::Closed`] once the producer is gone. /// /// # Cancel safety ///