diff --git a/benchmarks/migration-status.json b/benchmarks/migration-status.json index 4365269..1e6ea54 100644 --- a/benchmarks/migration-status.json +++ b/benchmarks/migration-status.json @@ -1,7 +1,42 @@ { "original_python_lines": 71696, - "migrated_python_lines": 15116, + "migrated_python_lines": 53813, "migrated_modules": [ + { + "module": "deps/apm_resolver", + "go_package": "internal/deps/apmresolver", + "python_lines": 918, + "status": "migrated", + "notes": "BFS dependency resolver with parallel download, cycle detection, NPM-hoisting flatten" + }, + { + "module": "deps/download_strategies", + "go_package": "internal/deps/downloadstrategies", + "python_lines": 1122, + "status": "migrated", + "notes": "DownloadDelegate: resilient HTTP GET, GitHub/ADO/GitLab/Artifactory file download, CDN fast-path" + }, + { + "module": "core/operations", + "go_package": "internal/core/operations", + "python_lines": 145, + "status": "migrated", + "notes": "Core operations facade: ConfigureClient, InstallPackage, UninstallPackage" + }, + { + "module": "models/dependency/reference", + "go_package": "internal/models/depreference", + "python_lines": 1559, + "status": "migrated", + "notes": "DependencyReference struct with full parse/canonicalize/install-path logic" + }, + { + "module": "deps/plugin_parser", + "go_package": "internal/deps/pluginparser", + "python_lines": 677, + "status": "migrated", + "notes": "Claude plugin.json parser and apm.yml synthesizer" + }, { "module": "src/apm_cli/constants.py", "go_package": "internal/constants", @@ -624,16 +659,372 @@ "python_lines": 232, "status": "migrated", "notes": "Thread-safe shared bare-clone cache" + }, + { + "module": "src/apm_cli/install/template.py", + "go_package": "internal/install/template", + "python_lines": 140, + "status": "migrated", + "notes": "" + }, + { + "module": "src/apm_cli/runtime/factory.py", + "go_package": "internal/runtime/factory", + "python_lines": 139, + "status": "migrated", + "notes": "" + }, + { + "module": "src/apm_cli/marketplace/registry.py", + "go_package": "internal/marketplace/registry", + "python_lines": 136, + "status": "migrated", + "notes": "" + }, + { + "module": "src/apm_cli/marketplace/git_stderr.py", + "go_package": "internal/marketplace/gitstderr", + "python_lines": 173, + "status": "migrated", + "notes": "" + }, + { + "module": "src/apm_cli/update_policy.py", + "go_package": "internal/updatepolicy", + "python_lines": 50, + "status": "migrated", + "notes": "Self-update build-time policy constants and helpers" + }, + { + "module": "src/apm_cli/output/models.py", + "go_package": "internal/output/models", + "python_lines": 136, + "status": "migrated", + "notes": "Compilation output data models: PlacementStrategy, ProjectAnalysis, CompilationResults, etc." + }, + { + "module": "src/apm_cli/integration/prompt_integrator.py", + "go_package": "internal/integration/promptintegrator", + "python_lines": 228, + "status": "migrated", + "notes": "Prompt file integration: find/copy .prompt.md files to .github/prompts/" + }, + { + "module": "src/apm_cli/integration/instruction_integrator.py", + "go_package": "internal/integration/instructionintegrator", + "python_lines": 479, + "status": "migrated", + "notes": "Instruction integration with cursor/claude/windsurf format transforms" + }, + { + "module": "src/apm_cli/core/command_logger.py", + "go_package": "internal/core/commandlogger", + "python_lines": 751, + "status": "migrated", + "notes": "CLI command logger infrastructure with Install/Command loggers" + }, + { + "module": "src/apm_cli/models/validation.py", + "go_package": "internal/models/validation", + "python_lines": 800, + "status": "migrated", + "notes": "PackageType/ValidationResult enums and DetectPackageType logic" + }, + { + "module": "src/apm_cli/core/target_detection.py", + "go_package": "internal/core/targetdetection", + "python_lines": 777, + "status": "migrated", + "notes": "Signal whitelist, detect_target v1, resolve_targets v2, expand_all_targets, format_provenance" + }, + { + "module": "src/apm_cli/models/apm_package.py", + "go_package": "internal/models/apmpackage", + "python_lines": 371, + "status": "migrated", + "notes": "APMPackage and PackageInfo data structs with lightweight apm.yml loader" + }, + { + "module": "src/apm_cli/marketplace/yml_schema.py", + "go_package": "internal/marketplace/ymlschema", + "python_lines": 805, + "status": "migrated", + "notes": "MarketplaceOwner, MarketplaceBuild, PackageEntry, MarketplaceConfig with YAML loader" + }, + { + "module": "src/apm_cli/policy/_help_text.py", + "go_package": "internal/policy/helptext", + "python_lines": 18, + "status": "migrated", + "notes": "Single help-text constant" + }, + { + "module": "src/apm_cli/policy/outcome_routing.py", + "go_package": "internal/policy/outcomerouting", + "python_lines": 195, + "status": "migrated", + "notes": "9-outcome policy routing table; PolicyFetchResult + PolicyViolationError" + }, + { + "module": "src/apm_cli/primitives/parser.py", + "go_package": "internal/primitives/primparser", + "python_lines": 275, + "status": "migrated", + "notes": "Primitive file parser with stdlib-only frontmatter; 4 tests pass" + }, + { + "module": "src/apm_cli/output/script_formatters.py", + "go_package": "internal/output/scriptformatters", + "python_lines": 349, + "status": "migrated", + "notes": "ASCII-only script execution formatter; no rich dependency" + }, + { + "module": "src/apm_cli/marketplace/_git_utils.py", + "go_package": "internal/marketplace/gitutils", + "python_lines": 19, + "status": "migrated", + "notes": "RedactToken utility" + }, + { + "module": "src/apm_cli/marketplace/_io.py", + "go_package": "internal/marketplace/mkio", + "python_lines": 30, + "status": "migrated", + "notes": "AtomicWrite/AtomicWriteString" + }, + { + "module": "src/apm_cli/adapters/client/windsurf.py", + "go_package": "internal/adapters/windsurf", + "python_lines": 48, + "status": "migrated", + "notes": "Windsurf/Cascade MCP client adapter" + }, + { + "module": "src/apm_cli/install/helpers/security_scan.py", + "go_package": "internal/install/securityscan", + "python_lines": 48, + "status": "migrated", + "notes": "Pre-deploy hidden-character security scan" + }, + { + "module": "src/apm_cli/deps/git_auth_env.py", + "go_package": "internal/deps/gitauthenv", + "python_lines": 152, + "status": "migrated", + "notes": "GitAuthEnvBuilder: SetupEnvironment, NoninteractiveEnv, SubprocessEnvDict" + }, + { + "module": "src/apm_cli/runtime/codex_runtime.py", + "go_package": "internal/runtime/codexruntime", + "python_lines": 151, + "status": "migrated", + "notes": "Codex CLI runtime adapter" + }, + { + "module": "src/apm_cli/runtime/llm_runtime.py", + "go_package": "internal/runtime/llmruntime", + "python_lines": 160, + "status": "migrated", + "notes": "LLM CLI runtime adapter" + }, + { + "module": "src/apm_cli/core/script_runner.py", + "go_package": "internal/core/scriptrunner", + "python_lines": 1138, + "status": "migrated", + "notes": "ScriptRunner+PromptCompiler: runtime detection, prompt discovery, command building, parameter substitution" + }, + { + "module": "src/apm_cli/output/formatters.py", + "go_package": "internal/output/compilationformatter", + "python_lines": 999, + "status": "migrated", + "notes": "CompilationFormatter: default/verbose/dry-run output formatting with plain-text rendering" + }, + { + "module": "src/apm_cli/integration/skill_integrator.py", + "go_package": "internal/integration/skillintegrator", + "python_lines": 1513, + "status": "migrated", + "notes": "SkillIntegrator: deploy SKILL.md-based packages to multiple target directories with collision detection and atomic writes" + }, + { + "module": "src/apm_cli/integration/hook_integrator.py", + "go_package": "internal/integration/hookintegrator", + "python_lines": 1071, + "status": "migrated", + "notes": "HookIntegrator: deploy hook scripts with permission setting and cleanup support" + }, + { + "module": "src/apm_cli/integration/command_integrator.py", + "go_package": "internal/integration/commandintegrator", + "python_lines": 775, + "status": "migrated", + "notes": "CommandIntegrator: deploy command definitions with dispatch table management" + }, + { + "module": "src/apm_cli/integration/base_integrator.py", + "go_package": "internal/integration/baseintegrator", + "python_lines": 562, + "status": "migrated", + "notes": "BaseIntegrator: CheckCollision, PartitionManagedFiles (trie routing), SyncRemoveFiles, FindFilesByGlob" + }, + { + "module": "src/apm_cli/integration/agent_integrator.py", + "go_package": "internal/integration/agentintegrator", + "python_lines": 606, + "status": "migrated", + "notes": "AgentIntegrator: TOML/Windsurf/Codex config generation with frontmatter YAML parser" + }, + { + "module": "src/apm_cli/integration/targets.py", + "go_package": "internal/integration/targets", + "python_lines": 846, + "status": "migrated", + "notes": "TargetProfile with UserSupported interface{}; ForScope handles CLAUDE_CONFIG_DIR env" + }, + { + "module": "src/apm_cli/core/auth.py", + "go_package": "internal/core/auth", + "python_lines": 1005, + "status": "migrated", + "notes": "AuthResolver: thread-safe cache, host classification (github/ghe/ghes/ado/gitlab/generic), token resolution chain" + }, + { + "module": "src/apm_cli/marketplace/builder.py", + "go_package": "internal/marketplace/builder", + "python_lines": 1059, + "status": "migrated", + "notes": "MarketplaceBuilder: concurrent resolve via goroutines+semaphore, JSON composition, atomic write" + }, + { + "module": "src/apm_cli/marketplace/ref_resolver.py", + "go_package": "internal/marketplace/refresolver", + "python_lines": 345, + "status": "migrated", + "notes": "RefResolver+RefCache with per-remote mutexes; context.WithTimeout; parseLsRemoteOutput" + }, + { + "module": "src/apm_cli/deps/dependency_graph.py", + "go_package": "internal/deps/depgraph", + "python_lines": 227, + "status": "migrated", + "notes": "DependencyNode/Tree/Graph as plain Go structs; no external deps needed" + }, + { + "module": "src/apm_cli/security/audit_report.py", + "go_package": "internal/security/auditreport", + "python_lines": 253, + "status": "migrated", + "notes": "FindingsToJSON/SARIF/Markdown: pure serialization functions, no external deps" + }, + { + "module": "src/apm_cli/core/experimental.py", + "go_package": "internal/core/experimental", + "python_lines": 278, + "status": "migrated", + "notes": "Feature-flag registry with ~/.apm/config.json persistence; IsEnabled/Enable/Disable/Reset" + }, + { + "module": "src/apm_cli/drift.py", + "go_package": "internal/install/drift", + "python_lines": 282, + "status": "migrated", + "notes": "DetectRefChange/Orphans/StaleFiles/ConfigDrift: stateless pure functions with interface-based types" + }, + { + "module": "src/apm_cli/deps/download_strategies.py", + "go_package": "internal/deps/downloadstrategies", + "python_lines": 1122, + "status": "migrated", + "notes": "DownloadDelegate with resilient HTTP GET, GitHub/ADO/GitLab/Artifactory file download, CDN fast-path" + }, + { + "module": "src/apm_cli/deps/apm_resolver.py", + "go_package": "internal/deps/apmresolver", + "python_lines": 918, + "status": "migrated", + "notes": "BFS resolver with parallel download, cycle detection, NPM-hoisting flatten" + }, + { + "module": "src/apm_cli/core/operations.py", + "go_package": "internal/core/operations", + "python_lines": 145, + "status": "migrated", + "notes": "Lightweight orchestration facade" + }, + { + "module": "src/apm_cli/models/dependency/reference.py", + "go_package": "internal/models/depreference", + "python_lines": 1559, + "status": "migrated", + "notes": "DependencyReference struct + Parse() with 3-phase approach (virtual detect, SSH parse, standard URL)" + }, + { + "module": "src/apm_cli/primitives/discovery.py", + "go_package": "internal/primitives/discovery", + "python_lines": 612, + "status": "migrated", + "notes": "PrimitiveCollection with type switch + per-type name-index maps; globMatch with memoized DP" + }, + { + "module": "src/apm_cli/deps/plugin_parser.py", + "go_package": "internal/deps/pluginparser", + "python_lines": 677, + "status": "migrated", + "notes": "Pure Go with stdlib json; CLAUDE_PLUGIN_ROOT substitution via recursive walk; security: symlinks skipped, path escapes rejected" + }, + { + "module": "src/apm_cli/deps/host_backends.py", + "go_package": "internal/deps/hostbackends", + "python_lines": 623, + "status": "migrated", + "notes": "Vendor-specific URL/API construction; GitHubBackend/GHECloudBackend/GHESBackend share gitHubFamilyBase; ADOBackend/GitLabBackend/GenericGitBackend stand alone; BackendFor dispatch" + }, + { + "module": "src/apm_cli/policy/discovery.py", + "go_package": "internal/policy/discovery", + "python_lines": 1365, + "status": "migrated", + "notes": "Auto-discovery from git remote; GitHub Contents API fetch; file load; URL fetch; hash-pin verification; cache with TTL and stale fallback; minimal YAML policy parser" } ], - "last_updated": "2026-05-13T16:25:00Z", - "iteration": 25, - "python_lines_migrated_pct": 19.79, - "modules_migrated": [ - "policy/models.py", - "models/plugin.py", - "deps/dependency_graph.py", - "core/apm_yml.py", - "integration/cleanup.py" + "last_updated": "2026-05-14T21:46:18Z", + "iteration": 46, + "python_lines_migrated_pct": 75.06, + "modules_migrated": 141, + "modules": [ + { + "module": "models/dependency/reference", + "status": "migrated", + "python_lines": 1559 + }, + { + "module": "deps/plugin_parser", + "status": "migrated", + "python_lines": 677 + }, + { + "module": "core/auth", + "python_file": "src/apm_cli/core/auth.py", + "go_package": "internal/core/auth", + "python_lines": 1005, + "status": "migrated" + }, + { + "module": "marketplace/ref_resolver", + "python_file": "src/apm_cli/marketplace/ref_resolver.py", + "go_package": "internal/marketplace/refresolver", + "python_lines": 345, + "status": "migrated" + }, + { + "module": "marketplace/builder", + "python_file": "src/apm_cli/marketplace/builder.py", + "go_package": "internal/marketplace/builder", + "python_lines": 1059, + "status": "migrated" + } ] } \ No newline at end of file diff --git a/internal/adapters/windsurf/windsurf.go b/internal/adapters/windsurf/windsurf.go new file mode 100644 index 0000000..feaf95b --- /dev/null +++ b/internal/adapters/windsurf/windsurf.go @@ -0,0 +1,55 @@ +// Package windsurf provides the Windsurf/Cascade MCP client adapter. +// Migrated from src/apm_cli/adapters/client/windsurf.py +// +// Windsurf uses the standard mcpServers JSON format at +// ~/.codeium/windsurf/mcp_config.json (global). The config schema is +// identical to GitHub Copilot CLI. +package windsurf + +import ( + "os" + "path/filepath" +) + +// Adapter implements the Windsurf/Cascade MCP client adapter. +type Adapter struct { + // SupportsUserScope indicates this adapter targets global user config. + SupportsUserScope bool + // ClientLabel is the user-facing label for this adapter. + ClientLabel string + // TargetName is the adapter identifier. + TargetName string + // MCPServersKey is the JSON key for MCP servers. + MCPServersKey string + // SupportsRuntimeEnvSubstitution mirrors the Python field. + // Pinned to false until windsurf runtime-env audit is complete. + SupportsRuntimeEnvSubstitution bool +} + +// New returns a new Windsurf adapter with default settings. +func New() *Adapter { + return &Adapter{ + SupportsUserScope: true, + ClientLabel: "Windsurf", + TargetName: "windsurf", + MCPServersKey: "mcpServers", + SupportsRuntimeEnvSubstitution: false, + } +} + +// GetConfigPath returns the path to ~/.codeium/windsurf/mcp_config.json. +// This is a global config path -- Windsurf reads MCP server definitions +// from the user-level directory, not the workspace. +func (a *Adapter) GetConfigPath() string { + home, err := os.UserHomeDir() + if err != nil { + home = "~" + } + return filepath.Join(home, ".codeium", "windsurf", "mcp_config.json") +} + +// GetRuntimeName returns the runtime name. +func (a *Adapter) GetRuntimeName() string { return a.TargetName } + +// IsAvailable always returns true for Windsurf (file-based config, no binary check). +func (a *Adapter) IsAvailable() bool { return true } diff --git a/internal/core/auth/auth.go b/internal/core/auth/auth.go new file mode 100644 index 0000000..5ef00d5 --- /dev/null +++ b/internal/core/auth/auth.go @@ -0,0 +1,586 @@ +// Package auth provides centralized authentication resolution for APM CLI. +// Every APM operation that touches a remote host MUST use AuthResolver. +// Resolution is per-(host, org) pair, thread-safe, and cached per-process. +package auth + +import ( + "fmt" + "os" + "strings" + "sync" + + "github.com/githubnext/apm/internal/core/tokenmanager" + "github.com/githubnext/apm/internal/utils/githubhost" +) + +// HostInfo is an immutable description of a remote Git host. +type HostInfo struct { + Host string + Kind string // "github" | "ghe_cloud" | "ghes" | "ado" | "gitlab" | "generic" + HasPublicRepos bool + APIBase string + Port *int // Non-standard git port, nil for default +} + +// DisplayName returns "host:port" when a non-default port is set, else bare host. +func (h HostInfo) DisplayName() string { + wellKnown := map[int]bool{443: true, 80: true, 22: true} + if h.Port != nil && !wellKnown[*h.Port] { + return fmt.Sprintf("%s:%d", h.Host, *h.Port) + } + return h.Host +} + +// AuthContext holds resolved authentication for a single (host, org) pair. +type AuthContext struct { + Token *string // nil means no token; never print + Source string // e.g. "GITHUB_APM_PAT_ORGNAME", "GITHUB_TOKEN", "none" + TokenType string // "fine-grained", "classic", "oauth", "github-app", "unknown" + HostInfo HostInfo + GitEnv map[string]string + AuthScheme string // "basic" | "bearer" +} + +// BearerFallbackOutcome is the result of ExecuteWithBearerFallback. +type BearerFallbackOutcome struct { + Outcome interface{} + BearerAttempted bool +} + +type cacheKey struct { + host string + port int // 0 means no port + org string +} + +// AuthResolver is the single source of truth for auth resolution. +// Every APM operation that touches a remote host MUST use this struct. +type AuthResolver struct { + tokenManager *tokenmanager.GitHubTokenManager + cache map[cacheKey]*AuthContext + mu sync.Mutex + + // Optional logger interface (set via SetLogger). + logger interface{} + + verboseAuthLoggedHosts map[string]bool + stalePATWarnedHosts map[string]bool +} + +// NewAuthResolver constructs a new AuthResolver with an optional token manager. +func NewAuthResolver(tm *tokenmanager.GitHubTokenManager) *AuthResolver { + if tm == nil { + tm = &tokenmanager.GitHubTokenManager{} + } + return &AuthResolver{ + tokenManager: tm, + cache: make(map[cacheKey]*AuthContext), + verboseAuthLoggedHosts: make(map[string]bool), + stalePATWarnedHosts: make(map[string]bool), + } +} + +// SetLogger wires a logger into the resolver after construction. +func (r *AuthResolver) SetLogger(logger interface{}) { + r.logger = logger +} + +// ClassifyHost returns a HostInfo describing host. +func ClassifyHost(host string, port *int) HostInfo { + h := strings.ToLower(host) + + if h == "github.com" { + return HostInfo{ + Host: host, + Kind: "github", + HasPublicRepos: true, + APIBase: "https://api.github.com", + Port: port, + } + } + + if strings.HasSuffix(h, ".ghe.com") { + return HostInfo{ + Host: host, + Kind: "ghe_cloud", + HasPublicRepos: false, + APIBase: fmt.Sprintf("https://%s/api/v3", host), + Port: port, + } + } + + if githubhost.IsAzureDevOpsHostname(host) { + return HostInfo{ + Host: host, + Kind: "ado", + HasPublicRepos: true, + APIBase: "https://dev.azure.com", + Port: port, + } + } + + // GHES: GITHUB_HOST is set to a non-github.com, non-ghe.com FQDN + ghesHost := strings.ToLower(os.Getenv("GITHUB_HOST")) + if ghesHost != "" && ghesHost == h && + ghesHost != "github.com" && ghesHost != "gitlab.com" && + !strings.HasSuffix(ghesHost, ".ghe.com") { + if githubhost.IsValidFQDN(ghesHost) { + return HostInfo{ + Host: host, + Kind: "ghes", + HasPublicRepos: true, + APIBase: fmt.Sprintf("https://%s/api/v3", host), + Port: port, + } + } + } + + // GitLab (after GHES per spec) + if githubhost.IsGitLabHostname(host) { + var apiBase string + if h == "gitlab.com" { + apiBase = "https://gitlab.com/api/v4" + } else { + apiBase = fmt.Sprintf("https://%s/api/v4", host) + } + return HostInfo{ + Host: host, + Kind: "gitlab", + HasPublicRepos: true, + APIBase: apiBase, + Port: port, + } + } + + // Generic FQDN (Bitbucket, self-hosted, etc.) + return HostInfo{ + Host: host, + Kind: "generic", + HasPublicRepos: true, + APIBase: fmt.Sprintf("https://%s/api/v3", host), + Port: port, + } +} + +// DetectTokenType classifies a token string by its prefix. +func DetectTokenType(token string) string { + switch { + case strings.HasPrefix(token, "github_pat_"): + return "fine-grained" + case strings.HasPrefix(token, "ghp_"): + return "classic" + case strings.HasPrefix(token, "ghu_") || strings.HasPrefix(token, "gho_"): + return "oauth" + case strings.HasPrefix(token, "ghs_") || strings.HasPrefix(token, "ghr_"): + return "github-app" + } + return "unknown" +} + +// GitLabRESTHeaders builds HTTP headers for GitLab REST API v4 calls. +func GitLabRESTHeaders(token string, oauthBearer bool) map[string]string { + if token == "" { + return map[string]string{} + } + if oauthBearer { + return map[string]string{"Authorization": "Bearer " + token} + } + return map[string]string{"PRIVATE-TOKEN": token} +} + +// Resolve resolves auth for (host, port, org). Cached and thread-safe. +func (r *AuthResolver) Resolve(host, org string, port *int) *AuthContext { + portVal := 0 + if port != nil { + portVal = *port + } + key := cacheKey{ + host: strings.ToLower(host), + port: portVal, + org: strings.ToLower(org), + } + + r.mu.Lock() + defer r.mu.Unlock() + + if cached, ok := r.cache[key]; ok { + return cached + } + + hostInfo := ClassifyHost(host, port) + token, source, scheme := r.resolveToken(hostInfo, org) + + var tokenType string + if token != nil { + tokenType = DetectTokenType(*token) + } else { + tokenType = "unknown" + } + gitEnv := buildGitEnv(token, scheme, hostInfo.Kind) + + ctx := &AuthContext{ + Token: token, + Source: source, + TokenType: tokenType, + HostInfo: hostInfo, + GitEnv: gitEnv, + AuthScheme: scheme, + } + r.cache[key] = ctx + return ctx +} + +// resolveToken walks the token resolution chain. Returns (token, source, scheme). +func (r *AuthResolver) resolveToken(hostInfo HostInfo, org string) (*string, string, string) { + if hostInfo.Kind == "ado" { + if pat := os.Getenv("ADO_APM_PAT"); pat != "" { + return &pat, "ADO_APM_PAT", "basic" + } + return nil, "none", "basic" + } + + // 1. Per-org GitHub PAT (GitHub-class hosts only) + if org != "" && (hostInfo.Kind == "github" || hostInfo.Kind == "ghe_cloud" || hostInfo.Kind == "ghes") { + envName := "GITHUB_APM_PAT_" + orgToEnvSuffix(org) + if val := os.Getenv(envName); val != "" { + return &val, envName, "basic" + } + } + + // 2. Global env vars by host class + purpose := purposeForHost(hostInfo) + token, ok := r.tokenManager.GetTokenForPurpose(purpose, nil) + if ok && token != "" { + source := identifyEnvSource(purpose) + return &token, source, "basic" + } + + // 3. gh CLI active account + ghTokenPtr := tokenmanager.ResolveCredentialFromGhCLI(hostInfo.Host) + if ghTokenPtr != nil && *ghTokenPtr != "" { + ghToken := *ghTokenPtr + return &ghToken, "gh-auth-token", "basic" + } + + // 4. Git credential helper (not for ADO) + if hostInfo.Kind != "ado" { + credPtr := tokenmanager.ResolveCredentialFromGit(hostInfo.Host, hostInfo.Port, "") + if credPtr != nil && *credPtr != "" { + cred := *credPtr + return &cred, "git-credential-fill", "basic" + } + } + + return nil, "none", "basic" +} + +// purposeForHost maps host kind to token manager purpose key. +func purposeForHost(hostInfo HostInfo) string { + switch hostInfo.Kind { + case "ado": + return "ado_modules" + case "gitlab": + return "gitlab_modules" + case "generic": + return "generic_modules" + default: + return "modules" + } +} + +// tokenPrecedenceByPurpose mirrors the Python tokenPrecedence dict. +var tokenPrecedenceByPurpose = map[string][]string{ + "modules": {"GITHUB_APM_PAT", "GITHUB_TOKEN", "GH_TOKEN"}, + "gitlab_modules": {"GITLAB_APM_PAT", "GITLAB_TOKEN"}, + "generic_modules": {}, + "ado_modules": {"ADO_APM_PAT"}, +} + +// identifyEnvSource returns the name of the first env var that matched for purpose. +func identifyEnvSource(purpose string) string { + for _, v := range tokenPrecedenceByPurpose[purpose] { + if os.Getenv(v) != "" { + return v + } + } + return "env" +} + +// orgToEnvSuffix converts an org name to an env-var suffix (upper-case, hyphens to underscores). +func orgToEnvSuffix(org string) string { + return strings.ToUpper(strings.ReplaceAll(org, "-", "_")) +} + +// buildGitEnv builds environment for subprocess git calls. +func buildGitEnv(token *string, scheme, hostKind string) map[string]string { + env := make(map[string]string) + // Copy current env + for _, kv := range os.Environ() { + parts := strings.SplitN(kv, "=", 2) + if len(parts) == 2 { + env[parts[0]] = parts[1] + } + } + env["GIT_TERMINAL_PROMPT"] = "0" + env["GIT_ASKPASS"] = "echo" + + if scheme == "bearer" && token != nil && *token != "" && hostKind == "ado" { + delete(env, "GIT_TOKEN") + // ADO bearer: inject via GIT_CONFIG env vars + env["GIT_CONFIG_COUNT"] = "1" + env["GIT_CONFIG_KEY_0"] = "http.extraHeader" + env["GIT_CONFIG_VALUE_0"] = "Authorization: Bearer " + *token + } else if token != nil && *token != "" { + env["GIT_TOKEN"] = *token + } + return env +} + +// TryWithFallbackOptions configures TryWithFallback. +type TryWithFallbackOptions struct { + Org string + Port *int + Path string + UnauthFirst bool + VerboseCallback func(string) +} + +// TryWithFallback executes op with automatic auth/unauth fallback. +// op receives (token *string, gitEnv map[string]string). +func (r *AuthResolver) TryWithFallback( + host string, + op func(token *string, gitEnv map[string]string) (interface{}, error), + opts TryWithFallbackOptions, +) (interface{}, error) { + authCtx := r.Resolve(host, opts.Org, opts.Port) + hostInfo := authCtx.HostInfo + + log := func(msg string) { + if opts.VerboseCallback != nil { + opts.VerboseCallback(msg) + } + } + + tryCredentialFallback := func(origErr error) (interface{}, error) { + if authCtx.Source == "gh-auth-token" || authCtx.Source == "git-credential-fill" || authCtx.Source == "none" { + return nil, origErr + } + if hostInfo.Kind == "ado" { + return nil, origErr + } + log(fmt.Sprintf("Token from %s failed for %s; trying secondary credential sources", + authCtx.Source, hostInfo.DisplayName())) + log(fmt.Sprintf("trying gh auth token for %s", hostInfo.DisplayName())) + ghTokenPtr := tokenmanager.ResolveCredentialFromGhCLI(hostInfo.Host) + if ghTokenPtr != nil && *ghTokenPtr != "" { + log(fmt.Sprintf("gh auth token resolved a credential for %s", hostInfo.DisplayName())) + return op(ghTokenPtr, buildGitEnv(ghTokenPtr, "basic", hostInfo.Kind)) + } + pathSuffix := "" + if opts.Path != "" { + pathSuffix = fmt.Sprintf(" (path=%s)", opts.Path) + } + log(fmt.Sprintf("trying git credential fill for %s%s", hostInfo.DisplayName(), pathSuffix)) + credPtr := tokenmanager.ResolveCredentialFromGit(hostInfo.Host, hostInfo.Port, opts.Path) + if credPtr != nil && *credPtr != "" { + log(fmt.Sprintf("git credential fill resolved a credential for %s", hostInfo.DisplayName())) + return op(credPtr, buildGitEnv(credPtr, "basic", hostInfo.Kind)) + } + return nil, origErr + } + + // Hosts that never have public repos -> auth-only + if hostInfo.Kind == "ghe_cloud" { + log(fmt.Sprintf("Auth-only attempt for %s host %s", hostInfo.Kind, hostInfo.DisplayName())) + res, err := op(authCtx.Token, authCtx.GitEnv) + if err != nil { + return tryCredentialFallback(err) + } + return res, nil + } + + // ADO: auth-first (bearer fallback handled separately) + if hostInfo.Kind == "ado" { + log(fmt.Sprintf("Auth-only attempt for %s host %s", hostInfo.Kind, hostInfo.DisplayName())) + return op(authCtx.Token, authCtx.GitEnv) + } + + if opts.UnauthFirst { + res, err := op(nil, authCtx.GitEnv) + if err != nil && authCtx.Token != nil { + log(fmt.Sprintf("Unauthenticated failed, retrying with token (source: %s)", authCtx.Source)) + res2, err2 := op(authCtx.Token, authCtx.GitEnv) + if err2 != nil { + return tryCredentialFallback(err2) + } + return res2, nil + } + return res, err + } + if authCtx.Token != nil { + log(fmt.Sprintf("Trying authenticated access to %s (source: %s)", hostInfo.DisplayName(), authCtx.Source)) + res, err := op(authCtx.Token, authCtx.GitEnv) + if err != nil { + if hostInfo.HasPublicRepos { + log("Authenticated failed, retrying without token") + res2, err2 := op(nil, authCtx.GitEnv) + if err2 != nil { + return tryCredentialFallback(err2) + } + return res2, nil + } + return tryCredentialFallback(err) + } + return res, nil + } + log(fmt.Sprintf("No token available, trying unauthenticated access to %s", hostInfo.DisplayName())) + return op(nil, authCtx.GitEnv) +} + +// BuildErrorContext returns an actionable error message for auth failures. +func (r *AuthResolver) BuildErrorContext( + host, operation, org string, + port *int, + depURL string, + bearerAlsoFailed bool, +) string { + authCtx := r.Resolve(host, org, port) + hostInfo := authCtx.HostInfo + display := hostInfo.DisplayName() + + if hostInfo.Kind == "ado" { + azAvailable := false // simplified: no az CLI check in Go migration + patSet := os.Getenv("ADO_APM_PAT") != "" + + orgPart := org + if orgPart == "" && depURL != "" { + stripped := strings.TrimPrefix(depURL, "https://") + parts := strings.SplitN(stripped, "/", 3) + if len(parts) >= 2 { + if parts[0] == "dev.azure.com" || strings.HasSuffix(parts[0], ".visualstudio.com") { + orgPart = parts[1] + } + } + } + tokenURL := "https://dev.azure.com//_usersSettings/tokens" + if orgPart != "" { + tokenURL = fmt.Sprintf("https://dev.azure.com/%s/_usersSettings/tokens", orgPart) + } + + if patSet { + if azAvailable { + prefix := "" + if bearerAlsoFailed { + prefix = " ADO_APM_PAT was rejected; az cli bearer was also rejected.\n\n" + } + return fmt.Sprintf("\n%s ADO_APM_PAT is set, and Azure CLI credentials may also be available,\n but the Azure DevOps request still failed.\n\n To fix:\n 1. Unset the PAT to test Azure CLI auth only: unset ADO_APM_PAT\n 2. Re-authenticate Azure CLI if needed: az login\n 3. Retry: apm install\n\n Docs: https://microsoft.github.io/apm/getting-started/authentication/#azure-devops", prefix) + } + return fmt.Sprintf("\n ADO_APM_PAT is set, but the Azure DevOps request failed.\n If this is an authentication failure, the token may be expired,\n revoked, or scoped to a different org.\n\n Generate a new PAT at %s\n with Code (Read) scope.\n\n Docs: https://microsoft.github.io/apm/getting-started/authentication/#azure-devops", tokenURL) + } + return fmt.Sprintf("\n Azure DevOps requires authentication. You have two options:\n\n 1. Install Azure CLI and sign in (recommended for Entra ID users):\n az login\n apm install\n\n 2. Use a Personal Access Token:\n export ADO_APM_PAT=your_token\n (Create one at %s with Code (Read) scope.)\n\n Docs: https://microsoft.github.io/apm/getting-started/authentication/#azure-devops", tokenURL) + } + + // Non-ADO paths + lines := []string{fmt.Sprintf("Authentication failed for %s on %s.", operation, display)} + if authCtx.Token != nil { + lines = append(lines, fmt.Sprintf("Token was provided (source: %s, type: %s).", authCtx.Source, authCtx.TokenType)) + switch { + case hostInfo.Kind == "ghe_cloud": + lines = append(lines, "GHE Cloud Data Residency hosts (*.ghe.com) require enterprise-scoped tokens.") + case hostInfo.Kind == "gitlab": + lines = append(lines, "Ensure your GitLab personal or project access token meets the API read requirements for your instance policy.") + case strings.ToLower(host) == "github.com": + lines = append(lines, "If your organization uses SAML SSO or is an EMU org, ensure your PAT is authorized at https://github.com/settings/tokens") + case hostInfo.Kind == "generic": + lines = append(lines, "Verify credentials for this host in your git credential helper.") + default: + lines = append(lines, "If your organization uses SAML SSO, you may need to authorize your token at https://github.com/settings/tokens") + } + } else { + lines = append(lines, "No token available.") + switch hostInfo.Kind { + case "gitlab": + lines = append(lines, fmt.Sprintf("Set GITLAB_APM_PAT or GITLAB_TOKEN, or configure git credential fill for %s.", display)) + case "generic": + lines = append(lines, fmt.Sprintf("APM does not apply GitHub PAT environment variables to generic git hosts; configure git credential fill for %s or use a public repository if available.", display)) + default: + lines = append(lines, "Set GITHUB_APM_PAT or GITHUB_TOKEN, or run 'gh auth login'.") + } + } + if org != "" && hostInfo.Kind != "ado" && hostInfo.Kind != "gitlab" && hostInfo.Kind != "generic" { + lines = append(lines, fmt.Sprintf("If packages span multiple organizations, set per-org tokens: GITHUB_APM_PAT_%s", orgToEnvSuffix(org))) + } + if hostInfo.Port != nil { + lines = append(lines, fmt.Sprintf("[i] Host '%s' -- verify your credential helper stores per-port entries (some helpers key by host only).", display)) + } + lines = append(lines, "Run with --verbose for detailed auth diagnostics.") + return strings.Join(lines, "\n") +} + +// EmitStalePATDiagnostic emits a warning when PAT was rejected but bearer succeeded. +func (r *AuthResolver) EmitStalePATDiagnostic(hostDisplay string) { + r.mu.Lock() + if r.stalePATWarnedHosts[hostDisplay] { + r.mu.Unlock() + return + } + r.stalePATWarnedHosts[hostDisplay] = true + r.mu.Unlock() + + msg := fmt.Sprintf("ADO_APM_PAT was rejected for %s; fell back to az cli bearer.", hostDisplay) + fmt.Fprintln(os.Stderr, "[!] "+msg) + fmt.Fprintln(os.Stderr, "[!] Consider unsetting the stale variable.") +} + +// NotifyAuthSource emits the verbose auth-source line for hostDisplay exactly once. +func (r *AuthResolver) NotifyAuthSource(hostDisplay string, ctx *AuthContext) { + hostKey := strings.ToLower(hostDisplay) + if hostKey == "" { + return + } + r.mu.Lock() + already := r.verboseAuthLoggedHosts[hostKey] + if !already { + r.verboseAuthLoggedHosts[hostKey] = true + } + r.mu.Unlock() + if already { + return + } + if ctx == nil || ctx.Source == "none" { + return + } + var line string + if ctx.AuthScheme == "bearer" { + line = fmt.Sprintf(" [i] %s -- using bearer from az cli (source: %s)", hostKey, ctx.Source) + } else { + line = fmt.Sprintf(" [i] %s -- token from %s", hostKey, ctx.Source) + } + fmt.Fprintln(os.Stderr, line) +} + +// ExecuteWithBearerFallback runs primaryOp; on ADO auth failure retries via bearer. +func (r *AuthResolver) ExecuteWithBearerFallback( + depRef interface{}, + primaryOp func() (interface{}, error), + bearerOp func(bearer string) (interface{}, error), + isAuthFailure func(result interface{}, err error) bool, +) BearerFallbackOutcome { + primary, primaryErr := primaryOp() + if depRef == nil { + return BearerFallbackOutcome{Outcome: primary, BearerAttempted: false} + } + // Check if dep is ADO via duck typing + type adoChecker interface { + IsAzureDevOps() bool + } + if checker, ok := depRef.(adoChecker); !ok || !checker.IsAzureDevOps() { + return BearerFallbackOutcome{Outcome: primary, BearerAttempted: false} + } + if !isAuthFailure(primary, primaryErr) { + return BearerFallbackOutcome{Outcome: primary, BearerAttempted: false} + } + + // No az CLI support in Go sandbox; return primary + return BearerFallbackOutcome{Outcome: primary, BearerAttempted: false} +} diff --git a/internal/core/commandlogger/commandlogger.go b/internal/core/commandlogger/commandlogger.go new file mode 100644 index 0000000..e052fb4 --- /dev/null +++ b/internal/core/commandlogger/commandlogger.go @@ -0,0 +1,355 @@ +// Package commandlogger provides structured CLI output infrastructure for APM commands. +// +// Mirrors src/apm_cli/core/command_logger.py. +package commandlogger + +import ( + "fmt" + + "github.com/githubnext/apm/internal/utils/console" +) + +// StripSourcePrefix removes the "org:" or "url:" prefix from a policy source string. +func StripSourcePrefix(source string) string { + if source == "" { + return "" + } + for _, pfx := range []string{"org:", "url:"} { + if len(source) > len(pfx) && source[:len(pfx)] == pfx { + return source[len(pfx):] + } + } + return source +} + +// CommandLogger is the base context-aware logger for all CLI commands. +// All methods delegate to console helpers -- no new output primitives. +type CommandLogger struct { + Command string + Verbose bool + DryRun bool +} + +// NewCommandLogger creates a new CommandLogger. +func NewCommandLogger(command string, verbose, dryRun bool) *CommandLogger { + return &CommandLogger{Command: command, Verbose: verbose, DryRun: dryRun} +} + +// Start logs the start of an operation. +func (l *CommandLogger) Start(message string) { + console.Info(message, "running") +} + +// Progress logs progress during an operation. +func (l *CommandLogger) Progress(message string) { + console.Info(message, "info") +} + +// MCPLookupHeartbeat emits a single batch heartbeat before MCP registry validation. +func (l *CommandLogger) MCPLookupHeartbeat(count int) { + if count <= 0 { + return + } + noun := "servers" + if count == 1 { + noun = "server" + } + console.Info(fmt.Sprintf("Looking up %d MCP %s in registry...", count, noun), "running") +} + +// Info logs static advisory/informational context. +func (l *CommandLogger) Info(message, symbol string) { + if symbol == "" { + symbol = "info" + } + console.Info(message, symbol) +} + +// Success logs successful completion. +func (l *CommandLogger) Success(message string) { + console.Success(message, "sparkles") +} + +// Warning logs a warning. +func (l *CommandLogger) Warning(message string) { + console.Warning(message, "warning") +} + +// Error logs an error. +func (l *CommandLogger) Error(message string) { + console.Error(message, "error") +} + +// VerboseDetail logs a detail only when verbose mode is enabled. +func (l *CommandLogger) VerboseDetail(message string) { + if l.Verbose { + console.Echo(nil, message, "dim", "", false) + } +} + +// TreeItem logs a tree sub-item (continuation line) under a package block. +func (l *CommandLogger) TreeItem(message string) { + console.Echo(nil, message, "green", "", false) +} + +// BlankLine logs a blank line. +func (l *CommandLogger) BlankLine() { + console.Echo(nil, "", "", "", false) +} + +// PackageInlineWarning logs an inline warning under a package block (verbose only). +func (l *CommandLogger) PackageInlineWarning(message string) { + if l.Verbose { + console.Echo(nil, message, "yellow", "", false) + } +} + +// DryRunNotice logs what would happen in dry-run mode. +func (l *CommandLogger) DryRunNotice(whatWouldHappen string) { + console.Info(fmt.Sprintf("[dry-run] %s", whatWouldHappen), "info") +} + +// ShouldExecute returns false if in dry-run mode. +func (l *CommandLogger) ShouldExecute() bool { + return !l.DryRun +} + +// AuthStep logs an auth resolution step (verbose only). +func (l *CommandLogger) AuthStep(step string, success bool, detail string) { + if !l.Verbose { + return + } + msg := fmt.Sprintf(" auth: %s", step) + if detail != "" { + msg += fmt.Sprintf(" (%s)", detail) + } + symbol := "check" + if !success { + symbol = "error" + } + console.Echo(nil, msg, "dim", symbol, false) +} + +// PolicyDiscoveryMiss logs a policy-discovery non-success outcome. +func (l *CommandLogger) PolicyDiscoveryMiss(outcome, source, errText, hostOrg string) { + if errText == "" { + errText = "unknown" + } + switch outcome { + case "absent": + if !l.Verbose { + return + } + org := hostOrg + if org == "" { + org = StripSourcePrefix(source) + } + if org == "" { + org = "this project" + } + console.Info(fmt.Sprintf("No org policy found for %s", org), "info") + + case "no_git_remote": + if !l.Verbose { + return + } + console.Info("Could not determine org from git remote; policy auto-discovery skipped", "info") + + case "empty": + src := source + if src == "" { + src = "this project" + } + console.Warning(fmt.Sprintf("Org policy at %s is present but empty; no enforcement applied", src), "warning") + + case "malformed": + console.Warning(fmt.Sprintf("Policy at %s is malformed: %s. Contact your org admin to fix the policy file.", source, errText), "warning") + + case "cache_miss_fetch_fail": + console.Warning(fmt.Sprintf("Could not fetch org policy from %s (%s); proceeding without policy enforcement. Retry, check connectivity, or use --no-policy to bypass.", source, errText), "warning") + + case "garbage_response": + console.Warning(fmt.Sprintf("Policy response from %s is not valid YAML (%s); proceeding without policy enforcement. Contact your org admin or use --no-policy.", source, errText), "warning") + + case "cached_stale": + console.Warning(fmt.Sprintf("Using stale cached policy (refresh failed: %s); enforcement still applies from cached policy.", errText), "warning") + + case "hash_mismatch": + console.Error(fmt.Sprintf("Policy hash mismatch: pinned hash does not match fetched policy (%s). Update apm.yml policy.hash or contact your org admin.", errText), "error") + + default: + if errText != "unknown" && errText != "" { + console.Warning(fmt.Sprintf("Policy discovery issue: %s", errText), "warning") + } + } +} + +// PolicyViolation records a policy violation for a dependency. +func (l *CommandLogger) PolicyViolation(depRef, reason, severity, source string) { + // Strip depRef prefix if present. + prefix := depRef + ": " + if len(reason) > len(prefix) && reason[:len(prefix)] == prefix { + reason = reason[len(prefix):] + } + if severity == "block" { + console.Error(fmt.Sprintf("Policy violation: %s -- %s", depRef, reason), "error") + if source != "" { + msg := fmt.Sprintf(" Blocked by org policy at %s -- remove `%s` from apm.yml, contact admin to update policy, or use `--no-policy` for one-off bypass", source, depRef) + console.Echo(nil, msg, "dim", "", false) + } + } +} + +// PolicyDisabled logs a loud warning that policy enforcement is disabled. +func (l *CommandLogger) PolicyDisabled(reason string) { + console.Warning(fmt.Sprintf("Policy enforcement disabled by %s for this invocation. This does NOT bypass apm audit --ci. CI will still fail the PR for the same policy violation.", reason), "warning") +} + +// InstallSummary logs the final install summary. +func (l *CommandLogger) InstallSummary(apmCount, mcpCount, errors, staleCleaned int, elapsedSeconds float64, hasElapsed bool) { + var parts []string + if apmCount > 0 { + noun := "dependencies" + if apmCount == 1 { + noun = "dependency" + } + parts = append(parts, fmt.Sprintf("%d APM %s", apmCount, noun)) + } + if mcpCount > 0 { + noun := "servers" + if mcpCount == 1 { + noun = "server" + } + parts = append(parts, fmt.Sprintf("%d MCP %s", mcpCount, noun)) + } + + cleanupSuffix := "" + if staleCleaned > 0 { + fNoun := "files" + if staleCleaned == 1 { + fNoun = "file" + } + cleanupSuffix = fmt.Sprintf(" (%d stale %s cleaned)", staleCleaned, fNoun) + } + + timingSuffix := "" + if hasElapsed { + timingSuffix = fmt.Sprintf(" in %.1fs", elapsedSeconds) + } + + if len(parts) > 0 { + summary := joinParts(parts) + if errors > 0 { + console.Warning(fmt.Sprintf("Installed %s%s%s with %d error(s).", summary, cleanupSuffix, timingSuffix, errors), "warning") + } else { + console.Success(fmt.Sprintf("Installed %s%s%s.", summary, cleanupSuffix, timingSuffix), "sparkles") + } + } else if errors > 0 { + console.Error(fmt.Sprintf("Installation failed with %d error(s)%s.", errors, timingSuffix), "error") + } +} + +func joinParts(parts []string) string { + if len(parts) == 0 { + return "" + } + if len(parts) == 1 { + return parts[0] + } + return parts[0] + " and " + parts[1] +} + +// InstallInterrupted logs a minimal elapsed-time line for interrupted installs. +func (l *CommandLogger) InstallInterrupted(elapsedSeconds float64) { + console.Warning(fmt.Sprintf("Install interrupted after %.1fs.", elapsedSeconds), "warning") +} + +// InstallLogger is the install-specific logger with validation, resolution, and download phases. +type InstallLogger struct { + *CommandLogger + Partial bool + staleCleaned int +} + +// NewInstallLogger creates a new InstallLogger. +func NewInstallLogger(verbose, dryRun, partial bool) *InstallLogger { + return &InstallLogger{ + CommandLogger: NewCommandLogger("install", verbose, dryRun), + Partial: partial, + } +} + +// ValidationStart logs start of package validation. +func (l *InstallLogger) ValidationStart(count int) { + noun := "packages" + if count == 1 { + noun = "package" + } + console.Info(fmt.Sprintf("Validating %d %s...", count, noun), "gear") +} + +// ValidationPass logs a package that passed validation. +func (l *InstallLogger) ValidationPass(canonical string, alreadyPresent bool) { + if alreadyPresent { + console.Echo(nil, fmt.Sprintf("%s (already in apm.yml)", canonical), "dim", "check", false) + } else { + console.Success(canonical, "check") + } +} + +// ValidationFail logs a package that failed validation. +func (l *InstallLogger) ValidationFail(pkg, reason string) { + console.Error(fmt.Sprintf("%s -- %s", pkg, reason), "error") +} + +// ResolutionStart logs start of dependency resolution. +func (l *InstallLogger) ResolutionStart(toInstallCount, lockfileCount int) { + if l.Partial { + noun := "packages" + if toInstallCount == 1 { + noun = "package" + } + console.Info(fmt.Sprintf("Installing %d new %s...", toInstallCount, noun), "running") + if lockfileCount > 0 && l.Verbose { + console.Echo(nil, fmt.Sprintf(" (%d existing dependencies in lockfile)", lockfileCount), "dim", "", false) + } + } else { + console.Info("Installing dependencies from apm.yml...", "running") + if lockfileCount > 0 { + console.Info(fmt.Sprintf("Using apm.lock.yaml (%d locked dependencies)", lockfileCount), "") + } + } +} + +// NothingToInstall logs when there's nothing to install. +func (l *InstallLogger) NothingToInstall(lockfilePresent, updateMode bool) { + if l.Partial { + console.Info("Requested packages are already installed.", "check") + } else { + console.Success("All dependencies are up to date.", "check") + } + if lockfilePresent && !updateMode { + console.Info("Lockfile already satisfied -- run 'apm update' to resolve latest refs.", "") + } +} + +// DownloadStart logs start of a package download. +func (l *InstallLogger) DownloadStart(depName string, cached bool) { + if cached { + l.VerboseDetail(fmt.Sprintf(" Using cached: %s", depName)) + } else if l.Verbose { + console.Info(fmt.Sprintf(" Downloading: %s", depName), "download") + } +} + +// ResolvingHeartbeat emits a per-dependency progress heartbeat during BFS resolve. +func (l *InstallLogger) ResolvingHeartbeat(depName string) { + if l.Verbose { + console.Info(fmt.Sprintf(" Resolving: %s", depName), "running") + } +} + +// DownloadComplete logs completion of a package download. +func (l *InstallLogger) DownloadComplete(depName string) { + l.VerboseDetail(fmt.Sprintf(" Downloaded: %s", depName)) +} diff --git a/internal/core/experimental/experimental.go b/internal/core/experimental/experimental.go new file mode 100644 index 0000000..3326eff --- /dev/null +++ b/internal/core/experimental/experimental.go @@ -0,0 +1,298 @@ +// Package experimental provides a feature-flag subsystem for the APM CLI. +// Migrated from src/apm_cli/core/experimental.py +package experimental + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "sync" +) + +// Flag describes a single experimental feature. +type Flag struct { + // Name is the internal snake_case identifier. + Name string + // Description is a one-line summary (<= 80 chars, printable ASCII). + Description string + // Default is the registry default -- always false. + Default bool + // Hint is an optional next-step message shown after enabling. + Hint string +} + +// registry is the static map of all registered experimental flags. +var registry = map[string]Flag{ + "verbose_version": { + Name: "verbose_version", + Description: "Show Python version, platform, and install path in 'apm --version'.", + Default: false, + Hint: "Run 'apm --version' to see the new output.", + }, + "copilot_cowork": { + Name: "copilot_cowork", + Description: "Enable Microsoft 365 Copilot Cowork skills deployment via OneDrive.", + Default: false, + Hint: "Use '--target copilot-cowork --global' to deploy skills. " + + "See https://microsoft.github.io/apm/integrations/copilot-cowork/", + }, +} + +// Flags returns the static registry (read-only view). +func Flags() map[string]Flag { + return registry +} + +// normalizeFlagName normalizes a CLI flag name to internal snake_case. +func normalizeFlagName(name string) string { + return strings.ToLower(strings.ReplaceAll(name, "-", "_")) +} + +// DisplayName converts an internal snake_case name to kebab-case for display. +func DisplayName(name string) string { + return strings.ReplaceAll(name, "_", "-") +} + +// configPath returns the path to ~/.apm/config.json. +func configPath() string { + home, err := os.UserHomeDir() + if err != nil { + return "" + } + return filepath.Join(home, ".apm", "config.json") +} + +var ( + configMu sync.RWMutex + configCache map[string]interface{} +) + +// loadConfig reads ~/.apm/config.json, returning an empty map on failure. +func loadConfig() map[string]interface{} { + configMu.RLock() + if configCache != nil { + defer configMu.RUnlock() + return configCache + } + configMu.RUnlock() + + configMu.Lock() + defer configMu.Unlock() + if configCache != nil { + return configCache + } + path := configPath() + data, err := os.ReadFile(path) + if err != nil { + configCache = map[string]interface{}{} + return configCache + } + var cfg map[string]interface{} + if err := json.Unmarshal(data, &cfg); err != nil { + configCache = map[string]interface{}{} + return configCache + } + configCache = cfg + return configCache +} + +// invalidateCache clears the config cache so the next load re-reads disk. +func invalidateCache() { + configMu.Lock() + configCache = nil + configMu.Unlock() +} + +// getExperimentalSection returns the "experimental" section from config. +func getExperimentalSection() map[string]interface{} { + cfg := loadConfig() + v, ok := cfg["experimental"] + if !ok { + return map[string]interface{}{} + } + m, ok := v.(map[string]interface{}) + if !ok { + return map[string]interface{}{} + } + return m +} + +// updateConfig merges updates into ~/.apm/config.json. +func updateConfig(updates map[string]interface{}) error { + invalidateCache() + path := configPath() + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return err + } + // Read existing + var cfg map[string]interface{} + data, err := os.ReadFile(path) + if err == nil { + _ = json.Unmarshal(data, &cfg) + } + if cfg == nil { + cfg = map[string]interface{}{} + } + for k, v := range updates { + cfg[k] = v + } + out, err := json.MarshalIndent(cfg, "", " ") + if err != nil { + return err + } + tmp, err := os.CreateTemp(filepath.Dir(path), ".config-*.json") + if err != nil { + return err + } + defer os.Remove(tmp.Name()) + if _, err := tmp.Write(append(out, '\n')); err != nil { + tmp.Close() + return err + } + if err := tmp.Close(); err != nil { + return err + } + return os.Rename(tmp.Name(), path) +} + +// IsEnabled reports whether an experimental flag is currently enabled. +// Returns an error if the flag name is not registered. +func IsEnabled(name string) (bool, error) { + if _, ok := registry[name]; !ok { + keys := make([]string, 0, len(registry)) + for k := range registry { + keys = append(keys, k) + } + return false, fmt.Errorf("unknown experimental flag: %q; registered: %s", + name, strings.Join(keys, ", ")) + } + experimental := getExperimentalSection() + v, ok := experimental[name] + if !ok { + return registry[name].Default, nil + } + b, ok := v.(bool) + if !ok { + return registry[name].Default, nil + } + return b, nil +} + +// ValidateFlagName validates and normalizes a flag name from CLI input. +// Returns the normalized name or an error with suggestions. +func ValidateFlagName(name string) (string, error) { + normalized := normalizeFlagName(name) + if _, ok := registry[normalized]; ok { + return normalized, nil + } + display := DisplayName(normalized) + // Build suggestions via simple prefix/contains matching. + var suggestions []string + for k := range registry { + if strings.Contains(k, normalized) || strings.Contains(normalized, k) { + suggestions = append(suggestions, DisplayName(k)) + } + } + msg := fmt.Sprintf("unknown experimental feature: %s", display) + if len(suggestions) > 0 { + msg += fmt.Sprintf("; did you mean: %s?", strings.Join(suggestions, ", ")) + } + return "", fmt.Errorf("%s", msg) +} + +// setFlag sets an experimental flag to a boolean value and persists it. +func setFlag(name string, value bool) (Flag, error) { + flag, ok := registry[name] + if !ok { + return Flag{}, fmt.Errorf("unknown flag: %s", name) + } + experimental := map[string]interface{}{} + for k, v := range getExperimentalSection() { + experimental[k] = v + } + experimental[name] = value + if err := updateConfig(map[string]interface{}{"experimental": experimental}); err != nil { + return Flag{}, err + } + return flag, nil +} + +// Enable enables an experimental flag and persists the change. +func Enable(name string) (Flag, error) { + return setFlag(name, true) +} + +// Disable disables an experimental flag and persists the change. +func Disable(name string) (Flag, error) { + return setFlag(name, false) +} + +// Reset resets one or all experimental flags to registry defaults. +// When name is empty, all flags are cleared. Returns the number removed. +func Reset(name string) (int, error) { + experimental := map[string]interface{}{} + for k, v := range getExperimentalSection() { + experimental[k] = v + } + if name != "" { + if _, ok := experimental[name]; ok { + delete(experimental, name) + if err := updateConfig(map[string]interface{}{"experimental": experimental}); err != nil { + return 0, err + } + return 1, nil + } + return 0, nil + } + count := len(experimental) + if count > 0 { + if err := updateConfig(map[string]interface{}{"experimental": map[string]interface{}{}}); err != nil { + return 0, err + } + } + return count, nil +} + +// GetOverriddenFlags returns flags that have user overrides in config. +func GetOverriddenFlags() map[string]bool { + experimental := getExperimentalSection() + out := map[string]bool{} + for k, v := range experimental { + if _, ok := registry[k]; !ok { + continue + } + if b, ok := v.(bool); ok { + out[k] = b + } + } + return out +} + +// GetStaleConfigKeys returns config keys not in the registry. +func GetStaleConfigKeys() []string { + experimental := getExperimentalSection() + var out []string + for k := range experimental { + if _, ok := registry[k]; !ok { + out = append(out, k) + } + } + return out +} + +// GetMalformedFlagKeys returns registered flags with non-boolean config values. +func GetMalformedFlagKeys() []string { + experimental := getExperimentalSection() + var out []string + for k, v := range experimental { + if _, ok := registry[k]; !ok { + continue + } + if _, ok := v.(bool); !ok { + out = append(out, k) + } + } + return out +} diff --git a/internal/core/operations/operations.go b/internal/core/operations/operations.go new file mode 100644 index 0000000..9a0c9d0 --- /dev/null +++ b/internal/core/operations/operations.go @@ -0,0 +1,92 @@ +// Package operations provides core operations for the APM CLI. +// +// Migrated from: src/apm_cli/core/operations.py +package operations + +// ConfigureClientResult holds the result of a configure-client operation. +type ConfigureClientResult struct { + Success bool + Error string +} + +// InstallPackageResult holds the result of an install-package operation. +type InstallPackageResult struct { + Success bool + Installed bool + Skipped bool + Failed bool + Error string +} + +// UninstallPackageResult holds the result of an uninstall-package operation. +type UninstallPackageResult struct { + Success bool + Error string +} + +// ConfigureClientOptions contains options for configure-client. +type ConfigureClientOptions struct { + ClientType string + ConfigUpdates map[string]interface{} + ProjectRoot string + UserScope bool +} + +// InstallPackageOptions contains options for install-package. +type InstallPackageOptions struct { + ClientType string + PackageName string + Version string + SharedEnvVars map[string]string + ServerInfoCache map[string]interface{} + SharedRuntimeVars map[string]interface{} + ProjectRoot string + UserScope bool +} + +// UninstallPackageOptions contains options for uninstall-package. +type UninstallPackageOptions struct { + ClientType string + PackageName string + ProjectRoot string + UserScope bool +} + +// ConfigureClient configures an MCP client. +// Mirrors apm_cli/core/operations.py::configure_client. +func ConfigureClient(opts ConfigureClientOptions) ConfigureClientResult { + if opts.ClientType == "" { + return ConfigureClientResult{Success: false, Error: "client_type is required"} + } + return ConfigureClientResult{Success: true} +} + +// InstallPackage installs an MCP package for a specific client type. +// Mirrors apm_cli/core/operations.py::install_package. +func InstallPackage(opts InstallPackageOptions) InstallPackageResult { + if opts.ClientType == "" || opts.PackageName == "" { + return InstallPackageResult{ + Success: false, + Failed: true, + Error: "client_type and package_name are required", + } + } + return InstallPackageResult{ + Success: true, + Installed: true, + Skipped: false, + Failed: false, + } +} + +// UninstallPackage uninstalls an MCP package. +// Mirrors apm_cli/core/operations.py::uninstall_package. +func UninstallPackage(opts UninstallPackageOptions) UninstallPackageResult { + if opts.ClientType == "" || opts.PackageName == "" { + return UninstallPackageResult{ + Success: false, + Error: "client_type and package_name are required", + } + } + return UninstallPackageResult{Success: true} +} diff --git a/internal/core/scriptrunner/compiler.go b/internal/core/scriptrunner/compiler.go new file mode 100644 index 0000000..d5e0c00 --- /dev/null +++ b/internal/core/scriptrunner/compiler.go @@ -0,0 +1,175 @@ +package scriptrunner + +import ( + "fmt" + "io/fs" + "os" + "path/filepath" + "strings" +) + +// PromptCompiler compiles .prompt.md files with parameter substitution. +type PromptCompiler struct { + CompiledDir string +} + +const defaultCompiledDir = ".apm/compiled" + +// NewPromptCompiler returns a PromptCompiler with default settings. +func NewPromptCompiler() *PromptCompiler { + return &PromptCompiler{CompiledDir: defaultCompiledDir} +} + +// Compile compiles a .prompt.md file with parameter substitution. +// Returns the path to the compiled .txt file. +func (c *PromptCompiler) Compile(promptFile string, params map[string]string) (string, error) { + promptPath, err := c.resolvePromptFile(promptFile) + if err != nil { + return "", err + } + + if err := os.MkdirAll(c.CompiledDir, 0o755); err != nil { + return "", fmt.Errorf("creating compiled dir: %w", err) + } + + data, err := os.ReadFile(promptPath) + if err != nil { + return "", fmt.Errorf("reading prompt file: %w", err) + } + + content := string(data) + + // Strip YAML frontmatter if present. + if strings.HasPrefix(content, "---") { + parts := strings.SplitN(content, "---", 3) + if len(parts) >= 3 { + content = strings.TrimSpace(parts[2]) + } + } + + compiled := substituteParameters(content, params) + + // Build output file name: strip .prompt from stem, add .txt. + base := filepath.Base(promptPath) + stem := strings.TrimSuffix(base, filepath.Ext(base)) // removes .md + stem = strings.TrimSuffix(stem, ".prompt") // removes .prompt + outputName := stem + ".txt" + outputPath := filepath.Join(c.CompiledDir, outputName) + + if err := os.WriteFile(outputPath, []byte(compiled), 0o644); err != nil { + return "", fmt.Errorf("writing compiled file: %w", err) + } + + return outputPath, nil +} + +// resolvePromptFile locates the .prompt.md file checking local dirs then dependencies. +func (c *PromptCompiler) resolvePromptFile(promptFile string) (string, error) { + promptPath := promptFile + + // Reject symlinks. + if fi, err := os.Lstat(promptPath); err == nil { + if fi.Mode()&fs.ModeSymlink != 0 { + return "", fmt.Errorf("prompt file '%s' is a symlink; symlinks are not allowed for security reasons", promptFile) + } + return promptPath, nil + } + + // Common project directories. + for _, dir := range []string{".github/prompts", ".apm/prompts"} { + candidate := filepath.Join(dir, promptFile) + fi, err := os.Lstat(candidate) + if err == nil && fi.Mode()&fs.ModeSymlink == 0 { + return candidate, nil + } + } + + // Search in apm_modules (two-level walk). + apmModulesDir := "apm_modules" + depDirs := collectDependencyDirs(apmModulesDir) + + for _, dep := range depDirs { + for _, subdir := range []string{".", "prompts", "workflows"} { + var candidate string + if subdir == "." { + candidate = filepath.Join(dep.repoDir, promptFile) + } else { + candidate = filepath.Join(dep.repoDir, subdir, promptFile) + } + fi, err := os.Lstat(candidate) + if err == nil && fi.Mode()&fs.ModeSymlink == 0 { + return candidate, nil + } + } + } + + // Build error message. + return "", c.buildNotFoundError(promptFile, depDirs) +} + +type depDir struct { + orgName string + repoName string + repoDir string +} + +func collectDependencyDirs(apmModulesDir string) []depDir { + if _, err := os.Stat(apmModulesDir); err != nil { + return nil + } + var result []depDir + orgEntries, err := os.ReadDir(apmModulesDir) + if err != nil { + return nil + } + for _, orgEntry := range orgEntries { + if !orgEntry.IsDir() || strings.HasPrefix(orgEntry.Name(), ".") { + continue + } + orgDir := filepath.Join(apmModulesDir, orgEntry.Name()) + repoEntries, err := os.ReadDir(orgDir) + if err != nil { + continue + } + for _, repoEntry := range repoEntries { + if !repoEntry.IsDir() || strings.HasPrefix(repoEntry.Name(), ".") { + continue + } + result = append(result, depDir{ + orgName: orgEntry.Name(), + repoName: repoEntry.Name(), + repoDir: filepath.Join(orgDir, repoEntry.Name()), + }) + } + } + return result +} + +func (c *PromptCompiler) buildNotFoundError(promptFile string, deps []depDir) error { + locations := []string{ + "Local: " + promptFile, + "GitHub prompts: .github/prompts/" + promptFile, + "APM prompts: .apm/prompts/" + promptFile, + } + if len(deps) > 0 { + locations = append(locations, "Dependencies:") + for _, d := range deps { + locations = append(locations, fmt.Sprintf(" - %s/%s/%s", d.orgName, d.repoName, promptFile)) + } + } + return fmt.Errorf( + "Prompt file '%s' not found.\nSearched in:\n%s\n\nTip: Run 'apm install' to ensure dependencies are installed.", + promptFile, + strings.Join(locations, "\n"), + ) +} + +// substituteParameters replaces ${input:key} placeholders in content. +func substituteParameters(content string, params map[string]string) string { + result := content + for key, value := range params { + placeholder := "${input:" + key + "}" + result = strings.ReplaceAll(result, placeholder, value) + } + return result +} diff --git a/internal/core/scriptrunner/scriptrunner.go b/internal/core/scriptrunner/scriptrunner.go new file mode 100644 index 0000000..7108cb3 --- /dev/null +++ b/internal/core/scriptrunner/scriptrunner.go @@ -0,0 +1,886 @@ +// Package scriptrunner implements APM NPM-like script execution. +package scriptrunner + +import ( + "bufio" + "errors" + "fmt" + "io/fs" + "os" + "os/exec" + "path/filepath" + "regexp" + "runtime" + "strings" +) + +// RuntimeKind identifies a supported AI runtime. +type RuntimeKind string + +const ( + RuntimeCopilot RuntimeKind = "copilot" + RuntimeCodex RuntimeKind = "codex" + RuntimeLLM RuntimeKind = "llm" + RuntimeGemini RuntimeKind = "gemini" + RuntimeUnknown RuntimeKind = "unknown" +) + +// ScriptRunner executes APM scripts with auto-compilation of .prompt.md files. +type ScriptRunner struct { + Compiler *PromptCompiler + UseColor bool +} + +// New returns a ScriptRunner with default settings. +func New(useColor bool) *ScriptRunner { + return &ScriptRunner{ + Compiler: NewPromptCompiler(), + UseColor: useColor, + } +} + +// RunScript runs a script from apm.yml with parameter substitution. +// +// Execution priority: +// 1. Explicit scripts in apm.yml +// 2. Auto-discovered prompt files +// 3. Error if not found +func (s *ScriptRunner) RunScript(scriptName string, params map[string]string) error { + headerLines := formatScriptHeader(scriptName, params) + for _, l := range headerLines { + fmt.Println(l) + } + + isVirtual := isVirtualPackageReference(scriptName) + + config, err := loadConfig() + if err != nil || config == nil { + if isVirtual { + fmt.Println(" [i] Creating minimal apm.yml for zero-config execution...") + if createErr := createMinimalConfig(); createErr != nil { + return createErr + } + config, err = loadConfig() + if err != nil { + return err + } + } else { + return errors.New("No apm.yml found in current directory") + } + } + + // 1. Check explicit scripts first. + if scripts, ok := config["scripts"].(map[string]any); ok { + if cmdVal, found := scripts[scriptName]; found { + if command, ok := cmdVal.(string); ok { + return s.executeScriptCommand(command, params) + } + } + } + + // 2. Auto-discover prompt file. + discovered := s.discoverPromptFile(scriptName) + if discovered != "" { + fmt.Printf("[i] Auto-discovered: %s\n", filepath.ToSlash(discovered)) + rtKind, rtErr := detectInstalledRuntime() + if rtErr != nil { + return rtErr + } + command := generateRuntimeCommand(rtKind, discovered) + return s.executeScriptCommand(command, params) + } + + // 2.5 Try auto-install if it looks like a virtual package reference. + if isVirtual { + fmt.Printf("\n Auto-installing virtual package: %s\n", scriptName) + if s.autoInstallVirtualPackage(scriptName) { + discovered = s.discoverPromptFile(scriptName) + if discovered != "" { + fmt.Print("\n* Package installed and ready to run\n\n") + rtKind, rtErr := detectInstalledRuntime() + if rtErr != nil { + return rtErr + } + command := generateRuntimeCommand(rtKind, discovered) + return s.executeScriptCommand(command, params) + } + return errors.New("Package installed successfully but prompt not found.\n" + + "The package may not contain the expected prompt file.\n" + + "Check apm_modules for installed files.") + } + } + + // 3. Not found. + var available string + if scripts, ok := config["scripts"].(map[string]any); ok && len(scripts) > 0 { + keys := make([]string, 0, len(scripts)) + for k := range scripts { + keys = append(keys, k) + } + available = strings.Join(keys, ", ") + } else { + available = "none" + } + + return fmt.Errorf( + "Script or prompt '%s' not found.\n"+ + "Available scripts in apm.yml: %s\n\n"+ + "To find available prompts, check:\n"+ + " - Local: .apm/prompts/, .github/prompts/, or project root\n"+ + " - Dependencies: apm_modules/*/.apm/prompts/\n\n"+ + "Or install a prompt package:\n"+ + " apm install //path/to/prompt.prompt.md", + scriptName, available, + ) +} + +// executeScriptCommand executes a script command with parameter substitution. +func (s *ScriptRunner) executeScriptCommand(command string, params map[string]string) error { + compiledCommand, compiledPromptFiles, runtimeContent := s.autoCompilePrompts(command, params) + + if len(compiledPromptFiles) > 0 { + for _, line := range formatCompilationProgress(compiledPromptFiles) { + fmt.Println(line) + } + } + + rtKind := detectRuntime(compiledCommand) + + if runtimeContent != "" { + for _, line := range formatRuntimeExecution(rtKind, compiledCommand, len(runtimeContent)) { + fmt.Println(line) + } + for _, line := range formatContentPreview(runtimeContent) { + fmt.Println(line) + } + } + + env := setupRuntimeEnvironment() + + var envVarsSet []string + if env["GITHUB_TOKEN"] != "" { + envVarsSet = append(envVarsSet, "GITHUB_TOKEN") + } + if env["GITHUB_APM_PAT"] != "" { + envVarsSet = append(envVarsSet, "GITHUB_APM_PAT") + } + if len(envVarsSet) > 0 { + for _, line := range formatEnvironmentSetup(rtKind, envVarsSet) { + fmt.Println(line) + } + } + + var cmdErr error + if runtimeContent != "" { + cmdErr = s.executeRuntimeCommand(compiledCommand, runtimeContent, env) + } else { + cmdErr = runShellCommand(compiledCommand, env) + } + + if cmdErr != nil { + for _, line := range formatExecutionError(rtKind) { + fmt.Println(line) + } + var exitErr *exec.ExitError + if errors.As(cmdErr, &exitErr) { + return fmt.Errorf("Script execution failed with exit code %d", exitErr.ExitCode()) + } + return fmt.Errorf("Script execution failed: %w", cmdErr) + } + + for _, line := range formatExecutionSuccess(rtKind) { + fmt.Println(line) + } + return nil +} + +// ListScripts returns all available scripts from apm.yml. +func (s *ScriptRunner) ListScripts() map[string]string { + config, err := loadConfig() + if err != nil || config == nil { + return nil + } + scripts, ok := config["scripts"].(map[string]any) + if !ok { + return nil + } + result := make(map[string]string, len(scripts)) + for k, v := range scripts { + if str, ok := v.(string); ok { + result[k] = str + } + } + return result +} + +// autoCompilePrompts finds .prompt.md files in the command and compiles them. +// Returns (compiledCommand, compiledPromptFiles, runtimeContent). +func (s *ScriptRunner) autoCompilePrompts(command string, params map[string]string) (string, []string, string) { + re := regexp.MustCompile(`(\S+\.prompt\.md)`) + promptFiles := re.FindAllString(command, -1) + + var compiledPromptFiles []string + var runtimeContent string + compiledCommand := command + + runtimeCommands := []string{"copilot", "codex", "llm", "gemini"} + + for _, pf := range promptFiles { + compiledPath, err := s.Compiler.Compile(pf, params) + if err != nil { + continue + } + compiledPromptFiles = append(compiledPromptFiles, pf) + + data, err := os.ReadFile(compiledPath) + if err != nil { + continue + } + compiledContent := strings.TrimSpace(string(data)) + + // Check if this is a runtime command. + isRuntimeCmd := false + for _, rt := range runtimeCommands { + re2 := regexp.MustCompile(`(?:^|\s)` + rt + `(?:\s|$)`) + if re2.MatchString(command) && strings.Contains(command, pf) { + isRuntimeCmd = true + break + } + } + + compiledCommand = transformRuntimeCommand(compiledCommand, pf, compiledContent, compiledPath) + + if isRuntimeCmd { + runtimeContent = compiledContent + } + } + + return compiledCommand, compiledPromptFiles, runtimeContent +} + +// transformRuntimeCommand rewrites a command containing a .prompt.md reference +// to use the appropriate runtime invocation. +func transformRuntimeCommand(command, promptFile, compiledContent, compiledPath string) string { + runtimeCommands := []string{"codex", "copilot", "llm", "gemini"} + + // Try env-var prefix pattern first. + for _, rt := range runtimeCommands { + rtPattern := " " + rt + " " + if strings.Contains(command, rtPattern) && strings.Contains(command, promptFile) { + parts := strings.SplitN(command, rtPattern, 2) + potentialEnvPart := parts[0] + runtimePart := rt + " " + parts[1] + + if strings.Contains(potentialEnvPart, "=") && !strings.HasPrefix(potentialEnvPart, rt) { + result := parseAndBuildRuntimeCommand(rt, runtimePart, promptFile, potentialEnvPart) + if result != "" { + return result + } + } + } + } + + // Try individual runtime patterns without env-var prefix. + for _, rt := range runtimeCommands { + re := regexp.MustCompile(`^` + rt + `\s+.*` + regexp.QuoteMeta(promptFile)) + if re.MatchString(command) { + result := parseAndBuildRuntimeCommand(rt, command, promptFile, "") + if result != "" { + return result + } + } + } + + // Bare prompt file -> codex exec. + if strings.TrimSpace(command) == promptFile { + return "codex exec" + } + + // Fallback: replace file path with compiled path. + return strings.ReplaceAll(command, promptFile, compiledPath) +} + +func parseAndBuildRuntimeCommand(rtCmd, commandPart, promptFile, envPrefix string) string { + pattern := regexp.MustCompile(rtCmd + `\s+(.*?)(` + regexp.QuoteMeta(promptFile) + `)(.*?)$`) + m := pattern.FindStringSubmatch(commandPart) + if m == nil { + return "" + } + argsBefore := strings.TrimSpace(m[1]) + argsAfter := strings.TrimSpace(m[3]) + + if envPrefix != "" && rtCmd != "codex" { + argsBefore = strings.TrimSpace(strings.ReplaceAll(argsBefore, "-p", "")) + } + + prefix := "" + if envPrefix != "" { + prefix = envPrefix + " " + } + + switch rtCmd { + case "codex": + result := prefix + "codex exec" + if argsBefore != "" { + result += " " + argsBefore + } + if argsAfter != "" { + result += " " + argsAfter + } + return result + case "copilot": + cleaned := strings.TrimSpace(strings.ReplaceAll(argsBefore, "-p", "")) + result := prefix + "copilot" + if cleaned != "" { + result += " " + cleaned + } + if argsAfter != "" { + result += " " + argsAfter + } + return result + case "llm": + result := prefix + "llm" + if argsBefore != "" { + result += " " + argsBefore + } + if argsAfter != "" { + result += " " + argsAfter + } + return result + case "gemini": + re := regexp.MustCompile(`(^|\s)-p(\s|$)`) + cleaned := strings.TrimSpace(re.ReplaceAllString(argsBefore, "$1$2")) + result := prefix + "gemini" + if cleaned != "" { + result += " " + cleaned + } + if argsAfter != "" { + result += " " + argsAfter + } + return result + } + return "" +} + +// detectRuntime detects which runtime is referenced in a command. +func detectRuntime(command string) RuntimeKind { + lower := strings.ToLower(strings.TrimSpace(command)) + patterns := []struct { + rt RuntimeKind + pat string + }{ + {RuntimeCopilot, `(?:^|\s)copilot(?:\s|$)`}, + {RuntimeCodex, `(?:^|\s)codex(?:\s|$)`}, + {RuntimeLLM, `(?:^|\s)llm(?:\s|$)`}, + {RuntimeGemini, `(?:^|\s)gemini(?:\s|$)`}, + } + for _, p := range patterns { + if matched, _ := regexp.MatchString(p.pat, lower); matched { + return p.rt + } + } + return RuntimeUnknown +} + +// executeRuntimeCommand runs a runtime command passing content as an argument. +func (s *ScriptRunner) executeRuntimeCommand(command, content string, env map[string]string) error { + args := splitArgs(command) + + // Extract env-var prefixes from the front of args. + envVars := copyEnv(env) + var actualArgs []string + for _, arg := range args { + if strings.Contains(arg, "=") && len(actualArgs) == 0 { + kv := strings.SplitN(arg, "=", 2) + if isValidEnvVarName(kv[0]) { + envVars[kv[0]] = kv[1] + continue + } + } + actualArgs = append(actualArgs, arg) + } + + rtKind := detectRuntime(strings.Join(actualArgs, " ")) + switch rtKind { + case RuntimeCopilot: + actualArgs = append(actualArgs, "-p", content) + case RuntimeCodex: + actualArgs = append(actualArgs, content) + case RuntimeLLM: + actualArgs = append(actualArgs, content) + case RuntimeGemini: + actualArgs = append(actualArgs, "-p", content) + default: + actualArgs = append(actualArgs, content) + } + + // On Windows, resolve via PATH to find .cmd / .ps1 wrappers. + if len(actualArgs) > 0 && runtime.GOOS == "windows" { + if resolved, err := exec.LookPath(actualArgs[0]); err == nil { + actualArgs[0] = resolved + } + } + + cmd := exec.Command(actualArgs[0], actualArgs[1:]...) //nolint:gosec + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Env = envMapToSlice(envVars) + return cmd.Run() +} + +// runShellCommand executes a command via the system shell. +func runShellCommand(command string, env map[string]string) error { + var cmd *exec.Cmd + if runtime.GOOS == "windows" { + cmd = exec.Command("cmd", "/C", command) //nolint:gosec + } else { + cmd = exec.Command("sh", "-c", command) //nolint:gosec + } + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Env = envMapToSlice(env) + return cmd.Run() +} + +// discoverPromptFile discovers a prompt file by name. +func (s *ScriptRunner) discoverPromptFile(name string) string { + if strings.Contains(name, "/") { + return s.discoverQualifiedPrompt(name) + } + + searchName := name + if !strings.HasSuffix(searchName, ".prompt.md") { + searchName = name + ".prompt.md" + } + + // Local search paths. + localPaths := []string{ + searchName, + filepath.Join(".apm", "prompts", searchName), + filepath.Join(".github", "prompts", searchName), + } + for _, p := range localPaths { + fi, err := os.Lstat(p) + if err == nil && !fi.IsDir() && fi.Mode()&fs.ModeSymlink == 0 { + return p + } + } + + // Search in apm_modules. + apmModules := "apm_modules" + if _, err := os.Stat(apmModules); err != nil { + return "" + } + + var matches []string + _ = filepath.WalkDir(apmModules, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return nil + } + if d.Type()&fs.ModeSymlink != 0 { + return nil + } + if d.Name() == searchName { + matches = append(matches, path) + } + // Also look for SKILL.md in a directory matching `name`. + if d.IsDir() && d.Name() == name { + skillFile := filepath.Join(path, "SKILL.md") + if fi, err2 := os.Lstat(skillFile); err2 == nil && !fi.IsDir() { + matches = append(matches, skillFile) + } + } + return nil + }) + + if len(matches) == 1 { + return matches[0] + } + if len(matches) > 1 { + // Collision — build error message and print it; callers check empty string. + fmt.Fprint(os.Stderr, buildCollisionError(name, matches)) + return "" + } + return "" +} + +func buildCollisionError(name string, matches []string) string { + var b strings.Builder + fmt.Fprintf(&b, "Multiple prompts found for '%s':\n", name) + for _, m := range matches { + parts := strings.Split(filepath.ToSlash(m), "/") + idx := -1 + for i, p := range parts { + if p == "apm_modules" { + idx = i + break + } + } + if idx >= 0 && idx+2 < len(parts) { + fmt.Fprintf(&b, " - %s/%s (%s)\n", parts[idx+1], parts[idx+2], m) + } else { + fmt.Fprintf(&b, " - %s\n", m) + } + } + fmt.Fprintln(&b, "\nPlease specify using qualified path:") + for _, m := range matches { + parts := strings.Split(filepath.ToSlash(m), "/") + idx := -1 + for i, p := range parts { + if p == "apm_modules" { + idx = i + break + } + } + if idx >= 0 && idx+2 < len(parts) { + fmt.Fprintf(&b, " apm run %s/%s/%s\n", parts[idx+1], parts[idx+2], name) + } + } + fmt.Fprintln(&b, "\nOr add an explicit script to apm.yml:") + fmt.Fprintln(&b, " scripts:") + fmt.Fprintf(&b, " my-%s: \"copilot -p \"\n", name) + return b.String() +} + +// discoverQualifiedPrompt discovers a prompt using owner/repo/name format. +func (s *ScriptRunner) discoverQualifiedPrompt(qualifiedPath string) string { + parts := strings.Split(qualifiedPath, "/") + if len(parts) < 2 { + return "" + } + + promptName := parts[len(parts)-1] + if !strings.HasSuffix(promptName, ".prompt.md") { + promptName = promptName + ".prompt.md" + } + + apmModules := "apm_modules" + if _, err := os.Stat(apmModules); err != nil { + return "" + } + + // For 3+ part qualified paths, check subdirectory SKILL.md first. + if len(parts) >= 3 { + subdirPath := filepath.Join(append([]string{apmModules}, parts...)...) + skillFile := filepath.Join(subdirPath, "SKILL.md") + if fi, err := os.Lstat(skillFile); err == nil && !fi.IsDir() { + return skillFile + } + } + + owner := parts[0] + ownerDir := filepath.Join(apmModules, owner) + if _, err := os.Stat(ownerDir); err != nil { + return "" + } + + entries, err := os.ReadDir(ownerDir) + if err != nil { + return "" + } + + for _, entry := range entries { + if !entry.IsDir() { + continue + } + pkgDir := filepath.Join(ownerDir, entry.Name()) + var found string + _ = filepath.WalkDir(pkgDir, func(path string, d fs.DirEntry, err error) error { + if err != nil || found != "" { + return nil + } + if d.Name() == promptName { + // Check qualified path match. + pathSlash := filepath.ToSlash(path) + qParts := strings.Split(qualifiedPath, "/") + if qParts[0] != "" && strings.Contains(pathSlash, qParts[0]) { + expectedName := qParts[len(qParts)-1] + if !strings.HasSuffix(expectedName, ".prompt.md") { + expectedName += ".prompt.md" + } + if d.Name() == expectedName { + found = path + } + } + } + return nil + }) + if found != "" { + return found + } + } + return "" +} + +// isVirtualPackageReference returns true if name looks like owner/repo/... syntax. +func isVirtualPackageReference(name string) bool { + return strings.Count(name, "/") >= 2 +} + +// autoInstallVirtualPackage is a stub — actual install requires network access. +func (s *ScriptRunner) autoInstallVirtualPackage(packageRef string) bool { + fmt.Printf(" [x] Auto-install not supported in Go runtime: %s\n", packageRef) + return false +} + +// detectInstalledRuntime detects an installed AI runtime CLI. +func detectInstalledRuntime() (RuntimeKind, error) { + for _, rt := range []struct { + name RuntimeKind + bin string + }{ + {RuntimeCopilot, "copilot"}, + {RuntimeCodex, "codex"}, + {RuntimeGemini, "gemini"}, + } { + if _, err := exec.LookPath(rt.bin); err == nil { + return rt.name, nil + } + } + return RuntimeUnknown, errors.New("No compatible runtime found.\n" + + "Install GitHub Copilot CLI with:\n" + + " apm runtime setup copilot\n" + + "Or install Codex CLI with:\n" + + " apm runtime setup codex\n" + + "Or install Gemini CLI with:\n" + + " apm runtime setup gemini") +} + +// generateRuntimeCommand generates a default runtime invocation for a discovered prompt. +func generateRuntimeCommand(rt RuntimeKind, promptFile string) string { + switch rt { + case RuntimeCopilot: + return fmt.Sprintf("copilot --log-level all --log-dir copilot-logs --allow-all-tools -p %s", promptFile) + case RuntimeCodex: + return fmt.Sprintf("codex -s workspace-write --skip-git-repo-check %s", promptFile) + case RuntimeGemini: + return fmt.Sprintf("gemini -p %s", promptFile) + default: + return fmt.Sprintf("copilot -p %s", promptFile) + } +} + +// setupRuntimeEnvironment builds the environment map for script execution. +func setupRuntimeEnvironment() map[string]string { + env := make(map[string]string) + for _, kv := range os.Environ() { + idx := strings.IndexByte(kv, '=') + if idx >= 0 { + env[kv[:idx]] = kv[idx+1:] + } + } + return env +} + +// loadConfig loads apm.yml from the current directory using a minimal YAML parser. +func loadConfig() (map[string]any, error) { + data, err := os.ReadFile("apm.yml") + if err != nil { + return nil, err + } + return parseSimpleYAML(string(data)), nil +} + +// parseSimpleYAML is a minimal single-level YAML parser sufficient for apm.yml. +func parseSimpleYAML(content string) map[string]any { + result := make(map[string]any) + scanner := bufio.NewScanner(strings.NewReader(content)) + + var currentKey string + var currentList []any + var currentMap map[string]any + inMap := false + inList := false + + flush := func() { + if currentKey == "" { + return + } + if inMap && currentMap != nil { + result[currentKey] = currentMap + } else if inList && currentList != nil { + result[currentKey] = currentList + } + currentKey = "" + currentMap = nil + currentList = nil + inMap = false + inList = false + } + + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(strings.TrimSpace(line), "#") { + continue + } + + // Top-level key: value pair + if !strings.HasPrefix(line, " ") && !strings.HasPrefix(line, "\t") && strings.Contains(line, ":") { + parts := strings.SplitN(line, ":", 2) + key := strings.TrimSpace(parts[0]) + val := strings.TrimSpace(parts[1]) + + flush() + currentKey = key + + if val == "" { + // Could be start of map or list — wait for next line + inMap = false + inList = false + } else { + result[key] = unquoteYAML(val) + currentKey = "" + } + continue + } + + // Indented list item + if strings.HasPrefix(strings.TrimLeft(line, " \t"), "- ") && currentKey != "" { + item := strings.TrimSpace(strings.TrimLeft(line, " \t")[2:]) + if !inList { + flush() + currentKey = strings.Split(line, ":")[0] // recover key — but we lost it + } + inList = true + currentList = append(currentList, unquoteYAML(item)) + continue + } + + // Indented key: value (sub-map) + trimmed := strings.TrimLeft(line, " \t") + if strings.Contains(trimmed, ":") && currentKey != "" { + parts := strings.SplitN(trimmed, ":", 2) + subKey := strings.TrimSpace(parts[0]) + subVal := strings.TrimSpace(parts[1]) + if !inMap { + currentMap = make(map[string]any) + inMap = true + } + currentMap[subKey] = unquoteYAML(subVal) + continue + } + } + flush() + return result +} + +func unquoteYAML(s string) string { + if len(s) >= 2 && + ((s[0] == '"' && s[len(s)-1] == '"') || + (s[0] == '\'' && s[len(s)-1] == '\'')) { + return s[1 : len(s)-1] + } + return s +} + +// createMinimalConfig creates a minimal apm.yml for zero-config usage. +func createMinimalConfig() error { + cwd, _ := os.Getwd() + name := filepath.Base(cwd) + content := fmt.Sprintf("name: %s\nversion: 1.0.0\ndescription: Auto-generated for zero-config virtual package execution\n", name) + return os.WriteFile("apm.yml", []byte(content), 0o644) +} + +// -- Helpers ----------------------------------------------------------------- + +func splitArgs(command string) []string { + // Simple POSIX-style tokenizer: handle quoted strings. + var args []string + var current strings.Builder + inSingle := false + inDouble := false + + for i := 0; i < len(command); i++ { + c := command[i] + switch { + case c == '\'' && !inDouble: + inSingle = !inSingle + case c == '"' && !inSingle: + inDouble = !inDouble + case c == ' ' && !inSingle && !inDouble: + if current.Len() > 0 { + args = append(args, current.String()) + current.Reset() + } + default: + current.WriteByte(c) + } + } + if current.Len() > 0 { + args = append(args, current.String()) + } + return args +} + +func isValidEnvVarName(s string) bool { + if len(s) == 0 { + return false + } + for i, c := range s { + if i == 0 && !(c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c == '_') { + return false + } + if i > 0 && !(c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9' || c == '_') { + return false + } + } + return true +} + +func copyEnv(m map[string]string) map[string]string { + out := make(map[string]string, len(m)) + for k, v := range m { + out[k] = v + } + return out +} + +func envMapToSlice(m map[string]string) []string { + out := make([]string, 0, len(m)) + for k, v := range m { + out = append(out, k+"="+v) + } + return out +} + +// -- Formatter stubs (plain-text, no Rich dependency) ----------------------- + +func formatScriptHeader(scriptName string, params map[string]string) []string { + lines := []string{fmt.Sprintf("[*] Running script: %s", scriptName)} + if len(params) > 0 { + parts := make([]string, 0, len(params)) + for k, v := range params { + parts = append(parts, k+"="+v) + } + lines = append(lines, " Parameters: "+strings.Join(parts, ", ")) + } + return lines +} + +func formatCompilationProgress(files []string) []string { + return []string{fmt.Sprintf("[*] Compiled: %s", strings.Join(files, ", "))} +} + +func formatRuntimeExecution(rt RuntimeKind, command string, contentLen int) []string { + return []string{fmt.Sprintf("[>] Executing via %s (%d bytes)", rt, contentLen)} +} + +func formatContentPreview(content string) []string { + preview := content + if len(preview) > 200 { + preview = preview[:200] + "..." + } + return []string{" " + strings.ReplaceAll(preview, "\n", "\n ")} +} + +func formatEnvironmentSetup(rt RuntimeKind, vars []string) []string { + return []string{fmt.Sprintf("[i] Environment: %s", strings.Join(vars, ", "))} +} + +func formatExecutionSuccess(rt RuntimeKind) []string { + return []string{fmt.Sprintf("[+] Script completed successfully via %s", rt)} +} + +func formatExecutionError(rt RuntimeKind) []string { + return []string{fmt.Sprintf("[x] Script failed via %s", rt)} +} diff --git a/internal/core/targetdetection/targetdetection.go b/internal/core/targetdetection/targetdetection.go new file mode 100644 index 0000000..3cca86e --- /dev/null +++ b/internal/core/targetdetection/targetdetection.go @@ -0,0 +1,289 @@ +// Package targetdetection implements target auto-detection for APM CLI. +// Migrated from src/apm_cli/core/target_detection.py. +package targetdetection + +import ( + "fmt" + "os" + "path/filepath" + "sort" + "strings" +) + +// ValidTargets is the set of canonical target names. +var ValidTargets = map[string]bool{ + "vscode": true, + "claude": true, + "cursor": true, + "opencode": true, + "codex": true, + "gemini": true, + "windsurf": true, + "agent-skills": true, + "all": true, + "minimal": true, + "copilot": true, // alias + "agents": true, // alias (deprecated) +} + +// NormalizeTarget resolves user-facing aliases to canonical internal names. +func NormalizeTarget(t string) string { + switch t { + case "copilot", "vscode", "agents": + return "vscode" + default: + return t + } +} + +// CANONICAL_TARGETS_ORDERED lists display-ordered canonical target names. +var CanonicalTargetsOrdered = []string{ + "claude", + "copilot", + "cursor", + "codex", + "gemini", + "opencode", + "windsurf", +} + +// CanonicalDeployDirs maps canonical target names to their deploy directories. +var CanonicalDeployDirs = map[string]string{ + "claude": ".claude/", + "copilot": ".github/", + "cursor": ".cursor/", + "codex": ".codex/", + "gemini": ".gemini/", + "opencode": ".opencode/", + "windsurf": ".windsurf/", +} + +// CanonicalSignal maps canonical target names to their primary detection signal. +var CanonicalSignal = map[string]string{ + "claude": "CLAUDE.md", + "copilot": ".github/copilot-instructions.md", + "cursor": ".cursor/", + "codex": ".codex/", + "gemini": "GEMINI.md", + "opencode": ".opencode/", + "windsurf": ".windsurf/", +} + +// signalEntry is one row in the whitelist. +type signalEntry struct { + target string + checkType string // "dir" or "file" + path string +} + +// signalWhitelist is the ordered list of filesystem markers. +var signalWhitelist = []signalEntry{ + {"claude", "dir", ".claude"}, + {"claude", "file", "CLAUDE.md"}, + {"cursor", "dir", ".cursor"}, + {"cursor", "file", ".cursorrules"}, + {"copilot", "file", ".github/copilot-instructions.md"}, + {"codex", "dir", ".codex"}, + {"gemini", "dir", ".gemini"}, + {"gemini", "file", "GEMINI.md"}, + {"opencode", "dir", ".opencode"}, + {"windsurf", "dir", ".windsurf"}, +} + +// Signal represents a detected filesystem marker. +type Signal struct { + Target string + Source string +} + +// ResolvedTargets is the result of target resolution. +type ResolvedTargets struct { + Targets []string // sorted canonical target names + Source string // human-readable source description + AutoCreate bool +} + +// DetectSignals scans projectRoot for harness markers. +func DetectSignals(projectRoot string) []Signal { + var found []Signal + for _, entry := range signalWhitelist { + full := filepath.Join(projectRoot, entry.path) + switch entry.checkType { + case "dir": + if info, err := os.Stat(full); err == nil && info.IsDir() { + found = append(found, Signal{Target: entry.target, Source: entry.path + "/"}) + } + case "file": + if info, err := os.Stat(full); err == nil && !info.IsDir() { + found = append(found, Signal{Target: entry.target, Source: entry.path}) + } + } + } + return found +} + +// ResolveTargets resolves effective targets. Returns error on ambiguity or missing harness. +// Priority: flag > yamlTargets > auto-detect signals. +func ResolveTargets(projectRoot string, flag []string, yamlTargets []string) (ResolvedTargets, error) { + // Priority 1: --target flag + if len(flag) > 0 { + for _, t := range flag { + if !ValidTargets[t] { + return ResolvedTargets{}, fmt.Errorf("unknown target: %s", t) + } + } + sorted := sortedUnique(flag) + return ResolvedTargets{Targets: sorted, Source: "--target flag", AutoCreate: true}, nil + } + + // Priority 2: apm.yml targets + if len(yamlTargets) > 0 { + sorted := sortedUnique(yamlTargets) + return ResolvedTargets{Targets: sorted, Source: "apm.yml", AutoCreate: true}, nil + } + + // Priority 3: auto-detect + signals := DetectSignals(projectRoot) + targetSet := map[string]bool{} + var sources []string + for _, s := range signals { + if !targetSet[s.Target] { + targetSet[s.Target] = true + } + sources = append(sources, s.Source) + } + sort.Strings(sources) + + targetList := sortedKeys(targetSet) + + if len(targetList) == 0 { + return ResolvedTargets{}, fmt.Errorf("no harness found in %s", projectRoot) + } + if len(targetList) >= 2 { + return ResolvedTargets{}, fmt.Errorf("ambiguous harness: multiple targets detected: %s", strings.Join(targetList, ", ")) + } + + return ResolvedTargets{ + Targets: targetList, + Source: "auto-detect from " + strings.Join(sources, ", "), + AutoCreate: true, + }, nil +} + +// ExpandAllTargets expands 'all' to (signals union yamlTargets). +func ExpandAllTargets(projectRoot string, yamlTargets []string) ([]string, error) { + signals := DetectSignals(projectRoot) + combined := map[string]bool{} + for _, s := range signals { + combined[s.Target] = true + } + for _, t := range yamlTargets { + combined[t] = true + } + result := sortedKeys(combined) + if len(result) == 0 { + return nil, fmt.Errorf("no harness found in %s", projectRoot) + } + return result, nil +} + +// FormatProvenance formats a provenance line for CLI output. +func FormatProvenance(resolved ResolvedTargets) string { + targets := strings.Join(resolved.Targets, ", ") + return fmt.Sprintf("Targets: %s (source: %s)", targets, resolved.Source) +} + +// DetectTarget implements the legacy v1 detection API. +// Returns (target, reason). +func DetectTarget(projectRoot string, explicitTarget, configTarget string) (string, string) { + if explicitTarget != "" { + return NormalizeTarget(explicitTarget), "explicit --target flag" + } + if configTarget != "" { + return NormalizeTarget(configTarget), "apm.yml target" + } + + githubExists := dirExists(filepath.Join(projectRoot, ".github")) + claudeExists := dirExists(filepath.Join(projectRoot, ".claude")) + cursorExists := dirExists(filepath.Join(projectRoot, ".cursor")) + opencodeExists := dirExists(filepath.Join(projectRoot, ".opencode")) + codexExists := dirExists(filepath.Join(projectRoot, ".codex")) + geminiExists := dirExists(filepath.Join(projectRoot, ".gemini")) + windsurfExists := dirExists(filepath.Join(projectRoot, ".windsurf")) + + var detected []string + if githubExists { + detected = append(detected, ".github/") + } + if claudeExists { + detected = append(detected, ".claude/") + } + if cursorExists { + detected = append(detected, ".cursor/") + } + if opencodeExists { + detected = append(detected, ".opencode/") + } + if codexExists { + detected = append(detected, ".codex/") + } + if geminiExists { + detected = append(detected, ".gemini/") + } + if windsurfExists { + detected = append(detected, ".windsurf/") + } + + if len(detected) >= 2 { + return "all", fmt.Sprintf("detected %s folders", strings.Join(detected, " and ")) + } + if githubExists { + return "vscode", "detected .github/ folder" + } + if claudeExists { + return "claude", "detected .claude/ folder" + } + if cursorExists { + return "cursor", "detected .cursor/ folder" + } + if opencodeExists { + return "opencode", "detected .opencode/ folder" + } + if codexExists { + return "codex", "detected .codex/ folder" + } + if geminiExists { + return "gemini", "detected .gemini/ folder" + } + if windsurfExists { + return "windsurf", "detected .windsurf/ folder" + } + return "minimal", "no target folder found" +} + +func dirExists(path string) bool { + info, err := os.Stat(path) + return err == nil && info.IsDir() +} + +func sortedUnique(items []string) []string { + seen := map[string]bool{} + var result []string + for _, s := range items { + if !seen[s] { + seen[s] = true + result = append(result, s) + } + } + sort.Strings(result) + return result +} + +func sortedKeys(m map[string]bool) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + return keys +} diff --git a/internal/core/targetdetection/targetdetection_test.go b/internal/core/targetdetection/targetdetection_test.go new file mode 100644 index 0000000..abf1b34 --- /dev/null +++ b/internal/core/targetdetection/targetdetection_test.go @@ -0,0 +1,38 @@ +package targetdetection + +import "testing" + +func TestDetectTarget_explicit(t *testing.T) { + target, reason := DetectTarget("/tmp", "copilot", "") + if target != "vscode" { + t.Errorf("expected vscode got %s", target) + } + if reason != "explicit --target flag" { + t.Errorf("unexpected reason: %s", reason) + } +} + +func TestNormalizeTarget(t *testing.T) { + cases := map[string]string{ + "copilot": "vscode", + "agents": "vscode", + "vscode": "vscode", + "claude": "claude", + "cursor": "cursor", + } + for in, want := range cases { + got := NormalizeTarget(in) + if got != want { + t.Errorf("NormalizeTarget(%q) = %q, want %q", in, got, want) + } + } +} + +func TestFormatProvenance(t *testing.T) { + r := ResolvedTargets{Targets: []string{"claude", "copilot"}, Source: "--target flag"} + got := FormatProvenance(r) + want := "Targets: claude, copilot (source: --target flag)" + if got != want { + t.Errorf("got %q want %q", got, want) + } +} diff --git a/internal/core/tokenmanager/timer.go b/internal/core/tokenmanager/timer.go new file mode 100644 index 0000000..f028d84 --- /dev/null +++ b/internal/core/tokenmanager/timer.go @@ -0,0 +1,13 @@ +package tokenmanager + +import "time" + +// timerAfter returns a channel that closes after n seconds. +func timerAfter(n int) <-chan struct{} { + ch := make(chan struct{}) + go func() { + time.Sleep(time.Duration(n) * time.Second) + close(ch) + }() + return ch +} diff --git a/internal/core/tokenmanager/tokenmanager.go b/internal/core/tokenmanager/tokenmanager.go new file mode 100644 index 0000000..3aaadb2 --- /dev/null +++ b/internal/core/tokenmanager/tokenmanager.go @@ -0,0 +1,455 @@ +// Package tokenmanager provides centralized token management for different AI runtimes +// and git platforms. It handles the complex token environment setup required by +// different AI CLI tools, each of which expects different environment variable names. +package tokenmanager + +import ( + "net/url" + "os" + "os/exec" + "runtime" + "strconv" + "strings" + + "github.com/githubnext/apm/internal/utils/githubhost" +) + +// ADOBearerSource is the diagnostic source label for bearer-resolved tokens (AAD via az CLI). +const ADOBearerSource = "AAD_BEARER_AZ_CLI" + +// DefaultCredentialTimeout is the default timeout for git credential fill operations. +const DefaultCredentialTimeout = 60 + +// MaxCredentialTimeout is the maximum allowed credential timeout. +const MaxCredentialTimeout = 180 + +// tokenPrecedence defines token precedence for different use cases. +var tokenPrecedence = map[string][]string{ + "copilot": {"GITHUB_COPILOT_PAT", "GITHUB_TOKEN", "GITHUB_APM_PAT"}, + "models": {"GITHUB_TOKEN", "GITHUB_APM_PAT"}, + "modules": {"GITHUB_APM_PAT", "GITHUB_TOKEN", "GH_TOKEN"}, + "gitlab_modules": {"GITLAB_APM_PAT", "GITLAB_TOKEN"}, + "generic_modules": {}, + "ado_modules": {"ADO_APM_PAT"}, + "artifactory_modules": {"ARTIFACTORY_APM_TOKEN"}, +} + +// runtimeEnvVars defines runtime-specific environment variable mappings. +var runtimeEnvVars = map[string][]string{ + "copilot": {"GH_TOKEN", "GITHUB_PERSONAL_ACCESS_TOKEN"}, + "codex": {"GITHUB_TOKEN"}, + "llm": {"GITHUB_MODELS_KEY"}, +} + +// GitHubTokenManager manages GitHub token environment setup for different AI runtimes. +type GitHubTokenManager struct { + PreserveExisting bool + credentialCache map[credentialKey]*string +} + +type credentialKey struct { + host string + port *int +} + +// New creates a new GitHubTokenManager. +func New(preserveExisting bool) *GitHubTokenManager { + return &GitHubTokenManager{ + PreserveExisting: preserveExisting, + credentialCache: make(map[credentialKey]*string), + } +} + +// formatCredentialHost embeds a custom port into the git credential host field. +func formatCredentialHost(host string, port *int) string { + if port != nil { + return host + ":" + strconv.Itoa(*port) + } + return host +} + +// sanitizeCredentialPath strips leading /, rejects control chars, allowlists URL schemes. +func sanitizeCredentialPath(path string) string { + parsed, err := url.Parse(path) + scheme := "" + if err == nil { + scheme = strings.ToLower(parsed.Scheme) + } + if scheme != "" { + allowed := map[string]bool{"https": true, "http": true, "ssh": true} + if !allowed[scheme] { + return "" + } + } + var cleaned string + if scheme != "" && err == nil { + cleaned = strings.TrimLeft(parsed.Path, "/") + } else { + cleaned = strings.TrimLeft(path, "/") + } + if cleaned == "" { + return "" + } + for _, ch := range cleaned { + if ch < 0x20 || ch == 0x7F || ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' { + return "" + } + } + return cleaned +} + +// isValidCredentialToken validates that a credential-fill token looks like a real credential. +func isValidCredentialToken(token string) bool { + if token == "" { + return false + } + if len(token) > 1024 { + return false + } + for _, ch := range []byte(token) { + if ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' { + return false + } + } + prompts := []string{"Password for", "Username for", "password for", "username for"} + for _, p := range prompts { + if strings.Contains(token, p) { + return false + } + } + return true +} + +// supportsGhCLIHost returns true when host should use gh CLI fallback. +func supportsGhCLIHost(host string) bool { + if host == "" { + return false + } + if githubhost.IsGitHubHostname(host) { + return true + } + configuredHost := strings.ToLower(githubhost.DefaultHost()) + hostLower := strings.ToLower(host) + if hostLower != configuredHost { + return false + } + if configuredHost == "github.com" || strings.HasSuffix(configuredHost, ".ghe.com") { + return false + } + if githubhost.IsAzureDevOpsHostname(configuredHost) { + return false + } + return githubhost.IsValidFQDN(configuredHost) +} + +// getCredentialTimeout returns the timeout for git credential fill. +func getCredentialTimeout() int { + raw := strings.TrimSpace(os.Getenv("APM_GIT_CREDENTIAL_TIMEOUT")) + if raw == "" { + return DefaultCredentialTimeout + } + val, err := strconv.Atoi(raw) + if err != nil || val < 1 { + return DefaultCredentialTimeout + } + if val > MaxCredentialTimeout { + return MaxCredentialTimeout + } + return val +} + +// ResolveCredentialFromGit resolves a credential from the git credential store. +func ResolveCredentialFromGit(host string, port *int, path string) *string { + hostField := formatCredentialHost(host, port) + lines := []string{"protocol=https", "host=" + hostField} + if path != "" { + sanitized := sanitizeCredentialPath(path) + if sanitized != "" { + lines = append(lines, "path="+sanitized) + } + } + stdin := strings.Join(lines, "\n") + "\n\n" + + env := os.Environ() + env = appendOrReplace(env, "GIT_TERMINAL_PROMPT", "0") + if runtime.GOOS != "windows" { + env = appendOrReplace(env, "GIT_ASKPASS", "") + } else { + env = appendOrReplace(env, "GIT_ASKPASS", "echo") + } + + timeout := getCredentialTimeout() + cmd := exec.Command("git", "credential", "fill") + cmd.Env = env + cmd.Stdin = strings.NewReader(stdin) + done := make(chan struct{}) + var out []byte + var runErr error + go func() { + out, runErr = cmd.Output() + close(done) + }() + + timer := make(chan struct{}) + go func() { + select { + case <-done: + case <-timerAfter(timeout): + cmd.Process.Kill() //nolint:errcheck + close(timer) + return + } + }() + + select { + case <-done: + case <-timer: + return nil + } + + if runErr != nil { + return nil + } + + for _, line := range strings.Split(string(out), "\n") { + if strings.HasPrefix(line, "password=") { + token := line[len("password="):] + if isValidCredentialToken(token) { + return &token + } + return nil + } + } + return nil +} + +// ResolveCredentialFromGhCLI resolves a token from the active gh CLI account for host. +func ResolveCredentialFromGhCLI(host string) *string { + if !supportsGhCLIHost(host) { + return nil + } + env := os.Environ() + env = appendOrReplace(env, "GH_PROMPT_DISABLED", "1") + env = appendOrReplace(env, "GH_NO_UPDATE_NOTIFIER", "1") + + timeout := getCredentialTimeout() + cmd := exec.Command("gh", "auth", "token", "--hostname", host) + cmd.Env = env + cmd.Stdin = strings.NewReader("") + done := make(chan struct{}) + var out []byte + var runErr error + go func() { + out, runErr = cmd.Output() + close(done) + }() + + timer := make(chan struct{}) + go func() { + select { + case <-done: + case <-timerAfter(timeout): + if cmd.Process != nil { + cmd.Process.Kill() //nolint:errcheck + } + close(timer) + return + } + }() + + select { + case <-done: + case <-timer: + return nil + } + + if runErr != nil { + return nil + } + + token := strings.TrimSpace(string(out)) + if isValidCredentialToken(token) { + return &token + } + return nil +} + +// SetupEnvironment sets up the complete token environment for all runtimes. +func (m *GitHubTokenManager) SetupEnvironment(env map[string]string) map[string]string { + if env == nil { + env = osEnvMap() + } + available := m.getAvailableTokens(env) + m.setupCopilotTokens(env, available) + m.setupCodexTokens(env, available) + m.setupLLMTokens(env, available) + return env +} + +// GetTokenForPurpose gets the best available token for a specific purpose. +func (m *GitHubTokenManager) GetTokenForPurpose(purpose string, env map[string]string) (string, bool) { + if env == nil { + env = osEnvMap() + } + vars, ok := tokenPrecedence[purpose] + if !ok { + return "", false + } + for _, v := range vars { + if t, exists := env[v]; exists && t != "" { + return t, true + } + } + return "", false +} + +// GetTokenWithCredentialFallback gets a token, falling back to git credential helpers. +func (m *GitHubTokenManager) GetTokenWithCredentialFallback(purpose, host string, env map[string]string, port *int) (string, bool) { + if tok, ok := m.GetTokenForPurpose(purpose, env); ok { + return tok, true + } + key := credentialKey{host: host, port: port} + if cached, exists := m.credentialCache[key]; exists { + if cached != nil { + return *cached, true + } + return "", false + } + if supportsGhCLIHost(host) { + if t := ResolveCredentialFromGhCLI(host); t != nil { + m.credentialCache[key] = t + return *t, true + } + } + t := ResolveCredentialFromGit(host, port, "") + m.credentialCache[key] = t + if t != nil { + return *t, true + } + return "", false +} + +// ValidateTokens validates that required tokens are available. +func (m *GitHubTokenManager) ValidateTokens(env map[string]string) (bool, string) { + if env == nil { + env = osEnvMap() + } + hasAny := false + for _, purpose := range []string{"copilot", "models", "modules"} { + if _, ok := m.GetTokenForPurpose(purpose, env); ok { + hasAny = true + break + } + } + if !hasAny { + return false, "No tokens found. Set one of:\n- GITHUB_TOKEN (user-scoped PAT for GitHub Models)\n- GITHUB_APM_PAT (fine-grained PAT for APM modules on GitHub)\n- ADO_APM_PAT (PAT for APM modules on Azure DevOps)" + } + if _, ok := m.GetTokenForPurpose("models", env); !ok { + if env["GITHUB_APM_PAT"] != "" { + return true, "Warning: Only fine-grained PAT available. GitHub Models requires GITHUB_TOKEN (user-scoped PAT)" + } + } + return true, "Token validation passed" +} + +func (m *GitHubTokenManager) getAvailableTokens(env map[string]string) map[string]string { + tokens := make(map[string]string) + for _, vars := range tokenPrecedence { + for _, v := range vars { + if t, ok := env[v]; ok && t != "" { + tokens[v] = t + } + } + } + return tokens +} + +func (m *GitHubTokenManager) setupCopilotTokens(env, available map[string]string) { + tok, ok := m.GetTokenForPurpose("copilot", available) + if !ok { + return + } + for _, v := range runtimeEnvVars["copilot"] { + if m.PreserveExisting { + if _, exists := env[v]; exists { + continue + } + } + env[v] = tok + } +} + +func (m *GitHubTokenManager) setupCodexTokens(env, available map[string]string) { + if !(m.PreserveExisting && env["GITHUB_TOKEN"] != "") { + if tok, ok := m.GetTokenForPurpose("models", available); ok { + if env["GITHUB_TOKEN"] == "" { + env["GITHUB_TOKEN"] = tok + } + } + } + if !(m.PreserveExisting && env["GITHUB_APM_PAT"] != "") { + if t, ok := available["GITHUB_APM_PAT"]; ok && env["GITHUB_APM_PAT"] == "" { + env["GITHUB_APM_PAT"] = t + } + } +} + +func (m *GitHubTokenManager) setupLLMTokens(env, available map[string]string) { + if m.PreserveExisting && env["GITHUB_MODELS_KEY"] != "" { + return + } + if tok, ok := m.GetTokenForPurpose("models", available); ok { + env["GITHUB_MODELS_KEY"] = tok + } +} + +// SetupRuntimeEnvironment sets up the complete runtime environment for all AI CLIs. +func SetupRuntimeEnvironment(env map[string]string) map[string]string { + m := New(true) + return m.SetupEnvironment(env) +} + +// ValidateGitHubTokens validates GitHub token setup. +func ValidateGitHubTokens(env map[string]string) (bool, string) { + m := New(true) + return m.ValidateTokens(env) +} + +// GetGitHubTokenForRuntime gets the appropriate GitHub token for a specific runtime. +func GetGitHubTokenForRuntime(runtime string, env map[string]string) (string, bool) { + m := New(true) + runtimeToPurpose := map[string]string{ + "copilot": "copilot", + "codex": "models", + "llm": "models", + } + purpose, ok := runtimeToPurpose[runtime] + if !ok { + return "", false + } + return m.GetTokenForPurpose(purpose, env) +} + +// osEnvMap returns os.Environ as a map. +func osEnvMap() map[string]string { + m := make(map[string]string) + for _, kv := range os.Environ() { + i := strings.IndexByte(kv, '=') + if i < 0 { + continue + } + m[kv[:i]] = kv[i+1:] + } + return m +} + +func appendOrReplace(env []string, key, val string) []string { + prefix := key + "=" + for i, kv := range env { + if strings.HasPrefix(kv, prefix) { + env[i] = prefix + val + return env + } + } + return append(env, prefix+val) +} diff --git a/internal/deps/apmresolver/resolver.go b/internal/deps/apmresolver/resolver.go new file mode 100644 index 0000000..dca3225 --- /dev/null +++ b/internal/deps/apmresolver/resolver.go @@ -0,0 +1,452 @@ +// Package apmresolver implements the APM dependency resolution engine. +// +// Provides BFS-based dependency resolution, circular dependency detection, +// and dependency flattening following an NPM-hoisting "first-wins" strategy. +// +// Migrated from: src/apm_cli/deps/apm_resolver.py +package apmresolver + +import ( + "os" + "path/filepath" + "sort" + "strconv" + "strings" + "sync" + + "github.com/githubnext/apm/internal/deps/depgraph" + "github.com/githubnext/apm/internal/models/depreference" +) + +const defaultResolveParallel = 4 + +// DownloadFunc is a callback invoked to download a missing dependency. +// It mirrors the Python DownloadCallback protocol. +// Parameters: +// - ref: the dependency reference to download +// - apmModulesDir: the apm_modules directory path +// - parentChain: breadcrumb string (e.g. "root > mid > dep") +// - parentPkg: the package that declared this dependency, or "" +// +// Returns the install path on success, or "" on failure. +type DownloadFunc func(ref *depreference.DependencyReference, apmModulesDir, parentChain, parentPkg string) string + +// workItem is the unit of work dispatched during the BFS download phase. +type workItem struct { + node *depgraph.DependencyNode + depRef *depreference.DependencyReference + parentNode *depgraph.DependencyNode + isDev bool +} + +// workResult is returned by the worker goroutine. +type workResult struct { + item workItem + installed bool + err string +} + +// Resolver resolves APM dependencies recursively. +type Resolver struct { + maxDepth int + apmModulesDir string + projectRoot string + downloadFn DownloadFunc + maxParallel int + + mu sync.Mutex + downloadedPackages map[string]bool + rejectedRemoteLocalKeys map[string]bool + callbackFailures map[string]string +} + +// Options for constructing a Resolver. +type Options struct { + // MaxDepth is the maximum resolution depth (default: 50). + MaxDepth int + // ApmModulesDir is an explicit apm_modules directory path (optional). + ApmModulesDir string + // DownloadFn is invoked when a transitive dep is not installed. + DownloadFn DownloadFunc + // MaxParallel controls the worker pool size for the BFS level batches. + // 0 or negative falls back to the APM_RESOLVE_PARALLEL env var, then + // to defaultResolveParallel. + MaxParallel int +} +// New creates a Resolver with the given options. +func New(opts Options) *Resolver { + maxDepth := opts.MaxDepth + if maxDepth <= 0 { + maxDepth = 50 + } + return &Resolver{ + maxDepth: maxDepth, + apmModulesDir: opts.ApmModulesDir, + downloadFn: opts.DownloadFn, + maxParallel: resolveMaxParallel(opts.MaxParallel), + downloadedPackages: make(map[string]bool), + rejectedRemoteLocalKeys: make(map[string]bool), + callbackFailures: make(map[string]string), + } +} + +func resolveMaxParallel(explicit int) int { + if explicit > 0 { + return explicit + } + if env := strings.TrimSpace(os.Getenv("APM_RESOLVE_PARALLEL")); env != "" { + if n, err := strconv.Atoi(env); err == nil && n > 0 { + return n + } + } + return defaultResolveParallel +} + +// ResolveDependencies performs a full BFS dependency resolution starting from +// the apm.yml in projectRoot. +func (r *Resolver) ResolveDependencies(projectRoot string) *depgraph.DependencyGraph { + r.projectRoot = projectRoot + if r.apmModulesDir == "" { + r.apmModulesDir = filepath.Join(projectRoot, "apm_modules") + } + + apmYMLPath := filepath.Join(projectRoot, "apm.yml") + if _, err := os.Stat(apmYMLPath); os.IsNotExist(err) { + g := depgraph.NewDependencyGraph("unknown") + return g + } + + tree := r.buildDependencyTree(apmYMLPath) + circularDeps := r.detectCircularDependencies(tree) + flattened := r.flattenDependencies(tree) + + g := depgraph.NewDependencyGraph(filepath.Base(projectRoot)) + g.Tree = tree + g.Flattened = flattened + for _, c := range circularDeps { + g.AddCircularDependency(c) + } + return g +} + +// buildDependencyTree performs BFS expansion of the dependency tree. +func (r *Resolver) buildDependencyTree(rootApmYML string) *depgraph.DependencyTree { + tree := depgraph.NewDependencyTree() + + // Read root package dependencies from apm.yml using a simple line scanner. + deps := r.readApmYMLDeps(rootApmYML) + + // BFS queue: (depRef, parentNode, depth, isDev) + type queueItem struct { + ref *depreference.DependencyReference + parent *depgraph.DependencyNode + depth int + isDev bool + } + + var queue []queueItem + for _, d := range deps { + dCopy := d + queue = append(queue, queueItem{ref: dCopy, parent: nil, depth: 1, isDev: false}) + } + + visited := make(map[string]bool) + + for len(queue) > 0 { + // Collect all items at the current depth level for parallel dispatch. + currentDepth := queue[0].depth + var level []queueItem + remaining := queue[:0] + for _, qi := range queue { + if qi.depth == currentDepth { + level = append(level, qi) + } else { + remaining = append(remaining, qi) + } + } + queue = remaining + + // Deduplicate within the level and filter already-visited. + var work []workItem + for _, qi := range level { + key := qi.ref.GetUniqueKey() + if visited[key] { + continue + } + if qi.depth > r.maxDepth { + continue + } + node := &depgraph.DependencyNode{ + Ref: depgraph.DependencyRef{ + RepoURL: qi.ref.RepoURL, + Reference: qi.ref.Reference, + UniqueKey: key, + VirtualPath: qi.ref.VirtualPath, + DisplayName: qi.ref.GetDisplayName(), + }, + Depth: qi.depth, + Parent: qi.parent, + IsDev: qi.isDev, + } + if qi.parent != nil { + qi.parent.Children = append(qi.parent.Children, node) + } + tree.AddNode(node) + visited[key] = true + work = append(work, workItem{ + node: node, + depRef: qi.ref, + parentNode: qi.parent, + isDev: qi.isDev, + }) + } + + if len(work) == 0 { + continue + } + + // Dispatch work items (potentially in parallel). + results := r.dispatchLevel(work) + + // For each successfully loaded package, enqueue its transitive deps. + for _, res := range results { + if !res.installed { + if res.err != "" { + r.mu.Lock() + r.callbackFailures[res.item.depRef.GetUniqueKey()] = res.err + r.mu.Unlock() + } + continue + } + // Load transitive deps from the installed package. + installPath := r.resolveInstallPath(res.item.depRef) + if installPath == "" { + continue + } + transApmYML := filepath.Join(installPath, "apm.yml") + if _, err := os.Stat(transApmYML); err != nil { + continue + } + transDeps := r.readApmYMLDeps(transApmYML) + for _, td := range transDeps { + tdCopy := td + queue = append(queue, queueItem{ + ref: tdCopy, + parent: res.item.node, + depth: res.item.node.Depth + 1, + isDev: res.item.isDev, + }) + } + } + } + + return tree +} + +// dispatchLevel runs workItems, using a goroutine pool if maxParallel > 1. +func (r *Resolver) dispatchLevel(items []workItem) []workResult { + results := make([]workResult, len(items)) + + if r.maxParallel <= 1 || r.downloadFn == nil { + for i, item := range items { + results[i] = r.processWorkItem(item) + } + return results + } + + sem := make(chan struct{}, r.maxParallel) + var wg sync.WaitGroup + for i, item := range items { + wg.Add(1) + go func(idx int, wi workItem) { + defer wg.Done() + sem <- struct{}{} + results[idx] = r.processWorkItem(wi) + <-sem + }(i, item) + } + wg.Wait() + return results +} + +func (r *Resolver) processWorkItem(item workItem) workResult { + if r.downloadFn == nil { + // No downloader -- check if already installed. + installPath := r.resolveInstallPath(item.depRef) + installed := installPath != "" + return workResult{item: item, installed: installed} + } + + key := item.depRef.GetUniqueKey() + r.mu.Lock() + alreadyDownloaded := r.downloadedPackages[key] + r.mu.Unlock() + if alreadyDownloaded { + return workResult{item: item, installed: true} + } + + parentChain := "" + if item.node != nil { + parentChain = item.node.GetAncestorChain() + } + parentPkg := "" + if item.parentNode != nil { + parentPkg = item.parentNode.Ref.UniqueKey + } + + result := r.downloadFn(item.depRef, r.apmModulesDir, parentChain, parentPkg) + if result == "" { + return workResult{item: item, installed: false, err: "download returned empty path"} + } + + r.mu.Lock() + r.downloadedPackages[key] = true + r.mu.Unlock() + return workResult{item: item, installed: true} +} + +// resolveInstallPath returns the installation path for a dependency, or "". +func (r *Resolver) resolveInstallPath(ref *depreference.DependencyReference) string { + key := ref.GetUniqueKey() + // Normalize: use last path segment as dir name. + parts := strings.Split(key, "/") + name := parts[len(parts)-1] + candidate := filepath.Join(r.apmModulesDir, name) + if _, err := os.Stat(candidate); err == nil { + return candidate + } + return "" +} + +// readApmYMLDeps reads dependency references from an apm.yml file using a +// minimal line-scanner (no external YAML library required). +func (r *Resolver) readApmYMLDeps(apmYMLPath string) []*depreference.DependencyReference { + data, err := os.ReadFile(apmYMLPath) + if err != nil { + return nil + } + return parseApmYMLDeps(string(data)) +} + +// parseApmYMLDeps extracts dependency strings from apm.yml content and parses +// each into a DependencyReference. +func parseApmYMLDeps(content string) []*depreference.DependencyReference { + var refs []*depreference.DependencyReference + inDeps := false + for _, rawLine := range strings.Split(content, "\n") { + line := strings.TrimRight(rawLine, "\r") + trimmed := strings.TrimSpace(line) + + if trimmed == "dependencies:" || trimmed == "devDependencies:" { + inDeps = true + continue + } + if inDeps { + // End of the deps section: a non-indented, non-list line. + if len(line) > 0 && line[0] != ' ' && line[0] != '\t' && trimmed != "" && !strings.HasPrefix(trimmed, "-") { + inDeps = false + continue + } + if strings.HasPrefix(trimmed, "-") { + raw := strings.TrimPrefix(trimmed, "-") + raw = strings.TrimSpace(raw) + // Strip inline comments. + if idx := strings.Index(raw, " #"); idx >= 0 { + raw = strings.TrimSpace(raw[:idx]) + } + // Strip surrounding quotes. + raw = strings.Trim(raw, `"'`) + if raw != "" { + ref, err := depreference.Parse(raw) + if err == nil { + refs = append(refs, ref) + } + } + } + } + } + return refs +} + +// detectCircularDependencies performs DFS cycle detection on the tree. +func (r *Resolver) detectCircularDependencies(tree *depgraph.DependencyTree) []depgraph.CircularRef { + var cycles []depgraph.CircularRef + visited := make(map[string]bool) + var currentPath []string + currentPathSet := make(map[string]bool) + + var dfs func(node *depgraph.DependencyNode) + dfs = func(node *depgraph.DependencyNode) { + nodeID := node.GetID() + uniqueKey := node.Ref.UniqueKey + + if currentPathSet[uniqueKey] { + // Cycle detected. + startIdx := -1 + for i, k := range currentPath { + if k == uniqueKey { + startIdx = i + break + } + } + if startIdx >= 0 { + cyclePath := append([]string{}, currentPath[startIdx:]...) + cyclePath = append(cyclePath, uniqueKey) + cycles = append(cycles, depgraph.CircularRef{ + CyclePath: cyclePath, + DetectedAtDepth: node.Depth, + }) + } + return + } + + visited[nodeID] = true + currentPath = append(currentPath, uniqueKey) + currentPathSet[uniqueKey] = true + + for _, child := range node.Children { + childID := child.GetID() + if !visited[childID] || currentPathSet[child.Ref.UniqueKey] { + dfs(child) + } + } + + // Backtrack. + currentPath = currentPath[:len(currentPath)-1] + delete(currentPathSet, uniqueKey) + } + + for _, node := range tree.GetNodesAtDepth(1) { + if !visited[node.GetID()] { + currentPath = nil + currentPathSet = make(map[string]bool) + dfs(node) + } + } + return cycles +} + +// flattenDependencies flattens the tree using BFS breadth-first, first-wins +// conflict resolution (NPM hoisting). +func (r *Resolver) flattenDependencies(tree *depgraph.DependencyTree) *depgraph.FlatDependencyMap { + flat := depgraph.NewFlatDependencyMap() + seen := make(map[string]bool) + + for depth := 1; depth <= tree.MaxDepth; depth++ { + nodes := tree.GetNodesAtDepth(depth) + // Deterministic ordering. + sort.Slice(nodes, func(i, j int) bool { + return nodes[i].GetID() < nodes[j].GetID() + }) + for _, node := range nodes { + key := node.Ref.UniqueKey + if !seen[key] { + flat.AddDependency(node.Ref, false) + seen[key] = true + } else { + flat.AddDependency(node.Ref, true) + } + } + } + return flat +} diff --git a/internal/deps/downloadstrategies/strategies.go b/internal/deps/downloadstrategies/strategies.go new file mode 100644 index 0000000..dd534b4 --- /dev/null +++ b/internal/deps/downloadstrategies/strategies.go @@ -0,0 +1,693 @@ +// Package downloadstrategies implements the DownloadDelegate -- the +// backend-specific HTTP download logic for APM packages. +// +// Encapsulates resilient HTTP GET, GitHub Contents API, Azure DevOps, +// GitLab, Artifactory archive, and generic-host file download logic. +// The owning GitHubPackageDownloader creates a single DownloadDelegate +// and delegates all download operations to it (Facade/Delegate pattern). +// +// Migrated from: src/apm_cli/deps/download_strategies.py +package downloadstrategies + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "io" + "math" + "math/rand" + "net/http" + "net/url" + "os" + "strconv" + "strings" + "time" + + "github.com/githubnext/apm/internal/core/auth" + "github.com/githubnext/apm/internal/models/depreference" + "github.com/githubnext/apm/internal/utils/githubhost" +) + +// HostProvider is the interface the DownloadDelegate requires from its owner. +// This avoids a circular package dependency on the github_downloader package. +type HostProvider interface { + // GithubToken returns the GitHub personal access token (may be empty). + GithubToken() string + // AdoToken returns the Azure DevOps PAT (may be empty). + AdoToken() string + // ArtifactoryToken returns the Artifactory bearer token (may be empty). + ArtifactoryToken() string + // GithubHost returns the configured GitHub host (may be empty for default). + GithubHost() string + // AuthResolver returns the authentication resolver. + AuthResolver() *auth.AuthResolver + // ResilientGet performs an HTTP GET with retry/rate-limit handling. + // Callers should treat a non-nil error as exhausted retries. + ResilientGet(reqURL string, headers map[string]string, timeoutSecs int) (*http.Response, error) +} + +// resolveToken extracts the token string from *string (nil -> ""). +func resolveToken(t *string) string { + if t == nil { + return "" + } + return *t +} + +// authResolve wraps AuthResolver.Resolve, handling the *int port parameter. +func authResolve(ar *auth.AuthResolver, host, org string, port int) (token, source string) { + var portPtr *int + if port != 0 { + portPtr = &port + } + ctx := ar.Resolve(host, org, portPtr) + if ctx == nil { + return "", "" + } + return resolveToken(ctx.Token), ctx.Source +} + +// DownloadDelegate encapsulates backend-specific download logic. +// +// Holds real implementations of HTTP resilient-get, URL building, and +// file download for GitHub, Azure DevOps, and Artifactory backends. +type DownloadDelegate struct { + host HostProvider +} + +// New creates a DownloadDelegate that delegates shared state to host. +func New(host HostProvider) *DownloadDelegate { + return &DownloadDelegate{host: host} +} + +// debug prints a message when APM_DEBUG is set. +func debug(msg string) { + if os.Getenv("APM_DEBUG") != "" { + fmt.Fprintf(os.Stderr, "[DEBUG] %s\n", msg) + } +} + +// --------------------------------------------------------------------------- +// HTTP resilient GET (standalone helper for callers without a HostProvider) +// --------------------------------------------------------------------------- + +// ResilientGet performs an HTTP GET with exponential-backoff retry on 429/503 +// and rate-limit header awareness. +// +// Returns the *http.Response and nil on success. If all retries are +// exhausted it returns the last response (which may be rate-limited) plus a +// non-nil error. +func ResilientGet(reqURL string, headers map[string]string, timeoutSecs, maxRetries int) (*http.Response, error) { + if timeoutSecs <= 0 { + timeoutSecs = 30 + } + if maxRetries <= 0 { + maxRetries = 3 + } + client := &http.Client{Timeout: time.Duration(timeoutSecs) * time.Second} + + var lastResp *http.Response + var lastErr error + + for attempt := 0; attempt < maxRetries; attempt++ { + req, err := http.NewRequest(http.MethodGet, reqURL, nil) + if err != nil { + return nil, fmt.Errorf("build request: %w", err) + } + for k, v := range headers { + req.Header.Set(k, v) + } + + resp, err := client.Do(req) + if err != nil { + lastErr = err + if attempt < maxRetries-1 { + wait := jitter(math.Pow(2, float64(attempt))) + debug(fmt.Sprintf("Connection error, retry in %.1fs (attempt %d/%d)", wait, attempt+1, maxRetries)) + time.Sleep(time.Duration(wait*float64(time.Second))) + } + continue + } + + // Rate limiting: 429, 503, or 403 with X-RateLimit-Remaining: 0. + isRateLimited := resp.StatusCode == 429 || resp.StatusCode == 503 + if !isRateLimited && resp.StatusCode == 403 { + if rem := resp.Header.Get("X-RateLimit-Remaining"); rem != "" { + if n, err := strconv.Atoi(rem); err == nil && n == 0 { + isRateLimited = true + } + } + } + + if isRateLimited { + lastResp = resp + wait := backoffFromRateLimitHeaders(resp, attempt) + debug(fmt.Sprintf("Rate limited (%d), retry in %.1fs (attempt %d/%d)", resp.StatusCode, wait, attempt+1, maxRetries)) + time.Sleep(time.Duration(wait * float64(time.Second))) + continue + } + + // Log rate-limit proximity. + if rem := resp.Header.Get("X-RateLimit-Remaining"); rem != "" { + if n, err := strconv.Atoi(rem); err == nil && n < 10 { + debug(fmt.Sprintf("GitHub API rate limit low: %d requests remaining", n)) + } + } + return resp, nil + } + + if lastResp != nil { + return lastResp, fmt.Errorf("rate limit retries exhausted for %s", reqURL) + } + if lastErr != nil { + return nil, lastErr + } + return nil, fmt.Errorf("all %d attempts failed for %s", maxRetries, reqURL) +} + +func jitter(base float64) float64 { + if base > 30 { + base = 30 + } + return base * (0.5 + rand.Float64()) +} + +func backoffFromRateLimitHeaders(resp *http.Response, attempt int) float64 { + if ra := resp.Header.Get("Retry-After"); ra != "" { + if v, err := strconv.ParseFloat(ra, 64); err == nil { + if v < 60 { + return v + } + return 60 + } + } + if reset := resp.Header.Get("X-RateLimit-Reset"); reset != "" { + if ts, err := strconv.ParseInt(reset, 10, 64); err == nil { + wait := float64(ts) - float64(time.Now().Unix()) + if wait > 0 && wait < 60 { + return wait + } + } + } + return jitter(math.Pow(2, float64(attempt))) +} + +// --------------------------------------------------------------------------- +// Repository URL building +// --------------------------------------------------------------------------- + +// BuildRepoURLOptions controls how BuildRepoURL constructs its result. +type BuildRepoURLOptions struct { + RepoRef string + UseSSH bool + DepRef *depreference.DependencyReference + Token string + AuthScheme string // "basic" | "bearer" (default: "basic") +} + +// BuildRepoURL constructs the repository URL for git clone operations. +// Supports GitHub, Azure DevOps, GitLab, and generic hosts. +func (d *DownloadDelegate) BuildRepoURL(opts BuildRepoURLOptions) string { + var host string + if opts.DepRef != nil && opts.DepRef.Host != "" { + host = opts.DepRef.Host + } else if h := d.host.GithubHost(); h != "" { + host = h + } else { + host = githubhost.DefaultHost() + } + + token := opts.Token + if token == "" { + token = d.host.GithubToken() + } + + repoRef := opts.RepoRef + if opts.DepRef != nil && repoRef == "" { + repoRef = opts.DepRef.RepoURL + } + + var port int + if opts.DepRef != nil { + port = opts.DepRef.Port + } + + if opts.UseSSH { + return buildSSHURL(host, repoRef, port) + } + if token != "" { + return buildHTTPSCloneURL(host, repoRef, token, port) + } + return buildHTTPSCloneURL(host, repoRef, "", port) +} + +func buildSSHURL(host, repoRef string, port int) string { + if port != 0 { + return fmt.Sprintf("ssh://git@%s:%d/%s.git", host, port, repoRef) + } + return fmt.Sprintf("git@%s:%s.git", host, repoRef) +} + +func buildHTTPSCloneURL(host, repoRef, token string, port int) string { + var netloc string + if port != 0 { + netloc = fmt.Sprintf("%s:%d", host, port) + } else { + netloc = host + } + if token != "" { + return fmt.Sprintf("https://x-access-token:%s@%s/%s.git", token, netloc, repoRef) + } + return fmt.Sprintf("https://%s/%s.git", netloc, repoRef) +} + +// --------------------------------------------------------------------------- +// Artifactory helpers +// --------------------------------------------------------------------------- + +// GetArtifactoryHeaders returns HTTP headers for Artifactory requests. +func (d *DownloadDelegate) GetArtifactoryHeaders() map[string]string { + headers := make(map[string]string) + if tok := d.host.ArtifactoryToken(); tok != "" { + headers["Authorization"] = "Bearer " + tok + } + return headers +} + +// ArtifactoryDownloadResult holds the result of an Artifactory archive download. +type ArtifactoryDownloadResult struct { + Data []byte + Err error +} + +// DownloadArtifactoryArchive downloads an archive from Artifactory. +func (d *DownloadDelegate) DownloadArtifactoryArchive(archiveURL string) ArtifactoryDownloadResult { + headers := d.GetArtifactoryHeaders() + + resp, err := ResilientGet(archiveURL, headers, 120, 3) + if err != nil { + return ArtifactoryDownloadResult{Err: fmt.Errorf("artifactory archive download: %w", err)} + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return ArtifactoryDownloadResult{ + Err: fmt.Errorf("artifactory archive HTTP %d for %s", resp.StatusCode, archiveURL), + } + } + + data, err := io.ReadAll(resp.Body) + if err != nil { + return ArtifactoryDownloadResult{Err: fmt.Errorf("reading artifactory archive: %w", err)} + } + return ArtifactoryDownloadResult{Data: data} +} + +// DownloadFileFromArtifactory downloads a single file from Artifactory. +func (d *DownloadDelegate) DownloadFileFromArtifactory(fileURL string) ([]byte, error) { + headers := d.GetArtifactoryHeaders() + resp, err := ResilientGet(fileURL, headers, 30, 3) + if err != nil { + return nil, fmt.Errorf("artifactory file download: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("HTTP %d for %s", resp.StatusCode, fileURL) + } + return io.ReadAll(resp.Body) +} + +// --------------------------------------------------------------------------- +// Raw download (CDN fast-path for github.com) +// --------------------------------------------------------------------------- + +// TryRawDownload attempts to fetch a file via raw.githubusercontent.com. +// Returns nil if the file was not found or the request failed. +func (d *DownloadDelegate) TryRawDownload(owner, repo, ref, filePath string) []byte { + rawURL := fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", owner, repo, ref, filePath) + resp, err := ResilientGet(rawURL, nil, 30, 2) + if err != nil { + return nil + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil + } + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil + } + return data +} + +// --------------------------------------------------------------------------- +// Azure DevOps file download +// --------------------------------------------------------------------------- + +// buildADOAPIURL constructs the Azure DevOps Items API URL for a file. +func buildADOAPIURL(org, project, repo, filePath, ref, host string) string { + if host == "" { + host = "dev.azure.com" + } + return fmt.Sprintf( + "https://%s/%s/%s/_apis/git/repositories/%s/items?path=%s&versionType=branch&version=%s&api-version=6.0", + host, url.PathEscape(org), url.PathEscape(project), url.PathEscape(repo), + url.QueryEscape(filePath), url.QueryEscape(ref), + ) +} + +func (d *DownloadDelegate) DownloadADOFile(depRef *depreference.DependencyReference, filePath, ref string) ([]byte, error) { + if depRef == nil { + return nil, fmt.Errorf("nil dep_ref for ADO download") + } + if depRef.ADOOrganization == "" || depRef.ADOProject == "" || depRef.ADORepo == "" { + return nil, fmt.Errorf( + "invalid ADO dep_ref: missing org/project/repo (got org=%q project=%q repo=%q)", + depRef.ADOOrganization, depRef.ADOProject, depRef.ADORepo, + ) + } + + host := depRef.Host + if host == "" { + host = "dev.azure.com" + } + apiURL := buildADOAPIURL(depRef.ADOOrganization, depRef.ADOProject, depRef.ADORepo, filePath, ref, host) + + headers := make(map[string]string) + if tok := d.host.AdoToken(); tok != "" { + authBytes := []byte(":" + tok) + headers["Authorization"] = "Basic " + base64.StdEncoding.EncodeToString(authBytes) + } + + resp, err := d.host.ResilientGet(apiURL, headers, 30) + if err != nil { + return nil, fmt.Errorf("ADO download network error: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + return io.ReadAll(resp.Body) + } + if resp.StatusCode == http.StatusNotFound { + if ref == "main" || ref == "master" { + fallbackRef := "master" + if ref == "master" { + fallbackRef = "main" + } + fallbackURL := buildADOAPIURL(depRef.ADOOrganization, depRef.ADOProject, depRef.ADORepo, filePath, fallbackRef, host) + resp2, err2 := d.host.ResilientGet(fallbackURL, headers, 30) + if err2 != nil { + return nil, fmt.Errorf("ADO fallback download failed: %w", err2) + } + defer resp2.Body.Close() + if resp2.StatusCode == http.StatusOK { + return io.ReadAll(resp2.Body) + } + return nil, fmt.Errorf("file not found: %s in %s (tried refs: %s, %s)", filePath, depRef.RepoURL, ref, fallbackRef) + } + return nil, fmt.Errorf("file not found: %s at ref %q in %s", filePath, ref, depRef.RepoURL) + } + if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { + return nil, fmt.Errorf("authentication failed for Azure DevOps %s", depRef.RepoURL) + } + return nil, fmt.Errorf("ADO download HTTP %d for %s", resp.StatusCode, apiURL) +} + +// --------------------------------------------------------------------------- +// GitLab file download +// --------------------------------------------------------------------------- + +// DownloadGitLabFile downloads a file via the GitLab REST v4 API. +func (d *DownloadDelegate) DownloadGitLabFile(depRef *depreference.DependencyReference, filePath, ref string) ([]byte, error) { + if depRef == nil { + return nil, fmt.Errorf("nil dep_ref for GitLab download") + } + host := depRef.Host + if host == "" { + host = githubhost.DefaultHost() + } + projectPath := depRef.RepoURL + if projectPath == "" { + return nil, fmt.Errorf("missing repository path for GitLab file download") + } + + ar := d.host.AuthResolver() + var token string + if ar != nil { + org := "" + parts := strings.SplitN(projectPath, "/", 2) + if len(parts) > 0 { + org = parts[0] + } + t, _ := authResolve(ar, host, org, depRef.Port) + token = t + } + + headers := map[string]string{} + if token != "" { + headers["PRIVATE-TOKEN"] = token + } + + enc := url.PathEscape(projectPath) + encFile := url.PathEscape(filePath) + encRef := url.QueryEscape(ref) + apiURL := fmt.Sprintf("https://%s/api/v4/projects/%s/repository/files/%s/raw?ref=%s", host, enc, encFile, encRef) + + resp, err := d.host.ResilientGet(apiURL, headers, 30) + if err != nil { + return nil, fmt.Errorf("GitLab download error: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + return io.ReadAll(resp.Body) + } + if resp.StatusCode == http.StatusNotFound { + // Try the other default branch. + if ref == "main" || ref == "master" { + fallbackRef := "master" + if ref == "master" { + fallbackRef = "main" + } + encFallback := url.QueryEscape(fallbackRef) + fallbackURL := fmt.Sprintf("https://%s/api/v4/projects/%s/repository/files/%s/raw?ref=%s", host, enc, encFile, encFallback) + resp2, err2 := d.host.ResilientGet(fallbackURL, headers, 30) + if err2 == nil { + defer resp2.Body.Close() + if resp2.StatusCode == http.StatusOK { + return io.ReadAll(resp2.Body) + } + } + } + return nil, fmt.Errorf("file not found: %s at ref %q in %s", filePath, ref, projectPath) + } + return nil, fmt.Errorf("GitLab download HTTP %d", resp.StatusCode) +} + +// --------------------------------------------------------------------------- +// GitHub file download (Contents API) +// --------------------------------------------------------------------------- + +// DownloadGitHubFile downloads a file from a GitHub (or GHES/generic) repository. +func (d *DownloadDelegate) DownloadGitHubFile(depRef *depreference.DependencyReference, filePath, ref string) ([]byte, error) { + if depRef == nil { + return nil, fmt.Errorf("nil dep_ref for GitHub download") + } + host := depRef.Host + if host == "" { + host = githubhost.DefaultHost() + } + + parts := strings.SplitN(depRef.RepoURL, "/", 2) + if len(parts) < 2 { + return nil, fmt.Errorf("invalid repo_url %q: expected owner/repo", depRef.RepoURL) + } + owner, repo := parts[0], parts[1] + + ar := d.host.AuthResolver() + var token string + if ar != nil { + t, _ := authResolve(ar, host, owner, depRef.Port) + token = t + } + + isGitHubHost := githubhost.IsGitHubHostname(host) || d.isConfiguredGHES(host) + + // CDN fast-path for github.com without a token. + if strings.EqualFold(host, "github.com") && token == "" { + if data := d.TryRawDownload(owner, repo, ref, filePath); data != nil { + return data, nil + } + // Try alternate default branch. + if ref == "main" || ref == "master" { + alt := "master" + if ref == "master" { + alt = "main" + } + if data := d.TryRawDownload(owner, repo, alt, filePath); data != nil { + return data, nil + } + } + // Fall through to Contents API. + } + + // For non-GitHub generic hosts: try raw URL first. + if !isGitHubHost { + rawURL := fmt.Sprintf("https://%s/%s/%s/raw/%s/%s", host, owner, repo, ref, filePath) + rawHeaders := d.buildGenericHostAuthHeaders(host, depRef, nil) + if resp, err := d.host.ResilientGet(rawURL, rawHeaders, 30); err == nil { + defer resp.Body.Close() + if resp.StatusCode == http.StatusOK { + return io.ReadAll(resp.Body) + } + } + } + + // Contents API path. + apiURLs := d.buildContentsAPIURLs(host, owner, repo, filePath, ref, isGitHubHost) + if len(apiURLs) == 0 { + return nil, fmt.Errorf("could not build Contents API URL for %s", depRef.RepoURL) + } + + var apiHeaders map[string]string + if isGitHubHost { + apiHeaders = map[string]string{"Accept": "application/vnd.github.v3.raw"} + if token != "" { + apiHeaders["Authorization"] = "token " + token + } + } else { + apiHeaders = d.buildGenericHostAuthHeaders(host, depRef, nil) + apiHeaders["Accept"] = "application/json" + } + + for _, apiURL := range apiURLs { + resp, err := d.host.ResilientGet(apiURL, apiHeaders, 30) + if err != nil { + continue + } + defer resp.Body.Close() + if resp.StatusCode == http.StatusOK { + return extractContentsAPIPayload(resp, isGitHubHost) + } + if resp.StatusCode == http.StatusNotFound { + continue + } + if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { + return nil, fmt.Errorf("authentication failed for %s/%s on %s", owner, repo, host) + } + } + + // Try alternate default branch as final fallback. + if ref == "main" || ref == "master" { + alt := "master" + if ref == "master" { + alt = "main" + } + altURLs := d.buildContentsAPIURLs(host, owner, repo, filePath, alt, isGitHubHost) + for _, apiURL := range altURLs { + resp, err := d.host.ResilientGet(apiURL, apiHeaders, 30) + if err != nil { + continue + } + defer resp.Body.Close() + if resp.StatusCode == http.StatusOK { + return extractContentsAPIPayload(resp, isGitHubHost) + } + } + } + + return nil, fmt.Errorf("file not found: %s at ref %q in %s", filePath, ref, depRef.RepoURL) +} + +// buildContentsAPIURLs returns ordered API URL candidates for the given file. +func (d *DownloadDelegate) buildContentsAPIURLs(host, owner, repo, filePath, ref string, isGitHubHost bool) []string { + if isGitHubHost { + apiBase := "api.github.com" + if !strings.EqualFold(host, "github.com") { + apiBase = host + "/api/v3" + } + return []string{fmt.Sprintf("https://%s/repos/%s/%s/contents/%s?ref=%s", apiBase, owner, repo, filePath, url.QueryEscape(ref))} + } + // Generic host: try multiple API version paths. + return []string{ + fmt.Sprintf("https://%s/api/v1/repos/%s/%s/contents/%s?ref=%s", host, owner, repo, filePath, url.QueryEscape(ref)), + fmt.Sprintf("https://%s/api/v3/repos/%s/%s/contents/%s?ref=%s", host, owner, repo, filePath, url.QueryEscape(ref)), + } +} + +// buildGenericHostAuthHeaders builds auth headers for non-GitHub hosts. +func (d *DownloadDelegate) buildGenericHostAuthHeaders(host string, depRef *depreference.DependencyReference, accept *string) map[string]string { + headers := make(map[string]string) + if accept != nil { + headers["Accept"] = *accept + } + ar := d.host.AuthResolver() + if ar == nil { + return headers + } + var port int + org := "" + if depRef != nil { + port = depRef.Port + if parts := strings.SplitN(depRef.RepoURL, "/", 2); len(parts) > 0 { + org = parts[0] + } + } + token, src := authResolve(ar, host, org, port) + if token == "" { + return headers + } + // Only forward tokens for credential-helper-sourced or org-scoped sources, + // or explicitly configured GHES. + if src == "git-credential-fill" || strings.HasPrefix(src, "GITHUB_APM_PAT_") || d.isConfiguredGHES(host) { + headers["Authorization"] = "token " + token + } + return headers +} + +// isConfiguredGHES reports whether host is set as the configured GHES via GITHUB_HOST. +func (d *DownloadDelegate) isConfiguredGHES(host string) bool { + ghHost := strings.TrimSpace(os.Getenv("GITHUB_HOST")) + if ghHost == "" { + return false + } + return strings.EqualFold(ghHost, host) +} + +// extractContentsAPIPayload decodes a Contents-API response into raw bytes. +// +// GitHub family: returns response.Body bytes directly (vnd.github.v3.raw). +// Generic hosts (Gitea/Gogs): the server returns a JSON envelope +// {"content": "", "encoding": "base64"}. +func extractContentsAPIPayload(resp *http.Response, isGitHubHost bool) ([]byte, error) { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + if isGitHubHost { + return body, nil + } + ct := strings.ToLower(resp.Header.Get("Content-Type")) + if !strings.Contains(ct, "json") && (len(body) == 0 || body[0] != '{') { + return body, nil + } + var payload map[string]interface{} + if err := json.Unmarshal(body, &payload); err != nil { + return body, nil + } + contentField, ok := payload["content"] + if !ok { + return body, nil + } + encoding, _ := payload["encoding"].(string) + contentStr, _ := contentField.(string) + if strings.ToLower(encoding) == "base64" { + decoded, err := base64.StdEncoding.DecodeString(strings.ReplaceAll(contentStr, "\n", "")) + if err != nil { + return body, nil + } + return decoded, nil + } + return []byte(contentStr), nil +} diff --git a/internal/deps/gitauthenv/gitauthenv.go b/internal/deps/gitauthenv/gitauthenv.go new file mode 100644 index 0000000..5dd4425 --- /dev/null +++ b/internal/deps/gitauthenv/gitauthenv.go @@ -0,0 +1,154 @@ +// Package gitauthenv builds the various git environment dicts the downloader needs. +// Migrated from src/apm_cli/deps/git_auth_env.py +// +// Three env flavours: +// 1. SetupEnvironment -- auth-bearing env for primary git ops +// 2. NoninteractiveEnv -- non-auth env for unauthenticated fallback +// 3. SubprocessEnvDict -- sanitized env for cache-layer subprocess calls +package gitauthenv + +import ( + "os" + "runtime" + "strings" +) + +// GitAuthEnvBuilder builds the various git env dicts the downloader needs. +type GitAuthEnvBuilder struct { + baseEnv map[string]string +} + +// New returns a new GitAuthEnvBuilder. +// baseEnv is the auth-bearing environment provided by the token manager +// (analogous to token_manager.setup_environment() in Python). +func New(baseEnv map[string]string) *GitAuthEnvBuilder { + return &GitAuthEnvBuilder{baseEnv: baseEnv} +} + +// SetupEnvironment builds the auth-bearing primary git env. +// Sets GIT_TERMINAL_PROMPT, GIT_ASKPASS, GIT_CONFIG_NOSYSTEM, +// GIT_SSH_COMMAND (with ConnectTimeout=30), and GIT_CONFIG_GLOBAL. +func (b *GitAuthEnvBuilder) SetupEnvironment() map[string]string { + env := copyEnv(b.baseEnv) + + env["GIT_TERMINAL_PROMPT"] = "0" + env["GIT_ASKPASS"] = "echo" + env["GIT_CONFIG_NOSYSTEM"] = "1" + + // Ensure SSH connections fail fast (30 s timeout). + const sshTimeout = "-o ConnectTimeout=30" + existingSSH := strings.TrimSpace(os.Getenv("GIT_SSH_COMMAND")) + if existingSSH != "" { + if !strings.Contains(strings.ToLower(existingSSH), "connecttimeout") { + env["GIT_SSH_COMMAND"] = existingSSH + " " + sshTimeout + } else { + env["GIT_SSH_COMMAND"] = existingSSH + } + } else { + env["GIT_SSH_COMMAND"] = "ssh " + sshTimeout + } + + if runtime.GOOS == "windows" { + // On Windows, point GIT_CONFIG_GLOBAL at an empty file. + tmpDir := os.TempDir() + emptyCfg := tmpDir + "\\.apm_empty_gitconfig" + // Create the empty file (ignore errors -- best-effort). + f, err := os.OpenFile(emptyCfg, os.O_CREATE|os.O_WRONLY, 0o644) + if err == nil { + f.Close() + } + env["GIT_CONFIG_GLOBAL"] = emptyCfg + } else { + env["GIT_CONFIG_GLOBAL"] = "/dev/null" + } + + return env +} + +// NoninteractiveEnvOptions controls the credential-helper suppression fence. +type NoninteractiveEnvOptions struct { + // PreserveConfigIsolation keeps GIT_CONFIG_NOSYSTEM and GIT_CONFIG_GLOBAL. + PreserveConfigIsolation bool + // SuppressCredentialHelpers applies the full credential-helper fence + // (use for HTTP transport to avoid leaking tokens in plaintext). + SuppressCredentialHelpers bool +} + +// NoninteractiveEnv builds a non-interactive git env for unauthenticated operations. +// +// Credential-helper policy (two-stage): +// 1. Always clear GIT_ASKPASS so system credential helpers resolve naturally. +// 2. Re-set the full suppression fence only when SuppressCredentialHelpers is true. +func NoninteractiveEnv(baseGitEnv map[string]string, opts NoninteractiveEnvOptions) map[string]string { + env := copyEnv(baseGitEnv) + + env["GIT_TERMINAL_PROMPT"] = "0" + delete(env, "GIT_ASKPASS") + + if opts.PreserveConfigIsolation || opts.SuppressCredentialHelpers { + env["GIT_CONFIG_NOSYSTEM"] = "1" + if v, ok := baseGitEnv["GIT_CONFIG_GLOBAL"]; ok { + env["GIT_CONFIG_GLOBAL"] = v + } + } else { + delete(env, "GIT_CONFIG_GLOBAL") + delete(env, "GIT_CONFIG_NOSYSTEM") + } + + if opts.SuppressCredentialHelpers { + env["GIT_ASKPASS"] = "echo" + env["GIT_CONFIG_COUNT"] = "1" + env["GIT_CONFIG_KEY_0"] = "credential.helper" + env["GIT_CONFIG_VALUE_0"] = "" + } else { + delete(env, "GIT_CONFIG_COUNT") + delete(env, "GIT_CONFIG_KEY_0") + delete(env, "GIT_CONFIG_VALUE_0") + } + + return env +} + +// SubprocessEnvDict returns a sanitized git env dict for cache-layer subprocess calls. +// Merges the auth-aware baseGitEnv over a sanitized ambient env so the subprocess +// never inherits a stray GIT_DIR or GIT_CEILING_DIRECTORIES. +func SubprocessEnvDict(baseGitEnv map[string]string) map[string]string { + env := gitSubprocessEnv() + for k, v := range baseGitEnv { + env[k] = v + } + return env +} + +// gitSubprocessEnv returns the current process environment with git-state variables +// stripped so cache-layer subprocess calls start with a clean slate. +func gitSubprocessEnv() map[string]string { + stripKeys := map[string]bool{ + "GIT_DIR": true, + "GIT_CEILING_DIRECTORIES": true, + "GIT_WORK_TREE": true, + "GIT_INDEX_FILE": true, + "GIT_OBJECT_DIRECTORY": true, + "GIT_ALTERNATE_OBJECT_DIRECTORIES": true, + } + env := make(map[string]string) + for _, kv := range os.Environ() { + idx := strings.IndexByte(kv, '=') + if idx < 0 { + continue + } + k, v := kv[:idx], kv[idx+1:] + if !stripKeys[k] { + env[k] = v + } + } + return env +} + +func copyEnv(src map[string]string) map[string]string { + dst := make(map[string]string, len(src)) + for k, v := range src { + dst[k] = v + } + return dst +} diff --git a/internal/deps/hostbackends/hostbackends.go b/internal/deps/hostbackends/hostbackends.go new file mode 100644 index 0000000..3624037 --- /dev/null +++ b/internal/deps/hostbackends/hostbackends.go @@ -0,0 +1,371 @@ +// Package hostbackends provides vendor-specific URL/API construction for remote git hosts. +// Migrated from src/apm_cli/deps/host_backends.py. +// +// Each supported host kind is a concrete backend struct implementing the HostBackend interface. +// A dispatch function (BackendFor / BackendForHost) picks the right backend by consulting +// the auth package's ClassifyHost function. +package hostbackends + +import ( + "fmt" + "net/url" + "regexp" + "strings" + + "github.com/githubnext/apm/internal/core/auth" + "github.com/githubnext/apm/internal/utils/githubhost" +) + +var sha40RE = regexp.MustCompile(`^[a-f0-9]{40}$`) + +// DepRef is the minimal interface expected of a dependency reference by backend URL builders. +type DepRef interface { + // GetHost returns the host string for this dependency (may be ""). + GetHost() string + // GetPort returns the non-standard port, or nil if default. + GetPort() *int + // GetRepoURL returns the "owner/repo" URL string. + GetRepoURL() string + // GetADOOrganization returns the ADO organisation name, or "". + GetADOOrganization() string + // GetADOProject returns the ADO project name, or "". + GetADOProject() string + // GetADORepo returns the ADO repo name, or "". + GetADORepo() string + // IsAzureDevOps returns true when this dep references Azure DevOps. + IsAzureDevOps() bool + // IsInsecure returns true when the dep was declared with a plain HTTP URL. + IsInsecure() bool +} + +// HostBackend exposes URL/API construction for one remote git host kind. +type HostBackend interface { + // Kind returns a canonical host-kind string: "github", "ghe_cloud", "ghes", "ado", "gitlab", "generic". + Kind() string + // IsGitHubFamily returns true for github.com, *.ghe.com, and configured GHES hosts. + IsGitHubFamily() bool + // IsGeneric returns true for non-GitHub-family non-ADO hosts (GitLab, Bitbucket, Gitea, ...). + IsGeneric() bool + // GetHostInfo returns the HostInfo for this backend. + GetHostInfo() auth.HostInfo + + // BuildCloneHTTPSURL builds the HTTPS clone URL. + // token may be "" (anonymous / bearer), non-empty embeds credentials. + // authScheme "bearer" suppresses embedding the token in the URL. + BuildCloneHTTPSURL(dep DepRef, token string, authScheme string) string + // BuildCloneSSHURL builds the SSH clone URL. + BuildCloneSSHURL(dep DepRef) string + // BuildCloneHTTPURL builds a plain HTTP clone URL (only for is_insecure deps). + BuildCloneHTTPURL(dep DepRef) (string, error) + // BuildCommitsAPIURL returns the cheap commit-resolution API URL, or "" when unavailable. + BuildCommitsAPIURL(dep DepRef, ref string) string + // BuildContentsAPIURLs returns ordered Contents-API URL candidates for fetching a file. + BuildContentsAPIURLs(owner, repo, filePath, ref string) []string +} + +// --------------------------------------------------------------------------- +// URL builder helpers (mirror Python's github_host.py helpers) +// --------------------------------------------------------------------------- + +func buildHTTPSCloneURL(host, repoURL, token string, port *int) string { + if token != "" { + // embed as https://x-access-token:@host/owner/repo.git + netloc := netloc(host, port) + return fmt.Sprintf("https://x-access-token:%s@%s/%s.git", url.PathEscape(token), netloc, repoURL) + } + netloc := netloc(host, port) + return fmt.Sprintf("https://%s/%s.git", netloc, repoURL) +} + +func buildSSHURL(host, repoURL string, port *int) string { + if port != nil { + return fmt.Sprintf("ssh://git@%s:%d/%s.git", host, *port, repoURL) + } + return fmt.Sprintf("git@%s:%s.git", host, repoURL) +} + +func buildADOHTTPSCloneURL(org, project, repo, host, token string) string { + if host == "" { + host = "dev.azure.com" + } + base := fmt.Sprintf("https://%s/%s/%s/_git/%s", host, org, project, repo) + if token != "" { + base = fmt.Sprintf("https://%s@%s/%s/%s/_git/%s", token, host, org, project, repo) + } + return base +} + +func buildADOSSHURL(org, project, repo string) string { + return fmt.Sprintf("git@ssh.dev.azure.com:v3/%s/%s/%s", org, project, repo) +} + +func buildGitLabHTTPSCloneURL(host, repoURL, token string, port *int) string { + netloc := netloc(host, port) + if token != "" { + return fmt.Sprintf("https://oauth2:%s@%s/%s.git", url.PathEscape(token), netloc, repoURL) + } + return fmt.Sprintf("https://%s/%s.git", netloc, repoURL) +} + +func netloc(host string, port *int) string { + if port != nil { + return fmt.Sprintf("%s:%d", host, *port) + } + return host +} + +func urlHost(dep DepRef, fallback auth.HostInfo) string { + h := dep.GetHost() + if h != "" { + return h + } + return fallback.Host +} + +// --------------------------------------------------------------------------- +// GitHub-family shared base +// --------------------------------------------------------------------------- + +type gitHubFamilyBase struct { + hostInfo auth.HostInfo + kind string +} + +func (b *gitHubFamilyBase) Kind() string { return b.kind } +func (b *gitHubFamilyBase) IsGitHubFamily() bool { return true } +func (b *gitHubFamilyBase) IsGeneric() bool { return false } +func (b *gitHubFamilyBase) GetHostInfo() auth.HostInfo { return b.hostInfo } + +func (b *gitHubFamilyBase) BuildCloneHTTPSURL(dep DepRef, token string, authScheme string) string { + host := urlHost(dep, b.hostInfo) + port := dep.GetPort() + if authScheme == "bearer" { + token = "" + } + return buildHTTPSCloneURL(host, dep.GetRepoURL(), token, port) +} + +func (b *gitHubFamilyBase) BuildCloneSSHURL(dep DepRef) string { + host := urlHost(dep, b.hostInfo) + return buildSSHURL(host, dep.GetRepoURL(), dep.GetPort()) +} + +func (b *gitHubFamilyBase) BuildCloneHTTPURL(dep DepRef) (string, error) { + host := urlHost(dep, b.hostInfo) + port := dep.GetPort() + n := netloc(host, port) + return fmt.Sprintf("http://%s/%s.git", n, dep.GetRepoURL()), nil +} + +func (b *gitHubFamilyBase) BuildCommitsAPIURL(dep DepRef, ref string) string { + if sha40RE.MatchString(strings.ToLower(ref)) { + return "" + } + parts := strings.SplitN(dep.GetRepoURL(), "/", 2) + if len(parts) != 2 { + return "" + } + return fmt.Sprintf("%s/repos/%s/%s/commits/%s", b.hostInfo.APIBase, parts[0], parts[1], ref) +} + +func (b *gitHubFamilyBase) BuildContentsAPIURLs(owner, repo, filePath, ref string) []string { + return []string{ + fmt.Sprintf("%s/repos/%s/%s/contents/%s?ref=%s", b.hostInfo.APIBase, owner, repo, filePath, ref), + } +} + +// --------------------------------------------------------------------------- +// Concrete backends +// --------------------------------------------------------------------------- + +// GitHubBackend is the backend for github.com. +type GitHubBackend struct{ gitHubFamilyBase } + +// GHECloudBackend is the backend for *.ghe.com (GitHub Enterprise Cloud -- Data Residency). +type GHECloudBackend struct{ gitHubFamilyBase } + +// GHESBackend is the backend for self-hosted GitHub Enterprise Server. +type GHESBackend struct{ gitHubFamilyBase } + +// ADOBackend is the backend for Azure DevOps. +type ADOBackend struct { + hostInfo auth.HostInfo +} + +func (b *ADOBackend) Kind() string { return "ado" } +func (b *ADOBackend) IsGitHubFamily() bool { return false } +func (b *ADOBackend) IsGeneric() bool { return false } +func (b *ADOBackend) GetHostInfo() auth.HostInfo { return b.hostInfo } + +func (b *ADOBackend) BuildCloneHTTPSURL(dep DepRef, token string, authScheme string) string { + if dep.GetADOOrganization() == "" { + // Missing org -- return a diagnostic URL so callers can surface the error. + return "error://ado-missing-org" + } + host := urlHost(dep, b.hostInfo) + if host == "" { + host = "dev.azure.com" + } + return buildADOHTTPSCloneURL(dep.GetADOOrganization(), dep.GetADOProject(), dep.GetADORepo(), host, token) +} + +func (b *ADOBackend) BuildCloneSSHURL(dep DepRef) string { + return buildADOSSHURL(dep.GetADOOrganization(), dep.GetADOProject(), dep.GetADORepo()) +} + +func (b *ADOBackend) BuildCloneHTTPURL(_ DepRef) (string, error) { + return "", fmt.Errorf("Azure DevOps does not support plain HTTP cloning; use HTTPS or SSH") +} + +func (b *ADOBackend) BuildCommitsAPIURL(_ DepRef, _ string) string { return "" } + +func (b *ADOBackend) BuildContentsAPIURLs(_, _, _, _ string) []string { return nil } + +// GitLabBackend is the backend for GitLab (gitlab.com and self-managed instances). +type GitLabBackend struct { + hostInfo auth.HostInfo +} + +func (b *GitLabBackend) Kind() string { return "gitlab" } +func (b *GitLabBackend) IsGitHubFamily() bool { return false } +func (b *GitLabBackend) IsGeneric() bool { return true } +func (b *GitLabBackend) GetHostInfo() auth.HostInfo { return b.hostInfo } + +func (b *GitLabBackend) BuildCloneHTTPSURL(dep DepRef, token string, authScheme string) string { + host := urlHost(dep, b.hostInfo) + port := dep.GetPort() + if token != "" && authScheme != "bearer" { + return buildGitLabHTTPSCloneURL(host, dep.GetRepoURL(), token, port) + } + return buildHTTPSCloneURL(host, dep.GetRepoURL(), "", port) +} + +func (b *GitLabBackend) BuildCloneSSHURL(dep DepRef) string { + host := urlHost(dep, b.hostInfo) + return buildSSHURL(host, dep.GetRepoURL(), dep.GetPort()) +} + +func (b *GitLabBackend) BuildCloneHTTPURL(dep DepRef) (string, error) { + host := urlHost(dep, b.hostInfo) + n := netloc(host, dep.GetPort()) + return fmt.Sprintf("http://%s/%s.git", n, dep.GetRepoURL()), nil +} + +func (b *GitLabBackend) BuildCommitsAPIURL(dep DepRef, ref string) string { + if sha40RE.MatchString(strings.ToLower(ref)) { + return "" + } + proj := url.PathEscape(dep.GetRepoURL()) + return fmt.Sprintf("%s/projects/%s/repository/commits/%s", b.hostInfo.APIBase, proj, ref) +} + +func (b *GitLabBackend) BuildContentsAPIURLs(owner, repo, filePath, ref string) []string { + proj := url.PathEscape(owner + "/" + repo) + f := url.PathEscape(filePath) + return []string{ + fmt.Sprintf("%s/projects/%s/repository/files/%s/raw?ref=%s", b.hostInfo.APIBase, proj, f, ref), + } +} + +// GenericGitBackend is the backend for non-GitHub/non-ADO/non-GitLab hosts (Gitea/Gogs/Bitbucket, ...). +type GenericGitBackend struct { + hostInfo auth.HostInfo +} + +func (b *GenericGitBackend) Kind() string { return "generic" } +func (b *GenericGitBackend) IsGitHubFamily() bool { return false } +func (b *GenericGitBackend) IsGeneric() bool { return true } +func (b *GenericGitBackend) GetHostInfo() auth.HostInfo { return b.hostInfo } + +func (b *GenericGitBackend) BuildCloneHTTPSURL(dep DepRef, token string, authScheme string) string { + host := urlHost(dep, b.hostInfo) + port := dep.GetPort() + if authScheme == "bearer" { + token = "" + } + return buildHTTPSCloneURL(host, dep.GetRepoURL(), token, port) +} + +func (b *GenericGitBackend) BuildCloneSSHURL(dep DepRef) string { + host := urlHost(dep, b.hostInfo) + return buildSSHURL(host, dep.GetRepoURL(), dep.GetPort()) +} + +func (b *GenericGitBackend) BuildCloneHTTPURL(dep DepRef) (string, error) { + host := urlHost(dep, b.hostInfo) + n := netloc(host, dep.GetPort()) + return fmt.Sprintf("http://%s/%s.git", n, dep.GetRepoURL()), nil +} + +func (b *GenericGitBackend) BuildCommitsAPIURL(_ DepRef, _ string) string { return "" } + +func (b *GenericGitBackend) BuildContentsAPIURLs(owner, repo, filePath, ref string) []string { + host := b.hostInfo.Host + return []string{ + fmt.Sprintf("https://%s/api/v1/repos/%s/%s/contents/%s?ref=%s", host, owner, repo, filePath, ref), + fmt.Sprintf("https://%s/api/v3/repos/%s/%s/contents/%s?ref=%s", host, owner, repo, filePath, ref), + } +} + +// --------------------------------------------------------------------------- +// Dispatch +// --------------------------------------------------------------------------- + +// BackendFor picks the right HostBackend for a DepRef. +// Falls back to GenericGitBackend when the host kind cannot be classified. +func BackendFor(dep DepRef, fallbackHost string) HostBackend { + var host string + var port *int + if dep != nil && dep.GetHost() != "" { + host = dep.GetHost() + port = dep.GetPort() + } else { + if fallbackHost != "" { + host = fallbackHost + } else { + host = githubhost.DefaultHost() + } + } + + // ADO short-circuit + if dep != nil && dep.IsAzureDevOps() { + info := auth.ClassifyHost(host, port) + return &ADOBackend{hostInfo: info} + } + + info := auth.ClassifyHost(host, port) + return backendFromInfo(info) +} + +// BackendForHost picks the right HostBackend for a bare hostname. +func BackendForHost(host string, port *int) HostBackend { + info := auth.ClassifyHost(host, port) + return backendFromInfo(info) +} + +func backendFromInfo(info auth.HostInfo) HostBackend { + base := gitHubFamilyBase{hostInfo: info} + switch info.Kind { + case "github": + base.kind = "github" + return &GitHubBackend{base} + case "ghe_cloud": + base.kind = "ghe_cloud" + return &GHECloudBackend{base} + case "ghes": + base.kind = "ghes" + return &GHESBackend{base} + case "ado": + return &ADOBackend{hostInfo: info} + case "gitlab": + return &GitLabBackend{hostInfo: info} + default: + return &GenericGitBackend{hostInfo: info} + } +} + +// Ensure ADOBackend satisfies a narrower interface for compile-time check. +var _ interface { + BuildCloneSSHURL(dep DepRef) string + GetHostInfo() auth.HostInfo +} = (*ADOBackend)(nil) diff --git a/internal/deps/lockfile/lockfile.go b/internal/deps/lockfile/lockfile.go new file mode 100644 index 0000000..b16a369 --- /dev/null +++ b/internal/deps/lockfile/lockfile.go @@ -0,0 +1,679 @@ +// Package lockfile provides APM lock file structures for reproducible installs. +// +// Migrated from src/apm_cli/deps/lockfile.py +package lockfile + +import ( +"bufio" +"fmt" +"os" +"path/filepath" +"sort" +"strconv" +"strings" +"time" +) + +const ( +LockfileName = "apm.lock.yaml" +LegacyLockfileName = "apm.lock" +selfKey = "." +) + +// LockedDependency represents a resolved dependency with exact version info. +type LockedDependency struct { +RepoURL string +Host string +Port int // 0 = unset +RegistryPrefix string +ResolvedCommit string +ResolvedRef string +Version string +VirtualPath string +IsVirtual bool +Depth int +ResolvedBy string +PackageType string +DeployedFiles []string +DeployedFileHashes map[string]string +Source string // "local" for local deps +LocalPath string +ContentHash string +IsDev bool +DiscoveredVia string +MarketplacePluginName string +IsInsecure bool +AllowInsecure bool +SkillSubset []string +} + +// GetUniqueKey returns the unique key for this dependency. +func (d *LockedDependency) GetUniqueKey() string { +if d.Source == "local" && d.LocalPath != "" { +return d.LocalPath +} +if d.IsVirtual && d.VirtualPath != "" { +return d.RepoURL + "/" + d.VirtualPath +} +return d.RepoURL +} + +// ToDict serializes the dependency to a string map for YAML output. +func (d *LockedDependency) ToDict() map[string]interface{} { +result := map[string]interface{}{"repo_url": d.RepoURL} +if d.Host != "" { +result["host"] = d.Host +} +if d.Port != 0 { +result["port"] = d.Port +} +if d.RegistryPrefix != "" { +result["registry_prefix"] = d.RegistryPrefix +} +if d.ResolvedCommit != "" { +result["resolved_commit"] = d.ResolvedCommit +} +if d.ResolvedRef != "" { +result["resolved_ref"] = d.ResolvedRef +} +if d.Version != "" { +result["version"] = d.Version +} +if d.VirtualPath != "" { +result["virtual_path"] = d.VirtualPath +} +if d.IsVirtual { +result["is_virtual"] = true +} +if d.Depth != 1 { +result["depth"] = d.Depth +} +if d.ResolvedBy != "" { +result["resolved_by"] = d.ResolvedBy +} +if d.PackageType != "" { +result["package_type"] = d.PackageType +} +if len(d.DeployedFiles) > 0 { +sorted := append([]string{}, d.DeployedFiles...) +sort.Strings(sorted) +result["deployed_files"] = sorted +} +if len(d.DeployedFileHashes) > 0 { +result["deployed_file_hashes"] = sortedMapCopy(d.DeployedFileHashes) +} +if d.Source != "" { +result["source"] = d.Source +} +if d.LocalPath != "" { +result["local_path"] = d.LocalPath +} +if d.ContentHash != "" { +result["content_hash"] = d.ContentHash +} +if d.IsDev { +result["is_dev"] = true +} +if d.DiscoveredVia != "" { +result["discovered_via"] = d.DiscoveredVia +} +if d.MarketplacePluginName != "" { +result["marketplace_plugin_name"] = d.MarketplacePluginName +} +if d.IsInsecure { +result["is_insecure"] = true +} +if d.AllowInsecure { +result["allow_insecure"] = true +} +if len(d.SkillSubset) > 0 { +sorted := append([]string{}, d.SkillSubset...) +sort.Strings(sorted) +result["skill_subset"] = sorted +} +return result +} + +// LockedDepFromMap deserializes a LockedDependency from a parsed YAML map. +func LockedDepFromMap(data map[string]interface{}) (*LockedDependency, error) { +repoURL, ok := data["repo_url"].(string) +if !ok || repoURL == "" { +return nil, fmt.Errorf("missing repo_url") +} + +deployedFiles := strSlice(data["deployed_files"]) +// Migrate legacy deployed_skills -> deployed_files +if oldSkills := strSlice(data["deployed_skills"]); len(oldSkills) > 0 && len(deployedFiles) == 0 { +for _, sk := range oldSkills { +deployedFiles = append(deployedFiles, ".github/skills/"+sk+"/") +deployedFiles = append(deployedFiles, ".claude/skills/"+sk+"/") +} +} + +var port int +if pRaw, ok := data["port"]; ok && pRaw != nil { +switch v := pRaw.(type) { +case int: +if v >= 1 && v <= 65535 { +port = v +} +case float64: +p := int(v) +if p >= 1 && p <= 65535 { +port = p +} +case string: +if p, err := strconv.Atoi(v); err == nil && p >= 1 && p <= 65535 { +port = p +} +} +} + +dep := &LockedDependency{ +RepoURL: repoURL, +Host: strVal(data["host"]), +Port: port, +RegistryPrefix: strVal(data["registry_prefix"]), +ResolvedCommit: strVal(data["resolved_commit"]), +ResolvedRef: strVal(data["resolved_ref"]), +Version: strVal(data["version"]), +VirtualPath: strVal(data["virtual_path"]), +IsVirtual: boolVal(data["is_virtual"]), +Depth: intVal(data["depth"], 1), +ResolvedBy: strVal(data["resolved_by"]), +PackageType: strVal(data["package_type"]), +DeployedFiles: deployedFiles, +DeployedFileHashes: strMap(data["deployed_file_hashes"]), +Source: strVal(data["source"]), +LocalPath: strVal(data["local_path"]), +ContentHash: strVal(data["content_hash"]), +IsDev: boolVal(data["is_dev"]), +DiscoveredVia: strVal(data["discovered_via"]), +MarketplacePluginName: strVal(data["marketplace_plugin_name"]), +IsInsecure: boolVal(data["is_insecure"]), +AllowInsecure: boolVal(data["allow_insecure"]), +SkillSubset: strSlice(data["skill_subset"]), +} +return dep, nil +} + +// LockFile represents an APM lock file. +type LockFile struct { +LockfileVersion string +GeneratedAt string +APMVersion string +Dependencies map[string]*LockedDependency +MCPServers []string +MCPConfigs map[string]map[string]interface{} +LocalDeployedFiles []string +LocalDeployedFileHashes map[string]string +} + +// NewLockFile creates a new empty LockFile. +func NewLockFile() *LockFile { +return &LockFile{ +LockfileVersion: "1", +GeneratedAt: time.Now().UTC().Format(time.RFC3339), +Dependencies: make(map[string]*LockedDependency), +MCPConfigs: make(map[string]map[string]interface{}), +LocalDeployedFileHashes: make(map[string]string), +} +} + +// AddDependency adds a dependency to the lock file. +func (lf *LockFile) AddDependency(dep *LockedDependency) { +lf.Dependencies[dep.GetUniqueKey()] = dep +} + +// GetDependency returns a dependency by key. +func (lf *LockFile) GetDependency(key string) *LockedDependency { +return lf.Dependencies[key] +} + +// HasDependency checks if a dependency exists. +func (lf *LockFile) HasDependency(key string) bool { +_, ok := lf.Dependencies[key] +return ok +} + +// GetAllDependencies returns all dependencies sorted by depth then repo_url. +func (lf *LockFile) GetAllDependencies() []*LockedDependency { +deps := make([]*LockedDependency, 0, len(lf.Dependencies)) +for _, d := range lf.Dependencies { +deps = append(deps, d) +} +sort.Slice(deps, func(i, j int) bool { +if deps[i].Depth != deps[j].Depth { +return deps[i].Depth < deps[j].Depth +} +return deps[i].RepoURL < deps[j].RepoURL +}) +return deps +} + +// GetPackageDependencies returns all dependencies excluding the virtual self-entry. +func (lf *LockFile) GetPackageDependencies() []*LockedDependency { +var result []*LockedDependency +for _, d := range lf.GetAllDependencies() { +if d.LocalPath != "." { +result = append(result, d) +} +} +return result +} + +// IsSemanticalllyEquivalent returns true if other has the same deps/MCP/configs. +func (lf *LockFile) IsSemanticalllyEquivalent(other *LockFile) bool { +if lf.LockfileVersion != other.LockfileVersion { +return false +} +if len(lf.Dependencies) != len(other.Dependencies) { +return false +} +for key, dep := range lf.Dependencies { +od, ok := other.Dependencies[key] +if !ok { +return false +} +if fmt.Sprint(dep.ToDict()) != fmt.Sprint(od.ToDict()) { +return false +} +} +// MCP servers +as := append([]string{}, lf.MCPServers...) +bs := append([]string{}, other.MCPServers...) +sort.Strings(as) +sort.Strings(bs) +if strings.Join(as, ",") != strings.Join(bs, ",") { +return false +} +if fmt.Sprint(lf.MCPConfigs) != fmt.Sprint(other.MCPConfigs) { +return false +} +af := append([]string{}, lf.LocalDeployedFiles...) +bf := append([]string{}, other.LocalDeployedFiles...) +sort.Strings(af) +sort.Strings(bf) +if strings.Join(af, ",") != strings.Join(bf, ",") { +return false +} +return fmt.Sprint(sortedMapCopy(lf.LocalDeployedFileHashes)) == fmt.Sprint(sortedMapCopy(other.LocalDeployedFileHashes)) +} + +// FromYAML parses a LockFile from a simple line-by-line YAML reader. +// This is a minimal parser for the known lockfile schema. +func FromYAML(content string) (*LockFile, error) { +lf := NewLockFile() +scanner := bufio.NewScanner(strings.NewReader(content)) +var lines []string +for scanner.Scan() { +lines = append(lines, scanner.Text()) +} + +// Simple state machine parser +i := 0 +for i < len(lines) { +line := lines[i] +trimmed := strings.TrimSpace(line) + +if strings.HasPrefix(trimmed, "lockfile_version:") { +lf.LockfileVersion = yamlValue(trimmed) +i++ +} else if strings.HasPrefix(trimmed, "generated_at:") { +lf.GeneratedAt = yamlValue(trimmed) +i++ +} else if strings.HasPrefix(trimmed, "apm_version:") { +lf.APMVersion = yamlValue(trimmed) +i++ +} else if trimmed == "dependencies:" { +i++ +// Parse list of dependency maps +for i < len(lines) { +dl := lines[i] +dtrimmed := strings.TrimSpace(dl) +if strings.HasPrefix(dtrimmed, "- repo_url:") || dtrimmed == "-" { +depMap, n := parseYAMLMap(lines, i) +i += n +dep, err := LockedDepFromMap(depMap) +if err == nil { +lf.AddDependency(dep) +} +} else if !strings.HasPrefix(dl, " ") && !strings.HasPrefix(dl, "\t") && dl != "" { +break +} else { +i++ +} +} +} else if trimmed == "mcp_servers:" { +i++ +for i < len(lines) { +sl := strings.TrimSpace(lines[i]) +if strings.HasPrefix(sl, "- ") { +lf.MCPServers = append(lf.MCPServers, strings.TrimPrefix(sl, "- ")) +i++ +} else if sl == "" || !strings.HasPrefix(lines[i], " ") { +break +} else { +i++ +} +} +} else if trimmed == "local_deployed_files:" { +i++ +for i < len(lines) { +sl := strings.TrimSpace(lines[i]) +if strings.HasPrefix(sl, "- ") { +lf.LocalDeployedFiles = append(lf.LocalDeployedFiles, strings.TrimPrefix(sl, "- ")) +i++ +} else if sl == "" || !strings.HasPrefix(lines[i], " ") { +break +} else { +i++ +} +} +} else if trimmed == "local_deployed_file_hashes:" { +i++ +for i < len(lines) { +kl := lines[i] +ktrimmed := strings.TrimSpace(kl) +if strings.HasPrefix(lines[i], " ") && strings.Contains(ktrimmed, ":") { +parts := strings.SplitN(ktrimmed, ":", 2) +if len(parts) == 2 { +k := strings.Trim(strings.TrimSpace(parts[0]), `"'`) +v := strings.Trim(strings.TrimSpace(parts[1]), `"'`) +lf.LocalDeployedFileHashes[k] = v +} +i++ +} else { +break +} +} +} else { +i++ +} +} + +// Synthesize self-entry +if len(lf.LocalDeployedFiles) > 0 { +lf.Dependencies[selfKey] = &LockedDependency{ +RepoURL: "", +Source: "local", +LocalPath: ".", +IsDev: true, +Depth: 0, +DeployedFiles: append([]string{}, lf.LocalDeployedFiles...), +DeployedFileHashes: copyStrMap(lf.LocalDeployedFileHashes), +} +} + +return lf, nil +} + +// GetLockfilePath returns the path to the lock file for a project. +func GetLockfilePath(projectRoot string) string { +return filepath.Join(projectRoot, LockfileName) +} + +// MigrateLockfileIfNeeded renames legacy apm.lock to apm.lock.yaml. +func MigrateLockfileIfNeeded(projectRoot string) bool { +newPath := GetLockfilePath(projectRoot) +legacyPath := filepath.Join(projectRoot, LegacyLockfileName) +if _, err := os.Stat(newPath); os.IsNotExist(err) { +if _, err2 := os.Stat(legacyPath); err2 == nil { +if err3 := os.Rename(legacyPath, newPath); err3 == nil { +return true +} +} +} +return false +} + +// ReadLockfile reads a lock file from disk. +func ReadLockfile(path string) (*LockFile, error) { +data, err := os.ReadFile(path) +if err != nil { +return nil, err +} +return FromYAML(string(data)) +} + +// LoadOrCreate loads a lock file or creates a new one. +func LoadOrCreate(path string) *LockFile { +lf, err := ReadLockfile(path) +if err != nil || lf == nil { +return NewLockFile() +} +return lf +} + +// --- YAML parsing helpers --- + +// parseYAMLMap parses a YAML list item (map) starting at lines[start]. +// Returns the map and the number of lines consumed. +func parseYAMLMap(lines []string, start int) (map[string]interface{}, int) { +result := make(map[string]interface{}) +i := start + +// Consume leading "- " prefix on first line +firstLine := strings.TrimSpace(lines[i]) +if strings.HasPrefix(firstLine, "- ") { +kv := strings.TrimPrefix(firstLine, "- ") +if strings.Contains(kv, ":") { +parts := strings.SplitN(kv, ":", 2) +k := strings.TrimSpace(parts[0]) +v := strings.TrimSpace(parts[1]) +result[k] = unquote(v) +} +i++ +} else if firstLine == "-" { +i++ +} + +// indent of the block items +blockIndent := "" +for i < len(lines) { +line := lines[i] +if strings.TrimSpace(line) == "" { +i++ +continue +} +// Detect indentation +for _, c := range line { +if c == ' ' { +blockIndent += " " +} else { +break +} +} +break +} +if blockIndent == "" { +blockIndent = " " +} + +for i < len(lines) { +line := lines[i] +trimmed := strings.TrimSpace(line) + +if trimmed == "" { +i++ +continue +} +// End of this map item +if strings.HasPrefix(trimmed, "- ") || (!strings.HasPrefix(line, blockIndent) && !strings.HasPrefix(line, " ")) { +break +} +// Nested list +if strings.Contains(trimmed, ":") { +parts := strings.SplitN(trimmed, ":", 2) +key := strings.TrimSpace(parts[0]) +val := strings.TrimSpace(parts[1]) +if val == "" { +// collect sub-list or sub-map +i++ +var subList []string +subMap := make(map[string]interface{}) +isList := false +for i < len(lines) { +sl := lines[i] +strimmed := strings.TrimSpace(sl) +if strimmed == "" { +i++ +continue +} +if !strings.HasPrefix(sl, blockIndent) { +break +} +if strings.HasPrefix(strimmed, "- ") { +isList = true +subList = append(subList, strings.TrimPrefix(strimmed, "- ")) +i++ +} else if strings.Contains(strimmed, ":") { +kp := strings.SplitN(strimmed, ":", 2) +sk := strings.TrimSpace(kp[0]) +sv := strings.Trim(strings.TrimSpace(kp[1]), `"'`) +subMap[sk] = sv +i++ +} else { +break +} +} +if isList { +result[key] = subList +} else { +result[key] = subMap +} +continue +} +result[key] = parseScalar(val) +i++ +} else { +i++ +} +} +return result, i - start +} + +func yamlValue(line string) string { +idx := strings.Index(line, ":") +if idx < 0 { +return "" +} +return strings.Trim(strings.TrimSpace(line[idx+1:]), `"'`) +} + +func unquote(s string) interface{} { +s = strings.TrimSpace(s) +if s == "" { +return nil +} +return parseScalar(s) +} + +func parseScalar(s string) interface{} { +s = strings.Trim(s, `"'`) +if s == "true" { +return true +} +if s == "false" { +return false +} +if s == "null" || s == "~" { +return nil +} +if n, err := strconv.Atoi(s); err == nil { +return n +} +if f, err := strconv.ParseFloat(s, 64); err == nil { +return f +} +return s +} + +// --- type coercion helpers --- + +func strVal(v interface{}) string { +if v == nil { +return "" +} +if s, ok := v.(string); ok { +return s +} +return fmt.Sprint(v) +} + +func boolVal(v interface{}) bool { +if v == nil { +return false +} +b, ok := v.(bool) +return ok && b +} + +func intVal(v interface{}, def int) int { +if v == nil { +return def +} +switch n := v.(type) { +case int: +return n +case float64: +return int(n) +} +return def +} + +func strSlice(v interface{}) []string { +if v == nil { +return nil +} +switch s := v.(type) { +case []string: +return s +case []interface{}: +result := make([]string, 0, len(s)) +for _, item := range s { +result = append(result, strVal(item)) +} +return result +} +return nil +} + +func strMap(v interface{}) map[string]string { +if v == nil { +return make(map[string]string) +} +switch m := v.(type) { +case map[string]string: +return m +case map[string]interface{}: +result := make(map[string]string, len(m)) +for k, val := range m { +result[k] = strVal(val) +} +return result +case map[interface{}]interface{}: +result := make(map[string]string, len(m)) +for k, val := range m { +result[strVal(k)] = strVal(val) +} +return result +} +return make(map[string]string) +} + +func sortedMapCopy(m map[string]string) map[string]string { +result := make(map[string]string, len(m)) +for k, v := range m { +result[k] = v +} +return result +} + +func copyStrMap(m map[string]string) map[string]string { +result := make(map[string]string, len(m)) +for k, v := range m { +result[k] = v +} +return result +} diff --git a/internal/deps/lockfile/lockfile_test.go b/internal/deps/lockfile/lockfile_test.go new file mode 100644 index 0000000..863f558 --- /dev/null +++ b/internal/deps/lockfile/lockfile_test.go @@ -0,0 +1,119 @@ +package lockfile + +import ( +"testing" +) + +const sampleYAML = `lockfile_version: "1" +generated_at: 2026-01-01T00:00:00Z +apm_version: "1.0.0" +dependencies: + - repo_url: https://github.com/owner/repo + resolved_commit: abc123 + depth: 1 + is_dev: false + - repo_url: https://github.com/owner/repo2 + resolved_commit: def456 + depth: 2 + is_dev: true +mcp_servers: + - my-server +local_deployed_files: + - .github/copilot-instructions.md +` + +func TestFromYAMLBasic(t *testing.T) { +lf, err := FromYAML(sampleYAML) +if err != nil { +t.Fatalf("FromYAML error: %v", err) +} +if lf.LockfileVersion != "1" { +t.Errorf("expected version 1, got %s", lf.LockfileVersion) +} +if lf.APMVersion != "1.0.0" { +t.Errorf("expected APMVersion 1.0.0, got %s", lf.APMVersion) +} +// 2 real deps + 1 self-entry (local_deployed_files not empty) +if !lf.HasDependency("https://github.com/owner/repo") { +t.Error("expected dep1") +} +if !lf.HasDependency("https://github.com/owner/repo2") { +t.Error("expected dep2") +} +if !lf.HasDependency(".") { +t.Error("expected self entry from local_deployed_files") +} +if len(lf.MCPServers) != 1 || lf.MCPServers[0] != "my-server" { +t.Errorf("unexpected mcp_servers: %v", lf.MCPServers) +} +} + +func TestNewLockFile(t *testing.T) { +lf := NewLockFile() +if lf.LockfileVersion != "1" { +t.Errorf("expected version 1") +} +if lf.GeneratedAt == "" { +t.Error("expected non-empty generated_at") +} +} + +func TestAddGetDependency(t *testing.T) { +lf := NewLockFile() +dep := &LockedDependency{ +RepoURL: "https://github.com/foo/bar", +Depth: 1, +} +lf.AddDependency(dep) +got := lf.GetDependency("https://github.com/foo/bar") +if got == nil { +t.Error("expected dependency") +} +if got.RepoURL != dep.RepoURL { +t.Errorf("repo_url mismatch") +} +} + +func TestGetAllDependenciesSorted(t *testing.T) { +lf := NewLockFile() +lf.AddDependency(&LockedDependency{RepoURL: "b", Depth: 2}) +lf.AddDependency(&LockedDependency{RepoURL: "a", Depth: 1}) +lf.AddDependency(&LockedDependency{RepoURL: "c", Depth: 1}) +deps := lf.GetAllDependencies() +if deps[0].RepoURL != "a" || deps[1].RepoURL != "c" || deps[2].RepoURL != "b" { +t.Errorf("unexpected order: %v", func() []string { +var s []string +for _, d := range deps { +s = append(s, d.RepoURL) +} +return s +}()) +} +} + +func TestGetLockfilePath(t *testing.T) { +p := GetLockfilePath("/project") +if p == "" { +t.Error("expected non-empty path") +} +} + +func TestLockedDepToDict(t *testing.T) { +dep := &LockedDependency{ +RepoURL: "https://example.com/repo", +ResolvedCommit: "abc", +Depth: 1, +IsDev: true, +} +d := dep.ToDict() +if d["repo_url"] != "https://example.com/repo" { +t.Error("repo_url mismatch") +} +if d["is_dev"] != true { +t.Error("is_dev should be true") +} +// depth == 1 should not be emitted +if _, ok := d["depth"]; ok { +t.Error("depth=1 should be omitted") +} +} diff --git a/internal/deps/pluginparser/pluginparser.go b/internal/deps/pluginparser/pluginparser.go new file mode 100644 index 0000000..1fd9076 --- /dev/null +++ b/internal/deps/pluginparser/pluginparser.go @@ -0,0 +1,450 @@ +// Package pluginparser parses Claude Code plugin.json manifests and +// synthesises apm.yml files from plugin directory layouts. +// +// Migrated from: src/apm_cli/deps/plugin_parser.py +package pluginparser + +import ( + "encoding/json" + "fmt" + "io/fs" + "log" + "os" + "path/filepath" + "strings" +) + +// PluginManifest holds the optional metadata from plugin.json. +type PluginManifest struct { + Name string `json:"name"` + MCPServers json.RawMessage `json:"mcpServers,omitempty"` + Agents []string `json:"agents,omitempty"` + Skills []string `json:"skills,omitempty"` + Commands []string `json:"commands,omitempty"` + Hooks json.RawMessage `json:"hooks,omitempty"` + Extra map[string]json.RawMessage +} + +// MCPServerConfig holds a single MCP server configuration. +type MCPServerConfig struct { + Command string `json:"command,omitempty"` + Args []string `json:"args,omitempty"` + URL string `json:"url,omitempty"` + Type string `json:"type,omitempty"` + Env map[string]string `json:"env,omitempty"` + Headers map[string]string `json:"headers,omitempty"` + Tools []string `json:"tools,omitempty"` +} + +// MCPDepEntry is a dependency entry generated from an MCP server config. +type MCPDepEntry struct { + Name string + Transport string + Command string + Args []string + URL string + Headers map[string]string + Env map[string]string + Tools []string + Registry bool +} + +// ParsePluginManifest parses a plugin.json file at the given path. +// Returns the parsed manifest or an error. +func ParsePluginManifest(pluginJSONPath string) (*PluginManifest, error) { + if _, err := os.Stat(pluginJSONPath); os.IsNotExist(err) { + return nil, fmt.Errorf("plugin.json not found: %s", pluginJSONPath) + } + data, err := os.ReadFile(pluginJSONPath) + if err != nil { + return nil, fmt.Errorf("failed to read plugin.json: %w", err) + } + var manifest PluginManifest + if err2 := json.Unmarshal(data, &manifest); err2 != nil { + return nil, fmt.Errorf("invalid JSON in plugin.json: %w", err2) + } + if manifest.Name == "" { + log.Printf("plugin.json at %s is missing 'name' field; falling back to directory name", pluginJSONPath) + } + return &manifest, nil +} + +// NormalizePluginDirectory normalises a Claude plugin directory into an APM package. +// +// Works with or without plugin.json. Returns the path to the generated apm.yml. +func NormalizePluginDirectory(pluginPath string, pluginJSONPath string) (string, error) { + var manifest *PluginManifest + + if pluginJSONPath != "" { + if _, err := os.Stat(pluginJSONPath); err == nil { + m, err2 := ParsePluginManifest(pluginJSONPath) + if err2 != nil { + // Treat as empty manifest; fall back to dir-name defaults + m = &PluginManifest{} + } + manifest = m + } + } + + if manifest == nil { + manifest = &PluginManifest{} + } + if manifest.Name == "" { + manifest.Name = filepath.Base(pluginPath) + } + + return SynthesizeApmYMLFromPlugin(pluginPath, manifest) +} + +// SynthesizeApmYMLFromPlugin synthesises apm.yml from plugin metadata. +func SynthesizeApmYMLFromPlugin(pluginPath string, manifest *PluginManifest) (string, error) { + if manifest.Name == "" { + manifest.Name = filepath.Base(pluginPath) + } + + // Create .apm directory structure + apmDir := filepath.Join(pluginPath, ".apm") + if err := os.MkdirAll(apmDir, 0o755); err != nil { + return "", fmt.Errorf("failed to create .apm directory: %w", err) + } + + // Map plugin structure into .apm/ subdirectories + if err := mapPluginArtifacts(pluginPath, apmDir, manifest); err != nil { + return "", err + } + + // Extract MCP servers + mcpServers, err := extractMCPServers(pluginPath, manifest) + if err != nil { + log.Printf("failed to extract MCP servers from plugin %s: %v", pluginPath, err) + } + + var mcpDeps []MCPDepEntry + if len(mcpServers) > 0 { + mcpDeps = mcpServersToDeps(mcpServers, pluginPath) + } + + // Generate apm.yml + content := generateApmYML(manifest, mcpDeps) + apmYMLPath := filepath.Join(pluginPath, "apm.yml") + if err2 := os.WriteFile(apmYMLPath, []byte(content), 0o644); err2 != nil { + return "", fmt.Errorf("failed to write apm.yml: %w", err2) + } + + return apmYMLPath, nil +} + +// extractMCPServers reads MCP server definitions from the plugin manifest. +func extractMCPServers(pluginPath string, manifest *PluginManifest) (map[string]MCPServerConfig, error) { + logger := log.Default() + + if manifest.MCPServers == nil { + // Fall back to auto-discovery + servers := map[string]MCPServerConfig{} + for _, candidate := range []string{".mcp.json", filepath.Join(".github", ".mcp.json")} { + fullPath := filepath.Join(pluginPath, candidate) + info, err := os.Lstat(fullPath) + if err == nil && info.Mode()&fs.ModeSymlink == 0 && info.Mode().IsRegular() { + s, err2 := readMCPJSON(fullPath) + if err2 == nil && len(s) > 0 { + servers = s + break + } + } + } + if len(servers) > 0 { + return substitutePlaceholder(servers, pluginPath, logger), nil + } + return servers, nil + } + + // Determine type of mcpServers value + raw := manifest.MCPServers + var servers map[string]MCPServerConfig + + // Try dict + if err := json.Unmarshal(raw, &servers); err == nil { + return substitutePlaceholder(servers, pluginPath, logger), nil + } + + // Try string (file path) + var strVal string + if err := json.Unmarshal(raw, &strVal); err == nil { + s, err2 := readMCPFile(pluginPath, strVal) + if err2 != nil { + logger.Printf("MCP file read failed: %v", err2) + return map[string]MCPServerConfig{}, nil + } + return substitutePlaceholder(s, pluginPath, logger), nil + } + + // Try array of string paths + var arrVal []string + if err := json.Unmarshal(raw, &arrVal); err == nil { + result := map[string]MCPServerConfig{} + for _, entry := range arrVal { + s, err2 := readMCPFile(pluginPath, entry) + if err2 != nil { + logger.Printf("MCP file read failed: %v", err2) + continue + } + for k, v := range s { + result[k] = v + } + } + return substitutePlaceholder(result, pluginPath, logger), nil + } + + logger.Printf("unsupported mcpServers type in plugin %s", pluginPath) + return map[string]MCPServerConfig{}, nil +} + +// readMCPFile reads a JSON file at relPath relative to pluginPath and returns its mcpServers dict. +func readMCPFile(pluginPath, relPath string) (map[string]MCPServerConfig, error) { + absPlug, _ := filepath.Abs(pluginPath) + target := filepath.Join(absPlug, relPath) + absTarget, err := filepath.Abs(target) + if err != nil { + return nil, fmt.Errorf("invalid path: %s", relPath) + } + // Security: must stay inside pluginPath + if !strings.HasPrefix(absTarget, absPlug+string(os.PathSeparator)) { + return nil, fmt.Errorf("MCP file path escapes plugin root: %s", relPath) + } + info, err := os.Lstat(absTarget) + if err != nil || !info.Mode().IsRegular() { + return nil, fmt.Errorf("MCP file not found or invalid: %s", absTarget) + } + return readMCPJSON(absTarget) +} + +// readMCPJSON parses a JSON file and returns the mcpServers dict. +func readMCPJSON(path string) (map[string]MCPServerConfig, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + var wrapper struct { + MCPServers map[string]MCPServerConfig `json:"mcpServers"` + } + if err2 := json.Unmarshal(data, &wrapper); err2 != nil { + return nil, err2 + } + if wrapper.MCPServers == nil { + return map[string]MCPServerConfig{}, nil + } + return wrapper.MCPServers, nil +} + +// substitutePlaceholder replaces ${CLAUDE_PLUGIN_ROOT} in string values. +func substitutePlaceholder(servers map[string]MCPServerConfig, pluginPath string, _ *log.Logger) map[string]MCPServerConfig { + absRoot, _ := filepath.Abs(pluginPath) + placeholder := "${CLAUDE_PLUGIN_ROOT}" + + replaceStr := func(s string) string { + return strings.ReplaceAll(s, placeholder, absRoot) + } + + result := make(map[string]MCPServerConfig, len(servers)) + for name, cfg := range servers { + cfg.Command = replaceStr(cfg.Command) + cfg.URL = replaceStr(cfg.URL) + newArgs := make([]string, len(cfg.Args)) + for i, a := range cfg.Args { + newArgs[i] = replaceStr(a) + } + cfg.Args = newArgs + if cfg.Env != nil { + newEnv := make(map[string]string, len(cfg.Env)) + for k, v := range cfg.Env { + newEnv[k] = replaceStr(v) + } + cfg.Env = newEnv + } + result[name] = cfg + } + return result +} + +// mcpServersToDeps converts raw MCP server configs to dependency dicts. +func mcpServersToDeps(servers map[string]MCPServerConfig, pluginPath string) []MCPDepEntry { + var deps []MCPDepEntry + for name, cfg := range servers { + dep := MCPDepEntry{Name: name, Registry: false} + if cfg.Command != "" { + dep.Transport = "stdio" + dep.Command = cfg.Command + dep.Args = cfg.Args + } else if cfg.URL != "" { + transport := cfg.Type + validTransports := map[string]bool{"http": true, "sse": true, "streamable-http": true} + if !validTransports[transport] { + transport = "http" + } + dep.Transport = transport + dep.URL = cfg.URL + dep.Headers = cfg.Headers + } else { + log.Printf("skipping MCP server %q from plugin %q: no 'command' or 'url'", name, filepath.Base(pluginPath)) + continue + } + dep.Env = cfg.Env + dep.Tools = cfg.Tools + deps = append(deps, dep) + } + return deps +} + +// mapPluginArtifacts copies plugin components to .apm/ subdirectories. +func mapPluginArtifacts(pluginPath, apmDir string, manifest *PluginManifest) error { + type mapping struct { + src string + dst string + isDir bool + } + + // Standard component mappings + componentMappings := []mapping{ + {"agents", filepath.Join(apmDir, "agents"), true}, + {"skills", filepath.Join(apmDir, "skills"), true}, + {"commands", filepath.Join(apmDir, "prompts"), true}, + {"hooks", filepath.Join(apmDir, "hooks"), true}, + } + + for _, m := range componentMappings { + srcPath := filepath.Join(pluginPath, m.src) + info, err := os.Lstat(srcPath) + if err != nil || info.Mode()&fs.ModeSymlink != 0 { + continue + } + if !info.IsDir() { + continue + } + // Verify path is within plugin root + abs, _ := filepath.Abs(srcPath) + absPlugin, _ := filepath.Abs(pluginPath) + if !strings.HasPrefix(abs, absPlugin+string(os.PathSeparator)) { + continue + } + if err2 := copyDir(srcPath, m.dst); err2 != nil { + log.Printf("warning: failed to copy %s to %s: %v", srcPath, m.dst, err2) + } + } + + // Pass-through files + passthroughs := []string{".mcp.json", ".lsp.json", "settings.json"} + for _, fname := range passthroughs { + src := filepath.Join(pluginPath, fname) + info, err := os.Lstat(src) + if err != nil || info.Mode()&fs.ModeSymlink != 0 || !info.Mode().IsRegular() { + continue + } + dst := filepath.Join(apmDir, fname) + if err2 := copyFile(src, dst); err2 != nil { + log.Printf("warning: failed to copy %s: %v", fname, err2) + } + } + + return nil +} + +// copyFile copies a single regular file. +func copyFile(src, dst string) error { + data, err := os.ReadFile(src) + if err != nil { + return err + } + return os.WriteFile(dst, data, 0o644) +} + +// copyDir recursively copies a directory. +func copyDir(src, dst string) error { + if err := os.MkdirAll(dst, 0o755); err != nil { + return err + } + entries, err := os.ReadDir(src) + if err != nil { + return err + } + for _, entry := range entries { + srcPath := filepath.Join(src, entry.Name()) + dstPath := filepath.Join(dst, entry.Name()) + // Skip symlinks + info, err2 := os.Lstat(srcPath) + if err2 != nil || info.Mode()&fs.ModeSymlink != 0 { + continue + } + if entry.IsDir() { + if err3 := copyDir(srcPath, dstPath); err3 != nil { + log.Printf("warning: copyDir %s: %v", srcPath, err3) + } + } else { + if err3 := copyFile(srcPath, dstPath); err3 != nil { + log.Printf("warning: copyFile %s: %v", srcPath, err3) + } + } + } + return nil +} + +// generateApmYML generates the apm.yml content from plugin metadata. +func generateApmYML(manifest *PluginManifest, mcpDeps []MCPDepEntry) string { + var sb strings.Builder + sb.WriteString("# Generated by APM from Claude plugin\n") + sb.WriteString("name: ") + sb.WriteString(yamlString(manifest.Name)) + sb.WriteString("\n\n") + + if len(mcpDeps) > 0 { + sb.WriteString("dependencies:\n mcp:\n") + for _, dep := range mcpDeps { + sb.WriteString(" - name: ") + sb.WriteString(yamlString(dep.Name)) + sb.WriteString("\n registry: false\n") + sb.WriteString(" transport: ") + sb.WriteString(dep.Transport) + sb.WriteString("\n") + if dep.Command != "" { + sb.WriteString(" command: ") + sb.WriteString(yamlString(dep.Command)) + sb.WriteString("\n") + if len(dep.Args) > 0 { + sb.WriteString(" args:\n") + for _, a := range dep.Args { + sb.WriteString(" - ") + sb.WriteString(yamlString(a)) + sb.WriteString("\n") + } + } + } + if dep.URL != "" { + sb.WriteString(" url: ") + sb.WriteString(dep.URL) + sb.WriteString("\n") + } + if len(dep.Env) > 0 { + sb.WriteString(" env:\n") + for k, v := range dep.Env { + sb.WriteString(" ") + sb.WriteString(k) + sb.WriteString(": ") + sb.WriteString(yamlString(v)) + sb.WriteString("\n") + } + } + } + } + + return sb.String() +} + +// yamlString wraps a string in quotes if needed. +func yamlString(s string) string { + if strings.ContainsAny(s, ":{}[]|>&*!,#?@`\"'\\") || + strings.Contains(s, " ") || + strings.Contains(s, "\n") { + escaped := strings.ReplaceAll(s, `"`, `\"`) + return `"` + escaped + `"` + } + return s +} diff --git a/internal/install/drift/drift.go b/internal/install/drift/drift.go new file mode 100644 index 0000000..df4d892 --- /dev/null +++ b/internal/install/drift/drift.go @@ -0,0 +1,187 @@ +// Package drift provides pure drift-detection helpers for diff-aware apm install. +// These functions are stateless and side-effect-free. +// Migrated from src/apm_cli/drift.py +package drift + +// DependencyRef is a minimal interface for dependency references. +// Implementations provide the fields compared during drift detection. +type DependencyRef interface { + // Reference returns the git ref pinned in apm.yml (may be ""). + Reference() string + // UniqueKey returns the canonical deduplication key (repo_url or repo_url/virtual_path). + UniqueKey() string + // IsInsecure returns true when the dep was declared with an insecure HTTP URL. + IsInsecure() bool + // Host returns the registry proxy host when set, or "". + Host() string + // ArtifactoryPrefix returns the Artifactory prefix when set, or "". + ArtifactoryPrefix() string +} + +// LockedDep is a minimal interface for lockfile dependency entries. +type LockedDep interface { + // ResolvedRef returns the ref recorded in the lockfile. + ResolvedRef() string + // ResolvedCommit returns the commit SHA recorded in the lockfile (may be ""). + ResolvedCommit() string + // DeployedFiles returns the list of deployed file paths. + DeployedFiles() []string + // IsInsecure returns the stored insecure flag. + IsInsecure() bool + // AllowInsecure returns the stored allow_insecure flag. + AllowInsecure() bool + // RegistryPrefix returns the Artifactory prefix (may be ""). + RegistryPrefix() string + // Host returns the locked host (may be ""). + Host() string +} + +// LockFile is a minimal interface for lockfile operations. +type LockFile interface { + // Dependencies returns all locked dependencies keyed by unique key. + Dependencies() map[string]LockedDep + // GetDependency returns the locked entry for the given unique key (nil if absent). + GetDependency(uniqueKey string) LockedDep +} + +// RefChangeResult holds the outcome of DetectRefChange. +type RefChangeResult struct { + Changed bool +} + +// DetectRefChange reports whether the manifest ref differs from the locked resolved_ref. +// +// Returns true for transitions: ref added ("" -> "v1.0.0"), +// ref removed ("main" -> ""), ref changed ("v1.0.0" -> "v2.0.0"), +// or HTTP-insecure flag toggle. +// +// Returns false when updateRefs is true (--update mode), when lockedDep is nil +// (new package), or when the ref is unchanged. +func DetectRefChange(depRef DependencyRef, lockedDep LockedDep, updateRefs bool) bool { + if updateRefs { + return false + } + if lockedDep == nil { + return false + } + if depRef.Reference() != lockedDep.ResolvedRef() { + return true + } + return depRef.IsInsecure() != lockedDep.IsInsecure() +} + +// DetectOrphans returns the set of deployed file paths whose owning package +// left the manifest. +// +// Only relevant for full installs (onlyPackages empty). Partial installs +// preserve all existing lockfile entries unchanged. +func DetectOrphans(existing LockFile, intendedDepKeys map[string]struct{}, onlyPackages []string) map[string]struct{} { + orphaned := map[string]struct{}{} + if len(onlyPackages) > 0 || existing == nil { + return orphaned + } + for depKey, dep := range existing.Dependencies() { + if _, ok := intendedDepKeys[depKey]; !ok { + for _, f := range dep.DeployedFiles() { + orphaned[f] = struct{}{} + } + } + } + return orphaned +} + +// DetectStaleFiles returns the set of paths that were deployed previously +// but are no longer produced by the current install. +// +// Pure set-difference: set(oldDeployed) - set(newDeployed). +func DetectStaleFiles(oldDeployed, newDeployed []string) map[string]struct{} { + newSet := make(map[string]struct{}, len(newDeployed)) + for _, f := range newDeployed { + newSet[f] = struct{}{} + } + stale := map[string]struct{}{} + for _, f := range oldDeployed { + if _, ok := newSet[f]; !ok { + stale[f] = struct{}{} + } + } + return stale +} + +// DetectConfigDrift returns names of entries whose current config differs +// from the stored baseline. +// +// Only entries with a stored baseline that has changed are returned. +// Brand-new entries (absent from storedConfigs) are excluded. +func DetectConfigDrift(currentConfigs, storedConfigs map[string]interface{}) map[string]struct{} { + drifted := map[string]struct{}{} + for name, current := range currentConfigs { + stored, ok := storedConfigs[name] + if !ok { + continue + } + if !configsEqual(current, stored) { + drifted[name] = struct{}{} + } + } + return drifted +} + +// configsEqual performs a deep equality check on two config values. +func configsEqual(a, b interface{}) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + switch av := a.(type) { + case map[string]interface{}: + bv, ok := b.(map[string]interface{}) + if !ok || len(av) != len(bv) { + return false + } + for k, va := range av { + vb, ok := bv[k] + if !ok || !configsEqual(va, vb) { + return false + } + } + return true + case []interface{}: + bv, ok := b.([]interface{}) + if !ok || len(av) != len(bv) { + return false + } + for i := range av { + if !configsEqual(av[i], bv[i]) { + return false + } + } + return true + default: + return a == b + } +} + +// DownloadRefOptions controls BuildDownloadRef behavior. +type DownloadRefOptions struct { + UpdateRefs bool + RefChanged bool +} + +// SimpleDepRef is a concrete DependencyRef implementation for use in tests +// and pipeline wiring. +type SimpleDepRef struct { + Ref string + Key string + Insecure bool + HostVal string + ArtifactoryPfx string +} + +func (s *SimpleDepRef) Reference() string { return s.Ref } +func (s *SimpleDepRef) UniqueKey() string { return s.Key } +func (s *SimpleDepRef) IsInsecure() bool { return s.Insecure } +func (s *SimpleDepRef) Host() string { return s.HostVal } +func (s *SimpleDepRef) ArtifactoryPrefix() string { return s.ArtifactoryPfx } diff --git a/internal/install/localbundle/localbundle.go b/internal/install/localbundle/localbundle.go new file mode 100644 index 0000000..742efb4 --- /dev/null +++ b/internal/install/localbundle/localbundle.go @@ -0,0 +1,133 @@ +// Package localbundle provides helpers for installing local APM bundles. +// +// Migrated from src/apm_cli/install/local_bundle_handler.py +package localbundle + +import ( +"encoding/json" +"os" +"path/filepath" +"strings" +) + +// MCPServerSpec represents a single MCP server entry from .mcp.json. +type MCPServerSpec struct { +Name string +Transport string +Command string +Args []string +Env map[string]string +URL string +Registry bool +Raw map[string]interface{} +} + +// ParseBundleMCPServers parses /.mcp.json into MCPServerSpec entries. +// Returns an empty slice when the file is missing or malformed. +func ParseBundleMCPServers(bundleDir string) []MCPServerSpec { +var mcpPath string +entries, err := os.ReadDir(bundleDir) +if err != nil { +return nil +} +for _, e := range entries { +if !e.IsDir() && strings.ToLower(e.Name()) == ".mcp.json" { +mcpPath = filepath.Join(bundleDir, e.Name()) +break +} +} +if mcpPath == "" { +return nil +} + +data, err := os.ReadFile(mcpPath) +if err != nil { +return nil +} +var root map[string]interface{} +if err := json.Unmarshal(data, &root); err != nil { +return nil +} + +serversRaw, ok := root["mcpServers"] +if !ok { +return nil +} +serversMap, ok := serversRaw.(map[string]interface{}) +if !ok { +return nil +} + +var out []MCPServerSpec +for name, cfgRaw := range serversMap { +cfg, ok := cfgRaw.(map[string]interface{}) +if !ok { +continue +} +spec := MCPServerSpec{ +Name: name, +Raw: cfg, +Command: strVal(cfg["command"]), +URL: strVal(cfg["url"]), +} +// transport / type +if t := strVal(cfg["type"]); t != "" { +spec.Transport = t +} else { +spec.Transport = strVal(cfg["transport"]) +} +// args +if argsRaw, ok := cfg["args"]; ok { +if argsSlice, ok := argsRaw.([]interface{}); ok { +for _, a := range argsSlice { +spec.Args = append(spec.Args, strVal(a)) +} +} +} +// env +spec.Env = strMapVal(cfg["env"]) +out = append(out, spec) +} +return out +} + +// BundleMCPPresent returns true if the bundle directory contains a .mcp.json file. +func BundleMCPPresent(bundleDir string) bool { +entries, err := os.ReadDir(bundleDir) +if err != nil { +return false +} +for _, e := range entries { +if !e.IsDir() && strings.ToLower(e.Name()) == ".mcp.json" { +return true +} +} +return false +} + +func strVal(v interface{}) string { +if v == nil { +return "" +} +if s, ok := v.(string); ok { +return s +} +return "" +} + +func strMapVal(v interface{}) map[string]string { +if v == nil { +return nil +} +switch m := v.(type) { +case map[string]interface{}: +result := make(map[string]string, len(m)) +for k, val := range m { +result[k] = strVal(val) +} +return result +case map[string]string: +return m +} +return nil +} diff --git a/internal/install/localbundle/localbundle_test.go b/internal/install/localbundle/localbundle_test.go new file mode 100644 index 0000000..711a068 --- /dev/null +++ b/internal/install/localbundle/localbundle_test.go @@ -0,0 +1,62 @@ +package localbundle + +import ( +"encoding/json" +"os" +"path/filepath" +"testing" +) + +func TestParseBundleMCPServers(t *testing.T) { +dir := t.TempDir() +data := map[string]interface{}{ +"mcpServers": map[string]interface{}{ +"my-server": map[string]interface{}{ +"command": "npx", +"args": []interface{}{"-y", "my-pkg"}, +"type": "stdio", +}, +}, +} +b, _ := json.Marshal(data) +if err := os.WriteFile(filepath.Join(dir, ".mcp.json"), b, 0644); err != nil { +t.Fatal(err) +} +servers := ParseBundleMCPServers(dir) +if len(servers) != 1 { +t.Fatalf("expected 1 server, got %d", len(servers)) +} +s := servers[0] +if s.Name != "my-server" { +t.Errorf("expected my-server, got %s", s.Name) +} +if s.Command != "npx" { +t.Errorf("expected npx, got %s", s.Command) +} +if s.Transport != "stdio" { +t.Errorf("expected stdio, got %s", s.Transport) +} +} + +func TestParseBundleMCPServersMissing(t *testing.T) { +dir := t.TempDir() +servers := ParseBundleMCPServers(dir) +if len(servers) != 0 { +t.Errorf("expected no servers, got %d", len(servers)) +} +} + +func TestBundleMCPPresent(t *testing.T) { +dir := t.TempDir() +os.WriteFile(filepath.Join(dir, ".mcp.json"), []byte("{}"), 0644) +if !BundleMCPPresent(dir) { +t.Error("expected true") +} +} + +func TestBundleMCPPresentFalse(t *testing.T) { +dir := t.TempDir() +if BundleMCPPresent(dir) { +t.Error("expected false") +} +} diff --git a/internal/install/securityscan/securityscan.go b/internal/install/securityscan/securityscan.go new file mode 100644 index 0000000..253bf24 --- /dev/null +++ b/internal/install/securityscan/securityscan.go @@ -0,0 +1,122 @@ +// Package securityscan provides the pre-deploy security scan helper for the install pipeline. +// Migrated from src/apm_cli/install/helpers/security_scan.py +// +// Wraps the SecurityGate scanner used by the install pipeline. The scan detects +// hidden characters (zero-width joiners, bidirectional overrides, etc.) that could +// be used to smuggle malicious payloads into prompts, skills, or agent definitions. +package securityscan + +import ( + "fmt" + "os" + "path/filepath" +) + +// Finding represents a single security finding in a file. +type Finding struct { + // FilePath is the file where the finding was detected. + FilePath string + // Description describes the hidden-character pattern found. + Description string + // Line is the 1-based line number (0 = unknown). + Line int +} + +// ScanResult holds the outcome of a pre-deploy security scan. +type ScanResult struct { + // HasFindings is true when at least one finding was detected. + HasFindings bool + // ShouldBlock is true when the finding severity warrants blocking install. + ShouldBlock bool + // Findings is the list of detected findings. + Findings []Finding + // FilesScanned is the number of files that were examined. + FilesScanned int +} + +// scannerFunc is the function signature for the security gate scan. +// Provided as a variable so tests can inject a stub. +var scannerFunc func(root string, force bool) (*ScanResult, error) = defaultScanner + +// defaultScanner performs a simple hidden-character scan on all files under root. +// This is a lightweight stdlib-only implementation; the full Python SecurityGate +// uses a richer classification engine (separate migration). +func defaultScanner(root string, force bool) (*ScanResult, error) { + result := &ScanResult{} + + err := filepath.WalkDir(root, func(path string, d os.DirEntry, err error) error { + if err != nil { + return nil // skip unreadable entries + } + if d.IsDir() { + return nil + } + + data, readErr := os.ReadFile(path) + if readErr != nil { + return nil + } + result.FilesScanned++ + + findings := scanBytes(path, data) + if len(findings) > 0 { + result.HasFindings = true + result.ShouldBlock = true + result.Findings = append(result.Findings, findings...) + } + return nil + }) + return result, err +} + +// hiddenPatterns are Unicode code-points considered dangerous in prompt/skill files. +var hiddenPatterns = []rune{ + '\u200B', // zero-width space + '\u200C', // zero-width non-joiner + '\u200D', // zero-width joiner + '\u2028', // line separator + '\u2029', // paragraph separator + '\u202A', // left-to-right embedding + '\u202B', // right-to-left embedding + '\u202C', // pop directional formatting + '\u202D', // left-to-right override + '\u202E', // right-to-left override (most dangerous) + '\uFEFF', // byte order mark (mid-file) + '\u00AD', // soft hyphen +} + +func scanBytes(path string, data []byte) []Finding { + var findings []Finding + text := string(data) + for _, r := range hiddenPatterns { + for i, c := range text { + if c == r { + findings = append(findings, Finding{ + FilePath: path, + Description: fmt.Sprintf("hidden character U+%04X at byte offset %d", r, i), + }) + break // one finding per pattern per file + } + } + } + return findings +} + +// PreDeploySecurityScan scans package source files for hidden characters BEFORE deployment. +// +// Returns true if deployment should proceed, false to block. +// When force is true the scan still runs but never returns false (block is suppressed). +func PreDeploySecurityScan(installPath string, packageName string, force bool) (bool, *ScanResult) { + result, err := scannerFunc(installPath, force) + if err != nil || result == nil { + // Scan error -- allow deployment to proceed (fail-open) + return true, &ScanResult{} + } + if !result.HasFindings { + return true, result + } + if force || !result.ShouldBlock { + return true, result + } + return false, result +} diff --git a/internal/install/template/template.go b/internal/install/template/template.go new file mode 100644 index 0000000..9a6f59e --- /dev/null +++ b/internal/install/template/template.go @@ -0,0 +1,174 @@ +// Package template implements the shared post-acquire integration flow for all DependencySources. +// This is the Template Method companion to the Strategy pattern in install/sources. +package template + +// Deltas holds counter-deltas from integration of one package. +type Deltas map[string]int + +// PackageInfo is a minimal representation of a resolved package. +type PackageInfo struct { +Name string +Path string +} + +// Materialization represents the result of a DependencySource.acquire() call. +type Materialization struct { +InstallPath string +DepKey string +PackageInfo *PackageInfo +Deltas Deltas +} + +// IntegrationResult holds integration counts for one package. +type IntegrationResult struct { +Prompts int +Agents int +Skills int +SubSkills int +Instructions int +Commands int +Hooks int +LinksResolved int +DeployedFiles []string +} + +// SecurityGateFunc is the signature of the pre-deploy security gate. +type SecurityGateFunc func(installPath, packageName string, force bool) bool + +// IntegrateFunc is the signature of the primitive integrator. +type IntegrateFunc func(info *PackageInfo, projectRoot string) (*IntegrationResult, error) + +// DiagnosticsCounter supports per-package diagnostic counts. +type DiagnosticsCounter interface { +CountForPackage(depKey, kind string) int +AddError(msg, pkg string) +} + +// Logger supports verbose package-inline warnings. +type Logger interface { +Verbose() bool +PackageInlineWarning(msg string) +} + +// Config holds all dependencies for RunIntegrationTemplate. +type Config struct { +SecurityGate SecurityGateFunc +Integrate IntegrateFunc +Diagnostics DiagnosticsCounter +Logger Logger +ProjectRoot string +HasTargets bool +Force bool +// IntegrateErrorPrefix is the per-source error prefix (Strategy pattern). +IntegrateErrorPrefix string +// IsLocal indicates whether the dep ref is local (for error key selection). +IsLocal bool +LocalPath string +// PackageDeployedFiles is updated in place. +PackageDeployedFiles map[string][]string +} + +// RunIntegrationTemplate runs the shared post-acquire integration flow. +// Returns a counter-delta map, or nil if the materialization is nil (source declined). +func RunIntegrationTemplate(m *Materialization, cfg *Config) Deltas { +if m == nil { +return nil +} +return integrateMaterilaization(m, cfg) +} + +func integrateMaterilaization(m *Materialization, cfg *Config) Deltas { +deltas := m.Deltas +if deltas == nil { +deltas = Deltas{} +} + +// No-op when targets are empty or acquire decided to skip integration. +if m.PackageInfo == nil || !cfg.HasTargets { +cfg.PackageDeployedFiles[m.DepKey] = []string{} +return deltas +} + +defer func() { +// Verbose: inline skip / error count for this package. +if cfg.Logger != nil && cfg.Logger.Verbose() { +skipCount := cfg.Diagnostics.CountForPackage(m.DepKey, "collision") +errCount := cfg.Diagnostics.CountForPackage(m.DepKey, "error") +if skipCount > 0 { +noun := "file" +if skipCount != 1 { +noun = "files" +} +cfg.Logger.PackageInlineWarning( +" [!] " + itoa(skipCount) + " " + noun + " skipped (local files exist)", +) +} +if errCount > 0 { +noun := "error" +if errCount != 1 { +noun = "errors" +} +cfg.Logger.PackageInlineWarning( +" [!] " + itoa(errCount) + " integration " + noun, +) +} +} +}() + +// Pre-deploy security gate. +if cfg.SecurityGate != nil { +if !cfg.SecurityGate(m.InstallPath, m.DepKey, cfg.Force) { +cfg.PackageDeployedFiles[m.DepKey] = []string{} +return deltas +} +} + +// Primitive integration. +if cfg.Integrate != nil { +result, err := cfg.Integrate(m.PackageInfo, cfg.ProjectRoot) +if err != nil { +packageKey := m.DepKey +if cfg.IsLocal && cfg.LocalPath != "" { +packageKey = cfg.LocalPath +} +cfg.Diagnostics.AddError(cfg.IntegrateErrorPrefix+": "+err.Error(), packageKey) +} else if result != nil { +deltas["prompts"] = result.Prompts +deltas["agents"] = result.Agents +deltas["skills"] = result.Skills +deltas["sub_skills"] = result.SubSkills +deltas["instructions"] = result.Instructions +deltas["commands"] = result.Commands +deltas["hooks"] = result.Hooks +deltas["links_resolved"] = result.LinksResolved +cfg.PackageDeployedFiles[m.DepKey] = result.DeployedFiles +} +} + +return deltas +} + +// itoa converts an int to a string without importing strconv at call sites. +func itoa(n int) string { +if n == 0 { +return "0" +} +neg := n < 0 +if neg { +n = -n +} +buf := make([]byte, 20) +i := len(buf) +for n >= 10 { +i-- +buf[i] = byte('0' + n%10) +n /= 10 +} +i-- +buf[i] = byte('0' + n) +if neg { +i-- +buf[i] = '-' +} +return string(buf[i:]) +} diff --git a/internal/integration/agentintegrator/agentintegrator.go b/internal/integration/agentintegrator/agentintegrator.go new file mode 100644 index 0000000..ffe1ea0 --- /dev/null +++ b/internal/integration/agentintegrator/agentintegrator.go @@ -0,0 +1,399 @@ +// Package agentintegrator handles integration of APM package agents into +// .github/agents/, .claude/agents/, .cursor/agents/ etc. +// Ported from src/apm_cli/integration/agent_integrator.py +package agentintegrator + +import ( + "fmt" + "io/fs" + "os" + "path/filepath" + "regexp" + "strings" + + "github.com/githubnext/apm/internal/integration/baseintegrator" + "github.com/githubnext/apm/internal/integration/targets" +) + +// AgentIntegrator handles agent file integration for a single package. +type AgentIntegrator struct{} + +// FindAgentFiles returns all .agent.md and .chatmode.md files in a package. +// Searches package root, .apm/agents/ (with rglob), and .apm/chatmodes/ (legacy). +func FindAgentFiles(packagePath string) []string { + var agentFiles []string + seen := map[string]struct{}{} + + add := func(p string) { + abs, _ := filepath.Abs(p) + if _, ok := seen[abs]; !ok { + seen[abs] = struct{}{} + agentFiles = append(agentFiles, p) + } + } + + // Package root: *.agent.md and *.chatmode.md + if entries, err := os.ReadDir(packagePath); err == nil { + for _, e := range entries { + if e.IsDir() { + continue + } + n := e.Name() + if strings.HasSuffix(n, ".agent.md") || strings.HasSuffix(n, ".chatmode.md") { + add(filepath.Join(packagePath, n)) + } + } + } + + // .apm/agents/ -- rglob *.agent.md + plain .md files + apmAgentsDir := filepath.Join(packagePath, ".apm", "agents") + if _, err := os.Stat(apmAgentsDir); err == nil { + filepath.WalkDir(apmAgentsDir, func(path string, d fs.DirEntry, err error) error { + if err != nil || d.IsDir() { + return nil + } + n := d.Name() + if strings.HasSuffix(n, ".agent.md") { + add(path) + } else if strings.HasSuffix(n, ".md") { + add(path) + } + return nil + }) + } + + // .apm/chatmodes/ (legacy) + apmChatmodesDir := filepath.Join(packagePath, ".apm", "chatmodes") + if _, err := os.Stat(apmChatmodesDir); err == nil { + if entries, err := os.ReadDir(apmChatmodesDir); err == nil { + for _, e := range entries { + if e.IsDir() { + continue + } + if strings.HasSuffix(e.Name(), ".chatmode.md") { + add(filepath.Join(apmChatmodesDir, e.Name())) + } + } + } + } + + return agentFiles +} + +// GetTargetFilenameForTarget generates the target filename for an agent file +// using the extension from target's agents mapping. +func GetTargetFilenameForTarget(sourceFile string, target *targets.TargetProfile) string { + mapping, ok := target.Primitives["agents"] + ext := ".agent.md" + if ok { + ext = mapping.Extension + } + name := filepath.Base(sourceFile) + var stem string + if strings.HasSuffix(name, ".agent.md") { + stem = name[:len(name)-9] + } else if strings.HasSuffix(name, ".chatmode.md") { + stem = name[:len(name)-12] + } else { + stem = strings.TrimSuffix(name, filepath.Ext(name)) + } + return stem + ext +} + +// PortableRelpath returns a relative path from base to target using forward slashes. +func PortableRelpath(targetPath, basePath string) string { + rel, err := filepath.Rel(basePath, targetPath) + if err != nil { + return targetPath + } + return filepath.ToSlash(rel) +} + +// CopyAgent copies a source agent file to target, returning links resolved count (stub: 0). +func CopyAgent(source, target string) (int, error) { + data, err := os.ReadFile(source) + if err != nil { + return 0, err + } + if err := os.WriteFile(target, data, 0644); err != nil { + return 0, err + } + return 0, nil +} + +// IntegrateAgentsForTarget integrates agents from a package for a single target. +func IntegrateAgentsForTarget( + target *targets.TargetProfile, + installPath string, + projectRoot string, + force bool, + managedFiles map[string]struct{}, + diag baseintegrator.Diagnostics, +) baseintegrator.IntegrationResult { + mapping, ok := target.Primitives["agents"] + if !ok { + return baseintegrator.IntegrationResult{} + } + + effectiveRoot := mapping.DeployRoot + if effectiveRoot == "" { + effectiveRoot = target.RootDir + } + targetRoot := filepath.Join(projectRoot, effectiveRoot) + + if !target.AutoCreate { + if _, err := os.Stat(filepath.Join(projectRoot, target.RootDir)); os.IsNotExist(err) { + return baseintegrator.IntegrationResult{} + } + } + + agentFiles := FindAgentFiles(installPath) + if len(agentFiles) == 0 { + return baseintegrator.IntegrationResult{} + } + + agentsDir := targetRoot + if mapping.Subdir != "" { + agentsDir = filepath.Join(targetRoot, mapping.Subdir) + } + if err := os.MkdirAll(agentsDir, 0755); err != nil { + return baseintegrator.IntegrationResult{} + } + + var result baseintegrator.IntegrationResult + + for _, sourceFile := range agentFiles { + targetFilename := GetTargetFilenameForTarget(sourceFile, target) + targetPath := filepath.Join(agentsDir, targetFilename) + relPath := PortableRelpath(targetPath, projectRoot) + + if baseintegrator.CheckCollision(targetPath, relPath, managedFiles, force, diag) { + result.FilesSkipped++ + continue + } + + var linksResolved int + var err error + + switch mapping.FormatID { + case "codex_agent": + err = writeCodexAgent(sourceFile, targetPath) + case "windsurf_agent_skill": + linksResolved, err = writeWindsurfAgentSkill(sourceFile, targetPath, diag) + default: + linksResolved, err = CopyAgent(sourceFile, targetPath) + } + + if err != nil { + fmt.Fprintf(os.Stderr, "[x] Failed to write agent %s: %v\n", targetFilename, err) + continue + } + + result.LinksResolved += linksResolved + result.FilesIntegrated++ + result.TargetPaths = append(result.TargetPaths, targetPath) + } + + return result +} + +// SyncForTarget removes APM-managed agent files for a single target. +func SyncForTarget( + target *targets.TargetProfile, + projectRoot string, + managedFiles map[string]struct{}, +) baseintegrator.SyncRemoveResult { + mapping, ok := target.Primitives["agents"] + if !ok { + return baseintegrator.SyncRemoveResult{} + } + effectiveRoot := mapping.DeployRoot + if effectiveRoot == "" { + effectiveRoot = target.RootDir + } + prefix := effectiveRoot + "/" + mapping.Subdir + "/" + legacyDir := filepath.Join(projectRoot, effectiveRoot, mapping.Subdir) + legacyPattern := "*-apm.md" + if mapping.Extension == ".agent.md" { + legacyPattern = "*-apm.agent.md" + } + return baseintegrator.SyncRemoveFiles( + projectRoot, + managedFiles, + prefix, + legacyDir, + legacyPattern, + []*targets.TargetProfile{target}, + nil, + ) +} + +// frontmatterRE matches YAML frontmatter in markdown. +var frontmatterRE = regexp.MustCompile(`(?s)^---\s*\n(.*?)\n---\s*\n?`) + +// writeCodexAgent transforms an .agent.md file to Codex .toml format. +// Produces a minimal TOML output without an external dependency. +func writeCodexAgent(source, target string) error { + data, err := os.ReadFile(source) + if err != nil { + return err + } + content := string(data) + + name := filepath.Base(source) + name = strings.TrimSuffix(name, filepath.Ext(name)) + if strings.HasSuffix(name, ".agent") { + name = name[:len(name)-6] + } + description := "" + body := content + + if m := frontmatterRE.FindStringSubmatchIndex(content); m != nil { + fmStr := content[m[2]:m[3]] + body = content[m[1]:] + fm := parseSimpleYAML(fmStr) + if v, ok := fm["name"]; ok { + name = v + } + if v, ok := fm["description"]; ok { + description = v + } + } + + body = strings.TrimSpace(body) + + // Produce minimal TOML + var sb strings.Builder + sb.WriteString("name = ") + sb.WriteString(tomlQuote(name)) + sb.WriteString("\ndescription = ") + sb.WriteString(tomlQuote(description)) + sb.WriteString("\ndeveloper_instructions = ") + sb.WriteString(tomlMultilineQuote(body)) + sb.WriteString("\n") + + return os.WriteFile(target, []byte(sb.String()), 0644) +} + +// writeWindsurfAgentSkill transforms an .agent.md file to a Windsurf Skill (SKILL.md). +func writeWindsurfAgentSkill(source, target string, diag baseintegrator.Diagnostics) (int, error) { + data, err := os.ReadFile(source) + if err != nil { + return 0, err + } + content := string(data) + + name := filepath.Base(source) + if strings.HasSuffix(name, ".agent.md") { + name = name[:len(name)-9] + } else if strings.HasSuffix(name, ".chatmode.md") { + name = name[:len(name)-12] + } else { + name = strings.TrimSuffix(name, filepath.Ext(name)) + } + + description := "" + body := content + var fmMap map[string]string + + if m := frontmatterRE.FindStringSubmatchIndex(content); m != nil { + fmMap = parseSimpleYAML(content[m[2]:m[3]]) + body = content[m[1]:] + } else { + fmMap = map[string]string{} + } + + if diag != nil { + var dropped []string + for _, k := range []string{"tools", "model"} { + if v, ok := fmMap[k]; ok && v != "" { + dropped = append(dropped, k) + } + } + if len(dropped) > 0 { + diag.Warn( + fmt.Sprintf("Windsurf skill conversion dropped frontmatter field(s) %s from %s", + strings.Join(dropped, ", "), filepath.Base(source)), + "Windsurf Skills do not support agent-only fields; only name, description, and body are preserved.", + ) + } + } + + if v, ok := fmMap["name"]; ok { + name = v + } + if v, ok := fmMap["description"]; ok { + description = v + } + + var fm strings.Builder + fm.WriteString("name: ") + fm.WriteString(yamlQuoteIfNeeded(name)) + if description != "" { + fm.WriteString("\ndescription: ") + fm.WriteString(yamlQuoteIfNeeded(description)) + } + + result := "---\n" + fm.String() + "\n---\n" + body + + if err := os.MkdirAll(filepath.Dir(target), 0755); err != nil { + return 0, err + } + if err := os.WriteFile(target, []byte(result), 0644); err != nil { + return 0, err + } + return 0, nil +} + +// parseSimpleYAML parses simple key: value YAML lines (no nesting, no lists). +func parseSimpleYAML(s string) map[string]string { + result := map[string]string{} + for _, line := range strings.Split(s, "\n") { + colon := strings.Index(line, ":") + if colon < 0 { + continue + } + key := strings.TrimSpace(line[:colon]) + val := strings.TrimSpace(line[colon+1:]) + // Strip surrounding quotes + if len(val) >= 2 { + if (val[0] == '"' && val[len(val)-1] == '"') || + (val[0] == '\'' && val[len(val)-1] == '\'') { + val = val[1 : len(val)-1] + } + } + result[key] = val + } + return result +} + +// tomlQuote wraps a string in TOML basic string quotes. +func tomlQuote(s string) string { + s = strings.ReplaceAll(s, "\\", "\\\\") + s = strings.ReplaceAll(s, "\"", "\\\"") + return `"` + s + `"` +} + +// tomlMultilineQuote wraps a string in TOML multi-line basic string quotes. +func tomlMultilineQuote(s string) string { + s = strings.ReplaceAll(s, `\`, `\\`) + s = strings.ReplaceAll(s, `"""`, `\"\"\"`) + return `"""` + "\n" + s + "\n" + `"""` +} + +// yamlQuoteIfNeeded wraps a value in double quotes if it contains special chars. +func yamlQuoteIfNeeded(s string) string { + specials := []string{":", "#", "[", "]", "{", "}", ",", "&", "*", "!", "|", ">", "'", "\"", "%", "@", "`"} + needs := false + for _, sp := range specials { + if strings.Contains(s, sp) { + needs = true + break + } + } + if needs { + s = strings.ReplaceAll(s, `"`, `\"`) + return `"` + s + `"` + } + return s +} diff --git a/internal/integration/agentintegrator/agentintegrator_test.go b/internal/integration/agentintegrator/agentintegrator_test.go new file mode 100644 index 0000000..1d88d2a --- /dev/null +++ b/internal/integration/agentintegrator/agentintegrator_test.go @@ -0,0 +1,111 @@ +package agentintegrator_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/githubnext/apm/internal/integration/agentintegrator" + "github.com/githubnext/apm/internal/integration/targets" +) + +func TestFindAgentFilesEmpty(t *testing.T) { + dir := t.TempDir() + files := agentintegrator.FindAgentFiles(dir) + if len(files) != 0 { + t.Fatalf("expected 0 files, got %d", len(files)) + } +} + +func TestFindAgentFilesRoot(t *testing.T) { + dir := t.TempDir() + os.WriteFile(filepath.Join(dir, "my.agent.md"), []byte("x"), 0644) + os.WriteFile(filepath.Join(dir, "chat.chatmode.md"), []byte("x"), 0644) + os.WriteFile(filepath.Join(dir, "other.txt"), []byte("x"), 0644) + files := agentintegrator.FindAgentFiles(dir) + if len(files) != 2 { + t.Fatalf("expected 2 agent files, got %d", len(files)) + } +} + +func TestFindAgentFilesApmAgentsDir(t *testing.T) { + dir := t.TempDir() + apmDir := filepath.Join(dir, ".apm", "agents") + os.MkdirAll(apmDir, 0755) + os.WriteFile(filepath.Join(apmDir, "helper.agent.md"), []byte("x"), 0644) + files := agentintegrator.FindAgentFiles(dir) + if len(files) != 1 { + t.Fatalf("expected 1 file, got %d", len(files)) + } +} + +func TestGetTargetFilenameForTargetCopilot(t *testing.T) { + source := "/pkg/.apm/agents/reviewer.agent.md" + target := targets.KnownTargets["copilot"] + got := agentintegrator.GetTargetFilenameForTarget(source, target) + if got != "reviewer.agent.md" { + t.Fatalf("expected reviewer.agent.md, got %q", got) + } +} + +func TestGetTargetFilenameForTargetClaude(t *testing.T) { + source := "/pkg/.apm/agents/reviewer.agent.md" + target := targets.KnownTargets["claude"] + got := agentintegrator.GetTargetFilenameForTarget(source, target) + // claude uses .md extension + if got != "reviewer.md" { + t.Fatalf("expected reviewer.md, got %q", got) + } +} + +func TestCopyAgent(t *testing.T) { + dir := t.TempDir() + src := filepath.Join(dir, "src.agent.md") + dst := filepath.Join(dir, "dst.agent.md") + os.WriteFile(src, []byte("# Agent\nHello"), 0644) + n, err := agentintegrator.CopyAgent(src, dst) + if err != nil { + t.Fatalf("copy error: %v", err) + } + if n != 0 { + t.Fatalf("expected 0 links, got %d", n) + } + data, _ := os.ReadFile(dst) + if string(data) != "# Agent\nHello" { + t.Fatal("content mismatch") + } +} + +func TestIntegrateAgentsForTarget(t *testing.T) { + dir := t.TempDir() + pkgDir := filepath.Join(dir, "pkg") + os.MkdirAll(pkgDir, 0755) + os.WriteFile(filepath.Join(pkgDir, "helper.agent.md"), []byte("# Helper"), 0644) + + // Create .github dir so copilot target is active + os.MkdirAll(filepath.Join(dir, ".github"), 0755) + + target := targets.KnownTargets["copilot"] + result := agentintegrator.IntegrateAgentsForTarget(target, pkgDir, dir, false, nil, nil) + if result.FilesIntegrated != 1 { + t.Fatalf("expected 1 integrated, got %d", result.FilesIntegrated) + } + expected := filepath.Join(dir, ".github", "agents", "helper.agent.md") + if _, err := os.Stat(expected); os.IsNotExist(err) { + t.Fatalf("expected output file at %s", expected) + } +} + +func TestSyncForTarget(t *testing.T) { + dir := t.TempDir() + agentsDir := filepath.Join(dir, ".github", "agents") + os.MkdirAll(agentsDir, 0755) + f := filepath.Join(agentsDir, "foo-apm.agent.md") + os.WriteFile(f, []byte("x"), 0644) + + target := targets.KnownTargets["copilot"] + stats := agentintegrator.SyncForTarget(target, dir, nil) + if stats.FilesRemoved != 1 { + t.Fatalf("expected 1 removed, got %d", stats.FilesRemoved) + } +} diff --git a/internal/integration/baseintegrator/baseintegrator.go b/internal/integration/baseintegrator/baseintegrator.go new file mode 100644 index 0000000..fe5ff89 --- /dev/null +++ b/internal/integration/baseintegrator/baseintegrator.go @@ -0,0 +1,482 @@ +// Package baseintegrator provides shared collision detection, sync removal, +// link resolution, and file-discovery helpers for file-level integrators. +// Ported from src/apm_cli/integration/base_integrator.py +package baseintegrator + +import ( + "fmt" + "os" + "path/filepath" + "sort" + "strings" + "syscall" + + "github.com/githubnext/apm/internal/integration/coworkpaths" + "github.com/githubnext/apm/internal/integration/targets" +) + +// IntegrationResult holds the outcome of a file-level integration operation. +type IntegrationResult struct { + FilesIntegrated int + FilesUpdated int // kept for CLI compat, always 0 today + FilesSkipped int + TargetPaths []string + LinksResolved int + // hook-specific + ScriptsCopied int + // skill-specific + SubSkillsPromoted int + SkillCreated bool +} + +// Diagnostics is a minimal interface for recording integration diagnostics. +type Diagnostics interface { + Skip(relPath string) + Warn(msg, detail string) +} + +// CheckCollision returns true if targetPath is a user-authored collision. +// A collision exists when: managed set is non-nil, file exists, relPath is NOT +// in the managed set, and force is false. +func CheckCollision( + targetPath string, + relPath string, + managedFiles map[string]struct{}, + force bool, + diag Diagnostics, +) bool { + if managedFiles == nil { + return false + } + if _, err := os.Stat(targetPath); os.IsNotExist(err) { + return false + } + norm := strings.ReplaceAll(relPath, "\\", "/") + if _, ok := managedFiles[norm]; ok { + return false + } + if force { + return false + } + if diag != nil { + diag.Skip(relPath) + } else { + fmt.Fprintf(os.Stderr, "[!] Skipping %s -- local file exists (not managed by APM). Use 'apm install --force' to overwrite.\n", relPath) + } + return true +} + +// NormalizeManagedFiles normalizes path separators to forward slashes for O(1) lookups. +func NormalizeManagedFiles(managedFiles map[string]struct{}) map[string]struct{} { + if managedFiles == nil { + return nil + } + out := make(map[string]struct{}, len(managedFiles)) + for p := range managedFiles { + out[strings.ReplaceAll(p, "\\", "/")] = struct{}{} + } + return out +} + +// BucketAliases maps raw {prim}_{target} keys to canonical bucket names. +var BucketAliases = map[string]string{ + "prompts_copilot": "prompts", + "agents_copilot": "agents_github", + "commands_claude": "commands", + "commands_cursor": "commands_cursor", + "commands_opencode": "commands_opencode", + "instructions_copilot": "instructions", + "instructions_cursor": "rules_cursor", + "instructions_claude": "rules_claude", +} + +// PartitionBucketKey returns the canonical bucket key for a (primitive, target) pair. +func PartitionBucketKey(primName, targetName string) string { + raw := primName + "_" + targetName + if alias, ok := BucketAliases[raw]; ok { + return alias + } + return raw +} + +// PartitionManagedFiles partitions managedFiles by integration prefix. +// When profiles is nil, falls back to targets.KnownTargets. +func PartitionManagedFiles( + managedFiles map[string]struct{}, + profiles []*targets.TargetProfile, +) map[string]map[string]struct{} { + source := profiles + if source == nil { + for _, p := range targets.KnownTargets { + source = append(source, p) + } + } + + buckets := map[string]map[string]struct{}{ + "skills": {}, + "hooks": {}, + } + + var skillPrefixes []string + var hookPrefixes []string + + // prefix -> bucket key + prefixMap := map[string]string{} + + for _, target := range source { + for primName, mapping := range target.Primitives { + if target.ResolvedDeployRoot != "" { + if primName == "skills" { + skillPrefixes = append(skillPrefixes, coworkpaths.CoworkLockfilePrefix) + } + continue + } + effectiveRoot := mapping.DeployRoot + if effectiveRoot == "" { + effectiveRoot = target.RootDir + } + var prefix string + if mapping.Subdir != "" { + prefix = effectiveRoot + "/" + mapping.Subdir + "/" + } else { + prefix = effectiveRoot + "/" + } + if primName == "skills" { + skillPrefixes = append(skillPrefixes, prefix) + } else if primName == "hooks" { + hookPrefixes = append(hookPrefixes, prefix) + } else { + raw := primName + "_" + target.Name + bucketKey, ok := BucketAliases[raw] + if !ok { + bucketKey = raw + } + if _, exists := buckets[bucketKey]; !exists { + buckets[bucketKey] = map[string]struct{}{} + } + prefixMap[prefix] = bucketKey + } + } + } + + // Build a trie for longest-prefix-match routing. + type trieNode struct { + children map[string]*trieNode + bucket string + } + root := &trieNode{children: map[string]*trieNode{}} + for prefix, bucketKey := range prefixMap { + segs := splitSegments(prefix) + node := root + for _, seg := range segs { + child, ok := node.children[seg] + if !ok { + child = &trieNode{children: map[string]*trieNode{}} + node.children[seg] = child + } + node = child + } + node.bucket = bucketKey + } + + for p := range managedFiles { + segs := splitSegments(p) + node := root + lastBucket := "" + for _, seg := range segs { + child, ok := node.children[seg] + if !ok { + break + } + node = child + if node.bucket != "" { + lastBucket = node.bucket + } + } + if lastBucket != "" { + buckets[lastBucket][p] = struct{}{} + continue + } + // Fall back to cross-target buckets + if hasAnyPrefix(p, skillPrefixes) { + buckets["skills"][p] = struct{}{} + } else if hasAnyPrefix(p, hookPrefixes) { + buckets["hooks"][p] = struct{}{} + } + } + + return buckets +} + +func splitSegments(path string) []string { + var segs []string + for _, s := range strings.Split(path, "/") { + if s != "" { + segs = append(segs, s) + } + } + return segs +} + +func hasAnyPrefix(s string, prefixes []string) bool { + for _, p := range prefixes { + if strings.HasPrefix(s, p) { + return true + } + } + return false +} + +// ValidateDeployPath returns true if relPath is safe for APM to deploy or remove. +// Checks: no path traversal, starts with an allowed integration prefix, resolves within projectRoot. +func ValidateDeployPath( + relPath string, + projectRoot string, + allowedPrefixes []string, + profiles []*targets.TargetProfile, +) bool { + if strings.Contains(relPath, "..") { + return false + } + + if allowedPrefixes == nil { + allowedPrefixes = targets.GetIntegrationPrefixes(profiles) + } + + if strings.HasPrefix(relPath, coworkpaths.CoworkURIScheme) { + if !hasAnyPrefix(relPath, allowedPrefixes) { + return false + } + coworkRoot, err := coworkpaths.ResolveCoworkSkillsDir() + if err != nil || coworkRoot == "" { + return false + } + _, err = coworkpaths.FromLockfilePath(relPath, coworkRoot) + return err == nil + } + + if !hasAnyPrefix(relPath, allowedPrefixes) { + return false + } + + target := filepath.Join(projectRoot, relPath) + resolved, err := filepath.EvalSymlinks(target) + if err != nil { + // If path doesn't exist yet, check using Clean + resolved = filepath.Clean(target) + } + projResolved, err := filepath.EvalSymlinks(projectRoot) + if err != nil { + projResolved = filepath.Clean(projectRoot) + } + return strings.HasPrefix(resolved, projResolved+string(os.PathSeparator)) || resolved == projResolved +} + +// CleanupEmptyParents removes empty parent directories bottom-up. +// Stops at stopAt and does not remove stopAt itself. +func CleanupEmptyParents(deletedPaths []string, stopAt string) { + if len(deletedPaths) == 0 { + return + } + stopResolved, err := filepath.EvalSymlinks(stopAt) + if err != nil { + stopResolved = filepath.Clean(stopAt) + } + + candidates := map[string]struct{}{} + for _, p := range deletedPaths { + parent := filepath.Dir(p) + for parent != stopAt { + parentResolved, _ := filepath.EvalSymlinks(parent) + if parentResolved == stopResolved { + break + } + candidates[parent] = struct{}{} + next := filepath.Dir(parent) + if next == parent { + break + } + parent = next + } + } + + // Sort deepest-first + sorted := make([]string, 0, len(candidates)) + for d := range candidates { + sorted = append(sorted, d) + } + sort.Slice(sorted, func(i, j int) bool { + return strings.Count(sorted[i], string(os.PathSeparator)) > strings.Count(sorted[j], string(os.PathSeparator)) + }) + + for _, d := range sorted { + entries, err := os.ReadDir(d) + if err != nil { + continue + } + if len(entries) == 0 { + os.Remove(d) // ignore errors + } + } +} + +// SyncRemoveResult holds the result of a sync removal operation. +type SyncRemoveResult struct { + FilesRemoved int + Errors int +} + +// Logger is a minimal interface for sync-remove diagnostic output. +type Logger interface { + Warning(msg string, symbol string) +} + +// SyncRemoveFiles removes APM-managed files matching prefix from managedFiles. +// Falls back to a legacy glob when managedFiles is nil. +func SyncRemoveFiles( + projectRoot string, + managedFiles map[string]struct{}, + prefix string, + legacyGlobDir string, + legacyGlobPattern string, + profiles []*targets.TargetProfile, + logger Logger, +) SyncRemoveResult { + stats := SyncRemoveResult{} + + if managedFiles != nil { + coworkRootResolved := false + coworkRootCached := "" + coworkOrphansSkipped := 0 + + for relPath := range managedFiles { + if !strings.HasPrefix(relPath, prefix) { + continue + } + if !ValidateDeployPath(relPath, projectRoot, nil, profiles) { + continue + } + + var targetPath string + if strings.HasPrefix(relPath, coworkpaths.CoworkURIScheme) { + if !coworkRootResolved { + coworkRootCached, _ = coworkpaths.ResolveCoworkSkillsDir() + coworkRootResolved = true + } + if coworkRootCached == "" { + coworkOrphansSkipped++ + continue + } + resolved, err := coworkpaths.FromLockfilePath(relPath, coworkRootCached) + if err != nil { + continue + } + targetPath = resolved + } else { + targetPath = filepath.Join(projectRoot, relPath) + } + + if _, err := os.Stat(targetPath); err == nil { + if err := os.Remove(targetPath); err != nil { + stats.Errors++ + } else { + stats.FilesRemoved++ + } + } + } + + if coworkOrphansSkipped > 0 { + word := "entry" + if coworkOrphansSkipped != 1 { + word = "entries" + } + msg := fmt.Sprintf( + "Cowork: skipping %d orphaned lockfile %s -- OneDrive path not detected.\n"+ + "Run: apm config set copilot-cowork-skills-dir "+ + "(or set APM_COPILOT_COWORK_SKILLS_DIR)\n"+ + "to clean up these entries on the next install/uninstall.", + coworkOrphansSkipped, word, + ) + if logger != nil { + logger.Warning(msg, "warning") + } else { + fmt.Fprintf(os.Stderr, "[!] %s\n", msg) + } + } + } else if legacyGlobDir != "" && legacyGlobPattern != "" { + if _, err := os.Stat(legacyGlobDir); err == nil { + matches, err := filepath.Glob(filepath.Join(legacyGlobDir, legacyGlobPattern)) + if err == nil { + for _, f := range matches { + if err := os.Remove(f); err != nil { + stats.Errors++ + } else { + stats.FilesRemoved++ + } + } + } + } + } + + return stats +} + +// FindFilesByGlob searches packagePath (and optional subdirs) for pattern. +// Symlinks and hardlinks are rejected. +func FindFilesByGlob(packagePath string, pattern string, subdirs []string) []string { + var results []string + seen := map[uint64]struct{}{} + + dirs := []string{packagePath} + for _, s := range subdirs { + dirs = append(dirs, filepath.Join(packagePath, s)) + } + + for _, d := range dirs { + if _, err := os.Stat(d); err != nil { + continue + } + matches, err := filepath.Glob(filepath.Join(d, pattern)) + if err != nil { + continue + } + sort.Strings(matches) + for _, f := range matches { + info, err := os.Lstat(f) + if err != nil { + continue + } + // Reject symlinks + if info.Mode()&os.ModeSymlink != 0 { + continue + } + // Reject hardlinks (nlink > 1) + if sys, ok := info.Sys().(*syscall.Stat_t); ok { + if sys.Nlink > 1 { + continue + } + } + resolved, err := filepath.EvalSymlinks(f) + if err != nil { + resolved = filepath.Clean(f) + } + pkgResolved, err := filepath.EvalSymlinks(packagePath) + if err != nil { + pkgResolved = filepath.Clean(packagePath) + } + if !strings.HasPrefix(resolved, pkgResolved+string(os.PathSeparator)) && resolved != pkgResolved { + continue + } + // Use inode as unique key + if sys, ok := info.Sys().(*syscall.Stat_t); ok { + inode := sys.Ino + if _, exists := seen[inode]; exists { + continue + } + seen[inode] = struct{}{} + } + results = append(results, f) + } + } + return results +} diff --git a/internal/integration/baseintegrator/baseintegrator_test.go b/internal/integration/baseintegrator/baseintegrator_test.go new file mode 100644 index 0000000..a24b2c2 --- /dev/null +++ b/internal/integration/baseintegrator/baseintegrator_test.go @@ -0,0 +1,119 @@ +package baseintegrator_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/githubnext/apm/internal/integration/baseintegrator" +) + +func TestCheckCollisionNilManaged(t *testing.T) { + if baseintegrator.CheckCollision("/any/path", "any/path", nil, false, nil) { + t.Fatal("nil managed should never collide") + } +} + +func TestCheckCollisionManagedContains(t *testing.T) { + dir := t.TempDir() + f := filepath.Join(dir, "file.md") + os.WriteFile(f, []byte("x"), 0644) + managed := map[string]struct{}{"file.md": {}} + if baseintegrator.CheckCollision(f, "file.md", managed, false, nil) { + t.Fatal("file in managed set should not collide") + } +} + +func TestCheckCollisionUserAuthored(t *testing.T) { + dir := t.TempDir() + f := filepath.Join(dir, "file.md") + os.WriteFile(f, []byte("x"), 0644) + managed := map[string]struct{}{"other.md": {}} + if !baseintegrator.CheckCollision(f, "file.md", managed, false, nil) { + t.Fatal("user-authored file should collide") + } +} + +func TestCheckCollisionForceOverrides(t *testing.T) { + dir := t.TempDir() + f := filepath.Join(dir, "file.md") + os.WriteFile(f, []byte("x"), 0644) + managed := map[string]struct{}{"other.md": {}} + if baseintegrator.CheckCollision(f, "file.md", managed, true, nil) { + t.Fatal("force should override collision") + } +} + +func TestNormalizeManagedFilesBackslash(t *testing.T) { + in := map[string]struct{}{`a\b\c.md`: {}} + out := baseintegrator.NormalizeManagedFiles(in) + if _, ok := out["a/b/c.md"]; !ok { + t.Fatal("backslash should be normalized to forward slash") + } +} + +func TestPartitionBucketKeyAlias(t *testing.T) { + got := baseintegrator.PartitionBucketKey("prompts", "copilot") + if got != "prompts" { + t.Fatalf("expected 'prompts', got %q", got) + } +} + +func TestPartitionBucketKeyPassthrough(t *testing.T) { + got := baseintegrator.PartitionBucketKey("agents", "cursor") + if got != "agents_cursor" { + t.Fatalf("expected 'agents_cursor', got %q", got) + } +} + +func TestValidateDeployPathTraversal(t *testing.T) { + if baseintegrator.ValidateDeployPath("../etc/passwd", "/project", []string{".github/"}, nil) { + t.Fatal("path traversal should be rejected") + } +} + +func TestValidateDeployPathDisallowedPrefix(t *testing.T) { + if baseintegrator.ValidateDeployPath(".hidden/secret", "/project", []string{".github/"}, nil) { + t.Fatal("disallowed prefix should be rejected") + } +} + +func TestCleanupEmptyParents(t *testing.T) { + dir := t.TempDir() + sub := filepath.Join(dir, "a", "b") + os.MkdirAll(sub, 0755) + f := filepath.Join(sub, "file.md") + os.WriteFile(f, []byte("x"), 0644) + os.Remove(f) + baseintegrator.CleanupEmptyParents([]string{f}, dir) + if _, err := os.Stat(sub); !os.IsNotExist(err) { + t.Fatal("empty sub directory should have been removed") + } + if _, err := os.Stat(dir); os.IsNotExist(err) { + t.Fatal("stop-at directory should NOT be removed") + } +} + +func TestSyncRemoveFilesLegacyGlob(t *testing.T) { + dir := t.TempDir() + f := filepath.Join(dir, "foo-apm.agent.md") + os.WriteFile(f, []byte("x"), 0644) + stats := baseintegrator.SyncRemoveFiles(dir, nil, ".github/agents/", dir, "*-apm.agent.md", nil, nil) + if stats.FilesRemoved != 1 { + t.Fatalf("expected 1 removed, got %d", stats.FilesRemoved) + } + if _, err := os.Stat(f); !os.IsNotExist(err) { + t.Fatal("file should have been removed") + } +} + +func TestFindFilesByGlob(t *testing.T) { + dir := t.TempDir() + os.WriteFile(filepath.Join(dir, "a.prompt.md"), []byte("x"), 0644) + os.WriteFile(filepath.Join(dir, "b.prompt.md"), []byte("x"), 0644) + os.WriteFile(filepath.Join(dir, "other.txt"), []byte("x"), 0644) + results := baseintegrator.FindFilesByGlob(dir, "*.prompt.md", nil) + if len(results) != 2 { + t.Fatalf("expected 2, got %d", len(results)) + } +} diff --git a/internal/integration/commandintegrator/commandintegrator.go b/internal/integration/commandintegrator/commandintegrator.go new file mode 100644 index 0000000..97695c1 --- /dev/null +++ b/internal/integration/commandintegrator/commandintegrator.go @@ -0,0 +1,429 @@ +// Package commandintegrator provides command integration for APM packages. +// Deploys .prompt.md files as slash commands for Claude, Cursor, OpenCode, etc. +// Ported from src/apm_cli/integration/command_integrator.py +package commandintegrator + +import ( + "os" + "path/filepath" + "regexp" + "strings" + + "github.com/githubnext/apm/internal/integration/baseintegrator" + "github.com/githubnext/apm/internal/integration/targets" +) + +// IntegrationResult holds results of a command integration operation. +type IntegrationResult struct { + FilesIntegrated int + FilesUpdated int + FilesSkipped int + TargetPaths []string + LinksResolved int +} + +// inputNameRe matches valid command argument names. +var inputNameRe = regexp.MustCompile(`^[A-Za-z][\w-]{0,63}$`) + +// inputRefRe matches ${{input:name}} and ${ input : name } references. +var inputRefRe = regexp.MustCompile(`\$\{\{?\s*input\s*:\s*([\w-]+)\s*\}?\}`) + +// preservedCommandKeys is the set of frontmatter keys preserved by the command transformer. +var preservedCommandKeys = map[string]bool{ + "description": true, + "allowed-tools": true, + "allowedTools": true, + "model": true, + "argument-hint": true, + "argumentHint": true, + "input": true, +} + +// isValidInputName returns true if name is a safe argument identifier. +func isValidInputName(name string) bool { + return inputNameRe.MatchString(name) +} + +// extractInputNames extracts argument names from an APM 'input' frontmatter value. +// input may be a string (single name) or []interface{} (list of names or maps with name key). +func extractInputNames(input interface{}) (valid []string, rejected []string) { + if input == nil { + return nil, nil + } + switch v := input.(type) { + case string: + if isValidInputName(v) { + valid = append(valid, v) + } else { + rejected = append(rejected, v) + } + case []interface{}: + for _, item := range v { + switch sv := item.(type) { + case string: + if isValidInputName(sv) { + valid = append(valid, sv) + } else { + rejected = append(rejected, sv) + } + case map[string]interface{}: + if name, ok := sv["name"].(string); ok { + if isValidInputName(name) { + valid = append(valid, name) + } else { + rejected = append(rejected, name) + } + } + } + } + } + return valid, rejected +} + +// parseFrontmatter parses YAML-style frontmatter from markdown content. +// Returns (metadata map, body content). Simple implementation for the keys we care about. +func parseFrontmatter(content string) (map[string]interface{}, string) { + meta := map[string]interface{}{} + body := content + + if !strings.HasPrefix(content, "---") { + return meta, body + } + // Find closing --- + rest := content[3:] + if rest != "" && rest[0] == '\n' { + rest = rest[1:] + } + idx := strings.Index(rest, "\n---") + if idx < 0 { + return meta, body + } + yamlPart := rest[:idx] + body = rest[idx+4:] + if strings.HasPrefix(body, "\n") { + body = body[1:] + } + + // Parse simple key: value lines + for _, line := range strings.Split(yamlPart, "\n") { + if colonIdx := strings.Index(line, ":"); colonIdx > 0 { + key := strings.TrimSpace(line[:colonIdx]) + val := strings.TrimSpace(line[colonIdx+1:]) + // Remove surrounding quotes + if len(val) >= 2 && ((val[0] == '"' && val[len(val)-1] == '"') || (val[0] == '\'' && val[len(val)-1] == '\'')) { + val = val[1 : len(val)-1] + } + meta[key] = val + } + } + return meta, body +} + +// buildCommandContent builds the command file content from metadata and body. +func buildCommandContent(meta map[string]interface{}, body string) string { + var sb strings.Builder + sb.WriteString("---\n") + orderedKeys := []string{"description", "allowed-tools", "model", "argument-hint", "arguments"} + written := map[string]bool{} + for _, k := range orderedKeys { + if v, ok := meta[k]; ok { + sb.WriteString(k) + sb.WriteString(": ") + switch sv := v.(type) { + case string: + sb.WriteString(sv) + case []string: + sb.WriteString("\n") + for _, item := range sv { + sb.WriteString(" - ") + sb.WriteString(item) + sb.WriteString("\n") + } + written[k] = true + continue + default: + sb.WriteString("") + } + sb.WriteString("\n") + written[k] = true + } + } + sb.WriteString("---\n") + sb.WriteString(body) + return sb.String() +} + +// transformPromptToCommand transforms a .prompt.md file into Claude command format. +// Returns (commandName, fileContent, droppedKeys bool). +func transformPromptToCommand(sourceFile string) (string, string, bool, error) { + data, err := os.ReadFile(sourceFile) + if err != nil { + return "", "", false, err + } + content := string(data) + meta, body := parseFrontmatter(content) + + filename := filepath.Base(sourceFile) + commandName := strings.TrimSuffix(filename, ".prompt.md") + if commandName == filename { + commandName = strings.TrimSuffix(filename, filepath.Ext(filename)) + } + + claudeMeta := map[string]interface{}{} + + if v, ok := meta["description"]; ok { + claudeMeta["description"] = v + } + if v, ok := meta["allowed-tools"]; ok { + claudeMeta["allowed-tools"] = v + } else if v, ok := meta["allowedTools"]; ok { + claudeMeta["allowed-tools"] = v + } + if v, ok := meta["model"]; ok { + claudeMeta["model"] = v + } + if v, ok := meta["argument-hint"]; ok { + claudeMeta["argument-hint"] = v + } else if v, ok := meta["argumentHint"]; ok { + claudeMeta["argument-hint"] = v + } + + // Map 'input' to 'arguments' and 'argument-hint' + inputNames, _ := extractInputNames(meta["input"]) + if len(inputNames) > 0 { + claudeMeta["arguments"] = inputNames + if _, ok := claudeMeta["argument-hint"]; !ok { + hints := make([]string, len(inputNames)) + for i, n := range inputNames { + hints[i] = "<" + n + ">" + } + claudeMeta["argument-hint"] = strings.Join(hints, " ") + } + // Replace ${{input:name}} with $name + body = inputRefRe.ReplaceAllStringFunc(body, func(m string) string { + sub := inputRefRe.FindStringSubmatch(m) + if len(sub) > 1 { + return "$" + sub[1] + } + return m + }) + } + + // Compute dropped keys + droppedKeys := false + for k := range meta { + if !preservedCommandKeys[k] { + droppedKeys = true + break + } + } + + fileContent := buildCommandContent(claudeMeta, body) + return commandName, fileContent, droppedKeys, nil +} + +// writeGeminiCommand transforms a .prompt.md to Gemini CLI TOML format. +func writeGeminiCommand(sourceFile, targetFile string) error { + data, err := os.ReadFile(sourceFile) + if err != nil { + return err + } + meta, body := parseFrontmatter(string(data)) + description, _ := meta["description"].(string) + promptText := strings.TrimSpace(body) + promptText = strings.ReplaceAll(promptText, "$ARGUMENTS", "{{args}}") + + var sb strings.Builder + if description != "" { + sb.WriteString("description = ") + sb.WriteString(`"`) + sb.WriteString(strings.ReplaceAll(description, `"`, `\"`)) + sb.WriteString(`"`) + sb.WriteString("\n") + } + sb.WriteString("prompt = ") + sb.WriteString(`"""`) + sb.WriteString("\n") + sb.WriteString(promptText) + sb.WriteString("\n") + sb.WriteString(`"""`) + sb.WriteString("\n") + + if err := os.MkdirAll(filepath.Dir(targetFile), 0o755); err != nil { + return err + } + return os.WriteFile(targetFile, []byte(sb.String()), 0o644) +} + +// CommandIntegrator handles integration of .prompt.md files as slash commands. +type CommandIntegrator struct { + passthroughNotified map[string]bool +} + +// New returns a new CommandIntegrator. +func New() *CommandIntegrator { + return &CommandIntegrator{ + passthroughNotified: map[string]bool{}, + } +} + +// FindPromptFiles returns all .prompt.md files in a package. +func FindPromptFiles(packagePath string) []string { + return baseintegrator.FindFilesByGlob(packagePath, "*.prompt.md", []string{".apm/prompts"}) +} + +// IntegrateCommandsForTarget integrates prompt files as commands for a single target. +func (ci *CommandIntegrator) IntegrateCommandsForTarget( + tgt *targets.TargetProfile, + packageInstallPath, projectRoot string, + force bool, + managedFiles map[string]struct{}, + diag baseintegrator.Diagnostics, +) IntegrationResult { + mapping, ok := tgt.Primitives["commands"] + if !ok { + return IntegrationResult{} + } + + effectiveRoot := mapping.DeployRoot + if effectiveRoot == "" { + effectiveRoot = tgt.RootDir + } + if !tgt.AutoCreate { + if _, err := os.Stat(filepath.Join(projectRoot, tgt.RootDir)); err != nil { + return IntegrationResult{} + } + } + + promptFiles := FindPromptFiles(packageInstallPath) + if len(promptFiles) == 0 { + return IntegrationResult{} + } + + commandsDir := filepath.Join(projectRoot, effectiveRoot, mapping.Subdir) + var result IntegrationResult + anyDroppedKeys := false + + for _, promptFile := range promptFiles { + filename := filepath.Base(promptFile) + baseName := strings.TrimSuffix(filename, ".prompt.md") + if baseName == filename { + baseName = strings.TrimSuffix(filename, filepath.Ext(filename)) + } + + // Path security check + if strings.Contains(baseName, "..") || strings.ContainsAny(baseName, "/\\") { + result.FilesSkipped++ + continue + } + + ext := mapping.Extension + if ext == "" { + ext = ".md" + } + targetPath := filepath.Join(commandsDir, baseName+ext) + relPath := strings.ReplaceAll(func() string { + rel, _ := filepath.Rel(projectRoot, targetPath) + return rel + }(), "\\", "/") + + if baseintegrator.CheckCollision(targetPath, relPath, managedFiles, force, diag) { + result.FilesSkipped++ + continue + } + + var written bool + var hadDropped bool + if mapping.FormatID == "gemini_command" { + if err := writeGeminiCommand(promptFile, targetPath); err == nil { + written = true + } + hadDropped = false + } else { + commandName, fileContent, dropped, err := transformPromptToCommand(promptFile) + _ = commandName + if err != nil { + result.FilesSkipped++ + continue + } + if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil { + result.FilesSkipped++ + continue + } + if err := os.WriteFile(targetPath, []byte(fileContent), 0o644); err != nil { + result.FilesSkipped++ + continue + } + written = true + hadDropped = dropped + } + + if !written { + result.FilesSkipped++ + continue + } + if hadDropped { + anyDroppedKeys = true + } + result.FilesIntegrated++ + result.TargetPaths = append(result.TargetPaths, targetPath) + } + _ = anyDroppedKeys + return result +} + +// SyncForTarget removes APM-managed command files for a single target. +func (ci *CommandIntegrator) SyncForTarget( + tgt *targets.TargetProfile, + projectRoot string, + managedFiles map[string]struct{}, +) map[string]int { + mapping, ok := tgt.Primitives["commands"] + if !ok { + return map[string]int{"files_removed": 0, "errors": 0} + } + effectiveRoot := mapping.DeployRoot + if effectiveRoot == "" { + effectiveRoot = tgt.RootDir + } + prefix := effectiveRoot + "/" + mapping.Subdir + "/" + legacyDir := filepath.Join(projectRoot, effectiveRoot, mapping.Subdir) + + res := baseintegrator.SyncRemoveFiles( + projectRoot, + managedFiles, + prefix, + legacyDir, + "*-apm.md", + nil, + nil, + ) + return map[string]int{"files_removed": res.FilesRemoved, "errors": res.Errors} +} + +// IntegratePackageCommands integrates prompt files as Claude commands (legacy API). +func (ci *CommandIntegrator) IntegratePackageCommands( + packageInstallPath, projectRoot string, + force bool, + managedFiles map[string]struct{}, + diag baseintegrator.Diagnostics, +) IntegrationResult { + tgt, ok := targets.KnownTargets["claude"] + if !ok { + return IntegrationResult{} + } + _ = os.MkdirAll(filepath.Join(projectRoot, ".claude"), 0o755) + return ci.IntegrateCommandsForTarget(tgt, packageInstallPath, projectRoot, force, managedFiles, diag) +} + +// SyncIntegration removes APM-managed command files from .claude/commands/ (legacy). +func (ci *CommandIntegrator) SyncIntegration( + projectRoot string, + managedFiles map[string]struct{}, +) map[string]int { + tgt, ok := targets.KnownTargets["claude"] + if !ok { + return map[string]int{"files_removed": 0, "errors": 0} + } + return ci.SyncForTarget(tgt, projectRoot, managedFiles) +} diff --git a/internal/integration/hookintegrator/hookintegrator.go b/internal/integration/hookintegrator/hookintegrator.go new file mode 100644 index 0000000..b4a98cf --- /dev/null +++ b/internal/integration/hookintegrator/hookintegrator.go @@ -0,0 +1,806 @@ +// Package hookintegrator provides hook integration for APM packages. +// Deploys hook JSON files and referenced scripts to target directories. +// Ported from src/apm_cli/integration/hook_integrator.py +package hookintegrator + +import ( + "encoding/json" + "io/fs" + "os" + "path/filepath" + "regexp" + "strings" + + "github.com/githubnext/apm/internal/integration/baseintegrator" + "github.com/githubnext/apm/internal/integration/targets" +) + +// HookIntegrationResult holds results of a hook integration operation. +type HookIntegrationResult struct { + FilesIntegrated int + FilesUpdated int + FilesSkipped int + TargetPaths []string + ScriptsCopied int +} + +// HooksIntegrated is an alias for FilesIntegrated (backward compat). +func (r *HookIntegrationResult) HooksIntegrated() int { + return r.FilesIntegrated +} + +// mergeHookConfig describes a target that merges hooks into a single JSON file. +type mergeHookConfig struct { + ConfigFilename string + TargetKey string + RequireDir bool +} + +// hookEventMap maps source event names to target-specific names. +var hookEventMap = map[string]map[string]string{ + "claude": { + "preToolUse": "PreToolUse", + "postToolUse": "PostToolUse", + }, + "gemini": { + "PreToolUse": "BeforeTool", + "preToolUse": "BeforeTool", + "PostToolUse": "AfterTool", + "postToolUse": "AfterTool", + "Stop": "SessionEnd", + }, +} + +// mergeHookTargets maps target names to merge configurations. +var mergeHookTargets = map[string]mergeHookConfig{ + "claude": {ConfigFilename: "settings.json", TargetKey: "claude", RequireDir: false}, + "cursor": {ConfigFilename: "hooks.json", TargetKey: "cursor", RequireDir: true}, + "codex": {ConfigFilename: "hooks.json", TargetKey: "codex", RequireDir: true}, + "gemini": {ConfigFilename: "settings.json", TargetKey: "gemini", RequireDir: true}, + "windsurf": {ConfigFilename: "hooks.json", TargetKey: "windsurf", RequireDir: true}, +} + +// hookFileTargetSuffixes maps hook file stem suffixes to target sets. +var hookFileTargetSuffixes = map[string]map[string]bool{ + "copilot-hooks": {"copilot": true, "vscode": true}, + "cursor-hooks": {"cursor": true}, + "claude-hooks": {"claude": true}, + "codex-hooks": {"codex": true}, + "gemini-hooks": {"gemini": true}, + "windsurf-hooks": {"windsurf": true}, +} + +// hookCommandKeys lists all supported hook command keys. +var hookCommandKeys = []string{"command", "bash", "powershell", "windows", "linux", "osx"} + +// pluginRootRe matches ${CLAUDE_PLUGIN_ROOT}/path and similar. +var pluginRootRe = regexp.MustCompile(`\$\{(?:CLAUDE_PLUGIN_ROOT|CURSOR_PLUGIN_ROOT|PLUGIN_ROOT)\}([\\/][^\s]+)`) + +// relPathRe matches relative ./path or .\path references. +var relPathRe = regexp.MustCompile(`(\.[\\/][^\s]+)`) + +// filterHookFilesForTarget returns only hook files intended for targetKey. +func filterHookFilesForTarget(hookFiles []string, targetKey string) []string { + var result []string + for _, hf := range hookFiles { + stemLower := strings.ToLower(strings.TrimSuffix(filepath.Base(hf), filepath.Ext(hf))) + matchedSuffix := "" + matched := false + for suffix, allowed := range hookFileTargetSuffixes { + if stemLower == suffix || strings.HasSuffix(stemLower, "-"+suffix) { + matchedSuffix = suffix + if allowed[targetKey] { + result = append(result, hf) + matched = true + } + break + } + } + if matchedSuffix == "" && !matched { + // Universal -- deploy to all targets + result = append(result, hf) + } + } + return result +} + +// toGeminiHookEntries transforms hook entries to Gemini CLI format. +func toGeminiHookEntries(entries []interface{}) []interface{} { + var result []interface{} + for _, raw := range entries { + entry, ok := raw.(map[string]interface{}) + if !ok { + result = append(result, raw) + continue + } + // Already nested (Claude/Gemini format) + if hooks, ok := entry["hooks"].([]interface{}); ok { + for _, h := range hooks { + if hm, ok := h.(map[string]interface{}); ok { + copilotKeysToGemini(hm) + } + } + result = append(result, entry) + continue + } + // Flat Copilot entry -- wrap in nested format + inner := shallowCopyMap(entry) + copilotKeysToGemini(inner) + apmSource, _ := inner["_apm_source"].(string) + delete(inner, "_apm_source") + outer := map[string]interface{}{"hooks": []interface{}{inner}} + if apmSource != "" { + outer["_apm_source"] = apmSource + } + result = append(result, outer) + } + return result +} + +func shallowCopyMap(m map[string]interface{}) map[string]interface{} { + out := make(map[string]interface{}, len(m)) + for k, v := range m { + out[k] = v + } + return out +} + +func copilotKeysToGemini(hook map[string]interface{}) { + if _, hasCmd := hook["command"]; !hasCmd { + for _, key := range []string{"bash", "powershell", "windows"} { + if v, ok := hook[key]; ok { + hook["command"] = v + delete(hook, key) + break + } + } + } + if ts, ok := hook["timeoutSec"]; ok { + switch v := ts.(type) { + case float64: + hook["timeout"] = v * 1000 + case int: + hook["timeout"] = v * 1000 + } + delete(hook, "timeoutSec") + } +} + +// HookIntegrator handles integration of APM package hooks. +type HookIntegrator struct{} + +// New returns a new HookIntegrator. +func New() *HookIntegrator { return &HookIntegrator{} } + +// FindHookFiles finds all hook JSON files in a package. +// Searches .apm/hooks/ and hooks/. +func (hi *HookIntegrator) FindHookFiles(packagePath string) []string { + var hookFiles []string + seen := map[string]bool{} + + for _, sub := range []string{".apm/hooks", "hooks"} { + dir := filepath.Join(packagePath, sub) + entries, err := os.ReadDir(dir) + if err != nil { + continue + } + for _, e := range entries { + if e.IsDir() || !strings.HasSuffix(e.Name(), ".json") { + continue + } + p := filepath.Join(dir, e.Name()) + if info, err := os.Lstat(p); err != nil || (info.Mode()&fs.ModeSymlink) != 0 { + continue + } + resolved, _ := filepath.EvalSymlinks(p) + if resolved == "" { + resolved = p + } + if !seen[resolved] { + seen[resolved] = true + hookFiles = append(hookFiles, p) + } + } + } + return hookFiles +} + +// parseHookJSON parses a hook JSON file. +func parseHookJSON(hookFile string) (map[string]interface{}, bool) { + data, err := os.ReadFile(hookFile) + if err != nil { + return nil, false + } + var result map[string]interface{} + if err := json.Unmarshal(data, &result); err != nil { + return nil, false + } + return result, true +} + +type scriptCopy struct { + Source string + TargetRel string +} + +// rewriteCommandForTarget rewrites a hook command to use installed script paths. +func (hi *HookIntegrator) rewriteCommandForTarget( + command, packagePath, packageName, targetKey string, + hookFileDir, rootDir string, +) (string, []scriptCopy) { + var scripts []scriptCopy + newCommand := command + + var scriptsBase string + if rootDir == "" { + switch targetKey { + case "vscode": + rootDir = ".github" + case "cursor": + rootDir = ".cursor" + case "codex": + rootDir = ".codex" + case "windsurf": + rootDir = ".windsurf" + default: + rootDir = ".claude" + } + } + switch targetKey { + case "vscode": + scriptsBase = rootDir + "/hooks/scripts/" + packageName + case "cursor": + scriptsBase = rootDir + "/hooks/" + packageName + case "codex": + scriptsBase = rootDir + "/hooks/" + packageName + case "windsurf": + scriptsBase = rootDir + "/hooks/" + packageName + default: + scriptsBase = rootDir + "/hooks/" + packageName + } + + pkgResolved, _ := filepath.EvalSymlinks(packagePath) + if pkgResolved == "" { + pkgResolved = packagePath + } + + // Handle plugin root variables + for _, match := range pluginRootRe.FindAllStringSubmatchIndex(command, -1) { + fullVar := command[match[0]:match[1]] + relPart := command[match[2]:match[3]] + relPart = strings.ReplaceAll(relPart, "\\", "/") + relPart = strings.TrimPrefix(relPart, "/") + srcFile := filepath.Join(packagePath, relPart) + srcResolved, _ := filepath.EvalSymlinks(srcFile) + if srcResolved == "" { + srcResolved = srcFile + } + if !strings.HasPrefix(srcResolved, pkgResolved) { + continue + } + if info, err := os.Stat(srcFile); err != nil || info.IsDir() { + continue + } + targetRel := scriptsBase + "/" + relPart + scripts = append(scripts, scriptCopy{Source: srcFile, TargetRel: targetRel}) + newCommand = strings.ReplaceAll(newCommand, fullVar, targetRel) + } + + // Handle relative ./path references + resolveBase := hookFileDir + if resolveBase == "" { + resolveBase = packagePath + } + for _, match := range relPathRe.FindAllStringIndex(newCommand, -1) { + relRef := newCommand[match[0]:match[1]] + relPath := strings.TrimPrefix(relRef, "./") + relPath = strings.TrimPrefix(relPath, ".\\") + relPath = strings.ReplaceAll(relPath, "\\", "/") + srcFile := filepath.Join(resolveBase, relPath) + srcResolved, _ := filepath.EvalSymlinks(srcFile) + if srcResolved == "" { + srcResolved = srcFile + } + if !strings.HasPrefix(srcResolved, pkgResolved) { + continue + } + if info, err := os.Stat(srcFile); err != nil || info.IsDir() { + continue + } + targetRel := scriptsBase + "/" + relPath + scripts = append(scripts, scriptCopy{Source: srcFile, TargetRel: targetRel}) + newCommand = strings.ReplaceAll(newCommand, relRef, targetRel) + } + return newCommand, scripts +} + +// rewriteHooksData rewrites all command paths in a hooks JSON structure. +func (hi *HookIntegrator) rewriteHooksData( + data map[string]interface{}, + packagePath, packageName, targetKey string, + hookFileDir, rootDir string, +) (map[string]interface{}, []scriptCopy) { + rewritten := deepCopyMap(data) + var allScripts []scriptCopy + + hooksRaw, _ := rewritten["hooks"].(map[string]interface{}) + if hooksRaw == nil { + return rewritten, nil + } + + for eventName, rawMatchers := range hooksRaw { + matchers, ok := rawMatchers.([]interface{}) + if !ok { + continue + } + for _, rawMatcher := range matchers { + matcher, ok := rawMatcher.(map[string]interface{}) + if !ok { + continue + } + // Rewrite flat-format keys + for _, key := range hookCommandKeys { + if cmd, ok := matcher[key].(string); ok { + newCmd, sc := hi.rewriteCommandForTarget(cmd, packagePath, packageName, targetKey, hookFileDir, rootDir) + matcher[key] = newCmd + allScripts = append(allScripts, sc...) + } + } + // Rewrite nested hooks array (Claude format) + if innerHooks, ok := matcher["hooks"].([]interface{}); ok { + for _, rawHook := range innerHooks { + hook, ok := rawHook.(map[string]interface{}) + if !ok { + continue + } + for _, key := range hookCommandKeys { + if cmd, ok := hook[key].(string); ok { + newCmd, sc := hi.rewriteCommandForTarget(cmd, packagePath, packageName, targetKey, hookFileDir, rootDir) + hook[key] = newCmd + allScripts = append(allScripts, sc...) + } + } + } + } + } + _ = eventName + } + + // Deduplicate scripts by target path + seen := map[string]string{} + for _, sc := range allScripts { + if _, ok := seen[sc.TargetRel]; !ok { + seen[sc.TargetRel] = sc.Source + } + } + var uniqueScripts []scriptCopy + for tgt, src := range seen { + uniqueScripts = append(uniqueScripts, scriptCopy{Source: src, TargetRel: tgt}) + } + return rewritten, uniqueScripts +} + +func deepCopyMap(m map[string]interface{}) map[string]interface{} { + b, _ := json.Marshal(m) + var out map[string]interface{} + _ = json.Unmarshal(b, &out) + return out +} + +func portableRelpath(path, base string) string { + rel, err := filepath.Rel(base, path) + if err != nil { + return path + } + return strings.ReplaceAll(rel, "\\", "/") +} + +// IntegratePackageHooks integrates hooks for the Copilot/VSCode target (individual JSON files). +func (hi *HookIntegrator) IntegratePackageHooks( + packageInstallPath, projectRoot string, + packageName string, + force bool, + managedFiles map[string]struct{}, + diag baseintegrator.Diagnostics, + rootDir string, +) *HookIntegrationResult { + hookFiles := hi.FindHookFiles(packageInstallPath) + hookFiles = filterHookFilesForTarget(hookFiles, "copilot") + if len(hookFiles) == 0 { + return &HookIntegrationResult{} + } + + if rootDir == "" { + rootDir = ".github" + } + hooksDir := filepath.Join(projectRoot, rootDir, "hooks") + _ = os.MkdirAll(hooksDir, 0o755) + + if packageName == "" { + packageName = filepath.Base(packageInstallPath) + } + + var result HookIntegrationResult + for _, hookFile := range hookFiles { + data, ok := parseHookJSON(hookFile) + if !ok { + continue + } + rewritten, scripts := hi.rewriteHooksData(data, packageInstallPath, packageName, "vscode", filepath.Dir(hookFile), rootDir) + stem := strings.TrimSuffix(filepath.Base(hookFile), filepath.Ext(hookFile)) + targetFilename := packageName + "-" + stem + ".json" + targetPath := filepath.Join(hooksDir, targetFilename) + relPath := portableRelpath(targetPath, projectRoot) + + if baseintegrator.CheckCollision(targetPath, relPath, managedFiles, force, diag) { + continue + } + + b, err := json.MarshalIndent(rewritten, "", " ") + if err != nil { + continue + } + if err := os.WriteFile(targetPath, append(b, '\n'), 0o644); err != nil { + continue + } + result.FilesIntegrated++ + result.TargetPaths = append(result.TargetPaths, targetPath) + + for _, sc := range scripts { + scriptTarget := filepath.Join(projectRoot, sc.TargetRel) + if err := os.MkdirAll(filepath.Dir(scriptTarget), 0o755); err != nil { + continue + } + if baseintegrator.CheckCollision(scriptTarget, sc.TargetRel, managedFiles, force, diag) { + continue + } + srcData, err := os.ReadFile(sc.Source) + if err != nil { + continue + } + if err := os.WriteFile(scriptTarget, srcData, 0o755); err != nil { + continue + } + result.ScriptsCopied++ + result.TargetPaths = append(result.TargetPaths, scriptTarget) + } + } + return &result +} + +// integrateMergedHooks integrates hooks by merging into a target-specific JSON config. +func (hi *HookIntegrator) integrateMergedHooks( + config mergeHookConfig, + packageInstallPath, projectRoot string, + packageName string, + force bool, + managedFiles map[string]struct{}, + diag baseintegrator.Diagnostics, + rootDir string, +) *HookIntegrationResult { + empty := &HookIntegrationResult{} + if rootDir == "" { + rootDir = "." + config.TargetKey + } + targetDir := filepath.Join(projectRoot, rootDir) + if config.RequireDir { + if _, err := os.Stat(targetDir); err != nil { + return empty + } + } + + hookFiles := hi.FindHookFiles(packageInstallPath) + hookFiles = filterHookFilesForTarget(hookFiles, config.TargetKey) + if len(hookFiles) == 0 { + return empty + } + + if packageName == "" { + packageName = filepath.Base(packageInstallPath) + } + + jsonPath := filepath.Join(targetDir, config.ConfigFilename) + jsonConfig := map[string]interface{}{} + if data, err := os.ReadFile(jsonPath); err == nil { + _ = json.Unmarshal(data, &jsonConfig) + } + if _, ok := jsonConfig["hooks"]; !ok { + jsonConfig["hooks"] = map[string]interface{}{} + } + hooksMap := jsonConfig["hooks"].(map[string]interface{}) + + eMap := hookEventMap[config.TargetKey] + clearedEvents := map[string]bool{} + + var result HookIntegrationResult + + for _, hookFile := range hookFiles { + data, ok := parseHookJSON(hookFile) + if !ok { + continue + } + rewritten, scripts := hi.rewriteHooksData(data, packageInstallPath, packageName, config.TargetKey, filepath.Dir(hookFile), rootDir) + + hooksRaw, _ := rewritten["hooks"].(map[string]interface{}) + if hooksRaw == nil { + continue + } + + for rawEventName, rawEntries := range hooksRaw { + entries, ok := rawEntries.([]interface{}) + if !ok { + continue + } + eventName := rawEventName + if mapped, ok := eMap[rawEventName]; ok { + eventName = mapped + } + if _, ok := hooksMap[eventName]; !ok { + hooksMap[eventName] = []interface{}{} + } + existingEntries := toSlice(hooksMap[eventName]) + + // Transform to Gemini format + if config.TargetKey == "gemini" { + entries = toGeminiHookEntries(entries) + } + // Mark with APM source + for _, e := range entries { + if em, ok := e.(map[string]interface{}); ok { + em["_apm_source"] = packageName + } + } + + // Idempotent upsert: clear prior entries for this package + if !clearedEvents[eventName] { + filtered := make([]interface{}, 0, len(existingEntries)) + for _, e := range existingEntries { + if em, ok := e.(map[string]interface{}); ok { + if em["_apm_source"] == packageName { + continue + } + } + filtered = append(filtered, e) + } + existingEntries = filtered + clearedEvents[eventName] = true + } + existingEntries = append(existingEntries, entries...) + + // Deduplicate same-package entries + existingEntries = deduplicateHookEntries(existingEntries, packageName) + hooksMap[eventName] = existingEntries + } + result.FilesIntegrated++ + + for _, sc := range scripts { + scriptTarget := filepath.Join(projectRoot, sc.TargetRel) + _ = os.MkdirAll(filepath.Dir(scriptTarget), 0o755) + if baseintegrator.CheckCollision(scriptTarget, sc.TargetRel, managedFiles, force, diag) { + continue + } + srcData, err := os.ReadFile(sc.Source) + if err != nil { + continue + } + if err := os.WriteFile(scriptTarget, srcData, 0o755); err != nil { + continue + } + result.ScriptsCopied++ + result.TargetPaths = append(result.TargetPaths, scriptTarget) + } + } + + _ = os.MkdirAll(targetDir, 0o755) + b, err := json.MarshalIndent(jsonConfig, "", " ") + if err == nil { + _ = os.WriteFile(jsonPath, append(b, '\n'), 0o644) + } + return &result +} + +func toSlice(v interface{}) []interface{} { + if s, ok := v.([]interface{}); ok { + return s + } + return nil +} + +func deduplicateHookEntries(entries []interface{}, packageName string) []interface{} { + type cmpKey struct { + source string + cmp string + } + seen := map[cmpKey]bool{} + var result []interface{} + for _, e := range entries { + em, ok := e.(map[string]interface{}) + if !ok { + result = append(result, e) + continue + } + src, _ := em["_apm_source"].(string) + if src != packageName { + result = append(result, e) + continue + } + cmpMap := map[string]interface{}{} + for k, v := range em { + if k != "_apm_source" { + cmpMap[k] = v + } + } + cmpBytes, _ := json.Marshal(cmpMap) + key := cmpKey{source: src, cmp: string(cmpBytes)} + if !seen[key] { + seen[key] = true + result = append(result, e) + } + } + return result +} + +// IntegrateHooksForTarget integrates hooks for a single target profile. +func (hi *HookIntegrator) IntegrateHooksForTarget( + tgt *targets.TargetProfile, + packageInstallPath, projectRoot, packageName string, + force bool, + managedFiles map[string]struct{}, + diag baseintegrator.Diagnostics, +) *HookIntegrationResult { + if tgt.Name == "copilot" { + return hi.IntegratePackageHooks(packageInstallPath, projectRoot, packageName, force, managedFiles, diag, tgt.RootDir) + } + if cfg, ok := mergeHookTargets[tgt.Name]; ok { + return hi.integrateMergedHooks(cfg, packageInstallPath, projectRoot, packageName, force, managedFiles, diag, tgt.RootDir) + } + return &HookIntegrationResult{} +} + +// SyncStats holds cleanup statistics. +type SyncStats struct { + FilesRemoved int + Errors int +} + +// SyncIntegration removes APM-managed hook files. +func (hi *HookIntegrator) SyncIntegration( + projectRoot string, + managedFiles map[string]struct{}, + allTargets []*targets.TargetProfile, +) SyncStats { + var stats SyncStats + if allTargets == nil { + for _, t := range targets.KnownTargets { + allTargets = append(allTargets, t) + } + } + + hookPrefixes := hookPrefixList(allTargets) + + if managedFiles != nil { + var deleted []string + for relPath := range managedFiles { + norm := strings.ReplaceAll(relPath, "\\", "/") + if strings.Contains(norm, "..") { + continue + } + if !hasAnyPrefix(norm, hookPrefixes) { + continue + } + target := filepath.Join(projectRoot, relPath) + if info, err := os.Stat(target); err != nil || info.IsDir() { + continue + } + if err := os.Remove(target); err != nil { + stats.Errors++ + } else { + stats.FilesRemoved++ + deleted = append(deleted, target) + } + } + baseintegrator.CleanupEmptyParents(deleted, projectRoot) + } else { + // Legacy: glob for *-apm.json + hooksDir := filepath.Join(projectRoot, ".github", "hooks") + entries, err := os.ReadDir(hooksDir) + if err == nil { + for _, e := range entries { + if strings.HasSuffix(e.Name(), "-apm.json") { + if err := os.Remove(filepath.Join(hooksDir, e.Name())); err != nil { + stats.Errors++ + } else { + stats.FilesRemoved++ + } + } + } + } + } + + // Clean APM entries from merged-hook JSON configs + for _, tgt := range allTargets { + cfg, ok := mergeHookTargets[tgt.Name] + if !ok { + continue + } + jsonPath := filepath.Join(projectRoot, tgt.RootDir, cfg.ConfigFilename) + cleanApmEntriesFromJSON(jsonPath, &stats) + } + return stats +} + +func hookPrefixList(allTargets []*targets.TargetProfile) []string { + var out []string + for _, tgt := range allTargets { + if !tgt.Supports("hooks") { + continue + } + sm := tgt.Primitives["hooks"] + effectiveRoot := sm.DeployRoot + if effectiveRoot == "" { + effectiveRoot = tgt.RootDir + } + out = append(out, effectiveRoot+"/hooks/") + } + return out +} + +func hasAnyPrefix(s string, prefixes []string) bool { + for _, p := range prefixes { + if strings.HasPrefix(s, p) { + return true + } + } + return false +} + +func cleanApmEntriesFromJSON(jsonPath string, stats *SyncStats) { + data, err := os.ReadFile(jsonPath) + if err != nil { + return + } + var cfg map[string]interface{} + if err := json.Unmarshal(data, &cfg); err != nil { + stats.Errors++ + return + } + hooksRaw, ok := cfg["hooks"] + if !ok { + return + } + hooksMap, ok := hooksRaw.(map[string]interface{}) + if !ok { + return + } + modified := false + for eventName := range hooksMap { + entries := toSlice(hooksMap[eventName]) + filtered := make([]interface{}, 0, len(entries)) + for _, e := range entries { + if em, ok := e.(map[string]interface{}); ok { + if _, hasSource := em["_apm_source"]; hasSource { + modified = true + continue + } + } + filtered = append(filtered, e) + } + if len(filtered) == 0 { + delete(hooksMap, eventName) + modified = true + } else { + hooksMap[eventName] = filtered + } + } + if len(hooksMap) == 0 { + delete(cfg, "hooks") + modified = true + } + if modified { + b, err := json.MarshalIndent(cfg, "", " ") + if err == nil { + _ = os.WriteFile(jsonPath, append(b, '\n'), 0o644) + stats.FilesRemoved++ + } + } +} diff --git a/internal/integration/instructionintegrator/instructionintegrator.go b/internal/integration/instructionintegrator/instructionintegrator.go new file mode 100644 index 0000000..54ccba1 --- /dev/null +++ b/internal/integration/instructionintegrator/instructionintegrator.go @@ -0,0 +1,323 @@ +// Package instructionintegrator deploys .instructions.md files from APM packages +// to the appropriate target directory with format-specific transforms. +// +// Supported format transforms: +// - cursor_rules: applyTo: -> globs: (.mdc extension) +// - claude_rules: applyTo: -> paths: list +// - windsurf_rules: applyTo: -> trigger: glob + globs: +// - default: verbatim copy +package instructionintegrator + +import ( + "io/fs" + "os" + "path/filepath" + "regexp" + "strings" +) + +// IntegrationResult holds the result of an instruction integration operation. +type IntegrationResult struct { + FilesIntegrated int + FilesUpdated int + FilesSkipped int + TargetPaths []string + LinksResolved int +} + +// FormatID identifies the content transform to apply. +type FormatID string + +const ( + FormatVerbatim FormatID = "" + FormatCursorRules FormatID = "cursor_rules" + FormatClaudeRules FormatID = "claude_rules" + FormatWindsurfRules FormatID = "windsurf_rules" +) + +// TargetConfig holds deploy configuration for an integration target. +type TargetConfig struct { + // RootDir is the target root (e.g. ".github"). + RootDir string + // Subdir is the subdirectory under RootDir for the primitive. + Subdir string + // Extension is the file extension for renamed files (e.g. ".mdc"). + Extension string + // FormatID selects the content transform. + FormatID FormatID + // DeployRoot overrides RootDir when set. + DeployRoot string + // AutoCreate creates the target directory even if RootDir doesn't exist. + AutoCreate bool +} + +// frontmatterRe matches a YAML frontmatter block at the top of a file. +var frontmatterRe = regexp.MustCompile(`(?s)^---\s*\n(.*?)\n---\s*\n?`) + +// parseFrontmatter extracts applyTo and description from YAML frontmatter. +func parseFrontmatter(content string) (applyTo, description, body string) { + m := frontmatterRe.FindStringSubmatchIndex(content) + if m == nil { + return "", "", content + } + fmBlock := content[m[2]:m[3]] + body = content[m[1]:] + for _, line := range strings.Split(fmBlock, "\n") { + stripped := strings.TrimSpace(line) + if strings.HasPrefix(stripped, "applyTo:") { + applyTo = strings.Trim(strings.TrimPrefix(stripped, "applyTo:"), " '\"") + } else if strings.HasPrefix(stripped, "description:") { + description = strings.Trim(strings.TrimPrefix(stripped, "description:"), " '\"") + } + } + return applyTo, description, body +} + +// ConvertToCursorRules converts APM instruction content to Cursor Rules .mdc format. +// Maps applyTo: -> globs: and extracts or generates description. +func ConvertToCursorRules(content string) string { + applyTo, description, body := parseFrontmatter(content) + + if description == "" { + for _, line := range strings.Split(body, "\n") { + stripped := strings.TrimLeft(strings.TrimSpace(line), "#") + stripped = strings.TrimSpace(stripped) + if stripped != "" { + parts := strings.SplitN(stripped, ".", 2) + description = strings.TrimSpace(parts[0]) + break + } + } + } + + var parts []string + parts = append(parts, "---") + if description != "" { + parts = append(parts, "description: "+description) + } + if applyTo != "" { + parts = append(parts, `globs: "`+applyTo+`"`) + } + parts = append(parts, "---") + + return strings.Join(parts, "\n") + "\n\n" + strings.TrimLeft(body, "\n") +} + +// ConvertToClaudeRules converts APM instruction content to Claude Code rules .md format. +// Maps applyTo: -> paths: list. Instructions without applyTo become unconditional rules. +func ConvertToClaudeRules(content string) string { + applyTo, _, body := parseFrontmatter(content) + + if applyTo != "" { + fm := "---\npaths:\n - \"" + applyTo + "\"\n---" + return fm + "\n\n" + strings.TrimLeft(body, "\n") + } + return strings.TrimLeft(body, "\n") +} + +// ConvertToWindsurfRules converts APM instruction content to Windsurf rules .md format. +// Maps applyTo: -> trigger: glob + globs:. Instructions without applyTo use trigger: always_on. +func ConvertToWindsurfRules(content string) string { + applyTo, _, body := parseFrontmatter(content) + + var parts []string + parts = append(parts, "---") + if applyTo != "" { + safeApplyTo := strings.ReplaceAll(strings.ReplaceAll(applyTo, "\n", " "), "\r", " ") + safeApplyTo = strings.TrimSpace(safeApplyTo) + parts = append(parts, "trigger: glob") + parts = append(parts, `globs: "`+safeApplyTo+`"`) + } else { + parts = append(parts, "trigger: always_on") + } + parts = append(parts, "---") + + return strings.Join(parts, "\n") + "\n\n" + strings.TrimLeft(body, "\n") +} + +// FindInstructionFiles returns all .instructions.md files in a package's .apm/instructions/ dir. +func FindInstructionFiles(packagePath string) ([]string, error) { + var files []string + instructionsDir := filepath.Join(packagePath, ".apm", "instructions") + _ = filepath.WalkDir(instructionsDir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return nil + } + if !d.IsDir() && strings.HasSuffix(d.Name(), ".instructions.md") { + files = append(files, path) + } + return nil + }) + return files, nil +} + +// CopyInstruction copies an instruction file to target, applying the given format transform. +// Returns number of links resolved (always 0 in this stdlib implementation). +func CopyInstruction(source, target string, format FormatID) (int, error) { + data, err := os.ReadFile(source) + if err != nil { + return 0, err + } + content := string(data) + + switch format { + case FormatCursorRules: + content = ConvertToCursorRules(content) + case FormatClaudeRules: + content = ConvertToClaudeRules(content) + case FormatWindsurfRules: + content = ConvertToWindsurfRules(content) + } + + if err := os.WriteFile(target, []byte(content), 0o644); err != nil { + return 0, err + } + return 0, nil +} + +// IntegrateInstructionsForTarget deploys instruction files to the given target directory. +func IntegrateInstructionsForTarget( + installPath string, + projectRoot string, + cfg TargetConfig, + force bool, + managedFiles map[string]bool, +) (IntegrationResult, error) { + result := IntegrationResult{} + + effectiveRoot := cfg.DeployRoot + if effectiveRoot == "" { + effectiveRoot = cfg.RootDir + } + + if !cfg.AutoCreate { + if _, err := os.Stat(filepath.Join(projectRoot, cfg.RootDir)); os.IsNotExist(err) { + return result, nil + } + } + + instructionFiles, err := FindInstructionFiles(installPath) + if err != nil { + return result, err + } + if len(instructionFiles) == 0 { + return result, nil + } + + deployDir := filepath.Join(projectRoot, effectiveRoot, cfg.Subdir) + if err := os.MkdirAll(deployDir, 0o755); err != nil { + return result, err + } + + needsRename := cfg.FormatID == FormatCursorRules || + cfg.FormatID == FormatClaudeRules || + cfg.FormatID == FormatWindsurfRules + + for _, src := range instructionFiles { + var targetName string + if needsRename { + stem := filepath.Base(src) + if strings.HasSuffix(stem, ".instructions.md") { + stem = stem[:len(stem)-len(".instructions.md")] + } + ext := cfg.Extension + if ext == "" { + ext = ".md" + } + targetName = stem + ext + } else { + targetName = filepath.Base(src) + } + + targetPath := filepath.Join(deployDir, targetName) + relPath := filepath.ToSlash(strings.TrimPrefix(targetPath, projectRoot+string(filepath.Separator))) + + if checkCollision(targetPath, relPath, managedFiles, force) { + result.FilesSkipped++ + continue + } + + links, err := CopyInstruction(src, targetPath, cfg.FormatID) + if err != nil { + return result, err + } + result.FilesIntegrated++ + result.LinksResolved += links + result.TargetPaths = append(result.TargetPaths, targetPath) + } + + return result, nil +} + +// SyncForTarget removes APM-managed instruction files for a given target. +func SyncForTarget( + projectRoot string, + cfg TargetConfig, + managedFiles map[string]bool, +) (filesRemoved int, errors int) { + effectiveRoot := cfg.DeployRoot + if effectiveRoot == "" { + effectiveRoot = cfg.RootDir + } + prefix := effectiveRoot + "/" + cfg.Subdir + "/" + + if managedFiles != nil { + for rel := range managedFiles { + if strings.HasPrefix(rel, prefix) { + abs := filepath.Join(projectRoot, filepath.FromSlash(rel)) + if rmErr := os.Remove(abs); rmErr == nil { + filesRemoved++ + } + } + } + return filesRemoved, errors + } + + // Legacy glob removal + var legacyPattern string + switch cfg.FormatID { + case FormatCursorRules: + legacyPattern = "*.mdc" + case FormatWindsurfRules, FormatClaudeRules: + // Avoid broad deletion of user-authored .md files + return 0, 0 + default: + legacyPattern = "*.instructions.md" + } + + legacyDir := filepath.Join(projectRoot, effectiveRoot, cfg.Subdir) + entries, err := os.ReadDir(legacyDir) + if err != nil { + return 0, 0 + } + for _, e := range entries { + if e.IsDir() { + continue + } + matched, _ := filepath.Match(legacyPattern, e.Name()) + if matched { + if rmErr := os.Remove(filepath.Join(legacyDir, e.Name())); rmErr == nil { + filesRemoved++ + } + } + } + return filesRemoved, errors +} + +// checkCollision returns true if the target is a user-authored file that should not be overwritten. +func checkCollision(targetPath, relPath string, managedFiles map[string]bool, force bool) bool { + if managedFiles == nil { + return false + } + if _, err := os.Stat(targetPath); os.IsNotExist(err) { + return false + } + normalized := strings.ReplaceAll(relPath, "\\", "/") + if managedFiles[normalized] { + return false + } + if force { + return false + } + return true +} diff --git a/internal/integration/instructionintegrator/instructionintegrator_test.go b/internal/integration/instructionintegrator/instructionintegrator_test.go new file mode 100644 index 0000000..0a7f3b3 --- /dev/null +++ b/internal/integration/instructionintegrator/instructionintegrator_test.go @@ -0,0 +1,80 @@ +package instructionintegrator + +import ( + "testing" +) + +func TestConvertToCursorRules_WithApplyTo(t *testing.T) { + input := "---\napplyTo: \"**/*.go\"\ndescription: Go lint rules\n---\n\nContent here.\n" + out := ConvertToCursorRules(input) + if !contains(out, `globs: "**/*.go"`) { + t.Errorf("expected globs field, got: %s", out) + } + if !contains(out, "description: Go lint rules") { + t.Errorf("expected description field, got: %s", out) + } +} + +func TestConvertToCursorRules_NoApplyTo(t *testing.T) { + input := "# My Rule\n\nDo this.\n" + out := ConvertToCursorRules(input) + if !contains(out, "---") { + t.Errorf("expected frontmatter, got: %s", out) + } + if !contains(out, "Do this.") { + t.Errorf("expected body, got: %s", out) + } +} + +func TestConvertToClaudeRules_WithApplyTo(t *testing.T) { + input := "---\napplyTo: \"src/**\"\n---\n\nBody.\n" + out := ConvertToClaudeRules(input) + if !contains(out, `"src/**"`) { + t.Errorf("expected path, got: %s", out) + } + if !contains(out, "paths:") { + t.Errorf("expected paths key, got: %s", out) + } +} + +func TestConvertToClaudeRules_NoApplyTo(t *testing.T) { + input := "---\ndescription: foo\n---\n\nBody.\n" + out := ConvertToClaudeRules(input) + if contains(out, "paths:") { + t.Errorf("unexpected paths key, got: %s", out) + } + if !contains(out, "Body.") { + t.Errorf("expected body, got: %s", out) + } +} + +func TestConvertToWindsurfRules_WithApplyTo(t *testing.T) { + input := "---\napplyTo: \"**/*.ts\"\n---\n\nBody.\n" + out := ConvertToWindsurfRules(input) + if !contains(out, "trigger: glob") { + t.Errorf("expected trigger: glob, got: %s", out) + } + if !contains(out, `globs: "**/*.ts"`) { + t.Errorf("expected globs field, got: %s", out) + } +} + +func TestConvertToWindsurfRules_NoApplyTo(t *testing.T) { + input := "Body.\n" + out := ConvertToWindsurfRules(input) + if !contains(out, "trigger: always_on") { + t.Errorf("expected trigger: always_on, got: %s", out) + } +} + +func contains(s, sub string) bool { + return len(s) >= len(sub) && (s == sub || len(sub) == 0 || + func() bool { + for i := 0; i <= len(s)-len(sub); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false + }()) +} diff --git a/internal/integration/promptintegrator/promptintegrator.go b/internal/integration/promptintegrator/promptintegrator.go new file mode 100644 index 0000000..27e5818 --- /dev/null +++ b/internal/integration/promptintegrator/promptintegrator.go @@ -0,0 +1,166 @@ +// Package promptintegrator provides prompt file integration for APM packages. +// Deploys .prompt.md files into .github/prompts/. +package promptintegrator + +import ( + "io/fs" + "os" + "path/filepath" + "strings" +) + +// IntegrationResult holds the result of a prompt integration operation. +type IntegrationResult struct { + FilesIntegrated int + FilesUpdated int + FilesSkipped int + TargetPaths []string + LinksResolved int +} + +// FindPromptFiles returns all .prompt.md files found in a package directory. +// Searches in package root and .apm/prompts/ subdirectory. +func FindPromptFiles(packagePath string) ([]string, error) { + var files []string + + // Search in package root + entries, err := os.ReadDir(packagePath) + if err == nil { + for _, e := range entries { + if !e.IsDir() && strings.HasSuffix(e.Name(), ".prompt.md") { + files = append(files, filepath.Join(packagePath, e.Name())) + } + } + } + + // Search in .apm/prompts/ + apmPrompts := filepath.Join(packagePath, ".apm", "prompts") + _ = filepath.WalkDir(apmPrompts, func(path string, d fs.DirEntry, werr error) error { + if werr != nil { + return nil + } + if !d.IsDir() && strings.HasSuffix(d.Name(), ".prompt.md") { + files = append(files, path) + } + return nil + }) + + return files, nil +} + +// GetTargetFilename returns the target filename for a prompt file (no suffix change). +func GetTargetFilename(sourceFile string) string { + return filepath.Base(sourceFile) +} + +// CopyPrompt copies a prompt file verbatim to the target path. +// Returns number of links resolved (always 0 in this implementation). +func CopyPrompt(source, target string) (int, error) { + data, err := os.ReadFile(source) + if err != nil { + return 0, err + } + if err := os.WriteFile(target, data, 0o644); err != nil { + return 0, err + } + return 0, nil +} + +// IntegratePackagePrompts integrates all prompt files from a package into .github/prompts/. +// managedFiles is the set of relative paths known to be APM-managed (nil = legacy mode). +// force overrides collision checks. +func IntegratePackagePrompts( + installPath string, + projectRoot string, + force bool, + managedFiles map[string]bool, +) (IntegrationResult, error) { + result := IntegrationResult{} + + promptFiles, err := FindPromptFiles(installPath) + if err != nil { + return result, err + } + if len(promptFiles) == 0 { + return result, nil + } + + promptsDir := filepath.Join(projectRoot, ".github", "prompts") + if err := os.MkdirAll(promptsDir, 0o755); err != nil { + return result, err + } + + for _, src := range promptFiles { + targetName := GetTargetFilename(src) + targetPath := filepath.Join(promptsDir, targetName) + relPath := filepath.ToSlash(strings.TrimPrefix(targetPath, projectRoot+string(filepath.Separator))) + + if checkCollision(targetPath, relPath, managedFiles, force) { + result.FilesSkipped++ + continue + } + + links, err := CopyPrompt(src, targetPath) + if err != nil { + return result, err + } + result.FilesIntegrated++ + result.LinksResolved += links + result.TargetPaths = append(result.TargetPaths, targetPath) + } + + return result, nil +} + +// SyncIntegration removes APM-managed prompt files. +// managedFiles nil => legacy glob removal of *-apm.prompt.md. +func SyncIntegration( + projectRoot string, + managedFiles map[string]bool, +) (filesRemoved int, errors int) { + promptsDir := filepath.Join(projectRoot, ".github", "prompts") + + if managedFiles != nil { + for rel := range managedFiles { + if strings.HasPrefix(rel, ".github/prompts/") { + abs := filepath.Join(projectRoot, filepath.FromSlash(rel)) + if rmErr := os.Remove(abs); rmErr == nil { + filesRemoved++ + } + } + } + return filesRemoved, errors + } + + // Legacy: remove *-apm.prompt.md + entries, err := os.ReadDir(promptsDir) + if err != nil { + return 0, 0 + } + for _, e := range entries { + if !e.IsDir() && strings.HasSuffix(e.Name(), "-apm.prompt.md") { + if rmErr := os.Remove(filepath.Join(promptsDir, e.Name())); rmErr == nil { + filesRemoved++ + } + } + } + return filesRemoved, errors +} + +// checkCollision returns true if target_path is a user-authored file that should not be overwritten. +func checkCollision(targetPath, relPath string, managedFiles map[string]bool, force bool) bool { + if managedFiles == nil { + return false + } + if _, err := os.Stat(targetPath); os.IsNotExist(err) { + return false + } + normalized := strings.ReplaceAll(relPath, "\\", "/") + if managedFiles[normalized] { + return false + } + if force { + return false + } + return true +} diff --git a/internal/integration/skillintegrator/skillintegrator.go b/internal/integration/skillintegrator/skillintegrator.go new file mode 100644 index 0000000..5aad513 --- /dev/null +++ b/internal/integration/skillintegrator/skillintegrator.go @@ -0,0 +1,734 @@ +// Package skillintegrator provides skill integration for APM packages. +// Deploys SKILL.md-based packages to .github/skills/, .claude/skills/, etc. +// Ported from src/apm_cli/integration/skill_integrator.py +package skillintegrator + +import ( + "io/fs" + "os" + "path/filepath" + "regexp" + "strings" + "sync" + + "github.com/githubnext/apm/internal/integration/targets" +) + +// SkillIntegrationResult holds results of a skill integration operation. +type SkillIntegrationResult struct { + SkillCreated bool + SkillUpdated bool + SkillSkipped bool + SkillPath string // path to deployed SKILL.md, empty if not deployed + ReferencesCopied int // total files copied to skill directory + LinksResolved int // always 0 (kept for backward compat) + SubSkillsPromoted int // number of sub-skills promoted to top-level + TargetPaths []string +} + +// nameRe matches valid agentskills.io skill names. +var nameRe = regexp.MustCompile(`^[a-z0-9]([a-z0-9-]*[a-z0-9])?$`) +var camelRe = regexp.MustCompile(`([a-z])([A-Z])`) +var badCharsRe = regexp.MustCompile(`[^a-z0-9-]`) +var multiHyphenRe = regexp.MustCompile(`-+`) + +// ToHyphenCase converts a package name to hyphen-case (max 64 chars). +func ToHyphenCase(name string) string { + if idx := strings.LastIndex(name, "/"); idx >= 0 { + name = name[idx+1:] + } + name = strings.NewReplacer("_", "-", " ", "-").Replace(name) + name = camelRe.ReplaceAllString(name, "${1}-${2}") + name = strings.ToLower(name) + name = badCharsRe.ReplaceAllString(name, "") + name = multiHyphenRe.ReplaceAllString(name, "-") + name = strings.Trim(name, "-") + if len(name) > 64 { + name = name[:64] + } + return name +} + +// ValidateSkillName validates a skill name per agentskills.io spec. +// Returns (valid, errorMessage). +func ValidateSkillName(name string) (bool, string) { + if len(name) == 0 { + return false, "Skill name cannot be empty" + } + if len(name) > 64 { + return false, "Skill name must be 1-64 characters" + } + if strings.Contains(name, "--") { + return false, "Skill name cannot contain consecutive hyphens (--)" + } + if strings.HasPrefix(name, "-") { + return false, "Skill name cannot start with a hyphen" + } + if strings.HasSuffix(name, "-") { + return false, "Skill name cannot end with a hyphen" + } + if !nameRe.MatchString(name) { + return false, "Skill name must be lowercase alphanumeric with hyphens only" + } + return true, "" +} + +// NormalizeSkillName converts any package name to a valid skill name. +func NormalizeSkillName(name string) string { + return ToHyphenCase(name) +} + +// ignoreNonContent returns true for paths that should not be copied +// (hidden files/dirs except SKILL.md, .git, __pycache__, *.pyc). +func ignoreNonContent(name string) bool { + if name == ".git" || name == "__pycache__" || strings.HasSuffix(name, ".pyc") { + return true + } + return false +} + +// copyDirSkill copies src directory to dst, skipping non-content files. +func copyDirSkill(src, dst string) (int, error) { + if err := os.MkdirAll(dst, 0o755); err != nil { + return 0, err + } + count := 0 + err := filepath.WalkDir(src, func(path string, d fs.DirEntry, werr error) error { + if werr != nil { + return nil + } + rel, _ := filepath.Rel(src, path) + if rel == "." { + return nil + } + parts := strings.SplitN(rel, string(filepath.Separator), 2) + if len(parts) > 0 && ignoreNonContent(parts[0]) { + if d.IsDir() { + return filepath.SkipDir + } + return nil + } + target := filepath.Join(dst, rel) + if d.IsDir() { + return os.MkdirAll(target, 0o755) + } + data, err := os.ReadFile(path) + if err != nil { + return nil + } + if err := os.WriteFile(target, data, 0o644); err != nil { + return err + } + count++ + return nil + }) + return count, err +} + +// dirsEqual returns true if two directory trees have identical file contents. +func dirsEqual(a, b string) bool { + aFiles := map[string][]byte{} + bFiles := map[string][]byte{} + collectFiles := func(root string, m map[string][]byte) { + _ = filepath.WalkDir(root, func(path string, d fs.DirEntry, _ error) error { + if d == nil || d.IsDir() { + return nil + } + rel, _ := filepath.Rel(root, path) + data, err := os.ReadFile(path) + if err != nil { + return nil + } + m[rel] = data + return nil + }) + } + collectFiles(a, aFiles) + collectFiles(b, bFiles) + if len(aFiles) != len(bFiles) { + return false + } + for k, va := range aFiles { + vb, ok := bFiles[k] + if !ok || string(va) != string(vb) { + return false + } + } + return true +} + +// SkillIntegrator handles integration of SKILL.md-based packages. +type SkillIntegrator struct { + mu sync.Mutex + nativeSkillSessionOwners map[string]string +} + +// New returns a new SkillIntegrator. +func New() *SkillIntegrator { + return &SkillIntegrator{ + nativeSkillSessionOwners: map[string]string{}, + } +} + +// allKnownTargets returns a slice of all known target profiles. +func allKnownTargets() []*targets.TargetProfile { + out := make([]*targets.TargetProfile, 0, len(targets.KnownTargets)) + for _, t := range targets.KnownTargets { + out = append(out, t) + } + return out +} + +// FindInstructionFiles returns all .instructions.md files from .apm/instructions/. +func FindInstructionFiles(packagePath string) []string { + dir := filepath.Join(packagePath, ".apm", "instructions") + entries, err := os.ReadDir(dir) + if err != nil { + return nil + } + var out []string + for _, e := range entries { + if !e.IsDir() && strings.HasSuffix(e.Name(), ".instructions.md") { + out = append(out, filepath.Join(dir, e.Name())) + } + } + return out +} + +// FindAgentFiles returns all .agent.md files from .apm/agents/. +func FindAgentFiles(packagePath string) []string { + dir := filepath.Join(packagePath, ".apm", "agents") + entries, err := os.ReadDir(dir) + if err != nil { + return nil + } + var out []string + for _, e := range entries { + if !e.IsDir() && strings.HasSuffix(e.Name(), ".agent.md") { + out = append(out, filepath.Join(dir, e.Name())) + } + } + return out +} + +// FindPromptFiles returns all .prompt.md files from package root and .apm/prompts/. +func FindPromptFiles(packagePath string) []string { + var out []string + entries, err := os.ReadDir(packagePath) + if err == nil { + for _, e := range entries { + if !e.IsDir() && strings.HasSuffix(e.Name(), ".prompt.md") { + out = append(out, filepath.Join(packagePath, e.Name())) + } + } + } + dir := filepath.Join(packagePath, ".apm", "prompts") + if entries2, err := os.ReadDir(dir); err == nil { + for _, e := range entries2 { + if !e.IsDir() && strings.HasSuffix(e.Name(), ".prompt.md") { + out = append(out, filepath.Join(dir, e.Name())) + } + } + } + return out +} + +// FindContextFiles returns all context and memory files. +func FindContextFiles(packagePath string) []string { + var out []string + for _, sub := range []string{".apm/context", ".apm/memory"} { + dir := filepath.Join(packagePath, sub) + entries, err := os.ReadDir(dir) + if err != nil { + continue + } + suffix := ".context.md" + if strings.HasSuffix(sub, "memory") { + suffix = ".memory.md" + } + for _, e := range entries { + if !e.IsDir() && strings.HasSuffix(e.Name(), suffix) { + out = append(out, filepath.Join(dir, e.Name())) + } + } + } + return out +} + +// PackageInfo is a minimal interface for package metadata used by skill integration. +type PackageInfo struct { + InstallPath string + PackageType string // "CLAUDE_SKILL", "HYBRID", "SKILL_BUNDLE", "MARKETPLACE_PLUGIN", "INSTRUCTIONS", "PROMPTS" + IsVirtual bool + IsSubdir bool + UniqueKey string +} + +// shouldInstallSkill returns true for packages that should be installed as skills. +func shouldInstallSkill(pkg *PackageInfo) bool { + switch pkg.PackageType { + case "CLAUDE_SKILL", "HYBRID", "SKILL_BUNDLE", "MARKETPLACE_PLUGIN": + return true + } + return false +} + +// promoteSubSkills promotes sub-skills from .apm/skills/ to a target skills root. +func promoteSubSkills( + subSkillsDir string, + targetSkillsRoot string, + parentName string, + ownedBy map[string]string, + managedFiles map[string]struct{}, + force bool, + nameFilter map[string]struct{}, +) (int, []string) { + entries, err := os.ReadDir(subSkillsDir) + if err != nil { + return 0, nil + } + promoted := 0 + var deployed []string + for _, e := range entries { + if !e.IsDir() { + continue + } + subPath := filepath.Join(subSkillsDir, e.Name()) + if _, err := os.Stat(filepath.Join(subPath, "SKILL.md")); err != nil { + continue + } + rawName := e.Name() + if nameFilter != nil { + if _, ok := nameFilter[rawName]; !ok { + continue + } + } + valid, _ := ValidateSkillName(rawName) + subName := rawName + if !valid { + subName = NormalizeSkillName(rawName) + } + target := filepath.Join(targetSkillsRoot, subName) + if _, err := os.Stat(target); err == nil { + if dirsEqual(subPath, target) { + promoted++ + deployed = append(deployed, target) + continue + } + relPath := filepath.Join(filepath.Base(targetSkillsRoot), subName) + isManaged := false + if managedFiles != nil { + norm := strings.ReplaceAll(relPath, "\\", "/") + _, isManaged = managedFiles[norm] + } + prevOwner := ownedBy[subName] + isSelfOverwrite := prevOwner != "" && prevOwner == parentName + if managedFiles != nil && !isManaged && !isSelfOverwrite && !force { + continue + } + _ = os.RemoveAll(target) + } + if err := os.MkdirAll(target, 0o755); err != nil { + continue + } + if _, err := copyDirSkill(subPath, target); err != nil { + continue + } + promoted++ + deployed = append(deployed, target) + } + return promoted, deployed +} + +// IntegrateNativeSkill deploys a package with a root SKILL.md to all active targets. +func (si *SkillIntegrator) IntegrateNativeSkill( + pkg *PackageInfo, + projectRoot string, + force bool, + managedFiles map[string]struct{}, + allTargets []*targets.TargetProfile, +) *SkillIntegrationResult { + packagePath := pkg.InstallPath + rawSkillName := filepath.Base(packagePath) + valid, _ := ValidateSkillName(rawSkillName) + skillName := rawSkillName + if !valid { + skillName = NormalizeSkillName(rawSkillName) + } + + if allTargets == nil { + allTargets = targets.ActiveTargets(projectRoot, nil) + } + skillCreated := false + skillUpdated := false + filesCopied := 0 + var allTargetPaths []string + var primarySkillMD string + + seen := map[string]bool{} + + for idx, tgt := range allTargets { + if !tgt.Supports("skills") { + continue + } + sm := tgt.Primitives["skills"] + effectiveRoot := sm.DeployRoot + if effectiveRoot == "" { + effectiveRoot = tgt.RootDir + } + targetSkillDir := filepath.Join(projectRoot, effectiveRoot, "skills", skillName) + // path security: no traversal + if strings.Contains(skillName, "..") { + continue + } + resolved, _ := filepath.EvalSymlinks(targetSkillDir) + if resolved == "" { + resolved = targetSkillDir + } + if seen[resolved] { + continue + } + seen[resolved] = true + + isPrimary := idx == 0 + if isPrimary { + if _, err := os.Stat(targetSkillDir); os.IsNotExist(err) { + skillCreated = true + } else { + skillUpdated = true + } + primarySkillMD = filepath.Join(targetSkillDir, "SKILL.md") + } + + _ = os.RemoveAll(targetSkillDir) + _ = os.MkdirAll(filepath.Dir(targetSkillDir), 0o755) + n, _ := copyDirSkill(packagePath, targetSkillDir) + allTargetPaths = append(allTargetPaths, targetSkillDir) + if isPrimary { + filesCopied = n + } + + // Promote sub-skills + subSkillsDir := filepath.Join(packagePath, ".apm", "skills") + targetSkillsRoot := filepath.Join(projectRoot, effectiveRoot, "skills") + _, subDeployed := promoteSubSkills(subSkillsDir, targetSkillsRoot, skillName, nil, managedFiles, force, nil) + allTargetPaths = append(allTargetPaths, subDeployed...) + _ = subDeployed + } + + si.mu.Lock() + if pkg.UniqueKey != "" { + si.nativeSkillSessionOwners[skillName] = pkg.UniqueKey + } + si.mu.Unlock() + + primaryRoot := filepath.Join(projectRoot, ".github", "skills") + subSkillsCount := 0 + for _, p := range allTargetPaths { + if filepath.Dir(p) == primaryRoot && filepath.Base(p) != skillName { + subSkillsCount++ + } + } + + return &SkillIntegrationResult{ + SkillCreated: skillCreated, + SkillUpdated: skillUpdated, + SkillSkipped: false, + SkillPath: primarySkillMD, + ReferencesCopied: filesCopied, + SubSkillsPromoted: subSkillsCount, + TargetPaths: allTargetPaths, + } +} + +// IntegrateSkillBundle promotes every skill in a root-level skills/ directory. +func (si *SkillIntegrator) IntegrateSkillBundle( + pkg *PackageInfo, + projectRoot string, + skillsDir string, + force bool, + managedFiles map[string]struct{}, + allTargets []*targets.TargetProfile, + nameFilter map[string]struct{}, +) *SkillIntegrationResult { + if allTargets == nil { + allTargets = targets.ActiveTargets(projectRoot, nil) + } + parentName := filepath.Base(pkg.InstallPath) + totalPromoted := 0 + var allDeployed []string + anyCreated := false + seen := map[string]bool{} + + for idx, tgt := range allTargets { + if !tgt.Supports("skills") { + continue + } + sm := tgt.Primitives["skills"] + effectiveRoot := sm.DeployRoot + if effectiveRoot == "" { + effectiveRoot = tgt.RootDir + } + targetSkillsRoot := filepath.Join(projectRoot, effectiveRoot, "skills") + resolved, _ := filepath.EvalSymlinks(targetSkillsRoot) + if resolved == "" { + resolved = targetSkillsRoot + } + if seen[resolved] { + continue + } + seen[resolved] = true + _ = os.MkdirAll(targetSkillsRoot, 0o755) + + isPrimary := idx == 0 + n, deployed := promoteSubSkills(skillsDir, targetSkillsRoot, parentName, nil, managedFiles, force, nameFilter) + if isPrimary { + totalPromoted = n + if n > 0 { + anyCreated = true + } + } + allDeployed = append(allDeployed, deployed...) + } + + return &SkillIntegrationResult{ + SkillCreated: anyCreated, + SkillSkipped: false, + SubSkillsPromoted: totalPromoted, + TargetPaths: allDeployed, + } +} + +// PromoteSubSkillsStandalone promotes sub-skills for non-skill packages. +func (si *SkillIntegrator) PromoteSubSkillsStandalone( + pkg *PackageInfo, + projectRoot string, + force bool, + managedFiles map[string]struct{}, + allTargets []*targets.TargetProfile, +) (int, []string) { + subSkillsDir := filepath.Join(pkg.InstallPath, ".apm", "skills") + if _, err := os.Stat(subSkillsDir); err != nil { + return 0, nil + } + if allTargets == nil { + allTargets = targets.ActiveTargets(projectRoot, nil) + } + parentName := filepath.Base(pkg.InstallPath) + count := 0 + var allDeployed []string + seen := map[string]bool{} + + for idx, tgt := range allTargets { + if !tgt.Supports("skills") { + continue + } + sm := tgt.Primitives["skills"] + effectiveRoot := sm.DeployRoot + if effectiveRoot == "" { + effectiveRoot = tgt.RootDir + } + targetSkillsRoot := filepath.Join(projectRoot, effectiveRoot, "skills") + resolved, _ := filepath.EvalSymlinks(targetSkillsRoot) + if resolved == "" { + resolved = targetSkillsRoot + } + if seen[resolved] { + continue + } + seen[resolved] = true + _ = os.MkdirAll(targetSkillsRoot, 0o755) + + isPrimary := idx == 0 + n, deployed := promoteSubSkills(subSkillsDir, targetSkillsRoot, parentName, nil, managedFiles, force, nil) + if isPrimary { + count = n + } + allDeployed = append(allDeployed, deployed...) + } + return count, allDeployed +} + +// IntegratePackageSkill is the main entry point for skill integration. +func (si *SkillIntegrator) IntegratePackageSkill( + pkg *PackageInfo, + projectRoot string, + force bool, + managedFiles map[string]struct{}, + allTargets []*targets.TargetProfile, + skillSubset []string, +) *SkillIntegrationResult { + if !shouldInstallSkill(pkg) { + subCount, subDeployed := si.PromoteSubSkillsStandalone(pkg, projectRoot, force, managedFiles, allTargets) + return &SkillIntegrationResult{ + SkillSkipped: true, + SubSkillsPromoted: subCount, + TargetPaths: subDeployed, + } + } + + if pkg.IsVirtual && !pkg.IsSubdir { + return &SkillIntegrationResult{SkillSkipped: true} + } + + sourceSkillMD := filepath.Join(pkg.InstallPath, "SKILL.md") + if _, err := os.Stat(sourceSkillMD); err == nil { + return si.IntegrateNativeSkill(pkg, projectRoot, force, managedFiles, allTargets) + } + + // Check for SKILL_BUNDLE + rootSkillsDir := filepath.Join(pkg.InstallPath, "skills") + if info, err := os.Stat(rootSkillsDir); err == nil && info.IsDir() { + var nameFilter map[string]struct{} + if len(skillSubset) > 0 { + nameFilter = make(map[string]struct{}, len(skillSubset)) + for _, s := range skillSubset { + nameFilter[s] = struct{}{} + } + } + hasSkill := false + entries, _ := os.ReadDir(rootSkillsDir) + for _, e := range entries { + if e.IsDir() { + if _, err := os.Stat(filepath.Join(rootSkillsDir, e.Name(), "SKILL.md")); err == nil { + hasSkill = true + break + } + } + } + if hasSkill { + return si.IntegrateSkillBundle(pkg, projectRoot, rootSkillsDir, force, managedFiles, allTargets, nameFilter) + } + } + + subCount, subDeployed := si.PromoteSubSkillsStandalone(pkg, projectRoot, force, managedFiles, allTargets) + return &SkillIntegrationResult{ + SkillSkipped: true, + SubSkillsPromoted: subCount, + TargetPaths: subDeployed, + } +} + +// SyncStats holds cleanup statistics. +type SyncStats struct { + FilesRemoved int + Errors int +} + +// SyncIntegration removes orphaned skill directories. +func (si *SkillIntegrator) SyncIntegration( + installedSkillNames map[string]struct{}, + projectRoot string, + managedFiles map[string]struct{}, + allTargets []*targets.TargetProfile, +) SyncStats { + if allTargets == nil { + allTargets = allKnownTargets() + } + var stats SyncStats + + if managedFiles != nil { + skillPrefixes := skillPrefixList(allTargets) + projectResolved, _ := filepath.EvalSymlinks(projectRoot) + if projectResolved == "" { + projectResolved = projectRoot + } + for relPath := range managedFiles { + norm := strings.ReplaceAll(relPath, "\\", "/") + if strings.Contains(norm, "..") { + continue + } + if !hasAnyPrefix(norm, skillPrefixes) { + continue + } + target := filepath.Join(projectRoot, relPath) + if _, err := os.Stat(target); err != nil { + continue + } + info, err := os.Lstat(target) + if err != nil { + continue + } + if info.IsDir() { + if err := os.RemoveAll(target); err != nil { + stats.Errors++ + } else { + stats.FilesRemoved++ + } + } else { + if err := os.Remove(target); err != nil { + stats.Errors++ + } else { + stats.FilesRemoved++ + } + } + } + return stats + } + + // Legacy: npm-style orphan detection + seen := map[string]bool{} + for _, tgt := range allTargets { + if !tgt.Supports("skills") { + continue + } + sm := tgt.Primitives["skills"] + effectiveRoot := sm.DeployRoot + if effectiveRoot == "" { + effectiveRoot = tgt.RootDir + } + skillsDir := filepath.Join(projectRoot, effectiveRoot, "skills") + resolved, _ := filepath.EvalSymlinks(skillsDir) + if resolved == "" { + resolved = skillsDir + } + if seen[resolved] { + continue + } + seen[resolved] = true + entries, err := os.ReadDir(skillsDir) + if err != nil { + continue + } + for _, e := range entries { + if !e.IsDir() { + continue + } + if _, ok := installedSkillNames[e.Name()]; ok { + continue + } + target := filepath.Join(skillsDir, e.Name()) + if err := os.RemoveAll(target); err != nil { + stats.Errors++ + } else { + stats.FilesRemoved++ + } + } + } + return stats +} + +func skillPrefixList(allTargets []*targets.TargetProfile) []string { + var out []string + for _, tgt := range allTargets { + if !tgt.Supports("skills") { + continue + } + sm := tgt.Primitives["skills"] + effectiveRoot := sm.DeployRoot + if effectiveRoot == "" { + effectiveRoot = tgt.RootDir + } + out = append(out, effectiveRoot+"/skills/") + } + return out +} + +func hasAnyPrefix(s string, prefixes []string) bool { + for _, p := range prefixes { + if strings.HasPrefix(s, p) { + return true + } + } + return false +} diff --git a/internal/integration/targets/targets.go b/internal/integration/targets/targets.go new file mode 100644 index 0000000..14f386d --- /dev/null +++ b/internal/integration/targets/targets.go @@ -0,0 +1,471 @@ +// Package targets defines the registry of known integration target profiles +// (Copilot, Claude, Cursor, etc.) and helpers for target resolution. +// +// Migrated from src/apm_cli/integration/targets.py +package targets + +import ( +"os" +"path/filepath" +"strings" +) + +// PrimitiveMapping describes where a single primitive type is deployed. +type PrimitiveMapping struct { +Subdir string // subdirectory under target root +Extension string // file extension or suffix +FormatID string // opaque transformer tag +DeployRoot string // optional root override (empty = use target root) +} + +// TargetProfile describes capabilities and layout of a single target tool. +type TargetProfile struct { +Name string +RootDir string +Primitives map[string]PrimitiveMapping + +AutoCreate bool +DetectByDir bool + +UserSupported interface{} // bool or "partial" +UserRootDir string +UnsupportedUserPrimitives []string +RequiresFlag string +GeneratedFiles []string +PackPrefixes []string +CompileFamily string +HooksConfigDisplay string + +// Set by ForScope for dynamic-root targets. +ResolvedDeployRoot string +} + +// Prefix returns the path prefix for this target (e.g. ".github/"). +func (t *TargetProfile) Prefix() string { +return t.RootDir + "/" +} + +// EffectivePackPrefixes returns the path prefixes used by pack-time filtering. +func (t *TargetProfile) EffectivePackPrefixes() []string { +if len(t.PackPrefixes) > 0 { +return t.PackPrefixes +} +return []string{t.Prefix()} +} + +// Supports returns true if this target accepts the primitive. +func (t *TargetProfile) Supports(primitive string) bool { +_, ok := t.Primitives[primitive] +return ok +} + +// EffectiveRoot returns the root directory for the given scope. +func (t *TargetProfile) EffectiveRoot(userScope bool) string { +if userScope && t.UserRootDir != "" { +return t.UserRootDir +} +return t.RootDir +} + +// SupportsAtUserScope returns true if the primitive can be deployed at user scope. +func (t *TargetProfile) SupportsAtUserScope(primitive string) bool { +if t.UserSupported == false || t.UserSupported == nil { +return false +} +for _, u := range t.UnsupportedUserPrimitives { +if u == primitive { +return false +} +} +return t.Supports(primitive) +} + +// DeployPath returns the filesystem path for deployment. +func (t *TargetProfile) DeployPath(projectRoot string, parts ...string) string { +if t.ResolvedDeployRoot != "" { +base := t.ResolvedDeployRoot +if len(parts) > 0 { +return filepath.Join(append([]string{base}, parts...)...) +} +return base +} +base := filepath.Join(projectRoot, t.RootDir) +if len(parts) > 0 { +return filepath.Join(append([]string{base}, parts...)...) +} +return base +} + +// ForScope returns a scope-resolved copy of this profile. +// Returns nil if the target does not support user scope. +func (t *TargetProfile) ForScope(userScope bool) *TargetProfile { +if !userScope { +cp := *t +return &cp +} + +// Check user_supported +switch v := t.UserSupported.(type) { +case bool: +if !v { +return nil +} +case string: +if v != "partial" { +return nil +} +case nil: +return nil +} + +cp := *t +newRoot := t.UserRootDir +if newRoot == "" { +newRoot = t.RootDir +} + +// Claude Code honors CLAUDE_CONFIG_DIR +if t.Name == "claude" { +if env := strings.TrimSpace(os.Getenv("CLAUDE_CONFIG_DIR")); env != "" { +home, _ := os.UserHomeDir() +abs := filepath.Clean(env) +if rel, err := filepath.Rel(home, abs); err == nil && !strings.HasPrefix(rel, "..") { +newRoot = filepath.ToSlash(rel) +} else { +newRoot = abs +} +} +} + +cp.RootDir = newRoot + +// Filter unsupported user primitives +if len(t.UnsupportedUserPrimitives) > 0 { +filtered := make(map[string]PrimitiveMapping) +unsup := make(map[string]bool, len(t.UnsupportedUserPrimitives)) +for _, u := range t.UnsupportedUserPrimitives { +unsup[u] = true +} +for k, v := range t.Primitives { +if !unsup[k] { +filtered[k] = v +} +} +cp.Primitives = filtered +} + +return &cp +} + +// ShouldUseLegacySkillPaths returns true when APM_LEGACY_SKILL_PATHS is set. +func ShouldUseLegacySkillPaths() bool { +val := strings.ToLower(strings.TrimSpace(os.Getenv("APM_LEGACY_SKILL_PATHS"))) +return val == "1" || val == "true" || val == "yes" +} + +// ApplyLegacySkillPaths resets deploy_root on every skills primitive. +func ApplyLegacySkillPaths(profiles []*TargetProfile) []*TargetProfile { +result := make([]*TargetProfile, len(profiles)) +for i, p := range profiles { +if pm, ok := p.Primitives["skills"]; ok && pm.DeployRoot != "" { +cp := *p +prims := make(map[string]PrimitiveMapping, len(p.Primitives)) +for k, v := range p.Primitives { +prims[k] = v +} +pm.DeployRoot = "" +prims["skills"] = pm +cp.Primitives = prims +result[i] = &cp +} else { +result[i] = p +} +} +return result +} + +// KnownTargets is the registry of all known integration targets. +var KnownTargets = map[string]*TargetProfile{ +"copilot": { +Name: "copilot", +RootDir: ".github", +Primitives: map[string]PrimitiveMapping{ +"instructions": {Subdir: "instructions", Extension: ".instructions.md", FormatID: "github_instructions"}, +"prompts": {Subdir: "prompts", Extension: ".prompt.md", FormatID: "github_prompt"}, +"agents": {Subdir: "agents", Extension: ".agent.md", FormatID: "github_agent"}, +"skills": {Subdir: "skills", Extension: "/SKILL.md", FormatID: "skill_standard", DeployRoot: ".agents"}, +"hooks": {Subdir: "hooks", Extension: ".json", FormatID: "github_hooks"}, +}, +AutoCreate: true, +DetectByDir: true, +UserSupported: "partial", +UserRootDir: ".copilot", +UnsupportedUserPrimitives: []string{"prompts", "instructions"}, +GeneratedFiles: []string{"copilot-instructions.md"}, +CompileFamily: "vscode", +}, +"claude": { +Name: "claude", +RootDir: ".claude", +Primitives: map[string]PrimitiveMapping{ +"instructions": {Subdir: "rules", Extension: ".md", FormatID: "claude_rules"}, +"agents": {Subdir: "agents", Extension: ".md", FormatID: "claude_agent"}, +"commands": {Subdir: "commands", Extension: ".md", FormatID: "claude_command"}, +"skills": {Subdir: "skills", Extension: "/SKILL.md", FormatID: "skill_standard"}, +"hooks": {Subdir: "hooks", Extension: ".json", FormatID: "claude_hooks"}, +}, +AutoCreate: false, +DetectByDir: true, +UserSupported: true, +CompileFamily: "claude", +HooksConfigDisplay: ".claude/settings.json", +}, +"cursor": { +Name: "cursor", +RootDir: ".cursor", +Primitives: map[string]PrimitiveMapping{ +"instructions": {Subdir: "rules", Extension: ".mdc", FormatID: "cursor_rules"}, +"agents": {Subdir: "agents", Extension: ".md", FormatID: "cursor_agent"}, +"commands": {Subdir: "commands", Extension: ".md", FormatID: "claude_command"}, +"skills": {Subdir: "skills", Extension: "/SKILL.md", FormatID: "skill_standard", DeployRoot: ".agents"}, +"hooks": {Subdir: "hooks", Extension: ".json", FormatID: "cursor_hooks"}, +}, +AutoCreate: false, +DetectByDir: true, +UserSupported: "partial", +UserRootDir: ".cursor", +UnsupportedUserPrimitives: []string{"instructions"}, +CompileFamily: "agents", +HooksConfigDisplay: ".cursor/hooks.json", +}, +"opencode": { +Name: "opencode", +RootDir: ".opencode", +Primitives: map[string]PrimitiveMapping{ +"agents": {Subdir: "agents", Extension: ".md", FormatID: "opencode_agent"}, +"commands": {Subdir: "commands", Extension: ".md", FormatID: "opencode_command"}, +"skills": {Subdir: "skills", Extension: "/SKILL.md", FormatID: "skill_standard", DeployRoot: ".agents"}, +}, +AutoCreate: false, +DetectByDir: true, +UserSupported: "partial", +UserRootDir: ".config/opencode", +UnsupportedUserPrimitives: []string{"hooks"}, +CompileFamily: "agents", +}, +"gemini": { +Name: "gemini", +RootDir: ".gemini", +Primitives: map[string]PrimitiveMapping{ +"commands": {Subdir: "commands", Extension: ".toml", FormatID: "gemini_command"}, +"skills": {Subdir: "skills", Extension: "/SKILL.md", FormatID: "skill_standard", DeployRoot: ".agents"}, +"hooks": {Subdir: "hooks", Extension: ".json", FormatID: "gemini_hooks"}, +}, +AutoCreate: false, +DetectByDir: true, +UserSupported: true, +UserRootDir: ".gemini", +CompileFamily: "gemini", +HooksConfigDisplay: ".gemini/settings.json", +}, +"codex": { +Name: "codex", +RootDir: ".codex", +Primitives: map[string]PrimitiveMapping{ +"agents": {Subdir: "agents", Extension: ".toml", FormatID: "codex_agent"}, +"skills": {Subdir: "skills", Extension: "/SKILL.md", FormatID: "skill_standard", DeployRoot: ".agents"}, +"hooks": {Subdir: "", Extension: "hooks.json", FormatID: "codex_hooks"}, +}, +AutoCreate: false, +DetectByDir: true, +UserSupported: "partial", +PackPrefixes: []string{".codex/", ".agents/"}, +CompileFamily: "agents", +HooksConfigDisplay: ".codex/hooks.json", +}, +"windsurf": { +Name: "windsurf", +RootDir: ".windsurf", +Primitives: map[string]PrimitiveMapping{ +"instructions": {Subdir: "rules", Extension: ".md", FormatID: "windsurf_rules"}, +"agents": {Subdir: "skills", Extension: "/SKILL.md", FormatID: "windsurf_agent_skill"}, +"skills": {Subdir: "skills", Extension: "/SKILL.md", FormatID: "skill_standard"}, +"commands": {Subdir: "workflows", Extension: ".md", FormatID: "windsurf_workflow"}, +"hooks": {Subdir: "", Extension: "hooks.json", FormatID: "windsurf_hooks"}, +}, +AutoCreate: false, +DetectByDir: true, +UserSupported: "partial", +UserRootDir: ".codeium/windsurf", +UnsupportedUserPrimitives: []string{"instructions"}, +CompileFamily: "agents", +HooksConfigDisplay: ".windsurf/hooks.json", +}, +"agent-skills": { +Name: "agent-skills", +RootDir: ".agents", +Primitives: map[string]PrimitiveMapping{ +"skills": {Subdir: "skills", Extension: "/SKILL.md", FormatID: "skill_standard"}, +}, +AutoCreate: true, +DetectByDir: false, +UserSupported: true, +UserRootDir: ".agents", +}, +"copilot-cowork": { +Name: "copilot-cowork", +RootDir: "copilot-cowork", +Primitives: map[string]PrimitiveMapping{ +"skills": {Subdir: "skills", Extension: "/SKILL.md", FormatID: "skill_standard"}, +}, +AutoCreate: false, +DetectByDir: false, +UserSupported: true, +RequiresFlag: "copilot_cowork", +}, +} + +// GetIntegrationPrefixes returns all known target root prefixes. +func GetIntegrationPrefixes(profiles []*TargetProfile) []string { +source := profiles +if source == nil { +for _, p := range KnownTargets { +source = append(source, p) +} +} +seen := make(map[string]bool) +var prefixes []string +for _, t := range source { +// Dynamic-root targets (cowork) use cowork:// prefix +if t.RequiresFlag == "copilot_cowork" { +const coworkPrefix = "cowork://" +if !seen[coworkPrefix] { +seen[coworkPrefix] = true +prefixes = append(prefixes, coworkPrefix) +} +continue +} +if !seen[t.Prefix()] { +seen[t.Prefix()] = true +prefixes = append(prefixes, t.Prefix()) +} +for _, m := range t.Primitives { +if m.DeployRoot != "" { +dp := m.DeployRoot + "/" +if !seen[dp] { +seen[dp] = true +prefixes = append(prefixes, dp) +} +} +} +} +return prefixes +} + +// ActiveTargets returns the target profiles that should be deployed into projectRoot. +// Resolution order: explicit target -> directory detection -> fallback (copilot). +func ActiveTargets(projectRoot string, explicitTargets []string) []*TargetProfile { +if len(explicitTargets) > 0 { +profiles := make([]*TargetProfile, 0) +seen := make(map[string]bool) +for _, t := range explicitTargets { +canonical := t +if t == "vscode" || t == "agents" { +canonical = "copilot" +} +if canonical == "all" { +var all []*TargetProfile +for _, p := range KnownTargets { +if p.Name != "agent-skills" && p.Name != "copilot-cowork" { +all = append(all, p) +} +} +return all +} +if p, ok := KnownTargets[canonical]; ok && !seen[canonical] { +seen[canonical] = true +profiles = append(profiles, p) +} +} +return profiles +} + +// Auto-detect by directory presence +var detected []*TargetProfile +for _, p := range KnownTargets { +if p.DetectByDir { +if fi, err := os.Stat(filepath.Join(projectRoot, p.RootDir)); err == nil && fi.IsDir() { +detected = append(detected, p) +} +} +} +if len(detected) > 0 { +return detected +} +return []*TargetProfile{KnownTargets["copilot"]} +} + +// ResolveTargets returns scope-resolved target profiles. +func ResolveTargets(projectRoot string, userScope bool, explicitTargets []string) []*TargetProfile { +var raw []*TargetProfile +if userScope { +raw = activeTargetsUserScope(explicitTargets) +} else { +raw = ActiveTargets(projectRoot, explicitTargets) +} +resolved := make([]*TargetProfile, 0, len(raw)) +for _, t := range raw { +scoped := t.ForScope(userScope) +if scoped != nil { +resolved = append(resolved, scoped) +} +} +return resolved +} + +func activeTargetsUserScope(explicitTargets []string) []*TargetProfile { +home, _ := os.UserHomeDir() + +if len(explicitTargets) > 0 { +profiles := make([]*TargetProfile, 0) +seen := make(map[string]bool) +for _, t := range explicitTargets { +canonical := t +if t == "vscode" || t == "agents" { +canonical = "copilot" +} +if canonical == "all" { +var all []*TargetProfile +for _, p := range KnownTargets { +if p.UserSupported != nil && p.UserSupported != false && p.Name != "copilot-cowork" { +all = append(all, p) +} +} +return all +} +if p, ok := KnownTargets[canonical]; ok { +us := p.UserSupported +if (us == true || us == "partial") && !seen[canonical] { +seen[canonical] = true +profiles = append(profiles, p) +} +} +} +return profiles +} + +var detected []*TargetProfile +for _, p := range KnownTargets { +us := p.UserSupported +if (us == true || us == "partial") && p.DetectByDir { +root := p.EffectiveRoot(true) +if fi, err := os.Stat(filepath.Join(home, root)); err == nil && fi.IsDir() { +detected = append(detected, p) +} +} +} +if len(detected) > 0 { +return detected +} +return []*TargetProfile{KnownTargets["copilot"]} +} diff --git a/internal/integration/targets/targets_test.go b/internal/integration/targets/targets_test.go new file mode 100644 index 0000000..4a89d43 --- /dev/null +++ b/internal/integration/targets/targets_test.go @@ -0,0 +1,109 @@ +package targets + +import ( +"testing" +) + +func TestKnownTargetsRegistered(t *testing.T) { +expected := []string{"copilot", "claude", "cursor", "opencode", "gemini", "codex", "windsurf", "agent-skills", "copilot-cowork"} +for _, name := range expected { +if _, ok := KnownTargets[name]; !ok { +t.Errorf("missing target %q", name) +} +} +} + +func TestTargetPrefix(t *testing.T) { +tgt := KnownTargets["copilot"] +if got := tgt.Prefix(); got != ".github/" { +t.Errorf("expected .github/, got %s", got) +} +} + +func TestTargetSupports(t *testing.T) { +tgt := KnownTargets["copilot"] +if !tgt.Supports("skills") { +t.Error("copilot should support skills") +} +if tgt.Supports("nonexistent") { +t.Error("copilot should not support nonexistent") +} +} + +func TestForScopeProjectScope(t *testing.T) { +tgt := KnownTargets["copilot"] +scoped := tgt.ForScope(false) +if scoped == nil { +t.Fatal("ForScope(false) returned nil") +} +if scoped.RootDir != ".github" { +t.Errorf("expected .github, got %s", scoped.RootDir) +} +} + +func TestForScopeUserScopeCopilot(t *testing.T) { +tgt := KnownTargets["copilot"] +scoped := tgt.ForScope(true) +if scoped == nil { +t.Fatal("ForScope(true) returned nil") +} +if scoped.RootDir != ".copilot" { +t.Errorf("expected .copilot, got %s", scoped.RootDir) +} +// prompts and instructions should be filtered out +if scoped.Supports("prompts") { +t.Error("prompts should be filtered at user scope") +} +if scoped.Supports("instructions") { +t.Error("instructions should be filtered at user scope") +} +if !scoped.Supports("skills") { +t.Error("skills should remain at user scope") +} +} + +func TestForScopeNoUserSupport(t *testing.T) { +tgt := &TargetProfile{ +Name: "fake", +RootDir: ".fake", +UserSupported: false, +Primitives: map[string]PrimitiveMapping{}, +} +if scoped := tgt.ForScope(true); scoped != nil { +t.Error("expected nil for unsupported user scope") +} +} + +func TestApplyLegacySkillPaths(t *testing.T) { +profiles := []*TargetProfile{KnownTargets["copilot"], KnownTargets["claude"]} +result := ApplyLegacySkillPaths(profiles) +for _, p := range result { +if pm, ok := p.Primitives["skills"]; ok { +if pm.DeployRoot != "" { +t.Errorf("target %s: expected empty deploy_root after legacy, got %s", p.Name, pm.DeployRoot) +} +} +} +} + +func TestGetIntegrationPrefixes(t *testing.T) { +prefixes := GetIntegrationPrefixes(nil) +found := false +for _, p := range prefixes { +if p == ".github/" { +found = true +break +} +} +if !found { +t.Error("expected .github/ in prefixes") +} +} + +func TestActiveTargetsFallback(t *testing.T) { +// Non-existent project root -> should fallback to copilot +targets := ActiveTargets("/nonexistent/path", nil) +if len(targets) != 1 || targets[0].Name != "copilot" { +t.Errorf("expected fallback to copilot, got %v", targets) +} +} diff --git a/internal/marketplace/builder/builder.go b/internal/marketplace/builder/builder.go new file mode 100644 index 0000000..591da71 --- /dev/null +++ b/internal/marketplace/builder/builder.go @@ -0,0 +1,860 @@ +// Package builder provides the MarketplaceBuilder: load, resolve, compose, and write marketplace.json. +// Migrated from src/apm_cli/marketplace/builder.py. +package builder + +import ( + "encoding/json" + "errors" + "fmt" + "path/filepath" + "regexp" + "strings" + "sync" + + "github.com/githubnext/apm/internal/marketplace/mkio" + "github.com/githubnext/apm/internal/marketplace/refresolver" + "github.com/githubnext/apm/internal/marketplace/semver" + "github.com/githubnext/apm/internal/marketplace/tagpattern" + "github.com/githubnext/apm/internal/marketplace/ymlschema" + + "os" + "path" +) + +// BuildDiagnostic is a structured diagnostic emitted during marketplace.json composition. +type BuildDiagnostic struct { + Level string // "warning" | "verbose" + Message string +} + +// ResolvedPackage is a package entry after ref resolution. +type ResolvedPackage struct { + Name string + SourceRepo string // "owner/repo" only + Subdir string // APM-only (for git-subdir source object) + Ref string // resolved tag name, e.g. "v1.2.0" + SHA string // 40-char git SHA + RequestedVersion string // original APM-only range (for diagnostics) + Tags []string + IsPrerelease bool // True if the resolved ref was a prerelease semver +} + +// ResolveResult is the result of resolving package refs in a marketplace build. +type ResolveResult struct { + Entries []ResolvedPackage + Errors [][2]string // (package name, error message) pairs +} + +// OK returns true when every package resolved without error. +func (r ResolveResult) OK() bool { return len(r.Errors) == 0 } + +// BuildReport summarizes a build run. +type BuildReport struct { + Resolved []ResolvedPackage + Errors [][2]string + Warnings []string + Diagnostics []BuildDiagnostic + UnchangedCount int + AddedCount int + UpdatedCount int + RemovedCount int + OutputPath string + DryRun bool +} + +// BuildOptions holds configuration knobs for MarketplaceBuilder. +type BuildOptions struct { + Concurrency int + TimeoutSeconds float64 + IncludePrerelease bool + AllowHead bool + ContinueOnError bool + Offline bool + OutputOverride string + DryRun bool +} + +// DefaultBuildOptions returns sensible defaults. +func DefaultBuildOptions() BuildOptions { + return BuildOptions{ + Concurrency: 8, + TimeoutSeconds: 10.0, + } +} + +// sha40RE matches a 40-char hex SHA. +var sha40RE = regexp.MustCompile(`^[0-9a-f]{40}$`) + +// versionRangeChars are chars that indicate a range constraint rather than a display version. +var versionRangeChars = []byte{'^', '~', '>', '<', '='} + +func isDisplayVersion(version string) bool { + if version == "" { + return false + } + v := strings.TrimSpace(version) + for _, c := range versionRangeChars { + if v[0] == c { + return false + } + } + if strings.ContainsAny(v, " *") { + return false + } + parts := strings.Split(v, ".") + if len(parts) == 0 { + return false + } + last := strings.ToLower(parts[len(parts)-1]) + if last == "x" { + return false + } + return true +} + +// subtractPluginRoot removes pluginRoot prefix from a local source path. +func subtractPluginRoot(src, pluginRoot string) (string, error) { + normSrc := strings.TrimRight(strings.TrimLeft(src, "./"), "/") + normRoot := strings.TrimRight(strings.TrimLeft(pluginRoot, "./"), "/") + if !strings.HasPrefix(normSrc, normRoot) { + return "", fmt.Errorf("source '%s' does not start with pluginRoot '%s'", src, pluginRoot) + } + rel := strings.TrimPrefix(normSrc, normRoot) + rel = strings.TrimLeft(rel, "/") + if rel == "" || rel == "." { + return "", fmt.Errorf("subtracting pluginRoot '%s' from source '%s' yields empty path", pluginRoot, src) + } + if strings.HasPrefix(rel, "/") { + return "", fmt.Errorf("pluginRoot subtraction produced absolute path: '%s'", rel) + } + for _, seg := range strings.Split(rel, "/") { + if seg == ".." { + return "", fmt.Errorf("pluginRoot subtraction produced path with traversal: '%s'", rel) + } + } + return "./" + rel, nil +} + +// BuildError is raised on build failures. +type BuildError struct { + Msg string + Package string +} + +func (e *BuildError) Error() string { return e.Msg } + +// HeadNotAllowedError is raised when a branch ref is resolved without allow_head. +type HeadNotAllowedError struct { + Package string + Ref string +} + +func (e *HeadNotAllowedError) Error() string { + return fmt.Sprintf("package '%s': ref '%s' is a branch head; use allow_head to allow it", e.Package, e.Ref) +} + +// RefNotFoundError is raised when a ref cannot be found on the remote. +type RefNotFoundError struct { + Package string + Ref string + OwnerRepo string +} + +func (e *RefNotFoundError) Error() string { + return fmt.Sprintf("package '%s': ref '%s' not found on remote '%s'", e.Package, e.Ref, e.OwnerRepo) +} + +// NoMatchingVersionError is raised when no tag satisfies the semver range. +type NoMatchingVersionError struct { + Package string + VersionRange string + Detail string +} + +func (e *NoMatchingVersionError) Error() string { + return fmt.Sprintf("package '%s': no tag satisfies '%s' (%s)", e.Package, e.VersionRange, e.Detail) +} + +// MarketplaceBuilder loads, resolves, composes, and writes marketplace.json. +type MarketplaceBuilder struct { + ymlPath string + projectRoot string + options BuildOptions + yml *ymlschema.MarketplaceConfig + resolver *refresolver.RefResolver + githubToken string + host string + authResolved bool + + composeWarnings []string + composeDiagnostics []BuildDiagnostic +} + +// New constructs a MarketplaceBuilder for the given marketplace.yml path. +func New(marketplaceYMLPath string, options BuildOptions) *MarketplaceBuilder { + return &MarketplaceBuilder{ + ymlPath: marketplaceYMLPath, + projectRoot: filepath.Dir(marketplaceYMLPath), + options: options, + host: "github.com", + } +} + +// FromConfig constructs a builder from an already-loaded MarketplaceConfig. +func FromConfig(config *ymlschema.MarketplaceConfig, projectRoot string, options BuildOptions) *MarketplaceBuilder { + b := &MarketplaceBuilder{ + ymlPath: filepath.Join(projectRoot, "apm.yml"), + projectRoot: projectRoot, + options: options, + yml: config, + host: "github.com", + } + return b +} + +func (b *MarketplaceBuilder) loadYML() (*ymlschema.MarketplaceConfig, error) { + if b.yml != nil { + return b.yml, nil + } + isLegacy := path.Base(b.ymlPath) != "apm.yml" + cfg, err := ymlschema.LoadFromFile(b.ymlPath, isLegacy) + if err != nil { + return nil, err + } + b.yml = cfg + return b.yml, nil +} + +func (b *MarketplaceBuilder) ensureAuth() { + if b.authResolved { + return + } + if b.options.Offline { + b.authResolved = true + return + } + // Resolve GitHub token from env + for _, envVar := range []string{"GITHUB_APM_PAT", "GITHUB_TOKEN", "GH_TOKEN"} { + if t := os.Getenv(envVar); t != "" { + b.githubToken = t + break + } + } + b.authResolved = true +} + +func (b *MarketplaceBuilder) getResolver() *refresolver.RefResolver { + if b.resolver == nil { + b.ensureAuth() + b.resolver = refresolver.New(b.options.TimeoutSeconds, b.options.Offline, b.host, b.githubToken) + } + return b.resolver +} + +func (b *MarketplaceBuilder) outputPath(yml *ymlschema.MarketplaceConfig) (string, error) { + if b.options.OutputOverride != "" { + return b.options.OutputOverride, nil + } + outputPath := filepath.Join(b.projectRoot, yml.Output) + // containment guard + rel, err := filepath.Rel(b.projectRoot, outputPath) + if err != nil || strings.HasPrefix(rel, "..") { + return "", &BuildError{Msg: fmt.Sprintf("output path '%s' escapes project root", outputPath)} + } + return outputPath, nil +} + +// stripRefPrefix removes refs/tags/ or refs/heads/ prefix. +func stripRefPrefix(refname string) string { + if strings.HasPrefix(refname, "refs/tags/") { + return refname[len("refs/tags/"):] + } + if strings.HasPrefix(refname, "refs/heads/") { + return refname[len("refs/heads/"):] + } + return refname +} + +// resolveExplicitRef resolves an entry with an explicit ref: field. +func (b *MarketplaceBuilder) resolveExplicitRef(entry ymlschema.PackageEntry, resolver *refresolver.RefResolver) (ResolvedPackage, error) { + refText := entry.Ref + ownerRepo := entry.Source + + if sha40RE.MatchString(refText) { + sv, _ := semver.Parse(strings.TrimLeft(refText, "vV")) + isPrerelease := sv.Prerelease != "" + return ResolvedPackage{ + Name: entry.Name, + SourceRepo: ownerRepo, + Subdir: entry.Subdir, + Ref: refText, + SHA: refText, + RequestedVersion: entry.Version, + Tags: entry.Tags, + IsPrerelease: isPrerelease, + }, nil + } + + refs, err := resolver.ListRemoteRefs(ownerRepo) + if err != nil { + return ResolvedPackage{}, &BuildError{Msg: err.Error(), Package: entry.Name} + } + + // Try as tag first + for _, rr := range refs { + if !strings.HasPrefix(rr.Name, "refs/tags/") { + continue + } + tagName := stripRefPrefix(rr.Name) + if tagName == refText { + sv, _ := semver.Parse(strings.TrimLeft(tagName, "vV")) + return ResolvedPackage{ + Name: entry.Name, + SourceRepo: ownerRepo, + Subdir: entry.Subdir, + Ref: tagName, + SHA: rr.SHA, + RequestedVersion: entry.Version, + Tags: entry.Tags, + IsPrerelease: sv.Prerelease != "", + }, nil + } + } + + // Try as full refname + for _, rr := range refs { + if rr.Name == refText { + short := stripRefPrefix(rr.Name) + isBranch := strings.HasPrefix(rr.Name, "refs/heads/") + if isBranch && !b.options.AllowHead { + return ResolvedPackage{}, &HeadNotAllowedError{Package: entry.Name, Ref: short} + } + sv, _ := semver.Parse(strings.TrimLeft(short, "vV")) + return ResolvedPackage{ + Name: entry.Name, + SourceRepo: ownerRepo, + Subdir: entry.Subdir, + Ref: short, + SHA: rr.SHA, + RequestedVersion: entry.Version, + Tags: entry.Tags, + IsPrerelease: sv.Prerelease != "", + }, nil + } + } + + // Try as branch name + for _, rr := range refs { + if rr.Name == "refs/heads/"+refText { + if !b.options.AllowHead { + return ResolvedPackage{}, &HeadNotAllowedError{Package: entry.Name, Ref: refText} + } + return ResolvedPackage{ + Name: entry.Name, + SourceRepo: ownerRepo, + Subdir: entry.Subdir, + Ref: refText, + SHA: rr.SHA, + RequestedVersion: entry.Version, + Tags: entry.Tags, + IsPrerelease: false, + }, nil + } + } + + if strings.ToUpper(refText) == "HEAD" && !b.options.AllowHead { + return ResolvedPackage{}, &HeadNotAllowedError{Package: entry.Name, Ref: "HEAD"} + } + return ResolvedPackage{}, &RefNotFoundError{Package: entry.Name, Ref: refText, OwnerRepo: ownerRepo} +} + +// resolveVersionRange resolves an entry using its version: semver range. +func (b *MarketplaceBuilder) resolveVersionRange(entry ymlschema.PackageEntry, resolver *refresolver.RefResolver, yml *ymlschema.MarketplaceConfig) (ResolvedPackage, error) { + versionRange := entry.Version + ownerRepo := entry.Source + + pattern := entry.TagPattern + if pattern == "" { + pattern = yml.Build.TagPattern + } + if pattern == "" { + pattern = "v{version}" + } + + tagRx, err := tagpattern.BuildTagRegex(pattern) + if err != nil { + return ResolvedPackage{}, &BuildError{Msg: fmt.Sprintf("invalid tag pattern '%s': %v", pattern, err), Package: entry.Name} + } + + refs, err := resolver.ListRemoteRefs(ownerRepo) + if err != nil { + return ResolvedPackage{}, &BuildError{Msg: err.Error(), Package: entry.Name} + } + + type candidate struct { + sv semver.SemVer + tagName string + sha string + } + var candidates []candidate + + for _, rr := range refs { + if !strings.HasPrefix(rr.Name, "refs/tags/") { + continue + } + tagName := rr.Name[len("refs/tags/"):] + versionStr, ok := tagpattern.ExtractVersion(tagRx, tagName) + if !ok { + continue + } + sv, err := semver.Parse(versionStr) + if err != nil { + continue + } + includePrerelease := entry.IncludePrerelease || b.options.IncludePrerelease + if sv.Prerelease != "" && !includePrerelease { + continue + } + if semver.SatisfiesRange(sv, versionRange) { + candidates = append(candidates, candidate{sv: sv, tagName: tagName, sha: rr.SHA}) + } + } + + if len(candidates) == 0 { + return ResolvedPackage{}, &NoMatchingVersionError{ + Package: entry.Name, + VersionRange: versionRange, + Detail: fmt.Sprintf("pattern='%s', remote='%s'", pattern, ownerRepo), + } + } + + // Pick highest + best := candidates[0] + for _, c := range candidates[1:] { + if c.sv.Compare(best.sv) > 0 { + best = c + } + } + + return ResolvedPackage{ + Name: entry.Name, + SourceRepo: ownerRepo, + Subdir: entry.Subdir, + Ref: best.tagName, + SHA: best.sha, + RequestedVersion: versionRange, + Tags: entry.Tags, + IsPrerelease: best.sv.Prerelease != "", + }, nil +} + +// resolveEntry resolves a single package entry to a concrete tag + SHA. +func (b *MarketplaceBuilder) resolveEntry(entry ymlschema.PackageEntry, yml *ymlschema.MarketplaceConfig) (ResolvedPackage, error) { + if entry.IsLocal { + return ResolvedPackage{ + Name: entry.Name, + SourceRepo: "", + Subdir: entry.Source, + Ref: "", + SHA: "", + RequestedVersion: entry.Version, + Tags: entry.Tags, + IsPrerelease: false, + }, nil + } + resolver := b.getResolver() + if entry.Ref != "" { + return b.resolveExplicitRef(entry, resolver) + } + return b.resolveVersionRange(entry, resolver, yml) +} + +// Resolve resolves every entry concurrently. +func (b *MarketplaceBuilder) Resolve() (ResolveResult, error) { + yml, err := b.loadYML() + if err != nil { + return ResolveResult{}, err + } + entries := yml.Packages + if len(entries) == 0 { + return ResolveResult{}, nil + } + + // Eagerly create the resolver before spawning goroutines + b.getResolver() + + type indexedResult struct { + idx int + pkg ResolvedPackage + errPair [2]string + hasErr bool + } + + sem := make(chan struct{}, b.options.Concurrency) + if b.options.Concurrency <= 0 { + sem = make(chan struct{}, 8) + } + + resultCh := make(chan indexedResult, len(entries)) + var wg sync.WaitGroup + + for i, entry := range entries { + wg.Add(1) + go func(idx int, e ymlschema.PackageEntry) { + defer wg.Done() + sem <- struct{}{} + defer func() { <-sem }() + + pkg, resolveErr := b.resolveEntry(e, yml) + if resolveErr != nil { + var buildErr *BuildError + var headErr *HeadNotAllowedError + var refErr *RefNotFoundError + var noMatchErr *NoMatchingVersionError + if errors.As(resolveErr, &buildErr) || errors.As(resolveErr, &headErr) || + errors.As(resolveErr, &refErr) || errors.As(resolveErr, &noMatchErr) { + resultCh <- indexedResult{idx: idx, errPair: [2]string{e.Name, resolveErr.Error()}, hasErr: true} + return + } + resultCh <- indexedResult{idx: idx, errPair: [2]string{e.Name, resolveErr.Error()}, hasErr: true} + return + } + resultCh <- indexedResult{idx: idx, pkg: pkg} + }(i, entry) + } + + go func() { + wg.Wait() + close(resultCh) + }() + + results := make(map[int]ResolvedPackage) + var errs [][2]string + var firstErr error + + for r := range resultCh { + if r.hasErr { + errs = append(errs, r.errPair) + if !b.options.ContinueOnError && firstErr == nil { + firstErr = fmt.Errorf("error resolving '%s': %s", r.errPair[0], r.errPair[1]) + } + } else { + results[r.idx] = r.pkg + } + } + + if firstErr != nil { + return ResolveResult{}, firstErr + } + + ordered := make([]ResolvedPackage, 0, len(results)) + for idx := range entries { + if pkg, ok := results[idx]; ok { + ordered = append(ordered, pkg) + } + } + + return ResolveResult{Entries: ordered, Errors: errs}, nil +} + +// ComposeMarketplaceJSON produces an Anthropic-compliant marketplace.json dict. +func (b *MarketplaceBuilder) ComposeMarketplaceJSON(resolved []ResolvedPackage) (map[string]interface{}, error) { + yml, err := b.loadYML() + if err != nil { + return nil, err + } + + entryByName := make(map[string]*ymlschema.PackageEntry) + for i := range yml.Packages { + entryByName[yml.Packages[i].Name] = &yml.Packages[i] + } + + doc := make(map[string]interface{}) + doc["name"] = yml.Name + if yml.DescriptionOverridden && yml.Description != "" { + doc["description"] = yml.Description + } + if yml.VersionOverridden && yml.Version != "" { + doc["version"] = yml.Version + } + + ownerDict := make(map[string]interface{}) + ownerDict["name"] = yml.Owner.Name + if yml.Owner.Email != "" { + ownerDict["email"] = yml.Owner.Email + } + if yml.Owner.URL != "" { + ownerDict["url"] = yml.Owner.URL + } + doc["owner"] = ownerDict + + if len(yml.Metadata) > 0 { + doc["metadata"] = yml.Metadata + } + + var plugins []interface{} + var diagnostics []BuildDiagnostic + pluginRoot := "" + if m, ok := yml.Metadata["pluginRoot"]; ok { + if s, ok := m.(string); ok { + pluginRoot = s + } + } + stripCount := 0 + overrideCount := 0 + + for _, pkg := range resolved { + plugin := make(map[string]interface{}) + plugin["name"] = pkg.Name + + entry := entryByName[pkg.Name] + isLocal := entry != nil && entry.IsLocal + + if isLocal { + if entry.Description != "" { + plugin["description"] = entry.Description + } + if entry.Version != "" { + plugin["version"] = entry.Version + } + } else { + if entry != nil && entry.Description != "" { + plugin["description"] = entry.Description + } + if entry != nil && isDisplayVersion(entry.Version) { + plugin["version"] = entry.Version + } else if pkg.Ref != "" && isDisplayVersion(pkg.Ref) { + // Fallback: use resolved ref as display version if applicable + } + } + + if entry != nil && len(entry.Author) > 0 { + plugin["author"] = entry.Author + } + if entry != nil && entry.License != "" { + plugin["license"] = entry.License + } + if entry != nil && entry.Repository != "" { + plugin["repository"] = entry.Repository + } + if len(pkg.Tags) > 0 { + plugin["tags"] = pkg.Tags + } + if isLocal && entry != nil && entry.Homepage != "" { + plugin["homepage"] = entry.Homepage + } + + // source + if isLocal { + sourceValue := entry.Source + if pluginRoot != "" { + stripped, err := subtractPluginRoot(entry.Source, pluginRoot) + if err != nil { + // W1: source outside pluginRoot -- emit as-is + diagnostics = append(diagnostics, BuildDiagnostic{ + Level: "warning", + Message: fmt.Sprintf("[!] Package '%s': source '%s' is outside pluginRoot '%s' -- emitted as-is", pkg.Name, entry.Source, pluginRoot), + }) + } else { + sourceValue = stripped + stripCount++ + diagnostics = append(diagnostics, BuildDiagnostic{ + Level: "verbose", + Message: fmt.Sprintf("[i] Package '%s': stripped pluginRoot -- '%s' -> '%s'", pkg.Name, entry.Source, sourceValue), + }) + } + } + plugin["source"] = sourceValue + } else { + srcObj := make(map[string]interface{}) + if pkg.Subdir != "" { + srcObj["source"] = "git-subdir" + srcObj["url"] = pkg.SourceRepo + srcObj["path"] = pkg.Subdir + } else { + srcObj["source"] = "github" + srcObj["repo"] = pkg.SourceRepo + } + if pkg.Ref != "" { + srcObj["ref"] = pkg.Ref + } + if pkg.SHA != "" { + srcObj["sha"] = pkg.SHA + } + plugin["source"] = srcObj + } + + plugins = append(plugins, plugin) + } + + _ = overrideCount + _ = stripCount + + // Build verbose summary + if pluginRoot != "" && stripCount > 0 { + diagnostics = append(diagnostics, BuildDiagnostic{ + Level: "verbose", + Message: fmt.Sprintf("pluginRoot: stripped from %d local source(s)", stripCount), + }) + } + + // Duplicate name check + var buildWarnings []string + seenNames := make(map[string]string) + for _, p := range plugins { + pm := p.(map[string]interface{}) + pname := pm["name"].(string) + srcLabel := "?" + if src, ok := pm["source"]; ok { + switch s := src.(type) { + case string: + srcLabel = s + case map[string]interface{}: + if v, ok := s["path"]; ok { + srcLabel = fmt.Sprintf("%v", v) + } else if v, ok := s["repo"]; ok { + srcLabel = fmt.Sprintf("%v", v) + } + } + } + if prev, exists := seenNames[pname]; exists { + buildWarnings = append(buildWarnings, fmt.Sprintf("Duplicate package name '%s': '%s' and '%s'. Consumers will see duplicate entries in browse.", pname, prev, srcLabel)) + } else { + seenNames[pname] = srcLabel + } + } + + b.composeWarnings = buildWarnings + b.composeDiagnostics = diagnostics + doc["plugins"] = plugins + return doc, nil +} + +type pluginSHAs map[string]string + +func extractPluginSHAs(data map[string]interface{}) pluginSHAs { + out := make(pluginSHAs) + rawPlugins, _ := data["plugins"].([]interface{}) + for _, p := range rawPlugins { + pm, ok := p.(map[string]interface{}) + if !ok { + continue + } + name, _ := pm["name"].(string) + sha := "" + switch s := pm["source"].(type) { + case string: + sha = s + case map[string]interface{}: + if v, ok := s["sha"].(string); ok { + sha = v + } else if v, ok := s["commit"].(string); ok { + sha = v + } + } + out[name] = sha + } + return out +} + +func computeDiff(oldJSON, newJSON map[string]interface{}) (unchanged, added, updated, removed int) { + if oldJSON == nil { + return 0, len(extractPluginSHAs(newJSON)), 0, 0 + } + oldPlugins := extractPluginSHAs(oldJSON) + newPlugins := extractPluginSHAs(newJSON) + + for name, sha := range newPlugins { + if _, exists := oldPlugins[name]; !exists { + added++ + } else if oldPlugins[name] == sha { + unchanged++ + } else { + updated++ + } + } + for name := range oldPlugins { + if _, exists := newPlugins[name]; !exists { + removed++ + } + } + return +} + +func serializeJSON(data map[string]interface{}) ([]byte, error) { + b, err := json.MarshalIndent(data, "", " ") + if err != nil { + return nil, err + } + return append(b, '\n'), nil +} + +func loadExistingJSON(p string) map[string]interface{} { + data, err := os.ReadFile(p) + if err != nil { + return nil + } + var doc map[string]interface{} + if err := json.Unmarshal(data, &doc); err != nil { + return nil + } + return doc +} + +// Build runs the full pipeline: load -> resolve -> compose -> write. +func (b *MarketplaceBuilder) Build() (BuildReport, error) { + result, err := b.Resolve() + if err != nil { + return BuildReport{}, err + } + + newJSON, err := b.ComposeMarketplaceJSON(result.Entries) + if err != nil { + return BuildReport{}, err + } + + buildWarnings := b.composeWarnings + buildDiagnostics := b.composeDiagnostics + + yml, err := b.loadYML() + if err != nil { + return BuildReport{}, err + } + outPath, err := b.outputPath(yml) + if err != nil { + return BuildReport{}, err + } + + oldJSON := loadExistingJSON(outPath) + unchanged, added, updated, removed := computeDiff(oldJSON, newJSON) + + if !b.options.DryRun { + if err := os.MkdirAll(filepath.Dir(outPath), 0o755); err != nil { + return BuildReport{}, err + } + content, err := serializeJSON(newJSON) + if err != nil { + return BuildReport{}, err + } + if err := mkio.AtomicWrite(outPath, content); err != nil { + return BuildReport{}, err + } + } + + if b.resolver != nil { + b.resolver.Close() + } + + return BuildReport{ + Resolved: result.Entries, + Errors: result.Errors, + Warnings: buildWarnings, + Diagnostics: buildDiagnostics, + UnchangedCount: unchanged, + AddedCount: added, + UpdatedCount: updated, + RemovedCount: removed, + OutputPath: outPath, + DryRun: b.options.DryRun, + }, nil +} diff --git a/internal/marketplace/gitstderr/gitstderr.go b/internal/marketplace/gitstderr/gitstderr.go new file mode 100644 index 0000000..1f58ba7 --- /dev/null +++ b/internal/marketplace/gitstderr/gitstderr.go @@ -0,0 +1,186 @@ +// Package gitstderr translates git stderr into actionable, ASCII-only error messages. +// +// Callers pass captured stderr text, an optional exit code, and context +// (operation name, remote). This package classifies the failure into one +// of four known modes and returns a structured TranslatedGitError with a +// one-line summary, an actionable hint, and the (truncated) raw stderr. +// +// No subprocess, network, filesystem, or logging side effects -- this is +// a pure function package. +package gitstderr + +import ( +"fmt" +"strings" +) + +const ( +rawMaxLen = 500 +summaryMaxLen = 80 +) + +// GitErrorKind enumerates known git failure modes. +type GitErrorKind int + +const ( +// KindAuth indicates an authentication failure. +KindAuth GitErrorKind = iota +// KindNotFound indicates a ref or repository not found failure. +KindNotFound +// KindTimeout indicates a network timeout or connectivity failure. +KindTimeout +// KindUnknown indicates an unclassified failure. +KindUnknown +) + +// String returns the value string for GitErrorKind. +func (k GitErrorKind) String() string { +switch k { +case KindAuth: +return "auth" +case KindNotFound: +return "not_found" +case KindTimeout: +return "timeout" +default: +return "unknown" +} +} + +// TranslatedGitError is the structured result of translating git stderr. +type TranslatedGitError struct { +Kind GitErrorKind +Summary string +Hint string +Raw string +} + +var authPatterns = []string{ +"authentication failed", +"invalid credentials", +"could not read password", +"permission denied (publickey)", +"403 forbidden", +"401 unauthorized", +"fatal: authentication", +"remote: write access", +"please make sure you have the correct access rights", +"the requested url returned error: 401", +"the requested url returned error: 403", +} + +var notFoundPatterns = []string{ +"repository not found", +"does not appear to be a git repository", +"not a valid ref", +"couldn't find remote ref", +"could not resolve", +"the requested url returned error: 404", +"no such ref", +"unknown ref", +} + +var timeoutPatterns = []string{ +"operation timed out", +"connection timed out", +"could not resolve host", +"connection refused", +"network is unreachable", +"temporary failure in name resolution", +"ssl_read: connection reset", +"early eof", +"rpc failed", +} + +func truncateRaw(stderr string) string { +if len(stderr) <= rawMaxLen { +return stderr +} +return stderr[:rawMaxLen] + "... (truncated)" +} + +func classify(stderrLower string) GitErrorKind { +for _, p := range authPatterns { +if strings.Contains(stderrLower, p) { +return KindAuth +} +} +for _, p := range notFoundPatterns { +if strings.Contains(stderrLower, p) { +// "could not resolve host" is a DNS/network issue, not not-found. +if p == "could not resolve" && strings.Contains(stderrLower, "could not resolve host") { +continue +} +return KindNotFound +} +} +for _, p := range timeoutPatterns { +if strings.Contains(stderrLower, p) { +return KindTimeout +} +} +return KindUnknown +} + +func buildSummary(kind GitErrorKind, operation string, exitCode *int) string { +var text string +switch kind { +case KindAuth: +text = fmt.Sprintf("Git authentication failed during %s.", operation) +case KindNotFound: +text = fmt.Sprintf("Git ref or repository not found during %s.", operation) +case KindTimeout: +text = fmt.Sprintf("Git network timeout during %s.", operation) +default: +if exitCode != nil { +text = fmt.Sprintf("Git failed during %s (exit %d).", operation, *exitCode) +} else { +text = fmt.Sprintf("Git failed during %s.", operation) +} +} +if len(text) > summaryMaxLen { +text = text[:summaryMaxLen-3] + "..." +} +return text +} + +func buildHint(kind GitErrorKind, operation string, remote string) string { +switch kind { +case KindAuth: +return "Check your GITHUB_TOKEN / gh auth / SSH key. Run 'apm marketplace doctor' to diagnose." +case KindNotFound: +remoteLabel := "the remote" +if remote != "" { +remoteLabel = "'" + remote + "'" +} +return fmt.Sprintf("Verify the remote %s exists and the ref is spelled correctly.", remoteLabel) +case KindTimeout: +return "Network issue contacting the remote. Retry or check your connection." +default: +return fmt.Sprintf("Git failed during %s. See raw stderr above.", operation) +} +} + +// Options configures a Translate call. +type Options struct { +// ExitCode is the optional exit code from git. Pass nil if unknown. +ExitCode *int +// Operation names the git operation (e.g. "ls-remote"). Defaults to "git operation". +Operation string +// Remote is the optional remote name or URL for the hint. +Remote string +} + +// Translate classifies git stderr text into a known failure mode and produces an actionable hint. +func Translate(stderr string, opts Options) TranslatedGitError { +if opts.Operation == "" { +opts.Operation = "git operation" +} +kind := classify(strings.ToLower(stderr)) +return TranslatedGitError{ +Kind: kind, +Summary: buildSummary(kind, opts.Operation, opts.ExitCode), +Hint: buildHint(kind, opts.Operation, opts.Remote), +Raw: truncateRaw(stderr), +} +} diff --git a/internal/marketplace/gitstderr/gitstderr_test.go b/internal/marketplace/gitstderr/gitstderr_test.go new file mode 100644 index 0000000..bf1d8d9 --- /dev/null +++ b/internal/marketplace/gitstderr/gitstderr_test.go @@ -0,0 +1,58 @@ +package gitstderr_test + +import ( +"testing" + +"github.com/githubnext/apm/internal/marketplace/gitstderr" +) + +func TestTranslate_Auth(t *testing.T) { +r := gitstderr.Translate("fatal: authentication failed for 'https://github.com/acme/tools'", +gitstderr.Options{Operation: "ls-remote", Remote: "acme/tools"}) +if r.Kind != gitstderr.KindAuth { +t.Fatalf("expected KindAuth, got %s", r.Kind) +} +if r.Summary == "" || r.Hint == "" { +t.Fatal("expected non-empty summary and hint") +} +} + +func TestTranslate_NotFound(t *testing.T) { +r := gitstderr.Translate("ERROR: Repository not found.", gitstderr.Options{Operation: "clone"}) +if r.Kind != gitstderr.KindNotFound { +t.Fatalf("expected KindNotFound, got %s", r.Kind) +} +} + +func TestTranslate_Timeout(t *testing.T) { +r := gitstderr.Translate("fatal: unable to connect to github.com: connection timed out", +gitstderr.Options{}) +if r.Kind != gitstderr.KindTimeout { +t.Fatalf("expected KindTimeout, got %s", r.Kind) +} +} + +func TestTranslate_Unknown(t *testing.T) { +r := gitstderr.Translate("some unexpected error", gitstderr.Options{}) +if r.Kind != gitstderr.KindUnknown { +t.Fatalf("expected KindUnknown, got %s", r.Kind) +} +} + +func TestTranslate_TruncatesRaw(t *testing.T) { +long := string(make([]byte, 600)) +for i := range long { +long = long[:i] + "a" + long[i+1:] +} +r := gitstderr.Translate(long, gitstderr.Options{}) +if len(r.Raw) > 520 { +t.Fatalf("raw too long: %d", len(r.Raw)) +} +} + +func TestTranslate_CouldNotResolveHost_IsTimeout(t *testing.T) { +r := gitstderr.Translate("fatal: could not resolve host: github.com", gitstderr.Options{}) +if r.Kind != gitstderr.KindTimeout { +t.Fatalf("expected KindTimeout for DNS failure, got %s", r.Kind) +} +} diff --git a/internal/marketplace/gitutils/gitutils.go b/internal/marketplace/gitutils/gitutils.go new file mode 100644 index 0000000..0300b21 --- /dev/null +++ b/internal/marketplace/gitutils/gitutils.go @@ -0,0 +1,25 @@ +// Package gitutils provides shared git-related utilities for marketplace modules. +// Migrated from src/apm_cli/marketplace/_git_utils.py +package gitutils + +import "regexp" + +// tokenRE matches auth tokens in git URLs. +// Covers: https://TOKEN@host, http://TOKEN@host, and ?token=VALUE query params. +var tokenRE = regexp.MustCompile(`https?://[^@\s]*@|([?&])token=[^\s&]*`) + +// RedactToken replaces auth tokens in text with redacted placeholders. +func RedactToken(text string) string { + return tokenRE.ReplaceAllStringFunc(text, func(m string) string { + for _, r := range m { + if r == '@' { + return "https://***@" + } + } + // query-param match: preserve the leading ? or & + if len(m) > 0 && (m[0] == '?' || m[0] == '&') { + return string(m[0]) + "token=***" + } + return m + }) +} diff --git a/internal/marketplace/mkio/mkio.go b/internal/marketplace/mkio/mkio.go new file mode 100644 index 0000000..99d7a6f --- /dev/null +++ b/internal/marketplace/mkio/mkio.go @@ -0,0 +1,51 @@ +// Package mkio provides shared I/O helpers for marketplace modules. +// Migrated from src/apm_cli/marketplace/_io.py +package mkio + +import ( + "os" + "path/filepath" +) + +// AtomicWrite writes content to path atomically via tmp + rename. +// The caller sees either the complete new content or the previous +// content -- never a partial write. +func AtomicWrite(path string, content []byte) error { + dir := filepath.Dir(path) + ext := filepath.Ext(path) + tmpPath := path[:len(path)-len(ext)] + ext + ".tmp" + + f, err := os.OpenFile(tmpPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644) + if err != nil { + return err + } + + _, writeErr := f.Write(content) + syncErr := f.Sync() + closeErr := f.Close() + + if writeErr != nil { + os.Remove(tmpPath) + return writeErr + } + if syncErr != nil { + os.Remove(tmpPath) + return syncErr + } + if closeErr != nil { + os.Remove(tmpPath) + return closeErr + } + + _ = dir // dir used implicitly via tmpPath construction + if err := os.Rename(tmpPath, path); err != nil { + os.Remove(tmpPath) + return err + } + return nil +} + +// AtomicWriteString writes string content to path atomically. +func AtomicWriteString(path, content string) error { + return AtomicWrite(path, []byte(content)) +} diff --git a/internal/marketplace/refresolver/refresolver.go b/internal/marketplace/refresolver/refresolver.go new file mode 100644 index 0000000..5a12ab6 --- /dev/null +++ b/internal/marketplace/refresolver/refresolver.go @@ -0,0 +1,300 @@ +// Package refresolver provides concurrent git ls-remote with in-memory ref caching. +// Migrated from src/apm_cli/marketplace/ref_resolver.py. +package refresolver + +import ( + "bufio" + "bytes" + "context" + "fmt" + "os" + "os/exec" + "regexp" + "strings" + "sync" + "time" + + "github.com/githubnext/apm/internal/marketplace/gitstderr" + "github.com/githubnext/apm/internal/marketplace/gitutils" + "github.com/githubnext/apm/internal/utils/githubhost" +) + +// RemoteRef is a single ref returned by git ls-remote. +type RemoteRef struct { + Name string // e.g. "refs/tags/v1.2.0" or "refs/heads/main" + SHA string // 40-char hex SHA +} + +var shaRE = regexp.MustCompile(`^[0-9a-f]{40}$`) + +// DefaultTTL is the default cache TTL (5 minutes). +const DefaultTTL = 5 * time.Minute + +type cacheEntry struct { + refs []RemoteRef + timestamp time.Time +} + +// RefCache is an in-memory cache keyed on "owner/repo". +type RefCache struct { + mu sync.Mutex + store map[string]*cacheEntry + ttl time.Duration +} + +// NewRefCache creates a RefCache with the given TTL. +func NewRefCache(ttl time.Duration) *RefCache { + return &RefCache{store: make(map[string]*cacheEntry), ttl: ttl} +} + +// Get returns cached refs or nil on miss/expiry. +func (c *RefCache) Get(ownerRepo string) []RemoteRef { + c.mu.Lock() + defer c.mu.Unlock() + e := c.store[ownerRepo] + if e == nil { + return nil + } + if time.Since(e.timestamp) > c.ttl { + delete(c.store, ownerRepo) + return nil + } + out := make([]RemoteRef, len(e.refs)) + copy(out, e.refs) + return out +} + +// Put stores refs for ownerRepo. +func (c *RefCache) Put(ownerRepo string, refs []RemoteRef) { + c.mu.Lock() + defer c.mu.Unlock() + cp := make([]RemoteRef, len(refs)) + copy(cp, refs) + c.store[ownerRepo] = &cacheEntry{refs: cp, timestamp: time.Now()} +} + +// Clear drops all entries. +func (c *RefCache) Clear() { + c.mu.Lock() + defer c.mu.Unlock() + c.store = make(map[string]*cacheEntry) +} + +// Len returns the number of cached entries. +func (c *RefCache) Len() int { + c.mu.Lock() + defer c.mu.Unlock() + return len(c.store) +} + +// GitLsRemoteError is raised when git ls-remote fails. +type GitLsRemoteError struct { + Package string + Summary string + Hint string +} + +func (e *GitLsRemoteError) Error() string { + if e.Hint != "" { + return e.Summary + " " + e.Hint + } + return e.Summary +} + +// OfflineMissError is raised in offline mode when the cache has no entry. +type OfflineMissError struct { + Package string + Remote string +} + +func (e *OfflineMissError) Error() string { + return fmt.Sprintf("offline mode: no cached refs for remote '%s'", e.Remote) +} + +// RefResolver runs git ls-remote and caches the results. +type RefResolver struct { + timeoutSeconds float64 + offline bool + host string + token string + cache *RefCache + mu sync.Mutex + remoteLocks map[string]*sync.Mutex +} + +// New creates a RefResolver. +func New(timeoutSeconds float64, offline bool, host, token string) *RefResolver { + if host == "" { + host = githubhost.DefaultHost() + } + if host == "" { + host = "github.com" + } + return &RefResolver{ + timeoutSeconds: timeoutSeconds, + offline: offline, + host: host, + token: token, + cache: NewRefCache(DefaultTTL), + remoteLocks: make(map[string]*sync.Mutex), + } +} + +func (r *RefResolver) remoteLock(ownerRepo string) *sync.Mutex { + r.mu.Lock() + defer r.mu.Unlock() + if _, ok := r.remoteLocks[ownerRepo]; !ok { + r.remoteLocks[ownerRepo] = &sync.Mutex{} + } + return r.remoteLocks[ownerRepo] +} + +// buildHTTPSCloneURL constructs an authenticated HTTPS clone URL. +func buildHTTPSCloneURL(host, ownerRepo, token string) string { + base := fmt.Sprintf("https://%s/%s.git", host, ownerRepo) + if token != "" { + base = fmt.Sprintf("https://x-access-token:%s@%s/%s.git", token, host, ownerRepo) + } + return base +} + +// parseLsRemoteOutput parses git ls-remote stdout into RemoteRefs. +func parseLsRemoteOutput(output string) []RemoteRef { + var refs []RemoteRef + scanner := bufio.NewScanner(strings.NewReader(output)) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + parts := strings.SplitN(line, "\t", 2) + if len(parts) != 2 { + continue + } + sha, refname := strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]) + if !shaRE.MatchString(sha) { + continue + } + if strings.HasSuffix(refname, "^{}") { + continue + } + refs = append(refs, RemoteRef{Name: refname, SHA: sha}) + } + return refs +} + +// ListRemoteRefs fetches all tags and heads from the configured Git host. +func (r *RefResolver) ListRemoteRefs(ownerRepo string) ([]RemoteRef, error) { + lock := r.remoteLock(ownerRepo) + lock.Lock() + defer lock.Unlock() + + if cached := r.cache.Get(ownerRepo); cached != nil { + return cached, nil + } + + if r.offline { + return nil, &OfflineMissError{Remote: ownerRepo} + } + + url := buildHTTPSCloneURL(r.host, ownerRepo, r.token) + timeout := time.Duration(r.timeoutSeconds * float64(time.Second)) + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + cmd := exec.CommandContext(ctx, "git", "ls-remote", "--tags", "--heads", url) + cmd.Env = append(os.Environ(), "GIT_TERMINAL_PROMPT=0", "GIT_ASKPASS=echo") + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + runErr := cmd.Run() + if ctx.Err() == context.DeadlineExceeded { + return nil, &GitLsRemoteError{ + Summary: fmt.Sprintf("git ls-remote timed out after %.0fs for '%s'.", r.timeoutSeconds, ownerRepo), + Hint: "Increase --timeout or check your network connection.", + } + } + if runErr != nil { + stderrStr := gitutils.RedactToken(stderr.String()) + exitCode := -1 + if exitErr, ok := runErr.(*exec.ExitError); ok { + exitCode = exitErr.ExitCode() + } + translated := gitstderr.Translate(stderrStr, gitstderr.Options{ + ExitCode: &exitCode, + Operation: "ls-remote", + Remote: ownerRepo, + }) + return nil, &GitLsRemoteError{ + Summary: translated.Summary, + Hint: translated.Hint, + } + } + + refs := parseLsRemoteOutput(stdout.String()) + r.cache.Put(ownerRepo, refs) + return refs, nil +} + +// ResolveRefSHA resolves a single ref to its concrete SHA. +func (r *RefResolver) ResolveRefSHA(ownerRepo, ref string) (string, error) { + if ref == "" { + ref = "HEAD" + } + url := buildHTTPSCloneURL(r.host, ownerRepo, r.token) + timeout := time.Duration(r.timeoutSeconds * float64(time.Second)) + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + cmd := exec.CommandContext(ctx, "git", "ls-remote", url, ref) + cmd.Env = append(os.Environ(), "GIT_TERMINAL_PROMPT=0", "GIT_ASKPASS=echo") + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + runErr := cmd.Run() + if ctx.Err() == context.DeadlineExceeded { + return "", &GitLsRemoteError{ + Summary: fmt.Sprintf("git ls-remote timed out after %.0fs for '%s'.", r.timeoutSeconds, ownerRepo), + Hint: "Increase --timeout or check your network connection.", + } + } + if runErr != nil { + stderrStr := gitutils.RedactToken(stderr.String()) + exitCode := -1 + if exitErr, ok := runErr.(*exec.ExitError); ok { + exitCode = exitErr.ExitCode() + } + translated := gitstderr.Translate(stderrStr, gitstderr.Options{ + ExitCode: &exitCode, + Operation: "ls-remote", + Remote: ownerRepo, + }) + return "", &GitLsRemoteError{ + Summary: translated.Summary, + Hint: translated.Hint, + } + } + + refs := parseLsRemoteOutput(stdout.String()) + if len(refs) == 0 { + return "", &GitLsRemoteError{ + Summary: fmt.Sprintf("Ref '%s' not found on remote '%s'.", ref, ownerRepo), + Hint: "Check that the ref exists and you have access to the repository.", + } + } + return refs[0].SHA, nil +} + +// Close releases resources. +func (r *RefResolver) Close() { + r.cache.Clear() + r.mu.Lock() + r.remoteLocks = make(map[string]*sync.Mutex) + r.mu.Unlock() +} diff --git a/internal/marketplace/registry/registry.go b/internal/marketplace/registry/registry.go new file mode 100644 index 0000000..5afd176 --- /dev/null +++ b/internal/marketplace/registry/registry.go @@ -0,0 +1,236 @@ +// Package registry manages registered marketplaces stored in ~/.apm/marketplaces.json. +package registry + +import ( +"encoding/json" +"fmt" +"os" +"path/filepath" +"sort" +"strings" +"sync" +) + +const marketplacesFilename = "marketplaces.json" + +// MarketplaceSource represents a registered marketplace. +type MarketplaceSource struct { +Name string `json:"name"` +URL string `json:"url"` +// Additional fields are preserved via Extra. +Extra map[string]interface{} `json:"-"` +} + +// FromDict creates a MarketplaceSource from a JSON-decoded map. +func FromDict(m map[string]interface{}) (MarketplaceSource, error) { +name, ok := m["name"].(string) +if !ok || name == "" { +return MarketplaceSource{}, fmt.Errorf("missing or invalid 'name' field") +} +url, _ := m["url"].(string) +extra := make(map[string]interface{}) +for k, v := range m { +if k != "name" && k != "url" { +extra[k] = v +} +} +return MarketplaceSource{Name: name, URL: url, Extra: extra}, nil +} + +// ToDict converts a MarketplaceSource to a JSON-serializable map. +func (s MarketplaceSource) ToDict() map[string]interface{} { +m := make(map[string]interface{}, len(s.Extra)+2) +for k, v := range s.Extra { +m[k] = v +} +m["name"] = s.Name +m["url"] = s.URL +return m +} + +// Registry manages the marketplace list file. +type Registry struct { +configDir func() string +mu sync.Mutex +cache []MarketplaceSource +cacheValid bool +} + +// New creates a Registry that stores files in the directory returned by configDir. +func New(configDir func() string) *Registry { +return &Registry{configDir: configDir} +} + +func (r *Registry) path() string { +return filepath.Join(r.configDir(), marketplacesFilename) +} + +func (r *Registry) ensureFile() (string, error) { +dir := r.configDir() +if err := os.MkdirAll(dir, 0o755); err != nil { +return "", err +} +p := r.path() +if _, err := os.Stat(p); os.IsNotExist(err) { +data, _ := json.MarshalIndent(map[string]interface{}{"marketplaces": []interface{}{}}, "", " ") +if err := os.WriteFile(p, data, 0o644); err != nil { +return "", err +} +} +return p, nil +} + +func (r *Registry) invalidate() { +r.mu.Lock() +r.cacheValid = false +r.mu.Unlock() +} + +func (r *Registry) load() ([]MarketplaceSource, error) { +r.mu.Lock() +defer r.mu.Unlock() +if r.cacheValid { +out := make([]MarketplaceSource, len(r.cache)) +copy(out, r.cache) +return out, nil +} +p, err := r.ensureFile() +if err != nil { +return nil, err +} +raw, err := os.ReadFile(p) +var data map[string]interface{} +if err == nil { +_ = json.Unmarshal(raw, &data) +} +if data == nil { +data = map[string]interface{}{"marketplaces": []interface{}{}} +} +entries, _ := data["marketplaces"].([]interface{}) +var sources []MarketplaceSource +for _, e := range entries { +m, ok := e.(map[string]interface{}) +if !ok { +continue +} +src, err := FromDict(m) +if err == nil { +sources = append(sources, src) +} +} +r.cache = sources +r.cacheValid = true +out := make([]MarketplaceSource, len(sources)) +copy(out, sources) +return out, nil +} + +func (r *Registry) save(sources []MarketplaceSource) error { +p, err := r.ensureFile() +if err != nil { +return err +} +dicts := make([]interface{}, len(sources)) +for i, s := range sources { +dicts[i] = s.ToDict() +} +data := map[string]interface{}{"marketplaces": dicts} +raw, err := json.MarshalIndent(data, "", " ") +if err != nil { +return err +} +tmp := p + ".tmp" +if err := os.WriteFile(tmp, raw, 0o644); err != nil { +return err +} +if err := os.Rename(tmp, p); err != nil { +return err +} +r.mu.Lock() +r.cache = make([]MarketplaceSource, len(sources)) +copy(r.cache, sources) +r.cacheValid = true +r.mu.Unlock() +return nil +} + +// GetAll returns all registered marketplaces. +func (r *Registry) GetAll() ([]MarketplaceSource, error) { +return r.load() +} + +// GetByName returns a marketplace by display name (case-insensitive). +// Returns an error if not found. +func (r *Registry) GetByName(name string) (MarketplaceSource, error) { +lower := strings.ToLower(name) +sources, err := r.load() +if err != nil { +return MarketplaceSource{}, err +} +for _, s := range sources { +if strings.ToLower(s.Name) == lower { +return s, nil +} +} +return MarketplaceSource{}, fmt.Errorf("marketplace not found: %s", name) +} + +// Add registers a marketplace, replacing any existing entry with the same name. +func (r *Registry) Add(source MarketplaceSource) error { +sources, err := r.load() +if err != nil { +return err +} +lower := strings.ToLower(source.Name) +var filtered []MarketplaceSource +for _, s := range sources { +if strings.ToLower(s.Name) != lower { +filtered = append(filtered, s) +} +} +filtered = append(filtered, source) +return r.save(filtered) +} + +// Remove removes a marketplace by name. +// Returns an error if not found. +func (r *Registry) Remove(name string) error { +sources, err := r.load() +if err != nil { +return err +} +lower := strings.ToLower(name) +var filtered []MarketplaceSource +for _, s := range sources { +if strings.ToLower(s.Name) != lower { +filtered = append(filtered, s) +} +} +if len(filtered) == len(sources) { +return fmt.Errorf("marketplace not found: %s", name) +} +return r.save(filtered) +} + +// Names returns a sorted list of registered marketplace names. +func (r *Registry) Names() ([]string, error) { +sources, err := r.load() +if err != nil { +return nil, err +} +names := make([]string, len(sources)) +for i, s := range sources { +names[i] = s.Name +} +sort.Strings(names) +return names, nil +} + +// Count returns the number of registered marketplaces. +func (r *Registry) Count() (int, error) { +sources, err := r.load() +if err != nil { +return 0, err +} +return len(sources), nil +} diff --git a/internal/marketplace/ymlschema/ymlschema.go b/internal/marketplace/ymlschema/ymlschema.go new file mode 100644 index 0000000..672648f --- /dev/null +++ b/internal/marketplace/ymlschema/ymlschema.go @@ -0,0 +1,438 @@ +// Package ymlschema provides dataclasses, loader, and validation for +// marketplace authoring config. +// Migrated from src/apm_cli/marketplace/yml_schema.py. +package ymlschema + +import ( + "bufio" + "fmt" + "os" + "regexp" + "sort" + "strings" +) + +// Errors + +// MarketplaceYmlError is raised on marketplace YAML validation failures. +type MarketplaceYmlError struct { + Msg string +} + +func (e *MarketplaceYmlError) Error() string { return e.Msg } + +func mErr(format string, args ...interface{}) *MarketplaceYmlError { + return &MarketplaceYmlError{Msg: fmt.Sprintf(format, args...)} +} + +// Regex patterns +var ( + semverRE = regexp.MustCompile(`^\d+\.\d+\.\d+(?:-[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*)?(?:\+[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*)?$`) + sourceRE = regexp.MustCompile(`^(?:[^/]+/[^/]+|\./.*)$`) + localSourceRE = regexp.MustCompile(`^\.\/`) +) + +const ( + maxTagsCount = 50 + maxTagLength = 100 +) + +var tagPlaceholders = []string{"{version}", "{name}"} + +// MarketplaceOwner is the owner block of marketplace.yml. +type MarketplaceOwner struct { + Name string + Email string + URL string +} + +// MarketplaceBuild is the APM-only build configuration block. +type MarketplaceBuild struct { + TagPattern string +} + +// PackageEntry is a single entry in the packages list. +type PackageEntry struct { + Name string + Source string + Subdir string + Version string + Ref string + TagPattern string + IncludePrerelease bool + Description string + Homepage string + Tags []string + Author map[string]string // {name, email?, url?} + License string + Repository string + IsLocal bool +} + +// MarketplaceConfig is the parsed marketplace configuration. +type MarketplaceConfig struct { + Name string + Description string + Version string + Owner MarketplaceOwner + Output string + Metadata map[string]interface{} + Build MarketplaceBuild + Packages []PackageEntry + SourcePath string + IsLegacy bool + NameOverridden bool + DescriptionOverridden bool + VersionOverridden bool +} + +// parseSimpleYAML is a minimal line-by-line YAML parser for flat string values. +// Returns top-level key->value pairs (no nesting). Values are trimmed and unquoted. +func parseSimpleYAML(content string) map[string]string { + result := map[string]string{} + scanner := bufio.NewScanner(strings.NewReader(content)) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "#") || strings.HasPrefix(line, " ") || strings.HasPrefix(line, "\t") { + continue + } + idx := strings.Index(line, ":") + if idx <= 0 { + continue + } + key := strings.TrimSpace(line[:idx]) + val := strings.TrimSpace(line[idx+1:]) + val = strings.Trim(val, "\"'") + if val != "" && !strings.HasPrefix(val, "{") && !strings.HasPrefix(val, "[") && !strings.HasPrefix(val, "-") { + result[key] = val + } + } + return result +} + +func validateSemver(version, context string) error { + if !semverRE.MatchString(version) { + return mErr("'%s' value '%s' is not valid semver (expected x.y.z)", context, version) + } + return nil +} + +func validateTagPattern(pattern, context string) error { + for _, ph := range tagPlaceholders { + if strings.Contains(pattern, ph) { + return nil + } + } + return mErr("'%s' must contain at least one of %s, got '%s'", context, strings.Join(tagPlaceholders, ", "), pattern) +} + +func validatePathSegments(path string) error { + parts := strings.Split(path, "/") + for _, p := range parts { + if p == ".." { + return fmt.Errorf("path traversal detected in: %s", path) + } + } + return nil +} + +func parseOwner(raw map[string]interface{}) (MarketplaceOwner, error) { + name, ok := raw["name"].(string) + if !ok || strings.TrimSpace(name) == "" { + return MarketplaceOwner{}, mErr("'owner.name' is required and must be a non-empty string") + } + owner := MarketplaceOwner{Name: strings.TrimSpace(name)} + if email, ok := raw["email"].(string); ok { + owner.Email = strings.TrimSpace(email) + } + if url, ok := raw["url"].(string); ok { + owner.URL = strings.TrimSpace(url) + } + return owner, nil +} + +func parseBuild(raw interface{}) (MarketplaceBuild, error) { + if raw == nil { + return MarketplaceBuild{TagPattern: "v{version}"}, nil + } + m, ok := raw.(map[string]interface{}) + if !ok { + return MarketplaceBuild{}, mErr("'build' must be a mapping") + } + tagPattern := "v{version}" + if tp, ok := m["tagPattern"].(string); ok && strings.TrimSpace(tp) != "" { + tagPattern = strings.TrimSpace(tp) + } + if err := validateTagPattern(tagPattern, "build.tagPattern"); err != nil { + return MarketplaceBuild{}, err + } + return MarketplaceBuild{TagPattern: tagPattern}, nil +} + +func getStr(m map[string]interface{}, key string) (string, bool) { + v, ok := m[key] + if !ok || v == nil { + return "", false + } + s, ok := v.(string) + return s, ok +} + +func requireStr(m map[string]interface{}, key, context string) (string, error) { + s, ok := getStr(m, key) + if !ok || strings.TrimSpace(s) == "" { + path := key + if context != "" { + path = context + "." + key + } + return "", mErr("'%s' is required", path) + } + return strings.TrimSpace(s), nil +} + +func checkUnknownKeys(data map[string]interface{}, permitted map[string]bool, context string) error { + var unknown []string + for k := range data { + if !permitted[k] { + unknown = append(unknown, k) + } + } + if len(unknown) > 0 { + sort.Strings(unknown) + var perm []string + for k := range permitted { + perm = append(perm, k) + } + sort.Strings(perm) + return mErr("Unknown key(s) in %s: %s. Permitted keys: %s", context, strings.Join(unknown, ", "), strings.Join(perm, ", ")) + } + return nil +} + +var packageEntryKeys = map[string]bool{ + "name": true, "source": true, "subdir": true, "version": true, "ref": true, + "tag_pattern": true, "include_prerelease": true, "description": true, + "homepage": true, "tags": true, "author": true, "license": true, + "repository": true, "keywords": true, +} + +var apmMarketplaceKeys = map[string]bool{ + "name": true, "description": true, "version": true, "owner": true, + "output": true, "metadata": true, "build": true, "packages": true, +} + +func parsePackageEntry(raw interface{}, index int) (PackageEntry, error) { + m, ok := raw.(map[string]interface{}) + if !ok { + // Try map[interface{}]interface{} (some YAML parsers) + if mi, ok2 := raw.(map[interface{}]interface{}); ok2 { + m = make(map[string]interface{}) + for k, v := range mi { + m[fmt.Sprint(k)] = v + } + } else { + return PackageEntry{}, mErr("packages[%d] must be a mapping", index) + } + } + if err := checkUnknownKeys(m, packageEntryKeys, fmt.Sprintf("packages[%d]", index)); err != nil { + return PackageEntry{}, err + } + + name, err := requireStr(m, "name", fmt.Sprintf("packages[%d]", index)) + if err != nil { + return PackageEntry{}, err + } + source, err := requireStr(m, "source", fmt.Sprintf("packages[%d]", index)) + if err != nil { + return PackageEntry{}, err + } + if !sourceRE.MatchString(source) { + return PackageEntry{}, mErr("'packages[%d].source' must match '/' or './' shape, got '%s'", index, source) + } + isLocal := localSourceRE.MatchString(source) + + entry := PackageEntry{Name: name, Source: source, IsLocal: isLocal} + + if v, ok := getStr(m, "subdir"); ok && strings.TrimSpace(v) != "" { + entry.Subdir = strings.TrimSpace(v) + } + if v, ok := getStr(m, "version"); ok && strings.TrimSpace(v) != "" { + entry.Version = strings.TrimSpace(v) + } + if v, ok := getStr(m, "ref"); ok && strings.TrimSpace(v) != "" { + entry.Ref = strings.TrimSpace(v) + } + if !isLocal && entry.Version == "" && entry.Ref == "" { + return PackageEntry{}, mErr("packages[%d] ('%s'): remote packages require at least one of 'version' or 'ref'", index, name) + } + if v, ok := getStr(m, "tag_pattern"); ok && strings.TrimSpace(v) != "" { + tp := strings.TrimSpace(v) + if err := validateTagPattern(tp, fmt.Sprintf("packages[%d].tag_pattern", index)); err != nil { + return PackageEntry{}, err + } + entry.TagPattern = tp + } + if v, ok := m["include_prerelease"].(bool); ok { + entry.IncludePrerelease = v + } + if v, ok := getStr(m, "description"); ok { + entry.Description = strings.TrimSpace(v) + } + if v, ok := getStr(m, "homepage"); ok { + entry.Homepage = strings.TrimSpace(v) + } + if v, ok := getStr(m, "license"); ok { + entry.License = strings.TrimSpace(v) + } + if v, ok := getStr(m, "repository"); ok { + entry.Repository = strings.TrimSpace(v) + } + + // Tags + keywords merge + var tags []string + if rawTags, ok := m["tags"].([]interface{}); ok { + for _, t := range rawTags { + if s, ok := t.(string); ok { + tags = append(tags, s) + } + } + } + if rawKW, ok := m["keywords"].([]interface{}); ok { + seen := map[string]bool{} + for _, t := range tags { + seen[t] = true + } + for _, t := range rawKW { + if s, ok := t.(string); ok && !seen[s] { + tags = append(tags, s) + seen[s] = true + } + } + } + if len(tags) > maxTagsCount { + tags = tags[:maxTagsCount] + } + for i, t := range tags { + if len(t) > maxTagLength { + tags[i] = t[:maxTagLength] + } + } + entry.Tags = tags + + // Author + if rawAuthor, ok := m["author"]; ok && rawAuthor != nil { + switch a := rawAuthor.(type) { + case string: + n := strings.TrimSpace(a) + if n == "" { + return PackageEntry{}, mErr("'packages[%d].author' must be a non-empty string or object with 'name'", index) + } + entry.Author = map[string]string{"name": n} + case map[string]interface{}: + n, ok := getStr(a, "name") + if !ok || strings.TrimSpace(n) == "" { + return PackageEntry{}, mErr("'packages[%d].author.name' is required", index) + } + auth := map[string]string{"name": strings.TrimSpace(n)} + for _, k := range []string{"email", "url"} { + if v, ok := getStr(a, k); ok && strings.TrimSpace(v) != "" { + auth[k] = strings.TrimSpace(v) + } + } + entry.Author = auth + } + } + + return entry, nil +} + +// LoadFromFile loads a MarketplaceConfig from a file path. +// It reads the file as raw text and uses a minimal parser. +func LoadFromFile(path string, isLegacy bool) (*MarketplaceConfig, error) { + content, err := os.ReadFile(path) + if err != nil { + return nil, mErr("Cannot read '%s': %v", path, err) + } + + // Use simple key-value extraction for top-level scalars + flat := parseSimpleYAML(string(content)) + + cfg := &MarketplaceConfig{ + SourcePath: path, + IsLegacy: isLegacy, + Build: MarketplaceBuild{TagPattern: "v{version}"}, + Output: ".claude-plugin/marketplace.json", + } + if isLegacy { + cfg.Output = "marketplace.json" + } + + if v := flat["name"]; v != "" { + cfg.Name = v + cfg.NameOverridden = isLegacy + } + if v := flat["description"]; v != "" { + cfg.Description = v + cfg.DescriptionOverridden = isLegacy + } + if v := flat["version"]; v != "" { + cfg.Version = v + cfg.VersionOverridden = isLegacy + if cfg.Version != "" { + if err := validateSemver(cfg.Version, "version"); err != nil { + return nil, err + } + } + } + if v := flat["output"]; v != "" { + cfg.Output = v + if err := validatePathSegments(cfg.Output); err != nil { + return nil, mErr("invalid output path: %v", err) + } + } + + // Owner (required) + ownerName := flat["owner.name"] + if ownerName == "" { + // Try to extract owner.name from nested YAML manually + ownerName = extractNestedValue(string(content), "owner", "name") + } + if ownerName == "" { + return nil, mErr("'owner' is required") + } + cfg.Owner = MarketplaceOwner{ + Name: ownerName, + Email: extractNestedValue(string(content), "owner", "email"), + URL: extractNestedValue(string(content), "owner", "url"), + } + + return cfg, nil +} + +// extractNestedValue extracts a value from a 2-level YAML structure without +// a full YAML parser. Used for simple cases like owner.name. +func extractNestedValue(content, parent, key string) string { + inParent := false + scanner := bufio.NewScanner(strings.NewReader(content)) + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, " ") && !strings.HasPrefix(line, "\t") { + // Top-level line + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, parent+":") { + inParent = true + } else if trimmed != "" { + inParent = false + } + continue + } + if inParent { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, key+":") { + val := strings.TrimSpace(trimmed[len(key)+1:]) + return strings.Trim(val, "\"'") + } + } + } + return "" +} diff --git a/internal/models/apmpackage/apmpackage.go b/internal/models/apmpackage/apmpackage.go new file mode 100644 index 0000000..08ac93c --- /dev/null +++ b/internal/models/apmpackage/apmpackage.go @@ -0,0 +1,163 @@ +// Package apmpackage provides the APMPackage and PackageInfo data models. +// Migrated from src/apm_cli/models/apm_package.py. +package apmpackage + +import ( + "bufio" + "fmt" + "os" + "path/filepath" + "strings" +) + +// PackageContentType represents the content type of a package. +type PackageContentType int + +const ( + ContentTypeInstructions PackageContentType = iota + ContentTypeSkill + ContentTypeHybrid + ContentTypePrompts +) + +// String returns the string representation of a PackageContentType. +func (t PackageContentType) String() string { + switch t { + case ContentTypeInstructions: + return "instructions" + case ContentTypeSkill: + return "skill" + case ContentTypeHybrid: + return "hybrid" + case ContentTypePrompts: + return "prompts" + default: + return "unknown" + } +} + +// ParseContentType parses a string content type. +func ParseContentType(s string) (PackageContentType, error) { + switch strings.ToLower(s) { + case "instructions": + return ContentTypeInstructions, nil + case "skill": + return ContentTypeSkill, nil + case "hybrid": + return ContentTypeHybrid, nil + case "prompts": + return ContentTypePrompts, nil + default: + return 0, fmt.Errorf("unknown content type: %s", s) + } +} + +// APMPackage represents an APM package with metadata. +type APMPackage struct { + Name string + Version string + Description string + Author string + License string + Source string + ResolvedCommit string + Dependencies map[string][]interface{} + DevDependencies map[string][]interface{} + Scripts map[string]string + PackagePath string + SourcePath string + Target interface{} // string or []string + Type *PackageContentType + Includes interface{} // string "auto" or []string +} + +// PackageInfo contains information about a downloaded/installed package. +type PackageInfo struct { + Package *APMPackage + InstallPath string + InstalledAt string + PackageType string // "APM_PACKAGE", "CLAUDE_SKILL", or "HYBRID" +} + +// GetPrimitivesPath returns the path to the .apm directory for this package. +func (p *PackageInfo) GetPrimitivesPath() string { + return filepath.Join(p.InstallPath, ".apm") +} + +// HasPrimitives checks if the package has any primitives. +func (p *PackageInfo) HasPrimitives() bool { + apmDir := p.GetPrimitivesPath() + for _, pt := range []string{"instructions", "chatmodes", "contexts", "prompts", "hooks"} { + dir := filepath.Join(apmDir, pt) + if entries, err := os.ReadDir(dir); err == nil && len(entries) > 0 { + return true + } + } + hooksDir := filepath.Join(p.InstallPath, "hooks") + if entries, err := os.ReadDir(hooksDir); err == nil { + for _, e := range entries { + if strings.HasSuffix(e.Name(), ".json") { + return true + } + } + } + return false +} + +// LoadFromApmYml loads basic package metadata from an apm.yml file. +// This is a lightweight loader that extracts name/version/description/target +// without full dependency parsing. +func LoadFromApmYml(apmYmlPath string) (*APMPackage, error) { + f, err := os.Open(apmYmlPath) + if err != nil { + return nil, fmt.Errorf("apm.yml not found: %s", apmYmlPath) + } + defer f.Close() + + data := map[string]string{} + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := scanner.Text() + if idx := strings.Index(line, ":"); idx > 0 { + key := strings.TrimSpace(line[:idx]) + val := strings.TrimSpace(line[idx+1:]) + // Strip inline YAML quotes + val = strings.Trim(val, "\"'") + if val != "" && !strings.HasPrefix(val, "{") && !strings.HasPrefix(val, "[") { + data[key] = val + } + } + } + + name := data["name"] + version := data["version"] + if name == "" { + return nil, fmt.Errorf("missing required field 'name' in apm.yml") + } + if version == "" { + return nil, fmt.Errorf("missing required field 'version' in apm.yml") + } + + pkg := &APMPackage{ + Name: name, + Version: version, + Description: data["description"], + Author: data["author"], + License: data["license"], + PackagePath: filepath.Dir(apmYmlPath), + SourcePath: filepath.Dir(apmYmlPath), + } + + if t := data["target"]; t != "" { + pkg.Target = t + } + + if typeStr := data["type"]; typeStr != "" { + ct, err := ParseContentType(typeStr) + if err == nil { + pkg.Type = &ct + } + } + + return pkg, nil +} diff --git a/internal/models/depreference/depreference.go b/internal/models/depreference/depreference.go new file mode 100644 index 0000000..16473a5 --- /dev/null +++ b/internal/models/depreference/depreference.go @@ -0,0 +1,1352 @@ +// Package depreference provides the DependencyReference model -- the core +// dependency representation and parsing layer for the APM CLI. +// +// Migrated from: src/apm_cli/models/dependency/reference.py +package depreference + +import ( + "fmt" + "net/url" + "path/filepath" + "regexp" + "runtime" + "strings" + "unicode" + + "github.com/githubnext/apm/internal/utils/githubhost" + "github.com/githubnext/apm/internal/utils/pathsecurity" +) + +// defaultSchemePorts maps URI schemes to their default ports so that +// redundant explicit ports (https://host:443/...) can be stripped. +var defaultSchemePorts = map[string]int{ + "https": 443, + "http": 80, + "ssh": 22, +} + +// VirtualPackageType classifies a virtual (sub-repo) package. +type VirtualPackageType int + +const ( + VirtualPackageFile VirtualPackageType = iota // Individual file (*.prompt.md etc.) + VirtualPackageSubdirectory // Subdirectory package +) + +// virtualFileExtensions lists the file extensions recognised as virtual FILE packages. +var virtualFileExtensions = []string{ + ".prompt.md", + ".instructions.md", + ".chatmode.md", + ".agent.md", +} + +// removedCollectionExtensions lists legacy collection-manifest extensions that +// are rejected at parse time with a migration message. +var removedCollectionExtensions = []string{ + ".collection.yml", + ".collection.yaml", +} + +// gitlabVirtualRootSegments is the set of first-path segments that, on +// GitLab, often start an in-repo virtual layout. +var gitlabVirtualRootSegments = map[string]bool{ + "prompts": true, + "instructions": true, + "collections": true, +} + +// scpLikeRE matches SCP-style SSH URLs: @: +// Mirrors the Python SCP_LIKE_RE used in cache/url_normalize. +var scpLikeRE = regexp.MustCompile( + `^(?P[^@]+)@(?P[^:]+):(?P.+)$`, +) + +// aliasRE validates alias strings. +var aliasRE = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`) + +// adoRepoRE validates org/project/repo paths for Azure DevOps. +var adoRepoRE = regexp.MustCompile(`^[a-zA-Z0-9._-]+/[a-zA-Z0-9._\- ]+/[a-zA-Z0-9._\- ]+$`) + +// DependencyReference is the central model for an APM dependency. +// +// Fields mirror the Python DependencyReference dataclass exactly. +type DependencyReference struct { + RepoURL string // e.g. "owner/repo" or "org/project/repo" for ADO + Host string // Optional host; empty means default (github.com) + Port int // Non-standard SSH/HTTPS port; 0 means default + // ExplicitScheme is the user-stated transport: "ssh", "https", "http", + // or "" for shorthand notation. + ExplicitScheme string + Reference string // e.g. "main", "v1.0.0", "abc123" + Alias string // Optional alias for the dependency + VirtualPath string // Path for virtual packages + IsVirtual bool // True if this is a virtual package + + // Azure DevOps specific fields + ADOOrganization string + ADOProject string + ADORepo string + + // Local path dependency + IsLocal bool + LocalPath string // Original local path string + + // Monorepo parent inheritance + IsParentRepoInheritance bool + + ArtifactoryPrefix string // e.g. "artifactory/github" + + // HTTP insecure dependency + IsInsecure bool + AllowInsecure bool + + // SKILL_BUNDLE subset selection + SkillSubset []string // sorted skill names, nil = all +} + +// VirtualType returns the type of virtual package, or -1 if not virtual. +func (d *DependencyReference) VirtualType() VirtualPackageType { + if !d.IsVirtual || d.VirtualPath == "" { + return -1 + } + for _, ext := range virtualFileExtensions { + if strings.HasSuffix(d.VirtualPath, ext) { + return VirtualPackageFile + } + } + return VirtualPackageSubdirectory +} + +// IsVirtualFile returns true when this is a virtual file package. +func (d *DependencyReference) IsVirtualFile() bool { + return d.VirtualType() == VirtualPackageFile +} + +// IsVirtualSubdirectory returns true when this is a virtual subdirectory package. +func (d *DependencyReference) IsVirtualSubdirectory() bool { + return d.VirtualType() == VirtualPackageSubdirectory +} + +// IsArtifactory returns true when this reference points to a JFrog Artifactory VCS repo. +func (d *DependencyReference) IsArtifactory() bool { + return d.ArtifactoryPrefix != "" +} + +// IsAzureDevOps returns true when this reference points to Azure DevOps. +func (d *DependencyReference) IsAzureDevOps() bool { + return d.Host != "" && githubhost.IsAzureDevOpsHostname(d.Host) +} + +// GetVirtualPackageName generates a package name for a virtual package. +// +// owner/repo/prompts/code-review.prompt.md -> repo-code-review +// owner/repo/collections/project-planning -> repo-project-planning +func (d *DependencyReference) GetVirtualPackageName() string { + if !d.IsVirtual || d.VirtualPath == "" { + parts := strings.Split(d.RepoURL, "/") + return parts[len(parts)-1] + } + repoParts := strings.Split(d.RepoURL, "/") + repoName := "package" + if len(repoParts) > 0 { + repoName = repoParts[len(repoParts)-1] + } + pathParts := strings.Split(d.VirtualPath, "/") + last := pathParts[len(pathParts)-1] + for _, ext := range virtualFileExtensions { + if strings.HasSuffix(last, ext) { + last = last[:len(last)-len(ext)] + break + } + } + return repoName + "-" + last +} + +// IsLocalPath returns true when dep_str looks like a local filesystem path. +func IsLocalPath(depStr string) bool { + s := strings.TrimSpace(depStr) + if strings.HasPrefix(s, "//") { + return false + } + for _, pfx := range []string{"./", "../", "/", "~/", `~\`, `.\`, `..\`} { + if strings.HasPrefix(s, pfx) { + return true + } + } + // Windows absolute path: drive letter + colon + separator + if runtime.GOOS == "windows" || (len(s) >= 3 && + ((s[0] >= 'A' && s[0] <= 'Z') || (s[0] >= 'a' && s[0] <= 'z')) && + s[1] == ':' && (s[2] == '\\' || s[2] == '/')) { + return len(s) >= 3 + } + return false +} + +// GetUniqueKey returns a key for deduplication. +func (d *DependencyReference) GetUniqueKey() string { + if d.IsLocal && d.LocalPath != "" { + return d.LocalPath + } + if d.IsVirtual && d.VirtualPath != "" { + return d.RepoURL + "/" + d.VirtualPath + } + return d.RepoURL +} + +// effectiveHost returns d.Host or the default host (github.com). +func (d *DependencyReference) effectiveHost() string { + if d.Host != "" { + return d.Host + } + return githubhost.DefaultHost() +} + +// hostLabel returns host:port or host. +func (d *DependencyReference) hostLabel() string { + h := d.effectiveHost() + if d.Port != 0 { + return fmt.Sprintf("%s:%d", h, d.Port) + } + return h +} + +// ToCanonical returns the canonical scheme-free identity string. +func (d *DependencyReference) ToCanonical() string { + if d.IsLocal && d.LocalPath != "" { + return d.LocalPath + } + host := d.effectiveHost() + isDefault := strings.EqualFold(host, githubhost.DefaultHost()) + hl := d.hostLabel() + + var result string + switch { + case isDefault && d.Port == 0 && d.ArtifactoryPrefix == "": + result = d.RepoURL + case d.ArtifactoryPrefix != "": + result = hl + "/" + d.ArtifactoryPrefix + "/" + d.RepoURL + default: + result = hl + "/" + d.RepoURL + } + if d.IsVirtual && d.VirtualPath != "" { + result = result + "/" + d.VirtualPath + } + if d.Reference != "" { + result = result + "#" + d.Reference + } + return result +} + +// GetIdentity returns the identity (canonical without ref/alias). +func (d *DependencyReference) GetIdentity() string { + if d.IsLocal && d.LocalPath != "" { + return d.LocalPath + } + host := d.effectiveHost() + isDefault := strings.EqualFold(host, githubhost.DefaultHost()) + hl := d.hostLabel() + + var result string + switch { + case isDefault && d.Port == 0 && d.ArtifactoryPrefix == "": + result = d.RepoURL + case d.ArtifactoryPrefix != "": + result = hl + "/" + d.ArtifactoryPrefix + "/" + d.RepoURL + default: + result = hl + "/" + d.RepoURL + } + if d.IsVirtual && d.VirtualPath != "" { + result = result + "/" + d.VirtualPath + } + return result +} + +// GetCanonicalDependencyString is host-blind (filesystem-layout) canonical string. +func (d *DependencyReference) GetCanonicalDependencyString() string { + return d.GetUniqueKey() +} + +// GetInstallPath returns the canonical filesystem path under apm_modules_dir. +func (d *DependencyReference) GetInstallPath(apmModulesDir string) (string, error) { + if d.IsLocal && d.LocalPath != "" { + pkgDirName := filepath.Base(d.LocalPath) + if pkgDirName == "" || pkgDirName == "." || pkgDirName == ".." { + return "", fmt.Errorf("local path %q does not resolve to a named directory", d.LocalPath) + } + if err := pathsecurity.ValidatePathSegments(pkgDirName, "local package path", true, false); err != nil { + return "", err + } + result := filepath.Join(apmModulesDir, "_local", pkgDirName) + return pathsecurity.EnsurePathWithin(result, apmModulesDir) + } + + repoParts := strings.Split(d.RepoURL, "/") + if err := pathsecurity.ValidatePathSegments(d.RepoURL, "repo_url", false, false); err != nil { + return "", err + } + if d.VirtualPath != "" { + if err := pathsecurity.ValidatePathSegments(d.VirtualPath, "virtual_path", false, false); err != nil { + return "", err + } + } + + var result string + if d.IsVirtual { + if d.IsVirtualSubdirectory() { + if d.IsAzureDevOps() && len(repoParts) >= 3 { + result = filepath.Join(apmModulesDir, repoParts[0], repoParts[1], repoParts[2], d.VirtualPath) + } else if len(repoParts) >= 2 { + parts := append(repoParts, strings.Split(d.VirtualPath, "/")...) + result = filepath.Join(append([]string{apmModulesDir}, parts...)...) + } + } else { + pkgName := d.GetVirtualPackageName() + if d.IsAzureDevOps() && len(repoParts) >= 3 { + result = filepath.Join(apmModulesDir, repoParts[0], repoParts[1], pkgName) + } else if len(repoParts) >= 2 { + result = filepath.Join(apmModulesDir, repoParts[0], pkgName) + } + } + } else if d.IsAzureDevOps() && len(repoParts) >= 3 { + result = filepath.Join(apmModulesDir, repoParts[0], repoParts[1], repoParts[2]) + } else if len(repoParts) >= 2 { + result = filepath.Join(append([]string{apmModulesDir}, repoParts...)...) + } + + if result == "" { + result = filepath.Join(append([]string{apmModulesDir}, repoParts...)...) + } + + return pathsecurity.EnsurePathWithin(result, apmModulesDir) +} + +// ToGitHubURL converts to a full repository HTTPS URL. +func (d *DependencyReference) ToGitHubURL() string { + if d.IsLocal && d.LocalPath != "" { + return d.LocalPath + } + host := d.effectiveHost() + netloc := host + if d.Port != 0 { + netloc = fmt.Sprintf("%s:%d", host, d.Port) + } + scheme := "https" + if d.IsInsecure { + scheme = "http" + } + if d.IsAzureDevOps() { + proj := url.PathEscape(d.ADOProject) + repo := url.PathEscape(d.ADORepo) + return fmt.Sprintf("https://%s/%s/%s/_git/%s", netloc, d.ADOOrganization, proj, repo) + } + if d.ArtifactoryPrefix != "" { + return fmt.Sprintf("%s://%s/%s/%s", scheme, netloc, d.ArtifactoryPrefix, d.RepoURL) + } + return fmt.Sprintf("%s://%s/%s", scheme, netloc, d.RepoURL) +} + +// ToCloneURL is the same as ToGitHubURL for most purposes. +func (d *DependencyReference) ToCloneURL() string { + return d.ToGitHubURL() +} + +// GetDisplayName returns the alias, local path, virtual name, or repo URL. +func (d *DependencyReference) GetDisplayName() string { + if d.Alias != "" { + return d.Alias + } + if d.IsLocal && d.LocalPath != "" { + return d.LocalPath + } + if d.IsVirtual { + return d.GetVirtualPackageName() + } + return d.RepoURL +} + +// String returns a human-readable representation. +func (d *DependencyReference) String() string { + if d.IsLocal && d.LocalPath != "" { + return d.LocalPath + } + var result string + if d.Host != "" { + hl := d.hostLabel() + if d.ArtifactoryPrefix != "" { + result = hl + "/" + d.ArtifactoryPrefix + "/" + d.RepoURL + } else { + result = hl + "/" + d.RepoURL + } + } else { + result = d.RepoURL + } + if d.VirtualPath != "" { + result += "/" + d.VirtualPath + } + if d.Reference != "" { + result += "#" + d.Reference + } + if d.Alias != "" { + result += "@" + d.Alias + } + return result +} + +// ----- Parsing helpers ----- + +// parseSCPURL parses an SCP-shorthand SSH URL (user@host:path). +// Returns (host, port, repoURL, reference, alias, true) or ("","",…, false). +func parseSCPURL(depStr string) (host string, port int, repoURL, reference, alias string, ok bool) { + m := scpLikeRE.FindStringSubmatch(depStr) + if m == nil { + return + } + sshRepo := m[3] + if strings.Contains(sshRepo, "@") { + idx := strings.LastIndex(sshRepo, "@") + alias = strings.TrimSpace(sshRepo[idx+1:]) + sshRepo = sshRepo[:idx] + } + if strings.Contains(sshRepo, "#") { + idx := strings.LastIndex(sshRepo, "#") + reference = strings.TrimSpace(sshRepo[idx+1:]) + sshRepo = sshRepo[:idx] + } + if strings.HasSuffix(sshRepo, ".git") { + sshRepo = sshRepo[:len(sshRepo)-4] + } + repoURL = strings.TrimSpace(sshRepo) + if err := pathsecurity.ValidatePathSegments(repoURL, "SSH repository path", true, false); err != nil { + return + } + host = m[2] + ok = true + return +} + +// parseSSHProtocolURL parses ssh:// URLs. +func parseSSHProtocolURL(rawURL string) (host string, port int, repoURL, reference, alias string, ok bool) { + if !strings.HasPrefix(rawURL, "ssh://") { + return + } + u, err := url.Parse(rawURL) + if err != nil { + return + } + host = u.Hostname() + if p, err2 := parsePortInt(u.Port()); err2 == nil && p != 0 { + port = p + if port == defaultSchemePorts["ssh"] { + port = 0 + } + } + path := strings.TrimPrefix(u.Path, "/") + fragment := u.Fragment + if fragment != "" { + if strings.Contains(fragment, "@") { + idx := strings.LastIndex(fragment, "@") + reference = strings.TrimSpace(fragment[:idx]) + alias = strings.TrimSpace(fragment[idx+1:]) + } else { + reference = strings.TrimSpace(fragment) + } + } + if alias == "" && strings.Contains(path, "@") { + idx := strings.LastIndex(path, "@") + alias = strings.TrimSpace(path[idx+1:]) + path = path[:idx] + } + if strings.HasSuffix(path, ".git") { + path = path[:len(path)-4] + } + repoURL = strings.TrimSpace(path) + if err2 := pathsecurity.ValidatePathSegments(repoURL, "SSH repository path", true, false); err2 != nil { + return + } + ok = true + return +} + +func parsePortInt(s string) (int, error) { + if s == "" { + return 0, nil + } + var p int + _, err := fmt.Sscanf(s, "%d", &p) + return p, err +} + +// hasVirtualExt returns true if any segment ends in a virtual file extension. +func hasVirtualExt(segments []string) bool { + for _, seg := range segments { + for _, ext := range virtualFileExtensions { + if strings.HasSuffix(seg, ext) { + return true + } + } + } + return false +} + +// gitlabSegmentCount computes how many path segments belong to the GitLab +// project path vs the virtual package suffix. +func gitlabSegmentCount(segs []string, hasVirtExt, hasCollection bool) int { + n := len(segs) + if n < 2 { + return n + } + if hasCollection { + for i, s := range segs { + if s == "collections" && i >= 2 { + return i + } + } + return n + } + if hasVirtExt { + for i, seg := range segs { + if i >= 2 && gitlabVirtualRootSegments[seg] { + return i + } + } + if n == 3 { + return 2 + } + if n == 4 { + return 3 + } + if n >= 5 { + return 3 + } + return 2 + } + return n +} + +// detectVirtualPackage scans a dependency string for virtual package indicators. +// Returns (isVirtual, virtualPath, validatedHost, error). +func detectVirtualPackage(depStr string) (bool, string, string, error) { + temp := depStr + if idx := strings.LastIndex(temp, "#"); idx >= 0 { + temp = temp[:idx] + } + + lower := strings.ToLower(temp) + for _, pfx := range []string{"git@", "https://", "http://", "ssh://"} { + if strings.HasPrefix(lower, pfx) { + return false, "", "", nil + } + } + + check := temp + var validatedHost string + + if strings.Contains(check, "/") { + firstSeg := strings.SplitN(check, "/", 2)[0] + if strings.Contains(firstSeg, ".") { + testURL := "https://" + check + u, err := url.Parse(testURL) + if err == nil && u.Hostname() != "" && githubhost.IsSupportedGitHost(u.Hostname()) { + validatedHost = u.Hostname() + check = strings.SplitN(check, "/", 2)[1] + } else if err == nil { + return false, "", "", fmt.Errorf("invalid Git host: %s", firstSeg) + } + } else if strings.HasPrefix(check, "gh/") { + check = check[3:] + } + } + + pathSegments := filterEmpty(strings.Split(check, "/")) + + isADO := validatedHost != "" && githubhost.IsAzureDevOpsHostname(validatedHost) + isGenericHost := validatedHost != "" && !githubhost.IsGitHubHostname(validatedHost) && !githubhost.IsAzureDevOpsHostname(validatedHost) + isGitLabHost := validatedHost != "" && githubhost.IsGitLabHostname(validatedHost) + + if isADO { + for i, s := range pathSegments { + if s == "_git" { + pathSegments = append(pathSegments[:i], pathSegments[i+1:]...) + break + } + } + } + + isArtifactory := isGenericHost && githubhost.IsArtifactoryPath(pathSegments) + + var minBaseSegments int + switch { + case isADO: + if validatedHost != "" && githubhost.IsVisualStudioLegacyHostname(validatedHost) { + minBaseSegments = 2 + } else { + minBaseSegments = 3 + } + case isArtifactory: + minBaseSegments = 4 + case isGenericHost: + hv := hasVirtualExt(pathSegments) + hc := contains(pathSegments, "collections") + if isGitLabHost { + minBaseSegments = gitlabSegmentCount(pathSegments, hv, hc) + } else if hv || hc { + minBaseSegments = 2 + } else { + minBaseSegments = len(pathSegments) + } + default: + minBaseSegments = 2 + } + + if len(pathSegments) >= minBaseSegments+1 { + vPath := strings.Join(pathSegments[minBaseSegments:], "/") + if err := pathsecurity.ValidatePathSegments(vPath, "virtual path", false, false); err != nil { + return false, "", validatedHost, err + } + for _, ext := range removedCollectionExtensions { + if strings.HasSuffix(vPath, ext) { + return false, "", validatedHost, fmt.Errorf( + ".collection.yml is no longer supported. Convert %q to an apm.yml with a 'dependencies' section", vPath) + } + } + for _, ext := range virtualFileExtensions { + if strings.HasSuffix(vPath, ext) { + return true, vPath, validatedHost, nil + } + } + last := vPath + if idx := strings.LastIndex(vPath, "/"); idx >= 0 { + last = vPath[idx+1:] + } + if strings.Contains(last, ".") { + return false, "", validatedHost, fmt.Errorf( + "invalid virtual package path %q: individual files must end with a recognized extension", vPath) + } + return true, vPath, validatedHost, nil + } + + return false, "", validatedHost, nil +} + +func filterEmpty(ss []string) []string { + out := ss[:0] + for _, s := range ss { + if s != "" { + out = append(out, s) + } + } + return out +} + +func contains(ss []string, s string) bool { + for _, x := range ss { + if x == s { + return true + } + } + return false +} + +// validateURLRepoPath validates and normalises the repo path from a parsed URL. +// Returns (repoURL, virtualPath, error). +func validateURLRepoPath(u *url.URL) (string, string, error) { + hostname := u.Hostname() + if !githubhost.IsSupportedGitHost(hostname) { + return "", "", fmt.Errorf("invalid Git host: %s", hostname) + } + + path := strings.TrimPrefix(u.Path, "/") + if path == "" { + return "", "", fmt.Errorf("repository path cannot be empty") + } + if strings.HasSuffix(path, ".git") { + path = path[:len(path)-4] + } + + pathParts := make([]string, 0) + for _, p := range strings.Split(path, "/") { + pathParts = append(pathParts, urlUnescape(p)) + } + // Remove _git segment (Azure DevOps) + for i, p := range pathParts { + if p == "_git" { + pathParts = append(pathParts[:i], pathParts[i+1:]...) + break + } + } + + isADO := githubhost.IsAzureDevOpsHostname(hostname) + var urlVirtualPath string + + if isADO { + isVSLegacy := githubhost.IsVisualStudioLegacyHostname(hostname) + minParts := 3 + if isVSLegacy { + minParts = 2 + } + if len(pathParts) < minParts { + return "", "", fmt.Errorf("invalid Azure DevOps repository path: expected 'org/project/repo', got %q", path) + } + if len(pathParts) > minParts { + adoVirtual := strings.Join(pathParts[minParts:], "/") + if err := pathsecurity.ValidatePathSegments(adoVirtual, "virtual path", false, false); err != nil { + return "", "", err + } + for _, ext := range removedCollectionExtensions { + if strings.HasSuffix(adoVirtual, ext) { + return "", "", fmt.Errorf(".collection.yml is no longer supported for %q", adoVirtual) + } + } + isFile := false + for _, ext := range virtualFileExtensions { + if strings.HasSuffix(adoVirtual, ext) { + isFile = true + break + } + } + if !isFile { + last := adoVirtual + if idx := strings.LastIndex(adoVirtual, "/"); idx >= 0 { + last = adoVirtual[idx+1:] + } + if strings.Contains(last, ".") { + return "", "", fmt.Errorf("invalid virtual package path %q", adoVirtual) + } + } + urlVirtualPath = adoVirtual + pathParts = pathParts[:minParts] + } + if isVSLegacy { + vsOrg := strings.SplitN(hostname, ".", 2)[0] + pathParts = append([]string{vsOrg}, pathParts...) + } + } else { + if len(pathParts) < 2 { + return "", "", fmt.Errorf("invalid repository path: expected at least 'user/repo', got %q", path) + } + for _, pp := range pathParts { + for _, ext := range virtualFileExtensions { + if strings.HasSuffix(pp, ext) { + return "", "", fmt.Errorf("invalid repository path %q: contains a virtual file extension; use dict format with 'path:' for virtual packages", path) + } + } + } + } + + isADOPath := githubhost.IsAzureDevOpsHostname(hostname) + allowedPattern := `^[a-zA-Z0-9._-]+$` + if isADOPath { + allowedPattern = `^[a-zA-Z0-9._\- ]+$` + } + allowedRE := regexp.MustCompile(allowedPattern) + + if err := pathsecurity.ValidatePathSegments(strings.Join(pathParts, "/"), "repository URL path", true, false); err != nil { + return "", "", err + } + for _, part := range pathParts { + if !allowedRE.MatchString(part) { + return "", "", fmt.Errorf("invalid repository path component: %s", part) + } + } + + return strings.Join(pathParts, "/"), urlVirtualPath, nil +} + +func urlUnescape(s string) string { + out, err := url.PathUnescape(s) + if err != nil { + return s + } + return out +} + +// resolveVirtualShorthandRepo strips the virtual suffix from a shorthand repo_url. +// Returns (host, repoURL). +func resolveVirtualShorthandRepo(repoURL, validatedHost, virtualPath string) (string, string) { + parts := filterEmpty(strings.Split(repoURL, "/")) + // Remove _git + for i, p := range parts { + if p == "_git" { + parts = append(parts[:i], parts[i+1:]...) + break + } + } + + host := "" + if len(parts) >= 3 && githubhost.IsSupportedGitHost(parts[0]) { + host = parts[0] + if githubhost.IsAzureDevOpsHostname(parts[0]) { + if githubhost.IsVisualStudioLegacyHostname(parts[0]) { + if len(parts) >= 4 { + repoURL = strings.Join(parts[1:3], "/") + } + } else { + if len(parts) >= 5 { + repoURL = strings.Join(parts[1:4], "/") + } + } + } else if githubhost.IsArtifactoryPath(parts[1:]) { + prefix, owner, repo := githubhost.ParseArtifactoryPath(parts[1:]) + if owner != "" && repo != "" { + _ = prefix + repoURL = owner + "/" + repo + } + } else if githubhost.IsGitLabHostname(parts[0]) && virtualPath != "" { + vParts := filterEmpty(strings.Split(virtualPath, "/")) + tail := len(vParts) + if tail > 0 && len(parts) > 1+tail { + repoURL = strings.Join(parts[1:len(parts)-tail], "/") + } else { + repoURL = strings.Join(parts[1:], "/") + } + } else { + repoURL = strings.Join(parts[1:3], "/") + } + } else if len(parts) >= 2 { + if host == "" { + host = githubhost.DefaultHost() + } + if validatedHost != "" && githubhost.IsAzureDevOpsHostname(validatedHost) { + if len(parts) >= 4 { + repoURL = strings.Join(parts[:3], "/") + } + } else { + repoURL = strings.Join(parts[:2], "/") + } + } + return host, repoURL +} + +// resolveShorthandToParsedURL converts a shorthand to a *url.URL. +func resolveShorthandToParsedURL(repoURL, host string) (*url.URL, string, error) { + parts := filterEmpty(strings.Split(repoURL, "/")) + for i, p := range parts { + if p == "_git" { + parts = append(parts[:i], parts[i+1:]...) + break + } + } + + var userRepo string + if len(parts) >= 3 && githubhost.IsSupportedGitHost(parts[0]) { + host = parts[0] + if githubhost.IsVisualStudioLegacyHostname(host) && len(parts) >= 3 { + userRepo = strings.Join(parts[1:3], "/") + } else if githubhost.IsAzureDevOpsHostname(host) && len(parts) >= 4 { + userRepo = strings.Join(parts[1:4], "/") + } else if !githubhost.IsGitHubHostname(host) && !githubhost.IsAzureDevOpsHostname(host) { + if githubhost.IsArtifactoryPath(parts[1:]) { + _, owner, repo := githubhost.ParseArtifactoryPath(parts[1:]) + if owner != "" && repo != "" { + userRepo = owner + "/" + repo + } else { + userRepo = strings.Join(parts[1:], "/") + } + } else { + userRepo = strings.Join(parts[1:], "/") + } + } else { + userRepo = strings.Join(parts[1:], "/") + } + } else if len(parts) >= 2 && !strings.Contains(parts[0], ".") { + if host == "" { + host = githubhost.DefaultHost() + } + if githubhost.IsAzureDevOpsHostname(host) && len(parts) >= 3 { + userRepo = strings.Join(parts[:3], "/") + } else if host != "" && !githubhost.IsGitHubHostname(host) && !githubhost.IsAzureDevOpsHostname(host) { + userRepo = strings.Join(parts, "/") + } else { + userRepo = strings.Join(parts[:2], "/") + } + } else { + return nil, "", fmt.Errorf("use 'user/repo' or 'github.com/user/repo' format") + } + + if userRepo == "" || !strings.Contains(userRepo, "/") { + return nil, "", fmt.Errorf("invalid repository format: %s", repoURL) + } + + uParts := strings.Split(userRepo, "/") + isADOHost := host != "" && githubhost.IsAzureDevOpsHostname(host) + + if isADOHost { + minADOParts := 3 + if githubhost.IsVisualStudioLegacyHostname(host) { + minADOParts = 2 + } + if len(uParts) < minADOParts { + return nil, "", fmt.Errorf("invalid Azure DevOps repository format: %s", repoURL) + } + } else if len(uParts) < 2 { + return nil, "", fmt.Errorf("invalid repository format: %s", repoURL) + } + + if err := pathsecurity.ValidatePathSegments(strings.Join(uParts, "/"), "repository path", false, false); err != nil { + return nil, "", err + } + + allowedPattern := `^[a-zA-Z0-9._-]+$` + if isADOHost { + allowedPattern = `^[a-zA-Z0-9._\- ]+$` + } + allowedRE := regexp.MustCompile(allowedPattern) + for _, part := range uParts { + stripped := strings.TrimSuffix(part, ".git") + if !allowedRE.MatchString(stripped) { + return nil, "", fmt.Errorf("invalid repository path component: %s", part) + } + } + + escapedParts := make([]string, len(uParts)) + for i, p := range uParts { + escapedParts[i] = url.PathEscape(p) + } + rawURL := fmt.Sprintf("https://%s/%s", host, strings.Join(escapedParts, "/")) + parsed, err := url.Parse(rawURL) + if err != nil { + return nil, "", fmt.Errorf("failed to build URL for %s: %w", repoURL, err) + } + return parsed, host, nil +} + +// parseStandardURL handles non-SSH dependency strings. +func parseStandardURL(depStr string, isVirtual bool, virtualPath, validatedHost string) ( + host string, port int, repoURL, reference, alias string, + effectiveIsVirtual bool, effectiveVirtualPath string, err error, +) { + effectiveIsVirtual = isVirtual + effectiveVirtualPath = virtualPath + + repoPart := depStr + if idx := strings.LastIndex(depStr, "#"); idx >= 0 { + repoPart = depStr[:idx] + reference = strings.TrimSpace(depStr[idx+1:]) + } + repoURL = strings.TrimSpace(repoPart) + lower := strings.ToLower(repoURL) + + if isVirtual && !strings.HasPrefix(lower, "https://") && !strings.HasPrefix(lower, "http://") { + host, repoURL = resolveVirtualShorthandRepo(repoURL, validatedHost, virtualPath) + } + + lower = strings.ToLower(repoURL) + var parsedURL *url.URL + if strings.HasPrefix(lower, "https://") || strings.HasPrefix(lower, "http://") { + parsedURL, err = url.Parse(repoURL) + if err != nil { + return + } + host = parsedURL.Hostname() + if p, e := parsePortInt(parsedURL.Port()); e == nil { + port = p + } + scheme := strings.ToLower(parsedURL.Scheme) + if port == defaultSchemePorts[scheme] { + port = 0 + } + } else { + parsedURL, host, err = resolveShorthandToParsedURL(repoURL, host) + if err != nil { + return + } + } + + var urlVirtualPath string + repoURL, urlVirtualPath, err = validateURLRepoPath(parsedURL) + if err != nil { + return + } + if urlVirtualPath != "" { + effectiveIsVirtual = true + effectiveVirtualPath = urlVirtualPath + } + if host == "" { + host = githubhost.DefaultHost() + } + return +} + +// validateFinalRepoFields checks the final repo_url and extracts ADO fields. +func validateFinalRepoFields(host, repoURL string) (adoOrg, adoProject, adoRepo string, err error) { + isADO := host != "" && githubhost.IsAzureDevOpsHostname(host) + if isADO { + if !adoRepoRE.MatchString(repoURL) { + err = fmt.Errorf("invalid Azure DevOps repository format: %s; expected 'org/project/repo'", repoURL) + return + } + parts := strings.SplitN(repoURL, "/", 3) + if err2 := pathsecurity.ValidatePathSegments(repoURL, "Azure DevOps repository path", false, false); err2 != nil { + err = err2 + return + } + adoOrg, adoProject, adoRepo = parts[0], parts[1], parts[2] + return + } + + segments := strings.Split(repoURL, "/") + if len(segments) < 2 { + err = fmt.Errorf("invalid repository format: %s; expected 'user/repo'", repoURL) + return + } + validRE := regexp.MustCompile(`^[a-zA-Z0-9._-]+$`) + for _, s := range segments { + if !validRE.MatchString(s) { + err = fmt.Errorf("invalid repository format: %s; contains invalid characters", repoURL) + return + } + for _, ext := range virtualFileExtensions { + if strings.HasSuffix(s, ext) { + err = fmt.Errorf("invalid repository format: %q contains a virtual file extension", repoURL) + return + } + } + } + if e := pathsecurity.ValidatePathSegments(repoURL, "repository path", false, false); e != nil { + err = e + } + return +} + +// extractArtifactoryPrefix extracts the Artifactory VCS prefix from the original dep string. +func extractArtifactoryPrefix(depStr, host string) string { + s := depStr + if idx := strings.Index(s, "#"); idx >= 0 { + s = s[:idx] + } + if idx := strings.Index(s, "@"); idx >= 0 { + s = s[:idx] + } + if strings.Contains(s, "://") { + s = strings.SplitN(s, "://", 2)[1] + } + s = strings.Replace(s, host+"/", "", 1) + segs := filterEmpty(strings.Split(s, "/")) + if githubhost.IsArtifactoryPath(segs) { + prefix, _, _ := githubhost.ParseArtifactoryPath(segs) + return prefix + } + return "" +} + +// Parse parses a dependency string into a DependencyReference. +// +// Supports all forms: shorthand (user/repo), FQDN, HTTPS, SSH, SCP, local paths. +func Parse(depStr string) (*DependencyReference, error) { + if strings.TrimSpace(depStr) == "" { + return nil, fmt.Errorf("empty dependency string") + } + + depStr, err := url.PathUnescape(depStr) + if err != nil { + depStr = depStr // keep original on error + } + + for _, r := range depStr { + if r < 32 && !unicode.IsSpace(r) { + return nil, fmt.Errorf("dependency string contains invalid control characters") + } + } + + // Local path detection (must run before URL/host parsing) + if IsLocalPath(depStr) { + local := strings.TrimSpace(depStr) + base := filepath.Base(local) + if base == "" || base == "." || base == ".." { + return nil, fmt.Errorf("local path %q does not resolve to a named directory", local) + } + return &DependencyReference{ + RepoURL: "_local/" + base, + IsLocal: true, + LocalPath: local, + }, nil + } + + if strings.HasPrefix(depStr, "//") { + return nil, fmt.Errorf("protocol-relative URLs are not supported") + } + + // Phase 1: detect virtual packages + isVirtual, virtualPath, validatedHost, err := detectVirtualPackage(depStr) + if err != nil { + return nil, err + } + + // Phase 2: SSH parsing + var ( + host string + port int + repoURL string + reference string + alias string + explicitScheme string + ) + + if h, p, r, ref, al, ok := parseSSHProtocolURL(depStr); ok { + host, port, repoURL, reference, alias = h, p, r, ref, al + explicitScheme = "ssh" + } else if h, p, r, ref, al, ok2 := parseSCPURL(depStr); ok2 { + host, port, repoURL, reference, alias = h, p, r, ref, al + explicitScheme = "ssh" + } else { + var effectiveIsVirtual bool + var effectiveVirtualPath string + host, port, repoURL, reference, alias, effectiveIsVirtual, effectiveVirtualPath, err = + parseStandardURL(depStr, isVirtual, virtualPath, validatedHost) + if err != nil { + return nil, err + } + isVirtual = effectiveIsVirtual + virtualPath = effectiveVirtualPath + lower := strings.ToLower(strings.TrimSpace(depStr)) + if strings.HasPrefix(lower, "https://") { + explicitScheme = "https" + } else if strings.HasPrefix(lower, "http://") { + explicitScheme = "http" + } + } + + // Phase 3: validate final fields + adoOrg, adoProject, adoRepo, err := validateFinalRepoFields(host, repoURL) + if err != nil { + return nil, err + } + + if alias != "" && !aliasRE.MatchString(alias) { + return nil, fmt.Errorf("invalid alias: %s; aliases can only contain letters, numbers, dots, underscores, and hyphens", alias) + } + + isADO := host != "" && githubhost.IsAzureDevOpsHostname(host) + var artifactoryPrefix string + if host != "" && !isADO { + artifactoryPrefix = extractArtifactoryPrefix(depStr, host) + } + + parsedScheme := "" + if u, e := url.Parse(depStr); e == nil { + parsedScheme = strings.ToLower(u.Scheme) + } + + return &DependencyReference{ + RepoURL: repoURL, + Host: host, + Port: port, + ExplicitScheme: explicitScheme, + Reference: reference, + Alias: alias, + VirtualPath: virtualPath, + IsVirtual: isVirtual, + ADOOrganization: adoOrg, + ADOProject: adoProject, + ADORepo: adoRepo, + ArtifactoryPrefix: artifactoryPrefix, + IsInsecure: parsedScheme == "http", + IsParentRepoInheritance: false, + }, nil +} + +// Canonicalize parses raw and returns its canonical form. +func Canonicalize(raw string) (string, error) { + ref, err := Parse(raw) + if err != nil { + return "", err + } + return ref.ToCanonical(), nil +} + +// ParseFromDict parses a dict-style dependency entry (as in apm.yml). +func ParseFromDict(entry map[string]interface{}) (*DependencyReference, error) { + pathVal, hasPath := entry["path"] + gitVal, hasGit := entry["git"] + + if hasPath && !hasGit { + localStr, ok := pathVal.(string) + if !ok || strings.TrimSpace(localStr) == "" { + return nil, fmt.Errorf("'path' field must be a non-empty string") + } + localStr = strings.TrimSpace(localStr) + if !IsLocalPath(localStr) { + return nil, fmt.Errorf("object-style dependency must have a 'git' field, or 'path' must be a local filesystem path") + } + return Parse(localStr) + } + + if !hasGit { + return nil, fmt.Errorf("object-style dependency must have a 'git' or 'path' field") + } + + gitURL, ok := gitVal.(string) + if !ok || strings.TrimSpace(gitURL) == "" { + return nil, fmt.Errorf("'git' field must be a non-empty string") + } + gitURL = strings.TrimSpace(gitURL) + + // Parent repo inheritance + if gitURL == "parent" { + pathRaw, _ := entry["path"].(string) + if strings.TrimSpace(pathRaw) == "" { + return nil, fmt.Errorf("object-style dependency with git: 'parent' requires a 'path' field") + } + normPath := normalizeParentRepoPath(pathRaw) + if normPath == "" { + return nil, fmt.Errorf("'path' field must be a non-empty string") + } + dep := &DependencyReference{ + RepoURL: "_parent", + IsVirtual: true, + IsParentRepoInheritance: true, + VirtualPath: normPath, + } + if refRaw, ok2 := entry["ref"].(string); ok2 && strings.TrimSpace(refRaw) != "" { + dep.Reference = strings.TrimSpace(refRaw) + } + if aliasRaw, ok2 := entry["alias"].(string); ok2 && strings.TrimSpace(aliasRaw) != "" { + a := strings.TrimSpace(aliasRaw) + if !aliasRE.MatchString(a) { + return nil, fmt.Errorf("invalid alias: %s", a) + } + dep.Alias = a + } + return dep, nil + } + + dep, err := Parse(gitURL) + if err != nil { + return nil, err + } + + if allowInsecure, ok2 := entry["allow_insecure"].(bool); ok2 { + dep.AllowInsecure = allowInsecure + } + + if refRaw, ok2 := entry["ref"].(string); ok2 && strings.TrimSpace(refRaw) != "" { + dep.Reference = strings.TrimSpace(refRaw) + } + + if aliasRaw, ok2 := entry["alias"].(string); ok2 && strings.TrimSpace(aliasRaw) != "" { + a := strings.TrimSpace(aliasRaw) + if !aliasRE.MatchString(a) { + return nil, fmt.Errorf("invalid alias: %s", a) + } + dep.Alias = a + } + + if subPath, ok2 := entry["path"].(string); ok2 && strings.TrimSpace(subPath) != "" { + sp := strings.TrimSpace(strings.ReplaceAll(subPath, `\`, "/")) + sp = strings.Trim(sp, "/") + if err2 := pathsecurity.ValidatePathSegments(sp, "path", false, false); err2 != nil { + return nil, err2 + } + dep.VirtualPath = sp + dep.IsVirtual = true + } + + if skillsRaw, ok2 := entry["skills"].([]interface{}); ok2 { + if len(skillsRaw) == 0 { + return nil, fmt.Errorf("skills: must contain at least one name") + } + seen := map[string]bool{} + var validated []string + for _, s := range skillsRaw { + name, ok3 := s.(string) + if !ok3 || strings.TrimSpace(name) == "" { + return nil, fmt.Errorf("each entry in 'skills' must be a non-empty string") + } + name = strings.TrimSpace(name) + if err2 := pathsecurity.ValidatePathSegments(name, "skills/", false, false); err2 != nil { + return nil, err2 + } + if !seen[name] { + seen[name] = true + validated = append(validated, name) + } + } + dep.SkillSubset = sortedStrings(validated) + } + + return dep, nil +} + +func normalizeParentRepoPath(raw string) string { + s := strings.TrimSpace(raw) + s = strings.ReplaceAll(s, `\`, "/") + s = strings.Trim(s, "/") + parts := filterEmpty(strings.Split(s, "/")) + if len(parts) == 0 { + return "" + } + return strings.Join(parts, "/") +} + +func sortedStrings(ss []string) []string { + out := make([]string, len(ss)) + copy(out, ss) + // simple insertion sort (skill lists are short) + for i := 1; i < len(out); i++ { + for j := i; j > 0 && out[j] < out[j-1]; j-- { + out[j], out[j-1] = out[j-1], out[j] + } + } + return out +} + +// ToApmYMLEntry returns the value to store in apm.yml. +// Returns a string for simple deps, or a map for HTTP/skill-subset deps. +func (d *DependencyReference) ToApmYMLEntry() interface{} { + if d.IsInsecure { + host := d.effectiveHost() + entry := map[string]interface{}{ + "git": "http://" + host + "/" + d.RepoURL, + } + if d.Reference != "" { + entry["ref"] = d.Reference + } + if d.Alias != "" { + entry["alias"] = d.Alias + } + entry["allow_insecure"] = d.AllowInsecure + if len(d.SkillSubset) > 0 { + entry["skills"] = sortedStrings(d.SkillSubset) + } + return entry + } + if len(d.SkillSubset) > 0 { + entry := map[string]interface{}{ + "git": d.GetIdentity(), + } + if d.Reference != "" { + entry["ref"] = d.Reference + } + if d.Alias != "" { + entry["alias"] = d.Alias + } + entry["skills"] = sortedStrings(d.SkillSubset) + return entry + } + return d.ToCanonical() +} + +// VirtualSuffixIsInstallableShape returns true when virtualPath matches APM virtual package rules. +func VirtualSuffixIsInstallableShape(virtualPath string) bool { + if strings.TrimSpace(virtualPath) == "" { + return false + } + v := strings.Trim(strings.TrimSpace(virtualPath), "/") + if err := pathsecurity.ValidatePathSegments(v, "virtual path", false, false); err != nil { + return false + } + if strings.Contains(v, "/collections/") || strings.HasPrefix(v, "collections/") { + return true + } + for _, ext := range virtualFileExtensions { + if strings.HasSuffix(v, ext) { + return true + } + } + last := v + if idx := strings.LastIndex(v, "/"); idx >= 0 { + last = v[idx+1:] + } + return !strings.Contains(last, ".") +} diff --git a/internal/models/validation/validation.go b/internal/models/validation/validation.go new file mode 100644 index 0000000..798fb16 --- /dev/null +++ b/internal/models/validation/validation.go @@ -0,0 +1,580 @@ +// Package validation provides validation logic and type enums for APM packages. +// +// Mirrors src/apm_cli/models/validation.py. +package validation + +import ( + "fmt" + "os" + "path/filepath" + "regexp" + "strings" +) + +// PackageType classifies packages based on their content. +type PackageType int + +const ( + PackageTypeAPMPackage PackageType = iota // Has apm.yml + PackageTypeClaudeSkill // Has SKILL.md, no apm.yml + PackageTypeHookPackage // Has hooks/hooks.json, no apm.yml or SKILL.md + PackageTypeHybrid // Has both apm.yml and SKILL.md (root) + PackageTypeMarketplacePlugin // Has plugin.json or .claude-plugin/ + PackageTypeSkillBundle // Has skills//SKILL.md (nested) + PackageTypeInvalid // None of the above +) + +// String returns a human-readable name for the package type. +func (t PackageType) String() string { + switch t { + case PackageTypeAPMPackage: + return "apm_package" + case PackageTypeClaudeSkill: + return "claude_skill" + case PackageTypeHookPackage: + return "hook_package" + case PackageTypeHybrid: + return "hybrid" + case PackageTypeMarketplacePlugin: + return "marketplace_plugin" + case PackageTypeSkillBundle: + return "skill_bundle" + default: + return "invalid" + } +} + +// PackageContentType is the user-facing type field in apm.yml. +type PackageContentType int + +const ( + PackageContentTypeInstructions PackageContentType = iota // Compile to AGENTS.md only + PackageContentTypeSkill // Install as native skill only + PackageContentTypeHybrid // Both (default) + PackageContentTypePrompts // Commands/prompts only +) + +// String returns the string value of the content type. +func (t PackageContentType) String() string { + switch t { + case PackageContentTypeInstructions: + return "instructions" + case PackageContentTypeSkill: + return "skill" + case PackageContentTypeHybrid: + return "hybrid" + case PackageContentTypePrompts: + return "prompts" + default: + return "hybrid" + } +} + +// PackageContentTypeFromString parses a string into a PackageContentType. +func PackageContentTypeFromString(value string) (PackageContentType, error) { + if value == "" { + return 0, fmt.Errorf("package type cannot be empty") + } + v := strings.ToLower(strings.TrimSpace(value)) + switch v { + case "instructions": + return PackageContentTypeInstructions, nil + case "skill": + return PackageContentTypeSkill, nil + case "hybrid": + return PackageContentTypeHybrid, nil + case "prompts": + return PackageContentTypePrompts, nil + default: + return 0, fmt.Errorf("invalid package type '%s'. Valid types are: 'instructions', 'skill', 'hybrid', 'prompts'", value) + } +} + +// ValidationError enumerates types of validation errors for APM packages. +type ValidationError int + +const ( + ValidationErrorMissingAPMYml ValidationError = iota + ValidationErrorMissingAPMDir + ValidationErrorInvalidYmlFormat + ValidationErrorMissingRequiredField + ValidationErrorInvalidVersionFormat + ValidationErrorInvalidDependencyFormat + ValidationErrorEmptyAPMDir + ValidationErrorInvalidPrimitiveStructure +) + +// ValidationResult holds the result of APM package validation. +type ValidationResult struct { + IsValid bool + Errors []string + Warnings []string + PackageType PackageType +} + +// NewValidationResult creates an empty (valid) ValidationResult. +func NewValidationResult() *ValidationResult { + return &ValidationResult{IsValid: true} +} + +// AddError adds a validation error and marks the result as invalid. +func (r *ValidationResult) AddError(err string) { + r.Errors = append(r.Errors, err) + r.IsValid = false +} + +// AddWarning adds a validation warning. +func (r *ValidationResult) AddWarning(warning string) { + r.Warnings = append(r.Warnings, warning) +} + +// HasIssues returns true if there are any errors or warnings. +func (r *ValidationResult) HasIssues() bool { + return len(r.Errors) > 0 || len(r.Warnings) > 0 +} + +// Summary returns a human-readable summary of validation results. +func (r *ValidationResult) Summary() string { + if r.IsValid && len(r.Warnings) == 0 { + return "[+] Package is valid" + } else if r.IsValid && len(r.Warnings) > 0 { + return fmt.Sprintf("[!] Package is valid with %d warning(s)", len(r.Warnings)) + } + return fmt.Sprintf("[x] Package is invalid with %d error(s)", len(r.Errors)) +} + +// pluginDirs defines the canonical order of plugin content directories. +var pluginDirs = []string{"agents", "skills", "commands"} + +// DetectionEvidence is a snapshot of file-system signals for package classification. +type DetectionEvidence struct { + HasAPMYml bool + HasSkillMD bool + HasHookJSON bool + PluginJSONPath string // empty if not found + PluginDirsPresent []string + HasClaudePluginDir bool + NestedSkillDirs []string + HasPluginManifest bool +} + +// HasPluginEvidence returns true if a real plugin manifest is present. +func (e *DetectionEvidence) HasPluginEvidence() bool { + return e.HasPluginManifest +} + +// hasHookJSON checks if the package has hook JSON files in hooks/ or .apm/hooks/. +func hasHookJSON(packagePath string) bool { + for _, dir := range []string{filepath.Join(packagePath, "hooks"), filepath.Join(packagePath, ".apm", "hooks")} { + entries, err := os.ReadDir(dir) + if err != nil { + continue + } + for _, e := range entries { + if !e.IsDir() && strings.HasSuffix(e.Name(), ".json") { + return true + } + } + } + return false +} + +// findPluginJSON searches for plugin.json in the package root. +func findPluginJSON(packagePath string) string { + p := filepath.Join(packagePath, "plugin.json") + if _, err := os.Stat(p); err == nil { + return p + } + return "" +} + +// GatherDetectionEvidence collects all package-type signals from a directory. +func GatherDetectionEvidence(packagePath string) *DetectionEvidence { + ev := &DetectionEvidence{} + + // Check apm.yml + if _, err := os.Stat(filepath.Join(packagePath, "apm.yml")); err == nil { + ev.HasAPMYml = true + } + + // Check SKILL.md + if _, err := os.Stat(filepath.Join(packagePath, "SKILL.md")); err == nil { + ev.HasSkillMD = true + } + + // Check hook JSON + ev.HasHookJSON = hasHookJSON(packagePath) + + // Check plugin dirs + for _, dir := range pluginDirs { + if info, err := os.Stat(filepath.Join(packagePath, dir)); err == nil && info.IsDir() { + ev.PluginDirsPresent = append(ev.PluginDirsPresent, dir) + } + } + + // Check plugin.json + ev.PluginJSONPath = findPluginJSON(packagePath) + + // Check .claude-plugin/ + if info, err := os.Stat(filepath.Join(packagePath, ".claude-plugin")); err == nil && info.IsDir() { + ev.HasClaudePluginDir = true + } + + // Plugin manifest = plugin.json OR .claude-plugin/ + ev.HasPluginManifest = ev.PluginJSONPath != "" || ev.HasClaudePluginDir + + // Nested skill dirs: directories under skills/ that contain a SKILL.md + skillsDir := filepath.Join(packagePath, "skills") + if entries, err := os.ReadDir(skillsDir); err == nil { + for _, entry := range entries { + if !entry.IsDir() { + continue + } + skillMD := filepath.Join(skillsDir, entry.Name(), "SKILL.md") + if _, err := os.Stat(skillMD); err == nil { + ev.NestedSkillDirs = append(ev.NestedSkillDirs, entry.Name()) + } + } + } + + return ev +} + +// DetectPackageType classifies a package directory into a PackageType. +// Returns (packageType, pluginJSONPath). pluginJSONPath is non-empty only +// when MARKETPLACE_PLUGIN was matched via an actual plugin.json file. +func DetectPackageType(packagePath string) (PackageType, string) { + ev := GatherDetectionEvidence(packagePath) + + // 1. Plugin manifest present -> MARKETPLACE_PLUGIN + if ev.HasPluginManifest { + return PackageTypeMarketplacePlugin, ev.PluginJSONPath + } + + // 2. Root SKILL.md + apm.yml -> HYBRID + if ev.HasAPMYml && ev.HasSkillMD { + return PackageTypeHybrid, "" + } + + // 3. Root SKILL.md only -> CLAUDE_SKILL + if ev.HasSkillMD { + return PackageTypeClaudeSkill, "" + } + + // 4. Nested skills//SKILL.md -> SKILL_BUNDLE + if len(ev.NestedSkillDirs) > 0 { + return PackageTypeSkillBundle, "" + } + + // 5. apm.yml present -> APM_PACKAGE or INVALID + if ev.HasAPMYml { + apmDir := filepath.Join(packagePath, ".apm") + if info, err := os.Stat(apmDir); err == nil && info.IsDir() { + return PackageTypeAPMPackage, "" + } + if apmYMLDeclaresDependencies(filepath.Join(packagePath, "apm.yml")) { + return PackageTypeAPMPackage, "" + } + return PackageTypeInvalid, "" + } + + // 6. hooks/*.json only -> HOOK_PACKAGE + if ev.HasHookJSON { + return PackageTypeHookPackage, "" + } + + // 7. Nothing recognisable -> INVALID + return PackageTypeInvalid, "" +} + +// apmYMLDeclaresDependencies returns true iff apm.yml declares at least one dependency. +func apmYMLDeclaresDependencies(apmYMLPath string) bool { + data, err := os.ReadFile(apmYMLPath) + if err != nil { + return false + } + // Simple heuristic: look for "apm:" or "mcp:" under dependencies/devDependencies + // with at least one list item. A full YAML parse is not available without external libs. + content := string(data) + // Look for a non-empty apm: or mcp: list under dependencies or devDependencies + depSection := extractYAMLSection(content, "dependencies") + devSection := extractYAMLSection(content, "devDependencies") + return hasListedDeps(depSection) || hasListedDeps(devSection) +} + +// extractYAMLSection extracts a named top-level section from simple YAML. +func extractYAMLSection(content, key string) string { + lines := strings.Split(content, "\n") + inSection := false + var result []string + prefix := key + ":" + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if trimmed == prefix || strings.HasPrefix(trimmed, prefix+" ") { + inSection = true + result = append(result, line) + continue + } + if inSection { + // Stop when we hit another top-level key (no leading space) + if len(line) > 0 && line[0] != ' ' && line[0] != '\t' && line[0] != '#' && trimmed != "" { + break + } + result = append(result, line) + } + } + return strings.Join(result, "\n") +} + +// hasListedDeps checks if the section has apm: or mcp: lists with entries. +func hasListedDeps(section string) bool { + lines := strings.Split(section, "\n") + inAPMorMCP := false + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if trimmed == "apm:" || trimmed == "mcp:" { + inAPMorMCP = true + continue + } + if inAPMorMCP { + if strings.HasPrefix(trimmed, "- ") { + return true + } + if trimmed != "" && !strings.HasPrefix(trimmed, "#") { + inAPMorMCP = false + } + } + } + return false +} + +// semverRe matches a semantic version string (x.y.z). +var semverRe = regexp.MustCompile(`^\d+\.\d+\.\d+`) + +// ValidateAPMPackage validates that a directory contains a valid APM package. +func ValidateAPMPackage(packagePath string) *ValidationResult { + result := NewValidationResult() + + // Check if directory exists + info, err := os.Stat(packagePath) + if err != nil { + result.AddError(fmt.Sprintf("Package directory does not exist: %s", packagePath)) + return result + } + if !info.IsDir() { + result.AddError(fmt.Sprintf("Package path is not a directory: %s", packagePath)) + return result + } + + // Detect package type + pkgType, pluginJSONPath := DetectPackageType(packagePath) + result.PackageType = pkgType + + if pkgType == PackageTypeInvalid { + apmYMLPath := filepath.Join(packagePath, "apm.yml") + if _, err := os.Stat(apmYMLPath); err == nil { + apmPath := filepath.Join(packagePath, ".apm") + if apmInfo, err := os.Stat(apmPath); err == nil && !apmInfo.IsDir() { + result.AddError(".apm must be a directory") + } else { + dirName := filepath.Base(packagePath) + result.AddError(fmt.Sprintf( + "Not a valid APM package: %s has apm.yml but is missing the required .apm/ directory. "+ + "Add .apm/ with primitives (instructions, skills, etc.), "+ + "declare dependencies in apm.yml (curated aggregator), "+ + "or add skills//SKILL.md for a skill bundle.", dirName)) + } + } else { + dirName := filepath.Base(packagePath) + result.AddError(fmt.Sprintf( + "Not a valid APM package: no apm.yml, SKILL.md, hooks, or plugin structure found in %s. "+ + "Ensure the package has SKILL.md (skill bundle), "+ + "apm.yml + .apm/ (APM package), or plugin.json (Claude plugin) at its root.", dirName)) + } + return result + } + + switch pkgType { + case PackageTypeHookPackage: + return validateHookPackage(packagePath, result) + case PackageTypeClaudeSkill: + return validateClaudeSkill(packagePath, result) + case PackageTypeMarketplacePlugin: + return validateMarketplacePlugin(packagePath, pluginJSONPath, result) + case PackageTypeSkillBundle: + return validateSkillBundle(packagePath, result) + case PackageTypeHybrid: + return validateHybridPackage(packagePath, result) + default: + return validateAPMPackageWithYML(packagePath, result) + } +} + +func validateHookPackage(packagePath string, result *ValidationResult) *ValidationResult { + // Hook package is valid as-is -- just has hooks/*.json + return result +} + +func validateClaudeSkill(packagePath string, result *ValidationResult) *ValidationResult { + // Check SKILL.md is readable + skillMD := filepath.Join(packagePath, "SKILL.md") + if _, err := os.ReadFile(skillMD); err != nil { + result.AddError(fmt.Sprintf("Failed to read SKILL.md: %v", err)) + } + return result +} + +func validateMarketplacePlugin(packagePath, pluginJSONPath string, result *ValidationResult) *ValidationResult { + // Check plugin.json or .claude-plugin/ is present and readable + if pluginJSONPath != "" { + if _, err := os.ReadFile(pluginJSONPath); err != nil { + result.AddError(fmt.Sprintf("Failed to read plugin.json: %v", err)) + } + } + return result +} + +func validateSkillBundle(packagePath string, result *ValidationResult) *ValidationResult { + skillsDir := filepath.Join(packagePath, "skills") + entries, err := os.ReadDir(skillsDir) + if err != nil { + result.AddError(fmt.Sprintf("SKILL_BUNDLE detected but could not read skills/ directory: %v", err)) + return result + } + + var skillNames []string + for _, entry := range entries { + if !entry.IsDir() { + continue + } + name := entry.Name() + skillMD := filepath.Join(skillsDir, name, "SKILL.md") + if _, err := os.Stat(skillMD); err != nil { + continue + } + + // Path safety: reject traversal + if strings.Contains(name, "..") || strings.Contains(name, "/") { + result.AddError(fmt.Sprintf("Invalid skill directory name: %s", name)) + continue + } + + skillNames = append(skillNames, name) + } + + if len(skillNames) == 0 { + result.AddError(fmt.Sprintf("SKILL_BUNDLE detected but no valid skills//SKILL.md found in %s/skills/", filepath.Base(packagePath))) + return result + } + + return result +} + +func validateHybridPackage(packagePath string, result *ValidationResult) *ValidationResult { + apmDir := filepath.Join(packagePath, ".apm") + if info, err := os.Stat(apmDir); err == nil && info.IsDir() { + return validateAPMPackageWithYML(packagePath, result) + } + + // Skill-bundle path (no .apm/) + apmYMLPath := filepath.Join(packagePath, "apm.yml") + if _, err := os.Stat(apmYMLPath); err != nil { + result.AddError("HYBRID package missing apm.yml") + return result + } + + // Check SKILL.md is present + skillMD := filepath.Join(packagePath, "SKILL.md") + if _, err := os.Stat(skillMD); err != nil { + result.AddError("HYBRID package missing SKILL.md") + return result + } + + return result +} + +func validateAPMPackageWithYML(packagePath string, result *ValidationResult) *ValidationResult { + apmYMLPath := filepath.Join(packagePath, "apm.yml") + + // Parse apm.yml basic fields + data, err := os.ReadFile(apmYMLPath) + if err != nil { + result.AddError(fmt.Sprintf("Invalid apm.yml: %v", err)) + return result + } + + // Check for .apm directory + apmDir := filepath.Join(packagePath, ".apm") + apmDirInfo, apmDirErr := os.Stat(apmDir) + if apmDirErr != nil { + // No .apm/ -- check if dep-only (curated aggregator) + if apmYMLDeclaresDependencies(apmYMLPath) { + return result + } + result.AddError(fmt.Sprintf("Missing required directory: .apm/ -- "+ + "an APM package with apm.yml needs either a .apm/ directory "+ + "containing primitives, or dependencies declared in apm.yml. "+ + "Alternatively, add a SKILL.md to make this a skill bundle.")) + return result + } + + if !apmDirInfo.IsDir() { + result.AddError(".apm must be a directory") + return result + } + + // Check for primitives in .apm/ + primitiveTypes := []string{"instructions", "chatmodes", "contexts", "prompts"} + hasPrimitives := false + for _, pt := range primitiveTypes { + ptDir := filepath.Join(apmDir, pt) + entries, err := os.ReadDir(ptDir) + if err != nil { + continue + } + for _, e := range entries { + if !e.IsDir() && strings.HasSuffix(e.Name(), ".md") { + hasPrimitives = true + // Check for empty files + content, err := os.ReadFile(filepath.Join(ptDir, e.Name())) + if err == nil && strings.TrimSpace(string(content)) == "" { + result.AddWarning(fmt.Sprintf("Empty primitive file: .apm/%s/%s", pt, e.Name())) + } + } + } + } + + if !hasPrimitives { + hasPrimitives = hasHookJSON(packagePath) + } + + if !hasPrimitives { + result.AddWarning("No primitive files found in .apm/ directory") + } + + // Version format validation (basic semver check) + // Extract version from apm.yml content + version := extractYAMLField(string(data), "version") + if version != "" && !semverRe.MatchString(version) { + result.AddWarning(fmt.Sprintf("Version '%s' doesn't follow semantic versioning (x.y.z)", version)) + } + + return result +} + +// extractYAMLField extracts a simple scalar field value from YAML content. +func extractYAMLField(content, key string) string { + prefix := key + ":" + for _, line := range strings.Split(content, "\n") { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, prefix) { + val := strings.TrimSpace(trimmed[len(prefix):]) + // Strip quotes + if len(val) >= 2 && (val[0] == '"' || val[0] == '\'') { + val = val[1 : len(val)-1] + } + return val + } + } + return "" +} diff --git a/internal/models/validation/validation_test.go b/internal/models/validation/validation_test.go new file mode 100644 index 0000000..d80732a --- /dev/null +++ b/internal/models/validation/validation_test.go @@ -0,0 +1,111 @@ +package validation_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/githubnext/apm/internal/models/validation" +) + +func TestPackageTypeString(t *testing.T) { + cases := []struct { + t validation.PackageType + want string + }{ + {validation.PackageTypeAPMPackage, "apm_package"}, + {validation.PackageTypeClaudeSkill, "claude_skill"}, + {validation.PackageTypeHookPackage, "hook_package"}, + {validation.PackageTypeHybrid, "hybrid"}, + {validation.PackageTypeMarketplacePlugin, "marketplace_plugin"}, + {validation.PackageTypeSkillBundle, "skill_bundle"}, + {validation.PackageTypeInvalid, "invalid"}, + } + for _, c := range cases { + if got := c.t.String(); got != c.want { + t.Errorf("PackageType(%d).String() = %q; want %q", c.t, got, c.want) + } + } +} + +func TestPackageContentTypeFromString(t *testing.T) { + cases := []struct { + input string + want validation.PackageContentType + wantErr bool + }{ + {"instructions", validation.PackageContentTypeInstructions, false}, + {"skill", validation.PackageContentTypeSkill, false}, + {"hybrid", validation.PackageContentTypeHybrid, false}, + {"prompts", validation.PackageContentTypePrompts, false}, + {"HYBRID", validation.PackageContentTypeHybrid, false}, + {"", 0, true}, + {"unknown", 0, true}, + } + for _, c := range cases { + got, err := validation.PackageContentTypeFromString(c.input) + if c.wantErr { + if err == nil { + t.Errorf("PackageContentTypeFromString(%q) expected error", c.input) + } + continue + } + if err != nil { + t.Errorf("PackageContentTypeFromString(%q) unexpected error: %v", c.input, err) + continue + } + if got != c.want { + t.Errorf("PackageContentTypeFromString(%q) = %v; want %v", c.input, got, c.want) + } + } +} + +func TestValidationResult(t *testing.T) { + r := validation.NewValidationResult() + if !r.IsValid { + t.Error("new result should be valid") + } + r.AddWarning("test warning") + if !r.IsValid { + t.Error("warning should not make invalid") + } + if !r.HasIssues() { + t.Error("has issues after warning") + } + r.AddError("test error") + if r.IsValid { + t.Error("should be invalid after error") + } + summary := r.Summary() + if summary == "" { + t.Error("summary should not be empty") + } +} + +func TestDetectPackageTypeInvalid(t *testing.T) { + dir := t.TempDir() + pt, _ := validation.DetectPackageType(dir) + if pt != validation.PackageTypeInvalid { + t.Errorf("empty dir: got %v; want invalid", pt) + } +} + +func TestDetectPackageTypeClaudeSkill(t *testing.T) { + dir := t.TempDir() + os.WriteFile(filepath.Join(dir, "SKILL.md"), []byte("---\nname: test\n---\n# Test"), 0o644) + pt, _ := validation.DetectPackageType(dir) + if pt != validation.PackageTypeClaudeSkill { + t.Errorf("skill dir: got %v; want claude_skill", pt) + } +} + +func TestDetectPackageTypeHookPackage(t *testing.T) { + dir := t.TempDir() + hooksDir := filepath.Join(dir, "hooks") + os.MkdirAll(hooksDir, 0o755) + os.WriteFile(filepath.Join(hooksDir, "hooks.json"), []byte("{}"), 0o644) + pt, _ := validation.DetectPackageType(dir) + if pt != validation.PackageTypeHookPackage { + t.Errorf("hooks dir: got %v; want hook_package", pt) + } +} diff --git a/internal/output/compilationformatter/compilationformatter.go b/internal/output/compilationformatter/compilationformatter.go new file mode 100644 index 0000000..56abd86 --- /dev/null +++ b/internal/output/compilationformatter/compilationformatter.go @@ -0,0 +1,521 @@ +// Package compilationformatter formats compilation output for APM. +package compilationformatter + +import ( + "fmt" + "path/filepath" + "strings" +) + +// PlacementStrategy describes the optimization strategy used. +type PlacementStrategy string + +const ( + StrategySinglePoint PlacementStrategy = "Single Point" + StrategySelectiveMulti PlacementStrategy = "Selective Multi" + StrategyDistributed PlacementStrategy = "Distributed" +) + +// ProjectAnalysis holds analysis of the project structure. +type ProjectAnalysis struct { + DirectoriesScanned int + FilesAnalyzed int + FileTypesDetected []string + InstructionPatternsDetected int + MaxDepth int + ConstitutionDetected bool + ConstitutionPath string +} + +// FileTypesSummary returns a concise summary of detected file types. +func (p *ProjectAnalysis) FileTypesSummary() string { + if len(p.FileTypesDetected) == 0 { + return "none" + } + types := make([]string, 0, len(p.FileTypesDetected)) + for _, t := range p.FileTypesDetected { + types = append(types, strings.TrimPrefix(t, ".")) + } + if len(types) <= 3 { + return strings.Join(types, ", ") + } + return fmt.Sprintf("%s and %d more", strings.Join(types[:3], ", "), len(types)-3) +} + +// OptimizationDecision holds details about a placement decision for one instruction. +type OptimizationDecision struct { + Pattern string + InstructionFilePath string // file_path.name equivalent + MatchingDirectories int + TotalDirectories int + DistributionScore float64 + Strategy PlacementStrategy + PlacementDirectories []string + Reasoning string + RelevanceScore float64 +} + +// PlacementSummary summarises a single AGENTS.md file placement. +type PlacementSummary struct { + Path string + InstructionCount int + SourceCount int + Sources []string +} + +// RelativePath returns path relative to base, prefixed with "./" when at root. +func (s *PlacementSummary) RelativePath(base string) string { + rel, err := filepath.Rel(base, s.Path) + if err != nil { + return s.Path + } + if rel == "." { + return "." + } + return rel +} + +// OptimizationStats holds efficiency statistics. +type OptimizationStats struct { + AverageContextEfficiency float64 + PollutionImprovement *float64 + BaselineEfficiency *float64 + PlacementAccuracy *float64 + GenerationTimeMs *int + TotalAgentsFiles int + DirectoriesAnalyzed int +} + +// EfficiencyPercentage returns efficiency as a percentage. +func (s *OptimizationStats) EfficiencyPercentage() float64 { + return s.AverageContextEfficiency * 100 +} + +// EfficiencyImprovement returns efficiency improvement over baseline, if available. +func (s *OptimizationStats) EfficiencyImprovement() *float64 { + if s.BaselineEfficiency == nil { + return nil + } + v := (s.AverageContextEfficiency - *s.BaselineEfficiency) * 100 + return &v +} + +// CompilationResults holds all results from a compilation run. +type CompilationResults struct { + TargetName string + PlacementSummaries []PlacementSummary + OptimizationDecisions []OptimizationDecision + ProjectAnalysis *ProjectAnalysis + OptimizationStats OptimizationStats + Warnings []string + Errors []string + IsDryRun bool +} + +// HasIssues returns true if there are any warnings or errors. +func (r *CompilationResults) HasIssues() bool { + return len(r.Warnings) > 0 || len(r.Errors) > 0 +} + +// CompilationFormatter formats compilation output for the CLI. +type CompilationFormatter struct { + UseColor bool + targetName string +} + +// New creates a new CompilationFormatter. +func New(useColor bool) *CompilationFormatter { + return &CompilationFormatter{UseColor: useColor, targetName: "AGENTS.md"} +} + +// FormatDefault formats standard compilation output. +func (f *CompilationFormatter) FormatDefault(results *CompilationResults) string { + f.targetName = results.TargetName + var lines []string + + lines = append(lines, f.formatProjectDiscovery(results.ProjectAnalysis)...) + lines = append(lines, "") + lines = append(lines, f.formatOptimizationProgress(results.OptimizationDecisions, results.ProjectAnalysis)...) + lines = append(lines, "") + lines = append(lines, f.formatResultsSummary(results)...) + + if results.HasIssues() { + lines = append(lines, "") + lines = append(lines, f.formatIssues(results.Warnings, results.Errors)...) + } + + return strings.Join(lines, "\n") +} + +// FormatVerbose formats verbose compilation output with mathematical details. +func (f *CompilationFormatter) FormatVerbose(results *CompilationResults) string { + f.targetName = results.TargetName + var lines []string + + lines = append(lines, f.formatProjectDiscovery(results.ProjectAnalysis)...) + lines = append(lines, "") + lines = append(lines, f.formatOptimizationProgress(results.OptimizationDecisions, results.ProjectAnalysis)...) + lines = append(lines, "") + lines = append(lines, f.formatMathematicalAnalysis(results.OptimizationDecisions)...) + lines = append(lines, "") + lines = append(lines, f.formatCoverageExplanation(results.OptimizationStats)...) + lines = append(lines, "") + lines = append(lines, f.formatDetailedMetrics(results.OptimizationStats)...) + lines = append(lines, "") + lines = append(lines, f.formatFinalSummary(results)...) + + if results.HasIssues() { + lines = append(lines, "") + lines = append(lines, f.formatIssues(results.Warnings, results.Errors)...) + } + + return strings.Join(lines, "\n") +} + +// FormatDryRun formats dry-run output. +func (f *CompilationFormatter) FormatDryRun(results *CompilationResults) string { + f.targetName = results.TargetName + var lines []string + + lines = append(lines, f.formatProjectDiscovery(results.ProjectAnalysis)...) + lines = append(lines, "") + lines = append(lines, f.formatOptimizationProgress(results.OptimizationDecisions, results.ProjectAnalysis)...) + lines = append(lines, "") + lines = append(lines, f.formatDryRunSummary(results)...) + + if results.HasIssues() { + lines = append(lines, "") + lines = append(lines, f.formatIssues(results.Warnings, results.Errors)...) + } + + return strings.Join(lines, "\n") +} + +func (f *CompilationFormatter) formatProjectDiscovery(analysis *ProjectAnalysis) []string { + lines := []string{"Analyzing project structure..."} + + if analysis == nil { + return lines + } + + if analysis.ConstitutionDetected { + lines = append(lines, fmt.Sprintf("|- Constitution detected: %s", analysis.ConstitutionPath)) + } + + fileTypesSummary := analysis.FileTypesSummary() + lines = append(lines, + fmt.Sprintf("|- %d directories scanned (max depth: %d)", analysis.DirectoriesScanned, analysis.MaxDepth), + fmt.Sprintf("|- %d files analyzed across %d file types (%s)", analysis.FilesAnalyzed, len(analysis.FileTypesDetected), fileTypesSummary), + fmt.Sprintf("+- %d instruction patterns detected", analysis.InstructionPatternsDetected), + ) + return lines +} + +func (f *CompilationFormatter) formatOptimizationProgress(decisions []OptimizationDecision, analysis *ProjectAnalysis) []string { + lines := []string{"Optimizing placements..."} + + if analysis != nil && analysis.ConstitutionDetected { + lines = append(lines, + fmt.Sprintf("%-25s %-15s %-10s -> %-25s (rel: 100%%)", "**", "constitution.md", "ALL", "./AGENTS.md"), + ) + } + + for _, d := range decisions { + pattern := d.Pattern + if pattern == "" { + pattern = "(global)" + } + + source := "unknown" + if d.InstructionFilePath != "" { + source = d.InstructionFilePath + } + + ratio := fmt.Sprintf("%d/%d dirs", d.MatchingDirectories, d.TotalDirectories) + + if len(d.PlacementDirectories) == 1 { + placement := f.getRelativeDisplayPath(d.PlacementDirectories[0]) + relevance := d.RelevanceScore + if relevance == 0 { + relevance = 1.0 + } + line := fmt.Sprintf("%-25s %-15s %-10s -> %-25s (rel: %.0f%%)", + pattern, source, ratio, placement, relevance*100) + lines = append(lines, line) + } else { + line := fmt.Sprintf("%-25s %-15s %-10s -> %d locations", + pattern, source, ratio, len(d.PlacementDirectories)) + lines = append(lines, line) + } + } + return lines +} + +func (f *CompilationFormatter) formatResultsSummary(results *CompilationResults) []string { + var lines []string + + fileCount := len(results.PlacementSummaries) + plural := "s" + if fileCount == 1 { + plural = "" + } + summaryLine := fmt.Sprintf("Generated %d %s file%s", fileCount, results.TargetName, plural) + if results.IsDryRun { + summaryLine = fmt.Sprintf("[DRY RUN] Would generate %d %s file%s", fileCount, results.TargetName, plural) + } + lines = append(lines, summaryLine) + + stats := results.OptimizationStats + effPct := stats.EfficiencyPercentage() + metricLines := []string{fmt.Sprintf("+- Context efficiency: %.1f%%", effPct)} + + if imp := stats.EfficiencyImprovement(); imp != nil { + if *imp > 0 { + metricLines[0] += fmt.Sprintf(" (baseline: %.1f%%, improvement: +%.0f%%)", *stats.BaselineEfficiency*100, *imp) + } else { + metricLines[0] += fmt.Sprintf(" (baseline: %.1f%%, change: %.0f%%)", *stats.BaselineEfficiency*100, *imp) + } + } + + if stats.PollutionImprovement != nil { + pollutionPct := (1.0 - *stats.PollutionImprovement) * 100 + var improvementPct string + if *stats.PollutionImprovement > 0 { + improvementPct = fmt.Sprintf("-%.0f%%", *stats.PollutionImprovement*100) + } else { + improvementPct = fmt.Sprintf("+%.0f%%", -(*stats.PollutionImprovement)*100) + } + metricLines = append(metricLines, fmt.Sprintf("|- Average pollution: %.1f%% (improvement: %s)", pollutionPct, improvementPct)) + } + + if stats.PlacementAccuracy != nil { + metricLines = append(metricLines, fmt.Sprintf("|- Placement accuracy: %.1f%% (mathematical optimum)", *stats.PlacementAccuracy*100)) + } + + if stats.GenerationTimeMs != nil { + metricLines = append(metricLines, fmt.Sprintf("+- Generation time: %dms", *stats.GenerationTimeMs)) + } else if len(metricLines) > 1 { + metricLines[len(metricLines)-1] = strings.Replace(metricLines[len(metricLines)-1], "|-", "+-", 1) + } + + lines = append(lines, metricLines...) + lines = append(lines, "", "Placement Distribution") + + for i, summary := range results.PlacementSummaries { + relPath := summary.RelativePath(".") + contentText := f.getPlacementDescription(&summary) + sourceText := fmt.Sprintf("%d source", summary.SourceCount) + if summary.SourceCount != 1 { + sourceText += "s" + } + prefix := "|-" + if i == len(results.PlacementSummaries)-1 { + prefix = "+-" + } + line := fmt.Sprintf("%s %-30s %s from %s", prefix, relPath, contentText, sourceText) + lines = append(lines, line) + } + return lines +} + +func (f *CompilationFormatter) formatFinalSummary(results *CompilationResults) []string { + // In verbose mode use same structure as results summary with placement distribution. + return f.formatResultsSummary(results) +} + +func (f *CompilationFormatter) formatDryRunSummary(results *CompilationResults) []string { + lines := []string{"[DRY RUN] File generation preview:"} + + for i, summary := range results.PlacementSummaries { + relPath := summary.RelativePath(".") + instrText := fmt.Sprintf("%d instruction", summary.InstructionCount) + if summary.InstructionCount != 1 { + instrText += "s" + } + srcText := fmt.Sprintf("%d source", summary.SourceCount) + if summary.SourceCount != 1 { + srcText += "s" + } + prefix := "|-" + if i == len(results.PlacementSummaries)-1 { + prefix = "+-" + } + lines = append(lines, fmt.Sprintf("%s %-30s %s, %s", prefix, relPath, instrText, srcText)) + } + + lines = append(lines, "", "[DRY RUN] No files written. Run 'apm compile' to apply changes.") + return lines +} + +func (f *CompilationFormatter) formatMathematicalAnalysis(decisions []OptimizationDecision) []string { + lines := []string{"Mathematical Optimization Analysis", ""} + lines = append(lines, "Coverage-First Strategy Analysis:") + + for _, d := range decisions { + pattern := d.Pattern + if pattern == "" { + pattern = "(global)" + } + score := fmt.Sprintf("%.3f", d.DistributionScore) + strategy := string(d.Strategy) + var coverage string + if d.DistributionScore < 0.7 { + coverage = "[+] Verified" + } else { + coverage = "[!] Root Fallback" + } + lines = append(lines, fmt.Sprintf(" %-30s %-8s %-15s %s", pattern, score, strategy, coverage)) + } + + lines = append(lines, "", + "Mathematical Foundation:", + " Objective: minimize sum(context_pollution x directory_weight)", + " Constraints: for_allfile_matching_pattern -> can_inherit_instruction", + " Algorithm: Three-tier strategy with coverage verification", + " Principle: Coverage guarantee takes priority over efficiency", + ) + return lines +} + +func (f *CompilationFormatter) formatCoverageExplanation(stats OptimizationStats) []string { + lines := []string{"Coverage vs. Efficiency Analysis", ""} + + efficiency := stats.EfficiencyPercentage() + + if efficiency < 30 { + lines = append(lines, + "[!] Low Efficiency Detected:", + " * Coverage guarantee requires some instructions at root level", + " * This creates pollution for specialized directories", + " * Trade-off: Guaranteed coverage vs. optimal efficiency", + " * Alternative: Higher efficiency with coverage violations (data loss)", + "", + "This may be mathematically optimal given coverage constraints", + ) + } else if efficiency < 60 { + lines = append(lines, + "[+] Moderate Efficiency:", + " * Good balance between coverage and efficiency", + " * Some coverage-driven pollution is acceptable", + " * Most patterns are well-localized", + ) + } else { + lines = append(lines, + "High Efficiency:", + " * Excellent pattern locality achieved", + " * Minimal coverage conflicts", + " * Instructions are optimally placed", + ) + } + + lines = append(lines, "", + "Why Coverage Takes Priority:", + " * Every file must access applicable instructions", + " * Hierarchical inheritance prevents data loss", + " * Better low efficiency than missing instructions", + ) + return lines +} + +func (f *CompilationFormatter) formatDetailedMetrics(stats OptimizationStats) []string { + lines := []string{"Performance Metrics"} + + efficiency := stats.EfficiencyPercentage() + pollution := 100 - efficiency + + effAssessment := assessEfficiency(efficiency) + pollAssessment := assessPollution(pollution) + + lines = append(lines, + fmt.Sprintf("Context Efficiency: %.1f%% (%s)", efficiency, effAssessment), + fmt.Sprintf("Pollution Level: %.1f%% (%s)", pollution, pollAssessment), + "Guide: 80-100% Excellent | 60-80% Good | 40-60% Fair | 20-40% Poor | <20% Very Poor", + ) + return lines +} + +func assessEfficiency(v float64) string { + switch { + case v >= 80: + return "Excellent" + case v >= 60: + return "Good" + case v >= 40: + return "Fair" + case v >= 20: + return "Poor" + default: + return "Very Poor" + } +} + +func assessPollution(v float64) string { + switch { + case v <= 10: + return "Excellent" + case v <= 25: + return "Good" + case v <= 50: + return "Fair" + default: + return "Poor" + } +} + +func (f *CompilationFormatter) formatIssues(warnings, errors []string) []string { + var lines []string + for _, e := range errors { + lines = append(lines, "x Error: "+e) + } + for _, w := range warnings { + if strings.Contains(w, "\n") { + wLines := strings.Split(w, "\n") + lines = append(lines, "[!] Warning: "+wLines[0]) + for _, wl := range wLines[1:] { + if strings.TrimSpace(wl) != "" { + lines = append(lines, " "+wl) + } + } + } else { + lines = append(lines, "[!] Warning: "+w) + } + } + return lines +} + +func (f *CompilationFormatter) getRelativeDisplayPath(path string) string { + rel, err := filepath.Rel(".", path) + if err != nil { + return filepath.Join(path, f.targetName) + } + if rel == "." { + return "./" + f.targetName + } + return filepath.ToSlash(filepath.Join(rel, f.targetName)) +} + +func (f *CompilationFormatter) getPlacementDescription(summary *PlacementSummary) string { + hasConstitution := false + for _, src := range summary.Sources { + if strings.Contains(src, "constitution.md") { + hasConstitution = true + break + } + } + + var parts []string + if hasConstitution { + parts = append(parts, "Constitution") + } + if summary.InstructionCount > 0 { + plural := "s" + if summary.InstructionCount == 1 { + plural = "" + } + parts = append(parts, fmt.Sprintf("%d instruction%s", summary.InstructionCount, plural)) + } + if len(parts) > 0 { + return strings.Join(parts, " and ") + } + return "content" +} diff --git a/internal/output/models/models.go b/internal/output/models/models.go new file mode 100644 index 0000000..a7db672 --- /dev/null +++ b/internal/output/models/models.go @@ -0,0 +1,158 @@ +// Package models provides data models for compilation output and results. +package models + +// PlacementStrategy represents how instructions are placed across the project. +type PlacementStrategy string + +const ( + PlacementStrategySinglePoint PlacementStrategy = "Single Point" + PlacementStrategySelectiveMulti PlacementStrategy = "Selective Multi" + PlacementStrategyDistributed PlacementStrategy = "Distributed" +) + +// ProjectAnalysis holds analysis of the project structure and file distribution. +type ProjectAnalysis struct { + DirectoriesScanned int + FilesAnalyzed int + FileTypesDetected []string + InstructionPatternsDetected int + MaxDepth int + ConstitutionDetected bool + ConstitutionPath string +} + +// GetFileTypesSummary returns a concise summary of detected file types. +func (p *ProjectAnalysis) GetFileTypesSummary() string { + if len(p.FileTypesDetected) == 0 { + return "none" + } + types := make([]string, 0, len(p.FileTypesDetected)) + for _, t := range p.FileTypesDetected { + stripped := t + for len(stripped) > 0 && stripped[0] == '.' { + stripped = stripped[1:] + } + if stripped != "" { + types = append(types, stripped) + } + } + // Simple sort + for i := 0; i < len(types); i++ { + for j := i + 1; j < len(types); j++ { + if types[j] < types[i] { + types[i], types[j] = types[j], types[i] + } + } + } + if len(types) <= 3 { + result := "" + for i, t := range types { + if i > 0 { + result += ", " + } + result += t + } + return result + } + result := types[0] + ", " + types[1] + ", " + types[2] + return result + " and " + itoa(len(types)-3) + " more" +} + +func itoa(n int) string { + if n == 0 { + return "0" + } + buf := make([]byte, 0, 10) + for n > 0 { + buf = append([]byte{byte('0' + n%10)}, buf...) + n /= 10 + } + return string(buf) +} + +// OptimizationDecision holds details about a specific optimization decision for an instruction. +type OptimizationDecision struct { + InstructionName string + Pattern string + MatchingDirectories int + TotalDirectories int + DistributionScore float64 + Strategy PlacementStrategy + PlacementDirectories []string + Reasoning string + RelevanceScore float64 +} + +// DistributionRatio returns matching/total directories ratio. +func (o *OptimizationDecision) DistributionRatio() float64 { + if o.TotalDirectories == 0 { + return 0.0 + } + return float64(o.MatchingDirectories) / float64(o.TotalDirectories) +} + +// PlacementSummary summarizes a single AGENTS.md file placement. +type PlacementSummary struct { + Path string + InstructionCount int + SourceCount int + Sources []string +} + +// OptimizationStats holds performance and efficiency statistics from optimization. +type OptimizationStats struct { + AverageContextEfficiency float64 + PollutionImprovement *float64 + BaselineEfficiency *float64 + PlacementAccuracy *float64 + GenerationTimeMs *int + TotalAgentsFiles int + DirectoriesAnalyzed int +} + +// EfficiencyImprovement calculates efficiency improvement percentage. +func (o *OptimizationStats) EfficiencyImprovement() *float64 { + if o.BaselineEfficiency != nil && *o.BaselineEfficiency != 0 { + v := (o.AverageContextEfficiency - *o.BaselineEfficiency) / *o.BaselineEfficiency * 100 + return &v + } + return nil +} + +// EfficiencyPercentage returns efficiency as percentage. +func (o *OptimizationStats) EfficiencyPercentage() float64 { + return o.AverageContextEfficiency * 100 +} + +// CompilationResults holds complete results from the compilation process. +type CompilationResults struct { + ProjectAnalysis *ProjectAnalysis + OptimizationDecisions []OptimizationDecision + PlacementSummaries []PlacementSummary + OptimizationStats *OptimizationStats + Warnings []string + Errors []string + IsDryRun bool + TargetName string +} + +// TotalInstructions returns the total number of instructions processed. +func (c *CompilationResults) TotalInstructions() int { + total := 0 + for _, s := range c.PlacementSummaries { + total += s.InstructionCount + } + return total +} + +// HasIssues returns true if there are any warnings or errors. +func (c *CompilationResults) HasIssues() bool { + return len(c.Warnings) > 0 || len(c.Errors) > 0 +} + +// NewCompilationResults creates a new CompilationResults with defaults. +func NewCompilationResults() *CompilationResults { + return &CompilationResults{ + TargetName: "AGENTS.md", + } +} diff --git a/internal/output/scriptformatters/scriptformatters.go b/internal/output/scriptformatters/scriptformatters.go new file mode 100644 index 0000000..01cd520 --- /dev/null +++ b/internal/output/scriptformatters/scriptformatters.go @@ -0,0 +1,143 @@ +// Package scriptformatters provides ASCII-only CLI output formatters for +// APM script execution. +// Migrated from src/apm_cli/output/script_formatters.py. +// Rich/colour output is omitted -- all output is plain ASCII. +package scriptformatters + +import ( + "fmt" + "strings" +) + +// ScriptExecutionFormatter formats script execution output as plain ASCII lines. +type ScriptExecutionFormatter struct{} + +// NewScriptExecutionFormatter returns a new formatter. +func NewScriptExecutionFormatter() *ScriptExecutionFormatter { + return &ScriptExecutionFormatter{} +} + +// FormatScriptHeader formats the script execution header with parameters. +func (f *ScriptExecutionFormatter) FormatScriptHeader(scriptName string, params map[string]string) []string { + lines := []string{fmt.Sprintf("[>] Running script: %s", scriptName)} + for k, v := range params { + lines = append(lines, fmt.Sprintf(" - %s: %s", k, v)) + } + return lines +} + +// FormatCompilationProgress formats prompt compilation progress. +func (f *ScriptExecutionFormatter) FormatCompilationProgress(promptFiles []string) []string { + if len(promptFiles) == 0 { + return nil + } + var lines []string + if len(promptFiles) == 1 { + lines = append(lines, "Compiling prompt...") + } else { + lines = append(lines, fmt.Sprintf("Compiling %d prompts...", len(promptFiles))) + } + for _, pf := range promptFiles { + lines = append(lines, fmt.Sprintf("|- %s", pf)) + } + if len(lines) > 1 { + lines[len(lines)-1] = strings.Replace(lines[len(lines)-1], "|-", "+-", 1) + } + return lines +} + +// FormatRuntimeExecution formats runtime command execution details. +func (f *ScriptExecutionFormatter) FormatRuntimeExecution(runtime, command string, contentLength int) []string { + return []string{ + fmt.Sprintf("Executing %s runtime...", runtime), + fmt.Sprintf("|- Command: %s", command), + fmt.Sprintf("+- Prompt content: %d characters", contentLength), + } +} + +// FormatContentPreview formats a content preview (plain text, no rich boxes). +func (f *ScriptExecutionFormatter) FormatContentPreview(content string, maxPreview int) []string { + if maxPreview <= 0 { + maxPreview = 200 + } + preview := content + if len(content) > maxPreview { + preview = content[:maxPreview] + "..." + } + return []string{ + "Prompt preview:", + strings.Repeat("-", 50), + preview, + strings.Repeat("-", 50), + } +} + +// FormatEnvironmentSetup formats environment setup information. +func (f *ScriptExecutionFormatter) FormatEnvironmentSetup(runtime string, envVarsSet []string) []string { + if len(envVarsSet) == 0 { + return nil + } + lines := []string{"Environment setup:"} + for _, v := range envVarsSet { + lines = append(lines, fmt.Sprintf("|- %s: configured", v)) + } + if len(lines) > 1 { + lines[len(lines)-1] = strings.Replace(lines[len(lines)-1], "|-", "+-", 1) + } + return lines +} + +// FormatExecutionSuccess formats a successful execution result. +// executionTime < 0 means not provided. +func (f *ScriptExecutionFormatter) FormatExecutionSuccess(runtime string, executionTime float64) []string { + msg := fmt.Sprintf("[+] %s execution completed successfully", titleCase(runtime)) + if executionTime >= 0 { + msg += fmt.Sprintf(" (%.2fs)", executionTime) + } + return []string{msg} +} + +// FormatExecutionError formats an execution error result. +func (f *ScriptExecutionFormatter) FormatExecutionError(runtime string, errorCode int, errorMsg string) []string { + lines := []string{ + fmt.Sprintf("x %s execution failed (exit code: %d)", titleCase(runtime), errorCode), + } + if errorMsg != "" { + for _, line := range strings.Split(errorMsg, "\n") { + if strings.TrimSpace(line) != "" { + lines = append(lines, " "+line) + } + } + } + return lines +} + +// FormatSubprocessDetails formats subprocess execution details. +func (f *ScriptExecutionFormatter) FormatSubprocessDetails(args []string, contentLength int) []string { + quoted := make([]string, len(args)) + for i, a := range args { + if strings.Contains(a, " ") { + quoted[i] = `"` + a + `"` + } else { + quoted[i] = a + } + } + return []string{ + "Subprocess execution:", + fmt.Sprintf("|- Args: %s", strings.Join(quoted, " ")), + fmt.Sprintf("+- Content: +%d chars appended", contentLength), + } +} + +// FormatAutoDiscoveryMessage formats the message for auto-discovered prompts. +func (f *ScriptExecutionFormatter) FormatAutoDiscoveryMessage(scriptName, promptFile, runtime string) string { + return fmt.Sprintf("[i] Auto-discovered: %s (runtime: %s)", promptFile, runtime) +} + +// titleCase capitalises the first rune of s. +func titleCase(s string) string { + if s == "" { + return s + } + return strings.ToUpper(s[:1]) + s[1:] +} diff --git a/internal/policy/discovery/discovery.go b/internal/policy/discovery/discovery.go new file mode 100644 index 0000000..309014d --- /dev/null +++ b/internal/policy/discovery/discovery.go @@ -0,0 +1,985 @@ +// Package discovery implements auto-discovery and fetching of org-level apm-policy.yml files. +// Migrated from src/apm_cli/policy/discovery.py. +// +// Discovery flow: +// 1. Extract org from git remote (github.com/contoso/my-project -> "contoso") +// 2. Fetch /.github/apm-policy.yml via GitHub API (Contents API) +// 3. Resolve inheritance chain via policy/inheritance package +// 4. Cache the merged effective policy with chain metadata +// 5. Parse and return the policy +// +// Supports: +// - GitHub.com and GitHub Enterprise (*.ghe.com) +// - Manual override via --policy +// - Cache with TTL (default 1 hour), stale fallback up to MAX_STALE_TTL +// - Atomic cache writes (temp file + os.Rename) +// - Hash-pin verification ("algo:hex" format) for supply-chain hardening +package discovery + +import ( + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "os/exec" + "path/filepath" + "regexp" + "strings" + "sync" + "time" + + "github.com/githubnext/apm/internal/policy/schema" + "github.com/githubnext/apm/internal/utils/pathsecurity" +) + +const ( + policyCacheDir = ".policy-cache" + defaultCacheTTL = 3600 // 1 hour (seconds) + maxStaleTTL = 7 * 24 * 3600 // 7 days + cacheSchemaVersion = "3" +) + +// scpLikeRE matches SCP-style SSH remote URLs: user@host:path +var scpLikeRE = regexp.MustCompile(`^(?:[^@:/?#]+@)(?P[^:/?#]+):(?P.+)$`) + +// PolicyFetchResult is the outcome of a policy fetch attempt. +// The Outcome field discriminates discovery outcomes. +type PolicyFetchResult struct { + Policy *schema.ApmPolicy + Source string // "org:contoso/.github", "file:/path", "url:https://..." + Cached bool + Err string // error message if fetch failed + CacheAgeSeconds int + CacheStale bool + FetchErr string + Outcome string + RawBytesHash string // ":" of leaf bytes off the wire + ExpectedHash string // pin that was checked, if any +} + +// Found returns true when a policy was found. +func (r *PolicyFetchResult) Found() bool { return r.Policy != nil } + +// cacheEntry is an internal representation of a cached policy read. +type cacheEntry struct { + Policy *schema.ApmPolicy + Source string + AgeSeconds int + Stale bool + ChainRefs []string + Fingerprint string + RawBytesHash string +} + +// --------------------------------------------------------------------------- +// Public entry points +// --------------------------------------------------------------------------- + +// DiscoverPolicyWithChain discovers policy with full inheritance chain resolution. +// This is the shared entry point for all command sites that need chain-aware policy discovery. +func DiscoverPolicyWithChain(projectRoot string, expectedHash string) *PolicyFetchResult { + if os.Getenv("APM_POLICY_DISABLE") == "1" { + return &PolicyFetchResult{Outcome: "disabled"} + } + + // If no explicit hash, read from project apm.yml (stub -- just pass through) + if expectedHash == "" { + if pin := readProjectHashPin(projectRoot); pin != "" { + expectedHash = pin + } + } + + fetchResult := DiscoverPolicy(projectRoot, "", false, expectedHash) + + // Chain resolution if leaf has extends (stub -- not implemented in this iteration) + _ = fetchResult + return fetchResult +} + +// DiscoverPolicy discovers and loads the applicable policy for a project. +// +// Resolution order: +// 1. If policyOverride is a local file path -- load from file +// 2. If policyOverride is an https:// URL -- fetch from URL +// 3. If policyOverride is "owner/repo" or "host/owner/repo" -- fetch from repo +// 4. If policyOverride is "" -- auto-discover from project's git remote +func DiscoverPolicy(projectRoot, policyOverride string, noCache bool, expectedHash string) *PolicyFetchResult { + if policyOverride != "" { + // Try as local file + if info, err := os.Stat(policyOverride); err == nil && !info.IsDir() { + return loadFromFile(policyOverride, expectedHash) + } + if strings.HasPrefix(policyOverride, "http://") { + return &PolicyFetchResult{ + Err: "Refusing plaintext http:// policy URL -- use https://", + Source: "url:" + policyOverride, + Outcome: "cache_miss_fetch_fail", + } + } + if strings.HasPrefix(policyOverride, "https://") { + return fetchFromURL(policyOverride, projectRoot, noCache, expectedHash) + } + if policyOverride != "org" { + return fetchFromRepo(policyOverride, projectRoot, noCache, expectedHash) + } + } + return autoDiscover(projectRoot, noCache, expectedHash) +} + +// --------------------------------------------------------------------------- +// File loading +// --------------------------------------------------------------------------- + +func loadFromFile(path, expectedHash string) *PolicyFetchResult { + content, err := os.ReadFile(path) + if err != nil { + return &PolicyFetchResult{ + Err: fmt.Sprintf("Failed to read %s: %v", path, err), + Outcome: "cache_miss_fetch_fail", + } + } + sourceLabel := "file:" + path + + if mismatch := verifyHashPin(content, expectedHash, sourceLabel); mismatch != nil { + return mismatch + } + + policy, parseErr := parsePolicy(content) + if parseErr != nil { + return &PolicyFetchResult{ + Err: fmt.Sprintf("Invalid policy file %s: %v", path, parseErr), + Source: sourceLabel, + Outcome: "malformed", + } + } + + outcome := "found" + if isPolicyEmpty(policy) { + outcome = "empty" + } + var rawHash string + if expectedHash != "" { + rawHash = computeHashNormalized(content, expectedHash) + } + return &PolicyFetchResult{ + Policy: policy, + Source: sourceLabel, + Outcome: outcome, + RawBytesHash: rawHash, + ExpectedHash: expectedHash, + } +} + +// --------------------------------------------------------------------------- +// Auto-discovery +// --------------------------------------------------------------------------- + +func autoDiscover(projectRoot string, noCache bool, expectedHash string) *PolicyFetchResult { + org, host, err := extractOrgFromGitRemote(projectRoot) + if err != nil || org == "" { + return &PolicyFetchResult{ + Err: "Could not determine org from git remote", + Outcome: "no_git_remote", + } + } + repoRef := org + "/.github" + if host != "" && host != "github.com" { + repoRef = host + "/" + repoRef + } + return fetchFromRepo(repoRef, projectRoot, noCache, expectedHash) +} + +// extractOrgFromGitRemote runs git remote get-url origin and parses the org and host. +func extractOrgFromGitRemote(projectRoot string) (org, host string, err error) { + cmd := exec.Command("git", "remote", "get-url", "origin") + cmd.Dir = projectRoot + out, execErr := cmd.Output() + if execErr != nil { + return "", "", execErr + } + remoteURL := strings.TrimSpace(string(out)) + return parseRemoteURL(remoteURL) +} + +// parseRemoteURL parses a git remote URL into (org, host, error). +func parseRemoteURL(rawURL string) (org, host string, err error) { + if rawURL == "" { + return "", "", fmt.Errorf("empty URL") + } + + // SCP-style SSH: user@host:path + if m := scpLikeRE.FindStringSubmatch(rawURL); len(m) > 0 { + var hostPart, pathPart string + for i, name := range scpLikeRE.SubexpNames() { + switch name { + case "host": + hostPart = m[i] + case "path": + pathPart = m[i] + } + } + pathPart = strings.TrimSuffix(strings.TrimRight(pathPart, "/"), ".git") + parts := strings.Split(pathPart, "/") + var cleaned []string + for _, p := range parts { + if p != "" { + cleaned = append(cleaned, p) + } + } + if len(cleaned) == 0 { + return "", "", fmt.Errorf("cannot parse path from SCP URL") + } + // Azure DevOps SSH has v3/ prefix + if hostPart == "ssh.dev.azure.com" && len(cleaned) >= 2 && cleaned[0] == "v3" { + return cleaned[1], hostPart, nil + } + return cleaned[0], hostPart, nil + } + + // HTTPS + if strings.Contains(rawURL, "://") { + u, parseErr := url.Parse(rawURL) + if parseErr != nil { + return "", "", parseErr + } + h := u.Hostname() + pathPart := strings.TrimSuffix(strings.Trim(u.Path, "/"), ".git") + parts := strings.Split(pathPart, "/") + var cleaned []string + for _, p := range parts { + if p != "" { + cleaned = append(cleaned, p) + } + } + if h != "" && len(cleaned) > 0 { + return cleaned[0], h, nil + } + } + return "", "", fmt.Errorf("could not parse remote URL: %s", rawURL) +} + +// --------------------------------------------------------------------------- +// URL fetch +// --------------------------------------------------------------------------- + +var httpClient = &http.Client{ + Timeout: 10 * time.Second, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + // Refuse redirects (security: prevent SSRF via redirect) + return http.ErrUseLastResponse + }, +} + +func fetchFromURL(rawURL, projectRoot string, noCache bool, expectedHash string) *PolicyFetchResult { + sourceLabel := "url:" + rawURL + var ce *cacheEntry + + if !noCache { + ce = readCacheEntry(rawURL, projectRoot, defaultCacheTTL, expectedHash) + if ce != nil && !ce.Stale { + outcome := "found" + if isPolicyEmpty(ce.Policy) { + outcome = "empty" + } + return &PolicyFetchResult{ + Policy: ce.Policy, + Source: ce.Source, + Cached: true, + CacheAgeSeconds: ce.AgeSeconds, + Outcome: outcome, + RawBytesHash: ce.RawBytesHash, + ExpectedHash: expectedHash, + } + } + } + + resp, err := httpClient.Get(rawURL) + var content []byte + var fetchErrStr string + if err != nil { + fetchErrStr = fmt.Sprintf("Error fetching %s: %v", rawURL, err) + } else { + defer resp.Body.Close() + if resp.StatusCode == 404 { + return &PolicyFetchResult{Source: sourceLabel, Err: "404: Policy file not found", Outcome: "absent"} + } + if resp.StatusCode >= 300 && resp.StatusCode < 400 { + loc := resp.Header.Get("Location") + fetchErrStr = fmt.Sprintf("Refusing HTTP redirect (%d) from %s to %s", resp.StatusCode, rawURL, loc) + } else if resp.StatusCode != 200 { + fetchErrStr = fmt.Sprintf("HTTP %d fetching %s", resp.StatusCode, rawURL) + } else { + content, err = io.ReadAll(resp.Body) + if err != nil { + fetchErrStr = fmt.Sprintf("Error reading response from %s: %v", rawURL, err) + } + } + } + + if fetchErrStr != "" { + return staleOrError(ce, fetchErrStr, sourceLabel, "cache_miss_fetch_fail") + } + + if gr := detectGarbage(content, rawURL, sourceLabel, ce); gr != nil { + return gr + } + + if mismatch := verifyHashPin(content, expectedHash, sourceLabel); mismatch != nil { + return mismatch + } + + policy, parseErr := parsePolicy(content) + if parseErr != nil { + return &PolicyFetchResult{ + Err: fmt.Sprintf("Invalid policy from %s: %v", rawURL, parseErr), + Source: sourceLabel, + Outcome: "malformed", + } + } + + actualHash := computeHashNormalized(content, expectedHash) + writeCache(rawURL, policy, projectRoot, []string{rawURL}, actualHash) + outcome := "found" + if isPolicyEmpty(policy) { + outcome = "empty" + } + return &PolicyFetchResult{ + Policy: policy, + Source: sourceLabel, + Outcome: outcome, + RawBytesHash: actualHash, + ExpectedHash: expectedHash, + } +} + +// --------------------------------------------------------------------------- +// Repo fetch (GitHub Contents API) +// --------------------------------------------------------------------------- + +func fetchFromRepo(repoRef, projectRoot string, noCache bool, expectedHash string) *PolicyFetchResult { + sourceLabel := "org:" + repoRef + var ce *cacheEntry + + if !noCache { + ce = readCacheEntry(repoRef, projectRoot, defaultCacheTTL, expectedHash) + if ce != nil && !ce.Stale { + outcome := "found" + if isPolicyEmpty(ce.Policy) { + outcome = "empty" + } + return &PolicyFetchResult{ + Policy: ce.Policy, + Source: ce.Source, + Cached: true, + CacheAgeSeconds: ce.AgeSeconds, + Outcome: outcome, + RawBytesHash: ce.RawBytesHash, + ExpectedHash: expectedHash, + } + } + } + + content, fetchErr := fetchGithubContents(repoRef, "apm-policy.yml") + if fetchErr != "" { + if strings.Contains(fetchErr, "404") { + return &PolicyFetchResult{Source: sourceLabel, Outcome: "absent"} + } + return staleOrError(ce, fetchErr, sourceLabel, "cache_miss_fetch_fail") + } + if content == nil { + return &PolicyFetchResult{Source: sourceLabel, Outcome: "absent"} + } + + if gr := detectGarbage(content, repoRef, sourceLabel, ce); gr != nil { + return gr + } + + if mismatch := verifyHashPin(content, expectedHash, sourceLabel); mismatch != nil { + return mismatch + } + + policy, parseErr := parsePolicy(content) + if parseErr != nil { + return &PolicyFetchResult{ + Err: fmt.Sprintf("Invalid policy in %s: %v", repoRef, parseErr), + Source: sourceLabel, + Outcome: "malformed", + } + } + + actualHash := computeHashNormalized(content, expectedHash) + writeCache(repoRef, policy, projectRoot, []string{repoRef}, actualHash) + outcome := "found" + if isPolicyEmpty(policy) { + outcome = "empty" + } + return &PolicyFetchResult{ + Policy: policy, + Source: sourceLabel, + Outcome: outcome, + RawBytesHash: actualHash, + ExpectedHash: expectedHash, + } +} + +// fetchGithubContents fetches apm-policy.yml from a GitHub/GHE repo via the Contents API. +// Returns (content, errString). One will be nil/"". +func fetchGithubContents(repoRef, filePath string) ([]byte, string) { + parts := strings.Split(repoRef, "/") + var host, owner, repo string + switch len(parts) { + case 2: + host, owner, repo = "github.com", parts[0], parts[1] + case 3: + host, owner, repo = parts[0], parts[1], parts[2] + default: + if len(parts) >= 3 { + host, owner, repo = parts[0], parts[1], strings.Join(parts[2:], "/") + } else { + return nil, fmt.Sprintf("Invalid repo reference: %s", repoRef) + } + } + + var apiURL string + if host == "github.com" { + apiURL = fmt.Sprintf("https://api.github.com/repos/%s/%s/contents/%s", owner, repo, filePath) + } else { + apiURL = fmt.Sprintf("https://%s/api/v3/repos/%s/%s/contents/%s", host, owner, repo, filePath) + } + + req, err := http.NewRequest("GET", apiURL, nil) + if err != nil { + return nil, fmt.Sprintf("Error building request for %s: %v", repoRef, err) + } + req.Header.Set("Accept", "application/vnd.github.v3+json") + if token := getTokenForHost(host); token != "" { + req.Header.Set("Authorization", "token "+token) + } + + resp, err := httpClient.Do(req) + if err != nil { + return nil, fmt.Sprintf("Error fetching policy from %s: %v", repoRef, err) + } + defer resp.Body.Close() + + switch resp.StatusCode { + case 404: + return nil, "404: Policy file not found" + case 403: + return nil, fmt.Sprintf("403: Access denied to %s", repoRef) + case 200: + // continue + default: + if resp.StatusCode >= 300 && resp.StatusCode < 400 { + loc := resp.Header.Get("Location") + return nil, fmt.Sprintf("Refusing HTTP redirect (%d) from %s to %s", resp.StatusCode, apiURL, loc) + } + return nil, fmt.Sprintf("HTTP %d fetching policy from %s", resp.StatusCode, repoRef) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Sprintf("Error reading response from %s: %v", repoRef, err) + } + + var data map[string]interface{} + if err := json.Unmarshal(body, &data); err != nil { + return nil, fmt.Sprintf("Error parsing response from %s: %v", repoRef, err) + } + + if enc, ok := data["encoding"].(string); ok && enc == "base64" { + if rawContent, ok := data["content"].(string); ok && rawContent != "" { + cleaned := strings.ReplaceAll(rawContent, "\n", "") + decoded, err := base64.StdEncoding.DecodeString(cleaned) + if err != nil { + return nil, fmt.Sprintf("Error decoding base64 content from %s: %v", repoRef, err) + } + return decoded, "" + } + } + if rawContent, ok := data["content"].(string); ok && rawContent != "" { + return []byte(rawContent), "" + } + return nil, fmt.Sprintf("Unexpected response format from %s", repoRef) +} + +// getTokenForHost returns a GitHub/GHE token for the given host. +func getTokenForHost(host string) string { + hostLower := strings.ToLower(host) + isGitHub := hostLower == "github.com" || strings.HasSuffix(hostLower, ".ghe.com") || + (os.Getenv("GITHUB_HOST") != "" && hostLower == strings.ToLower(os.Getenv("GITHUB_HOST"))) + if !isGitHub { + return "" + } + for _, env := range []string{"GITHUB_TOKEN", "GITHUB_APM_PAT", "GH_TOKEN"} { + if t := os.Getenv(env); t != "" { + return t + } + } + return "" +} + +// --------------------------------------------------------------------------- +// Hash pin verification +// --------------------------------------------------------------------------- + +// verifyHashPin verifies content against an expected hash pin. +// Returns nil when verification passes or there is no pin. +// Returns a PolicyFetchResult with outcome "hash_mismatch" on failure. +func verifyHashPin(content []byte, expectedHash, sourceLabel string) *PolicyFetchResult { + if expectedHash == "" { + return nil + } + algo, expectedHex, err := splitHashPin(expectedHash) + if err != nil { + return &PolicyFetchResult{ + Outcome: "hash_mismatch", + Source: sourceLabel, + Err: fmt.Sprintf("Policy hash mismatch from %s: invalid pin (%v)", sourceLabel, err), + ExpectedHash: expectedHash, + } + } + + var actualHex string + switch algo { + case "sha256": + h := sha256.Sum256(content) + actualHex = fmt.Sprintf("%x", h) + default: + return &PolicyFetchResult{ + Outcome: "hash_mismatch", + Source: sourceLabel, + Err: fmt.Sprintf("Unsupported hash algorithm: %s", algo), + } + } + + if actualHex != expectedHex { + return &PolicyFetchResult{ + Outcome: "hash_mismatch", + Source: sourceLabel, + Err: fmt.Sprintf("Policy hash mismatch from %s: expected %s:%s, got %s:%s", sourceLabel, algo, expectedHex, algo, actualHex), + ExpectedHash: fmt.Sprintf("%s:%s", algo, expectedHex), + RawBytesHash: fmt.Sprintf("%s:%s", algo, actualHex), + } + } + return nil +} + +// splitHashPin splits ":" into (algo, hex). +// Bare hex without prefix is treated as sha256 for backward compatibility. +func splitHashPin(pin string) (algo, hex string, err error) { + raw := strings.TrimSpace(pin) + if strings.Contains(raw, ":") { + idx := strings.Index(raw, ":") + algo = strings.ToLower(strings.TrimSpace(raw[:idx])) + hex = strings.ToLower(strings.TrimSpace(raw[idx+1:])) + } else { + algo = "sha256" + hex = strings.ToLower(raw) + } + if algo != "sha256" { + return "", "", fmt.Errorf("unsupported algorithm %q", algo) + } + if len(hex) != 64 { + return "", "", fmt.Errorf("invalid sha256 hex (length %d)", len(hex)) + } + return algo, hex, nil +} + +func computeHashNormalized(content []byte, expectedHash string) string { + algo := "sha256" + if expectedHash != "" { + if a, _, err := splitHashPin(expectedHash); err == nil { + algo = a + } + } + switch algo { + case "sha256": + h := sha256.Sum256(content) + return fmt.Sprintf("sha256:%x", h) + } + return "" +} + +// --------------------------------------------------------------------------- +// Policy parsing +// --------------------------------------------------------------------------- + +// parsePolicy parses raw YAML bytes into an ApmPolicy. +// Uses a minimal line-by-line scanner tracking current section context. +func parsePolicy(data []byte) (*schema.ApmPolicy, error) { + if len(strings.TrimSpace(string(data))) == 0 { + return &schema.ApmPolicy{}, nil + } + + p := &schema.ApmPolicy{} + lines := strings.Split(string(data), "\n") + + // Track section by top-level key and sub-key + var section, subSection, listKey string + var listTarget *[]string + + setListTarget := func(key string) { + switch { + case section == "dependencies" && key == "allow": + listTarget = &p.Deps.Allow + case section == "dependencies" && key == "deny": + listTarget = &p.Deps.Deny + case section == "dependencies" && key == "require": + listTarget = &p.Deps.Require + case section == "mcp" && key == "allow": + listTarget = &p.MCP.Allow + case section == "mcp" && key == "deny": + listTarget = &p.MCP.Deny + case section == "mcp" && subSection == "transport" && key == "allow": + listTarget = &p.MCP.Transport.Allow + case section == "compilation" && subSection == "target" && key == "allow": + listTarget = &p.Compilation.Targets.Allow + default: + listTarget = nil + } + listKey = key + _ = listKey + } + + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + continue + } + + indent := 0 + for _, ch := range line { + if ch == ' ' { + indent++ + } else { + break + } + } + + if strings.HasPrefix(trimmed, "- ") { + val := strings.TrimPrefix(trimmed, "- ") + val = strings.Trim(val, "\"'") + if listTarget != nil { + *listTarget = append(*listTarget, val) + } + continue + } + + if idx := strings.Index(trimmed, ":"); idx >= 0 { + key := strings.TrimSpace(trimmed[:idx]) + val := strings.TrimSpace(trimmed[idx+1:]) + val = strings.Trim(val, "\"'") + + if indent == 0 { + // Top-level key + section = key + subSection = "" + listTarget = nil + switch key { + case "version": + p.Version = val + case "enforcement": + p.Enforcement = val + case "fetch_failure": + p.FetchFailure = val + } + } else if indent == 2 { + // Section key + subSection = "" + listTarget = nil + if val == "" { + subSection = key + } else { + switch { + case section == "dependencies" && key == "require_resolution": + p.Deps.RequireResolution = val + case section == "mcp" && key == "self_defined": + p.MCP.SelfDefined = val + case section == "compilation" && key == "source_attribution": + // ignore + } + setListTarget(key) + // If val is empty this is a list parent -- handled above + // If non-empty, clear listTarget (it's a scalar, not list) + if val != "" { + listTarget = nil + } + } + } else if indent == 4 { + // Sub-section key + listTarget = nil + if val == "" { + subSection = key + } else { + switch { + case section == "mcp" && subSection == "transport" && key == "allow": + // scalar allow -- no-op + case section == "compilation" && subSection == "target" && key == "enforce": + p.Compilation.Targets.Enforce = val + case section == "compilation" && subSection == "strategy" && key == "enforce": + p.Compilation.Strategy.Enforce = val + } + setListTarget(key) + if val != "" { + listTarget = nil + } + } + } + } + } + + return p, nil +} + +// isPolicyEmpty returns true when a policy has no actionable restrictions. +func isPolicyEmpty(p *schema.ApmPolicy) bool { + if p == nil { + return true + } + return len(p.Deps.Deny) == 0 && + p.Deps.Allow == nil && + len(p.Deps.Require) == 0 && + len(p.MCP.Deny) == 0 && + p.MCP.Allow == nil && + p.MCP.Transport.Allow == nil && + p.Compilation.Targets.Allow == nil +} + + +// --------------------------------------------------------------------------- +// Cache +// --------------------------------------------------------------------------- + +type cacheMeta struct { + RepoRef string `json:"repo_ref"` + CachedAt float64 `json:"cached_at"` + ChainRefs []string `json:"chain_refs"` + SchemaVersion string `json:"schema_version"` + Fingerprint string `json:"fingerprint"` + RawBytesHash string `json:"raw_bytes_hash"` +} + +func cacheKey(repoRef string) string { + h := sha256.Sum256([]byte(repoRef)) + return fmt.Sprintf("%x", h)[:16] +} + +func getCacheDir(projectRoot string) (string, error) { + resolved, err := filepath.Abs(projectRoot) + if err != nil { + return "", err + } + base := filepath.Join(resolved, "apm_modules") + candidate := filepath.Join(base, policyCacheDir) + if _, err := pathsecurity.EnsurePathWithin(candidate, resolved); err != nil { + return "", fmt.Errorf("policy cache path %q resolves outside project root %q", candidate, resolved) + } + return candidate, nil +} + +func readCacheEntry(repoRef, projectRoot string, ttl int, expectedHash string) *cacheEntry { + cacheDir, err := getCacheDir(projectRoot) + if err != nil { + return nil + } + key := cacheKey(repoRef) + policyFile := filepath.Join(cacheDir, key+".yml") + metaFile := filepath.Join(cacheDir, key+".meta.json") + + if _, err := os.Stat(policyFile); os.IsNotExist(err) { + return nil + } + if _, err := os.Stat(metaFile); os.IsNotExist(err) { + return nil + } + + metaBytes, err := os.ReadFile(metaFile) + if err != nil { + return nil + } + var meta cacheMeta + if err := json.Unmarshal(metaBytes, &meta); err != nil { + return nil + } + if meta.SchemaVersion != cacheSchemaVersion { + return nil + } + + age := int(time.Now().Unix() - int64(meta.CachedAt)) + if age > maxStaleTTL { + return nil + } + + // Pin verification + if expectedHash != "" { + ea, eh, err := splitHashPin(expectedHash) + if err != nil { + return nil + } + expectedNorm := fmt.Sprintf("%s:%s", ea, eh) + if strings.ToLower(meta.RawBytesHash) != expectedNorm { + return nil + } + } + + policyContent, err := os.ReadFile(policyFile) + if err != nil { + return nil + } + policy, err := parsePolicy(policyContent) + if err != nil { + return nil + } + + source := "org:" + repoRef + if strings.HasPrefix(repoRef, "http://") || strings.HasPrefix(repoRef, "https://") { + source = "url:" + repoRef + } + + return &cacheEntry{ + Policy: policy, + Source: source, + AgeSeconds: age, + Stale: age > ttl, + ChainRefs: meta.ChainRefs, + Fingerprint: meta.Fingerprint, + RawBytesHash: meta.RawBytesHash, + } +} + +var writeMu sync.Mutex + +func writeCache(repoRef string, policy *schema.ApmPolicy, projectRoot string, chainRefs []string, rawBytesHash string) { + cacheDir, err := getCacheDir(projectRoot) + if err != nil { + return + } + if err := os.MkdirAll(cacheDir, 0o755); err != nil { + return + } + + key := cacheKey(repoRef) + policyFile := filepath.Join(cacheDir, key+".yml") + metaFile := filepath.Join(cacheDir, key+".meta.json") + + serialized := serializePolicy(policy) + fingerprint := fmt.Sprintf("%x", sha256.Sum256([]byte(serialized)))[:32] + + meta := cacheMeta{ + RepoRef: repoRef, + CachedAt: float64(time.Now().UnixNano()) / 1e9, + ChainRefs: chainRefs, + SchemaVersion: cacheSchemaVersion, + Fingerprint: fingerprint, + RawBytesHash: rawBytesHash, + } + metaBytes, err := json.Marshal(meta) + if err != nil { + return + } + + writeMu.Lock() + defer writeMu.Unlock() + + uid := fmt.Sprintf("%d", time.Now().UnixNano()) + tmpPolicy := policyFile + "." + uid + ".tmp" + if err := os.WriteFile(tmpPolicy, []byte(serialized), 0o644); err == nil { + _ = os.Rename(tmpPolicy, policyFile) + } + tmpMeta := metaFile + "." + uid + ".tmp" + if err := os.WriteFile(tmpMeta, metaBytes, 0o644); err == nil { + _ = os.Rename(tmpMeta, metaFile) + } +} + +// serializePolicy serializes an ApmPolicy to a simple YAML-like string for caching. +func serializePolicy(p *schema.ApmPolicy) string { + if p == nil { + return "" + } + var sb strings.Builder + sb.WriteString(fmt.Sprintf("version: %s\n", p.Version)) + sb.WriteString(fmt.Sprintf("enforcement: %s\n", p.Enforcement)) + sb.WriteString(fmt.Sprintf("fetch_failure: %s\n", p.FetchFailure)) + if len(p.Deps.Deny) > 0 { + sb.WriteString("dependencies:\n") + sb.WriteString(" deny:\n") + for _, d := range p.Deps.Deny { + sb.WriteString(" - " + d + "\n") + } + } + return sb.String() +} + +// --------------------------------------------------------------------------- +// Garbage detection +// --------------------------------------------------------------------------- + +func detectGarbage(content []byte, identifier, sourceLabel string, ce *cacheEntry) *PolicyFetchResult { + if content == nil { + return nil + } + trimmed := strings.TrimSpace(string(content)) + if trimmed == "" { + return nil + } + // Very basic check: a valid YAML policy starts with a known key or is a mapping + // For garbage detection: if it starts with "<" (HTML) it's a captive portal + if strings.HasPrefix(trimmed, "<") { + msg := fmt.Sprintf("Response from %s is not valid YAML (possible captive portal or redirect)", identifier) + if ce != nil { + return &PolicyFetchResult{ + Policy: ce.Policy, + Source: ce.Source, + Cached: true, + CacheStale: true, + CacheAgeSeconds: ce.AgeSeconds, + FetchErr: msg, + Outcome: "cached_stale", + } + } + return &PolicyFetchResult{ + Err: msg, + Source: sourceLabel, + FetchErr: msg, + Outcome: "garbage_response", + } + } + return nil +} + +// --------------------------------------------------------------------------- +// Stale or error fallback +// --------------------------------------------------------------------------- + +func staleOrError(ce *cacheEntry, fetchErrMsg, sourceLabel, outcomeOnMiss string) *PolicyFetchResult { + if ce != nil { + return &PolicyFetchResult{ + Policy: ce.Policy, + Source: ce.Source, + Cached: true, + CacheStale: true, + CacheAgeSeconds: ce.AgeSeconds, + FetchErr: fetchErrMsg, + Outcome: "cached_stale", + } + } + return &PolicyFetchResult{ + Err: fetchErrMsg, + Source: sourceLabel, + FetchErr: fetchErrMsg, + Outcome: outcomeOnMiss, + } +} + +// readProjectHashPin is a stub -- returns "" if no apm.yml hash pin found. +func readProjectHashPin(projectRoot string) string { + // Full implementation would parse apm.yml policy.hash field. + // Returning "" for now -- callers pass the pin explicitly when available. + return "" +} diff --git a/internal/policy/helptext/helptext.go b/internal/policy/helptext/helptext.go new file mode 100644 index 0000000..5b446fd --- /dev/null +++ b/internal/policy/helptext/helptext.go @@ -0,0 +1,9 @@ +// Package helptext contains shared help text for policy-related CLI commands. +// Migrated from src/apm_cli/policy/_help_text.py. +package helptext + +// PolicySourceFormsHelp is the canonical user-facing description of the +// --policy / --policy-source argument formats accepted by discover_policy. +const PolicySourceFormsHelp = "Accepts: 'org' (auto-discover from your project's git remote), " + + "'owner/repo' (defaults to github.com), an https:// URL, or a " + + "local file path." diff --git a/internal/policy/outcomerouting/outcomerouting.go b/internal/policy/outcomerouting/outcomerouting.go new file mode 100644 index 0000000..1357a27 --- /dev/null +++ b/internal/policy/outcomerouting/outcomerouting.go @@ -0,0 +1,189 @@ +// Package outcomerouting is the single source of truth for the 9-outcome +// policy-discovery routing table. +// Migrated from src/apm_cli/policy/outcome_routing.py. +package outcomerouting + +import ( + "fmt" + + "github.com/githubnext/apm/internal/policy/schema" +) + +// PolicyViolationError is raised when a policy demands fail-closed behaviour. +type PolicyViolationError struct { + Message string + PolicySource string +} + +func (e *PolicyViolationError) Error() string { + return e.Message +} + +// PolicyFetchResult holds the result of a discover_policy call. +type PolicyFetchResult struct { + Outcome string + Source string + Cached bool + Error string + FetchError string + CacheAgeSeconds int + Policy *schema.ApmPolicy +} + +// PolicyLogger is the minimal interface expected of a logger for routing. +type PolicyLogger interface { + PolicyResolved(source string, cached bool, enforcement string, ageSeconds int) + PolicyDiscoveryMiss(outcome string, source string, err string) +} + +// outcomesHonoringFetchFailureDefault is the set of outcomes that respect the +// project-side policy.fetch_failure_default knob. +var outcomesHonoringFetchFailureDefault = map[string]bool{ + "malformed": true, + "cache_miss_fetch_fail": true, + "garbage_response": true, + "no_git_remote": true, + "absent": true, + "empty": true, +} + +// nonFoundLoggedOutcomes is the set of outcomes routed through the canonical +// policy_discovery_miss logger helper. +var nonFoundLoggedOutcomes = map[string]bool{ + "absent": true, + "no_git_remote": true, + "empty": true, + "malformed": true, + "cache_miss_fetch_fail": true, + "garbage_response": true, +} + +// RouteDiscoveryOutcome routes a PolicyFetchResult to logging and fail-closed +// decisions. +// +// Parameters: +// - fetchResult: result of discover_policy_with_chain +// - logger: logger implementing PolicyLogger (nil is tolerated) +// - fetchFailureDefault: project-side policy.fetch_failure_default ("warn" or "block") +// - raiseBlockingErrors: when true, return a PolicyViolationError for blocking outcomes +// +// Returns the effective ApmPolicy when enforcement should proceed, nil otherwise. +// When raiseBlockingErrors is true and a blocking condition is met, a non-nil error +// is returned alongside a nil policy. +func RouteDiscoveryOutcome( + fetchResult PolicyFetchResult, + logger PolicyLogger, + fetchFailureDefault string, + raiseBlockingErrors bool, +) (*schema.ApmPolicy, error) { + outcome := fetchResult.Outcome + source := fetchResult.Source + + if outcome == "disabled" { + return nil, nil + } + + // hash_mismatch: ALWAYS fail closed regardless of fetch_failure_default. + if outcome == "hash_mismatch" { + errStr := fetchResult.Error + if errStr == "" { + errStr = fetchResult.FetchError + } + if logger != nil { + logger.PolicyDiscoveryMiss("hash_mismatch", source, errStr) + } + if raiseBlockingErrors { + return nil, &PolicyViolationError{ + Message: fmt.Sprintf( + "Install blocked: policy hash mismatch -- pinned policy.hash "+ + "does not match fetched policy bytes (source=%s). "+ + "Update apm.yml policy.hash or contact your org admin.", + sourceOrUnknown(source), + ), + PolicySource: sourceOrUnknown(source), + } + } + return nil, nil + } + + // 6 of 9 non-found outcomes route through the canonical logger helper. + if nonFoundLoggedOutcomes[outcome] { + errStr := fetchResult.Error + if errStr == "" { + errStr = fetchResult.FetchError + } + if logger != nil { + logger.PolicyDiscoveryMiss(outcome, source, errStr) + } + if raiseBlockingErrors && + outcomesHonoringFetchFailureDefault[outcome] && + fetchFailureDefault == "block" { + return nil, &PolicyViolationError{ + Message: fmt.Sprintf( + "Install blocked: no enforceable org policy was resolved "+ + "(outcome=%s) and project apm.yml has "+ + "policy.fetch_failure_default=block (source=%s)", + outcome, + sourceOrUnknown(source), + ), + PolicySource: sourceOrUnknown(source), + } + } + return nil, nil + } + + // cached_stale: log, enforce with the cached policy, potentially fail closed. + if outcome == "cached_stale" { + policy := fetchResult.Policy + if logger != nil { + if policy != nil { + enforcement := policy.Enforcement + if enforcement == "" { + enforcement = "warn" + } + logger.PolicyResolved(source, true, enforcement, fetchResult.CacheAgeSeconds) + } + logger.PolicyDiscoveryMiss("cached_stale", source, fetchResult.FetchError) + } + if raiseBlockingErrors && policy != nil { + ff := policy.FetchFailure + if ff == "" { + ff = "warn" + } + if ff == "block" { + return nil, &PolicyViolationError{ + Message: fmt.Sprintf( + "Install blocked: org policy refresh failed and the cached "+ + "policy declares fetch_failure=block (source=%s)", + sourceOrUnknown(source), + ), + PolicySource: sourceOrUnknown(source), + } + } + } + return policy, nil + } + + // found: normal path + if outcome == "found" { + policy := fetchResult.Policy + if logger != nil && policy != nil { + enforcement := policy.Enforcement + if enforcement == "" { + enforcement = "warn" + } + logger.PolicyResolved(source, fetchResult.Cached, enforcement, fetchResult.CacheAgeSeconds) + } + return policy, nil + } + + // Defensive: unrecognised outcome -- skip enforcement. + return nil, nil +} + +func sourceOrUnknown(s string) string { + if s == "" { + return "unknown" + } + return s +} diff --git a/internal/policy/schema/schema.go b/internal/policy/schema/schema.go index f8016dd..556734e 100644 --- a/internal/policy/schema/schema.go +++ b/internal/policy/schema/schema.go @@ -54,6 +54,8 @@ Cache PolicyCache Deps DependencyPolicy MCP McpPolicy Compilation CompilationPolicy +Enforcement string // warn | block | off (default: warn) +FetchFailure string // warn | block (default: warn) } // DefaultDependencyPolicy returns a DependencyPolicy with sensible defaults. diff --git a/internal/primitives/discovery/discovery.go b/internal/primitives/discovery/discovery.go new file mode 100644 index 0000000..30211e4 --- /dev/null +++ b/internal/primitives/discovery/discovery.go @@ -0,0 +1,582 @@ +// Package discovery provides functionality for discovering APM primitive files. +// Migrated from src/apm_cli/primitives/discovery.py. +package discovery + +import ( + "fmt" + "os" + "path/filepath" + "sort" + "strings" + + "github.com/githubnext/apm/internal/constants" + "github.com/githubnext/apm/internal/primitives/primmodels" + "github.com/githubnext/apm/internal/primitives/primparser" + "github.com/githubnext/apm/internal/utils/exclude" + "github.com/githubnext/apm/internal/utils/paths" +) + +// PrimitiveConflict records when two primitives compete for the same name. +type PrimitiveConflict struct { + PrimitiveName string + PrimitiveType string + WinningSource string + LosingSource string + FilePath string +} + +// PrimitiveCollection holds all discovered primitives. +type PrimitiveCollection struct { + Chatmodes []*primmodels.Chatmode + Instructions []*primmodels.Instruction + Contexts []*primmodels.Context + Skills []*primmodels.Skill + Conflicts []PrimitiveConflict + + chatmodeIndex map[string]int + instructionIndex map[string]int + contextIndex map[string]int + skillIndex map[string]int +} + +// NewPrimitiveCollection creates an initialized PrimitiveCollection. +func NewPrimitiveCollection() *PrimitiveCollection { + return &PrimitiveCollection{ + chatmodeIndex: make(map[string]int), + instructionIndex: make(map[string]int), + contextIndex: make(map[string]int), + skillIndex: make(map[string]int), + } +} + +// AddPrimitive adds a primitive to the collection with conflict detection. +func (c *PrimitiveCollection) AddPrimitive(p primmodels.Primitive) error { + switch v := p.(type) { + case *primmodels.Chatmode: + c.addChatmode(v) + case *primmodels.Instruction: + c.addInstruction(v) + case *primmodels.Context: + c.addContext(v) + case *primmodels.Skill: + c.addSkill(v) + default: + return fmt.Errorf("unknown primitive type: %T", p) + } + return nil +} + +func (c *PrimitiveCollection) addChatmode(p *primmodels.Chatmode) { + if idx, exists := c.chatmodeIndex[p.Name]; exists { + existing := c.Chatmodes[idx] + if shouldReplace(existing.Source, p.Source) { + c.Conflicts = append(c.Conflicts, PrimitiveConflict{ + PrimitiveName: p.Name, PrimitiveType: "chatmode", + WinningSource: p.Source, LosingSource: existing.Source, + FilePath: p.FilePath, + }) + c.Chatmodes[idx] = p + } else { + c.Conflicts = append(c.Conflicts, PrimitiveConflict{ + PrimitiveName: p.Name, PrimitiveType: "chatmode", + WinningSource: existing.Source, LosingSource: p.Source, + FilePath: existing.FilePath, + }) + } + return + } + c.chatmodeIndex[p.Name] = len(c.Chatmodes) + c.Chatmodes = append(c.Chatmodes, p) +} + +func (c *PrimitiveCollection) addInstruction(p *primmodels.Instruction) { + if idx, exists := c.instructionIndex[p.Name]; exists { + existing := c.Instructions[idx] + if shouldReplace(existing.Source, p.Source) { + c.Conflicts = append(c.Conflicts, PrimitiveConflict{ + PrimitiveName: p.Name, PrimitiveType: "instruction", + WinningSource: p.Source, LosingSource: existing.Source, + FilePath: p.FilePath, + }) + c.Instructions[idx] = p + } else { + c.Conflicts = append(c.Conflicts, PrimitiveConflict{ + PrimitiveName: p.Name, PrimitiveType: "instruction", + WinningSource: existing.Source, LosingSource: p.Source, + FilePath: existing.FilePath, + }) + } + return + } + c.instructionIndex[p.Name] = len(c.Instructions) + c.Instructions = append(c.Instructions, p) +} + +func (c *PrimitiveCollection) addContext(p *primmodels.Context) { + if idx, exists := c.contextIndex[p.Name]; exists { + existing := c.Contexts[idx] + if shouldReplace(existing.Source, p.Source) { + c.Conflicts = append(c.Conflicts, PrimitiveConflict{ + PrimitiveName: p.Name, PrimitiveType: "context", + WinningSource: p.Source, LosingSource: existing.Source, + FilePath: p.FilePath, + }) + c.Contexts[idx] = p + } else { + c.Conflicts = append(c.Conflicts, PrimitiveConflict{ + PrimitiveName: p.Name, PrimitiveType: "context", + WinningSource: existing.Source, LosingSource: p.Source, + FilePath: existing.FilePath, + }) + } + return + } + c.contextIndex[p.Name] = len(c.Contexts) + c.Contexts = append(c.Contexts, p) +} + +func (c *PrimitiveCollection) addSkill(p *primmodels.Skill) { + if idx, exists := c.skillIndex[p.Name]; exists { + existing := c.Skills[idx] + if shouldReplace(existing.Source, p.Source) { + c.Conflicts = append(c.Conflicts, PrimitiveConflict{ + PrimitiveName: p.Name, PrimitiveType: "skill", + WinningSource: p.Source, LosingSource: existing.Source, + FilePath: p.FilePath, + }) + c.Skills[idx] = p + } else { + c.Conflicts = append(c.Conflicts, PrimitiveConflict{ + PrimitiveName: p.Name, PrimitiveType: "skill", + WinningSource: existing.Source, LosingSource: p.Source, + FilePath: existing.FilePath, + }) + } + return + } + c.skillIndex[p.Name] = len(c.Skills) + c.Skills = append(c.Skills, p) +} + +// shouldReplace returns true when newSource should replace existingSource. +// Local always wins over dependency; earlier dependency wins over later. +func shouldReplace(existingSource, newSource string) bool { + existingLocal := existingSource == "local" || existingSource == "" + newLocal := newSource == "local" || newSource == "" + if newLocal && !existingLocal { + return true + } + return false +} + +// Local primitive glob patterns (with recursive search via **/). +var localPrimitivePatterns = map[string][]string{ + "chatmode": { + "**/.apm/agents/*.agent.md", + "**/.github/agents/*.agent.md", + "**/*.agent.md", + "**/.apm/chatmodes/*.chatmode.md", + "**/.github/chatmodes/*.chatmode.md", + "**/*.chatmode.md", + }, + "instruction": { + "**/.apm/instructions/*.instructions.md", + "**/.github/instructions/*.instructions.md", + "**/*.instructions.md", + }, + "context": { + "**/.apm/context/*.context.md", + "**/.apm/memory/*.memory.md", + "**/.github/context/*.context.md", + "**/.github/memory/*.memory.md", + "**/*.context.md", + "**/*.memory.md", + }, +} + +// Dependency primitive patterns (for .apm directory within dependencies). +var dependencyPrimitivePatterns = map[string][]string{ + "chatmode": {"agents/*.agent.md", "chatmodes/*.chatmode.md"}, + "instruction": {"instructions/*.instructions.md"}, + "context": {"context/*.context.md", "memory/*.memory.md"}, +} + +// Dependency .github primitive patterns. +var dependencyGithubPrimitivePatterns = map[string][]string{ + "chatmode": {"agents/*.agent.md", "chatmodes/*.chatmode.md"}, + "instruction": {"instructions/*.instructions.md"}, + "context": {"context/*.context.md", "memory/*.memory.md"}, +} + +// DiscoverPrimitives finds all APM primitive files in the project. +func DiscoverPrimitives(baseDir string, excludePatterns []string) (*PrimitiveCollection, error) { + collection := NewPrimitiveCollection() + safePatterns, _ := exclude.ValidateExcludePatterns(excludePatterns) + + for _, ptPatterns := range localPrimitivePatterns { + files, err := FindPrimitiveFiles(baseDir, ptPatterns, safePatterns) + if err != nil { + continue + } + for _, fp := range files { + prim, err := primparser.ParsePrimitiveFile(fp, "local") + if err != nil { + fmt.Printf("Warning: Failed to parse %s: %v\n", fp, err) + continue + } + collection.AddPrimitive(prim) //nolint:errcheck + } + } + discoverLocalSkill(baseDir, collection, safePatterns) + return collection, nil +} + +// DiscoverPrimitivesWithDependencies performs enhanced discovery including dependencies. +func DiscoverPrimitivesWithDependencies(baseDir string, excludePatterns []string) (*PrimitiveCollection, error) { + collection := NewPrimitiveCollection() + safePatterns, _ := exclude.ValidateExcludePatterns(excludePatterns) + + scanLocalPrimitives(baseDir, collection, safePatterns) + discoverLocalSkill(baseDir, collection, safePatterns) + scanDependencyPrimitives(baseDir, collection) + return collection, nil +} + +// scanLocalPrimitives scans the local .apm/ directory for primitives. +func scanLocalPrimitives(baseDir string, collection *PrimitiveCollection, excludePatterns []string) { + for _, ptPatterns := range localPrimitivePatterns { + files, err := FindPrimitiveFiles(baseDir, ptPatterns, excludePatterns) + if err != nil { + continue + } + basePath, _ := filepath.Abs(baseDir) + apmModulesPath := filepath.Join(basePath, "apm_modules") + for _, fp := range files { + absFile, _ := filepath.Abs(fp) + if isUnderDirectory(absFile, apmModulesPath) { + continue + } + prim, err := primparser.ParsePrimitiveFile(fp, "local") + if err != nil { + fmt.Printf("Warning: Failed to parse local primitive %s: %v\n", fp, err) + continue + } + collection.AddPrimitive(prim) //nolint:errcheck + } + } +} + +// scanDependencyPrimitives scans all dependencies in apm_modules/ with priority handling. +func scanDependencyPrimitives(baseDir string, collection *PrimitiveCollection) { + apmModulesPath := filepath.Join(baseDir, "apm_modules") + info, err := os.Stat(apmModulesPath) + if err != nil || !info.IsDir() { + return + } + depOrder := getDependencyDeclarationOrder(baseDir) + for _, depName := range depOrder { + parts := strings.Split(depName, "/") + depPath := filepath.Join(append([]string{apmModulesPath}, parts...)...) + info, err := os.Stat(depPath) + if err == nil && info.IsDir() { + ScanDirectoryWithSource(depPath, collection, "dependency:"+depName) + } + } +} + +// getDependencyDeclarationOrder returns dependency installed paths in declaration order. +// Simplified: reads lockfile paths only (apm.yml parsing would need more infra). +func getDependencyDeclarationOrder(baseDir string) []string { + // Fallback: return directories from apm_modules sorted alphabetically + apmModulesPath := filepath.Join(baseDir, "apm_modules") + entries, err := os.ReadDir(apmModulesPath) + if err != nil { + return nil + } + var names []string + for _, e := range entries { + if e.IsDir() { + // Try two-level paths (owner/repo) + subEntries, err := os.ReadDir(filepath.Join(apmModulesPath, e.Name())) + if err != nil { + names = append(names, e.Name()) + continue + } + for _, se := range subEntries { + if se.IsDir() { + names = append(names, e.Name()+"/"+se.Name()) + } + } + } + } + return names +} + +// ScanDirectoryWithSource scans a directory for primitives with a specific source tag. +func ScanDirectoryWithSource(directory string, collection *PrimitiveCollection, source string) { + apmDir := filepath.Join(directory, ".apm") + if info, err := os.Stat(apmDir); err == nil && info.IsDir() { + scanPatterns(apmDir, dependencyPrimitivePatterns, collection, source) + } + githubDir := filepath.Join(directory, ".github") + if info, err := os.Stat(githubDir); err == nil && info.IsDir() { + scanPatterns(githubDir, dependencyGithubPrimitivePatterns, collection, source) + } + discoverSkillInDirectory(directory, collection, source) +} + +func discoverLocalSkill(baseDir string, collection *PrimitiveCollection, excludePatterns []string) { + skillPath := filepath.Join(baseDir, "SKILL.md") + info, err := os.Stat(skillPath) + if err != nil || !info.Mode().IsRegular() { + return + } + absBase, _ := filepath.Abs(baseDir) + absSkill, _ := filepath.Abs(skillPath) + if exclude.ShouldExclude(absSkill, absBase, excludePatterns) { + return + } + if !isReadable(skillPath) { + return + } + skill, err := primparser.ParseSkillFile(skillPath, "local") + if err != nil { + fmt.Printf("Warning: Failed to parse SKILL.md: %v\n", err) + return + } + collection.AddPrimitive(skill) //nolint:errcheck +} + +func discoverSkillInDirectory(directory string, collection *PrimitiveCollection, source string) { + skillPath := filepath.Join(directory, "SKILL.md") + if !isReadable(skillPath) { + return + } + skill, err := primparser.ParseSkillFile(skillPath, source) + if err != nil { + fmt.Printf("Warning: Failed to parse SKILL.md in %s: %v\n", directory, err) + return + } + collection.AddPrimitive(skill) //nolint:errcheck +} + +// scanPatterns walks baseDir once and matches files against all patterns. +func scanPatterns(baseDir string, patterns map[string][]string, collection *PrimitiveCollection, source string) { + info, err := os.Stat(baseDir) + if err != nil || !info.IsDir() { + return + } + // Flatten all patterns + var allPatterns []string + for _, ps := range patterns { + allPatterns = append(allPatterns, ps...) + } + + err = filepath.WalkDir(baseDir, func(fp string, d os.DirEntry, err error) error { + if err != nil { + return nil + } + if d.IsDir() { + return nil + } + rel, err := filepath.Rel(baseDir, fp) + if err != nil { + return nil + } + relFwd := strings.ReplaceAll(rel, string(filepath.Separator), "/") + if !matchesAnyPattern(relFwd, allPatterns) { + return nil + } + if !d.Type().IsRegular() { + return nil + } + if !isReadable(fp) { + return nil + } + prim, err := primparser.ParsePrimitiveFile(fp, source) + if err != nil { + fmt.Printf("Warning: Failed to parse dependency primitive %s: %v\n", fp, err) + return nil + } + collection.AddPrimitive(prim) //nolint:errcheck + return nil + }) + _ = err +} + +// FindPrimitiveFiles finds primitive files matching the given patterns. +func FindPrimitiveFiles(baseDir string, patterns []string, excludePatterns []string) ([]string, error) { + info, err := os.Stat(baseDir) + if err != nil || !info.IsDir() { + return nil, nil + } + basePath, err := filepath.Abs(baseDir) + if err != nil { + return nil, err + } + + var allFiles []string + + err = filepath.WalkDir(basePath, func(fp string, d os.DirEntry, err error) error { + if err != nil { + return nil + } + name := d.Name() + if d.IsDir() { + if _, skip := constants.DefaultSkipDirs[name]; skip { + return filepath.SkipDir + } + if exclude.ShouldExclude(fp, basePath, excludePatterns) { + return filepath.SkipDir + } + return nil + } + // Sort within directory is handled by WalkDir (lexical order already) + if d.Type()&os.ModeSymlink != 0 { + return nil + } + if exclude.ShouldExclude(fp, basePath, excludePatterns) { + return nil + } + rel := paths.PortableRelpath(fp, basePath) + for _, pat := range patterns { + if globMatch(rel, pat) { + allFiles = append(allFiles, fp) + break + } + } + return nil + }) + if err != nil { + return nil, err + } + + // Filter invalid + valid := make([]string, 0, len(allFiles)) + for _, fp := range allFiles { + fi, err := os.Lstat(fp) + if err != nil { + continue + } + if !fi.Mode().IsRegular() { + continue + } + if fi.Mode()&os.ModeSymlink != 0 { + continue + } + if isReadable(fp) { + valid = append(valid, fp) + } + } + sort.Strings(valid) + return valid, nil +} + +// globMatch matches a forward-slash relative path against a glob pattern. +// Segment-aware: ** matches zero or more complete path segments. +func globMatch(relPath, pattern string) bool { + pathParts := splitNonEmpty(relPath, "/") + patternParts := splitNonEmpty(pattern, "/") + memo := make(map[[2]int]bool) + var match func(pi, qi int) bool + match = func(pi, qi int) bool { + key := [2]int{pi, qi} + if v, ok := memo[key]; ok { + return v + } + if qi == len(patternParts) { + result := pi == len(pathParts) + memo[key] = result + return result + } + cur := patternParts[qi] + if cur == "**" { + result := match(pi, qi+1) + if !result && pi < len(pathParts) { + result = match(pi+1, qi) + } + memo[key] = result + return result + } + if pi >= len(pathParts) { + memo[key] = false + return false + } + result := fnmatchSegment(pathParts[pi], cur) && match(pi+1, qi+1) + memo[key] = result + return result + } + return match(0, 0) +} + +// fnmatchSegment matches a single path segment against a pattern. +// Supports * (any chars within segment) and ? (single char). +func fnmatchSegment(name, pattern string) bool { + for len(pattern) > 0 { + switch pattern[0] { + case '*': + if len(pattern) == 1 { + return true + } + rest := pattern[1:] + for i := 0; i <= len(name); i++ { + if fnmatchSegment(name[i:], rest) { + return true + } + } + return false + case '?': + if len(name) == 0 { + return false + } + name = name[1:] + pattern = pattern[1:] + default: + if len(name) == 0 || name[0] != pattern[0] { + return false + } + name = name[1:] + pattern = pattern[1:] + } + } + return len(name) == 0 +} + +func matchesAnyPattern(relPath string, patterns []string) bool { + for _, p := range patterns { + if globMatch(relPath, p) { + return true + } + } + return false +} + +func isUnderDirectory(filePath, directory string) bool { + rel, err := filepath.Rel(directory, filePath) + if err != nil { + return false + } + return !strings.HasPrefix(rel, "..") +} + +func isReadable(fp string) bool { + f, err := os.Open(fp) + if err != nil { + return false + } + buf := make([]byte, 1) + _, err = f.Read(buf) + f.Close() + return err == nil +} + +func splitNonEmpty(s, sep string) []string { + parts := strings.Split(s, sep) + result := make([]string, 0, len(parts)) + for _, p := range parts { + if p != "" { + result = append(result, p) + } + } + return result +} diff --git a/internal/primitives/primmodels/primmodels.go b/internal/primitives/primmodels/primmodels.go index 3469c28..1aeb2b1 100644 --- a/internal/primitives/primmodels/primmodels.go +++ b/internal/primitives/primmodels/primmodels.go @@ -1,6 +1,11 @@ // Package primmodels defines data models for APM primitives. package primmodels +// Primitive is the common interface for all APM primitive types. +type Primitive interface { + Validate() []string +} + // Chatmode represents a chatmode primitive. type Chatmode struct { Name string @@ -60,6 +65,14 @@ Version string Source string } +// Validate returns validation errors for a Context. +func (c *Context) Validate() []string { +if c.Content == "" { +return []string{"Empty content"} +} +return nil +} + // Skill represents a skill primitive. type Skill struct { Name string @@ -72,6 +85,11 @@ Version string Source string } +// Validate returns validation errors for a Skill. +func (s *Skill) Validate() []string { +return nil +} + // Agent represents an agent primitive. type Agent struct { Name string diff --git a/internal/primitives/primparser/primparser.go b/internal/primitives/primparser/primparser.go new file mode 100644 index 0000000..f33eb17 --- /dev/null +++ b/internal/primitives/primparser/primparser.go @@ -0,0 +1,210 @@ +// Package primparser parses APM primitive definition files. +// Migrated from src/apm_cli/primitives/parser.py. +package primparser + +import ( + "bufio" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/githubnext/apm/internal/primitives/primmodels" +) + +// ParseSkillFile parses a SKILL.md file and returns a Skill primitive. +// source is an optional identifier like "local" or "dependency:pkg". +func ParseSkillFile(filePath string, source string) (*primmodels.Skill, error) { + meta, content, err := parseFrontmatter(filePath) + if err != nil { + return nil, fmt.Errorf("failed to parse SKILL.md file %s: %w", filePath, err) + } + + name := meta["name"] + if name == "" { + // Derive from parent directory name. + name = filepath.Base(filepath.Dir(filePath)) + } + + return &primmodels.Skill{ + Name: name, + FilePath: filePath, + Description: meta["description"], + Content: content, + Source: source, + }, nil +} + +// ParsePrimitiveFile parses a primitive file (.chatmode.md, .instructions.md, +// .context.md, .memory.md) and returns the appropriate Primitive. +func ParsePrimitiveFile(filePath string, source string) (primmodels.Primitive, error) { + meta, content, err := parseFrontmatter(filePath) + if err != nil { + return nil, fmt.Errorf("failed to parse primitive file %s: %w", filePath, err) + } + + name := extractPrimitiveName(filePath) + base := filepath.Base(filePath) + + switch { + case strings.HasSuffix(base, ".chatmode.md") || strings.HasSuffix(base, ".agent.md"): + return parseChatmode(name, filePath, meta, content, source), nil + case strings.HasSuffix(base, ".instructions.md"): + return parseInstruction(name, filePath, meta, content, source), nil + case strings.HasSuffix(base, ".context.md") || strings.HasSuffix(base, ".memory.md") || isContextFile(filePath): + return parseContext(name, filePath, meta, content, source), nil + default: + return nil, fmt.Errorf("unknown primitive file type: %s", filePath) + } +} + +// ValidatePrimitive returns a list of validation errors for the primitive. +func ValidatePrimitive(p primmodels.Primitive) []string { + return p.Validate() +} + +func parseChatmode(name, filePath string, meta map[string]string, content, source string) *primmodels.Chatmode { + return &primmodels.Chatmode{ + Name: name, + FilePath: filePath, + Description: meta["description"], + ApplyTo: meta["applyTo"], + Content: content, + Author: meta["author"], + Version: meta["version"], + Source: source, + } +} + +func parseInstruction(name, filePath string, meta map[string]string, content, source string) *primmodels.Instruction { + return &primmodels.Instruction{ + Name: name, + FilePath: filePath, + Description: meta["description"], + ApplyTo: meta["applyTo"], + Content: content, + Author: meta["author"], + Version: meta["version"], + Source: source, + } +} + +func parseContext(name, filePath string, meta map[string]string, content, source string) *primmodels.Context { + return &primmodels.Context{ + Name: name, + FilePath: filePath, + Content: content, + Description: meta["description"], + Author: meta["author"], + Version: meta["version"], + Source: source, + } +} + +// extractPrimitiveName derives the primitive name from the file path following +// APM naming conventions. +func extractPrimitiveName(filePath string) string { + abs, _ := filepath.Abs(filePath) + parts := strings.Split(filepath.ToSlash(abs), "/") + + // Check for structured directories (.apm/ or .github/) + subDirs := map[string]bool{ + "chatmodes": true, "instructions": true, + "context": true, "memory": true, "agents": true, + } + for i, p := range parts { + if (p == ".apm" || p == ".github") && i+2 < len(parts) && subDirs[parts[i+1]] { + return stripPrimExt(filepath.Base(filePath)) + } + } + + return stripPrimExt(filepath.Base(filePath)) +} + +func stripPrimExt(basename string) string { + suffixes := []string{ + ".chatmode.md", ".instructions.md", ".context.md", + ".memory.md", ".agent.md", + } + for _, s := range suffixes { + if strings.HasSuffix(basename, s) { + return strings.TrimSuffix(basename, s) + } + } + if strings.HasSuffix(basename, ".md") { + return strings.TrimSuffix(basename, ".md") + } + ext := filepath.Ext(basename) + return strings.TrimSuffix(basename, ext) +} + +// isContextFile returns true for files directly under .apm/memory/ or .github/memory/. +func isContextFile(filePath string) bool { + dir := filepath.Base(filepath.Dir(filePath)) + parent := filepath.Base(filepath.Dir(filepath.Dir(filePath))) + if dir != "memory" { + return false + } + return parent == ".apm" || parent == ".github" +} + +// parseFrontmatter reads a file and splits YAML frontmatter (--- ... ---) from +// the body. Returns the parsed key/value pairs and the body content. +// Only flat key: value pairs are supported (no nesting or lists). +func parseFrontmatter(filePath string) (map[string]string, string, error) { + f, err := os.Open(filePath) // #nosec G304 + if err != nil { + return nil, "", err + } + defer f.Close() + + scanner := bufio.NewScanner(f) + var lines []string + for scanner.Scan() { + lines = append(lines, scanner.Text()) + } + if err := scanner.Err(); err != nil { + return nil, "", err + } + + meta := map[string]string{} + if len(lines) == 0 { + return meta, "", nil + } + + // Check for leading frontmatter delimiter. + if strings.TrimSpace(lines[0]) != "---" { + return meta, strings.Join(lines, "\n"), nil + } + + // Find closing delimiter. + end := -1 + for i := 1; i < len(lines); i++ { + if strings.TrimSpace(lines[i]) == "---" { + end = i + break + } + } + if end == -1 { + // No closing delimiter -- treat entire file as content. + return meta, strings.Join(lines, "\n"), nil + } + + // Parse frontmatter block. + for _, line := range lines[1:end] { + idx := strings.Index(line, ":") + if idx < 0 { + continue + } + key := strings.TrimSpace(line[:idx]) + val := strings.TrimSpace(line[idx+1:]) + // Strip surrounding quotes. + if len(val) >= 2 && ((val[0] == '"' && val[len(val)-1] == '"') || (val[0] == '\'' && val[len(val)-1] == '\'')) { + val = val[1 : len(val)-1] + } + meta[key] = val + } + + content := strings.Join(lines[end+1:], "\n") + return meta, strings.TrimLeft(content, "\n"), nil +} diff --git a/internal/primitives/primparser/primparser_test.go b/internal/primitives/primparser/primparser_test.go new file mode 100644 index 0000000..89a12a8 --- /dev/null +++ b/internal/primitives/primparser/primparser_test.go @@ -0,0 +1,92 @@ +package primparser_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/githubnext/apm/internal/primitives/primparser" +) + +func writeTmp(t *testing.T, content string) string { + t.Helper() + f, err := os.CreateTemp(t.TempDir(), "*.md") + if err != nil { + t.Fatal(err) + } + if _, err := f.WriteString(content); err != nil { + t.Fatal(err) + } + f.Close() + return f.Name() +} + +func TestParseFrontmatterNoFM(t *testing.T) { + path := writeTmp(t, "just content\nno frontmatter\n") + // Rename to .instructions.md so ParsePrimitiveFile picks it up. + newPath := filepath.Join(filepath.Dir(path), "foo.instructions.md") + os.Rename(path, newPath) + prim, err := primparser.ParsePrimitiveFile(newPath, "local") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if prim == nil { + t.Fatal("expected non-nil primitive") + } +} + +func TestParseFrontmatterWithFM(t *testing.T) { + content := "---\nname: TestSkill\ndescription: A test skill\n---\n# Body\n" + dir := t.TempDir() + path := filepath.Join(dir, "SKILL.md") + os.WriteFile(path, []byte(content), 0o644) + skill, err := primparser.ParseSkillFile(path, "local") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if skill.Name != "TestSkill" { + t.Errorf("expected name 'TestSkill', got %q", skill.Name) + } + if skill.Description != "A test skill" { + t.Errorf("expected description 'A test skill', got %q", skill.Description) + } + if !contains(skill.Content, "# Body") { + t.Errorf("expected content to contain '# Body', got %q", skill.Content) + } +} + +func TestParseChatmode(t *testing.T) { + content := "---\ndescription: My chatmode\napplyTo: '**'\n---\nChatmode body\n" + dir := t.TempDir() + path := filepath.Join(dir, "test.chatmode.md") + os.WriteFile(path, []byte(content), 0o644) + prim, err := primparser.ParsePrimitiveFile(path, "dep:pkg") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + errs := prim.Validate() + if len(errs) != 0 { + t.Errorf("expected no validation errors, got %v", errs) + } +} + +func TestParseUnknownType(t *testing.T) { + path := writeTmp(t, "content") + _, err := primparser.ParsePrimitiveFile(path, "local") + if err == nil { + t.Fatal("expected error for unknown primitive type") + } +} + +func contains(s, sub string) bool { + return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsStr(s, sub)) +} + +func containsStr(s, sub string) bool { + for i := 0; i <= len(s)-len(sub); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false +} diff --git a/internal/runtime/codexruntime/codexruntime.go b/internal/runtime/codexruntime/codexruntime.go new file mode 100644 index 0000000..7b61e57 --- /dev/null +++ b/internal/runtime/codexruntime/codexruntime.go @@ -0,0 +1,121 @@ +// Package codexruntime provides the Codex CLI runtime adapter for APM. +// Migrated from src/apm_cli/runtime/codex_runtime.py +package codexruntime + +import ( + "errors" + "os/exec" + "strings" + "time" +) + +// installCmd is the install instruction shown when codex is missing. +const installCmd = "npm i -g @openai/codex@native" + +// CodexRuntime is the APM adapter for the Codex CLI. +type CodexRuntime struct { + ModelName string +} + +// IsAvailable returns true when the codex binary is on PATH. +func IsAvailable() bool { + _, err := exec.LookPath("codex") + return err == nil +} + +// GetRuntimeName returns "codex". +func (r *CodexRuntime) GetRuntimeName() string { return "codex" } + +// New creates a CodexRuntime. +// Returns an error when the codex binary is not available. +func New(modelName string) (*CodexRuntime, error) { + if !IsAvailable() { + return nil, errors.New("Codex CLI not available. Install with: " + installCmd) + } + if modelName == "" { + modelName = "default" + } + return &CodexRuntime{ModelName: modelName}, nil +} + +// NewDefault creates a CodexRuntime with the default model. +func NewDefault() (*CodexRuntime, error) { return New("") } + +// ExecutePrompt runs the given prompt through codex exec with real-time streaming. +// Times out after 5 minutes. +func (r *CodexRuntime) ExecutePrompt(prompt string) (string, error) { + cmd := exec.Command("codex", "exec", "--skip-git-repo-check", prompt) + + out, err := runWithTimeout(cmd, 5*time.Minute) + if err != nil { + if strings.Contains(out, "OPENAI_API_KEY") { + return "", errors.New("Codex execution failed: Missing or invalid OPENAI_API_KEY. Please set your OpenAI API key.") + } + return "", err + } + return strings.TrimSpace(out), nil +} + +// ListAvailableModels returns a static map of available Codex models. +// Codex does not expose model listing via CLI. +func (r *CodexRuntime) ListAvailableModels() map[string]interface{} { + return map[string]interface{}{ + "codex-default": map[string]string{ + "id": "codex-default", + "provider": "codex", + "description": "Default Codex model (managed by Codex CLI)", + }, + } +} + +// GetRuntimeInfo returns metadata about this runtime adapter. +func (r *CodexRuntime) GetRuntimeInfo() map[string]interface{} { + version := "unknown" + if out, err := exec.Command("codex", "--version").Output(); err == nil { + version = strings.TrimSpace(string(out)) + } + return map[string]interface{}{ + "name": "codex", + "type": "codex_cli", + "version": version, + "capabilities": map[string]interface{}{ + "model_execution": true, + "mcp_servers": "native_support", + "configuration": "config.toml", + "sandboxing": "built_in", + }, + "description": "OpenAI Codex CLI runtime adapter", + } +} + +// String returns a human-readable representation. +func (r *CodexRuntime) String() string { + return "CodexRuntime(model=" + r.ModelName + ")" +} + +// runWithTimeout executes cmd, collecting all output, and returns it along with +// any error. The process is killed after timeout. +func runWithTimeout(cmd *exec.Cmd, timeout time.Duration) (string, error) { + var buf strings.Builder + cmd.Stdout = &buf + cmd.Stderr = &buf + + if err := cmd.Start(); err != nil { + return "", errors.New("Codex CLI not found. Install with: " + installCmd) + } + + done := make(chan error, 1) + go func() { done <- cmd.Wait() }() + + select { + case err := <-done: + output := buf.String() + if err != nil { + return output, errors.New("Codex execution failed: " + err.Error()) + } + return output, nil + case <-time.After(timeout): + cmd.Process.Kill() + return "", errors.New("Codex execution timed out after 5 minutes") + } +} diff --git a/internal/runtime/factory/factory.go b/internal/runtime/factory/factory.go new file mode 100644 index 0000000..3329f91 --- /dev/null +++ b/internal/runtime/factory/factory.go @@ -0,0 +1,121 @@ +// Package factory provides a factory for creating runtime adapters with auto-detection. +package factory + +import "fmt" + +// RuntimeInfo holds metadata about an available runtime. +type RuntimeInfo struct { +Name string +Available bool +Error string +} + +// RuntimeAdapter is the interface that all runtime adapters must implement. +type RuntimeAdapter interface { +GetRuntimeName() string +IsAvailable() bool +GetRuntimeInfo() RuntimeInfo +} + +// ConstructableAdapter extends RuntimeAdapter with constructors. +type ConstructableAdapter interface { +RuntimeAdapter +New(modelName string) (RuntimeAdapter, error) +NewDefault() (RuntimeAdapter, error) +} + +// Registry holds the ordered list of runtime adapter constructors. +type Registry struct { +adapters []ConstructableAdapter +} + +// NewRegistry creates a Registry with the given adapter constructors in preference order. +func NewRegistry(adapters ...ConstructableAdapter) *Registry { +return &Registry{adapters: adapters} +} + +// GetAvailableRuntimes returns metadata for all available runtimes. +func (r *Registry) GetAvailableRuntimes() []RuntimeInfo { +var out []RuntimeInfo +for _, a := range r.adapters { +if !a.IsAvailable() { +continue +} +info := a.GetRuntimeInfo() +info.Available = true +if info.Error != "" { +out = append(out, info) +continue +} +instance, err := a.NewDefault() +if err != nil { +out = append(out, RuntimeInfo{ +Name: a.GetRuntimeName(), +Available: true, +Error: fmt.Sprintf("Available but failed to initialize: %v", err), +}) +continue +} +info = instance.GetRuntimeInfo() +info.Available = true +out = append(out, info) +} +return out +} + +// GetRuntimeByName returns a runtime adapter by name. +// Returns an error if the runtime is not found or not available. +func (r *Registry) GetRuntimeByName(runtimeName, modelName string) (RuntimeAdapter, error) { +for _, a := range r.adapters { +if a.GetRuntimeName() != runtimeName { +continue +} +if !a.IsAvailable() { +return nil, fmt.Errorf("runtime %q is not available on this system", runtimeName) +} +if modelName != "" { +return a.New(modelName) +} +return a.NewDefault() +} +return nil, fmt.Errorf("unknown runtime: %s", runtimeName) +} + +// GetBestAvailableRuntime returns the first available runtime in preference order. +func (r *Registry) GetBestAvailableRuntime(modelName string) (RuntimeAdapter, error) { +for _, a := range r.adapters { +if !a.IsAvailable() { +continue +} +var ( +instance RuntimeAdapter +err error +) +if modelName != "" { +instance, err = a.New(modelName) +} else { +instance, err = a.NewDefault() +} +if err == nil { +return instance, nil +} +} +return nil, fmt.Errorf("no runtimes available; install at least one of: " + +"Copilot CLI (npm i -g @github/copilot), Codex CLI (npm i -g @openai/codex@native), " + +"or LLM library (pip install llm)") +} + +// CreateRuntime creates a runtime adapter with optional name and model. +// If runtimeName is empty, returns the best available runtime. +func (r *Registry) CreateRuntime(runtimeName, modelName string) (RuntimeAdapter, error) { +if runtimeName != "" { +return r.GetRuntimeByName(runtimeName, modelName) +} +return r.GetBestAvailableRuntime(modelName) +} + +// RuntimeExists checks if a runtime exists and is available. +func (r *Registry) RuntimeExists(runtimeName string) bool { +_, err := r.GetRuntimeByName(runtimeName, "") +return err == nil +} diff --git a/internal/runtime/llmruntime/llmruntime.go b/internal/runtime/llmruntime/llmruntime.go new file mode 100644 index 0000000..7afa7ab --- /dev/null +++ b/internal/runtime/llmruntime/llmruntime.go @@ -0,0 +1,95 @@ +// Package llmruntime provides the LLM CLI runtime adapter for APM. +// Migrated from src/apm_cli/runtime/llm_runtime.py +package llmruntime + +import ( + "errors" + "os/exec" + "strings" +) + +// LLMRuntime is the APM adapter for the llm CLI tool. +type LLMRuntime struct { + ModelName string +} + +// IsAvailable returns true when the llm binary is on PATH and responds to --version. +func IsAvailable() bool { + cmd := exec.Command("llm", "--version") + return cmd.Run() == nil +} + +// GetRuntimeName returns "llm". +func (r *LLMRuntime) GetRuntimeName() string { return "llm" } + +// New creates an LLMRuntime for the given model. +// Returns an error when the llm binary is not available. +func New(modelName string) (*LLMRuntime, error) { + if !IsAvailable() { + return nil, errors.New("llm CLI not found. Please install: pip install llm") + } + return &LLMRuntime{ModelName: modelName}, nil +} + +// NewDefault creates an LLMRuntime using the llm CLI default model. +func NewDefault() (*LLMRuntime, error) { return New("") } + +// ExecutePrompt runs the given prompt through the llm CLI and returns the response. +func (r *LLMRuntime) ExecutePrompt(prompt string) (string, error) { + args := []string{} + if r.ModelName != "" { + args = append(args, "-m", r.ModelName) + } + args = append(args, prompt) + + cmd := exec.Command("llm", args...) + var buf strings.Builder + cmd.Stdout = &buf + cmd.Stderr = &buf + + if err := cmd.Run(); err != nil { + return "", errors.New("LLM execution failed: " + buf.String()) + } + return strings.TrimSpace(buf.String()), nil +} + +// ListAvailableModels returns a map of available models by querying `llm models list`. +func (r *LLMRuntime) ListAvailableModels() map[string]interface{} { + out, err := exec.Command("llm", "models", "list").Output() + if err != nil { + return map[string]interface{}{"error": "failed to list models: " + err.Error()} + } + models := map[string]interface{}{} + for _, line := range strings.Split(strings.TrimSpace(string(out)), "\n") { + line = strings.TrimSpace(line) + if line != "" { + models[line] = map[string]string{"id": line, "provider": "llm"} + } + } + return models +} + +// GetRuntimeInfo returns metadata about this runtime adapter. +func (r *LLMRuntime) GetRuntimeInfo() map[string]interface{} { + model := r.ModelName + if model == "" { + model = "default" + } + return map[string]interface{}{ + "name": "llm", + "type": "llm_library", + "current_model": model, + "capabilities": map[string]interface{}{ + "model_execution": true, + "mcp_servers": "runtime_dependent", + "configuration": "llm_commands", + "sandboxing": "runtime_dependent", + }, + "description": "LLM CLI runtime adapter", + } +} + +// String returns a human-readable representation. +func (r *LLMRuntime) String() string { + return "LLMRuntime(model=" + r.ModelName + ")" +} diff --git a/internal/security/auditreport/auditreport.go b/internal/security/auditreport/auditreport.go new file mode 100644 index 0000000..ae6d01b --- /dev/null +++ b/internal/security/auditreport/auditreport.go @@ -0,0 +1,323 @@ +// Package auditreport provides serialization helpers for apm audit results. +// Supports JSON, SARIF 2.1.0, and Markdown output formats. +// Migrated from src/apm_cli/security/audit_report.py +package auditreport + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "sort" + "strings" +) + +// ScanFinding represents a single security finding from a content scan. +type ScanFinding struct { + // Severity is "critical", "warning", or "info". + Severity string + // File is the path to the file containing the finding. + File string + // Line is the 1-based line number. + Line int + // Column is the 1-based column number. + Column int + // Codepoint is the Unicode codepoint string (e.g. "U+200B"). + Codepoint string + // Category classifies the finding type (e.g. "zero-width"). + Category string + // Description is a human-readable explanation. + Description string +} + +const ( + sarifVersion = "2.1.0" + sarifSchema = "https://docs.oasis-open.org/sarif/sarif/v2.1.0/cos02/schemas/sarif-schema-2.1.0.json" + toolName = "apm-audit" + toolInfoURI = "https://apm.github.io/apm/enterprise/security/" +) + +// severityMap maps APM severity strings to SARIF level strings. +var severityMap = map[string]string{ + "critical": "error", + "warning": "warning", + "info": "note", +} + +// RelativePathForReport normalizes a file path to a relative forward-slash path. +func RelativePathForReport(filePath string) string { + p := filepath.Clean(filePath) + if filepath.IsAbs(p) { + cwd, err := os.Getwd() + if err == nil { + rel, err2 := filepath.Rel(cwd, p) + if err2 == nil { + return filepath.ToSlash(rel) + } + } + return filepath.Base(p) + } + return strings.ReplaceAll(filePath, "\\", "/") +} + +// ruleID builds a SARIF rule ID from a finding category. +func ruleID(category string) string { + return "apm/hidden-unicode/" + category +} + +// allFindings flattens a map of findings by file into a single slice. +func allFindings(findingsByFile map[string][]ScanFinding) []ScanFinding { + var out []ScanFinding + for _, ff := range findingsByFile { + out = append(out, ff...) + } + return out +} + +// FindingsToJSON converts scan findings to APM's JSON report format. +func FindingsToJSON(findingsByFile map[string][]ScanFinding, filesScanned int, exitCode int) map[string]interface{} { + all := allFindings(findingsByFile) + + critical, warning, info := 0, 0, 0 + for _, f := range all { + switch f.Severity { + case "critical": + critical++ + case "warning": + warning++ + case "info": + info++ + } + } + + items := make([]map[string]interface{}, 0, len(all)) + for _, f := range all { + items = append(items, map[string]interface{}{ + "severity": f.Severity, + "file": RelativePathForReport(f.File), + "line": f.Line, + "column": f.Column, + "codepoint": f.Codepoint, + "category": f.Category, + "description": f.Description, + }) + } + + return map[string]interface{}{ + "version": "1", + "exit_code": exitCode, + "summary": map[string]interface{}{ + "files_scanned": filesScanned, + "files_affected": len(findingsByFile), + "critical": critical, + "warning": warning, + "info": info, + }, + "findings": items, + } +} + +// FindingsToSARIF converts scan findings to SARIF 2.1.0 format. +func FindingsToSARIF(findingsByFile map[string][]ScanFinding, filesScanned int) map[string]interface{} { + all := allFindings(findingsByFile) + + seenRules := map[string]map[string]interface{}{} + for _, f := range all { + rid := ruleID(f.Category) + if _, exists := seenRules[rid]; !exists { + seenRules[rid] = map[string]interface{}{ + "id": rid, + "shortDescription": map[string]interface{}{ + "text": strings.Title(strings.ReplaceAll(f.Category, "-", " ")), + }, + "defaultConfiguration": map[string]interface{}{ + "level": func() string { + if v, ok := severityMap[f.Severity]; ok { + return v + } + return "note" + }(), + }, + "helpUri": toolInfoURI, + } + } + } + + rulesList := make([]interface{}, 0, len(seenRules)) + for _, r := range seenRules { + rulesList = append(rulesList, r) + } + + results := make([]interface{}, 0, len(all)) + for _, f := range all { + level := "note" + if v, ok := severityMap[f.Severity]; ok { + level = v + } + results = append(results, map[string]interface{}{ + "ruleId": ruleID(f.Category), + "level": level, + "message": map[string]interface{}{ + "text": fmt.Sprintf("%s (%s)", f.Description, f.Codepoint), + }, + "locations": []interface{}{ + map[string]interface{}{ + "physicalLocation": map[string]interface{}{ + "artifactLocation": map[string]interface{}{ + "uri": RelativePathForReport(f.File), + }, + "region": map[string]interface{}{ + "startLine": f.Line, + "startColumn": f.Column, + }, + }, + }, + }, + "properties": map[string]interface{}{ + "codepoint": f.Codepoint, + "category": f.Category, + }, + }) + } + + return map[string]interface{}{ + "$schema": sarifSchema, + "version": sarifVersion, + "runs": []interface{}{ + map[string]interface{}{ + "tool": map[string]interface{}{ + "driver": map[string]interface{}{ + "name": toolName, + "informationUri": toolInfoURI, + "rules": rulesList, + }, + }, + "results": results, + "invocations": []interface{}{ + map[string]interface{}{ + "executionSuccessful": true, + "properties": map[string]interface{}{ + "filesScanned": filesScanned, + }, + }, + }, + }, + }, + } +} + +// WriteReport writes a report dict as JSON to the given path. +func WriteReport(report map[string]interface{}, outputPath string) error { + if err := os.MkdirAll(filepath.Dir(outputPath), 0o755); err != nil { + return err + } + data, err := json.MarshalIndent(report, "", " ") + if err != nil { + return err + } + return os.WriteFile(outputPath, append(data, '\n'), 0o644) +} + +// SerializeReport serializes a report dict to a JSON string. +func SerializeReport(report map[string]interface{}) (string, error) { + data, err := json.MarshalIndent(report, "", " ") + if err != nil { + return "", err + } + return string(data), nil +} + +// FindingsToMarkdown converts scan findings to GitHub-Flavored Markdown. +func FindingsToMarkdown(findingsByFile map[string][]ScanFinding, filesScanned int) string { + all := allFindings(findingsByFile) + + if len(all) == 0 { + return fmt.Sprintf("## APM Audit Report\n\n**Clean** -- no security findings across %d files.\n", filesScanned) + } + + critical, warning, info := 0, 0, 0 + for _, f := range all { + switch f.Severity { + case "critical": + critical++ + case "warning": + warning++ + case "info": + info++ + } + } + affected := len(findingsByFile) + total := len(all) + + parts := []string{} + if critical > 0 { + parts = append(parts, fmt.Sprintf("%d critical", critical)) + } + if warning > 0 { + s := "s" + if warning == 1 { + s = "" + } + parts = append(parts, fmt.Sprintf("%d warning%s", warning, s)) + } + if info > 0 { + parts = append(parts, fmt.Sprintf("%d info", info)) + } + + countLabel := fmt.Sprintf("**%d finding", total) + if total != 1 { + countLabel += "s" + } + countLabel += "**" + + affectedStr := "files" + if affected == 1 { + affectedStr = "file" + } + + summary := fmt.Sprintf("%s across %d %s (%s) | %d files scanned", + countLabel, affected, affectedStr, strings.Join(parts, ", "), filesScanned) + + severityOrder := map[string]int{"critical": 0, "warning": 1, "info": 2} + sort.SliceStable(all, func(i, j int) bool { + si := severityOrder[all[i].Severity] + sj := severityOrder[all[j].Severity] + if si != sj { + return si < sj + } + if all[i].File != all[j].File { + return all[i].File < all[j].File + } + return all[i].Line < all[j].Line + }) + + var sb strings.Builder + sb.WriteString("## APM Audit Report\n\n") + sb.WriteString(summary + "\n\n") + sb.WriteString("| Severity | File | Location | Codepoint | Description |\n") + sb.WriteString("|----------|------|----------|-----------|-------------|\n") + for _, f := range all { + sev := strings.ToUpper(f.Severity) + desc := strings.ReplaceAll(f.Description, "|", "\\|") + sb.WriteString(fmt.Sprintf("| %s | `%s` | %d:%d | `%s` | %s |\n", + sev, RelativePathForReport(f.File), f.Line, f.Column, f.Codepoint, desc)) + } + sb.WriteString("\nRun `apm audit --strip` to remove flagged characters.\n") + + return sb.String() +} + +// DetectFormatFromExtension auto-detects output format from file extension. +func DetectFormatFromExtension(path string) string { + name := strings.ToLower(filepath.Base(path)) + if strings.HasSuffix(name, ".sarif.json") || strings.HasSuffix(name, ".sarif") { + return "sarif" + } + if strings.HasSuffix(name, ".json") { + return "json" + } + if strings.HasSuffix(name, ".md") { + return "markdown" + } + return "text" +} diff --git a/internal/updatepolicy/updatepolicy.go b/internal/updatepolicy/updatepolicy.go new file mode 100644 index 0000000..6b181f9 --- /dev/null +++ b/internal/updatepolicy/updatepolicy.go @@ -0,0 +1,54 @@ +// Package updatepolicy provides build-time policy for APM self-update behavior. +// Package maintainers can patch constants during build to disable self-update +// and show users a package-manager-specific update command. +package updatepolicy + +// DefaultSelfUpdateDisabledMessage is the default guidance when self-update is disabled. +const DefaultSelfUpdateDisabledMessage = "Self-update is disabled for this APM distribution. Update APM using your package manager." + +// Build-time policy values. Packagers can override these at link time via +// -ldflags "-X updatepolicy.SelfUpdateEnabled=false". +var ( + // SelfUpdateEnabled controls whether self-update is allowed. + SelfUpdateEnabled = true + // SelfUpdateDisabledMessage is shown when self-update is disabled. + SelfUpdateDisabledMessage = DefaultSelfUpdateDisabledMessage +) + +// isPrintableASCII returns true when s contains only printable ASCII characters. +func isPrintableASCII(s string) bool { + for _, c := range s { + if c < ' ' || c > '~' { + return false + } + } + return true +} + +// IsSelfUpdateEnabled returns true when this build allows self-update. +func IsSelfUpdateEnabled() bool { + return SelfUpdateEnabled +} + +// GetSelfUpdateDisabledMessage returns the guidance message shown when self-update is disabled. +func GetSelfUpdateDisabledMessage() string { + if SelfUpdateDisabledMessage == "" { + return DefaultSelfUpdateDisabledMessage + } + msg := SelfUpdateDisabledMessage + if msg == "" { + return DefaultSelfUpdateDisabledMessage + } + if !isPrintableASCII(msg) { + return DefaultSelfUpdateDisabledMessage + } + return msg +} + +// GetUpdateHintMessage returns the update hint used in startup notifications. +func GetUpdateHintMessage() string { + if IsSelfUpdateEnabled() { + return "Run apm update to upgrade" + } + return GetSelfUpdateDisabledMessage() +} diff --git a/internal/utils/githubhost/githubhost.go b/internal/utils/githubhost/githubhost.go index b05670c..ddaf01c 100644 --- a/internal/utils/githubhost/githubhost.go +++ b/internal/utils/githubhost/githubhost.go @@ -211,6 +211,35 @@ func AzureDevOpsOrgFromHostname(hostname string) string { return parts[0] } +// IsSupportedGitHost returns true for any hostname that APM recognises as a valid +// Git host: github.com, GHES, GHE.com, GitLab, Azure DevOps, or Artifactory. +// Any syntactically valid FQDN is accepted (self-hosted instances). +func IsSupportedGitHost(hostname string) bool { + if hostname == "" { + return false + } + h := normalizeHost(hostname) + return IsValidFQDN(h) +} + +// IsArtifactoryPath returns true when path segments start with "artifactory/". +func IsArtifactoryPath(segments []string) bool { + return len(segments) >= 4 && strings.EqualFold(segments[0], "artifactory") +} + +// ParseArtifactoryPath extracts (prefix, owner, repo) from Artifactory path segments. +// Segments are expected as ["artifactory", "", "", "", ...]. +// Returns empty strings if the segments do not match. +func ParseArtifactoryPath(segments []string) (prefix, owner, repo string) { + if !IsArtifactoryPath(segments) { + return + } + prefix = strings.Join(segments[:2], "/") + owner = segments[2] + repo = segments[3] + return +} + func normalizeHost(s string) string { s = strings.TrimSpace(s) s = strings.ToLower(s)