diff --git a/experimental/ssh/cmd/server.go b/experimental/ssh/cmd/server.go index efe283f28a..d73ad2d5c4 100644 --- a/experimental/ssh/cmd/server.go +++ b/experimental/ssh/cmd/server.go @@ -30,6 +30,7 @@ and proxies them to local SSH daemon processes. var version string var secretScopeName string var authorizedKeySecretName string + var serverless bool cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID") cmd.MarkFlagRequired("cluster") @@ -43,6 +44,7 @@ and proxies them to local SSH daemon processes. cmd.Flags().IntVar(&maxClients, "max-clients", defaultMaxClients, "Maximum number of SSH clients") cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down after no pings from clients") cmd.Flags().StringVar(&version, "version", "", "Client version of the Databricks CLI") + cmd.Flags().BoolVar(&serverless, "serverless", false, "Enable serverless mode for Jupyter initialization") cmd.PreRunE = func(cmd *cobra.Command, args []string) error { // The server can be executed under a directory with an invalid bundle configuration. @@ -70,6 +72,7 @@ and proxies them to local SSH daemon processes. AuthorizedKeySecretName: authorizedKeySecretName, DefaultPort: defaultServerPort, PortRange: serverPortRange, + Serverless: serverless, } return server.Run(ctx, wsc, opts) } diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 940f792f0e..3d0e1a5527 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -429,6 +429,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, "shutdownDelay": opts.ShutdownDelay.String(), "maxClients": strconv.Itoa(opts.MaxClients), "sessionId": sessionID, + "serverless": strconv.FormatBool(opts.IsServerlessMode()), } cmdio.LogString(ctx, "Submitting a job to start the ssh server...") diff --git a/experimental/ssh/internal/client/ssh-server-bootstrap.py b/experimental/ssh/internal/client/ssh-server-bootstrap.py index 8dc0aff16a..397254436f 100644 --- a/experimental/ssh/internal/client/ssh-server-bootstrap.py +++ b/experimental/ssh/internal/client/ssh-server-bootstrap.py @@ -17,7 +17,8 @@ dbutils.widgets.text("authorizedKeySecretName", "") dbutils.widgets.text("maxClients", "10") dbutils.widgets.text("shutdownDelay", "10m") -dbutils.widgets.text("sessionId", "") # Required: unique identifier for the session +dbutils.widgets.text("sessionId", "") +dbutils.widgets.text("serverless", "false") def cleanup(): @@ -115,6 +116,7 @@ def run_ssh_server(): session_id = dbutils.widgets.get("sessionId") if not session_id: raise RuntimeError("Session ID is required. Please provide it using the 'sessionId' widget.") + serverless = dbutils.widgets.get("serverless") arch = platform.machine() if arch == "x86_64": @@ -137,6 +139,7 @@ def run_ssh_server(): "server", f"--cluster={ctx.clusterId}", f"--session-id={session_id}", + f"--serverless={serverless}", f"--secret-scope-name={secrets_scope}", f"--authorized-key-secret-name={public_key_secret_name}", f"--max-clients={max_clients}", diff --git a/experimental/ssh/internal/server/jupyter-init.py b/experimental/ssh/internal/server/jupyter-init.py index 3e58b2f94a..e212f9d9e8 100644 --- a/experimental/ssh/internal/server/jupyter-init.py +++ b/experimental/ssh/internal/server/jupyter-init.py @@ -1,7 +1,7 @@ from typing import List, Optional from IPython.core.getipython import get_ipython from IPython.display import display as ip_display -from dbruntime import UserNamespaceInitializer +import os def _log_exceptions(func): @@ -18,18 +18,21 @@ def wrapper(*args, **kwargs): return wrapper -_user_namespace_initializer = UserNamespaceInitializer.getOrCreate() -_entry_point = _user_namespace_initializer.get_spark_entry_point() -_globals = _user_namespace_initializer.get_namespace_globals() -for name, value in _globals.items(): - print(f"Registering global: {name} = {value}") - if name not in globals(): - globals()[name] = value +@_log_exceptions +def _setup_dedicated_session(): + from dbruntime import UserNamespaceInitializer + _user_namespace_initializer = UserNamespaceInitializer.getOrCreate() + _entry_point = _user_namespace_initializer.get_spark_entry_point() + _globals = _user_namespace_initializer.get_namespace_globals() + for name, value in _globals.items(): + print(f"Registering global: {name} = {value}") + if name not in globals(): + globals()[name] = value -# 'display' from the runtime uses custom widgets that don't work in Jupyter. -# We use the IPython display instead (in combination with the html formatter for DataFrames). -globals()["display"] = ip_display + # 'display' from the runtime uses custom widgets that don't work in Jupyter. + # We use the IPython display instead (in combination with the html formatter for DataFrames). + globals()["display"] = ip_display @_log_exceptions @@ -157,19 +160,28 @@ def _parse_line_for_databricks_magics(lines: List[str]) -> List[str]: @_log_exceptions -def _register_magics(): - """Register the magic command parser with IPython.""" +def _register_common_magics(): + """Register the common magic command parser with IPython.""" + ip = get_ipython() + ip.input_transformers_cleanup.append(_parse_line_for_databricks_magics) + + +@_log_exceptions +def _register_pip_magics(user_namespace_initializer: any, entry_point: any): + """Register the pip magic command parser with IPython.""" from dbruntime.DatasetInfo import UserNamespaceDict from dbruntime.PipMagicOverrides import PipMagicOverrides + from dbruntime import UserNamespaceInitializer + user_namespace_initializer = UserNamespaceInitializer.getOrCreate() + entry_point = user_namespace_initializer.get_spark_entry_point() user_ns = UserNamespaceDict( - _user_namespace_initializer.get_namespace_globals(), - _entry_point.getDriverConf(), - _entry_point, + user_namespace_initializer.get_namespace_globals(), + entry_point.getDriverConf(), + entry_point, ) ip = get_ipython() - ip.input_transformers_cleanup.append(_parse_line_for_databricks_magics) - ip.register_magics(PipMagicOverrides(_entry_point, _globals["sc"]._conf, user_ns)) + ip.register_magics(PipMagicOverrides(entry_point, globals["sc"]._conf, user_ns)) @_log_exceptions @@ -186,6 +198,34 @@ def df_html(df: DataFrame) -> str: html_formatter.for_type(DataFrame, df_html) -_register_magics() +@_log_exceptions +def _setup_serverless_session(): + import IPython + from databricks.connect import DatabricksSession + + user_ns = getattr(IPython.get_ipython(), "user_ns", {}) + existing_session = getattr(user_ns, "spark", None) + try: + # Clear the existing local spark session, otherwise DatabricksSession will re-use it. + user_ns["spark"] = None + globals()["spark"] = None + # DatabricksSession will use the existing env vars for the connection. + spark_session = DatabricksSession.builder.serverless(True).getOrCreate() + user_ns["spark"] = spark_session + globals()["spark"] = spark_session + except Exception as e: + user_ns["spark"] = existing_session + globals()["spark"] = existing_session + raise e + + +if os.environ.get("DATABRICKS_JUPYTER_SERVERLESS") == "true": + _setup_serverless_session() +else: + _setup_dedicated_session() + _register_pip_magics() + + +_register_common_magics() _register_formatters() _register_runtime_hooks() diff --git a/experimental/ssh/internal/server/server.go b/experimental/ssh/internal/server/server.go index 92fa76050a..05243ec3c3 100644 --- a/experimental/ssh/internal/server/server.go +++ b/experimental/ssh/internal/server/server.go @@ -34,6 +34,8 @@ type ServerOptions struct { // SessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). // Used for metadata storage path. Defaults to ClusterID if not set. SessionID string + // Serverless indicates whether the server is running on serverless compute. + Serverless bool // The directory to store sshd configuration ConfigDir string // The name of the secrets scope to use for client and server keys diff --git a/experimental/ssh/internal/server/sshd.go b/experimental/ssh/internal/server/sshd.go index c8f23d02a5..6a8125d2b3 100644 --- a/experimental/ssh/internal/server/sshd.go +++ b/experimental/ssh/internal/server/sshd.go @@ -67,6 +67,9 @@ func prepareSSHDConfig(ctx context.Context, client *databricks.WorkspaceClient, setEnv += " GIT_CONFIG_GLOBAL=/Workspace/.proc/self/git/config" setEnv += " ENABLE_DATABRICKS_CLI=true" setEnv += " PYTHONPYCACHEPREFIX=/tmp/pycache" + if opts.Serverless { + setEnv += " DATABRICKS_JUPYTER_SERVERLESS=true" + } sshdConfigContent := "PubkeyAuthentication yes\n" + "PasswordAuthentication no\n" +