Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion experimental/ssh/cmd/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ func newConnectCommand() *cobra.Command {
Connect to serverless:
databricks ssh connect
databricks ssh connect --accelerator=<GPU_type> # AI Runtime
databricks ssh connect --environment=<name> # custom base environment

Connect to a dedicated cluster:
databricks ssh connect --cluster=<cluster-id>`,
Expand All @@ -38,6 +39,7 @@ Connect to a dedicated cluster:
var liteswap string
var skipSettingsCheck bool
var environmentVersion int
var environment string
var autoApprove bool

cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks dedicated cluster ID")
Expand Down Expand Up @@ -71,6 +73,8 @@ Connect to a dedicated cluster:
cmd.Flags().IntVar(&environmentVersion, "environment-version", defaultEnvironmentVersion, "Environment version for AI Runtime")
cmd.Flags().MarkHidden("environment-version")

cmd.Flags().StringVar(&environment, "environment", "", "Custom base environment for serverless compute: an env.yaml path, a workspace-base-environments resource ID, or a display name")

cmd.Flags().BoolVar(&autoApprove, "auto-approve", false, "Skip confirmation prompts, installing IDE extensions and applying IDE settings without asking")

cmd.PreRunE = func(cmd *cobra.Command, args []string) error {
Expand All @@ -88,14 +92,20 @@ Connect to a dedicated cluster:
ctx := cmd.Context()
wsClient := cmdctx.WorkspaceClient(ctx)
if connectionName == "" && clusterID == "" && !proxyMode {
connectionName = client.GenerateDefaultConnectionName(wsClient.Config.Host, accelerator)
connectionName = client.GenerateDefaultConnectionName(wsClient.Config.Host, accelerator, environment)
}
// Serverless GPU compute can take much longer to provision than CPU compute,
// so allow extra time for the SSH server job to start.
startupTimeout := taskStartupTimeout
if accelerator != "" {
startupTimeout = gpuTaskStartupTimeout
}
// Only carry an explicitly-set environment version. Leaving it at 0 otherwise
// lets the submit path default to minEnvironmentVersion and lets Validate
// detect a real --environment-version + --environment conflict.
if !cmd.Flags().Changed("environment-version") {
environmentVersion = 0
}
opts := client.ClientOptions{
Profile: wsClient.Config.Profile,
ClusterID: clusterID,
Expand All @@ -117,6 +127,7 @@ Connect to a dedicated cluster:
Liteswap: liteswap,
SkipSettingsCheck: skipSettingsCheck,
EnvironmentVersion: environmentVersion,
Environment: environment,
AdditionalArgs: args,
AutoApprove: autoApprove,
}
Expand Down
85 changes: 77 additions & 8 deletions experimental/ssh/internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/retries"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/databricks/databricks-sdk-go/service/environments"
"github.com/databricks/databricks-sdk-go/service/iam"
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/databricks/databricks-sdk-go/service/workspace"
Expand Down Expand Up @@ -112,6 +113,10 @@ type ClientOptions struct {
SkipSettingsCheck bool
// Environment version for serverless compute.
EnvironmentVersion int
// Base environment for serverless compute. Accepts an env.yaml path (leading "/"),
// a "workspace-base-environments/..." resource ID, or a bare display name resolved
// against the workspace base environments. Maps to compute.Environment.BaseEnvironment.
Environment string
// If true, skip confirmation prompts for IDE extension install and IDE settings updates.
AutoApprove bool
}
Expand All @@ -135,15 +140,31 @@ func (o *ClientOptions) Validate() error {
if o.EnvironmentVersion > 0 && o.EnvironmentVersion < minEnvironmentVersion {
return fmt.Errorf("environment version must be >= %d, got %d", minEnvironmentVersion, o.EnvironmentVersion)
}
// base_environment and environment_version are mutually exclusive in the SDK,
// and a custom base environment only applies to serverless compute.
if o.Environment != "" && o.EnvironmentVersion > 0 {
return errors.New("--environment cannot be used together with --environment-version")
}
if o.Environment != "" && o.ClusterID != "" {
return errors.New("--environment can only be used with serverless compute")
}
return nil
}

// GenerateDefaultConnectionName creates a deterministic connection name from
// the workspace host and accelerator type. The name includes a hash of the
// workspace host so that different workspaces produce different names,
// avoiding SSH known_hosts conflicts.
func GenerateDefaultConnectionName(host, accelerator string) string {
h := md5.Sum([]byte(host))
// the workspace host, accelerator type, and base environment. The name includes
// a hash so that different workspaces produce different names (avoiding SSH
// known_hosts conflicts). The environment is folded into the hash because a
// serverless server bakes in its environment at startup: distinct environments
// must map to distinct connection names so they don't reuse each other's server.
func GenerateDefaultConnectionName(host, accelerator, environment string) string {
// Keep the hash host-only when no environment is set so existing default
// connection names are preserved.
hashInput := host
if environment != "" {
hashInput = host + "\x00" + environment
}
h := md5.Sum([]byte(hashInput))
hashStr := hex.EncodeToString(h[:4])
if accelerator != "" {
acc := strings.ToLower(strings.ReplaceAll(accelerator, "_", "-"))
Expand Down Expand Up @@ -218,6 +239,12 @@ func (o *ClientOptions) ToProxyCommand() (string, error) {
proxyCommand += " --environment-version=" + strconv.Itoa(o.EnvironmentVersion)
}

if o.Environment != "" {
// Quote the value: env.yaml paths and display names may contain spaces,
// unlike the other (space-free) flag values serialized above.
proxyCommand += " --environment=" + strconv.Quote(o.Environment)
}

return proxyCommand, nil
}

Expand Down Expand Up @@ -601,12 +628,22 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
}

if opts.IsServerlessMode() {
// base_environment and environment_version are mutually exclusive: a custom
// base environment carries its own version, so we don't also set one.
var spec compute.Environment
if opts.Environment != "" {
baseEnvironment, err := resolveBaseEnvironment(ctx, client, opts.Environment)
if err != nil {
return 0, err
}
spec.BaseEnvironment = baseEnvironment
} else {
spec.EnvironmentVersion = strconv.Itoa(max(opts.EnvironmentVersion, minEnvironmentVersion))
}
submitRequest.Environments = []jobs.JobEnvironment{
{
EnvironmentKey: serverlessEnvironmentKey,
Spec: &compute.Environment{
EnvironmentVersion: strconv.Itoa(max(opts.EnvironmentVersion, minEnvironmentVersion)),
},
Spec: &spec,
},
}
}
Expand All @@ -622,6 +659,37 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
return waiter.RunId, waitForJobToStart(ctx, client, waiter.RunId, opts)
}

// resolveBaseEnvironment maps the user-provided --environment value to a
// compute.Environment.BaseEnvironment string. A leading "/" is an env.yaml path and a
// "workspace-base-environments/" prefix is a resource ID; both are passed through
// verbatim. Anything else is treated as a display name and resolved to its resource ID
// via the workspace base environments listing.
func resolveBaseEnvironment(ctx context.Context, client *databricks.WorkspaceClient, input string) (string, error) {
if strings.HasPrefix(input, "/") || strings.HasPrefix(input, "workspace-base-environments/") {
return input, nil
}

envs, err := client.Environments.ListWorkspaceBaseEnvironmentsAll(ctx, environments.ListWorkspaceBaseEnvironmentsRequest{})
if err != nil {
return "", fmt.Errorf("failed to list workspace base environments: %w", err)
}

var matches []string
for _, e := range envs {
if e.DisplayName == input {
matches = append(matches, e.Name)
}
}
switch len(matches) {
case 0:
return "", fmt.Errorf("no workspace base environment found with display name %q", input)
case 1:
return matches[0], nil
default:
return "", fmt.Errorf("multiple workspace base environments found with display name %q", input)
}
}

// shellSingleQuote wraps s in single quotes for safe inclusion in a shell
// command, escaping any embedded single quotes.
func shellSingleQuote(s string) string {
Expand Down Expand Up @@ -1042,6 +1110,7 @@ func logSshTunnelEvent(ctx context.Context, opts ClientOptions, isSuccess, isRec
SshTunnelEvent: &protos.SshTunnelEvent{
ComputeType: computeType,
AcceleratorType: opts.Accelerator,
Environment: opts.Environment,
IdeType: opts.IDE,
ClientMode: clientMode,
IsReconnect: isReconnect,
Expand Down
61 changes: 61 additions & 0 deletions experimental/ssh/internal/client/client_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/service/environments"
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
Expand All @@ -29,6 +30,66 @@ func terminatedRun(runID, taskRunID int64, message, pageURL string) *jobs.Run {
}
}

func TestResolveBaseEnvironment(t *testing.T) {
ctx := cmdio.MockDiscard(t.Context())

// Paths and resource IDs are passed through without hitting the API.
t.Run("env.yaml path passthrough", func(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
got, err := resolveBaseEnvironment(ctx, m.WorkspaceClient, "/Workspace/path/to/env.yaml")
require.NoError(t, err)
assert.Equal(t, "/Workspace/path/to/env.yaml", got)
})

t.Run("resource ID passthrough", func(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
got, err := resolveBaseEnvironment(ctx, m.WorkspaceClient, "workspace-base-environments/dbe_123")
require.NoError(t, err)
assert.Equal(t, "workspace-base-environments/dbe_123", got)
})

t.Run("display name resolves to resource ID", func(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
m.GetMockEnvironmentsAPI().EXPECT().
ListWorkspaceBaseEnvironmentsAll(mock.Anything, environments.ListWorkspaceBaseEnvironmentsRequest{}).
Return([]environments.WorkspaceBaseEnvironment{
{DisplayName: "other", Name: "workspace-base-environments/dbe_other"},
{DisplayName: "my-env", Name: "workspace-base-environments/dbe_mine"},
}, nil)

got, err := resolveBaseEnvironment(ctx, m.WorkspaceClient, "my-env")
require.NoError(t, err)
assert.Equal(t, "workspace-base-environments/dbe_mine", got)
})

t.Run("display name not found", func(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
m.GetMockEnvironmentsAPI().EXPECT().
ListWorkspaceBaseEnvironmentsAll(mock.Anything, environments.ListWorkspaceBaseEnvironmentsRequest{}).
Return([]environments.WorkspaceBaseEnvironment{
{DisplayName: "other", Name: "workspace-base-environments/dbe_other"},
}, nil)

_, err := resolveBaseEnvironment(ctx, m.WorkspaceClient, "my-env")
require.Error(t, err)
assert.Contains(t, err.Error(), `no workspace base environment found with display name "my-env"`)
})

t.Run("display name ambiguous", func(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
m.GetMockEnvironmentsAPI().EXPECT().
ListWorkspaceBaseEnvironmentsAll(mock.Anything, environments.ListWorkspaceBaseEnvironmentsRequest{}).
Return([]environments.WorkspaceBaseEnvironment{
{DisplayName: "my-env", Name: "workspace-base-environments/dbe_1"},
{DisplayName: "my-env", Name: "workspace-base-environments/dbe_2"},
}, nil)

_, err := resolveBaseEnvironment(ctx, m.WorkspaceClient, "my-env")
require.Error(t, err)
assert.Contains(t, err.Error(), `multiple workspace base environments found with display name "my-env"`)
})
}

func TestDescribeRunFailureIncludesMessageTraceAndURL(t *testing.T) {
ctx := cmdio.MockDiscard(t.Context())
m := mocks.NewMockWorkspaceClient(t)
Expand Down
45 changes: 42 additions & 3 deletions experimental/ssh/internal/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,24 @@ func TestValidate(t *testing.T) {
name: "valid environment version",
opts: client.ClientOptions{ClusterID: "abc-123", EnvironmentVersion: 4},
},
{
name: "environment with environment version",
opts: client.ClientOptions{ConnectionName: "my-conn", Environment: "my-env", EnvironmentVersion: 4},
wantErr: "--environment cannot be used together with --environment-version",
},
{
name: "environment with cluster",
opts: client.ClientOptions{ClusterID: "abc-123", Environment: "my-env"},
wantErr: "--environment can only be used with serverless compute",
},
{
name: "valid environment with connection name",
opts: client.ClientOptions{ConnectionName: "my-conn", Environment: "/Workspace/path/to/env.yaml"},
},
{
name: "environment with serverless GPU accelerator",
opts: client.ClientOptions{ConnectionName: "my-conn", Accelerator: "GPU_1xA10", Environment: "my-gpu-env"},
},
}

for _, tt := range tests {
Expand Down Expand Up @@ -146,10 +164,23 @@ func TestGenerateDefaultConnectionName(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := client.GenerateDefaultConnectionName(tt.host, tt.accelerator)
got := client.GenerateDefaultConnectionName(tt.host, tt.accelerator, "")
assert.Equal(t, tt.want, got)
})
}

// A serverless server bakes in its environment, so distinct environments must
// map to distinct default names (otherwise --environment is silently ignored
// when an existing server for the default name is reused).
t.Run("environment differentiates the name", func(t *testing.T) {
const host = "https://my-workspace.cloud.databricks.com"
base := client.GenerateDefaultConnectionName(host, "", "")
withEnv := client.GenerateDefaultConnectionName(host, "", "my-env")
otherEnv := client.GenerateDefaultConnectionName(host, "", "other-env")
assert.NotEqual(t, base, withEnv, "setting --environment must change the default name")
assert.NotEqual(t, withEnv, otherEnv, "different environments must produce different names")
assert.Equal(t, withEnv, client.GenerateDefaultConnectionName(host, "", "my-env"), "must be deterministic")
})
}

func TestGenerateDefaultConnectionNameMatchesRegex(t *testing.T) {
Expand All @@ -159,12 +190,15 @@ func TestGenerateDefaultConnectionNameMatchesRegex(t *testing.T) {
"https://workspace3.gcp.databricks.com",
}
accelerators := []string{"", "GPU_1xA10", "GPU_8xH100"}
environments := []string{"", "my-env", "/Workspace/Users/me@example.com/env.yaml"}
nameRegex := regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]*$`)

for _, host := range hosts {
for _, acc := range accelerators {
name := client.GenerateDefaultConnectionName(host, acc)
assert.Regexp(t, nameRegex, name, "host=%q accelerator=%q name=%q", host, acc, name)
for _, env := range environments {
name := client.GenerateDefaultConnectionName(host, acc, env)
assert.Regexp(t, nameRegex, name, "host=%q accelerator=%q environment=%q name=%q", host, acc, env, name)
}
}
}
}
Expand Down Expand Up @@ -224,6 +258,11 @@ func TestToProxyCommand(t *testing.T) {
opts: client.ClientOptions{ClusterID: "abc-123", EnvironmentVersion: 4},
want: quoted + " ssh connect --proxy --cluster=abc-123 --auto-start-cluster=false --shutdown-delay=0s --environment-version=4",
},
{
name: "serverless with environment",
opts: client.ClientOptions{ConnectionName: "my-conn", Environment: "my env", ShutdownDelay: 2 * time.Minute},
want: quoted + ` ssh connect --proxy --name=my-conn --shutdown-delay=2m0s --environment="my env"`,
},
}

for _, tt := range tests {
Expand Down
4 changes: 4 additions & 0 deletions libs/telemetry/protos/ssh_tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ type SshTunnelEvent struct {
// GPU accelerator type for serverless compute.
AcceleratorType string `json:"accelerator_type,omitempty"`

// Base environment specified for serverless compute (raw --environment input:
// an env.yaml path, a workspace-base-environments resource ID, or a display name).
Environment string `json:"environment,omitempty"`

// IDE that initiated the connection (e.g., "vscode", "cursor").
IdeType string `json:"ide_type,omitempty"`

Expand Down
Loading