From f757f6b01bb898eaea8079a32daef83b59e0451b Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Thu, 18 Jun 2026 20:21:45 +0000 Subject: [PATCH 1/5] experimental/air: add run config launch accessors Flatten the validated runConfig schema into the derived values the launch path consumes (timeout seconds, retry default, docker image URL, requirements file vs inline dependencies, runtime version), replacing the Python CLI's _convert_to_run_config step. handle_run reads runConfig directly, so these are accessors rather than a separate internal config type. Co-authored-by: Isaac --- experimental/air/cmd/runconfig_launch.go | 62 ++++++++++++++ experimental/air/cmd/runconfig_launch_test.go | 80 +++++++++++++++++++ 2 files changed, 142 insertions(+) create mode 100644 experimental/air/cmd/runconfig_launch.go create mode 100644 experimental/air/cmd/runconfig_launch_test.go diff --git a/experimental/air/cmd/runconfig_launch.go b/experimental/air/cmd/runconfig_launch.go new file mode 100644 index 0000000000..6864e4c454 --- /dev/null +++ b/experimental/air/cmd/runconfig_launch.go @@ -0,0 +1,62 @@ +package aircmd + +// This file flattens the validated runConfig schema into the derived values the +// launch path consumes, replacing the Python CLI's _convert_to_run_config step. +// There is no separate internal config type: handle_run reads runConfig directly, +// using these accessors for the values that need computing rather than a plain +// field read. + +const defaultMaxRetries = 3 + +// timeoutSeconds converts timeout_minutes to seconds. Zero means the user set no +// timeout and the backend default applies. +func (c *runConfig) timeoutSeconds() int { + if c.TimeoutMinutes == nil { + return 0 + } + return *c.TimeoutMinutes * 60 +} + +// maxRetries returns the retry count, applying the schema default when unset. +func (c *runConfig) maxRetries() int { + if c.MaxRetries == nil { + return defaultMaxRetries + } + return *c.MaxRetries +} + +// dockerImageURL returns the custom docker image URL, or "" when none is set. +func (c *runConfig) dockerImageURL() string { + if c.Environment != nil && c.Environment.DockerImage != nil { + return c.Environment.DockerImage.URL + } + return "" +} + +// requirementsFile returns the path to a requirements file when +// environment.dependencies is a string, and whether it was set. +func (c *runConfig) requirementsFile() (string, bool) { + if c.Environment == nil || !c.Environment.Dependencies.set || c.Environment.Dependencies.isList { + return "", false + } + return c.Environment.Dependencies.path, true +} + +// inlineDependencies returns the inline package list when +// environment.dependencies is a list, and whether it was set. +func (c *runConfig) inlineDependencies() ([]string, bool) { + if c.Environment == nil || !c.Environment.Dependencies.set || !c.Environment.Dependencies.isList { + return nil, false + } + return c.Environment.Dependencies.list, true +} + +// runtimeVersion returns the client image version from environment.version when +// set. For a requirements-file dependency set, the version lives in that file and +// is resolved at launch, not here. +func (c *runConfig) runtimeVersion() (string, bool) { + if c.Environment == nil || !c.Environment.Version.set { + return "", false + } + return c.Environment.Version.raw, true +} diff --git a/experimental/air/cmd/runconfig_launch_test.go b/experimental/air/cmd/runconfig_launch_test.go new file mode 100644 index 0000000000..289db91c7d --- /dev/null +++ b/experimental/air/cmd/runconfig_launch_test.go @@ -0,0 +1,80 @@ +package aircmd + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRunConfigTimeoutSeconds(t *testing.T) { + c := &runConfig{} + assert.Equal(t, 0, c.timeoutSeconds()) + + c.TimeoutMinutes = new(2) + assert.Equal(t, 120, c.timeoutSeconds()) +} + +func TestRunConfigMaxRetries(t *testing.T) { + c := &runConfig{} + assert.Equal(t, defaultMaxRetries, c.maxRetries()) + + c.MaxRetries = new(0) + assert.Equal(t, 0, c.maxRetries()) + + c.MaxRetries = new(7) + assert.Equal(t, 7, c.maxRetries()) +} + +func TestRunConfigDockerImageURL(t *testing.T) { + c := &runConfig{} + assert.Empty(t, c.dockerImageURL()) + + c.Environment = &environmentConfig{} + assert.Empty(t, c.dockerImageURL()) + + c.Environment.DockerImage = &dockerImageConfig{URL: "org/repo:tag"} + assert.Equal(t, "org/repo:tag", c.dockerImageURL()) +} + +func TestRunConfigDependencies(t *testing.T) { + t.Run("unset", func(t *testing.T) { + c := &runConfig{} + _, ok := c.requirementsFile() + assert.False(t, ok) + _, ok = c.inlineDependencies() + assert.False(t, ok) + }) + + t.Run("file path", func(t *testing.T) { + c := &runConfig{Environment: &environmentConfig{ + Dependencies: dependencies{set: true, isList: false, path: "req.yaml"}, + }} + path, ok := c.requirementsFile() + assert.True(t, ok) + assert.Equal(t, "req.yaml", path) + _, ok = c.inlineDependencies() + assert.False(t, ok) + }) + + t.Run("inline list", func(t *testing.T) { + c := &runConfig{Environment: &environmentConfig{ + Dependencies: dependencies{set: true, isList: true, list: []string{"torch", "numpy"}}, + }} + list, ok := c.inlineDependencies() + assert.True(t, ok) + assert.Equal(t, []string{"torch", "numpy"}, list) + _, ok = c.requirementsFile() + assert.False(t, ok) + }) +} + +func TestRunConfigRuntimeVersion(t *testing.T) { + c := &runConfig{} + _, ok := c.runtimeVersion() + assert.False(t, ok) + + c.Environment = &environmentConfig{Version: stringOrInt{set: true, raw: "5"}} + v, ok := c.runtimeVersion() + assert.True(t, ok) + assert.Equal(t, "5", v) +} From 2e7cf854c8bc9164a7aff4f25d906ec84b74026e Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Thu, 18 Jun 2026 20:37:18 +0000 Subject: [PATCH 2/5] experimental/air: wire run command for load, validate, dry-run Wire `air run`'s RunE to load and structurally validate the YAML config, and implement --dry-run (validate without submitting). The non-dry-run submission path returns "not implemented" until the submit phase lands; --override is rejected with a clear error since the override pipeline is not ported yet. Drop `run` from the not-implemented stub test now that it does real work. Co-authored-by: Isaac --- acceptance/experimental/air/run/invalid.yaml | 5 ++ acceptance/experimental/air/run/out.test.toml | 3 ++ acceptance/experimental/air/run/output.txt | 39 ++++++++++++++ acceptance/experimental/air/run/script | 17 +++++++ acceptance/experimental/air/run/test.toml | 4 ++ acceptance/experimental/air/run/valid.yaml | 5 ++ .../experimental/air/unimplemented/output.txt | 6 --- .../experimental/air/unimplemented/script | 3 -- experimental/air/cmd/run.go | 51 +++++++++++++++++-- experimental/air/cmd/stubs_test.go | 1 - 10 files changed, 121 insertions(+), 13 deletions(-) create mode 100644 acceptance/experimental/air/run/invalid.yaml create mode 100644 acceptance/experimental/air/run/out.test.toml create mode 100644 acceptance/experimental/air/run/output.txt create mode 100644 acceptance/experimental/air/run/script create mode 100644 acceptance/experimental/air/run/test.toml create mode 100644 acceptance/experimental/air/run/valid.yaml diff --git a/acceptance/experimental/air/run/invalid.yaml b/acceptance/experimental/air/run/invalid.yaml new file mode 100644 index 0000000000..c011fc81b3 --- /dev/null +++ b/acceptance/experimental/air/run/invalid.yaml @@ -0,0 +1,5 @@ +experiment_name: bad.name +command: x +compute: + accelerator_type: GPU_8xH100 + num_accelerators: 3 diff --git a/acceptance/experimental/air/run/out.test.toml b/acceptance/experimental/air/run/out.test.toml new file mode 100644 index 0000000000..d6187dcb04 --- /dev/null +++ b/acceptance/experimental/air/run/out.test.toml @@ -0,0 +1,3 @@ +Local = true +Cloud = false +EnvMatrix.DATABRICKS_BUNDLE_ENGINE = [] diff --git a/acceptance/experimental/air/run/output.txt b/acceptance/experimental/air/run/output.txt new file mode 100644 index 0000000000..180886290b --- /dev/null +++ b/acceptance/experimental/air/run/output.txt @@ -0,0 +1,39 @@ + +=== dry-run (text) +>>> [CLI] experimental air run -f valid.yaml --dry-run +Dry run: configuration for "smoke-test" is valid; not submitting. + +=== dry-run (json) +>>> [CLI] experimental air run -f valid.yaml --dry-run -o json +{ + "v": 1, + "ts": "[TIMESTAMP]", + "data": { + "status": "DRY_RUN_OK", + "dry_run": true + } +} + +=== override not yet supported +>>> [CLI] experimental air run -f valid.yaml --dry-run --override a=b +Error: --override is not yet supported + +Exit code: 1 + +=== watch not yet supported +>>> [CLI] experimental air run -f valid.yaml --dry-run --watch +Error: --watch is not yet supported + +Exit code: 1 + +=== invalid config is rejected +>>> [CLI] experimental air run -f invalid.yaml --dry-run +Error: invalid experiment_name "bad.name": only alphanumeric characters, hyphens (-), and underscores (_) are allowed + +Exit code: 1 + +=== missing --file +>>> [CLI] experimental air run --dry-run +Error: required flag(s) "file" not set + +Exit code: 1 diff --git a/acceptance/experimental/air/run/script b/acceptance/experimental/air/run/script new file mode 100644 index 0000000000..806bd321e6 --- /dev/null +++ b/acceptance/experimental/air/run/script @@ -0,0 +1,17 @@ +title "dry-run (text)" +trace $CLI experimental air run -f valid.yaml --dry-run + +title "dry-run (json)" +trace $CLI experimental air run -f valid.yaml --dry-run -o json + +title "override not yet supported" +errcode trace $CLI experimental air run -f valid.yaml --dry-run --override a=b + +title "watch not yet supported" +errcode trace $CLI experimental air run -f valid.yaml --dry-run --watch + +title "invalid config is rejected" +errcode trace $CLI experimental air run -f invalid.yaml --dry-run + +title "missing --file" +errcode trace $CLI experimental air run --dry-run diff --git a/acceptance/experimental/air/run/test.toml b/acceptance/experimental/air/run/test.toml new file mode 100644 index 0000000000..2f971c3ed2 --- /dev/null +++ b/acceptance/experimental/air/run/test.toml @@ -0,0 +1,4 @@ +# `air run --dry-run` validates the config locally and makes no workspace calls, +# so no engine matrix or server stubs are needed. +[EnvMatrix] +DATABRICKS_BUNDLE_ENGINE = [] diff --git a/acceptance/experimental/air/run/valid.yaml b/acceptance/experimental/air/run/valid.yaml new file mode 100644 index 0000000000..b82a321b05 --- /dev/null +++ b/acceptance/experimental/air/run/valid.yaml @@ -0,0 +1,5 @@ +experiment_name: smoke-test +command: python train.py +compute: + accelerator_type: GPU_1xH100 + num_accelerators: 1 diff --git a/acceptance/experimental/air/unimplemented/output.txt b/acceptance/experimental/air/unimplemented/output.txt index 66ddc34d58..21c3c891af 100644 --- a/acceptance/experimental/air/unimplemented/output.txt +++ b/acceptance/experimental/air/unimplemented/output.txt @@ -1,10 +1,4 @@ -=== run ->>> [CLI] experimental air run -Error: `air run` is not implemented yet - -Exit code: 1 - === logs >>> [CLI] experimental air logs 123 Error: `air logs` is not implemented yet diff --git a/acceptance/experimental/air/unimplemented/script b/acceptance/experimental/air/unimplemented/script index d00d045368..4c53586b16 100644 --- a/acceptance/experimental/air/unimplemented/script +++ b/acceptance/experimental/air/unimplemented/script @@ -1,8 +1,5 @@ # Each stub must fail with "not implemented"; errcode records the exit code. -title "run" -errcode trace $CLI experimental air run - title "logs" errcode trace $CLI experimental air logs 123 diff --git a/experimental/air/cmd/run.go b/experimental/air/cmd/run.go index 0bc3d1fd94..95bf360b83 100644 --- a/experimental/air/cmd/run.go +++ b/experimental/air/cmd/run.go @@ -1,10 +1,23 @@ package aircmd import ( + "errors" + "fmt" + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/flags" "github.com/spf13/cobra" ) +// runResult is the JSON payload for `air run`. +type runResult struct { + Status string `json:"status"` + DryRun bool `json:"dry_run,omitempty"` + RunID string `json:"run_id,omitempty"` + DashboardURL string `json:"dashboard_url,omitempty"` +} + func newRunCommand() *cobra.Command { var ( file string @@ -21,9 +34,6 @@ func newRunCommand() *cobra.Command { Long: `Submit a training workload to Databricks serverless GPU compute. The workload is described by a YAML config file (see --file).`, - RunE: func(cmd *cobra.Command, args []string) error { - return notImplemented("run") - }, } cmd.Flags().StringVarP(&file, "file", "f", "", "Path to the workload YAML config") @@ -31,6 +41,41 @@ The workload is described by a YAML config file (see --file).`, cmd.Flags().StringArrayVar(&overrides, "override", nil, "Override a YAML field, e.g. compute.num_accelerators=8 (repeatable)") cmd.Flags().BoolVar(&dryRun, "dry-run", false, "Validate the config without submitting") cmd.Flags().StringVar(&idempotencyKey, "idempotency-key", "", "Return the existing run if this key was already used") + _ = cmd.MarkFlagRequired("file") + + // --dry-run only validates the config locally, so it needs no workspace. + // Submission requires an authenticated client. + cmd.PreRunE = func(cmd *cobra.Command, args []string) error { + if dryRun { + return nil + } + return root.MustWorkspaceClient(cmd, args) + } + + cmd.RunE = func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + + // --override is parsed and applied before validation; that pipeline is + // not ported yet, so reject it rather than silently ignore the flag. + if len(overrides) > 0 { + return errors.New("--override is not yet supported") + } + + cfg, err := loadRunConfig(file) + if err != nil { + return err + } + + if dryRun { + if root.OutputType(cmd) == flags.OutputText { + cmdio.LogString(ctx, fmt.Sprintf("Dry run: configuration for %q is valid; not submitting.", cfg.ExperimentName)) + return nil + } + return renderEnvelope(ctx, runResult{Status: "DRY_RUN_OK", DryRun: true}) + } + + return notImplemented("run submission") + } return cmd } diff --git a/experimental/air/cmd/stubs_test.go b/experimental/air/cmd/stubs_test.go index b9f5c330f0..4607d7d9ea 100644 --- a/experimental/air/cmd/stubs_test.go +++ b/experimental/air/cmd/stubs_test.go @@ -13,7 +13,6 @@ import ( // fails with a "not implemented" error. Drop a command here once it lands. func TestStubCommandsReturnNotImplemented(t *testing.T) { stubs := map[string]*cobra.Command{ - "run": newRunCommand(), "logs": newLogsCommand(), "cancel": newCancelCommand(), "register-image": newRegisterImageCommand(), From f46d5e0dfd5382d632eb050ce5e5b07513caabb1 Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Mon, 22 Jun 2026 23:15:39 +0000 Subject: [PATCH 3/5] experimental/air: add run pre-submit resolution helpers Resolve the workspace context air run needs before uploading and submitting: the current user, the per-user workspace home (with env override), a unique cli_launch directory for a run's artifacts, the MLflow experiment path, and ensuring a custom experiment_directory exists (created if missing, matching the CLI's convention for its other artifact directories). Co-authored-by: Isaac --- experimental/air/cmd/runlaunch.go | 82 ++++++++++++++++++++++++++ experimental/air/cmd/runlaunch_test.go | 71 ++++++++++++++++++++++ 2 files changed, 153 insertions(+) create mode 100644 experimental/air/cmd/runlaunch.go create mode 100644 experimental/air/cmd/runlaunch_test.go diff --git a/experimental/air/cmd/runlaunch.go b/experimental/air/cmd/runlaunch.go new file mode 100644 index 0000000000..163655d9dd --- /dev/null +++ b/experimental/air/cmd/runlaunch.go @@ -0,0 +1,82 @@ +package aircmd + +import ( + "context" + "errors" + "fmt" + "path" + "strings" + + "github.com/databricks/cli/libs/env" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/apierr" + "github.com/databricks/databricks-sdk-go/service/iam" + "github.com/databricks/databricks-sdk-go/service/workspace" + "github.com/google/uuid" +) + +// userWorkspaceDirEnv overrides the per-user workspace directory; mirrors the +// Python CLI's DATABRICKS_INTERNAL_USER_WORKSPACE_DIR escape hatch. +const userWorkspaceDirEnv = "DATABRICKS_INTERNAL_USER_WORKSPACE_DIR" + +// currentUserEmail returns the authenticated user's email (works for any domain). +func currentUserEmail(ctx context.Context, w *databricks.WorkspaceClient) (string, error) { + me, err := w.CurrentUser.Me(ctx, iam.MeRequest{}) + if err != nil { + return "", fmt.Errorf("failed to resolve current user: %w", err) + } + return me.UserName, nil +} + +// userWorkspaceDir returns the user's workspace home, honoring the env override. +func userWorkspaceDir(ctx context.Context, w *databricks.WorkspaceClient) (string, error) { + if override := env.Get(ctx, userWorkspaceDirEnv); override != "" { + return override, nil + } + email, err := currentUserEmail(ctx, w) + if err != nil { + return "", err + } + return "/Workspace/Users/" + email, nil +} + +// cliLaunchDir returns a unique workspace directory for a run's launch artifacts: +// /.air/cli_launch//_. run defaults to experiment. +func cliLaunchDir(base, experiment, run string) string { + if run == "" { + run = experiment + } + unique := strings.ReplaceAll(uuid.NewString(), "-", "")[:16] + return path.Join(base, ".air", "cli_launch", experiment, run+"_"+unique) +} + +// mlflowExperimentName builds the full MLflow experiment path. A custom directory +// is used as-is; otherwise it defaults under the user's home. +func mlflowExperimentName(experiment, experimentDir, userEmail string) string { + if experimentDir != "" { + return strings.TrimRight(experimentDir, "/") + "/" + experiment + } + return "/Users/" + userEmail + "/" + experiment +} + +// ensureExperimentDirectory creates experimentDir if it is missing, matching the +// CLI's convention for its other artifact directories. Without this, a missing +// parent surfaces only as a server-side INTERNAL_ERROR after the run is wasted. +// An empty dir means the default (/Users//...), which always exists. +func ensureExperimentDirectory(ctx context.Context, w *databricks.WorkspaceClient, experimentDir string) error { + if experimentDir == "" { + return nil + } + + info, err := w.Workspace.GetStatusByPath(ctx, experimentDir) + if errors.Is(err, apierr.ErrNotFound) { + return w.Workspace.MkdirsByPath(ctx, experimentDir) + } + if err != nil { + return fmt.Errorf("failed to check experiment_directory %q: %w", experimentDir, err) + } + if info.ObjectType != workspace.ObjectTypeDirectory { + return fmt.Errorf("experiment_directory %q is not a directory (object_type=%s)", experimentDir, info.ObjectType) + } + return nil +} diff --git a/experimental/air/cmd/runlaunch_test.go b/experimental/air/cmd/runlaunch_test.go new file mode 100644 index 0000000000..df8c3a087a --- /dev/null +++ b/experimental/air/cmd/runlaunch_test.go @@ -0,0 +1,71 @@ +package aircmd + +import ( + "strings" + "testing" + + "github.com/databricks/cli/libs/filer" + "github.com/databricks/cli/libs/testserver" + "github.com/databricks/databricks-sdk-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMlflowExperimentName(t *testing.T) { + assert.Equal(t, "/Users/me@example.com/exp", mlflowExperimentName("exp", "", "me@example.com")) + assert.Equal(t, "/Workspace/shared/exp", mlflowExperimentName("exp", "/Workspace/shared", "me@example.com")) + assert.Equal(t, "/Workspace/shared/exp", mlflowExperimentName("exp", "/Workspace/shared/", "me@example.com")) +} + +func TestCliLaunchDir(t *testing.T) { + dir := cliLaunchDir("/Workspace/Users/me@example.com", "my-exp", "") + assert.True(t, strings.HasPrefix(dir, "/Workspace/Users/me@example.com/.air/cli_launch/my-exp/my-exp_"), dir) + // run name overrides the leaf; the unique suffix keeps successive dirs distinct. + withRun := cliLaunchDir("/base", "exp", "run1") + assert.True(t, strings.HasPrefix(withRun, "/base/.air/cli_launch/exp/run1_"), withRun) + assert.NotEqual(t, dir, cliLaunchDir("/Workspace/Users/me@example.com", "my-exp", "")) +} + +func newFakeWorkspaceClient(t *testing.T) *databricks.WorkspaceClient { + server := testserver.New(t) + t.Cleanup(server.Close) + testserver.AddDefaultHandlers(server) + w, err := databricks.NewWorkspaceClient(&databricks.Config{Host: server.URL, Token: "token"}) + require.NoError(t, err) + return w +} + +func TestUserWorkspaceDir(t *testing.T) { + w := newFakeWorkspaceClient(t) + dir, err := userWorkspaceDir(t.Context(), w) + require.NoError(t, err) + assert.True(t, strings.HasPrefix(dir, "/Workspace/Users/"), dir) + + // The env override wins without an API call. + t.Setenv(userWorkspaceDirEnv, "/Workspace/custom") + dir, err = userWorkspaceDir(t.Context(), w) + require.NoError(t, err) + assert.Equal(t, "/Workspace/custom", dir) +} + +func TestEnsureExperimentDirectory(t *testing.T) { + ctx := t.Context() + w := newFakeWorkspaceClient(t) + + // Empty means default (always exists) — no API call, no error. + require.NoError(t, ensureExperimentDirectory(ctx, w, "")) + + // A missing path is created. + require.NoError(t, ensureExperimentDirectory(ctx, w, "/Workspace/Users/me/exp")) + + // An existing directory is accepted as-is. + require.NoError(t, w.Workspace.MkdirsByPath(ctx, "/Workspace/Users/me/existing")) + require.NoError(t, ensureExperimentDirectory(ctx, w, "/Workspace/Users/me/existing")) + + // A path that exists but is a file is rejected. + fc, err := filer.NewWorkspaceFilesClient(w, "/Workspace/Users/me") + require.NoError(t, err) + require.NoError(t, fc.Write(ctx, "afile", strings.NewReader("x"))) + err = ensureExperimentDirectory(ctx, w, "/Workspace/Users/me/afile") + require.ErrorContains(t, err, "is not a directory") +} From 185f533a05bdfca7dafc6dfd1561a65209d30c46 Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Mon, 22 Jun 2026 23:18:02 +0000 Subject: [PATCH 4/5] experimental/air: upload run launch artifacts Assemble and upload the launch artifacts for a run into its cli_launch directory: the merged config (training_config.yaml, 1 MB cap), the inline command as command.sh, requirements.yaml (from a file or synthesized from inline dependencies), and hyperparameters.yaml. buildArtifacts is pure; the upload writes through a narrow fileWriter (a workspace filer in production). A TODO(DABs) marks the client-side upload path as a future candidate for reuse of DABs' file-staging (libs/sync / bundle deploy). Co-authored-by: Isaac --- experimental/air/cmd/runupload.go | 114 +++++++++++++++++++++ experimental/air/cmd/runupload_test.go | 135 +++++++++++++++++++++++++ 2 files changed, 249 insertions(+) create mode 100644 experimental/air/cmd/runupload.go create mode 100644 experimental/air/cmd/runupload_test.go diff --git a/experimental/air/cmd/runupload.go b/experimental/air/cmd/runupload.go new file mode 100644 index 0000000000..c9c4047381 --- /dev/null +++ b/experimental/air/cmd/runupload.go @@ -0,0 +1,114 @@ +package aircmd + +import ( + "bytes" + "context" + "fmt" + "io" + "os" + "path/filepath" + + "github.com/databricks/cli/libs/filer" + "go.yaml.in/yaml/v3" +) + +// Launch artifact basenames, uploaded into the run's cli_launch directory. The +// server-side launcher derives requirements.yaml / hyperparameters.yaml from the +// same directory, so these names are part of the contract. +const ( + trainingConfigName = "training_config.yaml" + commandScriptName = "command.sh" + requirementsName = "requirements.yaml" + hyperparametersName = "hyperparameters.yaml" +) + +// maxConfigYAMLBytes caps training_config.yaml. It is referenced by the Jobs +// payload and rendered on the run page, so an oversized parameters/command block +// is rejected here; full parameters still ship in hyperparameters.yaml. +const maxConfigYAMLBytes = 1024 * 1024 + +// uploadItem is a single artifact to write into the launch directory. +type uploadItem struct { + name string + data []byte +} + +// fileWriter is the subset of filer.Filer the upload path needs; a narrow +// interface keeps buildArtifacts/upload testable without a live workspace. +type fileWriter interface { + Write(ctx context.Context, name string, reader io.Reader, mode ...filer.WriteMode) error +} + +// requirementsDoc mirrors the on-disk requirements.yaml format so the worker +// parses synthesized inline dependencies identically to a user-provided file. +type requirementsDoc struct { + Version string `yaml:"version,omitempty"` + Dependencies []string `yaml:"dependencies"` +} + +// buildArtifacts assembles the files to upload for a run: the merged config, the +// inline command as a script, requirements (from a file or synthesized from +// inline dependencies), and hyperparameters. configPath is the local YAML path. +func buildArtifacts(cfg *runConfig, configPath string) ([]uploadItem, error) { + // TODO(DABs): with no _bases_/overrides ported yet, the merged config is the + // file as-is; once those land, upload the re-serialized merged YAML instead. + configData, err := os.ReadFile(configPath) + if err != nil { + return nil, fmt.Errorf("failed to read config %s: %w", configPath, err) + } + if len(configData) > maxConfigYAMLBytes { + return nil, fmt.Errorf("config YAML is %.2f MB, over the %d MB limit; reduce 'parameters' or 'command'", + float64(len(configData))/(1024*1024), maxConfigYAMLBytes/(1024*1024)) + } + + items := []uploadItem{ + {trainingConfigName, configData}, + {commandScriptName, []byte(*cfg.Command)}, + } + + switch reqPath, ok := cfg.requirementsFile(); { + case ok: + // Resolve a relative requirements path against the config's directory. + if !filepath.IsAbs(reqPath) { + reqPath = filepath.Join(filepath.Dir(configPath), reqPath) + } + data, err := os.ReadFile(reqPath) + if err != nil { + return nil, fmt.Errorf("failed to read requirements file %s: %w", reqPath, err) + } + items = append(items, uploadItem{requirementsName, data}) + default: + if deps, ok := cfg.inlineDependencies(); ok { + version, _ := cfg.runtimeVersion() + data, err := yaml.Marshal(requirementsDoc{Version: version, Dependencies: deps}) + if err != nil { + return nil, fmt.Errorf("failed to synthesize requirements.yaml: %w", err) + } + items = append(items, uploadItem{requirementsName, data}) + } + } + + if len(cfg.Parameters) > 0 { + data, err := yaml.Marshal(cfg.Parameters) + if err != nil { + return nil, fmt.Errorf("failed to serialize parameters: %w", err) + } + items = append(items, uploadItem{hyperparametersName, data}) + } + + return items, nil +} + +// uploadArtifacts writes each artifact into the launch directory, overwriting and +// creating parents as needed. +// +// TODO(DABs): this client-side upload could move onto libs/sync / a bundle deploy +// so the CLI reuses DABs' file-staging machinery instead of writing files itself. +func uploadArtifacts(ctx context.Context, w fileWriter, items []uploadItem) error { + for _, it := range items { + if err := w.Write(ctx, it.name, bytes.NewReader(it.data), filer.OverwriteIfExists, filer.CreateParentDirectories); err != nil { + return fmt.Errorf("failed to upload %s: %w", it.name, err) + } + } + return nil +} diff --git a/experimental/air/cmd/runupload_test.go b/experimental/air/cmd/runupload_test.go new file mode 100644 index 0000000000..ec700a9229 --- /dev/null +++ b/experimental/air/cmd/runupload_test.go @@ -0,0 +1,135 @@ +package aircmd + +import ( + "context" + "errors" + "io" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/databricks/cli/libs/filer" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// fakeWriter records artifact writes in place of a workspace filer. +type fakeWriter struct { + written map[string]string +} + +func (f *fakeWriter) Write(ctx context.Context, name string, reader io.Reader, mode ...filer.WriteMode) error { + if f.written == nil { + f.written = map[string]string{} + } + data, err := io.ReadAll(reader) + if err != nil { + return err + } + f.written[name] = string(data) + return nil +} + +func writeConfigFile(t *testing.T, name, content string) string { + t.Helper() + path := filepath.Join(t.TempDir(), name) + require.NoError(t, os.WriteFile(path, []byte(content), 0o600)) + return path +} + +func itemNames(items []uploadItem) []string { + names := make([]string, len(items)) + for i, it := range items { + names[i] = it.name + } + return names +} + +func TestBuildArtifacts_CommandAndConfig(t *testing.T) { + path := writeConfigFile(t, "run.yaml", minimalConfig) + cfg := &runConfig{Command: new("python train.py")} + + items, err := buildArtifacts(cfg, path) + require.NoError(t, err) + assert.Equal(t, []string{trainingConfigName, commandScriptName}, itemNames(items)) + assert.Equal(t, minimalConfig, string(items[0].data)) + assert.Equal(t, "python train.py", string(items[1].data)) +} + +func TestBuildArtifacts_InlineRequirementsAndParameters(t *testing.T) { + path := writeConfigFile(t, "run.yaml", "x: y\n") + cfg := &runConfig{ + Command: new("echo hi"), + Environment: &environmentConfig{ + Dependencies: dependencies{set: true, isList: true, list: []string{"torch", "numpy"}}, + Version: stringOrInt{set: true, raw: "5"}, + }, + Parameters: map[string]any{"lr": 0.1}, + } + + items, err := buildArtifacts(cfg, path) + require.NoError(t, err) + assert.Equal(t, []string{trainingConfigName, commandScriptName, requirementsName, hyperparametersName}, itemNames(items)) + + var reqIdx int + for i, it := range items { + if it.name == requirementsName { + reqIdx = i + } + } + req := string(items[reqIdx].data) + assert.Contains(t, req, "version: \"5\"") + assert.Contains(t, req, "- torch") +} + +func TestBuildArtifacts_RequirementsFile(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "run.yaml"), []byte("x: y\n"), 0o600)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "reqs.yaml"), []byte("version: 4\n"), 0o600)) + cfg := &runConfig{ + Command: new("echo hi"), + Environment: &environmentConfig{Dependencies: dependencies{set: true, isList: false, path: "reqs.yaml"}}, + } + + items, err := buildArtifacts(cfg, filepath.Join(dir, "run.yaml")) + require.NoError(t, err) + assert.Contains(t, itemNames(items), requirementsName) +} + +func TestBuildArtifacts_OversizeConfigRejected(t *testing.T) { + path := writeConfigFile(t, "run.yaml", strings.Repeat("a", maxConfigYAMLBytes+1)) + _, err := buildArtifacts(&runConfig{Command: new("x")}, path) + require.Error(t, err) + assert.Contains(t, err.Error(), "over the 1 MB limit") +} + +func TestUploadArtifacts(t *testing.T) { + w := &fakeWriter{} + items := []uploadItem{{trainingConfigName, []byte("cfg")}, {commandScriptName, []byte("cmd")}} + require.NoError(t, uploadArtifacts(t.Context(), w, items)) + assert.Equal(t, "cfg", w.written[trainingConfigName]) + assert.Equal(t, "cmd", w.written[commandScriptName]) +} + +// errWriter fails every Write, exercising the upload error path. +type errWriter struct{} + +func (errWriter) Write(ctx context.Context, name string, reader io.Reader, mode ...filer.WriteMode) error { + return errors.New("boom") +} + +func TestUploadArtifacts_WriteError(t *testing.T) { + err := uploadArtifacts(t.Context(), errWriter{}, []uploadItem{{trainingConfigName, []byte("x")}}) + require.ErrorContains(t, err, "failed to upload "+trainingConfigName) +} + +func TestBuildArtifacts_MissingRequirementsFile(t *testing.T) { + cfgPath := writeConfigFile(t, "run.yaml", "x: y\n") + cfg := &runConfig{ + Command: new("echo hi"), + Environment: &environmentConfig{Dependencies: dependencies{set: true, isList: false, path: "nope.yaml"}}, + } + _, err := buildArtifacts(cfg, cfgPath) + require.ErrorContains(t, err, "failed to read requirements file") +} From a5d851b7080fba99ef9374317e06be2add30bd29 Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Mon, 22 Jun 2026 23:30:39 +0000 Subject: [PATCH 5/5] experimental/air: assemble and submit a training run MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wire `air run` end to end: ensure the experiment directory, upload launch artifacts, build the native ai_runtime_task payload, and submit it via a direct POST to /api/2.2/jobs/runs/submit. The ai_runtime_task routes straight to the training service with no genai-mapi forwarding — the MAPI path is deprecated. The proto is lean: env vars and secrets are staged as co-located env_vars.json / secret_env_vars.json workspace files rather than inline, and requirements / hyperparameters are derived server-side from the command directory. The non-dry-run path resolves the workspace context, uploads, submits, and prints the run id + dashboard URL. usage_policy_name, code_source snapshots, and --watch are rejected with clear errors until their phases land. environment.docker_image is accepted by the schema as scaffolding but not conveyed (the native path has no docker field). Co-authored-by: Isaac --- experimental/air/cmd/run.go | 23 ++- experimental/air/cmd/runlaunch.go | 9 - experimental/air/cmd/runlaunch_test.go | 6 - experimental/air/cmd/runsubmit.go | 244 +++++++++++++++++++++++++ experimental/air/cmd/runsubmit_test.go | 145 +++++++++++++++ experimental/air/cmd/runupload.go | 56 ++++++ experimental/air/cmd/runupload_test.go | 20 ++ 7 files changed, 485 insertions(+), 18 deletions(-) create mode 100644 experimental/air/cmd/runsubmit.go create mode 100644 experimental/air/cmd/runsubmit_test.go diff --git a/experimental/air/cmd/run.go b/experimental/air/cmd/run.go index 95bf360b83..bd32810e9b 100644 --- a/experimental/air/cmd/run.go +++ b/experimental/air/cmd/run.go @@ -3,8 +3,10 @@ package aircmd import ( "errors" "fmt" + "strconv" "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdctx" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/flags" "github.com/spf13/cobra" @@ -55,11 +57,14 @@ The workload is described by a YAML config file (see --file).`, cmd.RunE = func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - // --override is parsed and applied before validation; that pipeline is - // not ported yet, so reject it rather than silently ignore the flag. + // These flags' pipelines are not ported yet; reject rather than silently + // ignore them. if len(overrides) > 0 { return errors.New("--override is not yet supported") } + if watch { + return errors.New("--watch is not yet supported") + } cfg, err := loadRunConfig(file) if err != nil { @@ -74,7 +79,19 @@ The workload is described by a YAML config file (see --file).`, return renderEnvelope(ctx, runResult{Status: "DRY_RUN_OK", DryRun: true}) } - return notImplemented("run submission") + w := cmdctx.WorkspaceClient(ctx) + runID, dashboardURL, err := submitWorkload(ctx, w, cfg, file, idempotencyKey) + if err != nil { + return err + } + + runIDStr := strconv.FormatInt(runID, 10) + if root.OutputType(cmd) == flags.OutputText { + cmdio.LogString(ctx, "Submitted run "+runIDStr) + cmdio.LogString(ctx, "View at: "+dashboardURL) + return nil + } + return renderEnvelope(ctx, runResult{Status: "SUBMITTED", RunID: runIDStr, DashboardURL: dashboardURL}) } return cmd diff --git a/experimental/air/cmd/runlaunch.go b/experimental/air/cmd/runlaunch.go index 163655d9dd..b2a7215e66 100644 --- a/experimental/air/cmd/runlaunch.go +++ b/experimental/air/cmd/runlaunch.go @@ -50,15 +50,6 @@ func cliLaunchDir(base, experiment, run string) string { return path.Join(base, ".air", "cli_launch", experiment, run+"_"+unique) } -// mlflowExperimentName builds the full MLflow experiment path. A custom directory -// is used as-is; otherwise it defaults under the user's home. -func mlflowExperimentName(experiment, experimentDir, userEmail string) string { - if experimentDir != "" { - return strings.TrimRight(experimentDir, "/") + "/" + experiment - } - return "/Users/" + userEmail + "/" + experiment -} - // ensureExperimentDirectory creates experimentDir if it is missing, matching the // CLI's convention for its other artifact directories. Without this, a missing // parent surfaces only as a server-side INTERNAL_ERROR after the run is wasted. diff --git a/experimental/air/cmd/runlaunch_test.go b/experimental/air/cmd/runlaunch_test.go index df8c3a087a..af6f0f70d3 100644 --- a/experimental/air/cmd/runlaunch_test.go +++ b/experimental/air/cmd/runlaunch_test.go @@ -11,12 +11,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestMlflowExperimentName(t *testing.T) { - assert.Equal(t, "/Users/me@example.com/exp", mlflowExperimentName("exp", "", "me@example.com")) - assert.Equal(t, "/Workspace/shared/exp", mlflowExperimentName("exp", "/Workspace/shared", "me@example.com")) - assert.Equal(t, "/Workspace/shared/exp", mlflowExperimentName("exp", "/Workspace/shared/", "me@example.com")) -} - func TestCliLaunchDir(t *testing.T) { dir := cliLaunchDir("/Workspace/Users/me@example.com", "my-exp", "") assert.True(t, strings.HasPrefix(dir, "/Workspace/Users/me@example.com/.air/cli_launch/my-exp/my-exp_"), dir) diff --git a/experimental/air/cmd/runsubmit.go b/experimental/air/cmd/runsubmit.go new file mode 100644 index 0000000000..08f7c99399 --- /dev/null +++ b/experimental/air/cmd/runsubmit.go @@ -0,0 +1,244 @@ +package aircmd + +import ( + "context" + "errors" + "net/http" + "path" + "strconv" + "strings" + + "github.com/databricks/cli/libs/auth" + "github.com/databricks/cli/libs/env" + "github.com/databricks/cli/libs/filer" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/client" + "github.com/google/uuid" +) + +// jobsRunsSubmitPath is the Jobs one-time-run endpoint. air builds the full +// payload and POSTs it here directly — the native ai_runtime_task is not modeled +// by the typed SDK, and we want no genai-mapi forwarding. +const jobsRunsSubmitPath = "/api/2.2/jobs/runs/submit" + +// dlRuntimeImageEnv overrides the default deep-learning runtime image. +const dlRuntimeImageEnv = "DATABRICKS_DL_RUNTIME_IMAGE" + +const defaultDlRuntimeImage = "CLIENT-GPU-4" + +// aiRuntimeEnvironmentKey ties the task to the serverless environment that +// carries the runtime channel. +const aiRuntimeEnvironmentKey = "default" + +// aiRuntimeCompute is a deployment's accelerator request. +type aiRuntimeCompute struct { + AcceleratorType string `json:"accelerator_type"` + AcceleratorCount int `json:"accelerator_count"` +} + +// aiRuntimeDeployment is one worker deployment of the run. +type aiRuntimeDeployment struct { + CommandPath string `json:"command_path"` + Compute aiRuntimeCompute `json:"compute"` +} + +// aiRuntimeTask is the native AI Runtime task. It routes straight to the training +// service — no genai-mapi forwarding. The proto is lean: env vars, secrets, +// requirements, and hyperparameters are staged as workspace files co-located with +// command.sh (see runupload.go), not carried inline. +type aiRuntimeTask struct { + Experiment string `json:"experiment"` + Deployments []aiRuntimeDeployment `json:"deployments"` + MlflowRun string `json:"mlflow_run,omitempty"` + MlflowExperimentDirectory string `json:"mlflow_experiment_directory,omitempty"` +} + +// environmentSpec carries the bare runtime channel ("4", "5", ...). +type environmentSpec struct { + EnvironmentVersion string `json:"environment_version"` +} + +// jobEnvironment is the serverless environment a task references for its runtime. +type jobEnvironment struct { + EnvironmentKey string `json:"environment_key"` + Spec environmentSpec `json:"spec"` +} + +// submitTask is the single task air submits: a native ai_runtime_task. +type submitTask struct { + TaskKey string `json:"task_key"` + RunIf string `json:"run_if"` + AiRuntimeTask aiRuntimeTask `json:"ai_runtime_task"` + EnvironmentKey string `json:"environment_key"` + MaxRetries int `json:"max_retries,omitempty"` + RetryOnTimeout bool `json:"retry_on_timeout,omitempty"` +} + +// jobsSubmitRun is the Jobs runs/submit payload. +type jobsSubmitRun struct { + RunName string `json:"run_name"` + TimeoutSeconds int `json:"timeout_seconds,omitempty"` + Tasks []submitTask `json:"tasks"` + Environments []jobEnvironment `json:"environments"` + BudgetPolicyID string `json:"budget_policy_id,omitempty"` + IdempotencyToken string `json:"idempotency_token,omitempty"` +} + +// dlRuntimeImage resolves the runtime channel. A config version wins; otherwise +// the env override or default applies. The CLIENT-GPU- prefix is stripped because +// the native path wants the bare channel. +func dlRuntimeImage(ctx context.Context, runtimeVersion string) string { + if runtimeVersion != "" { + return runtimeVersion + } + img := env.Get(ctx, dlRuntimeImageEnv) + if img == "" { + img = defaultDlRuntimeImage + } + return strings.TrimPrefix(img, "CLIENT-GPU-") +} + +// buildSubmitPayload assembles the runs/submit payload. commandPath is the +// workspace path of the uploaded command.sh; dlImage is the runtime channel. +func buildSubmitPayload(cfg *runConfig, commandPath, dlImage string) jobsSubmitRun { + task := aiRuntimeTask{ + Experiment: cfg.ExperimentName, + Deployments: []aiRuntimeDeployment{{ + CommandPath: commandPath, + Compute: aiRuntimeCompute{ + AcceleratorType: cfg.Compute.AcceleratorType, + AcceleratorCount: cfg.Compute.NumAccelerators, + }, + }}, + } + if cfg.MLflowRunName != nil { + task.MlflowRun = *cfg.MLflowRunName + } + if cfg.MLflowExperimentDirectory != nil { + task.MlflowExperimentDirectory = *cfg.MLflowExperimentDirectory + } + + st := submitTask{ + TaskKey: cfg.ExperimentName, + RunIf: "ALL_SUCCESS", + AiRuntimeTask: task, + EnvironmentKey: aiRuntimeEnvironmentKey, + } + // retry_on_timeout pairs with max_retries, matching the Python payload. + if r := cfg.maxRetries(); r > 0 { + st.MaxRetries = r + st.RetryOnTimeout = true + } + + return jobsSubmitRun{ + RunName: cfg.ExperimentName, + TimeoutSeconds: cfg.timeoutSeconds(), + Tasks: []submitTask{st}, + Environments: []jobEnvironment{{ + EnvironmentKey: aiRuntimeEnvironmentKey, + Spec: environmentSpec{EnvironmentVersion: dlImage}, + }}, + } +} + +// jobsSubmitClient submits one-time runs through the Jobs API. +type jobsSubmitClient struct { + c *client.DatabricksClient +} + +func newJobsSubmitClient(w *databricks.WorkspaceClient) (*jobsSubmitClient, error) { + c, err := client.New(w.Config) + if err != nil { + return nil, err + } + return &jobsSubmitClient{c: c}, nil +} + +type submitRunResponse struct { + RunID int64 `json:"run_id,omitempty"` +} + +// submit POSTs the payload to runs/submit and returns the new run_id. +func (j *jobsSubmitClient) submit(ctx context.Context, payload jobsSubmitRun) (int64, error) { + var resp submitRunResponse + if err := j.c.Do(ctx, http.MethodPost, jobsRunsSubmitPath, auth.WorkspaceIDHeaders(j.c.Config), nil, payload, &resp); err != nil { + return 0, err + } + return resp.RunID, nil +} + +// submitToken resolves the idempotency token: the --idempotency-key flag wins, +// then the config's token, else a generated one. Capped at the Jobs API's 64. +func submitToken(flag string, cfg *runConfig) string { + token := flag + if token == "" && cfg.IdempotencyToken != nil { + token = *cfg.IdempotencyToken + } + if token == "" { + token = uuid.NewString() + } + if len(token) > 64 { + token = token[:64] + } + return token +} + +// submitWorkload runs the submit happy path: ensure the experiment directory, +// upload the launch artifacts, assemble the Jobs payload, and submit it. It +// returns the new run_id and its dashboard URL. +func submitWorkload(ctx context.Context, w *databricks.WorkspaceClient, cfg *runConfig, configPath, idempotencyKey string) (int64, string, error) { + // Resolving usage_policy_name to a budget policy id and packaging a + // code_source snapshot are not ported yet; reject rather than silently drop. + if cfg.UsagePolicyName != nil { + return 0, "", errors.New("usage_policy_name is not yet supported") + } + if cfg.CodeSource != nil { + return 0, "", errors.New("code_source is not yet supported") + } + + experimentDir := "" + if cfg.MLflowExperimentDirectory != nil { + experimentDir = *cfg.MLflowExperimentDirectory + } + if err := ensureExperimentDirectory(ctx, w, experimentDir); err != nil { + return 0, "", err + } + + base, err := userWorkspaceDir(ctx, w) + if err != nil { + return 0, "", err + } + runName := "" + if cfg.MLflowRunName != nil { + runName = *cfg.MLflowRunName + } + funcDir := cliLaunchDir(base, cfg.ExperimentName, runName) + + fc, err := filer.NewWorkspaceFilesClient(w, funcDir) + if err != nil { + return 0, "", err + } + items, err := buildArtifacts(cfg, configPath) + if err != nil { + return 0, "", err + } + if err := uploadArtifacts(ctx, fc, items); err != nil { + return 0, "", err + } + + runtimeVersion, _ := cfg.runtimeVersion() + payload := buildSubmitPayload(cfg, path.Join(funcDir, commandScriptName), dlRuntimeImage(ctx, runtimeVersion)) + payload.IdempotencyToken = submitToken(idempotencyKey, cfg) + + jc, err := newJobsSubmitClient(w) + if err != nil { + return 0, "", err + } + runID, err := jc.submit(ctx, payload) + if err != nil { + return 0, "", err + } + + dashboardURL := strings.TrimRight(w.Config.Host, "/") + "/jobs/runs/" + strconv.FormatInt(runID, 10) + return runID, dashboardURL, nil +} diff --git a/experimental/air/cmd/runsubmit_test.go b/experimental/air/cmd/runsubmit_test.go new file mode 100644 index 0000000000..c43dda0fcd --- /dev/null +++ b/experimental/air/cmd/runsubmit_test.go @@ -0,0 +1,145 @@ +package aircmd + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/databricks/cli/libs/testserver" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDlRuntimeImage(t *testing.T) { + ctx := t.Context() + // A config runtime version wins and is used bare. + assert.Equal(t, "5", dlRuntimeImage(ctx, "5")) + // Default, with the CLIENT-GPU- prefix stripped for the GPU_* path. + assert.Equal(t, "4", dlRuntimeImage(ctx, "")) + // Env override, prefix stripped. + t.Setenv(dlRuntimeImageEnv, "CLIENT-GPU-7") + assert.Equal(t, "7", dlRuntimeImage(ctx, "")) +} + +func TestBuildSubmitPayload(t *testing.T) { + cfg := &runConfig{ + ExperimentName: "exp", + Command: new("python train.py"), + Compute: &computeConfig{AcceleratorType: "GPU_8xH100", NumAccelerators: 16}, + MaxRetries: new(2), + TimeoutMinutes: new(30), + MLflowRunName: new("run-v2"), + MLflowExperimentDirectory: new("/Workspace/Users/me/exp"), + } + + p := buildSubmitPayload(cfg, "/d/command.sh", "5") + + assert.Equal(t, "exp", p.RunName) + assert.Equal(t, 1800, p.TimeoutSeconds) + require.Len(t, p.Environments, 1) + assert.Equal(t, aiRuntimeEnvironmentKey, p.Environments[0].EnvironmentKey) + assert.Equal(t, "5", p.Environments[0].Spec.EnvironmentVersion) + + require.Len(t, p.Tasks, 1) + task := p.Tasks[0] + assert.Equal(t, "exp", task.TaskKey) + assert.Equal(t, "ALL_SUCCESS", task.RunIf) + assert.Equal(t, aiRuntimeEnvironmentKey, task.EnvironmentKey) + assert.Equal(t, 2, task.MaxRetries) + assert.True(t, task.RetryOnTimeout) + + at := task.AiRuntimeTask + assert.Equal(t, "exp", at.Experiment) + assert.Equal(t, "run-v2", at.MlflowRun) + assert.Equal(t, "/Workspace/Users/me/exp", at.MlflowExperimentDirectory) + require.Len(t, at.Deployments, 1) + assert.Equal(t, "/d/command.sh", at.Deployments[0].CommandPath) + assert.Equal(t, aiRuntimeCompute{AcceleratorType: "GPU_8xH100", AcceleratorCount: 16}, at.Deployments[0].Compute) +} + +func TestSubmitToken(t *testing.T) { + cfg := &runConfig{IdempotencyToken: new("from-config")} + assert.Equal(t, "from-flag", submitToken("from-flag", cfg)) // flag wins + assert.Equal(t, "from-config", submitToken("", cfg)) // then config + assert.NotEmpty(t, submitToken("", &runConfig{})) // else generated + assert.Len(t, submitToken(string(make([]byte, 80)), cfg), 64) // capped +} + +func TestJobsSubmitClient(t *testing.T) { + server := testserver.New(t) + t.Cleanup(server.Close) + + var got jobsSubmitRun + server.Handle("POST", "/api/2.2/jobs/runs/submit", func(req testserver.Request) any { + require.NoError(t, json.Unmarshal(req.Body, &got)) + return submitRunResponse{RunID: 999} + }) + + w := &databricks.WorkspaceClient{Config: &config.Config{Host: server.URL, Token: "token"}} + jc, err := newJobsSubmitClient(w) + require.NoError(t, err) + + runID, err := jc.submit(t.Context(), jobsSubmitRun{RunName: "exp", Tasks: []submitTask{{TaskKey: "exp"}}}) + require.NoError(t, err) + assert.Equal(t, int64(999), runID) + assert.Equal(t, "exp", got.RunName) +} + +func TestSubmitWorkload(t *testing.T) { + server := testserver.New(t) + t.Cleanup(server.Close) + testserver.AddDefaultHandlers(server) + + var got jobsSubmitRun + server.Handle("POST", "/api/2.2/jobs/runs/submit", func(req testserver.Request) any { + require.NoError(t, json.Unmarshal(req.Body, &got)) + return submitRunResponse{RunID: 777} + }) + w, err := databricks.NewWorkspaceClient(&databricks.Config{Host: server.URL, Token: "token"}) + require.NoError(t, err) + + cfgPath := writeConfigFile(t, "run.yaml", minimalConfig) + cfg, err := loadRunConfig(cfgPath) + require.NoError(t, err) + + runID, dashboardURL, err := submitWorkload(t.Context(), w, cfg, cfgPath, "idem-key") + require.NoError(t, err) + assert.Equal(t, int64(777), runID) + assert.Contains(t, dashboardURL, "/jobs/runs/777") + + // The submitted payload is a native ai_runtime_task pointing at the uploaded + // command.sh under the run's launch directory. + assert.Equal(t, "my-run", got.RunName) + assert.Equal(t, "idem-key", got.IdempotencyToken) + require.Len(t, got.Environments, 1) + require.Len(t, got.Tasks, 1) + at := got.Tasks[0].AiRuntimeTask + require.Len(t, at.Deployments, 1) + d := at.Deployments[0] + assert.True(t, strings.HasSuffix(d.CommandPath, "/"+commandScriptName), d.CommandPath) + assert.Contains(t, d.CommandPath, "/.air/cli_launch/") + assert.Equal(t, aiRuntimeCompute{AcceleratorType: "GPU_1xH100", AcceleratorCount: 1}, d.Compute) +} + +func TestSubmitWorkloadGuards(t *testing.T) { + w := newFakeWorkspaceClient(t) + cfgPath := writeConfigFile(t, "run.yaml", minimalConfig) + base, err := loadRunConfig(cfgPath) + require.NoError(t, err) + + t.Run("usage_policy_name rejected", func(t *testing.T) { + cfg := *base + cfg.UsagePolicyName = new("p") + _, _, err := submitWorkload(t.Context(), w, &cfg, cfgPath, "") + require.ErrorContains(t, err, "usage_policy_name is not yet supported") + }) + + t.Run("code_source rejected", func(t *testing.T) { + cfg := *base + cfg.CodeSource = &codeSourceConfig{Type: "snapshot"} + _, _, err := submitWorkload(t.Context(), w, &cfg, cfgPath, "") + require.ErrorContains(t, err, "code_source is not yet supported") + }) +} diff --git a/experimental/air/cmd/runupload.go b/experimental/air/cmd/runupload.go index c9c4047381..fb9ca00b98 100644 --- a/experimental/air/cmd/runupload.go +++ b/experimental/air/cmd/runupload.go @@ -3,10 +3,14 @@ package aircmd import ( "bytes" "context" + "encoding/json" "fmt" "io" + "maps" "os" "path/filepath" + "slices" + "strings" "github.com/databricks/cli/libs/filer" "go.yaml.in/yaml/v3" @@ -20,6 +24,8 @@ const ( commandScriptName = "command.sh" requirementsName = "requirements.yaml" hyperparametersName = "hyperparameters.yaml" + envVarsName = "env_vars.json" + secretEnvVarsName = "secret_env_vars.json" ) // maxConfigYAMLBytes caps training_config.yaml. It is referenced by the Jobs @@ -96,9 +102,59 @@ func buildArtifacts(cfg *runConfig, configPath string) ([]uploadItem, error) { items = append(items, uploadItem{hyperparametersName, data}) } + // The ai_runtime_task proto carries no inline env vars or secrets; stage them + // as JSON files co-located with command.sh for the server-side launcher. + if len(cfg.EnvVariables) > 0 { + data, err := json.Marshal(envVarEntries(cfg.EnvVariables)) + if err != nil { + return nil, fmt.Errorf("failed to serialize env_variables: %w", err) + } + items = append(items, uploadItem{envVarsName, data}) + } + if len(cfg.Secrets) > 0 { + data, err := json.Marshal(secretEnvVarEntries(cfg.Secrets)) + if err != nil { + return nil, fmt.Errorf("failed to serialize secrets: %w", err) + } + items = append(items, uploadItem{secretEnvVarsName, data}) + } + return items, nil } +// envVarEntry is one entry in env_vars.json. +type envVarEntry struct { + Name string `json:"name"` + Value string `json:"value"` +} + +// secretEnvVarEntry is one entry in secret_env_vars.json. The YAML side is +// {ENV_VAR: "scope/key"}; the launcher wants the split form. +type secretEnvVarEntry struct { + Name string `json:"name"` + SecretScope string `json:"secret_scope"` + SecretKey string `json:"secret_key"` +} + +// envVarEntries renders env_variables sorted by name for deterministic output. +func envVarEntries(vars map[string]string) []envVarEntry { + out := make([]envVarEntry, 0, len(vars)) + for _, name := range slices.Sorted(maps.Keys(vars)) { + out = append(out, envVarEntry{Name: name, Value: vars[name]}) + } + return out +} + +// secretEnvVarEntries renders secrets sorted by name for deterministic output. +func secretEnvVarEntries(secrets map[string]string) []secretEnvVarEntry { + out := make([]secretEnvVarEntry, 0, len(secrets)) + for _, name := range slices.Sorted(maps.Keys(secrets)) { + scope, key, _ := strings.Cut(secrets[name], "/") + out = append(out, secretEnvVarEntry{Name: name, SecretScope: scope, SecretKey: key}) + } + return out +} + // uploadArtifacts writes each artifact into the launch directory, overwriting and // creating parents as needed. // diff --git a/experimental/air/cmd/runupload_test.go b/experimental/air/cmd/runupload_test.go index ec700a9229..0c87524735 100644 --- a/experimental/air/cmd/runupload_test.go +++ b/experimental/air/cmd/runupload_test.go @@ -83,6 +83,26 @@ func TestBuildArtifacts_InlineRequirementsAndParameters(t *testing.T) { assert.Contains(t, req, "- torch") } +func TestBuildArtifacts_EnvVarsAndSecrets(t *testing.T) { + path := writeConfigFile(t, "run.yaml", "x: y\n") + cfg := &runConfig{ + Command: new("echo hi"), + EnvVariables: map[string]string{"WANDB": "demo"}, + Secrets: map[string]string{"HF_TOKEN": "myscope/hf"}, + } + + items, err := buildArtifacts(cfg, path) + require.NoError(t, err) + assert.Subset(t, itemNames(items), []string{envVarsName, secretEnvVarsName}) + + byName := map[string][]byte{} + for _, it := range items { + byName[it.name] = it.data + } + assert.JSONEq(t, `[{"name":"WANDB","value":"demo"}]`, string(byName[envVarsName])) + assert.JSONEq(t, `[{"name":"HF_TOKEN","secret_scope":"myscope","secret_key":"hf"}]`, string(byName[secretEnvVarsName])) +} + func TestBuildArtifacts_RequirementsFile(t *testing.T) { dir := t.TempDir() require.NoError(t, os.WriteFile(filepath.Join(dir, "run.yaml"), []byte("x: y\n"), 0o600))