Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions experimental/ssh/cmd/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ and proxies them to local SSH daemon processes.
var version string
var secretScopeName string
var authorizedKeySecretName string
var serverless bool

cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID")
cmd.MarkFlagRequired("cluster")
Expand All @@ -43,6 +44,7 @@ and proxies them to local SSH daemon processes.
cmd.Flags().IntVar(&maxClients, "max-clients", defaultMaxClients, "Maximum number of SSH clients")
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down after no pings from clients")
cmd.Flags().StringVar(&version, "version", "", "Client version of the Databricks CLI")
cmd.Flags().BoolVar(&serverless, "serverless", false, "Enable serverless mode for Jupyter initialization")

cmd.PreRunE = func(cmd *cobra.Command, args []string) error {
// The server can be executed under a directory with an invalid bundle configuration.
Expand Down Expand Up @@ -70,6 +72,7 @@ and proxies them to local SSH daemon processes.
AuthorizedKeySecretName: authorizedKeySecretName,
DefaultPort: defaultServerPort,
PortRange: serverPortRange,
Serverless: serverless,
}
return server.Run(ctx, wsc, opts)
}
Expand Down
1 change: 1 addition & 0 deletions experimental/ssh/internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
"shutdownDelay": opts.ShutdownDelay.String(),
"maxClients": strconv.Itoa(opts.MaxClients),
"sessionId": sessionID,
"serverless": strconv.FormatBool(opts.IsServerlessMode()),
}

cmdio.LogString(ctx, "Submitting a job to start the ssh server...")
Expand Down
5 changes: 4 additions & 1 deletion experimental/ssh/internal/client/ssh-server-bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
dbutils.widgets.text("authorizedKeySecretName", "")
dbutils.widgets.text("maxClients", "10")
dbutils.widgets.text("shutdownDelay", "10m")
dbutils.widgets.text("sessionId", "") # Required: unique identifier for the session
dbutils.widgets.text("sessionId", "")
dbutils.widgets.text("serverless", "false")


def cleanup():
Expand Down Expand Up @@ -115,6 +116,7 @@ def run_ssh_server():
session_id = dbutils.widgets.get("sessionId")
if not session_id:
raise RuntimeError("Session ID is required. Please provide it using the 'sessionId' widget.")
serverless = dbutils.widgets.get("serverless")

arch = platform.machine()
if arch == "x86_64":
Expand All @@ -137,6 +139,7 @@ def run_ssh_server():
"server",
f"--cluster={ctx.clusterId}",
f"--session-id={session_id}",
f"--serverless={serverless}",
f"--secret-scope-name={secrets_scope}",
f"--authorized-key-secret-name={public_key_secret_name}",
f"--max-clients={max_clients}",
Expand Down
78 changes: 59 additions & 19 deletions experimental/ssh/internal/server/jupyter-init.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, Optional
from IPython.core.getipython import get_ipython
from IPython.display import display as ip_display
from dbruntime import UserNamespaceInitializer
import os


def _log_exceptions(func):
Expand All @@ -18,18 +18,21 @@ def wrapper(*args, **kwargs):
return wrapper


_user_namespace_initializer = UserNamespaceInitializer.getOrCreate()
_entry_point = _user_namespace_initializer.get_spark_entry_point()
_globals = _user_namespace_initializer.get_namespace_globals()
for name, value in _globals.items():
print(f"Registering global: {name} = {value}")
if name not in globals():
globals()[name] = value
@_log_exceptions
def _setup_dedicated_session():
from dbruntime import UserNamespaceInitializer

_user_namespace_initializer = UserNamespaceInitializer.getOrCreate()
_entry_point = _user_namespace_initializer.get_spark_entry_point()
_globals = _user_namespace_initializer.get_namespace_globals()
for name, value in _globals.items():
print(f"Registering global: {name} = {value}")
if name not in globals():
globals()[name] = value

# 'display' from the runtime uses custom widgets that don't work in Jupyter.
# We use the IPython display instead (in combination with the html formatter for DataFrames).
globals()["display"] = ip_display
# 'display' from the runtime uses custom widgets that don't work in Jupyter.
# We use the IPython display instead (in combination with the html formatter for DataFrames).
globals()["display"] = ip_display


@_log_exceptions
Expand Down Expand Up @@ -157,19 +160,28 @@ def _parse_line_for_databricks_magics(lines: List[str]) -> List[str]:


@_log_exceptions
def _register_magics():
"""Register the magic command parser with IPython."""
def _register_common_magics():
"""Register the common magic command parser with IPython."""
ip = get_ipython()
ip.input_transformers_cleanup.append(_parse_line_for_databricks_magics)


@_log_exceptions
def _register_pip_magics(user_namespace_initializer: any, entry_point: any):
"""Register the pip magic command parser with IPython."""
from dbruntime.DatasetInfo import UserNamespaceDict
from dbruntime.PipMagicOverrides import PipMagicOverrides
from dbruntime import UserNamespaceInitializer

user_namespace_initializer = UserNamespaceInitializer.getOrCreate()
entry_point = user_namespace_initializer.get_spark_entry_point()
user_ns = UserNamespaceDict(
_user_namespace_initializer.get_namespace_globals(),
_entry_point.getDriverConf(),
_entry_point,
user_namespace_initializer.get_namespace_globals(),
entry_point.getDriverConf(),
entry_point,
)
ip = get_ipython()
ip.input_transformers_cleanup.append(_parse_line_for_databricks_magics)
ip.register_magics(PipMagicOverrides(_entry_point, _globals["sc"]._conf, user_ns))
ip.register_magics(PipMagicOverrides(entry_point, globals["sc"]._conf, user_ns))


@_log_exceptions
Expand All @@ -186,6 +198,34 @@ def df_html(df: DataFrame) -> str:
html_formatter.for_type(DataFrame, df_html)


_register_magics()
@_log_exceptions
def _setup_serverless_session():
import IPython
from databricks.connect import DatabricksSession

user_ns = getattr(IPython.get_ipython(), "user_ns", {})
existing_session = getattr(user_ns, "spark", None)
try:
# Clear the existing local spark session, otherwise DatabricksSession will re-use it.
user_ns["spark"] = None
globals()["spark"] = None
# DatabricksSession will use the existing env vars for the connection.
spark_session = DatabricksSession.builder.serverless(True).getOrCreate()
user_ns["spark"] = spark_session
globals()["spark"] = spark_session
except Exception as e:
user_ns["spark"] = existing_session
globals()["spark"] = existing_session
raise e


if os.environ.get("DATABRICKS_JUPYTER_SERVERLESS") == "true":
_setup_serverless_session()
else:
_setup_dedicated_session()
_register_pip_magics()


_register_common_magics()
_register_formatters()
_register_runtime_hooks()
2 changes: 2 additions & 0 deletions experimental/ssh/internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ type ServerOptions struct {
// SessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless).
// Used for metadata storage path. Defaults to ClusterID if not set.
SessionID string
// Serverless indicates whether the server is running on serverless compute.
Serverless bool
// The directory to store sshd configuration
ConfigDir string
// The name of the secrets scope to use for client and server keys
Expand Down
3 changes: 3 additions & 0 deletions experimental/ssh/internal/server/sshd.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ func prepareSSHDConfig(ctx context.Context, client *databricks.WorkspaceClient,
setEnv += " GIT_CONFIG_GLOBAL=/Workspace/.proc/self/git/config"
setEnv += " ENABLE_DATABRICKS_CLI=true"
setEnv += " PYTHONPYCACHEPREFIX=/tmp/pycache"
if opts.Serverless {
setEnv += " DATABRICKS_JUPYTER_SERVERLESS=true"
}

sshdConfigContent := "PubkeyAuthentication yes\n" +
"PasswordAuthentication no\n" +
Expand Down