diff --git a/cmd/publisher/commands/status.go b/cmd/publisher/commands/status.go new file mode 100644 index 000000000..6310d5874 --- /dev/null +++ b/cmd/publisher/commands/status.go @@ -0,0 +1,246 @@ +package commands + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" +) + +// StatusUpdateRequest represents the request body for status update endpoints +type StatusUpdateRequest struct { + Status string `json:"status"` + StatusMessage *string `json:"statusMessage,omitempty"` + AlternativeURL *string `json:"alternativeUrl,omitempty"` + NewName *string `json:"newName,omitempty"` +} + +// AllVersionsStatusResponse represents the response from the all-versions status endpoint +type AllVersionsStatusResponse struct { + UpdatedCount int `json:"updatedCount"` +} + +func StatusCommand(args []string) error { + // Parse command flags + fs := flag.NewFlagSet("status", flag.ExitOnError) + status := fs.String("status", "", "New status: active, deprecated, or yanked (required)") + message := fs.String("message", "", "Optional status message explaining the change") + alternativeURL := fs.String("alternative-url", "", "Optional URL to alternative/replacement server") + newName := fs.String("new-name", "", "Optional new server name when server has been renamed") + allVersions := fs.Bool("all-versions", false, "Apply status change to all versions of the server") + + if err := fs.Parse(args); err != nil { + return err + } + + // Validate required arguments + if *status == "" { + return errors.New("--status flag is required (active, deprecated, or yanked)") + } + + // Validate status value + validStatuses := map[string]bool{"active": true, "deprecated": true, "yanked": true} + if !validStatuses[*status] { + return fmt.Errorf("invalid status '%s'. Must be one of: active, deprecated, yanked", *status) + } + + // Get server name from positional args + remainingArgs := fs.Args() + if len(remainingArgs) < 1 { + return errors.New("server name is required\n\nUsage: mcp-publisher status [version] --status [flags]") + } + + serverName := remainingArgs[0] + var version string + + // Get version if provided (required unless --all-versions is set) + if !*allVersions { + if len(remainingArgs) < 2 { + return errors.New("version is required unless --all-versions flag is set\n\nUsage: mcp-publisher status --status [flags]") + } + version = remainingArgs[1] + } + + // Validate new-name parameter constraints + if *newName != "" { + // Validation: new-name requires deprecated or yanked status + if *status != "deprecated" && *status != "yanked" { + return errors.New("--new-name can only be used with --status deprecated or --status yanked") + } + // Validation: new-name requires --all-versions flag + if !*allVersions { + return errors.New("--new-name requires --all-versions flag") + } + } + + // Load saved token + homeDir, err := os.UserHomeDir() + if err != nil { + return fmt.Errorf("failed to get home directory: %w", err) + } + + tokenPath := filepath.Join(homeDir, TokenFileName) + tokenData, err := os.ReadFile(tokenPath) + if err != nil { + if os.IsNotExist(err) { + return errors.New("not authenticated. Run 'mcp-publisher login ' first") + } + return fmt.Errorf("failed to read token: %w", err) + } + + var tokenInfo map[string]string + if err := json.Unmarshal(tokenData, &tokenInfo); err != nil { + return fmt.Errorf("invalid token data: %w", err) + } + + token := tokenInfo["token"] + registryURL := tokenInfo["registry"] + if registryURL == "" { + registryURL = DefaultRegistryURL + } + + // Update status + if *allVersions { + return updateAllVersionsStatus(registryURL, serverName, *status, *message, *alternativeURL, *newName, token) + } + return updateVersionStatus(registryURL, serverName, version, *status, *message, *alternativeURL, *newName, token) +} + +func updateVersionStatus(registryURL, serverName, version, status, statusMessage, alternativeURL, newName, token string) error { + _, _ = fmt.Fprintf(os.Stdout, "Updating %s version %s to status: %s\n", serverName, version, status) + + if err := updateServerStatus(registryURL, serverName, version, status, statusMessage, alternativeURL, newName, token); err != nil { + return fmt.Errorf("failed to update status: %w", err) + } + + _, _ = fmt.Fprintln(os.Stdout, "✓ Successfully updated status") + return nil +} + +func updateAllVersionsStatus(registryURL, serverName, status, statusMessage, alternativeURL, newName, token string) error { + _, _ = fmt.Fprintf(os.Stdout, "Updating all versions of %s to status: %s\n", serverName, status) + + if !strings.HasSuffix(registryURL, "/") { + registryURL += "/" + } + + // Build the request body + requestBody := StatusUpdateRequest{ + Status: status, + } + if statusMessage != "" { + requestBody.StatusMessage = &statusMessage + } + if alternativeURL != "" { + requestBody.AlternativeURL = &alternativeURL + } + if newName != "" { + requestBody.NewName = &newName + } + + jsonData, err := json.Marshal(requestBody) + if err != nil { + return fmt.Errorf("error serializing request: %w", err) + } + + // URL encode the server name + encodedServerName := url.PathEscape(serverName) + statusURL := registryURL + "v0/servers/" + encodedServerName + "/status" + + req, err := http.NewRequestWithContext(context.Background(), http.MethodPatch, statusURL, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("error creating request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("error sending request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("error reading response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("server returned status %d: %s", resp.StatusCode, body) + } + + // Parse response to get updated count + var response AllVersionsStatusResponse + if err := json.Unmarshal(body, &response); err != nil { + // If we can't parse the response, just report success + _, _ = fmt.Fprintln(os.Stdout, "✓ Successfully updated all versions") + return nil + } + + _, _ = fmt.Fprintf(os.Stdout, "✓ Successfully updated %d version(s)\n", response.UpdatedCount) + return nil +} + +func updateServerStatus(registryURL, serverName, version, status, statusMessage, alternativeURL, newName, token string) error { + if !strings.HasSuffix(registryURL, "/") { + registryURL += "/" + } + + // Build the request body + requestBody := StatusUpdateRequest{ + Status: status, + } + if statusMessage != "" { + requestBody.StatusMessage = &statusMessage + } + if alternativeURL != "" { + requestBody.AlternativeURL = &alternativeURL + } + if newName != "" { + requestBody.NewName = &newName + } + + jsonData, err := json.Marshal(requestBody) + if err != nil { + return fmt.Errorf("error serializing request: %w", err) + } + + // URL encode the server name and version + encodedServerName := url.PathEscape(serverName) + encodedVersion := url.PathEscape(version) + statusURL := registryURL + "v0/servers/" + encodedServerName + "/versions/" + encodedVersion + "/status" + + req, err := http.NewRequestWithContext(context.Background(), http.MethodPatch, statusURL, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("error creating request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("error sending request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("error reading response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("server returned status %d: %s", resp.StatusCode, body) + } + + return nil +} diff --git a/cmd/publisher/commands/status_test.go b/cmd/publisher/commands/status_test.go new file mode 100644 index 000000000..bcbeed943 --- /dev/null +++ b/cmd/publisher/commands/status_test.go @@ -0,0 +1,285 @@ +package commands_test + +import ( + "strings" + "testing" + + "github.com/modelcontextprotocol/registry/cmd/publisher/commands" +) + +func TestStatusCommand_Validation(t *testing.T) { + tests := []struct { + name string + args []string + expectError bool + errorSubstr string + }{ + { + name: "missing --status flag", + args: []string{"io.github.user/my-server", "1.0.0"}, + expectError: true, + errorSubstr: "--status flag is required", + }, + { + name: "invalid status value", + args: []string{"--status", "invalid", "io.github.user/my-server", "1.0.0"}, + expectError: true, + errorSubstr: "invalid status 'invalid'", + }, + { + name: "missing server name", + args: []string{"--status", "deprecated"}, + expectError: true, + errorSubstr: "server name is required", + }, + { + name: "missing version without --all-versions", + args: []string{"--status", "deprecated", "io.github.user/my-server"}, + expectError: true, + errorSubstr: "version is required unless --all-versions", + }, + { + name: "valid args passes validation", + args: []string{"--status", "deprecated", "io.github.user/my-server", "1.0.0"}, + expectError: false, + }, + { + name: "valid args with --all-versions passes validation", + args: []string{"--status", "deprecated", "--all-versions", "io.github.user/my-server"}, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := commands.StatusCommand(tt.args) + + if tt.expectError { + if err == nil { + t.Errorf("Expected error but got none") + return + } + if !strings.Contains(err.Error(), tt.errorSubstr) { + t.Errorf("Expected error containing '%s', got: %v", tt.errorSubstr, err) + } + } else if err != nil { + // For valid args, we expect it to pass validation + // It may fail later at auth or API level, which is acceptable + // Just check it's not a validation error + if strings.Contains(err.Error(), "invalid status") || + strings.Contains(err.Error(), "server name is required") || + strings.Contains(err.Error(), "version is required unless") || + strings.Contains(err.Error(), "--status flag is required") { + t.Errorf("Validation failed unexpectedly: %v", err) + } + } + }) + } +} + +func TestStatusCommand_ServerNameValidation(t *testing.T) { + tests := []struct { + name string + serverName string + }{ + { + name: "valid github server name", + serverName: "io.github.user/my-server", + }, + { + name: "valid domain server name", + serverName: "com.example/my-server", + }, + { + name: "server name with dashes", + serverName: "io.github.user/my-cool-server", + }, + { + name: "server name with underscores", + serverName: "io.github.user/my_server", + }, + { + name: "server name with dots", + serverName: "io.github.user/my.server", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + args := []string{"--status", "deprecated", tt.serverName, "1.0.0"} + err := commands.StatusCommand(args) + + // Should pass validation (server name format is not validated by CLI) + if err != nil && strings.Contains(err.Error(), "server name is required") { + t.Errorf("Server name '%s' was rejected", tt.serverName) + } + }) + } +} + +func TestStatusCommand_VersionValidation(t *testing.T) { + tests := []struct { + name string + version string + }{ + { + name: "semver version", + version: "1.0.0", + }, + { + name: "semver with patch", + version: "1.2.3", + }, + { + name: "semver with prerelease", + version: "1.0.0-alpha", + }, + { + name: "semver with build metadata", + version: "1.0.0+20130313144700", + }, + { + name: "semver with prerelease and build", + version: "1.0.0-beta.1+exp.sha.5114f85", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + args := []string{"--status", "deprecated", "io.github.user/my-server", tt.version} + err := commands.StatusCommand(args) + + // Should pass validation (version format is not validated by CLI) + if err != nil && strings.Contains(err.Error(), "version is required") { + t.Errorf("Version '%s' was rejected", tt.version) + } + }) + } +} + +func TestStatusCommand_AllVersionsFlag(t *testing.T) { + tests := []struct { + name string + args []string + expectError bool + errorSubstr string + }{ + { + name: "all-versions without version arg passes validation", + args: []string{"--status", "deprecated", "--all-versions", "io.github.user/my-server"}, + expectError: false, + }, + { + name: "all-versions with extra version arg still works", + args: []string{"--status", "deprecated", "--all-versions", "io.github.user/my-server", "1.0.0"}, + expectError: false, + }, + { + name: "missing server name with all-versions", + args: []string{"--status", "deprecated", "--all-versions"}, + expectError: true, + errorSubstr: "server name is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := commands.StatusCommand(tt.args) + + if tt.expectError { + if err == nil { + t.Errorf("Expected error but got none") + return + } + if !strings.Contains(err.Error(), tt.errorSubstr) { + t.Errorf("Expected error containing '%s', got: %v", tt.errorSubstr, err) + } + } else if err != nil { + // Should pass validation + // Just check it's not a validation error + if strings.Contains(err.Error(), "invalid status") || + strings.Contains(err.Error(), "server name is required") || + strings.Contains(err.Error(), "version is required unless") || + strings.Contains(err.Error(), "--status flag is required") { + t.Errorf("Validation failed unexpectedly: %v", err) + } + } + }) + } +} + +func TestStatusCommand_FlagCombinations(t *testing.T) { + tests := []struct { + name string + args []string + }{ + { + name: "status with message", + args: []string{"--status", "deprecated", "--message", "Please upgrade to v2", "io.github.user/my-server", "1.0.0"}, + }, + { + name: "status with alternative-url", + args: []string{"--status", "deprecated", "--alternative-url", "https://github.com/user/new-server", "io.github.user/my-server", "1.0.0"}, + }, + { + name: "status with both message and alternative-url", + args: []string{"--status", "deprecated", "--message", "Deprecated", "--alternative-url", "https://example.com", "io.github.user/my-server", "1.0.0"}, + }, + { + name: "active status with message and url (CLI accepts, server validates)", + args: []string{"--status", "active", "--message", "Should be ignored", "--alternative-url", "https://example.com", "io.github.user/my-server", "1.0.0"}, + }, + { + name: "all-versions with message", + args: []string{"--status", "deprecated", "--all-versions", "--message", "All versions deprecated", "io.github.user/my-server"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := commands.StatusCommand(tt.args) + // All these should pass CLI validation + // They may fail at auth or API level which is acceptable + if err != nil { + // Just check it's not a validation error we can detect + if strings.Contains(err.Error(), "invalid status") || + strings.Contains(err.Error(), "server name is required") || + strings.Contains(err.Error(), "version is required unless") || + strings.Contains(err.Error(), "--status flag is required") { + t.Errorf("Validation failed unexpectedly: %v", err) + } + } + }) + } +} + +func TestStatusCommand_MissingStatus(t *testing.T) { + // Test various ways status flag can be missing + tests := []struct { + name string + args []string + }{ + { + name: "no status flag at all", + args: []string{"io.github.user/my-server", "1.0.0"}, + }, + { + name: "empty status value", + args: []string{"--status", "", "io.github.user/my-server", "1.0.0"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := commands.StatusCommand(tt.args) + + if err == nil { + t.Errorf("Expected error for missing status but got none") + return + } + if !strings.Contains(err.Error(), "--status flag is required") { + t.Errorf("Expected '--status flag is required' error, got: %v", err) + } + }) + } +} diff --git a/cmd/publisher/main.go b/cmd/publisher/main.go index 1b265d725..e71c5bd2d 100644 --- a/cmd/publisher/main.go +++ b/cmd/publisher/main.go @@ -43,6 +43,8 @@ func main() { err = commands.LogoutCommand() case "publish": err = commands.PublishCommand(os.Args[2:]) + case "status": + err = commands.StatusCommand(os.Args[2:]) case "validate": err = commands.ValidateCommand(os.Args[2:]) case "--version", "-v", "version": @@ -73,6 +75,7 @@ func printUsage() { _, _ = fmt.Fprintln(os.Stdout, " login Authenticate with the registry") _, _ = fmt.Fprintln(os.Stdout, " logout Clear saved authentication") _, _ = fmt.Fprintln(os.Stdout, " publish Publish server.json to the registry") + _, _ = fmt.Fprintln(os.Stdout, " status Update the status of a server version") _, _ = fmt.Fprintln(os.Stdout, " validate Validate server.json without publishing") _, _ = fmt.Fprintln(os.Stdout) _, _ = fmt.Fprintln(os.Stdout, "Use 'mcp-publisher --help' for more information about a command.") @@ -128,6 +131,46 @@ func printCommandHelp(command string) { _, _ = fmt.Fprintln(os.Stdout) _, _ = fmt.Fprintln(os.Stdout, "You must be logged in before publishing. Run 'mcp-publisher login' first.") + case "status": + _, _ = fmt.Fprintln(os.Stdout, "Update the status of a server version") + _, _ = fmt.Fprintln(os.Stdout) + _, _ = fmt.Fprintln(os.Stdout, "Usage:") + _, _ = fmt.Fprintln(os.Stdout, " mcp-publisher status [version] --status [flags]") + _, _ = fmt.Fprintln(os.Stdout) + _, _ = fmt.Fprintln(os.Stdout, "Arguments:") + _, _ = fmt.Fprintln(os.Stdout, " server-name Full server name (e.g., io.github.user/my-server)") + _, _ = fmt.Fprintln(os.Stdout, " version Server version to update (required unless --all-versions is set)") + _, _ = fmt.Fprintln(os.Stdout) + _, _ = fmt.Fprintln(os.Stdout, "Flags:") + _, _ = fmt.Fprintln(os.Stdout, " --status string New status: active, deprecated, or yanked (required)") + _, _ = fmt.Fprintln(os.Stdout, " --message string Optional message explaining the status change") + _, _ = fmt.Fprintln(os.Stdout, " --alternative-url string Optional URL to alternative/replacement server") + _, _ = fmt.Fprintln(os.Stdout, " --new-name string Optional new server name when server has been renamed") + _, _ = fmt.Fprintln(os.Stdout, " --all-versions Apply status change to all versions of the server") + _, _ = fmt.Fprintln(os.Stdout) + _, _ = fmt.Fprintln(os.Stdout, "Examples:") + _, _ = fmt.Fprintln(os.Stdout, " # Deprecate a specific version") + _, _ = fmt.Fprintln(os.Stdout, " mcp-publisher status io.github.user/my-server 1.0.0 --status deprecated \\") + _, _ = fmt.Fprintln(os.Stdout, " --message \"Please upgrade to 2.0.0\"") + _, _ = fmt.Fprintln(os.Stdout) + _, _ = fmt.Fprintln(os.Stdout, " # Yank a version with security issues") + _, _ = fmt.Fprintln(os.Stdout, " mcp-publisher status io.github.user/my-server 1.0.0 --status yanked \\") + _, _ = fmt.Fprintln(os.Stdout, " --message \"Critical security vulnerability\"") + _, _ = fmt.Fprintln(os.Stdout) + _, _ = fmt.Fprintln(os.Stdout, " # Restore a version to active") + _, _ = fmt.Fprintln(os.Stdout, " mcp-publisher status io.github.user/my-server 1.0.0 --status active") + _, _ = fmt.Fprintln(os.Stdout) + _, _ = fmt.Fprintln(os.Stdout, " # Deprecate all versions") + _, _ = fmt.Fprintln(os.Stdout, " mcp-publisher status io.github.user/my-server --all-versions --status deprecated \\") + _, _ = fmt.Fprintln(os.Stdout, " --message \"Project archived\"") + _, _ = fmt.Fprintln(os.Stdout) + _, _ = fmt.Fprintln(os.Stdout, " # Deprecate with new name (rename)") + _, _ = fmt.Fprintln(os.Stdout, " mcp-publisher status io.github.user/my-server --all-versions --status deprecated \\") + _, _ = fmt.Fprintln(os.Stdout, " --new-name \"io.github.company/my-server\" \\") + _, _ = fmt.Fprintln(os.Stdout, " --message \"Moved to company organization\"") + _, _ = fmt.Fprintln(os.Stdout) + _, _ = fmt.Fprintln(os.Stdout, "You must be logged in before updating status. Run 'mcp-publisher login' first.") + default: fmt.Fprintf(os.Stderr, "Unknown command: %s\n", command) printUsage() diff --git a/internal/api/handlers/v0/edit.go b/internal/api/handlers/v0/edit.go index 9c0286d0c..724941d4e 100644 --- a/internal/api/handlers/v0/edit.go +++ b/internal/api/handlers/v0/edit.go @@ -8,12 +8,12 @@ import ( "strings" "github.com/danielgtaylor/huma/v2" + "github.com/modelcontextprotocol/registry/internal/auth" "github.com/modelcontextprotocol/registry/internal/config" "github.com/modelcontextprotocol/registry/internal/database" "github.com/modelcontextprotocol/registry/internal/service" apiv0 "github.com/modelcontextprotocol/registry/pkg/api/v0" - "github.com/modelcontextprotocol/registry/pkg/model" ) // EditServerInput represents the input for editing a server @@ -21,7 +21,6 @@ type EditServerInput struct { Authorization string `header:"Authorization" doc:"Registry JWT token with edit permissions" required:"true"` ServerName string `path:"serverName" doc:"URL-encoded server name" example:"com.example%2Fmy-server"` Version string `path:"version" doc:"URL-encoded version to edit" example:"1.0.0"` - Status string `query:"status" doc:"New status for the server (active, deprecated, deleted)" required:"false" enum:"active,deprecated,deleted"` Body apiv0.ServerJSON `body:""` } @@ -35,8 +34,8 @@ func RegisterEditEndpoints(api huma.API, pathPrefix string, registry service.Reg Method: http.MethodPut, Path: pathPrefix + "/servers/{serverName}/versions/{version}", Summary: "Edit MCP server", - Description: "Update a specific version of an existing MCP server (admin only).", - Tags: []string{"admin"}, + Description: "Update the configuration of a specific version of an existing MCP server. Requires edit permission for the server. Use PATCH /servers/{serverName}/versions/{version}/status to update status metadata.", + Tags: []string{"servers"}, Security: []map[string][]string{ {"bearer": {}}, }, @@ -91,28 +90,8 @@ func RegisterEditEndpoints(api huma.API, pathPrefix string, registry service.Reg return nil, huma.Error400BadRequest("Version in request body must match URL path parameter") } - // Handle status changes with proper permission validation - if input.Status != "" { - newStatus := model.Status(input.Status) - - // Prevent undeleting servers - once deleted, they stay deleted - if currentServer.Meta.Official != nil && - currentServer.Meta.Official.Status == model.StatusDeleted && - newStatus != model.StatusDeleted { - return nil, huma.Error400BadRequest("Cannot change status of deleted server. Deleted servers cannot be undeleted.") - } - - // For now, only allow status changes for admins - // Future: Implement logic to allow server authors to change active <-> deprecated - // but only admins can set to deleted - } - - // Update the server using the service - var statusPtr *string - if input.Status != "" { - statusPtr = &input.Status - } - updatedServer, err := registry.UpdateServer(ctx, serverName, version, &input.Body, statusPtr) + // Update the server using the service (no status change - use the status endpoint for that) + updatedServer, err := registry.UpdateServer(ctx, serverName, version, &input.Body, nil) if err != nil { if errors.Is(err, database.ErrNotFound) { return nil, huma.Error404NotFound("Server not found") diff --git a/internal/api/handlers/v0/edit_test.go b/internal/api/handlers/v0/edit_test.go index 66089db22..78699563f 100644 --- a/internal/api/handlers/v0/edit_test.go +++ b/internal/api/handlers/v0/edit_test.go @@ -71,25 +71,6 @@ func TestEditServerEndpoint(t *testing.T) { require.NoError(t, err) } - // Create a deleted server for undelete testing - deletedServer := &apiv0.ServerJSON{ - Schema: model.CurrentSchemaURL, - Name: "io.github.testuser/deleted-server", - Description: "Server that was deleted", - Version: "1.0.0", - Repository: &model.Repository{ - URL: "https://github.com/testuser/deleted-server", - Source: "github", - ID: "testuser/deleted-server", - }, - } - _, err = registryService.CreateServer(context.Background(), deletedServer) - require.NoError(t, err) - - // Set the server to deleted status - _, err = registryService.UpdateServer(context.Background(), deletedServer.Name, deletedServer.Version, deletedServer, stringPtr(string(model.StatusDeleted))) - require.NoError(t, err) - // Create a server with build metadata for URL encoding test buildMetadataServer := &apiv0.ServerJSON{ Schema: model.CurrentSchemaURL, @@ -112,7 +93,6 @@ func TestEditServerEndpoint(t *testing.T) { authClaims *auth.JWTClaims authHeader string requestBody apiv0.ServerJSON - statusParam string expectedStatus int expectedError string checkResult func(*testing.T, *apiv0.ServerResponse) @@ -148,31 +128,6 @@ func TestEditServerEndpoint(t *testing.T) { assert.NotNil(t, resp.Meta.Official) }, }, - { - name: "successful edit with status change", - serverName: "io.github.testuser/editable-server", - version: "1.0.0", - authClaims: &auth.JWTClaims{ - AuthMethod: auth.MethodGitHubAT, - AuthMethodSubject: "testuser", - Permissions: []auth.Permission{ - {Action: auth.PermissionActionEdit, ResourcePattern: "io.github.testuser/*"}, - }, - }, - requestBody: apiv0.ServerJSON{ - Schema: model.CurrentSchemaURL, - Name: "io.github.testuser/editable-server", - Description: "Server with status change", - Version: "1.0.0", - }, - statusParam: "deprecated", - expectedStatus: http.StatusOK, - checkResult: func(t *testing.T, resp *apiv0.ServerResponse) { - t.Helper() - assert.Equal(t, "Server with status change", resp.Server.Description) - assert.Equal(t, model.StatusDeprecated, resp.Meta.Official.Status) - }, - }, { name: "missing authorization header", serverName: "io.github.testuser/editable-server", @@ -310,27 +265,6 @@ func TestEditServerEndpoint(t *testing.T) { expectedStatus: http.StatusBadRequest, expectedError: "Version in request body must match URL path parameter", }, - { - name: "attempt to undelete server should fail", - serverName: "io.github.testuser/deleted-server", - version: "1.0.0", - authClaims: &auth.JWTClaims{ - AuthMethod: auth.MethodGitHubAT, - AuthMethodSubject: "testuser", - Permissions: []auth.Permission{ - {Action: auth.PermissionActionEdit, ResourcePattern: "io.github.testuser/*"}, - }, - }, - requestBody: apiv0.ServerJSON{ - Schema: model.CurrentSchemaURL, - Name: "io.github.testuser/deleted-server", - Description: "Trying to undelete server", - Version: "1.0.0", - }, - statusParam: "active", // Trying to change from deleted to active - expectedStatus: http.StatusBadRequest, - expectedError: "Cannot change status of deleted server", - }, { name: "successful edit of version with build metadata (URL encoded)", serverName: "io.github.testuser/build-metadata-server", @@ -381,9 +315,6 @@ func TestEditServerEndpoint(t *testing.T) { encodedServerName := url.PathEscape(tc.serverName) encodedVersion := url.PathEscape(tc.version) requestURL := "/v0/servers/" + encodedServerName + "/versions/" + encodedVersion - if tc.statusParam != "" { - requestURL += "?status=" + tc.statusParam - } req := httptest.NewRequest(http.MethodPut, requestURL, bytes.NewReader(requestBody)) req.Header.Set("Content-Type", "application/json") @@ -404,6 +335,9 @@ func TestEditServerEndpoint(t *testing.T) { mux.ServeHTTP(w, req) // Check response + if tc.expectedStatus != w.Code { + t.Logf("Response body: %s", w.Body.String()) + } assert.Equal(t, tc.expectedStatus, w.Code) if tc.expectedError != "" { @@ -437,12 +371,10 @@ func TestEditServerEndpointEdgeCases(t *testing.T) { testServers := []struct { name string version string - status model.Status }{ - {"com.example/active-server", "1.0.0", model.StatusActive}, - {"com.example/deprecated-server", "1.0.0", model.StatusDeprecated}, - {"com.example/multi-version-server", "1.0.0", model.StatusActive}, - {"com.example/multi-version-server", "2.0.0", model.StatusActive}, + {"com.example/active-server", "1.0.0"}, + {"com.example/multi-version-server", "1.0.0"}, + {"com.example/multi-version-server", "2.0.0"}, } for _, server := range testServers { @@ -453,17 +385,6 @@ func TestEditServerEndpointEdgeCases(t *testing.T) { Version: server.version, }) require.NoError(t, err) - - // Set specific status if not active - if server.status != model.StatusActive { - _, err = registryService.UpdateServer(context.Background(), server.name, server.version, &apiv0.ServerJSON{ - Schema: model.CurrentSchemaURL, - Name: server.name, - Description: "Test server for editing", - Version: server.version, - }, stringPtr(string(server.status))) - require.NoError(t, err) - } } // Create API @@ -471,95 +392,6 @@ func TestEditServerEndpointEdgeCases(t *testing.T) { api := humago.New(mux, huma.DefaultConfig("Test API", "1.0.0")) v0.RegisterEditEndpoints(api, "/v0", registryService, cfg) - t.Run("status transitions", func(t *testing.T) { - tests := []struct { - name string - serverName string - version string - fromStatus string - toStatus string - expectedStatus int - expectedError string - }{ - { - name: "active to deprecated", - serverName: "com.example/active-server", - version: "1.0.0", - toStatus: "deprecated", - expectedStatus: http.StatusOK, - }, - { - name: "deprecated to active", - serverName: "com.example/deprecated-server", - version: "1.0.0", - toStatus: "active", - expectedStatus: http.StatusOK, - }, - { - name: "active to deleted", - serverName: "com.example/active-server", - version: "1.0.0", - toStatus: "deleted", - expectedStatus: http.StatusOK, - }, - { - name: "invalid status", - serverName: "com.example/active-server", - version: "1.0.0", - toStatus: "invalid_status", - expectedStatus: http.StatusUnprocessableEntity, - expectedError: "validation failed", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - requestBody := apiv0.ServerJSON{ - Schema: model.CurrentSchemaURL, - Name: tt.serverName, - Description: "Status transition test", - Version: tt.version, - } - - bodyBytes, err := json.Marshal(requestBody) - require.NoError(t, err) - - encodedName := url.PathEscape(tt.serverName) - requestURL := "/v0/servers/" + encodedName + "/versions/" + tt.version + "?status=" + tt.toStatus - - req := httptest.NewRequest(http.MethodPut, requestURL, bytes.NewReader(bodyBytes)) - req.Header.Set("Content-Type", "application/json") - - // Generate admin token - jwtManager := auth.NewJWTManager(cfg) - tokenResponse, err := jwtManager.GenerateTokenResponse(context.Background(), auth.JWTClaims{ - AuthMethod: auth.MethodNone, - Permissions: []auth.Permission{ - {Action: auth.PermissionActionEdit, ResourcePattern: "*"}, - }, - }) - require.NoError(t, err) - req.Header.Set("Authorization", "Bearer "+tokenResponse.RegistryToken) - - w := httptest.NewRecorder() - mux.ServeHTTP(w, req) - - assert.Equal(t, tt.expectedStatus, w.Code) - - if tt.expectedError != "" { - assert.Contains(t, w.Body.String(), tt.expectedError) - } - - if tt.expectedStatus == http.StatusOK { - var response apiv0.ServerResponse - err := json.NewDecoder(w.Body).Decode(&response) - require.NoError(t, err) - assert.Equal(t, model.Status(tt.toStatus), response.Meta.Official.Status) - } - }) - } - }) - t.Run("URL encoding edge cases", func(t *testing.T) { // Create server with special characters specialServerName := "io.dots.and-dashes/server_with_underscores" @@ -655,9 +487,66 @@ func TestEditServerEndpointEdgeCases(t *testing.T) { require.NoError(t, err) assert.NotEqual(t, "Updated v1.0.0 specifically", otherVersion.Server.Description) }) -} -// Helper function -func stringPtr(s string) *string { - return &s + t.Run("edit preserves status metadata", func(t *testing.T) { + // Create a server and set it to deprecated using the service directly + deprecatedServer := &apiv0.ServerJSON{ + Schema: model.CurrentSchemaURL, + Name: "com.example/deprecated-for-edit-test", + Description: "Server to test edit preserves status", + Version: "1.0.0", + } + _, err := registryService.CreateServer(context.Background(), deprecatedServer) + require.NoError(t, err) + + // Set to deprecated status using UpdateServerStatus + statusMsg := "This server is deprecated" + _, err = registryService.UpdateServerStatus(context.Background(), deprecatedServer.Name, deprecatedServer.Version, &service.StatusChangeRequest{ + NewStatus: model.StatusDeprecated, + StatusMessage: &statusMsg, + }) + require.NoError(t, err) + + // Now edit the server description + requestBody := apiv0.ServerJSON{ + Schema: model.CurrentSchemaURL, + Name: "com.example/deprecated-for-edit-test", + Description: "Updated description but status should remain deprecated", + Version: "1.0.0", + } + + bodyBytes, err := json.Marshal(requestBody) + require.NoError(t, err) + + encodedName := url.PathEscape("com.example/deprecated-for-edit-test") + requestURL := "/v0/servers/" + encodedName + "/versions/1.0.0" + + req := httptest.NewRequest(http.MethodPut, requestURL, bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + + jwtManager := auth.NewJWTManager(cfg) + tokenResponse, err := jwtManager.GenerateTokenResponse(context.Background(), auth.JWTClaims{ + AuthMethod: auth.MethodNone, + Permissions: []auth.Permission{ + {Action: auth.PermissionActionEdit, ResourcePattern: "*"}, + }, + }) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+tokenResponse.RegistryToken) + + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var response apiv0.ServerResponse + err = json.NewDecoder(w.Body).Decode(&response) + require.NoError(t, err) + assert.Equal(t, "Updated description but status should remain deprecated", response.Server.Description) + // Status should still be deprecated + assert.Equal(t, model.StatusDeprecated, response.Meta.Official.Status) + // Status message should be preserved + assert.NotNil(t, response.Meta.Official.StatusMessage) + assert.Equal(t, "This server is deprecated", *response.Meta.Official.StatusMessage) + }) } diff --git a/internal/api/handlers/v0/servers.go b/internal/api/handlers/v0/servers.go index f9f7ba5f7..bcc0bef50 100644 --- a/internal/api/handlers/v0/servers.go +++ b/internal/api/handlers/v0/servers.go @@ -18,11 +18,12 @@ const errRecordNotFound = "record not found" // ListServersInput represents the input for listing servers type ListServersInput struct { - Cursor string `query:"cursor" doc:"Pagination cursor" required:"false" example:"server-cursor-123"` - Limit int `query:"limit" doc:"Number of items per page" default:"30" minimum:"1" maximum:"100" example:"50"` - UpdatedSince string `query:"updated_since" doc:"Filter servers updated since timestamp (RFC3339 datetime)" required:"false" example:"2025-08-07T13:15:04.280Z"` - Search string `query:"search" doc:"Search servers by name (substring match)" required:"false" example:"filesystem"` - Version string `query:"version" doc:"Filter by version ('latest' for latest version, or an exact version like '1.2.3')" required:"false" example:"latest"` + Cursor string `query:"cursor" doc:"Pagination cursor" required:"false" example:"server-cursor-123"` + Limit int `query:"limit" doc:"Number of items per page" default:"30" minimum:"1" maximum:"100" example:"50"` + UpdatedSince string `query:"updated_since" doc:"Filter servers updated since timestamp (RFC3339 datetime)" required:"false" example:"2025-08-07T13:15:04.280Z"` + Search string `query:"search" doc:"Search servers by name (substring match)" required:"false" example:"filesystem"` + Version string `query:"version" doc:"Filter by version ('latest' for latest version, or an exact version like '1.2.3')" required:"false" example:"latest"` + IncludeYanked bool `query:"include_yanked" doc:"Include yanked servers in results (default: false, but always true when updated_since is provided)" required:"false" default:"false"` } // ServerDetailInput represents the input for getting server details @@ -82,6 +83,15 @@ func RegisterServersEndpoints(api huma.API, pathPrefix string, registry service. } } + // Handle include_yanked parameter + // When updated_since is provided, always include yanked for incremental sync + if filter.UpdatedSince != nil { + includeYanked := true + filter.IncludeYanked = &includeYanked + } else { + filter.IncludeYanked = &input.IncludeYanked + } + // Get paginated results with filtering servers, nextCursor, err := registry.ListServers(ctx, filter, input.Cursor, input.Limit) if err != nil { diff --git a/internal/api/handlers/v0/servers_test.go b/internal/api/handlers/v0/servers_test.go index e0d19a031..bff677d8c 100644 --- a/internal/api/handlers/v0/servers_test.go +++ b/internal/api/handlers/v0/servers_test.go @@ -392,6 +392,117 @@ func TestGetAllVersionsEndpoint(t *testing.T) { } } +func TestListServersYankedFiltering(t *testing.T) { + ctx := context.Background() + registryService := service.NewRegistryService(database.NewTestDB(t), config.NewConfig()) + + // Setup test data: 2 active servers and 1 yanked server + _, err := registryService.CreateServer(ctx, &apiv0.ServerJSON{ + Schema: model.CurrentSchemaURL, + Name: "com.example/active-server-1", + Description: "Active server 1", + Version: "1.0.0", + }) + require.NoError(t, err) + + _, err = registryService.CreateServer(ctx, &apiv0.ServerJSON{ + Schema: model.CurrentSchemaURL, + Name: "com.example/active-server-2", + Description: "Active server 2", + Version: "1.0.0", + }) + require.NoError(t, err) + + _, err = registryService.CreateServer(ctx, &apiv0.ServerJSON{ + Schema: model.CurrentSchemaURL, + Name: "com.example/yanked-server", + Description: "Yanked server", + Version: "1.0.0", + }) + require.NoError(t, err) + + // Yank the third server + _, err = registryService.UpdateServerStatus(ctx, "com.example/yanked-server", "1.0.0", &service.StatusChangeRequest{ + NewStatus: model.StatusYanked, + }) + require.NoError(t, err) + + // Create API + mux := http.NewServeMux() + api := humago.New(mux, huma.DefaultConfig("Test API", "1.0.0")) + v0.RegisterServersEndpoints(api, "/v0", registryService) + + tests := []struct { + name string + queryParams string + expectedStatus int + expectedCount int + checkYanked bool // whether yanked server should be in results + }{ + { + name: "default excludes yanked servers", + queryParams: "", + expectedStatus: http.StatusOK, + expectedCount: 2, + checkYanked: false, + }, + { + name: "include_yanked=false excludes yanked servers", + queryParams: "?include_yanked=false", + expectedStatus: http.StatusOK, + expectedCount: 2, + checkYanked: false, + }, + { + name: "include_yanked=true includes yanked servers", + queryParams: "?include_yanked=true", + expectedStatus: http.StatusOK, + expectedCount: 3, + checkYanked: true, + }, + { + name: "updated_since always includes yanked servers", + queryParams: "?updated_since=1990-01-01T00:00:00Z", + expectedStatus: http.StatusOK, + expectedCount: 3, + checkYanked: true, + }, + { + name: "updated_since overrides include_yanked=false", + queryParams: "?updated_since=1990-01-01T00:00:00Z&include_yanked=false", + expectedStatus: http.StatusOK, + expectedCount: 3, + checkYanked: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v0/servers"+tt.queryParams, nil) + w := httptest.NewRecorder() + + mux.ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, w.Code) + + var resp apiv0.ServerListResponse + err := json.NewDecoder(w.Body).Decode(&resp) + assert.NoError(t, err) + assert.Len(t, resp.Servers, tt.expectedCount) + + // Check if yanked server is in results + hasYanked := false + for _, server := range resp.Servers { + if server.Server.Name == "com.example/yanked-server" { + hasYanked = true + assert.Equal(t, model.StatusYanked, server.Meta.Official.Status) + } + } + assert.Equal(t, tt.checkYanked, hasYanked, "Yanked server presence mismatch") + }) + } +} + func TestServersEndpointEdgeCases(t *testing.T) { ctx := context.Background() registryService := service.NewRegistryService(database.NewTestDB(t), config.NewConfig()) @@ -514,7 +625,7 @@ func TestServersEndpointEdgeCases(t *testing.T) { assert.NotNil(t, server.Meta) assert.NotNil(t, server.Meta.Official) assert.NotZero(t, server.Meta.Official.PublishedAt) - assert.Contains(t, []model.Status{model.StatusActive, model.StatusDeprecated, model.StatusDeleted}, server.Meta.Official.Status) + assert.Contains(t, []model.Status{model.StatusActive, model.StatusDeprecated, model.StatusYanked}, server.Meta.Official.Status) } }) } diff --git a/internal/api/handlers/v0/status.go b/internal/api/handlers/v0/status.go new file mode 100644 index 000000000..c813c7381 --- /dev/null +++ b/internal/api/handlers/v0/status.go @@ -0,0 +1,342 @@ +package v0 + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/danielgtaylor/huma/v2" + + "github.com/modelcontextprotocol/registry/internal/auth" + "github.com/modelcontextprotocol/registry/internal/config" + "github.com/modelcontextprotocol/registry/internal/database" + "github.com/modelcontextprotocol/registry/internal/service" + apiv0 "github.com/modelcontextprotocol/registry/pkg/api/v0" + "github.com/modelcontextprotocol/registry/pkg/model" +) + +// UpdateServerStatusBody represents the request body for updating server status +type UpdateServerStatusBody struct { + Status string `json:"status" required:"true" enum:"active,deprecated,yanked" doc:"New server lifecycle status"` + StatusMessage *string `json:"statusMessage,omitempty" maxLength:"500" doc:"Optional message explaining the status change (e.g., reason for deprecation)"` + AlternativeURL *string `json:"alternativeUrl,omitempty" format:"uri" doc:"Optional URL to an alternative/replacement server for deprecated or yanked servers"` + NewName *string `json:"newName,omitempty" doc:"Optional new server name when server has been renamed/moved"` +} + +// UpdateServerStatusInput represents the input for updating server status +type UpdateServerStatusInput struct { + Authorization string `header:"Authorization" doc:"Registry JWT token with edit permissions" required:"true"` + ServerName string `path:"serverName" doc:"URL-encoded server name" example:"com.example%2Fmy-server"` + Version string `path:"version" doc:"URL-encoded version to update" example:"1.0.0"` + Body UpdateServerStatusBody `body:""` +} + +// validateStatusTransition validates if the status transition is allowed +func validateStatusTransition(currentServer *apiv0.ServerResponse, newStatus model.Status, body UpdateServerStatusBody) error { + if currentServer.Meta.Official == nil { + return nil + } + + currentStatus := currentServer.Meta.Official.Status + isSameStatus := currentStatus == newStatus + + // Reject same-status requests with no metadata updates (pointless no-op) + if isSameStatus && !hasMetadataFieldsToUpdate(body) { + return huma.Error400BadRequest(fmt.Sprintf("No changes to apply: status is already %s", currentStatus)) + } + + // Reject invalid status transitions (e.g., invalid status values) + if !isSameStatus && !isValidStatusTransition(currentStatus, newStatus) { + return huma.Error400BadRequest(fmt.Sprintf("Invalid status transition from %s to %s", currentStatus, newStatus)) + } + + return nil +} + +// RegisterStatusEndpoints registers the status update endpoint with a custom path prefix +func RegisterStatusEndpoints(api huma.API, pathPrefix string, registry service.RegistryService, cfg *config.Config) { + jwtManager := auth.NewJWTManager(cfg) + + // Update server status endpoint + huma.Register(api, huma.Operation{ + OperationID: "update-server-status" + strings.ReplaceAll(pathPrefix, "/", "-"), + Method: http.MethodPatch, + Path: pathPrefix + "/servers/{serverName}/versions/{version}/status", + Summary: "Update MCP server status", + Description: "Update the status metadata of a specific version of an MCP server. Requires edit permission for the server. This endpoint allows changing status, status message, alternative URL, and new name without requiring the full server configuration.", + Tags: []string{"servers"}, + Security: []map[string][]string{ + {"bearer": {}}, + }, + }, func(ctx context.Context, input *UpdateServerStatusInput) (*Response[apiv0.ServerResponse], error) { + // Extract bearer token + const bearerPrefix = "Bearer " + authHeader := input.Authorization + if len(authHeader) < len(bearerPrefix) || !strings.EqualFold(authHeader[:len(bearerPrefix)], bearerPrefix) { + return nil, huma.Error401Unauthorized("Invalid Authorization header format. Expected 'Bearer '") + } + token := authHeader[len(bearerPrefix):] + + // Validate Registry JWT token + claims, err := jwtManager.ValidateToken(ctx, token) + if err != nil { + return nil, huma.Error401Unauthorized("Invalid or expired Registry JWT token", err) + } + + // URL-decode the server name + serverName, err := url.PathUnescape(input.ServerName) + if err != nil { + return nil, huma.Error400BadRequest("Invalid server name encoding", err) + } + + // URL-decode the version + version, err := url.PathUnescape(input.Version) + if err != nil { + return nil, huma.Error400BadRequest("Invalid version encoding", err) + } + + newStatus := model.Status(input.Body.Status) + + // Get all versions - gives us both the server data and version count in one DB call + allVersions, err := registry.GetAllVersionsByServerName(ctx, serverName) + if err != nil { + if errors.Is(err, database.ErrNotFound) { + return nil, huma.Error404NotFound("Server not found") + } + return nil, huma.Error500InternalServerError("Failed to get server versions", err) + } + + // Check newName validation first (before version check) for better error messages + hasNewName := input.Body.NewName != nil && *input.Body.NewName != "" + if hasNewName && len(allVersions) > 1 { + return nil, huma.Error400BadRequest("new_name cannot be used with single version endpoint when server has multiple versions. Use the all-versions endpoint instead: PATCH /servers/{serverName}/status") + } + + // Find the requested version + var currentServer *apiv0.ServerResponse + for _, v := range allVersions { + if v.Server.Version == version { + currentServer = v + break + } + } + if currentServer == nil { + return nil, huma.Error404NotFound("Server version not found") + } + + // Verify edit permissions for this server + if !jwtManager.HasPermission(currentServer.Server.Name, auth.PermissionActionEdit, claims.Permissions) { + return nil, huma.Error403Forbidden("You do not have edit permissions for this server") + } + + // Validate newName if provided + if hasNewName { + if err := validateNewNameForStatus(ctx, input.Body, serverName, registry, jwtManager, claims); err != nil { + return nil, err + } + } + + // Validate status transition is allowed + if err := validateStatusTransition(currentServer, newStatus, input.Body); err != nil { + return nil, err + } + + // Build status change request + statusChange := buildStatusChangeRequestFromBody(input.Body) + + // Update the server status using the service + updatedServer, err := registry.UpdateServerStatus(ctx, serverName, version, statusChange) + if err != nil { + if errors.Is(err, database.ErrNotFound) { + return nil, huma.Error404NotFound("Server not found") + } + return nil, huma.Error400BadRequest("Failed to update server status", err) + } + + return &Response[apiv0.ServerResponse]{ + Body: *updatedServer, + }, nil + }) +} + +// validateNewNameForStatus validates the new_name parameter for the status endpoint +func validateNewNameForStatus(ctx context.Context, body UpdateServerStatusBody, currentServerName string, registry service.RegistryService, jwtManager *auth.JWTManager, claims *auth.JWTClaims) error { + newName := *body.NewName + + // Validation: new_name cannot be the same as the current server name + if newName == currentServerName { + return huma.Error400BadRequest("new_name cannot be the same as the current server name") + } + + // Validation: new_name can only be used with deprecated or yanked status + if body.Status != string(model.StatusDeprecated) && body.Status != string(model.StatusYanked) { + return huma.Error400BadRequest("new_name can only be used with deprecated or yanked status") + } + + // Validation: Check that the new server exists + newServer, err := registry.GetServerByName(ctx, newName) + if err != nil { + if errors.Is(err, database.ErrNotFound) { + return huma.Error400BadRequest(fmt.Sprintf("New server '%s' does not exist in the registry", newName)) + } + return huma.Error500InternalServerError("Failed to validate new server name", err) + } + + // Validation: Check that the user has publish permissions for the new server + if !jwtManager.HasPermission(newServer.Server.Name, auth.PermissionActionPublish, claims.Permissions) { + return huma.Error403Forbidden(fmt.Sprintf("You do not have permissions for the new server '%s'", newName)) + } + + return nil +} + +// buildStatusChangeRequestFromBody constructs a StatusChangeRequest from the request body +func buildStatusChangeRequestFromBody(body UpdateServerStatusBody) *service.StatusChangeRequest { + var statusMessage *string + var alternativeURL *string + var newName *string + + newStatus := model.Status(body.Status) + + // When transitioning to active status, clear status_message, alternative_url, and new_name + if newStatus != model.StatusActive { + statusMessage = body.StatusMessage + alternativeURL = body.AlternativeURL + newName = body.NewName + } + + return &service.StatusChangeRequest{ + NewStatus: newStatus, + StatusMessage: statusMessage, + AlternativeURL: alternativeURL, + NewName: newName, + } +} + +// hasMetadataFieldsToUpdate checks if any metadata fields (statusMessage, alternativeURL, newName) are being updated +func hasMetadataFieldsToUpdate(body UpdateServerStatusBody) bool { + return (body.StatusMessage != nil && *body.StatusMessage != "") || + (body.AlternativeURL != nil && *body.AlternativeURL != "") || + (body.NewName != nil && *body.NewName != "") +} + +// UpdateAllVersionsStatusInput represents the input for updating all versions' status +type UpdateAllVersionsStatusInput struct { + Authorization string `header:"Authorization" doc:"Registry JWT token with edit permissions" required:"true"` + ServerName string `path:"serverName" doc:"URL-encoded server name" example:"com.example%2Fmy-server"` + Body UpdateServerStatusBody `body:""` +} + +// UpdateAllVersionsStatusResponse represents the response for updating all versions' status +type UpdateAllVersionsStatusResponse struct { + UpdatedCount int `json:"updatedCount" doc:"Number of versions updated"` + Servers []apiv0.ServerResponse `json:"servers" doc:"List of all updated server versions"` +} + +// RegisterAllVersionsStatusEndpoints registers the all-versions status update endpoint +func RegisterAllVersionsStatusEndpoints(api huma.API, pathPrefix string, registry service.RegistryService, cfg *config.Config) { + jwtManager := auth.NewJWTManager(cfg) + + // Update all versions status endpoint + huma.Register(api, huma.Operation{ + OperationID: "update-server-all-versions-status" + strings.ReplaceAll(pathPrefix, "/", "-"), + Method: http.MethodPatch, + Path: pathPrefix + "/servers/{serverName}/status", + Summary: "Update status for all versions of an MCP server", + Description: "Update the status metadata of all versions of an MCP server in a single transaction. Requires edit permission for the server. Either all versions are updated or none on failure.", + Tags: []string{"servers"}, + Security: []map[string][]string{ + {"bearer": {}}, + }, + }, func(ctx context.Context, input *UpdateAllVersionsStatusInput) (*Response[UpdateAllVersionsStatusResponse], error) { + // Extract bearer token + const bearerPrefix = "Bearer " + authHeader := input.Authorization + if len(authHeader) < len(bearerPrefix) || !strings.EqualFold(authHeader[:len(bearerPrefix)], bearerPrefix) { + return nil, huma.Error401Unauthorized("Invalid Authorization header format. Expected 'Bearer '") + } + token := authHeader[len(bearerPrefix):] + + // Validate Registry JWT token + claims, err := jwtManager.ValidateToken(ctx, token) + if err != nil { + return nil, huma.Error401Unauthorized("Invalid or expired Registry JWT token", err) + } + + // URL-decode the server name + serverName, err := url.PathUnescape(input.ServerName) + if err != nil { + return nil, huma.Error400BadRequest("Invalid server name encoding", err) + } + + // Get any version to verify server exists and check permissions + currentServer, err := registry.GetServerByName(ctx, serverName) + if err != nil { + if errors.Is(err, database.ErrNotFound) { + return nil, huma.Error404NotFound("Server not found") + } + return nil, huma.Error500InternalServerError("Failed to get server", err) + } + + // Verify edit permissions for this server + if !jwtManager.HasPermission(currentServer.Server.Name, auth.PermissionActionEdit, claims.Permissions) { + return nil, huma.Error403Forbidden("You do not have edit permissions for this server") + } + + // Validate newName if provided + if input.Body.NewName != nil && *input.Body.NewName != "" { + if err := validateNewNameForStatus(ctx, input.Body, serverName, registry, jwtManager, claims); err != nil { + return nil, err + } + } + + // Build status change request + statusChange := buildStatusChangeRequestFromBody(input.Body) + + // Update all versions' status using the service + updatedServers, err := registry.UpdateAllVersionsStatus(ctx, serverName, statusChange) + if err != nil { + if errors.Is(err, database.ErrNotFound) { + return nil, huma.Error404NotFound("Server not found") + } + return nil, huma.Error400BadRequest("Failed to update server status", err) + } + + // Convert to response format + servers := make([]apiv0.ServerResponse, len(updatedServers)) + for i, s := range updatedServers { + servers[i] = *s + } + + return &Response[UpdateAllVersionsStatusResponse]{ + Body: UpdateAllVersionsStatusResponse{ + UpdatedCount: len(servers), + Servers: servers, + }, + }, nil + }) +} + +// isValidStatusTransition checks if a status transition is allowed +// Allowed transitions: +// - active ↔ deprecated ↔ yanked (all bidirectional transitions allowed) +// - Same status transitions are NOT allowed (no-op) +func isValidStatusTransition(currentStatus, newStatus model.Status) bool { + // Same status transition is not allowed (no-op) + if currentStatus == newStatus { + return false + } + + // All transitions between active, deprecated, and yanked are allowed + validStatuses := map[model.Status]bool{ + model.StatusActive: true, + model.StatusDeprecated: true, + model.StatusYanked: true, + } + + // Both current and new status must be valid + return validStatuses[currentStatus] && validStatuses[newStatus] +} diff --git a/internal/api/handlers/v0/status_test.go b/internal/api/handlers/v0/status_test.go new file mode 100644 index 000000000..3b763ef45 --- /dev/null +++ b/internal/api/handlers/v0/status_test.go @@ -0,0 +1,1023 @@ +package v0_test + +import ( + "bytes" + "context" + "crypto/ed25519" + "crypto/rand" + "encoding/hex" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/danielgtaylor/huma/v2" + "github.com/danielgtaylor/huma/v2/adapters/humago" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + v0 "github.com/modelcontextprotocol/registry/internal/api/handlers/v0" + "github.com/modelcontextprotocol/registry/internal/auth" + "github.com/modelcontextprotocol/registry/internal/config" + "github.com/modelcontextprotocol/registry/internal/database" + "github.com/modelcontextprotocol/registry/internal/service" + apiv0 "github.com/modelcontextprotocol/registry/pkg/api/v0" + "github.com/modelcontextprotocol/registry/pkg/model" +) + +func TestUpdateServerStatusEndpoint(t *testing.T) { + // Create test config + testSeed := make([]byte, ed25519.SeedSize) + _, err := rand.Read(testSeed) + require.NoError(t, err) + cfg := &config.Config{ + JWTPrivateKey: hex.EncodeToString(testSeed), + EnableRegistryValidation: false, + } + + // Create registry service and test data + registryService := service.NewRegistryService(database.NewTestDB(t), cfg) + + // Create test servers for different scenarios + testServers := map[string]*apiv0.ServerJSON{ + "active": { + Schema: model.CurrentSchemaURL, + Name: "io.github.testuser/active-server", + Description: "Server in active status", + Version: "1.0.0", + Repository: &model.Repository{ + URL: "https://github.com/testuser/active-server", + Source: "github", + ID: "testuser/active-server", + }, + }, + "deprecated": { + Schema: model.CurrentSchemaURL, + Name: "io.github.testuser/deprecated-server", + Description: "Server in deprecated status", + Version: "1.0.0", + Repository: &model.Repository{ + URL: "https://github.com/testuser/deprecated-server", + Source: "github", + ID: "testuser/deprecated-server", + }, + }, + "yanked": { + Schema: model.CurrentSchemaURL, + Name: "io.github.testuser/yanked-server", + Description: "Server in yanked status", + Version: "1.0.0", + Repository: &model.Repository{ + URL: "https://github.com/testuser/yanked-server", + Source: "github", + ID: "testuser/yanked-server", + }, + }, + "other": { + Schema: model.CurrentSchemaURL, + Name: "io.github.otheruser/other-server", + Description: "Server owned by another user", + Version: "1.0.0", + Repository: &model.Repository{ + URL: "https://github.com/otheruser/other-server", + Source: "github", + ID: "otheruser/other-server", + }, + }, + "newserver": { + Schema: model.CurrentSchemaURL, + Name: "io.github.testuser/new-server", + Description: "New server for renaming tests", + Version: "1.0.0", + Repository: &model.Repository{ + URL: "https://github.com/testuser/new-server", + Source: "github", + ID: "testuser/new-server", + }, + }, + "newname-active-test": { + Schema: model.CurrentSchemaURL, + Name: "io.github.testuser/newname-active-test", + Description: "Server for testing newName with active status", + Version: "1.0.0", + Repository: &model.Repository{ + URL: "https://github.com/testuser/newname-active-test", + Source: "github", + ID: "testuser/newname-active-test", + }, + }, + "multi-version": { + Schema: model.CurrentSchemaURL, + Name: "io.github.testuser/multi-version-server", + Description: "Server with multiple versions for testing", + Version: "1.0.0", + Repository: &model.Repository{ + URL: "https://github.com/testuser/multi-version-server", + Source: "github", + ID: "testuser/multi-version-server", + }, + }, + } + + // Create the test servers + for _, server := range testServers { + _, err := registryService.CreateServer(context.Background(), server) + require.NoError(t, err) + } + + // Set deprecated server to deprecated status + _, err = registryService.UpdateServerStatus(context.Background(), testServers["deprecated"].Name, testServers["deprecated"].Version, &service.StatusChangeRequest{ + NewStatus: model.StatusDeprecated, + }) + require.NoError(t, err) + + // Set yanked server to yanked status + _, err = registryService.UpdateServerStatus(context.Background(), testServers["yanked"].Name, testServers["yanked"].Version, &service.StatusChangeRequest{ + NewStatus: model.StatusYanked, + }) + require.NoError(t, err) + + // Set newname-active-test server to deprecated for the newName rejection test + _, err = registryService.UpdateServerStatus(context.Background(), testServers["newname-active-test"].Name, testServers["newname-active-test"].Version, &service.StatusChangeRequest{ + NewStatus: model.StatusDeprecated, + }) + require.NoError(t, err) + + // Add a second version to multi-version server + multiVersionV2 := &apiv0.ServerJSON{ + Schema: model.CurrentSchemaURL, + Name: "io.github.testuser/multi-version-server", + Description: "Server with multiple versions for testing", + Version: "2.0.0", + Repository: &model.Repository{ + URL: "https://github.com/testuser/multi-version-server", + Source: "github", + ID: "testuser/multi-version-server", + }, + } + _, err = registryService.CreateServer(context.Background(), multiVersionV2) + require.NoError(t, err) + + testCases := []struct { + name string + serverName string + version string + authClaims *auth.JWTClaims + authHeader string + requestBody v0.UpdateServerStatusBody + expectedStatus int + expectedError string + checkResult func(*testing.T, *apiv0.ServerResponse) + }{ + { + name: "successful status change from active to deprecated", + serverName: "io.github.testuser/active-server", + version: "1.0.0", + authClaims: &auth.JWTClaims{ + AuthMethod: auth.MethodGitHubAT, + AuthMethodSubject: "testuser", + Permissions: []auth.Permission{ + {Action: auth.PermissionActionEdit, ResourcePattern: "io.github.testuser/*"}, + }, + }, + requestBody: v0.UpdateServerStatusBody{ + Status: "deprecated", + }, + expectedStatus: http.StatusOK, + checkResult: func(t *testing.T, resp *apiv0.ServerResponse) { + t.Helper() + assert.Equal(t, model.StatusDeprecated, resp.Meta.Official.Status) + }, + }, + { + name: "successful status change with message and alternative URL", + serverName: "io.github.testuser/active-server", + version: "1.0.0", + authClaims: &auth.JWTClaims{ + AuthMethod: auth.MethodGitHubAT, + AuthMethodSubject: "testuser", + Permissions: []auth.Permission{ + {Action: auth.PermissionActionEdit, ResourcePattern: "io.github.testuser/*"}, + }, + }, + requestBody: v0.UpdateServerStatusBody{ + Status: "yanked", + StatusMessage: strPtr("Security vulnerability discovered"), + AlternativeURL: strPtr("https://example.com/patched-version"), + }, + expectedStatus: http.StatusOK, + checkResult: func(t *testing.T, resp *apiv0.ServerResponse) { + t.Helper() + assert.Equal(t, model.StatusYanked, resp.Meta.Official.Status) + assert.NotNil(t, resp.Meta.Official.StatusMessage) + assert.Equal(t, "Security vulnerability discovered", *resp.Meta.Official.StatusMessage) + assert.NotNil(t, resp.Meta.Official.AlternativeURL) + assert.Equal(t, "https://example.com/patched-version", *resp.Meta.Official.AlternativeURL) + }, + }, + { + name: "successful unyank from yanked to active", + serverName: "io.github.testuser/yanked-server", + version: "1.0.0", + authClaims: &auth.JWTClaims{ + AuthMethod: auth.MethodGitHubAT, + AuthMethodSubject: "testuser", + Permissions: []auth.Permission{ + {Action: auth.PermissionActionEdit, ResourcePattern: "io.github.testuser/*"}, + }, + }, + requestBody: v0.UpdateServerStatusBody{ + Status: "active", + }, + expectedStatus: http.StatusOK, + checkResult: func(t *testing.T, resp *apiv0.ServerResponse) { + t.Helper() + assert.Equal(t, model.StatusActive, resp.Meta.Official.Status) + // Status message and alternative URL should be cleared when transitioning to active + assert.Nil(t, resp.Meta.Official.StatusMessage) + assert.Nil(t, resp.Meta.Official.AlternativeURL) + }, + }, + { + name: "successful undeprecate from deprecated to active", + serverName: "io.github.testuser/deprecated-server", + version: "1.0.0", + authClaims: &auth.JWTClaims{ + AuthMethod: auth.MethodGitHubAT, + AuthMethodSubject: "testuser", + Permissions: []auth.Permission{ + {Action: auth.PermissionActionEdit, ResourcePattern: "io.github.testuser/*"}, + }, + }, + requestBody: v0.UpdateServerStatusBody{ + Status: "active", + }, + expectedStatus: http.StatusOK, + checkResult: func(t *testing.T, resp *apiv0.ServerResponse) { + t.Helper() + assert.Equal(t, model.StatusActive, resp.Meta.Official.Status) + }, + }, + { + name: "missing authorization header", + serverName: "io.github.testuser/active-server", + version: "1.0.0", + authHeader: "", + requestBody: v0.UpdateServerStatusBody{Status: "deprecated"}, + expectedStatus: http.StatusUnprocessableEntity, + expectedError: "required header parameter is missing", + }, + { + name: "invalid authorization header format", + serverName: "io.github.testuser/active-server", + version: "1.0.0", + authHeader: "InvalidFormat token123", + requestBody: v0.UpdateServerStatusBody{ + Status: "deprecated", + }, + expectedStatus: http.StatusUnauthorized, + expectedError: "Invalid Authorization header format", + }, + { + name: "invalid token", + serverName: "io.github.testuser/active-server", + version: "1.0.0", + authHeader: "Bearer invalid-token", + requestBody: v0.UpdateServerStatusBody{ + Status: "deprecated", + }, + expectedStatus: http.StatusUnauthorized, + expectedError: "Invalid or expired Registry JWT token", + }, + { + name: "permission denied - no edit permissions", + serverName: "io.github.testuser/active-server", + version: "1.0.0", + authClaims: &auth.JWTClaims{ + AuthMethod: auth.MethodGitHubAT, + AuthMethodSubject: "testuser", + Permissions: []auth.Permission{ + {Action: auth.PermissionActionPublish, ResourcePattern: "io.github.testuser/*"}, + }, + }, + requestBody: v0.UpdateServerStatusBody{ + Status: "deprecated", + }, + expectedStatus: http.StatusForbidden, + expectedError: "You do not have edit permissions", + }, + { + name: "permission denied - wrong namespace", + serverName: "io.github.otheruser/other-server", + version: "1.0.0", + authClaims: &auth.JWTClaims{ + AuthMethod: auth.MethodGitHubAT, + AuthMethodSubject: "testuser", + Permissions: []auth.Permission{ + {Action: auth.PermissionActionEdit, ResourcePattern: "io.github.testuser/*"}, + }, + }, + requestBody: v0.UpdateServerStatusBody{ + Status: "deprecated", + }, + expectedStatus: http.StatusForbidden, + expectedError: "You do not have edit permissions", + }, + { + name: "server not found", + serverName: "io.github.testuser/non-existent", + version: "1.0.0", + authClaims: &auth.JWTClaims{ + AuthMethod: auth.MethodGitHubAT, + AuthMethodSubject: "testuser", + Permissions: []auth.Permission{ + {Action: auth.PermissionActionEdit, ResourcePattern: "io.github.testuser/*"}, + }, + }, + requestBody: v0.UpdateServerStatusBody{ + Status: "deprecated", + }, + expectedStatus: http.StatusNotFound, + expectedError: "Server not found", + }, + { + name: "newName with valid deprecated status and permissions", + serverName: "io.github.testuser/active-server", + version: "1.0.0", + authClaims: &auth.JWTClaims{ + AuthMethod: auth.MethodGitHubAT, + AuthMethodSubject: "testuser", + Permissions: []auth.Permission{ + {Action: auth.PermissionActionEdit, ResourcePattern: "io.github.testuser/*"}, + {Action: auth.PermissionActionPublish, ResourcePattern: "io.github.testuser/new-server"}, + }, + }, + requestBody: v0.UpdateServerStatusBody{ + Status: "deprecated", + StatusMessage: strPtr("Moved to new server"), + NewName: strPtr("io.github.testuser/new-server"), + }, + expectedStatus: http.StatusOK, + checkResult: func(t *testing.T, resp *apiv0.ServerResponse) { + t.Helper() + assert.Equal(t, model.StatusDeprecated, resp.Meta.Official.Status) + assert.NotNil(t, resp.Meta.Official.NewName) + assert.Equal(t, "io.github.testuser/new-server", *resp.Meta.Official.NewName) + }, + }, + { + name: "newName rejected when transitioning to active", + serverName: "io.github.testuser/newname-active-test", + version: "1.0.0", + authClaims: &auth.JWTClaims{ + AuthMethod: auth.MethodGitHubAT, + AuthMethodSubject: "testuser", + Permissions: []auth.Permission{ + {Action: auth.PermissionActionEdit, ResourcePattern: "io.github.testuser/*"}, + }, + }, + requestBody: v0.UpdateServerStatusBody{ + Status: "active", + NewName: strPtr("io.github.testuser/new-server"), + }, + expectedStatus: http.StatusBadRequest, + expectedError: "new_name can only be used with deprecated or yanked status", + }, + { + name: "newName rejected when same as current server name", + serverName: "io.github.testuser/active-server", + version: "1.0.0", + authClaims: &auth.JWTClaims{ + AuthMethod: auth.MethodGitHubAT, + AuthMethodSubject: "testuser", + Permissions: []auth.Permission{ + {Action: auth.PermissionActionEdit, ResourcePattern: "io.github.testuser/*"}, + }, + }, + requestBody: v0.UpdateServerStatusBody{ + Status: "deprecated", + NewName: strPtr("io.github.testuser/active-server"), + }, + expectedStatus: http.StatusBadRequest, + expectedError: "new_name cannot be the same as the current server name", + }, + { + name: "newName with non-existent server", + serverName: "io.github.testuser/deprecated-server", + version: "1.0.0", + authClaims: &auth.JWTClaims{ + AuthMethod: auth.MethodGitHubAT, + AuthMethodSubject: "testuser", + Permissions: []auth.Permission{ + {Action: auth.PermissionActionEdit, ResourcePattern: "io.github.testuser/*"}, + }, + }, + requestBody: v0.UpdateServerStatusBody{ + Status: "yanked", + NewName: strPtr("io.github.testuser/non-existent-server"), + }, + expectedStatus: http.StatusBadRequest, + expectedError: "New server 'io.github.testuser/non-existent-server' does not exist in the registry", + }, + { + name: "newName with server belonging to different user without permissions", + serverName: "io.github.testuser/deprecated-server", + version: "1.0.0", + authClaims: &auth.JWTClaims{ + AuthMethod: auth.MethodGitHubAT, + AuthMethodSubject: "testuser", + Permissions: []auth.Permission{ + {Action: auth.PermissionActionEdit, ResourcePattern: "io.github.testuser/*"}, + }, + }, + requestBody: v0.UpdateServerStatusBody{ + Status: "yanked", + NewName: strPtr("io.github.otheruser/other-server"), + }, + expectedStatus: http.StatusForbidden, + expectedError: "You do not have permissions for the new server 'io.github.otheruser/other-server'", + }, + { + name: "same status transition allowed when updating newName", + serverName: "io.github.testuser/deprecated-server", + version: "1.0.0", + authClaims: &auth.JWTClaims{ + AuthMethod: auth.MethodGitHubAT, + AuthMethodSubject: "testuser", + Permissions: []auth.Permission{ + {Action: auth.PermissionActionEdit, ResourcePattern: "io.github.testuser/*"}, + {Action: auth.PermissionActionPublish, ResourcePattern: "io.github.testuser/new-server"}, + }, + }, + requestBody: v0.UpdateServerStatusBody{ + Status: "deprecated", + NewName: strPtr("io.github.testuser/new-server"), + }, + expectedStatus: http.StatusOK, + checkResult: func(t *testing.T, resp *apiv0.ServerResponse) { + t.Helper() + assert.Equal(t, model.StatusDeprecated, resp.Meta.Official.Status) + assert.NotNil(t, resp.Meta.Official.NewName) + assert.Equal(t, "io.github.testuser/new-server", *resp.Meta.Official.NewName) + }, + }, + { + name: "same status transition allowed when updating statusMessage", + serverName: "io.github.testuser/deprecated-server", + version: "1.0.0", + authClaims: &auth.JWTClaims{ + AuthMethod: auth.MethodGitHubAT, + AuthMethodSubject: "testuser", + Permissions: []auth.Permission{ + {Action: auth.PermissionActionEdit, ResourcePattern: "io.github.testuser/*"}, + }, + }, + requestBody: v0.UpdateServerStatusBody{ + Status: "deprecated", + StatusMessage: strPtr("Updated deprecation message"), + }, + expectedStatus: http.StatusOK, + checkResult: func(t *testing.T, resp *apiv0.ServerResponse) { + t.Helper() + assert.Equal(t, model.StatusDeprecated, resp.Meta.Official.Status) + assert.NotNil(t, resp.Meta.Official.StatusMessage) + assert.Equal(t, "Updated deprecation message", *resp.Meta.Official.StatusMessage) + }, + }, + { + name: "newName rejected for single version when server has multiple versions", + serverName: "io.github.testuser/multi-version-server", + version: "1.0.0", + authClaims: &auth.JWTClaims{ + AuthMethod: auth.MethodGitHubAT, + AuthMethodSubject: "testuser", + Permissions: []auth.Permission{ + {Action: auth.PermissionActionEdit, ResourcePattern: "io.github.testuser/*"}, + {Action: auth.PermissionActionPublish, ResourcePattern: "io.github.testuser/new-server"}, + }, + }, + requestBody: v0.UpdateServerStatusBody{ + Status: "deprecated", + NewName: strPtr("io.github.testuser/new-server"), + }, + expectedStatus: http.StatusBadRequest, + expectedError: "new_name cannot be used with single version endpoint when server has multiple versions", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create Huma API + mux := http.NewServeMux() + api := humago.New(mux, huma.DefaultConfig("Test API", "1.0.0")) + + // Register status endpoints + v0.RegisterStatusEndpoints(api, "/v0", registryService, cfg) + + // Create request body + requestBody, err := json.Marshal(tc.requestBody) + require.NoError(t, err) + + // Create request URL with proper encoding + encodedServerName := url.PathEscape(tc.serverName) + encodedVersion := url.PathEscape(tc.version) + requestURL := "/v0/servers/" + encodedServerName + "/versions/" + encodedVersion + "/status" + + req := httptest.NewRequest(http.MethodPatch, requestURL, bytes.NewReader(requestBody)) + req.Header.Set("Content-Type", "application/json") + + // Set authorization header + if tc.authHeader != "" { + req.Header.Set("Authorization", tc.authHeader) + } else if tc.authClaims != nil { + // Generate valid JWT token + jwtManager := auth.NewJWTManager(cfg) + tokenResponse, err := jwtManager.GenerateTokenResponse(context.Background(), *tc.authClaims) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+tokenResponse.RegistryToken) + } + + // Create response recorder and execute request + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + + // Check response + if tc.expectedStatus != w.Code { + t.Logf("Response body: %s", w.Body.String()) + } + assert.Equal(t, tc.expectedStatus, w.Code) + + if tc.expectedError != "" { + assert.Contains(t, w.Body.String(), tc.expectedError) + } + + if tc.expectedStatus == http.StatusOK && tc.checkResult != nil { + var response apiv0.ServerResponse + err := json.NewDecoder(w.Body).Decode(&response) + require.NoError(t, err) + tc.checkResult(t, &response) + } + }) + } +} + +func TestUpdateServerStatusEndpointSameStatusTransition(t *testing.T) { + // Create test config + testSeed := make([]byte, ed25519.SeedSize) + _, err := rand.Read(testSeed) + require.NoError(t, err) + cfg := &config.Config{ + JWTPrivateKey: hex.EncodeToString(testSeed), + EnableRegistryValidation: false, + } + + // Create registry service + registryService := service.NewRegistryService(database.NewTestDB(t), cfg) + + // Create an active server + activeServer := &apiv0.ServerJSON{ + Schema: model.CurrentSchemaURL, + Name: "io.github.testuser/same-status-test", + Description: "Server for same status transition test", + Version: "1.0.0", + } + _, err = registryService.CreateServer(context.Background(), activeServer) + require.NoError(t, err) + + // Create Huma API + mux := http.NewServeMux() + api := humago.New(mux, huma.DefaultConfig("Test API", "1.0.0")) + v0.RegisterStatusEndpoints(api, "/v0", registryService, cfg) + + // Try to transition from active to active (should fail) + requestBody := v0.UpdateServerStatusBody{ + Status: "active", + } + bodyBytes, err := json.Marshal(requestBody) + require.NoError(t, err) + + encodedName := url.PathEscape(activeServer.Name) + requestURL := "/v0/servers/" + encodedName + "/versions/1.0.0/status" + + req := httptest.NewRequest(http.MethodPatch, requestURL, bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + + // Generate admin token + jwtManager := auth.NewJWTManager(cfg) + tokenResponse, err := jwtManager.GenerateTokenResponse(context.Background(), auth.JWTClaims{ + AuthMethod: auth.MethodNone, + Permissions: []auth.Permission{ + {Action: auth.PermissionActionEdit, ResourcePattern: "*"}, + }, + }) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+tokenResponse.RegistryToken) + + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) + assert.Contains(t, w.Body.String(), "No changes to apply: status is already active") +} + +func TestUpdateServerStatusEndpointURLEncoding(t *testing.T) { + // Create test config + testSeed := make([]byte, ed25519.SeedSize) + _, err := rand.Read(testSeed) + require.NoError(t, err) + cfg := &config.Config{ + JWTPrivateKey: hex.EncodeToString(testSeed), + EnableRegistryValidation: false, + } + + // Create registry service + registryService := service.NewRegistryService(database.NewTestDB(t), cfg) + + // Create a server with build metadata version + buildMetadataServer := &apiv0.ServerJSON{ + Schema: model.CurrentSchemaURL, + Name: "io.github.testuser/build-metadata-server", + Description: "Server with build metadata version", + Version: "1.0.0+20130313144700", + } + _, err = registryService.CreateServer(context.Background(), buildMetadataServer) + require.NoError(t, err) + + // Create Huma API + mux := http.NewServeMux() + api := humago.New(mux, huma.DefaultConfig("Test API", "1.0.0")) + v0.RegisterStatusEndpoints(api, "/v0", registryService, cfg) + + // Update status with URL-encoded version + requestBody := v0.UpdateServerStatusBody{ + Status: "deprecated", + StatusMessage: strPtr("Testing URL encoding"), + } + bodyBytes, err := json.Marshal(requestBody) + require.NoError(t, err) + + encodedName := url.PathEscape(buildMetadataServer.Name) + encodedVersion := url.PathEscape(buildMetadataServer.Version) + requestURL := "/v0/servers/" + encodedName + "/versions/" + encodedVersion + "/status" + + req := httptest.NewRequest(http.MethodPatch, requestURL, bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + + // Generate admin token + jwtManager := auth.NewJWTManager(cfg) + tokenResponse, err := jwtManager.GenerateTokenResponse(context.Background(), auth.JWTClaims{ + AuthMethod: auth.MethodNone, + Permissions: []auth.Permission{ + {Action: auth.PermissionActionEdit, ResourcePattern: "*"}, + }, + }) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+tokenResponse.RegistryToken) + + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var response apiv0.ServerResponse + err = json.NewDecoder(w.Body).Decode(&response) + require.NoError(t, err) + assert.Equal(t, model.StatusDeprecated, response.Meta.Official.Status) + assert.Equal(t, "1.0.0+20130313144700", response.Server.Version) +} + +func TestUpdateAllVersionsStatusEndpoint(t *testing.T) { + // Create test config + testSeed := make([]byte, ed25519.SeedSize) + _, err := rand.Read(testSeed) + require.NoError(t, err) + cfg := &config.Config{ + JWTPrivateKey: hex.EncodeToString(testSeed), + EnableRegistryValidation: false, + } + + // Create registry service and test data + registryService := service.NewRegistryService(database.NewTestDB(t), cfg) + + // Create a server with multiple versions + multiVersionServer := &apiv0.ServerJSON{ + Schema: model.CurrentSchemaURL, + Name: "io.github.testuser/multi-version-server", + Description: "Server with multiple versions", + Version: "1.0.0", + Repository: &model.Repository{ + URL: "https://github.com/testuser/multi-version-server", + Source: "github", + ID: "testuser/multi-version-server", + }, + } + _, err = registryService.CreateServer(context.Background(), multiVersionServer) + require.NoError(t, err) + + // Add more versions + multiVersionServer.Version = "1.1.0" + _, err = registryService.CreateServer(context.Background(), multiVersionServer) + require.NoError(t, err) + + multiVersionServer.Version = "2.0.0" + _, err = registryService.CreateServer(context.Background(), multiVersionServer) + require.NoError(t, err) + + // Create a new server to use for newName tests + newServer := &apiv0.ServerJSON{ + Schema: model.CurrentSchemaURL, + Name: "io.github.testuser/new-target-server", + Description: "New target server for renaming", + Version: "1.0.0", + Repository: &model.Repository{ + URL: "https://github.com/testuser/new-target-server", + Source: "github", + ID: "testuser/new-target-server", + }, + } + _, err = registryService.CreateServer(context.Background(), newServer) + require.NoError(t, err) + + // Create other user's server + otherServer := &apiv0.ServerJSON{ + Schema: model.CurrentSchemaURL, + Name: "io.github.otheruser/other-server", + Description: "Server owned by another user", + Version: "1.0.0", + Repository: &model.Repository{ + URL: "https://github.com/otheruser/other-server", + Source: "github", + ID: "otheruser/other-server", + }, + } + _, err = registryService.CreateServer(context.Background(), otherServer) + require.NoError(t, err) + + testCases := []struct { + name string + serverName string + authClaims *auth.JWTClaims + authHeader string + requestBody v0.UpdateServerStatusBody + expectedStatus int + expectedError string + checkResult func(*testing.T, *v0.UpdateAllVersionsStatusResponse) + }{ + { + name: "successful deprecation of all versions", + serverName: "io.github.testuser/multi-version-server", + authClaims: &auth.JWTClaims{ + AuthMethod: auth.MethodGitHubAT, + AuthMethodSubject: "testuser", + Permissions: []auth.Permission{ + {Action: auth.PermissionActionEdit, ResourcePattern: "io.github.testuser/*"}, + }, + }, + requestBody: v0.UpdateServerStatusBody{ + Status: "deprecated", + StatusMessage: strPtr("This server is deprecated"), + }, + expectedStatus: http.StatusOK, + checkResult: func(t *testing.T, resp *v0.UpdateAllVersionsStatusResponse) { + t.Helper() + assert.Equal(t, 3, resp.UpdatedCount) + assert.Len(t, resp.Servers, 3) + for _, server := range resp.Servers { + assert.Equal(t, model.StatusDeprecated, server.Meta.Official.Status) + assert.NotNil(t, server.Meta.Official.StatusMessage) + assert.Equal(t, "This server is deprecated", *server.Meta.Official.StatusMessage) + } + }, + }, + { + name: "successful yank of all versions with newName", + serverName: "io.github.testuser/multi-version-server", + authClaims: &auth.JWTClaims{ + AuthMethod: auth.MethodGitHubAT, + AuthMethodSubject: "testuser", + Permissions: []auth.Permission{ + {Action: auth.PermissionActionEdit, ResourcePattern: "io.github.testuser/*"}, + {Action: auth.PermissionActionPublish, ResourcePattern: "io.github.testuser/new-target-server"}, + }, + }, + requestBody: v0.UpdateServerStatusBody{ + Status: "yanked", + StatusMessage: strPtr("Moved to new server"), + NewName: strPtr("io.github.testuser/new-target-server"), + }, + expectedStatus: http.StatusOK, + checkResult: func(t *testing.T, resp *v0.UpdateAllVersionsStatusResponse) { + t.Helper() + assert.Equal(t, 3, resp.UpdatedCount) + for _, server := range resp.Servers { + assert.Equal(t, model.StatusYanked, server.Meta.Official.Status) + assert.NotNil(t, server.Meta.Official.NewName) + assert.Equal(t, "io.github.testuser/new-target-server", *server.Meta.Official.NewName) + } + }, + }, + { + name: "successful reactivation of all versions", + serverName: "io.github.testuser/multi-version-server", + authClaims: &auth.JWTClaims{ + AuthMethod: auth.MethodGitHubAT, + AuthMethodSubject: "testuser", + Permissions: []auth.Permission{ + {Action: auth.PermissionActionEdit, ResourcePattern: "io.github.testuser/*"}, + }, + }, + requestBody: v0.UpdateServerStatusBody{ + Status: "active", + }, + expectedStatus: http.StatusOK, + checkResult: func(t *testing.T, resp *v0.UpdateAllVersionsStatusResponse) { + t.Helper() + assert.Equal(t, 3, resp.UpdatedCount) + for _, server := range resp.Servers { + assert.Equal(t, model.StatusActive, server.Meta.Official.Status) + // Status message and newName should be cleared when transitioning to active + assert.Nil(t, server.Meta.Official.StatusMessage) + assert.Nil(t, server.Meta.Official.NewName) + } + }, + }, + { + name: "missing authorization header", + serverName: "io.github.testuser/multi-version-server", + authHeader: "", + requestBody: v0.UpdateServerStatusBody{Status: "deprecated"}, + expectedStatus: http.StatusUnprocessableEntity, + expectedError: "required header parameter is missing", + }, + { + name: "invalid authorization header format", + serverName: "io.github.testuser/multi-version-server", + authHeader: "InvalidFormat token123", + requestBody: v0.UpdateServerStatusBody{ + Status: "deprecated", + }, + expectedStatus: http.StatusUnauthorized, + expectedError: "Invalid Authorization header format", + }, + { + name: "permission denied - no edit permissions", + serverName: "io.github.testuser/multi-version-server", + authClaims: &auth.JWTClaims{ + AuthMethod: auth.MethodGitHubAT, + AuthMethodSubject: "testuser", + Permissions: []auth.Permission{ + {Action: auth.PermissionActionPublish, ResourcePattern: "io.github.testuser/*"}, + }, + }, + requestBody: v0.UpdateServerStatusBody{ + Status: "deprecated", + }, + expectedStatus: http.StatusForbidden, + expectedError: "You do not have edit permissions", + }, + { + name: "server not found", + serverName: "io.github.testuser/non-existent", + authClaims: &auth.JWTClaims{ + AuthMethod: auth.MethodGitHubAT, + AuthMethodSubject: "testuser", + Permissions: []auth.Permission{ + {Action: auth.PermissionActionEdit, ResourcePattern: "io.github.testuser/*"}, + }, + }, + requestBody: v0.UpdateServerStatusBody{ + Status: "deprecated", + }, + expectedStatus: http.StatusNotFound, + expectedError: "Server not found", + }, + { + name: "newName rejected when transitioning to active", + serverName: "io.github.testuser/multi-version-server", + authClaims: &auth.JWTClaims{ + AuthMethod: auth.MethodGitHubAT, + AuthMethodSubject: "testuser", + Permissions: []auth.Permission{ + {Action: auth.PermissionActionEdit, ResourcePattern: "io.github.testuser/*"}, + }, + }, + requestBody: v0.UpdateServerStatusBody{ + Status: "active", + NewName: strPtr("io.github.testuser/new-target-server"), + }, + expectedStatus: http.StatusBadRequest, + expectedError: "new_name can only be used with deprecated or yanked status", + }, + { + name: "newName rejected when same as current server name", + serverName: "io.github.testuser/multi-version-server", + authClaims: &auth.JWTClaims{ + AuthMethod: auth.MethodGitHubAT, + AuthMethodSubject: "testuser", + Permissions: []auth.Permission{ + {Action: auth.PermissionActionEdit, ResourcePattern: "io.github.testuser/*"}, + }, + }, + requestBody: v0.UpdateServerStatusBody{ + Status: "deprecated", + NewName: strPtr("io.github.testuser/multi-version-server"), + }, + expectedStatus: http.StatusBadRequest, + expectedError: "new_name cannot be the same as the current server name", + }, + { + name: "newName with non-existent server", + serverName: "io.github.testuser/multi-version-server", + authClaims: &auth.JWTClaims{ + AuthMethod: auth.MethodGitHubAT, + AuthMethodSubject: "testuser", + Permissions: []auth.Permission{ + {Action: auth.PermissionActionEdit, ResourcePattern: "io.github.testuser/*"}, + }, + }, + requestBody: v0.UpdateServerStatusBody{ + Status: "deprecated", + NewName: strPtr("io.github.testuser/non-existent-server"), + }, + expectedStatus: http.StatusBadRequest, + expectedError: "New server 'io.github.testuser/non-existent-server' does not exist in the registry", + }, + { + name: "newName with server belonging to different user without permissions", + serverName: "io.github.testuser/multi-version-server", + authClaims: &auth.JWTClaims{ + AuthMethod: auth.MethodGitHubAT, + AuthMethodSubject: "testuser", + Permissions: []auth.Permission{ + {Action: auth.PermissionActionEdit, ResourcePattern: "io.github.testuser/*"}, + }, + }, + requestBody: v0.UpdateServerStatusBody{ + Status: "deprecated", + NewName: strPtr("io.github.otheruser/other-server"), + }, + expectedStatus: http.StatusForbidden, + expectedError: "You do not have permissions for the new server 'io.github.otheruser/other-server'", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create Huma API + mux := http.NewServeMux() + api := humago.New(mux, huma.DefaultConfig("Test API", "1.0.0")) + + // Register all-versions status endpoints + v0.RegisterAllVersionsStatusEndpoints(api, "/v0", registryService, cfg) + + // Create request body + requestBody, err := json.Marshal(tc.requestBody) + require.NoError(t, err) + + // Create request URL with proper encoding + encodedServerName := url.PathEscape(tc.serverName) + requestURL := "/v0/servers/" + encodedServerName + "/status" + + req := httptest.NewRequest(http.MethodPatch, requestURL, bytes.NewReader(requestBody)) + req.Header.Set("Content-Type", "application/json") + + // Set authorization header + if tc.authHeader != "" { + req.Header.Set("Authorization", tc.authHeader) + } else if tc.authClaims != nil { + // Generate valid JWT token + jwtManager := auth.NewJWTManager(cfg) + tokenResponse, err := jwtManager.GenerateTokenResponse(context.Background(), *tc.authClaims) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+tokenResponse.RegistryToken) + } + + // Create response recorder and execute request + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + + // Check response + if tc.expectedStatus != w.Code { + t.Logf("Response body: %s", w.Body.String()) + } + assert.Equal(t, tc.expectedStatus, w.Code) + + if tc.expectedError != "" { + assert.Contains(t, w.Body.String(), tc.expectedError) + } + + if tc.expectedStatus == http.StatusOK && tc.checkResult != nil { + var response v0.UpdateAllVersionsStatusResponse + err := json.NewDecoder(w.Body).Decode(&response) + require.NoError(t, err) + tc.checkResult(t, &response) + } + }) + } +} + +// strPtr is a helper function to create a pointer to a string +func strPtr(s string) *string { + return &s +} diff --git a/internal/api/router/v0.go b/internal/api/router/v0.go index eae032ddb..3a8425b5d 100644 --- a/internal/api/router/v0.go +++ b/internal/api/router/v0.go @@ -19,6 +19,8 @@ func RegisterV0Routes( v0.RegisterVersionEndpoint(api, "/v0", versionInfo) v0.RegisterServersEndpoints(api, "/v0", registry) v0.RegisterEditEndpoints(api, "/v0", registry, cfg) + v0.RegisterStatusEndpoints(api, "/v0", registry, cfg) + v0.RegisterAllVersionsStatusEndpoints(api, "/v0", registry, cfg) v0auth.RegisterAuthEndpoints(api, "/v0", cfg) v0.RegisterPublishEndpoint(api, "/v0", registry, cfg) } @@ -31,6 +33,8 @@ func RegisterV0_1Routes( v0.RegisterVersionEndpoint(api, "/v0.1", versionInfo) v0.RegisterServersEndpoints(api, "/v0.1", registry) v0.RegisterEditEndpoints(api, "/v0.1", registry, cfg) + v0.RegisterStatusEndpoints(api, "/v0.1", registry, cfg) + v0.RegisterAllVersionsStatusEndpoints(api, "/v0.1", registry, cfg) v0auth.RegisterAuthEndpoints(api, "/v0.1", cfg) v0.RegisterPublishEndpoint(api, "/v0.1", registry, cfg) } diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go index 4558666cb..adb47aa27 100644 --- a/internal/auth/jwt.go +++ b/internal/auth/jwt.go @@ -9,6 +9,7 @@ import ( "time" "github.com/golang-jwt/jwt/v5" + "github.com/modelcontextprotocol/registry/internal/config" ) @@ -17,7 +18,9 @@ type PermissionAction string const ( PermissionActionPublish PermissionAction = "publish" - // Intended for admins taking moderation actions only, at least for now + // PermissionActionEdit allows editing server configuration and status. + // Can be scoped to a namespace (e.g., "io.github.username/*") for server owners, + // or global ("*") for admins. PermissionActionEdit PermissionAction = "edit" ) @@ -128,7 +131,6 @@ func (j *JWTManager) ValidateToken(_ context.Context, tokenString string) (*JWTC jwt.WithValidMethods([]string{"EdDSA"}), jwt.WithExpirationRequired(), ) - // Validate token if err != nil { return nil, fmt.Errorf("failed to parse token: %w", err) diff --git a/internal/database/database.go b/internal/database/database.go index ebae55d7f..3cc6313e8 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -7,6 +7,7 @@ import ( "github.com/jackc/pgx/v5" apiv0 "github.com/modelcontextprotocol/registry/pkg/api/v0" + "github.com/modelcontextprotocol/registry/pkg/model" ) // Common database errors @@ -27,6 +28,7 @@ type ServerFilter struct { SubstringName *string // for substring search on name Version *string // for exact version matching IsLatest *bool // for filtering latest versions only + IncludeYanked *bool // for including yanked packages in results (default: exclude) } // Database defines the interface for database operations @@ -36,7 +38,9 @@ type Database interface { // UpdateServer updates an existing server record UpdateServer(ctx context.Context, tx pgx.Tx, serverName, version string, serverJSON *apiv0.ServerJSON) (*apiv0.ServerResponse, error) // SetServerStatus updates the status of a specific server version - SetServerStatus(ctx context.Context, tx pgx.Tx, serverName, version string, status string) (*apiv0.ServerResponse, error) + SetServerStatus(ctx context.Context, tx pgx.Tx, serverName, version string, status model.Status, statusMessage, alternativeURL, newName *string) (*apiv0.ServerResponse, error) + // SetAllVersionsStatus updates the status of all versions of a server in a single query + SetAllVersionsStatus(ctx context.Context, tx pgx.Tx, serverName string, status model.Status, statusMessage, alternativeURL, newName *string) ([]*apiv0.ServerResponse, error) // ListServers retrieve server entries with optional filtering ListServers(ctx context.Context, tx pgx.Tx, filter *ServerFilter, cursor string, limit int) ([]*apiv0.ServerResponse, string, error) // GetServerByName retrieve a single server by its name diff --git a/internal/database/migrations/013_add_status_fields_and_change_deleted_to_yanked.sql b/internal/database/migrations/013_add_status_fields_and_change_deleted_to_yanked.sql new file mode 100644 index 000000000..cb1d7f7d5 --- /dev/null +++ b/internal/database/migrations/013_add_status_fields_and_change_deleted_to_yanked.sql @@ -0,0 +1,37 @@ +-- Add status management fields and change 'deleted' status to 'yanked' +-- This migration adds support for deprecation and yanking features + +BEGIN; + +-- Add new columns for status management +ALTER TABLE servers ADD COLUMN status_changed_at TIMESTAMP WITH TIME ZONE; +ALTER TABLE servers ADD COLUMN status_message TEXT; +ALTER TABLE servers ADD COLUMN alternative_url TEXT; +ALTER TABLE servers ADD COLUMN new_name VARCHAR(255); + +-- Initialize status_changed_at with published_at for existing records +UPDATE servers SET status_changed_at = published_at WHERE status_changed_at IS NULL; + +-- Make status_changed_at NOT NULL now that all records have values +ALTER TABLE servers ALTER COLUMN status_changed_at SET NOT NULL; + +-- Update constraint to include 'yanked' status and remove 'deleted' +ALTER TABLE servers DROP CONSTRAINT check_status_valid; +ALTER TABLE servers ADD CONSTRAINT check_status_valid +CHECK (status IN ('active', 'deprecated', 'yanked')); + +-- Change existing 'deleted' status to 'yanked' +UPDATE servers SET status = 'yanked' WHERE status = 'deleted'; + +-- Add index for new_name lookups +CREATE INDEX idx_servers_new_name ON servers(new_name); + +-- Validation: new_name must be a valid server name format if set +ALTER TABLE servers ADD CONSTRAINT check_new_name_format + CHECK (new_name IS NULL OR new_name ~ '^[a-zA-Z0-9.-]+/[a-zA-Z0-9._-]+$'); + +-- Constraint: status_changed_at must be >= published_at +ALTER TABLE servers ADD CONSTRAINT check_status_changed_at_after_published + CHECK (status_changed_at >= published_at); + +COMMIT; diff --git a/internal/database/postgres.go b/internal/database/postgres.go index b47ef74f2..7591f1681 100644 --- a/internal/database/postgres.go +++ b/internal/database/postgres.go @@ -79,6 +79,73 @@ func NewPostgreSQL(ctx context.Context, connectionURI string) (*PostgreSQL, erro }, nil } +// buildFilterConditions constructs WHERE clause conditions from a ServerFilter +func buildFilterConditions(filter *ServerFilter, argIndex int) ([]string, []any, int) { + var conditions []string + var args []any + + if filter == nil { + return conditions, args, argIndex + } + + if filter.Name != nil { + conditions = append(conditions, fmt.Sprintf("server_name = $%d", argIndex)) + args = append(args, *filter.Name) + argIndex++ + } + if filter.RemoteURL != nil { + conditions = append(conditions, fmt.Sprintf("EXISTS (SELECT 1 FROM jsonb_array_elements(value->'remotes') AS remote WHERE remote->>'url' = $%d)", argIndex)) + args = append(args, *filter.RemoteURL) + argIndex++ + } + if filter.UpdatedSince != nil { + conditions = append(conditions, fmt.Sprintf("updated_at > $%d", argIndex)) + args = append(args, *filter.UpdatedSince) + argIndex++ + } + if filter.SubstringName != nil { + conditions = append(conditions, fmt.Sprintf("server_name ILIKE $%d", argIndex)) + args = append(args, "%"+*filter.SubstringName+"%") + argIndex++ + } + if filter.Version != nil { + conditions = append(conditions, fmt.Sprintf("version = $%d", argIndex)) + args = append(args, *filter.Version) + argIndex++ + } + if filter.IsLatest != nil { + conditions = append(conditions, fmt.Sprintf("is_latest = $%d", argIndex)) + args = append(args, *filter.IsLatest) + argIndex++ + } + if filter.IncludeYanked == nil || !*filter.IncludeYanked { + conditions = append(conditions, "status != 'yanked'") + } + + return conditions, args, argIndex +} + +// addCursorCondition adds pagination cursor condition to WHERE clause +func addCursorCondition(cursor string, argIndex int) (string, []any, int) { + if cursor == "" { + return "", nil, argIndex + } + + // Parse cursor format: "serverName:version" + parts := strings.SplitN(cursor, ":", 2) + if len(parts) == 2 { + cursorServerName := parts[0] + cursorVersion := parts[1] + // Use compound condition: (server_name > cursor_name) OR (server_name = cursor_name AND version > cursor_version) + condition := fmt.Sprintf("(server_name > $%d OR (server_name = $%d AND version > $%d))", argIndex, argIndex+1, argIndex+2) + return condition, []any{cursorServerName, cursorServerName, cursorVersion}, argIndex + 3 + } + + // Fallback for malformed cursor - treat as server name only for backwards compatibility + condition := fmt.Sprintf("server_name > $%d", argIndex) + return condition, []any{cursor}, argIndex + 1 +} + func (db *PostgreSQL) ListServers( ctx context.Context, tx pgx.Tx, @@ -94,64 +161,17 @@ func (db *PostgreSQL) ListServers( return nil, "", ctx.Err() } - // Build WHERE clause for filtering using dedicated columns - var whereConditions []string - args := []any{} + // Build WHERE clause conditions argIndex := 1 + whereConditions, args, argIndex := buildFilterConditions(filter, argIndex) - // Add filters using dedicated columns for better performance - if filter != nil { - if filter.Name != nil { - whereConditions = append(whereConditions, fmt.Sprintf("server_name = $%d", argIndex)) - args = append(args, *filter.Name) - argIndex++ - } - if filter.RemoteURL != nil { - whereConditions = append(whereConditions, fmt.Sprintf("EXISTS (SELECT 1 FROM jsonb_array_elements(value->'remotes') AS remote WHERE remote->>'url' = $%d)", argIndex)) - args = append(args, *filter.RemoteURL) - argIndex++ - } - if filter.UpdatedSince != nil { - whereConditions = append(whereConditions, fmt.Sprintf("updated_at > $%d", argIndex)) - args = append(args, *filter.UpdatedSince) - argIndex++ - } - if filter.SubstringName != nil { - whereConditions = append(whereConditions, fmt.Sprintf("server_name ILIKE $%d", argIndex)) - args = append(args, "%"+*filter.SubstringName+"%") - argIndex++ - } - if filter.Version != nil { - whereConditions = append(whereConditions, fmt.Sprintf("version = $%d", argIndex)) - args = append(args, *filter.Version) - argIndex++ - } - if filter.IsLatest != nil { - whereConditions = append(whereConditions, fmt.Sprintf("is_latest = $%d", argIndex)) - args = append(args, *filter.IsLatest) - argIndex++ - } - } - - // Add cursor pagination using compound serverName:version cursor - if cursor != "" { - // Parse cursor format: "serverName:version" - parts := strings.SplitN(cursor, ":", 2) - if len(parts) == 2 { - cursorServerName := parts[0] - cursorVersion := parts[1] - - // Use compound condition: (server_name > cursor_name) OR (server_name = cursor_name AND version > cursor_version) - whereConditions = append(whereConditions, fmt.Sprintf("(server_name > $%d OR (server_name = $%d AND version > $%d))", argIndex, argIndex+1, argIndex+2)) - args = append(args, cursorServerName, cursorServerName, cursorVersion) - argIndex += 3 - } else { - // Fallback for malformed cursor - treat as server name only for backwards compatibility - whereConditions = append(whereConditions, fmt.Sprintf("server_name > $%d", argIndex)) - args = append(args, cursor) - argIndex++ - } + // Add cursor pagination + cursorCondition, cursorArgs, argIndex := addCursorCondition(cursor, argIndex) + if cursorCondition != "" { + whereConditions = append(whereConditions, cursorCondition) + args = append(args, cursorArgs...) } + _ = argIndex // Silence unused variable warning // Build the WHERE clause whereClause := "" @@ -161,7 +181,7 @@ func (db *PostgreSQL) ListServers( // Query servers table with hybrid column/JSON data query := fmt.Sprintf(` - SELECT server_name, version, status, published_at, updated_at, is_latest, value + SELECT server_name, version, status, status_changed_at, status_message, alternative_url, new_name, published_at, updated_at, is_latest, value FROM servers %s ORDER BY server_name, version @@ -178,11 +198,12 @@ func (db *PostgreSQL) ListServers( var results []*apiv0.ServerResponse for rows.Next() { var serverName, version, status string - var publishedAt, updatedAt time.Time + var statusChangedAt, publishedAt, updatedAt time.Time + var statusMessage, alternativeURL, newName *string var isLatest bool var valueJSON []byte - err := rows.Scan(&serverName, &version, &status, &publishedAt, &updatedAt, &isLatest, &valueJSON) + err := rows.Scan(&serverName, &version, &status, &statusChangedAt, &statusMessage, &alternativeURL, &newName, &publishedAt, &updatedAt, &isLatest, &valueJSON) if err != nil { return nil, "", fmt.Errorf("failed to scan server row: %w", err) } @@ -198,10 +219,14 @@ func (db *PostgreSQL) ListServers( Server: serverJSON, Meta: apiv0.ResponseMeta{ Official: &apiv0.RegistryExtensions{ - Status: model.Status(status), - PublishedAt: publishedAt, - UpdatedAt: updatedAt, - IsLatest: isLatest, + Status: model.Status(status), + StatusChangedAt: statusChangedAt, + StatusMessage: statusMessage, + AlternativeURL: alternativeURL, + NewName: newName, + PublishedAt: publishedAt, + UpdatedAt: updatedAt, + IsLatest: isLatest, }, }, } @@ -230,7 +255,7 @@ func (db *PostgreSQL) GetServerByName(ctx context.Context, tx pgx.Tx, serverName } query := ` - SELECT server_name, version, status, published_at, updated_at, is_latest, value + SELECT server_name, version, status, status_changed_at, status_message, alternative_url, new_name, published_at, updated_at, is_latest, value FROM servers WHERE server_name = $1 AND is_latest = true ORDER BY published_at DESC @@ -238,11 +263,12 @@ func (db *PostgreSQL) GetServerByName(ctx context.Context, tx pgx.Tx, serverName ` var name, version, status string - var publishedAt, updatedAt time.Time + var statusChangedAt, publishedAt, updatedAt time.Time + var statusMessage, alternativeURL, newName *string var isLatest bool var valueJSON []byte - err := db.getExecutor(tx).QueryRow(ctx, query, serverName).Scan(&name, &version, &status, &publishedAt, &updatedAt, &isLatest, &valueJSON) + err := db.getExecutor(tx).QueryRow(ctx, query, serverName).Scan(&name, &version, &status, &statusChangedAt, &statusMessage, &alternativeURL, &newName, &publishedAt, &updatedAt, &isLatest, &valueJSON) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return nil, ErrNotFound @@ -261,10 +287,14 @@ func (db *PostgreSQL) GetServerByName(ctx context.Context, tx pgx.Tx, serverName Server: serverJSON, Meta: apiv0.ResponseMeta{ Official: &apiv0.RegistryExtensions{ - Status: model.Status(status), - PublishedAt: publishedAt, - UpdatedAt: updatedAt, - IsLatest: isLatest, + Status: model.Status(status), + StatusChangedAt: statusChangedAt, + StatusMessage: statusMessage, + AlternativeURL: alternativeURL, + NewName: newName, + PublishedAt: publishedAt, + UpdatedAt: updatedAt, + IsLatest: isLatest, }, }, } @@ -279,18 +309,19 @@ func (db *PostgreSQL) GetServerByNameAndVersion(ctx context.Context, tx pgx.Tx, } query := ` - SELECT server_name, version, status, published_at, updated_at, is_latest, value + SELECT server_name, version, status, status_changed_at, status_message, alternative_url, new_name, published_at, updated_at, is_latest, value FROM servers WHERE server_name = $1 AND version = $2 LIMIT 1 ` var name, vers, status string - var publishedAt, updatedAt time.Time + var statusChangedAt, publishedAt, updatedAt time.Time + var statusMessage, alternativeURL, newName *string var isLatest bool var valueJSON []byte - err := db.getExecutor(tx).QueryRow(ctx, query, serverName, version).Scan(&name, &vers, &status, &publishedAt, &updatedAt, &isLatest, &valueJSON) + err := db.getExecutor(tx).QueryRow(ctx, query, serverName, version).Scan(&name, &vers, &status, &statusChangedAt, &statusMessage, &alternativeURL, &newName, &publishedAt, &updatedAt, &isLatest, &valueJSON) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return nil, ErrNotFound @@ -309,10 +340,14 @@ func (db *PostgreSQL) GetServerByNameAndVersion(ctx context.Context, tx pgx.Tx, Server: serverJSON, Meta: apiv0.ResponseMeta{ Official: &apiv0.RegistryExtensions{ - Status: model.Status(status), - PublishedAt: publishedAt, - UpdatedAt: updatedAt, - IsLatest: isLatest, + Status: model.Status(status), + StatusChangedAt: statusChangedAt, + StatusMessage: statusMessage, + AlternativeURL: alternativeURL, + NewName: newName, + PublishedAt: publishedAt, + UpdatedAt: updatedAt, + IsLatest: isLatest, }, }, } @@ -327,7 +362,7 @@ func (db *PostgreSQL) GetAllVersionsByServerName(ctx context.Context, tx pgx.Tx, } query := ` - SELECT server_name, version, status, published_at, updated_at, is_latest, value + SELECT server_name, version, status, status_changed_at, status_message, alternative_url, new_name, published_at, updated_at, is_latest, value FROM servers WHERE server_name = $1 ORDER BY published_at DESC @@ -342,11 +377,12 @@ func (db *PostgreSQL) GetAllVersionsByServerName(ctx context.Context, tx pgx.Tx, var results []*apiv0.ServerResponse for rows.Next() { var name, version, status string - var publishedAt, updatedAt time.Time + var statusChangedAt, publishedAt, updatedAt time.Time + var statusMessage, alternativeURL, newName *string var isLatest bool var valueJSON []byte - err := rows.Scan(&name, &version, &status, &publishedAt, &updatedAt, &isLatest, &valueJSON) + err := rows.Scan(&name, &version, &status, &statusChangedAt, &statusMessage, &alternativeURL, &newName, &publishedAt, &updatedAt, &isLatest, &valueJSON) if err != nil { return nil, fmt.Errorf("failed to scan server row: %w", err) } @@ -362,10 +398,14 @@ func (db *PostgreSQL) GetAllVersionsByServerName(ctx context.Context, tx pgx.Tx, Server: serverJSON, Meta: apiv0.ResponseMeta{ Official: &apiv0.RegistryExtensions{ - Status: model.Status(status), - PublishedAt: publishedAt, - UpdatedAt: updatedAt, - IsLatest: isLatest, + Status: model.Status(status), + StatusChangedAt: statusChangedAt, + StatusMessage: statusMessage, + AlternativeURL: alternativeURL, + NewName: newName, + PublishedAt: publishedAt, + UpdatedAt: updatedAt, + IsLatest: isLatest, }, }, } @@ -407,20 +447,22 @@ func (db *PostgreSQL) CreateServer(ctx context.Context, tx pgx.Tx, serverJSON *a // Insert the new server version using composite primary key insertQuery := ` - INSERT INTO servers (server_name, version, status, published_at, updated_at, is_latest, value) - VALUES ($1, $2, $3, $4, $5, $6, $7) + INSERT INTO servers (server_name, version, status, status_changed_at, status_message, alternative_url, published_at, updated_at, is_latest, value) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) ` _, err = db.getExecutor(tx).Exec(ctx, insertQuery, serverJSON.Name, serverJSON.Version, string(officialMeta.Status), + officialMeta.StatusChangedAt, + officialMeta.StatusMessage, + officialMeta.AlternativeURL, officialMeta.PublishedAt, officialMeta.UpdatedAt, officialMeta.IsLatest, valueJSON, ) - if err != nil { return nil, fmt.Errorf("failed to insert server: %w", err) } @@ -463,14 +505,15 @@ func (db *PostgreSQL) UpdateServer(ctx context.Context, tx pgx.Tx, serverName, v UPDATE servers SET value = $1, updated_at = NOW() WHERE server_name = $2 AND version = $3 - RETURNING server_name, version, status, published_at, updated_at, is_latest + RETURNING server_name, version, status, status_changed_at, status_message, alternative_url, new_name, published_at, updated_at, is_latest ` var name, vers, status string - var publishedAt, updatedAt time.Time + var statusChangedAt, publishedAt, updatedAt time.Time + var statusMessage, alternativeURL, newName *string var isLatest bool - err = db.getExecutor(tx).QueryRow(ctx, query, valueJSON, serverName, version).Scan(&name, &vers, &status, &publishedAt, &updatedAt, &isLatest) + err = db.getExecutor(tx).QueryRow(ctx, query, valueJSON, serverName, version).Scan(&name, &vers, &status, &statusChangedAt, &statusMessage, &alternativeURL, &newName, &publishedAt, &updatedAt, &isLatest) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return nil, ErrNotFound @@ -483,10 +526,14 @@ func (db *PostgreSQL) UpdateServer(ctx context.Context, tx pgx.Tx, serverName, v Server: *serverJSON, Meta: apiv0.ResponseMeta{ Official: &apiv0.RegistryExtensions{ - Status: model.Status(status), - PublishedAt: publishedAt, - UpdatedAt: updatedAt, - IsLatest: isLatest, + Status: model.Status(status), + StatusChangedAt: statusChangedAt, + StatusMessage: statusMessage, + AlternativeURL: alternativeURL, + NewName: newName, + PublishedAt: publishedAt, + UpdatedAt: updatedAt, + IsLatest: isLatest, }, }, } @@ -495,25 +542,26 @@ func (db *PostgreSQL) UpdateServer(ctx context.Context, tx pgx.Tx, serverName, v } // SetServerStatus updates the status of a specific server version -func (db *PostgreSQL) SetServerStatus(ctx context.Context, tx pgx.Tx, serverName, version string, status string) (*apiv0.ServerResponse, error) { +func (db *PostgreSQL) SetServerStatus(ctx context.Context, tx pgx.Tx, serverName, version string, status model.Status, statusMessage, alternativeURL, newName *string) (*apiv0.ServerResponse, error) { if ctx.Err() != nil { return nil, ctx.Err() } - // Update the status column + // Update the status and related fields query := ` UPDATE servers - SET status = $1, updated_at = NOW() + SET status = $1, status_changed_at = NOW(), updated_at = NOW(), status_message = $4, alternative_url = $5, new_name = $6 WHERE server_name = $2 AND version = $3 - RETURNING server_name, version, status, value, published_at, updated_at, is_latest + RETURNING server_name, version, status, value, published_at, updated_at, is_latest, status_changed_at, status_message, alternative_url, new_name ` var name, vers, currentStatus string - var publishedAt, updatedAt time.Time + var publishedAt, updatedAt, statusChangedAt time.Time var isLatest bool var valueJSON []byte + var resultStatusMessage, resultAlternativeURL, resultNewName *string - err := db.getExecutor(tx).QueryRow(ctx, query, status, serverName, version).Scan(&name, &vers, ¤tStatus, &valueJSON, &publishedAt, &updatedAt, &isLatest) + err := db.getExecutor(tx).QueryRow(ctx, query, string(status), serverName, version, statusMessage, alternativeURL, newName).Scan(&name, &vers, ¤tStatus, &valueJSON, &publishedAt, &updatedAt, &isLatest, &statusChangedAt, &resultStatusMessage, &resultAlternativeURL, &resultNewName) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return nil, ErrNotFound @@ -532,10 +580,14 @@ func (db *PostgreSQL) SetServerStatus(ctx context.Context, tx pgx.Tx, serverName Server: serverJSON, Meta: apiv0.ResponseMeta{ Official: &apiv0.RegistryExtensions{ - Status: model.Status(currentStatus), - PublishedAt: publishedAt, - UpdatedAt: updatedAt, - IsLatest: isLatest, + Status: model.Status(currentStatus), + StatusChangedAt: statusChangedAt, + StatusMessage: resultStatusMessage, + AlternativeURL: resultAlternativeURL, + NewName: resultNewName, + PublishedAt: publishedAt, + UpdatedAt: updatedAt, + IsLatest: isLatest, }, }, } @@ -543,6 +595,73 @@ func (db *PostgreSQL) SetServerStatus(ctx context.Context, tx pgx.Tx, serverName return serverResponse, nil } +// SetAllVersionsStatus updates the status of all versions of a server in a single query +func (db *PostgreSQL) SetAllVersionsStatus(ctx context.Context, tx pgx.Tx, serverName string, status model.Status, statusMessage, alternativeURL, newName *string) ([]*apiv0.ServerResponse, error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + + // Update the status and related fields for all versions + query := ` + UPDATE servers + SET status = $1, status_changed_at = NOW(), updated_at = NOW(), status_message = $2, alternative_url = $3, new_name = $4 + WHERE server_name = $5 + RETURNING server_name, version, status, value, published_at, updated_at, is_latest, status_changed_at, status_message, alternative_url, new_name + ` + + rows, err := db.getExecutor(tx).Query(ctx, query, string(status), statusMessage, alternativeURL, newName, serverName) + if err != nil { + return nil, fmt.Errorf("failed to update all server versions status: %w", err) + } + defer rows.Close() + + var results []*apiv0.ServerResponse + for rows.Next() { + var name, vers, currentStatus string + var publishedAt, updatedAt, statusChangedAt time.Time + var isLatest bool + var valueJSON []byte + var resultStatusMessage, resultAlternativeURL, resultNewName *string + + if err := rows.Scan(&name, &vers, ¤tStatus, &valueJSON, &publishedAt, &updatedAt, &isLatest, &statusChangedAt, &resultStatusMessage, &resultAlternativeURL, &resultNewName); err != nil { + return nil, fmt.Errorf("failed to scan server row: %w", err) + } + + // Unmarshal the JSON data + var serverJSON apiv0.ServerJSON + if err := json.Unmarshal(valueJSON, &serverJSON); err != nil { + return nil, fmt.Errorf("failed to unmarshal server JSON: %w", err) + } + + serverResponse := &apiv0.ServerResponse{ + Server: serverJSON, + Meta: apiv0.ResponseMeta{ + Official: &apiv0.RegistryExtensions{ + Status: model.Status(currentStatus), + StatusChangedAt: statusChangedAt, + StatusMessage: resultStatusMessage, + AlternativeURL: resultAlternativeURL, + NewName: resultNewName, + PublishedAt: publishedAt, + UpdatedAt: updatedAt, + IsLatest: isLatest, + }, + }, + } + results = append(results, serverResponse) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating server rows: %w", err) + } + + if len(results) == 0 { + return nil, ErrNotFound + } + + return results, nil +} + // InTransaction executes a function within a database transaction func (db *PostgreSQL) InTransaction(ctx context.Context, fn func(ctx context.Context, tx pgx.Tx) error) error { if ctx.Err() != nil { @@ -615,7 +734,7 @@ func (db *PostgreSQL) GetCurrentLatestVersion(ctx context.Context, tx pgx.Tx, se executor := db.getExecutor(tx) query := ` - SELECT server_name, version, status, value, published_at, updated_at, is_latest + SELECT server_name, version, status, status_changed_at, status_message, alternative_url, new_name, published_at, updated_at, is_latest, value FROM servers WHERE server_name = $1 AND is_latest = true ` @@ -623,11 +742,12 @@ func (db *PostgreSQL) GetCurrentLatestVersion(ctx context.Context, tx pgx.Tx, se row := executor.QueryRow(ctx, query, serverName) var name, version, status string - var publishedAt, updatedAt time.Time + var statusChangedAt, publishedAt, updatedAt time.Time + var statusMessage, alternativeURL, newName *string var isLatest bool var jsonValue []byte - err := row.Scan(&name, &version, &status, &jsonValue, &publishedAt, &updatedAt, &isLatest) + err := row.Scan(&name, &version, &status, &statusChangedAt, &statusMessage, &alternativeURL, &newName, &publishedAt, &updatedAt, &isLatest, &jsonValue) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return nil, ErrNotFound @@ -646,9 +766,14 @@ func (db *PostgreSQL) GetCurrentLatestVersion(ctx context.Context, tx pgx.Tx, se Server: serverJSON, Meta: apiv0.ResponseMeta{ Official: &apiv0.RegistryExtensions{ - PublishedAt: publishedAt, - UpdatedAt: updatedAt, - IsLatest: isLatest, + Status: model.Status(status), + StatusChangedAt: statusChangedAt, + StatusMessage: statusMessage, + AlternativeURL: alternativeURL, + NewName: newName, + PublishedAt: publishedAt, + UpdatedAt: updatedAt, + IsLatest: isLatest, }, }, } diff --git a/internal/database/postgres_test.go b/internal/database/postgres_test.go index ee90fdc75..a152e4d8a 100644 --- a/internal/database/postgres_test.go +++ b/internal/database/postgres_test.go @@ -14,9 +14,12 @@ import ( "github.com/stretchr/testify/require" ) +const testVersion100 = "1.0.0" + func TestPostgreSQL_CreateServer(t *testing.T) { db := database.NewTestDB(t) ctx := context.Background() + timeNow := time.Now() tests := []struct { name string @@ -36,10 +39,11 @@ func TestPostgreSQL_CreateServer(t *testing.T) { }, }, officialMeta: &apiv0.RegistryExtensions{ - Status: model.StatusActive, - PublishedAt: time.Now(), - UpdatedAt: time.Now(), - IsLatest: true, + Status: model.StatusActive, + StatusChangedAt: timeNow, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: true, }, expectError: false, }, @@ -51,10 +55,11 @@ func TestPostgreSQL_CreateServer(t *testing.T) { Version: "1.0.0", }, officialMeta: &apiv0.RegistryExtensions{ - Status: model.StatusActive, - PublishedAt: time.Now(), - UpdatedAt: time.Now(), - IsLatest: true, + Status: model.StatusActive, + StatusChangedAt: timeNow, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: true, }, expectError: true, // Note: Expecting generic database error for constraint violation @@ -94,6 +99,7 @@ func TestPostgreSQL_CreateServer(t *testing.T) { func TestPostgreSQL_GetServerByName(t *testing.T) { db := database.NewTestDB(t) ctx := context.Background() + timeNow := time.Now() // Setup test data serverJSON := &apiv0.ServerJSON{ @@ -102,10 +108,11 @@ func TestPostgreSQL_GetServerByName(t *testing.T) { Version: "1.0.0", } officialMeta := &apiv0.RegistryExtensions{ - Status: model.StatusActive, - PublishedAt: time.Now(), - UpdatedAt: time.Now(), - IsLatest: true, + Status: model.StatusActive, + StatusChangedAt: timeNow, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: true, } // Create the server @@ -153,6 +160,7 @@ func TestPostgreSQL_GetServerByName(t *testing.T) { func TestPostgreSQL_GetServerByNameAndVersion(t *testing.T) { db := database.NewTestDB(t) ctx := context.Background() + timeNow := time.Now() // Setup test data with multiple versions serverName := "com.example/version-test-server" @@ -165,10 +173,11 @@ func TestPostgreSQL_GetServerByNameAndVersion(t *testing.T) { Version: version, } officialMeta := &apiv0.RegistryExtensions{ - Status: model.StatusActive, - PublishedAt: time.Now(), - UpdatedAt: time.Now(), - IsLatest: i == len(versions)-1, // Only last version is latest + Status: model.StatusActive, + StatusChangedAt: timeNow, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: i == len(versions)-1, // Only last version is latest } _, err := db.CreateServer(ctx, nil, serverJSON, officialMeta) @@ -274,10 +283,11 @@ func TestPostgreSQL_ListServers(t *testing.T) { }, } officialMeta := &apiv0.RegistryExtensions{ - Status: server.status, - PublishedAt: server.publishedAt, - UpdatedAt: server.publishedAt, - IsLatest: server.isLatest, + Status: server.status, + StatusChangedAt: server.publishedAt, + PublishedAt: server.publishedAt, + UpdatedAt: server.publishedAt, + IsLatest: server.isLatest, } _, err := db.CreateServer(ctx, nil, serverJSON, officialMeta) @@ -398,20 +408,22 @@ func TestPostgreSQL_ListServers(t *testing.T) { func TestPostgreSQL_UpdateServer(t *testing.T) { db := database.NewTestDB(t) ctx := context.Background() + timeNow := time.Now() // Setup test data serverName := "com.example/update-test-server" - version := "1.0.0" + version := testVersion100 serverJSON := &apiv0.ServerJSON{ Name: serverName, Description: "Original description", Version: version, } officialMeta := &apiv0.RegistryExtensions{ - Status: model.StatusActive, - PublishedAt: time.Now(), - UpdatedAt: time.Now(), - IsLatest: true, + Status: model.StatusActive, + StatusChangedAt: timeNow, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: true, } _, err := db.CreateServer(ctx, nil, serverJSON, officialMeta) @@ -441,11 +453,11 @@ func TestPostgreSQL_UpdateServer(t *testing.T) { { name: "update non-existent server", serverName: "com.example/non-existent", - version: "1.0.0", + version: testVersion100, updatedServer: &apiv0.ServerJSON{ Name: "com.example/non-existent", Description: "Should fail", - Version: "1.0.0", + Version: testVersion100, }, expectError: true, errorType: database.ErrNotFound, @@ -476,20 +488,22 @@ func TestPostgreSQL_UpdateServer(t *testing.T) { func TestPostgreSQL_SetServerStatus(t *testing.T) { db := database.NewTestDB(t) ctx := context.Background() + timeNow := time.Now() // Setup test data serverName := "com.example/status-test-server" - version := "1.0.0" + version := testVersion100 serverJSON := &apiv0.ServerJSON{ Name: serverName, Description: "A server for status testing", Version: version, } officialMeta := &apiv0.RegistryExtensions{ - Status: model.StatusActive, - PublishedAt: time.Now(), - UpdatedAt: time.Now(), - IsLatest: true, + Status: model.StatusActive, + StatusChangedAt: timeNow, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: true, } _, err := db.CreateServer(ctx, nil, serverJSON, officialMeta) @@ -528,7 +542,7 @@ func TestPostgreSQL_SetServerStatus(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := db.SetServerStatus(ctx, nil, tt.serverName, tt.version, tt.newStatus) + result, err := db.SetServerStatus(ctx, nil, tt.serverName, tt.version, model.Status(tt.newStatus), nil, nil, nil) if tt.expectError { assert.Error(t, err) @@ -551,6 +565,7 @@ func TestPostgreSQL_TransactionHandling(t *testing.T) { ctx := context.Background() t.Run("successful transaction", func(t *testing.T) { + timeNow := time.Now() err := db.InTransaction(ctx, func(ctx context.Context, tx pgx.Tx) error { serverJSON := &apiv0.ServerJSON{ Name: "com.example/transaction-success", @@ -558,10 +573,11 @@ func TestPostgreSQL_TransactionHandling(t *testing.T) { Version: "1.0.0", } officialMeta := &apiv0.RegistryExtensions{ - Status: model.StatusActive, - PublishedAt: time.Now(), - UpdatedAt: time.Now(), - IsLatest: true, + Status: model.StatusActive, + StatusChangedAt: timeNow, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: true, } _, err := db.CreateServer(ctx, tx, serverJSON, officialMeta) @@ -577,6 +593,7 @@ func TestPostgreSQL_TransactionHandling(t *testing.T) { }) t.Run("failed transaction rollback", func(t *testing.T) { + timeNow := time.Now() err := db.InTransaction(ctx, func(ctx context.Context, tx pgx.Tx) error { serverJSON := &apiv0.ServerJSON{ Name: "com.example/transaction-rollback", @@ -584,10 +601,11 @@ func TestPostgreSQL_TransactionHandling(t *testing.T) { Version: "1.0.0", } officialMeta := &apiv0.RegistryExtensions{ - Status: model.StatusActive, - PublishedAt: time.Now(), - UpdatedAt: time.Now(), - IsLatest: true, + Status: model.StatusActive, + StatusChangedAt: timeNow, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: true, } _, err := db.CreateServer(ctx, tx, serverJSON, officialMeta) @@ -624,6 +642,7 @@ func TestPostgreSQL_ConcurrencyAndLocking(t *testing.T) { // Launch two concurrent publish operations for i := 0; i < 2; i++ { go func(version string) { + timeNow := time.Now() err := db.InTransaction(ctx, func(ctx context.Context, tx pgx.Tx) error { // Acquire lock if err := db.AcquirePublishLock(ctx, tx, serverName); err != nil { @@ -639,10 +658,11 @@ func TestPostgreSQL_ConcurrencyAndLocking(t *testing.T) { Version: version, } officialMeta := &apiv0.RegistryExtensions{ - Status: model.StatusActive, - PublishedAt: time.Now(), - UpdatedAt: time.Now(), - IsLatest: true, + Status: model.StatusActive, + StatusChangedAt: timeNow, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: true, } _, err := db.CreateServer(ctx, tx, serverJSON, officialMeta) @@ -681,6 +701,7 @@ func TestPostgreSQL_ConcurrencyAndLocking(t *testing.T) { func TestPostgreSQL_HelperMethods(t *testing.T) { db := database.NewTestDB(t) ctx := context.Background() + timeNow := time.Now() serverName := "com.example/helper-test-server" @@ -693,10 +714,11 @@ func TestPostgreSQL_HelperMethods(t *testing.T) { Version: version, } officialMeta := &apiv0.RegistryExtensions{ - Status: model.StatusActive, - PublishedAt: time.Now(), - UpdatedAt: time.Now(), - IsLatest: version == "2.0.0", + Status: model.StatusActive, + StatusChangedAt: timeNow, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: version == "2.0.0", } _, err := db.CreateServer(ctx, nil, serverJSON, officialMeta) @@ -762,6 +784,7 @@ func TestPostgreSQL_HelperMethods(t *testing.T) { func TestPostgreSQL_EdgeCases(t *testing.T) { db := database.NewTestDB(t) ctx := context.Background() + timeNow := time.Now() t.Run("input validation", func(t *testing.T) { // Test nil inputs @@ -783,10 +806,11 @@ func TestPostgreSQL_EdgeCases(t *testing.T) { Version: "1.0.0", } officialMeta := &apiv0.RegistryExtensions{ - Status: model.StatusActive, - PublishedAt: time.Now(), - UpdatedAt: time.Now(), - IsLatest: true, + Status: model.StatusActive, + StatusChangedAt: timeNow, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: true, } _, err := db.CreateServer(ctx, nil, invalidServer, officialMeta) @@ -821,10 +845,11 @@ func TestPostgreSQL_EdgeCases(t *testing.T) { {Type: "streamable-http", URL: "https://complex.example.com/mcp"}, }, }, &apiv0.RegistryExtensions{ - Status: model.StatusActive, - PublishedAt: testTime, - UpdatedAt: testTime, - IsLatest: true, + Status: model.StatusActive, + StatusChangedAt: testTime, + PublishedAt: testTime, + UpdatedAt: testTime, + IsLatest: true, }) require.NoError(t, err) @@ -852,31 +877,137 @@ func TestPostgreSQL_EdgeCases(t *testing.T) { Description: "Status transition test", Version: version, }, &apiv0.RegistryExtensions{ - Status: model.StatusActive, - PublishedAt: time.Now(), - UpdatedAt: time.Now(), - IsLatest: true, + Status: model.StatusActive, + StatusChangedAt: timeNow, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: true, }) require.NoError(t, err) // Test all valid status transitions statuses := []string{ string(model.StatusDeprecated), - string(model.StatusDeleted), + string(model.StatusYanked), string(model.StatusActive), // Can transition back } for _, status := range statuses { - result, err := db.SetServerStatus(ctx, nil, serverName, version, status) + result, err := db.SetServerStatus(ctx, nil, serverName, version, model.Status(status), nil, nil, nil) assert.NoError(t, err, "Should allow transition to %s", status) assert.Equal(t, model.Status(status), result.Meta.Official.Status) } }) + + // Test status transitions with additional fields + t.Run("status transitions with message and alternative URL", func(t *testing.T) { + testServerName := "com.example/status-fields-test" + testVersion := "1.0.0" + + // Create a test server + _, err := db.CreateServer(ctx, nil, &apiv0.ServerJSON{ + Name: testServerName, + Description: "Status fields test", + Version: testVersion, + }, &apiv0.RegistryExtensions{ + Status: model.StatusActive, + StatusChangedAt: timeNow, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: true, + }) + require.NoError(t, err) + + statusMessage := "This server has been deprecated. Please use the new version." + alternativeURL := "https://example.com/new-server" + + // Test setting status with message and alternative URL + result, err := db.SetServerStatus(ctx, nil, testServerName, testVersion, model.StatusDeprecated, &statusMessage, &alternativeURL, nil) + assert.NoError(t, err) + assert.Equal(t, model.StatusDeprecated, result.Meta.Official.Status) + assert.NotNil(t, result.Meta.Official.StatusMessage) + assert.Equal(t, statusMessage, *result.Meta.Official.StatusMessage) + assert.NotNil(t, result.Meta.Official.AlternativeURL) + assert.Equal(t, alternativeURL, *result.Meta.Official.AlternativeURL) + assert.NotZero(t, result.Meta.Official.StatusChangedAt) + + // Test clearing status message and alternative URL + result, err = db.SetServerStatus(ctx, nil, testServerName, testVersion, model.StatusActive, nil, nil, nil) + assert.NoError(t, err) + assert.Equal(t, model.StatusActive, result.Meta.Official.Status) + assert.Nil(t, result.Meta.Official.StatusMessage) + assert.Nil(t, result.Meta.Official.AlternativeURL) + }) + + // Test comprehensive status transitions as per user requirements + t.Run("comprehensive status transitions", func(t *testing.T) { + testServerName := "com.example/comprehensive-transitions-test" + testVersion := "1.0.0" + + // Create a test server in active status + _, err := db.CreateServer(ctx, nil, &apiv0.ServerJSON{ + Name: testServerName, + Description: "Comprehensive transitions test", + Version: testVersion, + }, &apiv0.RegistryExtensions{ + Status: model.StatusActive, + StatusChangedAt: timeNow, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: true, + }) + require.NoError(t, err) + + // Define all valid transitions based on user requirements + transitionTests := []struct { + name string + fromStatus model.Status + toStatus model.Status + description string + }{ + // Active ↔ Deprecated + {"active to deprecated", model.StatusActive, model.StatusDeprecated, "Deprecating active server"}, + {"deprecated to active", model.StatusDeprecated, model.StatusActive, "Reactivating deprecated server"}, + + // Active ↔ Yanked + {"active to yanked", model.StatusActive, model.StatusYanked, "Yanking active server"}, + {"yanked to active", model.StatusYanked, model.StatusActive, "Unyanking to active server"}, + + // Deprecated ↔ Yanked + {"deprecated to yanked", model.StatusDeprecated, model.StatusYanked, "Yanking deprecated server"}, + {"yanked to deprecated", model.StatusYanked, model.StatusDeprecated, "Moving yanked to deprecated"}, + } + + for _, tt := range transitionTests { + t.Run(tt.name, func(t *testing.T) { + // First ensure the server is in the expected starting status + if tt.fromStatus != model.StatusActive { + _, err := db.SetServerStatus(ctx, nil, testServerName, testVersion, tt.fromStatus, nil, nil, nil) + require.NoError(t, err, "failed to set initial status to %s", tt.fromStatus) + } + + // Verify the server is in the expected starting status + currentServer, err := db.GetServerByNameAndVersion(ctx, nil, testServerName, testVersion) + require.NoError(t, err) + assert.Equal(t, tt.fromStatus, currentServer.Meta.Official.Status, "server should be in %s status before transition", tt.fromStatus) + + // Perform the transition + result, err := db.SetServerStatus(ctx, nil, testServerName, testVersion, tt.toStatus, &tt.description, nil, nil) + assert.NoError(t, err, "should allow transition from %s to %s", tt.fromStatus, tt.toStatus) + assert.NotNil(t, result) + assert.Equal(t, tt.toStatus, result.Meta.Official.Status, "status should be %s after transition", tt.toStatus) + assert.NotNil(t, result.Meta.Official.StatusMessage) + assert.Equal(t, tt.description, *result.Meta.Official.StatusMessage) + assert.NotZero(t, result.Meta.Official.StatusChangedAt) + }) + } + }) } func TestPostgreSQL_PerformanceScenarios(t *testing.T) { db := database.NewTestDB(t) ctx := context.Background() + timeNow := time.Now() t.Run("many versions management", func(t *testing.T) { serverName := "com.example/many-versions-server" @@ -889,10 +1020,11 @@ func TestPostgreSQL_PerformanceScenarios(t *testing.T) { Description: fmt.Sprintf("Version %d", i), Version: fmt.Sprintf("1.0.%d", i), }, &apiv0.RegistryExtensions{ - Status: model.StatusActive, - PublishedAt: time.Now(), - UpdatedAt: time.Now(), - IsLatest: i == versionCount-1, // Only last one is latest + Status: model.StatusActive, + StatusChangedAt: timeNow, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: i == versionCount-1, // Only last one is latest }) require.NoError(t, err) } @@ -926,10 +1058,11 @@ func TestPostgreSQL_PerformanceScenarios(t *testing.T) { Description: "Pagination test server", Version: "1.0.0", }, &apiv0.RegistryExtensions{ - Status: model.StatusActive, - PublishedAt: time.Now(), - UpdatedAt: time.Now(), - IsLatest: true, + Status: model.StatusActive, + StatusChangedAt: timeNow, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: true, }) require.NoError(t, err) } @@ -955,6 +1088,697 @@ func TestPostgreSQL_PerformanceScenarios(t *testing.T) { }) } +func TestPostgreSQL_NewStatusFields(t *testing.T) { + db := database.NewTestDB(t) + ctx := context.Background() + timeNow := time.Now() + + t.Run("status_changed_at field functionality", func(t *testing.T) { + serverJSON := &apiv0.ServerJSON{ + Name: "com.example/status-changed-at-test", + Description: "Test server for status_changed_at field", + Version: "1.0.0", + } + + // Create server with specific status_changed_at + officialMeta := &apiv0.RegistryExtensions{ + Status: model.StatusActive, + StatusChangedAt: timeNow, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: true, + } + + _, err := db.CreateServer(ctx, nil, serverJSON, officialMeta) + require.NoError(t, err) + + // Retrieve and verify status_changed_at + result, err := db.GetServerByNameAndVersion(ctx, nil, serverJSON.Name, serverJSON.Version) + require.NoError(t, err) + assert.NotNil(t, result.Meta.Official) + assert.Equal(t, timeNow.Unix(), result.Meta.Official.StatusChangedAt.Unix()) + }) + + t.Run("status_message field functionality", func(t *testing.T) { + serverJSON := &apiv0.ServerJSON{ + Name: "com.example/status-message-test", + Description: "Test server for status_message field", + Version: "1.0.0", + } + + statusMessage := "This server is deprecated due to security issues. Please migrate to v2.0.0" + officialMeta := &apiv0.RegistryExtensions{ + Status: model.StatusDeprecated, + StatusChangedAt: timeNow, + StatusMessage: &statusMessage, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: true, + } + + _, err := db.CreateServer(ctx, nil, serverJSON, officialMeta) + require.NoError(t, err) + + // Retrieve and verify status_message + result, err := db.GetServerByNameAndVersion(ctx, nil, serverJSON.Name, serverJSON.Version) + require.NoError(t, err) + assert.NotNil(t, result.Meta.Official) + assert.NotNil(t, result.Meta.Official.StatusMessage) + assert.Equal(t, statusMessage, *result.Meta.Official.StatusMessage) + }) + + t.Run("alternative_url field functionality", func(t *testing.T) { + serverJSON := &apiv0.ServerJSON{ + Name: "com.example/alternative-url-test", + Description: "Test server for alternative_url field", + Version: "1.0.0", + } + + alternativeURL := "https://github.com/example/new-server-v2" + officialMeta := &apiv0.RegistryExtensions{ + Status: model.StatusDeprecated, + StatusChangedAt: timeNow, + AlternativeURL: &alternativeURL, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: true, + } + + _, err := db.CreateServer(ctx, nil, serverJSON, officialMeta) + require.NoError(t, err) + + // Retrieve and verify alternative_url + result, err := db.GetServerByNameAndVersion(ctx, nil, serverJSON.Name, serverJSON.Version) + require.NoError(t, err) + assert.NotNil(t, result.Meta.Official) + assert.NotNil(t, result.Meta.Official.AlternativeURL) + assert.Equal(t, alternativeURL, *result.Meta.Official.AlternativeURL) + }) + + t.Run("yanked status functionality", func(t *testing.T) { + serverJSON := &apiv0.ServerJSON{ + Name: "com.example/yanked-status-test", + Description: "Test server for yanked status", + Version: "1.0.0", + } + + statusMessage := "This version has critical security vulnerabilities and has been yanked" + alternativeURL := "https://github.com/example/secure-version" + + officialMeta := &apiv0.RegistryExtensions{ + Status: model.StatusYanked, + StatusChangedAt: timeNow, + StatusMessage: &statusMessage, + AlternativeURL: &alternativeURL, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: false, // Yanked versions should not be latest + } + + _, err := db.CreateServer(ctx, nil, serverJSON, officialMeta) + require.NoError(t, err) + + // Retrieve and verify yanked status with all fields + result, err := db.GetServerByNameAndVersion(ctx, nil, serverJSON.Name, serverJSON.Version) + require.NoError(t, err) + assert.NotNil(t, result.Meta.Official) + assert.Equal(t, model.StatusYanked, result.Meta.Official.Status) + assert.NotNil(t, result.Meta.Official.StatusMessage) + assert.Equal(t, statusMessage, *result.Meta.Official.StatusMessage) + assert.NotNil(t, result.Meta.Official.AlternativeURL) + assert.Equal(t, alternativeURL, *result.Meta.Official.AlternativeURL) + assert.False(t, result.Meta.Official.IsLatest) + }) + + t.Run("nil status_message and alternative_url", func(t *testing.T) { + serverJSON := &apiv0.ServerJSON{ + Name: "com.example/nil-fields-test", + Description: "Test server for nil optional fields", + Version: "1.0.0", + } + + officialMeta := &apiv0.RegistryExtensions{ + Status: model.StatusActive, + StatusChangedAt: timeNow, + StatusMessage: nil, // Explicitly nil + AlternativeURL: nil, // Explicitly nil + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: true, + } + + _, err := db.CreateServer(ctx, nil, serverJSON, officialMeta) + require.NoError(t, err) + + // Retrieve and verify nil fields are handled correctly + result, err := db.GetServerByNameAndVersion(ctx, nil, serverJSON.Name, serverJSON.Version) + require.NoError(t, err) + assert.NotNil(t, result.Meta.Official) + assert.Nil(t, result.Meta.Official.StatusMessage) + assert.Nil(t, result.Meta.Official.AlternativeURL) + }) + + t.Run("status_changed_at constraint enforcement", func(t *testing.T) { + serverJSON := &apiv0.ServerJSON{ + Name: "com.example/constraint-test", + Description: "Test server for constraint validation", + Version: "1.0.0", + } + + // Try to create server with status_changed_at before published_at (should fail) + earlierTime := timeNow.Add(-1 * time.Hour) + officialMeta := &apiv0.RegistryExtensions{ + Status: model.StatusActive, + StatusChangedAt: earlierTime, // Before published_at + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: true, + } + + _, err := db.CreateServer(ctx, nil, serverJSON, officialMeta) + assert.Error(t, err) + assert.Contains(t, err.Error(), "check_status_changed_at_after_published") + }) + + t.Run("all status transitions work", func(t *testing.T) { + // Test that we can create servers with all valid status values + statuses := []model.Status{ + model.StatusActive, + model.StatusDeprecated, + model.StatusYanked, + } + + for i, status := range statuses { + serverJSON := &apiv0.ServerJSON{ + Name: fmt.Sprintf("com.example/status-test-%d", i), + Description: fmt.Sprintf("Test server for status %s", status), + Version: "1.0.0", + } + + officialMeta := &apiv0.RegistryExtensions{ + Status: status, + StatusChangedAt: timeNow, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: true, + } + + _, err := db.CreateServer(ctx, nil, serverJSON, officialMeta) + assert.NoError(t, err, "Should be able to create server with status: %s", status) + + // Verify the status was set correctly + result, err := db.GetServerByNameAndVersion(ctx, nil, serverJSON.Name, serverJSON.Version) + require.NoError(t, err) + assert.Equal(t, status, result.Meta.Official.Status) + } + }) +} + +func TestPostgreSQL_StatusFieldsInListOperations(t *testing.T) { + db := database.NewTestDB(t) + ctx := context.Background() + timeNow := time.Now() + + // Create test servers with different statuses and status fields + testServers := []struct { + name string + status model.Status + statusMessage *string + alternativeURL *string + }{ + { + name: "com.example/active-server", + status: model.StatusActive, + statusMessage: nil, + alternativeURL: nil, + }, + { + name: "com.example/deprecated-server", + status: model.StatusDeprecated, + statusMessage: stringPtr("Deprecated in favor of v2"), + alternativeURL: stringPtr("https://github.com/example/v2"), + }, + { + name: "com.example/yanked-server", + status: model.StatusYanked, + statusMessage: stringPtr("Security vulnerability found"), + alternativeURL: stringPtr("https://github.com/example/secure"), + }, + } + + // Create all test servers + for _, server := range testServers { + serverJSON := &apiv0.ServerJSON{ + Name: server.name, + Description: "Test server for list operations", + Version: "1.0.0", + } + + officialMeta := &apiv0.RegistryExtensions{ + Status: server.status, + StatusChangedAt: timeNow, + StatusMessage: server.statusMessage, + AlternativeURL: server.alternativeURL, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: true, + } + + _, err := db.CreateServer(ctx, nil, serverJSON, officialMeta) + require.NoError(t, err) + } + + t.Run("ListServers includes new status fields", func(t *testing.T) { + results, _, err := db.ListServers(ctx, nil, nil, "", 10) + require.NoError(t, err) + + // Find our test servers in the results + foundServers := make(map[string]*apiv0.ServerResponse) + for _, result := range results { + for _, testServer := range testServers { + if result.Server.Name == testServer.name { + foundServers[testServer.name] = result + } + } + } + + // Verify all test servers were found with correct status fields + for _, testServer := range testServers { + result, found := foundServers[testServer.name] + assert.True(t, found, "Server %s should be found in list results", testServer.name) + if !found { + continue + } + + assert.NotNil(t, result.Meta.Official) + assert.Equal(t, testServer.status, result.Meta.Official.Status) + assert.Equal(t, timeNow.Unix(), result.Meta.Official.StatusChangedAt.Unix()) + + if testServer.statusMessage != nil { + assert.NotNil(t, result.Meta.Official.StatusMessage) + assert.Equal(t, *testServer.statusMessage, *result.Meta.Official.StatusMessage) + } else { + assert.Nil(t, result.Meta.Official.StatusMessage) + } + + if testServer.alternativeURL != nil { + assert.NotNil(t, result.Meta.Official.AlternativeURL) + assert.Equal(t, *testServer.alternativeURL, *result.Meta.Official.AlternativeURL) + } else { + assert.Nil(t, result.Meta.Official.AlternativeURL) + } + } + }) + + t.Run("GetServerByName includes new status fields", func(t *testing.T) { + for _, testServer := range testServers { + result, err := db.GetServerByName(ctx, nil, testServer.name) + require.NoError(t, err) + + assert.NotNil(t, result.Meta.Official) + assert.Equal(t, testServer.status, result.Meta.Official.Status) + assert.Equal(t, timeNow.Unix(), result.Meta.Official.StatusChangedAt.Unix()) + + if testServer.statusMessage != nil { + assert.NotNil(t, result.Meta.Official.StatusMessage) + assert.Equal(t, *testServer.statusMessage, *result.Meta.Official.StatusMessage) + } else { + assert.Nil(t, result.Meta.Official.StatusMessage) + } + + if testServer.alternativeURL != nil { + assert.NotNil(t, result.Meta.Official.AlternativeURL) + assert.Equal(t, *testServer.alternativeURL, *result.Meta.Official.AlternativeURL) + } else { + assert.Nil(t, result.Meta.Official.AlternativeURL) + } + } + }) +} + +func TestPostgreSQL_SetAllVersionsStatus(t *testing.T) { + db := database.NewTestDB(t) + ctx := context.Background() + timeNow := time.Now() + + t.Run("update all versions status successfully", func(t *testing.T) { + serverName := "com.example/all-versions-status-test" + versions := []string{"1.0.0", "1.1.0", "2.0.0"} + + // Create multiple versions + for i, version := range versions { + serverJSON := &apiv0.ServerJSON{ + Name: serverName, + Description: "Test server for all-versions status update", + Version: version, + } + officialMeta := &apiv0.RegistryExtensions{ + Status: model.StatusActive, + StatusChangedAt: timeNow, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: i == len(versions)-1, + } + + _, err := db.CreateServer(ctx, nil, serverJSON, officialMeta) + require.NoError(t, err) + } + + // Update all versions to deprecated + statusMessage := "All versions deprecated" + alternativeURL := "https://example.com/new-server" + + results, err := db.SetAllVersionsStatus(ctx, nil, serverName, model.StatusDeprecated, &statusMessage, &alternativeURL, nil) + assert.NoError(t, err) + assert.Len(t, results, 3) + + // Verify all versions were updated + for _, result := range results { + assert.Equal(t, model.StatusDeprecated, result.Meta.Official.Status) + assert.NotNil(t, result.Meta.Official.StatusMessage) + assert.Equal(t, statusMessage, *result.Meta.Official.StatusMessage) + assert.NotNil(t, result.Meta.Official.AlternativeURL) + assert.Equal(t, alternativeURL, *result.Meta.Official.AlternativeURL) + } + + // Verify by fetching each version individually + for _, version := range versions { + server, err := db.GetServerByNameAndVersion(ctx, nil, serverName, version) + require.NoError(t, err) + assert.Equal(t, model.StatusDeprecated, server.Meta.Official.Status) + assert.NotNil(t, server.Meta.Official.StatusMessage) + assert.Equal(t, statusMessage, *server.Meta.Official.StatusMessage) + } + }) + + t.Run("update all versions with newName", func(t *testing.T) { + serverName := "com.example/all-versions-newname-test" + newServerName := "com.example/replacement-server" + + // Create the old server with multiple versions + for i, version := range []string{"1.0.0", "2.0.0"} { + serverJSON := &apiv0.ServerJSON{ + Name: serverName, + Description: "Old server", + Version: version, + } + officialMeta := &apiv0.RegistryExtensions{ + Status: model.StatusActive, + StatusChangedAt: timeNow, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: i == 1, + } + + _, err := db.CreateServer(ctx, nil, serverJSON, officialMeta) + require.NoError(t, err) + } + + // Create the replacement server (so newName validation would pass at API level) + _, err := db.CreateServer(ctx, nil, &apiv0.ServerJSON{ + Name: newServerName, + Description: "Replacement server", + Version: "1.0.0", + }, &apiv0.RegistryExtensions{ + Status: model.StatusActive, + StatusChangedAt: timeNow, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: true, + }) + require.NoError(t, err) + + // Update all versions to deprecated with newName + statusMessage := "Server renamed" + results, err := db.SetAllVersionsStatus(ctx, nil, serverName, model.StatusDeprecated, &statusMessage, nil, &newServerName) + assert.NoError(t, err) + assert.Len(t, results, 2) + + // Verify newName is set on all versions + for _, result := range results { + assert.Equal(t, model.StatusDeprecated, result.Meta.Official.Status) + assert.NotNil(t, result.Meta.Official.NewName) + assert.Equal(t, newServerName, *result.Meta.Official.NewName) + } + }) + + t.Run("update non-existent server returns error", func(t *testing.T) { + results, err := db.SetAllVersionsStatus(ctx, nil, "com.example/non-existent-server", model.StatusDeprecated, nil, nil, nil) + assert.Error(t, err) + assert.ErrorIs(t, err, database.ErrNotFound) + assert.Nil(t, results) + }) + + t.Run("update all versions to yanked", func(t *testing.T) { + serverName := "com.example/all-versions-yanked-test" + + // Create multiple versions + for i, version := range []string{"1.0.0", "1.1.0"} { + serverJSON := &apiv0.ServerJSON{ + Name: serverName, + Description: "Test server for yanking", + Version: version, + } + officialMeta := &apiv0.RegistryExtensions{ + Status: model.StatusActive, + StatusChangedAt: timeNow, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: i == 1, + } + + _, err := db.CreateServer(ctx, nil, serverJSON, officialMeta) + require.NoError(t, err) + } + + // Yank all versions + statusMessage := "Critical security vulnerability" + results, err := db.SetAllVersionsStatus(ctx, nil, serverName, model.StatusYanked, &statusMessage, nil, nil) + assert.NoError(t, err) + assert.Len(t, results, 2) + + // Verify all versions are yanked + for _, result := range results { + assert.Equal(t, model.StatusYanked, result.Meta.Official.Status) + assert.NotNil(t, result.Meta.Official.StatusMessage) + assert.Equal(t, statusMessage, *result.Meta.Official.StatusMessage) + } + }) + + t.Run("transition from yanked back to active", func(t *testing.T) { + serverName := "com.example/all-versions-reactivate-test" + + // Create server in yanked state + for i, version := range []string{"1.0.0", "2.0.0"} { + serverJSON := &apiv0.ServerJSON{ + Name: serverName, + Description: "Test server for reactivation", + Version: version, + } + officialMeta := &apiv0.RegistryExtensions{ + Status: model.StatusYanked, + StatusChangedAt: timeNow, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: i == 1, + } + + _, err := db.CreateServer(ctx, nil, serverJSON, officialMeta) + require.NoError(t, err) + } + + // Reactivate all versions + results, err := db.SetAllVersionsStatus(ctx, nil, serverName, model.StatusActive, nil, nil, nil) + assert.NoError(t, err) + assert.Len(t, results, 2) + + // Verify all versions are active and metadata is cleared + for _, result := range results { + assert.Equal(t, model.StatusActive, result.Meta.Official.Status) + assert.Nil(t, result.Meta.Official.StatusMessage) + assert.Nil(t, result.Meta.Official.AlternativeURL) + assert.Nil(t, result.Meta.Official.NewName) + } + }) +} + +func TestPostgreSQL_IncludeYankedFilter(t *testing.T) { + db := database.NewTestDB(t) + ctx := context.Background() + timeNow := time.Now() + + // Create test servers with different statuses + testServers := []struct { + name string + version string + status model.Status + }{ + { + name: "com.example/yanked-filter-active", + version: "1.0.0", + status: model.StatusActive, + }, + { + name: "com.example/yanked-filter-deprecated", + version: "1.0.0", + status: model.StatusDeprecated, + }, + { + name: "com.example/yanked-filter-yanked", + version: "1.0.0", + status: model.StatusYanked, + }, + } + + // Create all test servers + for _, server := range testServers { + serverJSON := &apiv0.ServerJSON{ + Name: server.name, + Description: "Test server for include yanked filter", + Version: server.version, + } + officialMeta := &apiv0.RegistryExtensions{ + Status: server.status, + StatusChangedAt: timeNow, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: true, + } + + _, err := db.CreateServer(ctx, nil, serverJSON, officialMeta) + require.NoError(t, err) + } + + t.Run("excludes yanked by default (nil IncludeYanked)", func(t *testing.T) { + filter := &database.ServerFilter{ + SubstringName: stringPtr("yanked-filter"), + } + + results, _, err := db.ListServers(ctx, nil, filter, "", 10) + require.NoError(t, err) + + // Should only get active and deprecated servers + assert.Len(t, results, 2) + + for _, result := range results { + assert.NotEqual(t, model.StatusYanked, result.Meta.Official.Status, + "Yanked servers should be excluded by default") + } + + // Verify we got the expected servers + names := make([]string, len(results)) + for i, r := range results { + names[i] = r.Server.Name + } + assert.Contains(t, names, "com.example/yanked-filter-active") + assert.Contains(t, names, "com.example/yanked-filter-deprecated") + }) + + t.Run("excludes yanked when IncludeYanked is false", func(t *testing.T) { + filter := &database.ServerFilter{ + SubstringName: stringPtr("yanked-filter"), + IncludeYanked: boolPtr(false), + } + + results, _, err := db.ListServers(ctx, nil, filter, "", 10) + require.NoError(t, err) + + // Should only get active and deprecated servers + assert.Len(t, results, 2) + + for _, result := range results { + assert.NotEqual(t, model.StatusYanked, result.Meta.Official.Status, + "Yanked servers should be excluded when IncludeYanked is false") + } + }) + + t.Run("includes yanked when IncludeYanked is true", func(t *testing.T) { + filter := &database.ServerFilter{ + SubstringName: stringPtr("yanked-filter"), + IncludeYanked: boolPtr(true), + } + + results, _, err := db.ListServers(ctx, nil, filter, "", 10) + require.NoError(t, err) + + // Should get all servers including yanked + assert.Len(t, results, 3) + + // Verify we got all statuses + statuses := make(map[model.Status]bool) + for _, result := range results { + statuses[result.Meta.Official.Status] = true + } + + assert.True(t, statuses[model.StatusActive], "Should include active servers") + assert.True(t, statuses[model.StatusDeprecated], "Should include deprecated servers") + assert.True(t, statuses[model.StatusYanked], "Should include yanked servers") + }) + + t.Run("combined filters with include yanked", func(t *testing.T) { + // Test that IncludeYanked works correctly with other filters + filter := &database.ServerFilter{ + SubstringName: stringPtr("yanked-filter"), + Version: stringPtr("1.0.0"), + IsLatest: boolPtr(true), + IncludeYanked: boolPtr(true), + } + + results, _, err := db.ListServers(ctx, nil, filter, "", 10) + require.NoError(t, err) + + // Should get all 3 servers (all match version and isLatest criteria) + assert.Len(t, results, 3) + }) + + t.Run("multiple versions with yanked filtering", func(t *testing.T) { + serverName := "com.example/multi-version-yanked-test" + + // Create server with multiple versions, one yanked + versionsData := []struct { + version string + status model.Status + }{ + {"1.0.0", model.StatusYanked}, // Old version, yanked + {"1.1.0", model.StatusActive}, // Current stable + {"2.0.0", model.StatusActive}, // Latest + } + + for i, v := range versionsData { + serverJSON := &apiv0.ServerJSON{ + Name: serverName, + Description: "Multi-version server", + Version: v.version, + } + officialMeta := &apiv0.RegistryExtensions{ + Status: v.status, + StatusChangedAt: timeNow, + PublishedAt: timeNow, + UpdatedAt: timeNow, + IsLatest: i == len(versionsData)-1, + } + + _, err := db.CreateServer(ctx, nil, serverJSON, officialMeta) + require.NoError(t, err) + } + + // Without IncludeYanked - should get only active versions + filter := &database.ServerFilter{ + Name: stringPtr(serverName), + IncludeYanked: boolPtr(false), + } + results, _, err := db.ListServers(ctx, nil, filter, "", 10) + require.NoError(t, err) + assert.Len(t, results, 2, "Should only get non-yanked versions") + + // With IncludeYanked - should get all versions + filter.IncludeYanked = boolPtr(true) + results, _, err = db.ListServers(ctx, nil, filter, "", 10) + require.NoError(t, err) + assert.Len(t, results, 3, "Should get all versions including yanked") + }) +} + // Helper functions for creating pointers to basic types func stringPtr(s string) *string { return &s diff --git a/internal/service/registry_service.go b/internal/service/registry_service.go index 3b41cdef6..4922997fa 100644 --- a/internal/service/registry_service.go +++ b/internal/service/registry_service.go @@ -7,6 +7,7 @@ import ( "time" "github.com/jackc/pgx/v5" + "github.com/modelcontextprotocol/registry/internal/config" "github.com/modelcontextprotocol/registry/internal/database" "github.com/modelcontextprotocol/registry/internal/validators" @@ -152,10 +153,11 @@ func (s *registryServiceImpl) createServerInTransaction(ctx context.Context, tx // Create metadata for the new server officialMeta := &apiv0.RegistryExtensions{ - Status: model.StatusActive, /* New versions are active by default */ - PublishedAt: publishTime, - UpdatedAt: publishTime, - IsLatest: isNewLatest, + Status: model.StatusActive, /* New versions are active by default */ + StatusChangedAt: publishTime, + PublishedAt: publishTime, + UpdatedAt: publishTime, + IsLatest: isNewLatest, } // Insert new server version @@ -186,29 +188,29 @@ func (s *registryServiceImpl) validateNoDuplicateRemoteURLs(ctx context.Context, } // UpdateServer updates an existing server with new details -func (s *registryServiceImpl) UpdateServer(ctx context.Context, serverName, version string, req *apiv0.ServerJSON, newStatus *string) (*apiv0.ServerResponse, error) { +func (s *registryServiceImpl) UpdateServer(ctx context.Context, serverName, version string, req *apiv0.ServerJSON, statusChange *StatusChangeRequest) (*apiv0.ServerResponse, error) { // Wrap the entire operation in a transaction return database.InTransactionT(ctx, s.db, func(ctx context.Context, tx pgx.Tx) (*apiv0.ServerResponse, error) { - return s.updateServerInTransaction(ctx, tx, serverName, version, req, newStatus) + return s.updateServerInTransaction(ctx, tx, serverName, version, req, statusChange) }) } // updateServerInTransaction contains the actual UpdateServer logic within a transaction -func (s *registryServiceImpl) updateServerInTransaction(ctx context.Context, tx pgx.Tx, serverName, version string, req *apiv0.ServerJSON, newStatus *string) (*apiv0.ServerResponse, error) { - // Get current server to check if it's deleted or being deleted +func (s *registryServiceImpl) updateServerInTransaction(ctx context.Context, tx pgx.Tx, serverName, version string, req *apiv0.ServerJSON, statusChange *StatusChangeRequest) (*apiv0.ServerResponse, error) { + // Get current server to check if it's yanked or being yanked currentServer, err := s.db.GetServerByNameAndVersion(ctx, tx, serverName, version) if err != nil { return nil, err } // Skip registry validation if: - // 1. Server is currently deleted, OR - // 2. Server is being set to deleted status - currentlyDeleted := currentServer.Meta.Official != nil && currentServer.Meta.Official.Status == model.StatusDeleted - beingDeleted := newStatus != nil && *newStatus == string(model.StatusDeleted) - skipRegistryValidation := currentlyDeleted || beingDeleted + // 1. Server is currently yanked, OR + // 2. Server is being set to yanked status + currentlyYanked := currentServer.Meta.Official != nil && currentServer.Meta.Official.Status == model.StatusYanked + beingYanked := statusChange != nil && statusChange.NewStatus == model.StatusYanked + skipRegistryValidation := currentlyYanked || beingYanked - // Validate the request, potentially skipping registry validation for deleted servers + // Validate the request, potentially skipping registry validation for yanked servers if err := validators.ValidateUpdateRequest(ctx, *req, s.cfg, skipRegistryValidation); err != nil { return nil, err } @@ -233,8 +235,8 @@ func (s *registryServiceImpl) updateServerInTransaction(ctx context.Context, tx } // Handle status change if provided - if newStatus != nil { - updatedWithStatus, err := s.db.SetServerStatus(ctx, tx, serverName, version, *newStatus) + if statusChange != nil { + updatedWithStatus, err := s.db.SetServerStatus(ctx, tx, serverName, version, statusChange.NewStatus, statusChange.StatusMessage, statusChange.AlternativeURL, statusChange.NewName) if err != nil { return nil, err } @@ -243,3 +245,47 @@ func (s *registryServiceImpl) updateServerInTransaction(ctx context.Context, tx return updatedServerResponse, nil } + +// UpdateServerStatus updates only the status metadata of a server version +func (s *registryServiceImpl) UpdateServerStatus(ctx context.Context, serverName, version string, statusChange *StatusChangeRequest) (*apiv0.ServerResponse, error) { + // Wrap the entire operation in a transaction + return database.InTransactionT(ctx, s.db, func(ctx context.Context, tx pgx.Tx) (*apiv0.ServerResponse, error) { + return s.updateServerStatusInTransaction(ctx, tx, serverName, version, statusChange) + }) +} + +// updateServerStatusInTransaction contains the actual UpdateServerStatus logic within a transaction +func (s *registryServiceImpl) updateServerStatusInTransaction(ctx context.Context, tx pgx.Tx, serverName, version string, statusChange *StatusChangeRequest) (*apiv0.ServerResponse, error) { + // Get current server to verify it exists + _, err := s.db.GetServerByNameAndVersion(ctx, tx, serverName, version) + if err != nil { + return nil, err + } + + // Acquire advisory lock to prevent concurrent edits of servers with same name + if err := s.db.AcquirePublishLock(ctx, tx, serverName); err != nil { + return nil, err + } + + // Update only the status metadata + return s.db.SetServerStatus(ctx, tx, serverName, version, statusChange.NewStatus, statusChange.StatusMessage, statusChange.AlternativeURL, statusChange.NewName) +} + +// UpdateAllVersionsStatus updates the status metadata of all versions of a server in a single transaction +func (s *registryServiceImpl) UpdateAllVersionsStatus(ctx context.Context, serverName string, statusChange *StatusChangeRequest) ([]*apiv0.ServerResponse, error) { + // Wrap the entire operation in a transaction + return database.InTransactionT(ctx, s.db, func(ctx context.Context, tx pgx.Tx) ([]*apiv0.ServerResponse, error) { + return s.updateAllVersionsStatusInTransaction(ctx, tx, serverName, statusChange) + }) +} + +// updateAllVersionsStatusInTransaction contains the actual UpdateAllVersionsStatus logic within a transaction +func (s *registryServiceImpl) updateAllVersionsStatusInTransaction(ctx context.Context, tx pgx.Tx, serverName string, statusChange *StatusChangeRequest) ([]*apiv0.ServerResponse, error) { + // Acquire advisory lock to prevent concurrent edits of servers with same name + if err := s.db.AcquirePublishLock(ctx, tx, serverName); err != nil { + return nil, err + } + + // Update all versions' status in a single database call + return s.db.SetAllVersionsStatus(ctx, tx, serverName, statusChange.NewStatus, statusChange.StatusMessage, statusChange.AlternativeURL, statusChange.NewName) +} diff --git a/internal/service/registry_service_test.go b/internal/service/registry_service_test.go index b091e1eb4..f7aa77d69 100644 --- a/internal/service/registry_service_test.go +++ b/internal/service/registry_service_test.go @@ -466,7 +466,7 @@ func TestUpdateServer(t *testing.T) { serverName string version string updatedServer *apiv0.ServerJSON - newStatus *string + statusChange *StatusChangeRequest expectError bool errorMsg string checkResult func(*testing.T, *apiv0.ServerResponse) @@ -503,7 +503,9 @@ func TestUpdateServer(t *testing.T) { Description: "Updated with status change", Version: version, }, - newStatus: stringPtr(string(model.StatusDeprecated)), + statusChange: &StatusChangeRequest{ + NewStatus: model.StatusDeprecated, + }, expectError: false, checkResult: func(t *testing.T, result *apiv0.ServerResponse) { t.Helper() @@ -528,7 +530,7 @@ func TestUpdateServer(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := service.UpdateServer(ctx, tt.serverName, tt.version, tt.updatedServer, tt.newStatus) + result, err := service.UpdateServer(ctx, tt.serverName, tt.version, tt.updatedServer, tt.statusChange) if tt.expectError { assert.Error(t, err) @@ -547,10 +549,10 @@ func TestUpdateServer(t *testing.T) { } } -func TestUpdateServer_SkipValidationForDeletedServers(t *testing.T) { +func TestUpdateServer_SkipValidationForYankedServers(t *testing.T) { ctx := context.Background() testDB := database.NewTestDB(t) - // Enable registry validation to test that it gets skipped for deleted servers + // Enable registry validation to test that it gets skipped for yanked servers service := NewRegistryService(testDB, &config.Config{EnableRegistryValidation: true}) serverName := "com.example/validation-skip-test" @@ -579,21 +581,23 @@ func TestUpdateServer_SkipValidationForDeletedServers(t *testing.T) { require.NoError(t, err, "failed to create server with validation disabled") service.(*registryServiceImpl).cfg.EnableRegistryValidation = originalConfig - // First, set server to deleted status - deletedStatus := string(model.StatusDeleted) - _, err = service.UpdateServer(ctx, serverName, version, invalidServer, &deletedStatus) - require.NoError(t, err, "should be able to set server to deleted (validation should be skipped)") + // First, set server to yanked status + yankedStatusChange := &StatusChangeRequest{ + NewStatus: model.StatusYanked, + } + _, err = service.UpdateServer(ctx, serverName, version, invalidServer, yankedStatusChange) + require.NoError(t, err, "should be able to set server to yanked (validation should be skipped)") - // Verify server is now deleted + // Verify server is now yanked updatedServer, err := service.GetServerByNameAndVersion(ctx, serverName, version) require.NoError(t, err) - assert.Equal(t, model.StatusDeleted, updatedServer.Meta.Official.Status) + assert.Equal(t, model.StatusYanked, updatedServer.Meta.Official.Status) - // Now try to update a deleted server - validation should be skipped + // Now try to update a yanked server - validation should be skipped updatedInvalidServer := &apiv0.ServerJSON{ Schema: model.CurrentSchemaURL, Name: serverName, - Description: "Updated description for deleted server", + Description: "Updated description for yanked server", Version: version, Packages: []model.Package{ { @@ -605,18 +609,18 @@ func TestUpdateServer_SkipValidationForDeletedServers(t *testing.T) { }, } - // This should succeed despite invalid packages because server is deleted + // This should succeed despite invalid packages because server is yanked result, err := service.UpdateServer(ctx, serverName, version, updatedInvalidServer, nil) - assert.NoError(t, err, "updating deleted server should skip registry validation") + assert.NoError(t, err, "updating yanked server should skip registry validation") assert.NotNil(t, result) - assert.Equal(t, "Updated description for deleted server", result.Server.Description) - assert.Equal(t, model.StatusDeleted, result.Meta.Official.Status) + assert.Equal(t, "Updated description for yanked server", result.Server.Description) + assert.Equal(t, model.StatusYanked, result.Meta.Official.Status) - // Test updating a server being set to deleted status + // Test updating a server being set to yanked status activeServer := &apiv0.ServerJSON{ Schema: model.CurrentSchemaURL, - Name: "com.example/being-deleted-test", - Description: "Server being deleted", + Name: "com.example/being-yanked-test", + Description: "Server being yanked", Version: "1.0.0", Packages: []model.Package{ { @@ -634,12 +638,14 @@ func TestUpdateServer_SkipValidationForDeletedServers(t *testing.T) { require.NoError(t, err) service.(*registryServiceImpl).cfg.EnableRegistryValidation = originalConfig - // Update server and set to deleted in same operation - should skip validation - newDeletedStatus := string(model.StatusDeleted) - result2, err := service.UpdateServer(ctx, "com.example/being-deleted-test", "1.0.0", activeServer, &newDeletedStatus) - assert.NoError(t, err, "updating server being set to deleted should skip registry validation") + // Update server and set to yanked in same operation - should skip validation + newYankedStatusChange := &StatusChangeRequest{ + NewStatus: model.StatusYanked, + } + result2, err := service.UpdateServer(ctx, "com.example/being-yanked-test", "1.0.0", activeServer, newYankedStatusChange) + assert.NoError(t, err, "updating server being set to yanked should skip registry validation") assert.NotNil(t, result2) - assert.Equal(t, model.StatusDeleted, result2.Meta.Official.Status) + assert.Equal(t, model.StatusYanked, result2.Meta.Official.Status) } func TestListServers(t *testing.T) { diff --git a/internal/service/service.go b/internal/service/service.go index 4ebd8ec8c..8ad47269c 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -5,8 +5,17 @@ import ( "github.com/modelcontextprotocol/registry/internal/database" apiv0 "github.com/modelcontextprotocol/registry/pkg/api/v0" + "github.com/modelcontextprotocol/registry/pkg/model" ) +// StatusChangeRequest represents a request to change a server's status +type StatusChangeRequest struct { + NewStatus model.Status `json:"newStatus"` + StatusMessage *string `json:"statusMessage,omitempty"` + AlternativeURL *string `json:"alternativeUrl,omitempty"` + NewName *string `json:"newName,omitempty"` +} + // RegistryService defines the interface for registry operations type RegistryService interface { // ListServers retrieve all servers with optional filtering @@ -20,5 +29,9 @@ type RegistryService interface { // CreateServer creates a new server version CreateServer(ctx context.Context, req *apiv0.ServerJSON) (*apiv0.ServerResponse, error) // UpdateServer updates an existing server and optionally its status - UpdateServer(ctx context.Context, serverName, version string, req *apiv0.ServerJSON, newStatus *string) (*apiv0.ServerResponse, error) + UpdateServer(ctx context.Context, serverName, version string, req *apiv0.ServerJSON, statusChange *StatusChangeRequest) (*apiv0.ServerResponse, error) + // UpdateServerStatus updates only the status metadata of a server version + UpdateServerStatus(ctx context.Context, serverName, version string, statusChange *StatusChangeRequest) (*apiv0.ServerResponse, error) + // UpdateAllVersionsStatus updates the status metadata of all versions of a server in a single transaction + UpdateAllVersionsStatus(ctx context.Context, serverName string, statusChange *StatusChangeRequest) ([]*apiv0.ServerResponse, error) } diff --git a/pkg/api/v0/types.go b/pkg/api/v0/types.go index 513e125be..f1db28d7d 100644 --- a/pkg/api/v0/types.go +++ b/pkg/api/v0/types.go @@ -7,10 +7,14 @@ import ( ) type RegistryExtensions struct { - Status model.Status `json:"status" enum:"active,deprecated,deleted" doc:"Server lifecycle status"` - PublishedAt time.Time `json:"publishedAt" format:"date-time" doc:"Timestamp when the server was first published to the registry"` - UpdatedAt time.Time `json:"updatedAt,omitempty" format:"date-time" doc:"Timestamp when the server entry was last updated"` - IsLatest bool `json:"isLatest" doc:"Whether this is the latest version of the server"` + Status model.Status `json:"status" enum:"active,deprecated,yanked" doc:"Server lifecycle status"` + StatusChangedAt time.Time `json:"statusChangedAt" format:"date-time" doc:"Timestamp when the server status was last changed"` + StatusMessage *string `json:"statusMessage,omitempty" doc:"Optional message explaining status change (e.g., deprecation reason, migration guidance)"` + AlternativeURL *string `json:"alternativeUrl,omitempty" format:"uri" doc:"Optional URL to alternative/replacement server for deprecated or yanked servers"` + NewName *string `json:"newName,omitempty" doc:"Optional new server name when server has been renamed"` + PublishedAt time.Time `json:"publishedAt" format:"date-time" doc:"Timestamp when the server was first published to the registry"` + UpdatedAt time.Time `json:"updatedAt,omitempty" format:"date-time" doc:"Timestamp when the server entry was last updated"` + IsLatest bool `json:"isLatest" doc:"Whether this is the latest version of the server"` } type ResponseMeta struct { diff --git a/pkg/model/types.go b/pkg/model/types.go index 8a7443d88..1afeb0a10 100644 --- a/pkg/model/types.go +++ b/pkg/model/types.go @@ -5,7 +5,7 @@ type Status string const ( StatusActive Status = "active" StatusDeprecated Status = "deprecated" - StatusDeleted Status = "deleted" + StatusYanked Status = "yanked" ) // Transport represents transport configuration for both Package and Remote contexts. diff --git a/scripts/mirror_data/load_production_data.go b/scripts/mirror_data/load_production_data.go index 71094467c..461394a03 100644 --- a/scripts/mirror_data/load_production_data.go +++ b/scripts/mirror_data/load_production_data.go @@ -101,7 +101,7 @@ func main() { COUNT(CASE WHEN value->>'status' = 'null' THEN 1 END) as string_null_status, COUNT(CASE WHEN value->>'status' = 'active' THEN 1 END) as active_status, COUNT(CASE WHEN value->>'status' = 'deprecated' THEN 1 END) as deprecated_status, - COUNT(CASE WHEN value->>'status' = 'deleted' THEN 1 END) as deleted_status + COUNT(CASE WHEN value->>'status' = 'yanked' THEN 1 END) as yanked_status FROM servers `) if err != nil { @@ -110,8 +110,8 @@ func main() { defer rows.Close() if rows.Next() { - var total, nullStatus, emptyStatus, stringNullStatus, activeStatus, deprecatedStatus, deletedStatus int - rows.Scan(&total, &nullStatus, &emptyStatus, &stringNullStatus, &activeStatus, &deprecatedStatus, &deletedStatus) + var total, nullStatus, emptyStatus, stringNullStatus, activeStatus, deprecatedStatus, yankedStatus int + rows.Scan(&total, &nullStatus, &emptyStatus, &stringNullStatus, &activeStatus, &deprecatedStatus, &yankedStatus) fmt.Printf(" Total servers: %d\n", total) fmt.Printf(" NULL status: %d\n", nullStatus) @@ -119,8 +119,8 @@ func main() { fmt.Printf(" 'null' string status: %d\n", stringNullStatus) fmt.Printf(" 'active' status: %d\n", activeStatus) fmt.Printf(" 'deprecated' status: %d\n", deprecatedStatus) - fmt.Printf(" 'deleted' status: %d\n", deletedStatus) - fmt.Printf(" Other/Invalid: %d\n", total-nullStatus-emptyStatus-stringNullStatus-activeStatus-deprecatedStatus-deletedStatus) + fmt.Printf(" 'yanked' status: %d\n", yankedStatus) + fmt.Printf(" Other/Invalid: %d\n", total-nullStatus-emptyStatus-stringNullStatus-activeStatus-deprecatedStatus-yankedStatus) } // Print sample servers with no status