diff --git a/internal/agent/agent.go b/internal/agent/agent.go new file mode 100644 index 00000000..625fe444 --- /dev/null +++ b/internal/agent/agent.go @@ -0,0 +1,66 @@ +// Package agent detects whether the Go SQL driver is being invoked by an AI +// coding agent by checking for well-known environment variables that agents set +// in their spawned shell processes. +// +// Detection only succeeds when exactly one agent environment variable is +// present, to avoid ambiguous attribution when multiple agent environments +// overlap. +// +// Adding a new agent requires only a new constant and a new entry in +// knownAgents. +// +// References for each environment variable: +// - ANTIGRAVITY_AGENT: Closed source. Google Antigravity sets this variable. +// - CLAUDECODE: https://github.com/anthropics/claude-code (sets CLAUDECODE=1) +// - CLINE_ACTIVE: https://github.com/cline/cline (shipped in v3.24.0) +// - CODEX_CI: https://github.com/openai/codex (part of UNIFIED_EXEC_ENV array in codex-rs) +// - CURSOR_AGENT: Closed source. Referenced in a gist by johnlindquist. +// - GEMINI_CLI: https://google-gemini.github.io/gemini-cli/docs/tools/shell.html (sets GEMINI_CLI=1) +// - OPENCODE: https://github.com/opencode-ai/opencode (sets OPENCODE=1) +package agent + +import "os" + +const ( + Antigravity = "antigravity" + ClaudeCode = "claude-code" + Cline = "cline" + Codex = "codex" + Cursor = "cursor" + GeminiCLI = "gemini-cli" + OpenCode = "opencode" +) + +var knownAgents = []struct { + envVar string + product string +}{ + {"ANTIGRAVITY_AGENT", Antigravity}, + {"CLAUDECODE", ClaudeCode}, + {"CLINE_ACTIVE", Cline}, + {"CODEX_CI", Codex}, + {"CURSOR_AGENT", Cursor}, + {"GEMINI_CLI", GeminiCLI}, + {"OPENCODE", OpenCode}, +} + +// Detect returns the product string of the AI coding agent driving the current +// process, or an empty string if no agent (or multiple agents) are detected. +func Detect() string { + return detect(os.Getenv) +} + +// detect is the internal implementation that accepts an env lookup function +// for testability. +func detect(getenv func(string) string) string { + var detected []string + for _, a := range knownAgents { + if getenv(a.envVar) != "" { + detected = append(detected, a.product) + } + } + if len(detected) == 1 { + return detected[0] + } + return "" +} diff --git a/internal/agent/agent_test.go b/internal/agent/agent_test.go new file mode 100644 index 00000000..a802aa45 --- /dev/null +++ b/internal/agent/agent_test.go @@ -0,0 +1,71 @@ +package agent + +import ( + "testing" +) + +func envWith(vars map[string]string) func(string) string { + return func(key string) string { + return vars[key] + } +} + +func TestDetectsSingleAgent(t *testing.T) { + cases := []struct { + envVar string + product string + }{ + {"ANTIGRAVITY_AGENT", Antigravity}, + {"CLAUDECODE", ClaudeCode}, + {"CLINE_ACTIVE", Cline}, + {"CODEX_CI", Codex}, + {"CURSOR_AGENT", Cursor}, + {"GEMINI_CLI", GeminiCLI}, + {"OPENCODE", OpenCode}, + } + for _, tc := range cases { + t.Run(tc.product, func(t *testing.T) { + got := detect(envWith(map[string]string{tc.envVar: "1"})) + if got != tc.product { + t.Errorf("detect() = %q, want %q", got, tc.product) + } + }) + } +} + +func TestReturnsEmptyWhenNoAgent(t *testing.T) { + got := detect(envWith(map[string]string{})) + if got != "" { + t.Errorf("detect() = %q, want empty", got) + } +} + +func TestReturnsEmptyWhenMultipleAgents(t *testing.T) { + got := detect(envWith(map[string]string{ + "CLAUDECODE": "1", + "CURSOR_AGENT": "1", + })) + if got != "" { + t.Errorf("detect() = %q, want empty", got) + } +} + +func TestIgnoresEmptyValues(t *testing.T) { + got := detect(envWith(map[string]string{"CLAUDECODE": ""})) + if got != "" { + t.Errorf("detect() = %q, want empty", got) + } +} + +func TestDetectUsesOsGetenv(t *testing.T) { + // Clear all known agent env vars, then set one + for _, a := range knownAgents { + t.Setenv(a.envVar, "") + } + t.Setenv("CLAUDECODE", "1") + + got := Detect() + if got != ClaudeCode { + t.Errorf("Detect() = %q, want %q", got, ClaudeCode) + } +} diff --git a/internal/client/client.go b/internal/client/client.go index 0b2ebab0..b644c294 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -20,6 +20,7 @@ import ( "time" dbsqlerr "github.com/databricks/databricks-sql-go/errors" + "github.com/databricks/databricks-sql-go/internal/agent" dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors" "github.com/apache/thrift/lib/go/thrift" @@ -295,6 +296,9 @@ func InitThriftClient(cfg *config.Config, httpclient *http.Client) (*ThriftServi if cfg.UserAgentEntry != "" { userAgent = fmt.Sprintf("%s/%s (%s)", cfg.DriverName, cfg.DriverVersion, cfg.UserAgentEntry) } + if agentProduct := agent.Detect(); agentProduct != "" { + userAgent = fmt.Sprintf("%s agent/%s", userAgent, agentProduct) + } thriftHttpClient.SetHeader("User-Agent", userAgent) default: