Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions experimental/ssh/cmd/connect.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ssh

import (
"errors"
"time"

"github.com/databricks/cli/cmd/root"
Expand All @@ -18,10 +19,18 @@ func newConnectCommand() *cobra.Command {
This command establishes an SSH connection to Databricks compute, setting up
the SSH server and handling the connection proxy.

For dedicated clusters:
databricks ssh connect --cluster=<cluster-id>

For serverless compute:
databricks ssh connect --name=<connection-name> [--accelerator=<accelerator>]

` + disclaimer,
}

var clusterID string
var connectionName string
var accelerator string
var proxyMode bool
var serverMetadata string
var shutdownDelay time.Duration
Expand All @@ -30,9 +39,11 @@ the SSH server and handling the connection proxy.
var releasesDir string
var autoStartCluster bool
var userKnownHostsFile string
var liteswap string

cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (required)")
cmd.MarkFlagRequired("cluster")
cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (for dedicated clusters)")
cmd.Flags().StringVar(&connectionName, "name", "", "Connection name (for serverless compute)")
cmd.Flags().StringVar(&accelerator, "accelerator", "", "GPU accelerator type for serverless compute (GPU_1xA10 or GPU_8xH100)")
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down the server after the last client disconnects")
cmd.Flags().IntVar(&maxClients, "max-clients", defaultMaxClients, "Maximum number of SSH clients")
cmd.Flags().BoolVar(&autoStartCluster, "auto-start-cluster", true, "Automatically start the cluster if it is not running")
Expand All @@ -50,6 +61,9 @@ the SSH server and handling the connection proxy.
cmd.Flags().StringVar(&userKnownHostsFile, "user-known-hosts-file", "", "Path to user known hosts file for SSH client")
cmd.Flags().MarkHidden("user-known-hosts-file")

cmd.Flags().StringVar(&liteswap, "liteswap", "", "Liteswap header value for traffic routing (dev/test only)")
cmd.Flags().MarkHidden("liteswap")

cmd.PreRunE = func(cmd *cobra.Command, args []string) error {
// CLI in the proxy mode is executed by the ssh client and can't prompt for input
if proxyMode {
Expand All @@ -64,20 +78,40 @@ the SSH server and handling the connection proxy.
cmd.RunE = func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
wsClient := cmdctx.WorkspaceClient(ctx)

if !proxyMode && clusterID == "" && connectionName == "" {
return errors.New("please provide --cluster flag with the cluster ID, or --name flag with the serverless connection name")
}

if accelerator != "" && connectionName == "" {
return errors.New("--accelerator flag can only be used with serverless compute (--name flag)")
}

// Remove when we add support for serverless CPU
if connectionName != "" && accelerator == "" {
return errors.New("--name flag requires --accelerator to be set (e.g. for now we only support serverless GPU compute)")
}

// TODO: validate connectionName if provided

opts := client.ClientOptions{
Profile: wsClient.Config.Profile,
ClusterID: clusterID,
ConnectionName: connectionName,
Accelerator: accelerator,
ProxyMode: proxyMode,
ServerMetadata: serverMetadata,
ShutdownDelay: shutdownDelay,
MaxClients: maxClients,
HandoverTimeout: handoverTimeout,
ReleasesDir: releasesDir,
ServerTimeout: serverTimeout,
TaskStartupTimeout: taskStartupTimeout,
AutoStartCluster: autoStartCluster,
ClientPublicKeyName: clientPublicKeyName,
ClientPrivateKeyName: clientPrivateKeyName,
UserKnownHostsFile: userKnownHostsFile,
Liteswap: liteswap,
AdditionalArgs: args,
}
return client.Run(ctx, wsClient, opts)
Expand Down
1 change: 1 addition & 0 deletions experimental/ssh/cmd/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ const (
defaultHandoverTimeout = 30 * time.Minute

serverTimeout = 24 * time.Hour
taskStartupTimeout = 10 * time.Minute
serverPortRange = 100
serverConfigDir = ".ssh-tunnel"
serverPrivateKeyName = "server-private-key"
Expand Down
4 changes: 4 additions & 0 deletions experimental/ssh/cmd/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,15 @@ and proxies them to local SSH daemon processes.
var maxClients int
var shutdownDelay time.Duration
var clusterID string
var sessionID string
var version string
var secretScopeName string
var authorizedKeySecretName string

cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID")
cmd.MarkFlagRequired("cluster")
cmd.Flags().StringVar(&sessionID, "session-id", "", "Session identifier (cluster ID or serverless connection name)")
cmd.MarkFlagRequired("session-id")
cmd.Flags().StringVar(&secretScopeName, "secret-scope-name", "", "Databricks secret scope name to store SSH keys")
cmd.MarkFlagRequired("secret-scope-name")
cmd.Flags().StringVar(&authorizedKeySecretName, "authorized-key-secret-name", "", "Name of the secret containing the client public key")
Expand All @@ -56,6 +59,7 @@ and proxies them to local SSH daemon processes.
wsc := cmdctx.WorkspaceClient(ctx)
opts := server.ServerOptions{
ClusterID: clusterID,
SessionID: sessionID,
MaxClients: maxClients,
ShutdownDelay: shutdownDelay,
Version: version,
Expand Down
Loading
Loading