From d4d1e64414287fe68ba317d1f000da9915aa7a84 Mon Sep 17 00:00:00 2001 From: Anton Nekipelov <226657+anton-107@users.noreply.github.com> Date: Wed, 24 Jun 2026 16:54:01 +0200 Subject: [PATCH] experimental/ssh: add --environment flag to ssh connect Adds a user-facing --environment flag mapping to compute.Environment.BaseEnvironment in the serverless submit path. It accepts an env.yaml path, a workspace-base-environments resource ID, or a bare display name (resolved via ListWorkspaceBaseEnvironments). - Rejects --environment with --environment-version or --cluster. - Serializes through ToProxyCommand so it survives serverless reconnect. - Folds the environment into the default connection name so distinct environments map to distinct servers (mirrors --accelerator). - Adds an environment slot to SshTunnelEvent telemetry. DECO-27423 Co-authored-by: Isaac --- experimental/ssh/cmd/connect.go | 13 ++- experimental/ssh/internal/client/client.go | 85 +++++++++++++++++-- .../internal/client/client_internal_test.go | 61 +++++++++++++ .../ssh/internal/client/client_test.go | 45 +++++++++- libs/telemetry/protos/ssh_tunnel.go | 4 + 5 files changed, 196 insertions(+), 12 deletions(-) diff --git a/experimental/ssh/cmd/connect.go b/experimental/ssh/cmd/connect.go index 6daaca6db7a..5f32f343af5 100644 --- a/experimental/ssh/cmd/connect.go +++ b/experimental/ssh/cmd/connect.go @@ -18,6 +18,7 @@ func newConnectCommand() *cobra.Command { Connect to serverless: databricks ssh connect databricks ssh connect --accelerator= # AI Runtime + databricks ssh connect --environment= # custom base environment Connect to a dedicated cluster: databricks ssh connect --cluster=`, @@ -38,6 +39,7 @@ Connect to a dedicated cluster: var liteswap string var skipSettingsCheck bool var environmentVersion int + var environment string var autoApprove bool cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks dedicated cluster ID") @@ -71,6 +73,8 @@ Connect to a dedicated cluster: cmd.Flags().IntVar(&environmentVersion, "environment-version", defaultEnvironmentVersion, "Environment version for AI Runtime") cmd.Flags().MarkHidden("environment-version") + cmd.Flags().StringVar(&environment, "environment", "", "Custom base environment for serverless compute: an env.yaml path, a workspace-base-environments resource ID, or a display name") + cmd.Flags().BoolVar(&autoApprove, "auto-approve", false, "Skip confirmation prompts, installing IDE extensions and applying IDE settings without asking") cmd.PreRunE = func(cmd *cobra.Command, args []string) error { @@ -88,7 +92,7 @@ Connect to a dedicated cluster: ctx := cmd.Context() wsClient := cmdctx.WorkspaceClient(ctx) if connectionName == "" && clusterID == "" && !proxyMode { - connectionName = client.GenerateDefaultConnectionName(wsClient.Config.Host, accelerator) + connectionName = client.GenerateDefaultConnectionName(wsClient.Config.Host, accelerator, environment) } // Serverless GPU compute can take much longer to provision than CPU compute, // so allow extra time for the SSH server job to start. @@ -96,6 +100,12 @@ Connect to a dedicated cluster: if accelerator != "" { startupTimeout = gpuTaskStartupTimeout } + // Only carry an explicitly-set environment version. Leaving it at 0 otherwise + // lets the submit path default to minEnvironmentVersion and lets Validate + // detect a real --environment-version + --environment conflict. + if !cmd.Flags().Changed("environment-version") { + environmentVersion = 0 + } opts := client.ClientOptions{ Profile: wsClient.Config.Profile, ClusterID: clusterID, @@ -117,6 +127,7 @@ Connect to a dedicated cluster: Liteswap: liteswap, SkipSettingsCheck: skipSettingsCheck, EnvironmentVersion: environmentVersion, + Environment: environment, AdditionalArgs: args, AutoApprove: autoApprove, } diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index e79f99259fb..478298a2eab 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -34,6 +34,7 @@ import ( "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/retries" "github.com/databricks/databricks-sdk-go/service/compute" + "github.com/databricks/databricks-sdk-go/service/environments" "github.com/databricks/databricks-sdk-go/service/iam" "github.com/databricks/databricks-sdk-go/service/jobs" "github.com/databricks/databricks-sdk-go/service/workspace" @@ -112,6 +113,10 @@ type ClientOptions struct { SkipSettingsCheck bool // Environment version for serverless compute. EnvironmentVersion int + // Base environment for serverless compute. Accepts an env.yaml path (leading "/"), + // a "workspace-base-environments/..." resource ID, or a bare display name resolved + // against the workspace base environments. Maps to compute.Environment.BaseEnvironment. + Environment string // If true, skip confirmation prompts for IDE extension install and IDE settings updates. AutoApprove bool } @@ -135,15 +140,31 @@ func (o *ClientOptions) Validate() error { if o.EnvironmentVersion > 0 && o.EnvironmentVersion < minEnvironmentVersion { return fmt.Errorf("environment version must be >= %d, got %d", minEnvironmentVersion, o.EnvironmentVersion) } + // base_environment and environment_version are mutually exclusive in the SDK, + // and a custom base environment only applies to serverless compute. + if o.Environment != "" && o.EnvironmentVersion > 0 { + return errors.New("--environment cannot be used together with --environment-version") + } + if o.Environment != "" && o.ClusterID != "" { + return errors.New("--environment can only be used with serverless compute") + } return nil } // GenerateDefaultConnectionName creates a deterministic connection name from -// the workspace host and accelerator type. The name includes a hash of the -// workspace host so that different workspaces produce different names, -// avoiding SSH known_hosts conflicts. -func GenerateDefaultConnectionName(host, accelerator string) string { - h := md5.Sum([]byte(host)) +// the workspace host, accelerator type, and base environment. The name includes +// a hash so that different workspaces produce different names (avoiding SSH +// known_hosts conflicts). The environment is folded into the hash because a +// serverless server bakes in its environment at startup: distinct environments +// must map to distinct connection names so they don't reuse each other's server. +func GenerateDefaultConnectionName(host, accelerator, environment string) string { + // Keep the hash host-only when no environment is set so existing default + // connection names are preserved. + hashInput := host + if environment != "" { + hashInput = host + "\x00" + environment + } + h := md5.Sum([]byte(hashInput)) hashStr := hex.EncodeToString(h[:4]) if accelerator != "" { acc := strings.ToLower(strings.ReplaceAll(accelerator, "_", "-")) @@ -218,6 +239,12 @@ func (o *ClientOptions) ToProxyCommand() (string, error) { proxyCommand += " --environment-version=" + strconv.Itoa(o.EnvironmentVersion) } + if o.Environment != "" { + // Quote the value: env.yaml paths and display names may contain spaces, + // unlike the other (space-free) flag values serialized above. + proxyCommand += " --environment=" + strconv.Quote(o.Environment) + } + return proxyCommand, nil } @@ -601,12 +628,22 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, } if opts.IsServerlessMode() { + // base_environment and environment_version are mutually exclusive: a custom + // base environment carries its own version, so we don't also set one. + var spec compute.Environment + if opts.Environment != "" { + baseEnvironment, err := resolveBaseEnvironment(ctx, client, opts.Environment) + if err != nil { + return 0, err + } + spec.BaseEnvironment = baseEnvironment + } else { + spec.EnvironmentVersion = strconv.Itoa(max(opts.EnvironmentVersion, minEnvironmentVersion)) + } submitRequest.Environments = []jobs.JobEnvironment{ { EnvironmentKey: serverlessEnvironmentKey, - Spec: &compute.Environment{ - EnvironmentVersion: strconv.Itoa(max(opts.EnvironmentVersion, minEnvironmentVersion)), - }, + Spec: &spec, }, } } @@ -622,6 +659,37 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, return waiter.RunId, waitForJobToStart(ctx, client, waiter.RunId, opts) } +// resolveBaseEnvironment maps the user-provided --environment value to a +// compute.Environment.BaseEnvironment string. A leading "/" is an env.yaml path and a +// "workspace-base-environments/" prefix is a resource ID; both are passed through +// verbatim. Anything else is treated as a display name and resolved to its resource ID +// via the workspace base environments listing. +func resolveBaseEnvironment(ctx context.Context, client *databricks.WorkspaceClient, input string) (string, error) { + if strings.HasPrefix(input, "/") || strings.HasPrefix(input, "workspace-base-environments/") { + return input, nil + } + + envs, err := client.Environments.ListWorkspaceBaseEnvironmentsAll(ctx, environments.ListWorkspaceBaseEnvironmentsRequest{}) + if err != nil { + return "", fmt.Errorf("failed to list workspace base environments: %w", err) + } + + var matches []string + for _, e := range envs { + if e.DisplayName == input { + matches = append(matches, e.Name) + } + } + switch len(matches) { + case 0: + return "", fmt.Errorf("no workspace base environment found with display name %q", input) + case 1: + return matches[0], nil + default: + return "", fmt.Errorf("multiple workspace base environments found with display name %q", input) + } +} + // shellSingleQuote wraps s in single quotes for safe inclusion in a shell // command, escaping any embedded single quotes. func shellSingleQuote(s string) string { @@ -1042,6 +1110,7 @@ func logSshTunnelEvent(ctx context.Context, opts ClientOptions, isSuccess, isRec SshTunnelEvent: &protos.SshTunnelEvent{ ComputeType: computeType, AcceleratorType: opts.Accelerator, + Environment: opts.Environment, IdeType: opts.IDE, ClientMode: clientMode, IsReconnect: isReconnect, diff --git a/experimental/ssh/internal/client/client_internal_test.go b/experimental/ssh/internal/client/client_internal_test.go index 292f9dd0db3..d2bfde57833 100644 --- a/experimental/ssh/internal/client/client_internal_test.go +++ b/experimental/ssh/internal/client/client_internal_test.go @@ -7,6 +7,7 @@ import ( "github.com/databricks/cli/libs/cmdio" "github.com/databricks/databricks-sdk-go/experimental/mocks" + "github.com/databricks/databricks-sdk-go/service/environments" "github.com/databricks/databricks-sdk-go/service/jobs" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -29,6 +30,66 @@ func terminatedRun(runID, taskRunID int64, message, pageURL string) *jobs.Run { } } +func TestResolveBaseEnvironment(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + + // Paths and resource IDs are passed through without hitting the API. + t.Run("env.yaml path passthrough", func(t *testing.T) { + m := mocks.NewMockWorkspaceClient(t) + got, err := resolveBaseEnvironment(ctx, m.WorkspaceClient, "/Workspace/path/to/env.yaml") + require.NoError(t, err) + assert.Equal(t, "/Workspace/path/to/env.yaml", got) + }) + + t.Run("resource ID passthrough", func(t *testing.T) { + m := mocks.NewMockWorkspaceClient(t) + got, err := resolveBaseEnvironment(ctx, m.WorkspaceClient, "workspace-base-environments/dbe_123") + require.NoError(t, err) + assert.Equal(t, "workspace-base-environments/dbe_123", got) + }) + + t.Run("display name resolves to resource ID", func(t *testing.T) { + m := mocks.NewMockWorkspaceClient(t) + m.GetMockEnvironmentsAPI().EXPECT(). + ListWorkspaceBaseEnvironmentsAll(mock.Anything, environments.ListWorkspaceBaseEnvironmentsRequest{}). + Return([]environments.WorkspaceBaseEnvironment{ + {DisplayName: "other", Name: "workspace-base-environments/dbe_other"}, + {DisplayName: "my-env", Name: "workspace-base-environments/dbe_mine"}, + }, nil) + + got, err := resolveBaseEnvironment(ctx, m.WorkspaceClient, "my-env") + require.NoError(t, err) + assert.Equal(t, "workspace-base-environments/dbe_mine", got) + }) + + t.Run("display name not found", func(t *testing.T) { + m := mocks.NewMockWorkspaceClient(t) + m.GetMockEnvironmentsAPI().EXPECT(). + ListWorkspaceBaseEnvironmentsAll(mock.Anything, environments.ListWorkspaceBaseEnvironmentsRequest{}). + Return([]environments.WorkspaceBaseEnvironment{ + {DisplayName: "other", Name: "workspace-base-environments/dbe_other"}, + }, nil) + + _, err := resolveBaseEnvironment(ctx, m.WorkspaceClient, "my-env") + require.Error(t, err) + assert.Contains(t, err.Error(), `no workspace base environment found with display name "my-env"`) + }) + + t.Run("display name ambiguous", func(t *testing.T) { + m := mocks.NewMockWorkspaceClient(t) + m.GetMockEnvironmentsAPI().EXPECT(). + ListWorkspaceBaseEnvironmentsAll(mock.Anything, environments.ListWorkspaceBaseEnvironmentsRequest{}). + Return([]environments.WorkspaceBaseEnvironment{ + {DisplayName: "my-env", Name: "workspace-base-environments/dbe_1"}, + {DisplayName: "my-env", Name: "workspace-base-environments/dbe_2"}, + }, nil) + + _, err := resolveBaseEnvironment(ctx, m.WorkspaceClient, "my-env") + require.Error(t, err) + assert.Contains(t, err.Error(), `multiple workspace base environments found with display name "my-env"`) + }) +} + func TestDescribeRunFailureIncludesMessageTraceAndURL(t *testing.T) { ctx := cmdio.MockDiscard(t.Context()) m := mocks.NewMockWorkspaceClient(t) diff --git a/experimental/ssh/internal/client/client_test.go b/experimental/ssh/internal/client/client_test.go index ef9e6fb53b7..468ae7dfb0a 100644 --- a/experimental/ssh/internal/client/client_test.go +++ b/experimental/ssh/internal/client/client_test.go @@ -94,6 +94,24 @@ func TestValidate(t *testing.T) { name: "valid environment version", opts: client.ClientOptions{ClusterID: "abc-123", EnvironmentVersion: 4}, }, + { + name: "environment with environment version", + opts: client.ClientOptions{ConnectionName: "my-conn", Environment: "my-env", EnvironmentVersion: 4}, + wantErr: "--environment cannot be used together with --environment-version", + }, + { + name: "environment with cluster", + opts: client.ClientOptions{ClusterID: "abc-123", Environment: "my-env"}, + wantErr: "--environment can only be used with serverless compute", + }, + { + name: "valid environment with connection name", + opts: client.ClientOptions{ConnectionName: "my-conn", Environment: "/Workspace/path/to/env.yaml"}, + }, + { + name: "environment with serverless GPU accelerator", + opts: client.ClientOptions{ConnectionName: "my-conn", Accelerator: "GPU_1xA10", Environment: "my-gpu-env"}, + }, } for _, tt := range tests { @@ -146,10 +164,23 @@ func TestGenerateDefaultConnectionName(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := client.GenerateDefaultConnectionName(tt.host, tt.accelerator) + got := client.GenerateDefaultConnectionName(tt.host, tt.accelerator, "") assert.Equal(t, tt.want, got) }) } + + // A serverless server bakes in its environment, so distinct environments must + // map to distinct default names (otherwise --environment is silently ignored + // when an existing server for the default name is reused). + t.Run("environment differentiates the name", func(t *testing.T) { + const host = "https://my-workspace.cloud.databricks.com" + base := client.GenerateDefaultConnectionName(host, "", "") + withEnv := client.GenerateDefaultConnectionName(host, "", "my-env") + otherEnv := client.GenerateDefaultConnectionName(host, "", "other-env") + assert.NotEqual(t, base, withEnv, "setting --environment must change the default name") + assert.NotEqual(t, withEnv, otherEnv, "different environments must produce different names") + assert.Equal(t, withEnv, client.GenerateDefaultConnectionName(host, "", "my-env"), "must be deterministic") + }) } func TestGenerateDefaultConnectionNameMatchesRegex(t *testing.T) { @@ -159,12 +190,15 @@ func TestGenerateDefaultConnectionNameMatchesRegex(t *testing.T) { "https://workspace3.gcp.databricks.com", } accelerators := []string{"", "GPU_1xA10", "GPU_8xH100"} + environments := []string{"", "my-env", "/Workspace/Users/me@example.com/env.yaml"} nameRegex := regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]*$`) for _, host := range hosts { for _, acc := range accelerators { - name := client.GenerateDefaultConnectionName(host, acc) - assert.Regexp(t, nameRegex, name, "host=%q accelerator=%q name=%q", host, acc, name) + for _, env := range environments { + name := client.GenerateDefaultConnectionName(host, acc, env) + assert.Regexp(t, nameRegex, name, "host=%q accelerator=%q environment=%q name=%q", host, acc, env, name) + } } } } @@ -224,6 +258,11 @@ func TestToProxyCommand(t *testing.T) { opts: client.ClientOptions{ClusterID: "abc-123", EnvironmentVersion: 4}, want: quoted + " ssh connect --proxy --cluster=abc-123 --auto-start-cluster=false --shutdown-delay=0s --environment-version=4", }, + { + name: "serverless with environment", + opts: client.ClientOptions{ConnectionName: "my-conn", Environment: "my env", ShutdownDelay: 2 * time.Minute}, + want: quoted + ` ssh connect --proxy --name=my-conn --shutdown-delay=2m0s --environment="my env"`, + }, } for _, tt := range tests { diff --git a/libs/telemetry/protos/ssh_tunnel.go b/libs/telemetry/protos/ssh_tunnel.go index 42be8233b7f..0f867ed7cf5 100644 --- a/libs/telemetry/protos/ssh_tunnel.go +++ b/libs/telemetry/protos/ssh_tunnel.go @@ -26,6 +26,10 @@ type SshTunnelEvent struct { // GPU accelerator type for serverless compute. AcceleratorType string `json:"accelerator_type,omitempty"` + // Base environment specified for serverless compute (raw --environment input: + // an env.yaml path, a workspace-base-environments resource ID, or a display name). + Environment string `json:"environment,omitempty"` + // IDE that initiated the connection (e.g., "vscode", "cursor"). IdeType string `json:"ide_type,omitempty"`