diff --git a/cmd/obol/model.go b/cmd/obol/model.go index c16c062..901d67a 100644 --- a/cmd/obol/model.go +++ b/cmd/obol/model.go @@ -23,7 +23,7 @@ func modelCommand(cfg *config.Config) *cli.Command { Flags: []cli.Flag{ &cli.StringFlag{ Name: "provider", - Usage: "Provider name (anthropic, openai)", + Usage: "Provider name (e.g. anthropic, openai, zai, deepseek)", }, &cli.StringFlag{ Name: "api-key", @@ -38,7 +38,7 @@ func modelCommand(cfg *config.Config) *cli.Command { // Interactive mode if flags not provided if provider == "" || apiKey == "" { var err error - provider, apiKey, err = promptModelConfig() + provider, apiKey, err = promptModelConfig(cfg) if err != nil { return err } @@ -64,7 +64,7 @@ func modelCommand(cfg *config.Config) *cli.Command { fmt.Println("Global llmspy providers:") fmt.Println() - fmt.Printf(" %-12s %-8s %-10s %s\n", "PROVIDER", "ENABLED", "API KEY", "ENV VAR") + fmt.Printf(" %-20s %-8s %-10s %s\n", "PROVIDER", "ENABLED", "API KEY", "ENV VAR") for _, name := range providers { s := status[name] key := "n/a" @@ -75,8 +75,12 @@ func modelCommand(cfg *config.Config) *cli.Command { key = "missing" } } - fmt.Printf(" %-12s %-8t %-10s %s\n", name, s.Enabled, key, s.EnvVar) + fmt.Printf(" %-20s %-8t %-10s %s\n", name, s.Enabled, key, s.EnvVar) } + + // Show hint about available providers + fmt.Println() + fmt.Println("Run 'obol model setup' to configure a provider.") return nil }, }, @@ -85,13 +89,23 @@ func modelCommand(cfg *config.Config) *cli.Command { } // promptModelConfig interactively asks the user for provider and API key. -func promptModelConfig() (string, string, error) { +// It queries the running llmspy pod for available providers. +func promptModelConfig(cfg *config.Config) (string, string, error) { + providers, err := model.GetAvailableProviders(cfg) + if err != nil { + return "", "", fmt.Errorf("failed to discover providers: %w", err) + } + if len(providers) == 0 { + return "", "", fmt.Errorf("no cloud providers found in llmspy") + } + reader := bufio.NewReader(os.Stdin) - fmt.Println("Select a provider:") - fmt.Println(" [1] Anthropic") - fmt.Println(" [2] OpenAI") - fmt.Print("\nChoice [1]: ") + fmt.Println("Available providers:") + for i, p := range providers { + fmt.Printf(" [%d] %s (%s)\n", i+1, p.Name, p.ID) + } + fmt.Printf("\nChoice [1]: ") line, _ := reader.ReadString('\n') choice := strings.TrimSpace(line) @@ -99,24 +113,18 @@ func promptModelConfig() (string, string, error) { choice = "1" } - var provider, display string - switch choice { - case "1": - provider = "anthropic" - display = "Anthropic" - case "2": - provider = "openai" - display = "OpenAI" - default: - return "", "", fmt.Errorf("unknown choice: %s", choice) + idx := 0 + if _, err := fmt.Sscanf(choice, "%d", &idx); err != nil || idx < 1 || idx > len(providers) { + return "", "", fmt.Errorf("invalid choice: %s", choice) } + selected := providers[idx-1] - fmt.Printf("\n%s API key: ", display) + fmt.Printf("\n%s API key (%s): ", selected.Name, selected.EnvVar) apiKey, _ := reader.ReadString('\n') apiKey = strings.TrimSpace(apiKey) if apiKey == "" { return "", "", fmt.Errorf("API key is required") } - return provider, apiKey, nil + return selected.ID, apiKey, nil } diff --git a/internal/model/model.go b/internal/model/model.go index 0399056..6ce25f5 100644 --- a/internal/model/model.go +++ b/internal/model/model.go @@ -19,10 +19,11 @@ const ( deployName = "llmspy" ) -// providerEnvKeys maps provider names to their Secret key names. -var providerEnvKeys = map[string]string{ - "anthropic": "ANTHROPIC_API_KEY", - "openai": "OPENAI_API_KEY", +// ProviderInfo describes an llmspy provider discovered from the running pod. +type ProviderInfo struct { + ID string // provider id (e.g. "zai", "anthropic") + Name string // display name (e.g. "Z.AI", "Anthropic") + EnvVar string // env var for API key (e.g. "ZHIPU_API_KEY") } // ProviderStatus captures effective global llmspy provider state. @@ -33,14 +34,10 @@ type ProviderStatus struct { } // ConfigureLLMSpy enables a cloud provider in the llmspy gateway. -// It patches the llms-secrets Secret with the API key, enables the provider +// It discovers the provider's env var from the running llmspy pod, +// patches the llms-secrets Secret with the API key, enables the provider // in the llmspy-config ConfigMap, and restarts the deployment. func ConfigureLLMSpy(cfg *config.Config, provider, apiKey string) error { - envKey, ok := providerEnvKeys[provider] - if !ok { - return fmt.Errorf("unsupported llmspy provider: %s (supported: anthropic, openai)", provider) - } - kubectlBinary := filepath.Join(cfg.BinDir, "kubectl") kubeconfigPath := filepath.Join(cfg.ConfigDir, "kubeconfig.yaml") @@ -48,6 +45,12 @@ func ConfigureLLMSpy(cfg *config.Config, provider, apiKey string) error { return fmt.Errorf("cluster not running. Run 'obol stack up' first") } + // Discover the env var name from the llmspy pod's providers.json + envKey, err := getProviderEnvKey(kubectlBinary, kubeconfigPath, provider) + if err != nil { + return err + } + // 1. Patch the Secret with the API key fmt.Printf("Configuring llmspy: setting %s key...\n", provider) patchJSON := fmt.Sprintf(`{"stringData":{"%s":"%s"}}`, envKey, apiKey) @@ -83,7 +86,86 @@ func ConfigureLLMSpy(cfg *config.Config, provider, apiKey string) error { return nil } -// GetProviderStatus reads llmspy ConfigMap + Secret and returns global provider status. +// getProviderEnvKey queries the llmspy pod for the env var name a provider uses. +// It reads the merged providers.json inside the pod (package defaults + ConfigMap overrides). +func getProviderEnvKey(kubectlBinary, kubeconfigPath, provider string) (string, error) { + script := fmt.Sprintf(`import json +with open('/home/llms/.llms/providers.json') as f: + d = json.load(f) +p = d.get('%s') +if p and p.get('env'): + print(p['env'][0]) +`, provider) + + output, err := kubectlOutput(kubectlBinary, kubeconfigPath, + "exec", "-n", namespace, fmt.Sprintf("deploy/%s", deployName), "--", + "python3", "-c", script) + if err != nil { + return "", fmt.Errorf("failed to query llmspy for provider %q: %w", provider, err) + } + return parseProviderEnvKey(provider, output) +} + +// parseProviderEnvKey extracts an env var name from kubectl exec output. +func parseProviderEnvKey(provider, output string) (string, error) { + envKey := strings.TrimSpace(output) + if envKey == "" { + return "", fmt.Errorf("unknown provider %q — run 'obol model status' to see available providers", provider) + } + return envKey, nil +} + +// GetAvailableProviders queries the llmspy pod for all providers that accept an API key. +func GetAvailableProviders(cfg *config.Config) ([]ProviderInfo, error) { + kubectlBinary := filepath.Join(cfg.BinDir, "kubectl") + kubeconfigPath := filepath.Join(cfg.ConfigDir, "kubeconfig.yaml") + if _, err := os.Stat(kubeconfigPath); os.IsNotExist(err) { + return nil, fmt.Errorf("cluster not running. Run 'obol stack up' first") + } + + script := `import json +with open('/home/llms/.llms/providers.json') as f: + d = json.load(f) +for pid in sorted(d): + p = d[pid] + env = p.get('env', []) + if env: + print(pid + '\t' + p.get('name', pid) + '\t' + env[0]) +` + output, err := kubectlOutput(kubectlBinary, kubeconfigPath, + "exec", "-n", namespace, fmt.Sprintf("deploy/%s", deployName), "--", + "python3", "-c", script) + if err != nil { + return nil, fmt.Errorf("failed to query llmspy providers: %w", err) + } + + return parseAvailableProviders(output), nil +} + +// parseAvailableProviders parses tab-separated kubectl exec output into ProviderInfo slices. +func parseAvailableProviders(output string) []ProviderInfo { + trimmed := strings.TrimSpace(output) + if trimmed == "" { + return nil + } + var providers []ProviderInfo + for _, line := range strings.Split(trimmed, "\n") { + parts := strings.SplitN(line, "\t", 3) + if len(parts) != 3 { + continue + } + providers = append(providers, ProviderInfo{ + ID: parts[0], + Name: parts[1], + EnvVar: parts[2], + }) + } + return providers +} + +// GetProviderStatus reads llmspy state and returns global provider status. +// It queries the llmspy pod for available providers and cross-references +// with the ConfigMap (enabled/disabled) and Secret (API keys). func GetProviderStatus(cfg *config.Config) (map[string]ProviderStatus, error) { kubectlBinary := filepath.Join(cfg.BinDir, "kubectl") kubeconfigPath := filepath.Join(cfg.ConfigDir, "kubeconfig.yaml") @@ -91,17 +173,47 @@ func GetProviderStatus(cfg *config.Config) (map[string]ProviderStatus, error) { return nil, fmt.Errorf("cluster not running. Run 'obol stack up' first") } + // Get all available providers from llmspy (with env var names) + available, err := GetAvailableProviders(cfg) + if err != nil { + return nil, err + } + + // Read enabled/disabled state from ConfigMap llmsRaw, err := kubectlOutput(kubectlBinary, kubeconfigPath, "get", "configmap", configMapName, "-n", namespace, "-o", "jsonpath={.data.llms\\.json}") if err != nil { return nil, err } + + // Read Secret to check which API keys are set + secretRaw, err := kubectlOutput(kubectlBinary, kubeconfigPath, + "get", "secret", secretName, "-n", namespace, "-o", "json") + if err != nil { + return nil, err + } + + return buildProviderStatus(available, []byte(llmsRaw), []byte(secretRaw)) +} + +// buildProviderStatus is the pure logic for building provider status from raw data. +// available: providers discovered from the llmspy pod +// llmsJSON: the llms.json content from the ConfigMap +// secretJSON: the full Secret JSON (with base64-encoded .data) +func buildProviderStatus(available []ProviderInfo, llmsJSON, secretJSON []byte) (map[string]ProviderStatus, error) { + envKeyByProvider := make(map[string]string) + for _, p := range available { + envKeyByProvider[p.ID] = p.EnvVar + } + var llmsConfig map[string]interface{} - if err := json.Unmarshal([]byte(llmsRaw), &llmsConfig); err != nil { + if err := json.Unmarshal(llmsJSON, &llmsConfig); err != nil { return nil, fmt.Errorf("failed to parse llms.json from ConfigMap: %w", err) } status := make(map[string]ProviderStatus) + + // Seed from ConfigMap providers (shows what's been configured) if providers, ok := llmsConfig["providers"].(map[string]interface{}); ok { for name, raw := range providers { enabled := false @@ -110,38 +222,37 @@ func GetProviderStatus(cfg *config.Config) (map[string]ProviderStatus, error) { enabled = v } } - keyEnv := providerEnvKeys[name] status[name] = ProviderStatus{ - Enabled: enabled, - // Ollama needs no API key, so it's always considered "has key". - // Cloud providers are updated below from the actual K8s Secret. + Enabled: enabled, HasAPIKey: name == "ollama", - EnvVar: keyEnv, + EnvVar: envKeyByProvider[name], } } } - secretRaw, err := kubectlOutput(kubectlBinary, kubeconfigPath, - "get", "secret", secretName, "-n", namespace, "-o", "json") - if err != nil { - return nil, err - } + // Parse Secret var secret struct { Data map[string]string `json:"data"` } - if err := json.Unmarshal([]byte(secretRaw), &secret); err != nil { + if err := json.Unmarshal(secretJSON, &secret); err != nil { return nil, fmt.Errorf("failed to parse llms secret: %w", err) } - for provider, envKey := range providerEnvKeys { - st := status[provider] - st.EnvVar = envKey - if v, ok := secret.Data[envKey]; ok && strings.TrimSpace(v) != "" { + // Cross-reference Secret keys with provider env vars + secretKeys := make(map[string]bool) + for k, v := range secret.Data { + if strings.TrimSpace(v) != "" { + secretKeys[k] = true + } + } + for name, st := range status { + if st.EnvVar != "" && secretKeys[st.EnvVar] { st.HasAPIKey = true + status[name] = st } - status[provider] = st } + // Ensure Ollama always shows if _, ok := status["ollama"]; !ok { status["ollama"] = ProviderStatus{ Enabled: true, @@ -156,45 +267,18 @@ func GetProviderStatus(cfg *config.Config) (map[string]ProviderStatus, error) { // sets providers..enabled = true, and patches the ConfigMap back. func enableProviderInConfigMap(kubectlBinary, kubeconfigPath, provider string) error { // Read current llms.json from ConfigMap - var stdout bytes.Buffer - cmd := exec.Command(kubectlBinary, "get", "configmap", configMapName, - "-n", namespace, "-o", "jsonpath={.data.llms\\.json}") - cmd.Env = append(os.Environ(), fmt.Sprintf("KUBECONFIG=%s", kubeconfigPath)) - cmd.Stdout = &stdout - var stderr bytes.Buffer - cmd.Stderr = &stderr - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to read ConfigMap: %w\n%s", err, stderr.String()) - } - - // Parse JSON - var llmsConfig map[string]interface{} - if err := json.Unmarshal(stdout.Bytes(), &llmsConfig); err != nil { - return fmt.Errorf("failed to parse llms.json: %w", err) - } - - // Set providers..enabled = true - providers, ok := llmsConfig["providers"].(map[string]interface{}) - if !ok { - providers = make(map[string]interface{}) - llmsConfig["providers"] = providers - } - - providerCfg, ok := providers[provider].(map[string]interface{}) - if !ok { - providerCfg = make(map[string]interface{}) - providers[provider] = providerCfg + raw, err := kubectlOutput(kubectlBinary, kubeconfigPath, + "get", "configmap", configMapName, "-n", namespace, "-o", "jsonpath={.data.llms\\.json}") + if err != nil { + return fmt.Errorf("failed to read ConfigMap: %w", err) } - providerCfg["enabled"] = true - // Marshal back to JSON - updated, err := json.Marshal(llmsConfig) + updated, err := patchLLMsJSON([]byte(raw), provider) if err != nil { - return fmt.Errorf("failed to marshal llms.json: %w", err) + return err } - // Patch ConfigMap - // Use strategic merge patch with the new llms.json + // Build ConfigMap patch patchData := map[string]interface{}{ "data": map[string]string{ "llms.json": string(updated), @@ -210,6 +294,30 @@ func enableProviderInConfigMap(kubectlBinary, kubeconfigPath, provider string) e "-p", string(patchJSON), "--type=merge") } +// patchLLMsJSON takes raw llms.json content and returns updated JSON +// with providers..enabled = true. +func patchLLMsJSON(llmsJSON []byte, provider string) ([]byte, error) { + var llmsConfig map[string]interface{} + if err := json.Unmarshal(llmsJSON, &llmsConfig); err != nil { + return nil, fmt.Errorf("failed to parse llms.json: %w", err) + } + + providers, ok := llmsConfig["providers"].(map[string]interface{}) + if !ok { + providers = make(map[string]interface{}) + llmsConfig["providers"] = providers + } + + providerCfg, ok := providers[provider].(map[string]interface{}) + if !ok { + providerCfg = make(map[string]interface{}) + providers[provider] = providerCfg + } + providerCfg["enabled"] = true + + return json.Marshal(llmsConfig) +} + // kubectl runs a kubectl command with the given kubeconfig and returns any error. func kubectl(binary, kubeconfig string, args ...string) error { cmd := exec.Command(binary, args...) diff --git a/internal/model/model_test.go b/internal/model/model_test.go new file mode 100644 index 0000000..24f93cb --- /dev/null +++ b/internal/model/model_test.go @@ -0,0 +1,369 @@ +package model + +import ( + "encoding/json" + "testing" +) + +func TestParseProviderEnvKey(t *testing.T) { + tests := []struct { + name string + provider string + output string + want string + wantErr bool + }{ + { + name: "anthropic", + provider: "anthropic", + output: "ANTHROPIC_API_KEY\n", + want: "ANTHROPIC_API_KEY", + }, + { + name: "zai with trailing whitespace", + provider: "zai", + output: " ZHIPU_API_KEY \n", + want: "ZHIPU_API_KEY", + }, + { + name: "empty output means unknown provider", + provider: "nosuchprovider", + output: "", + wantErr: true, + }, + { + name: "whitespace-only output means unknown provider", + provider: "nosuchprovider", + output: " \n ", + wantErr: true, + }, + { + name: "openai", + provider: "openai", + output: "OPENAI_API_KEY", + want: "OPENAI_API_KEY", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseProviderEnvKey(tt.provider, tt.output) + if tt.wantErr { + if err == nil { + t.Fatalf("expected error, got %q", got) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != tt.want { + t.Errorf("got %q, want %q", got, tt.want) + } + }) + } +} + +func TestParseAvailableProviders(t *testing.T) { + tests := []struct { + name string + output string + want []ProviderInfo + }{ + { + name: "empty output", + output: "", + want: nil, + }, + { + name: "whitespace only", + output: " \n ", + want: nil, + }, + { + name: "single provider", + output: "anthropic\tAnthropic\tANTHROPIC_API_KEY\n", + want: []ProviderInfo{ + {ID: "anthropic", Name: "Anthropic", EnvVar: "ANTHROPIC_API_KEY"}, + }, + }, + { + name: "multiple providers sorted", + output: "anthropic\tAnthropic\tANTHROPIC_API_KEY\n" + + "openai\tOpenAI\tOPENAI_API_KEY\n" + + "zai\tZ.AI\tZHIPU_API_KEY\n", + want: []ProviderInfo{ + {ID: "anthropic", Name: "Anthropic", EnvVar: "ANTHROPIC_API_KEY"}, + {ID: "openai", Name: "OpenAI", EnvVar: "OPENAI_API_KEY"}, + {ID: "zai", Name: "Z.AI", EnvVar: "ZHIPU_API_KEY"}, + }, + }, + { + name: "malformed line skipped", + output: "badline\n" + "anthropic\tAnthropic\tANTHROPIC_API_KEY\n", + want: []ProviderInfo{ + {ID: "anthropic", Name: "Anthropic", EnvVar: "ANTHROPIC_API_KEY"}, + }, + }, + { + name: "tab in name preserved", + output: "deepseek\tDeepSeek\tDEEPSEEK_API_KEY\n", + want: []ProviderInfo{ + {ID: "deepseek", Name: "DeepSeek", EnvVar: "DEEPSEEK_API_KEY"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := parseAvailableProviders(tt.output) + if len(got) != len(tt.want) { + t.Fatalf("got %d providers, want %d", len(got), len(tt.want)) + } + for i := range got { + if got[i] != tt.want[i] { + t.Errorf("provider[%d]: got %+v, want %+v", i, got[i], tt.want[i]) + } + } + }) + } +} + +func TestBuildProviderStatus(t *testing.T) { + t.Run("basic status with cloud provider key set", func(t *testing.T) { + available := []ProviderInfo{ + {ID: "anthropic", Name: "Anthropic", EnvVar: "ANTHROPIC_API_KEY"}, + {ID: "openai", Name: "OpenAI", EnvVar: "OPENAI_API_KEY"}, + {ID: "zai", Name: "Z.AI", EnvVar: "ZHIPU_API_KEY"}, + } + + llmsJSON := []byte(`{ + "providers": { + "ollama": {"enabled": true}, + "anthropic": {"enabled": true}, + "openai": {"enabled": false}, + "zai": {"enabled": true} + } + }`) + + // Secret .data values are base64 in real k8s, but our code just checks + // if the key exists and the value is non-empty (the cross-reference uses + // the raw string from the JSON — k8s returns base64 in .data). + secretJSON := []byte(`{ + "data": { + "ANTHROPIC_API_KEY": "c2stYW50LXh4eA==", + "OPENAI_API_KEY": "", + "ZHIPU_API_KEY": "ZWU1NjM5Nzk=" + } + }`) + + status, err := buildProviderStatus(available, llmsJSON, secretJSON) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Ollama: enabled, always has key + if s := status["ollama"]; !s.Enabled || !s.HasAPIKey { + t.Errorf("ollama: got enabled=%t hasKey=%t, want enabled=true hasKey=true", s.Enabled, s.HasAPIKey) + } + + // Anthropic: enabled, key set + if s := status["anthropic"]; !s.Enabled || !s.HasAPIKey || s.EnvVar != "ANTHROPIC_API_KEY" { + t.Errorf("anthropic: got %+v, want enabled=true hasKey=true envVar=ANTHROPIC_API_KEY", s) + } + + // OpenAI: disabled, key empty + if s := status["openai"]; s.Enabled || s.HasAPIKey || s.EnvVar != "OPENAI_API_KEY" { + t.Errorf("openai: got %+v, want enabled=false hasKey=false envVar=OPENAI_API_KEY", s) + } + + // Z.AI: enabled, key set + if s := status["zai"]; !s.Enabled || !s.HasAPIKey || s.EnvVar != "ZHIPU_API_KEY" { + t.Errorf("zai: got %+v, want enabled=true hasKey=true envVar=ZHIPU_API_KEY", s) + } + }) + + t.Run("ollama injected when missing from configmap", func(t *testing.T) { + available := []ProviderInfo{ + {ID: "anthropic", Name: "Anthropic", EnvVar: "ANTHROPIC_API_KEY"}, + } + llmsJSON := []byte(`{"providers":{"anthropic":{"enabled":false}}}`) + secretJSON := []byte(`{"data":{}}`) + + status, err := buildProviderStatus(available, llmsJSON, secretJSON) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s, ok := status["ollama"]; !ok || !s.Enabled || !s.HasAPIKey { + t.Errorf("ollama should be injected as enabled with key; got %+v, ok=%t", s, ok) + } + }) + + t.Run("provider in configmap but not in available list gets no env var", func(t *testing.T) { + available := []ProviderInfo{} // no providers discovered + llmsJSON := []byte(`{"providers":{"mystery":{"enabled":true}}}`) + secretJSON := []byte(`{"data":{}}`) + + status, err := buildProviderStatus(available, llmsJSON, secretJSON) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s := status["mystery"]; !s.Enabled || s.EnvVar != "" { + t.Errorf("mystery: got %+v, want enabled=true envVar=''", s) + } + }) + + t.Run("empty providers section", func(t *testing.T) { + llmsJSON := []byte(`{}`) + secretJSON := []byte(`{"data":{}}`) + + status, err := buildProviderStatus(nil, llmsJSON, secretJSON) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Only ollama (injected) + if len(status) != 1 { + t.Errorf("expected 1 provider (ollama), got %d", len(status)) + } + }) + + t.Run("invalid llms json", func(t *testing.T) { + _, err := buildProviderStatus(nil, []byte(`not json`), []byte(`{"data":{}}`)) + if err == nil { + t.Fatal("expected error for invalid llms.json") + } + }) + + t.Run("invalid secret json", func(t *testing.T) { + _, err := buildProviderStatus(nil, []byte(`{}`), []byte(`not json`)) + if err == nil { + t.Fatal("expected error for invalid secret JSON") + } + }) +} + +func TestPatchLLMsJSON(t *testing.T) { + t.Run("enable existing disabled provider", func(t *testing.T) { + input := []byte(`{"providers":{"anthropic":{"enabled":false},"ollama":{"enabled":true}}}`) + + got, err := patchLLMsJSON(input, "anthropic") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var result map[string]interface{} + if err := json.Unmarshal(got, &result); err != nil { + t.Fatalf("output is not valid JSON: %v", err) + } + + providers := result["providers"].(map[string]interface{}) + anthropic := providers["anthropic"].(map[string]interface{}) + if anthropic["enabled"] != true { + t.Errorf("anthropic.enabled = %v, want true", anthropic["enabled"]) + } + + // Ollama should be untouched + ollama := providers["ollama"].(map[string]interface{}) + if ollama["enabled"] != true { + t.Errorf("ollama.enabled = %v, want true (untouched)", ollama["enabled"]) + } + }) + + t.Run("enable new provider not in config", func(t *testing.T) { + input := []byte(`{"providers":{"ollama":{"enabled":true}}}`) + + got, err := patchLLMsJSON(input, "zai") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var result map[string]interface{} + if err := json.Unmarshal(got, &result); err != nil { + t.Fatalf("output is not valid JSON: %v", err) + } + + providers := result["providers"].(map[string]interface{}) + zai := providers["zai"].(map[string]interface{}) + if zai["enabled"] != true { + t.Errorf("zai.enabled = %v, want true", zai["enabled"]) + } + }) + + t.Run("create providers section if missing", func(t *testing.T) { + input := []byte(`{"version":"1.0"}`) + + got, err := patchLLMsJSON(input, "deepseek") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var result map[string]interface{} + if err := json.Unmarshal(got, &result); err != nil { + t.Fatalf("output is not valid JSON: %v", err) + } + + // version preserved + if result["version"] != "1.0" { + t.Errorf("version lost: got %v", result["version"]) + } + + providers := result["providers"].(map[string]interface{}) + ds := providers["deepseek"].(map[string]interface{}) + if ds["enabled"] != true { + t.Errorf("deepseek.enabled = %v, want true", ds["enabled"]) + } + }) + + t.Run("preserves other provider fields", func(t *testing.T) { + input := []byte(`{"providers":{"anthropic":{"enabled":false,"customField":"keep"}}}`) + + got, err := patchLLMsJSON(input, "anthropic") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var result map[string]interface{} + if err := json.Unmarshal(got, &result); err != nil { + t.Fatalf("output is not valid JSON: %v", err) + } + + providers := result["providers"].(map[string]interface{}) + anthropic := providers["anthropic"].(map[string]interface{}) + if anthropic["enabled"] != true { + t.Errorf("enabled = %v, want true", anthropic["enabled"]) + } + if anthropic["customField"] != "keep" { + t.Errorf("customField = %v, want 'keep'", anthropic["customField"]) + } + }) + + t.Run("invalid json input", func(t *testing.T) { + _, err := patchLLMsJSON([]byte(`{bad`), "anthropic") + if err == nil { + t.Fatal("expected error for invalid JSON") + } + }) + + t.Run("idempotent enable", func(t *testing.T) { + input := []byte(`{"providers":{"anthropic":{"enabled":true}}}`) + + got, err := patchLLMsJSON(input, "anthropic") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var result map[string]interface{} + if err := json.Unmarshal(got, &result); err != nil { + t.Fatalf("output is not valid JSON: %v", err) + } + + providers := result["providers"].(map[string]interface{}) + anthropic := providers["anthropic"].(map[string]interface{}) + if anthropic["enabled"] != true { + t.Errorf("enabled = %v, want true", anthropic["enabled"]) + } + }) +} diff --git a/internal/openclaw/integration_test.go b/internal/openclaw/integration_test.go index cd72387..b7988e0 100644 --- a/internal/openclaw/integration_test.go +++ b/internal/openclaw/integration_test.go @@ -485,6 +485,46 @@ func TestIntegration_OpenAIInference(t *testing.T) { t.Logf("OpenAI response: %s", reply) } +func TestIntegration_ZaiInference(t *testing.T) { + cfg := requireCluster(t) + apiKey := requireEnvKey(t, "ZHIPU_API_KEY") + + const id = "test-zai" + t.Cleanup(func() { cleanupInstance(t, cfg, id) }) + + // Configure llmspy gateway via obol model setup — this provider was NOT in + // the old hardcoded map, so it only works with dynamic provider discovery. + t.Log("configuring llmspy via: obol model setup --provider zai") + obolRun(t, cfg, "model", "setup", "--provider", "zai", "--api-key", apiKey) + + cloud := &CloudProviderInfo{ + Name: "zai", + APIKey: apiKey, + ModelID: "glm-4-flash", + Display: "GLM-4 Flash", + } + + // Scaffold cloud overlay + deploy via obol openclaw sync + t.Logf("scaffolding OpenClaw instance %q with Z.AI via llmspy", id) + scaffoldCloudInstance(t, cfg, id, cloud) + + t.Log("deploying via: obol openclaw sync " + id) + obolRun(t, cfg, "openclaw", "sync", id) + + namespace := fmt.Sprintf("%s-%s", appName, id) + waitForPodReady(t, cfg, namespace) + + token := getGatewayToken(t, cfg, id) + t.Logf("retrieved gateway token (%d chars)", len(token)) + + baseURL := portForward(t, cfg, namespace) + agentModel := "ollama/glm-4-flash" // routed through llmspy + t.Logf("testing inference with model %s at %s", agentModel, baseURL) + + reply := chatCompletion(t, baseURL, agentModel, token) + t.Logf("Z.AI response: %s", reply) +} + func TestIntegration_MultiInstance(t *testing.T) { cfg := requireCluster(t) models := requireOllama(t)