diff --git a/README.md b/README.md index f0c1a7401..631336567 100644 --- a/README.md +++ b/README.md @@ -1052,8 +1052,18 @@ The following sets of tools are available: - `repo`: Repository name (string, required) - `title`: PR title (string, required) +- **get_prs_reviewed_by** - Get PRs reviewed by user + - **Required OAuth Scopes**: `repo` + - `owner`: Repository owner (string, required) + - `page`: Page number for pagination (min 1) (number, optional) + - `perPage`: Results per page for pagination (min 1, max 100) (number, optional) + - `repo`: Repository name (string, required) + - `reviewer`: GitHub username of the reviewer (string, required) + - `state`: PR state filter: open, closed, or all (string, optional) + - **list_pull_requests** - List pull requests - **Required OAuth Scopes**: `repo` + - `author`: Filter by PR author username (client-side filter) (string, optional) - `base`: Filter by base branch (string, optional) - `direction`: Sort direction (string, optional) - `head`: Filter by head user/org and branch (string, optional) @@ -1075,8 +1085,8 @@ The following sets of tools are available: - **pull_request_read** - Get details for a single pull request - **Required OAuth Scopes**: `repo` - - `method`: Action to specify what pull request data needs to be retrieved from GitHub. - Possible options: + - `method`: Action to specify what pull request data needs to be retrieved from GitHub. + Possible options: 1. get - Get details of a specific pull request. 2. get_diff - Get the diff of a pull request. 3. get_status - Get status of a head commit in a pull request. This reflects status of builds and checks. diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index a11fe29a5..1a074a3bc 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "strings" "github.com/go-viper/mapstructure/v2" "github.com/google/go-github/v79/github" @@ -29,8 +30,8 @@ func PullRequestRead(t translations.TranslationHelperFunc) inventory.ServerTool Properties: map[string]*jsonschema.Schema{ "method": { Type: "string", - Description: `Action to specify what pull request data needs to be retrieved from GitHub. -Possible options: + Description: `Action to specify what pull request data needs to be retrieved from GitHub. +Possible options: 1. get - Get details of a specific pull request. 2. get_diff - Get the diff of a pull request. 3. get_status - Get status of a head commit in a pull request. This reflects status of builds and checks. @@ -1046,6 +1047,10 @@ func ListPullRequests(t translations.TranslationHelperFunc) inventory.ServerTool Description: "Sort direction", Enum: []any{"asc", "desc"}, }, + "author": { + Type: "string", + Description: "Filter by PR author username (client-side filter)", + }, }, Required: []string{"owner", "repo"}, } @@ -1055,7 +1060,7 @@ func ListPullRequests(t translations.TranslationHelperFunc) inventory.ServerTool ToolsetMetadataPullRequests, mcp.Tool{ Name: "list_pull_requests", - Description: t("TOOL_LIST_PULL_REQUESTS_DESCRIPTION", "List pull requests in a GitHub repository. If the user specifies an author, then DO NOT use this tool and use the search_pull_requests tool instead."), + Description: t("TOOL_LIST_PULL_REQUESTS_DESCRIPTION", "List pull requests in a GitHub repository. If the user specifies an author, then DO NOT use this tool and use the search_pull_requests tool instead. If you receive a 422 error from search_pull_requests, then use the get_prs_reviewed_by tool (to list by reviewer), or this tool with the author parameter (for filtering by author) depending on what you need."), Annotations: &mcp.ToolAnnotations{ Title: t("TOOL_LIST_PULL_REQUESTS_USER_TITLE", "List pull requests"), ReadOnlyHint: true, @@ -1092,6 +1097,10 @@ func ListPullRequests(t translations.TranslationHelperFunc) inventory.ServerTool if err != nil { return utils.NewToolResultError(err.Error()), nil, nil } + author, err := OptionalParam[string](args, "author") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } pagination, err := OptionalPaginationParams(args) if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -1131,6 +1140,18 @@ func ListPullRequests(t translations.TranslationHelperFunc) inventory.ServerTool return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list pull requests", resp, bodyBytes), nil, nil } + // Filter by author if specified (client-side filtering) + if author != "" { + filtered := make([]*github.PullRequest, 0) + for _, pr := range prs { + if pr != nil && pr.User != nil && pr.User.Login != nil && + strings.EqualFold(*pr.User.Login, author) { + filtered = append(filtered, pr) + } + } + prs = filtered + } + // sanitize title/body on each PR for _, pr := range prs { if pr == nil { @@ -1153,6 +1174,154 @@ func ListPullRequests(t translations.TranslationHelperFunc) inventory.ServerTool }) } +// GetPRsReviewedBy creates a tool for finding PRs reviewed by a specific user +func GetPRsReviewedBy(t translations.TranslationHelperFunc) inventory.ServerTool { + schema := &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "reviewer": { + Type: "string", + Description: "GitHub username of the reviewer", + }, + "state": { + Type: "string", + Description: "PR state filter: open, closed, or all", + Enum: []any{"open", "closed", "all"}, + }, + }, + Required: []string{"owner", "repo", "reviewer"}, + } + WithPagination(schema) + + return NewTool( + ToolsetMetadataPullRequests, + mcp.Tool{ + Name: "get_prs_reviewed_by", + Description: t("TOOL_GET_PRS_REVIEWED_BY_DESCRIPTION", + "Find PRs reviewed by a user. Use this tool if you receive a 422 error when using the search_pull_requests tool."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_GET_PRS_REVIEWED_BY_TITLE", "Get PRs reviewed by user"), + ReadOnlyHint: true, + }, + InputSchema: schema, + }, + []scopes.Scope{scopes.Repo}, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + reviewer, err := RequiredParam[string](args, "reviewer") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + state, err := OptionalParam[string](args, "state") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + if state == "" { + state = "all" + } + + opts := &github.PullRequestListOptions{ + State: state, + ListOptions: github.ListOptions{ + PerPage: 100, + }, + } + + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + if pagination.Page > 0 { + opts.ListOptions.Page = pagination.Page + } + if pagination.PerPage > 0 { + opts.ListOptions.PerPage = pagination.PerPage + } + + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + + // List all PRs + prs, resp, err := client.PullRequests.List(ctx, owner, repo, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to list pull requests", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + // Filter PRs by reviewer + var reviewedPRs []*github.PullRequest + for _, pr := range prs { + if pr.Number == nil { + continue + } + reviews, _, err := client.PullRequests.ListReviews(ctx, owner, repo, *pr.Number, nil) + if err != nil { + continue // Skip PRs we can't get reviews for + } + for _, review := range reviews { + if review.User != nil && review.User.Login != nil && + strings.EqualFold(*review.User.Login, reviewer) { + reviewedPRs = append(reviewedPRs, pr) + break + } + } + } + + // Sanitize the results + sanitized := make([]map[string]any, 0, len(reviewedPRs)) + for _, pr := range reviewedPRs { + if pr.Title != nil { + pr.Title = github.Ptr(sanitize.Sanitize(*pr.Title)) + } + if pr.Body != nil { + pr.Body = github.Ptr(sanitize.Sanitize(*pr.Body)) + } + sanitized = append(sanitized, map[string]any{ + "number": pr.GetNumber(), + "title": pr.GetTitle(), + "state": pr.GetState(), + "html_url": pr.GetHTMLURL(), + "user": pr.GetUser().GetLogin(), + "draft": pr.GetDraft(), + }) + } + + result := map[string]any{ + "pull_requests": sanitized, + "total_count": len(sanitized), + } + + r, err := json.Marshal(result) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + } + + return utils.NewToolResultText(string(r)), nil, nil + }, + ) +} + // MergePullRequest creates a tool to merge a pull request. func MergePullRequest(t translations.TranslationHelperFunc) inventory.ServerTool { schema := &jsonschema.Schema{ @@ -1310,7 +1479,7 @@ func SearchPullRequests(t translations.TranslationHelperFunc) inventory.ServerTo ToolsetMetadataPullRequests, mcp.Tool{ Name: "search_pull_requests", - Description: t("TOOL_SEARCH_PULL_REQUESTS_DESCRIPTION", "Search for pull requests in GitHub repositories using issues search syntax already scoped to is:pr"), + Description: t("TOOL_SEARCH_PULL_REQUESTS_DESCRIPTION", "Search for pull requests in GitHub repositories using issues search syntax already scoped to is:pr. If you receive a 422 error, then use the get_prs_reviewed_by tool instead."), Annotations: &mcp.ToolAnnotations{ Title: t("TOOL_SEARCH_PULL_REQUESTS_USER_TITLE", "Search pull requests"), ReadOnlyHint: true, diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 61a4ad7f1..4543b4440 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -685,6 +685,185 @@ func Test_ListPullRequests(t *testing.T) { } } +func Test_GetPRsReviewedBy(t *testing.T) { + // Verify tool definition once + serverTool := GetPRsReviewedBy(translations.NullTranslationHelper) + tool := serverTool.Tool + require.NoError(t, toolsnaps.Test(tool.Name, tool)) + + assert.Equal(t, "get_prs_reviewed_by", tool.Name) + assert.NotEmpty(t, tool.Description) + schema := tool.InputSchema.(*jsonschema.Schema) + assert.Contains(t, schema.Properties, "owner") + assert.Contains(t, schema.Properties, "repo") + assert.Contains(t, schema.Properties, "reviewer") + assert.Contains(t, schema.Properties, "state") + assert.Contains(t, schema.Properties, "perPage") + assert.Contains(t, schema.Properties, "page") + assert.ElementsMatch(t, schema.Required, []string{"owner", "repo", "reviewer"}) + + // Setup mock PRs + mockPRs := []*github.PullRequest{ + { + Number: github.Ptr(42), + Title: github.Ptr("First PR"), + State: github.Ptr("open"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"), + User: &github.User{Login: github.Ptr("author1")}, + }, + { + Number: github.Ptr(43), + Title: github.Ptr("Second PR"), + State: github.Ptr("closed"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/43"), + User: &github.User{Login: github.Ptr("author2")}, + }, + { + Number: github.Ptr(44), + Title: github.Ptr("Third PR"), + State: github.Ptr("open"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/44"), + User: &github.User{Login: github.Ptr("author3")}, + }, + } + + // Mock reviews for each PR + reviewsPR42 := []*github.PullRequestReview{ + {ID: github.Ptr(int64(1)), User: &github.User{Login: github.Ptr("reviewer1")}}, + {ID: github.Ptr(int64(2)), User: &github.User{Login: github.Ptr("reviewer2")}}, + } + reviewsPR43 := []*github.PullRequestReview{ + {ID: github.Ptr(int64(3)), User: &github.User{Login: github.Ptr("reviewer2")}}, + } + reviewsPR44 := []*github.PullRequestReview{ + {ID: github.Ptr(int64(4)), User: &github.User{Login: github.Ptr("reviewer1")}}, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedCount int + expectedErrMsg string + }{ + { + name: "find PRs reviewed by reviewer1", + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposPullsByOwnerByRepo: mockResponse(t, http.StatusOK, mockPRs), + "GET /repos/{owner}/{repo}/pulls/42/reviews": mockResponse(t, http.StatusOK, reviewsPR42), + "GET /repos/{owner}/{repo}/pulls/43/reviews": mockResponse(t, http.StatusOK, reviewsPR43), + "GET /repos/{owner}/{repo}/pulls/44/reviews": mockResponse(t, http.StatusOK, reviewsPR44), + }), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "reviewer": "reviewer1", + }, + expectError: false, + expectedCount: 2, // PR 42 and 44 + }, + { + name: "find PRs reviewed by reviewer2", + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposPullsByOwnerByRepo: mockResponse(t, http.StatusOK, mockPRs), + "GET /repos/{owner}/{repo}/pulls/42/reviews": mockResponse(t, http.StatusOK, reviewsPR42), + "GET /repos/{owner}/{repo}/pulls/43/reviews": mockResponse(t, http.StatusOK, reviewsPR43), + "GET /repos/{owner}/{repo}/pulls/44/reviews": mockResponse(t, http.StatusOK, reviewsPR44), + }), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "reviewer": "reviewer2", + }, + expectError: false, + expectedCount: 2, // PR 42 and 43 + }, + { + name: "case insensitive reviewer matching", + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposPullsByOwnerByRepo: mockResponse(t, http.StatusOK, mockPRs), + "GET /repos/{owner}/{repo}/pulls/42/reviews": mockResponse(t, http.StatusOK, reviewsPR42), + "GET /repos/{owner}/{repo}/pulls/43/reviews": mockResponse(t, http.StatusOK, reviewsPR43), + "GET /repos/{owner}/{repo}/pulls/44/reviews": mockResponse(t, http.StatusOK, reviewsPR44), + }), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "reviewer": "REVIEWER1", + }, + expectError: false, + expectedCount: 2, + }, + { + name: "no PRs reviewed by unknown reviewer", + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposPullsByOwnerByRepo: mockResponse(t, http.StatusOK, mockPRs), + "GET /repos/{owner}/{repo}/pulls/42/reviews": mockResponse(t, http.StatusOK, reviewsPR42), + "GET /repos/{owner}/{repo}/pulls/43/reviews": mockResponse(t, http.StatusOK, reviewsPR43), + "GET /repos/{owner}/{repo}/pulls/44/reviews": mockResponse(t, http.StatusOK, reviewsPR44), + }), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "reviewer": "unknownuser", + }, + expectError: false, + expectedCount: 0, + }, + { + name: "listing PRs fails", + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposPullsByOwnerByRepo: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }, + }), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "nonexistent", + "reviewer": "reviewer1", + }, + expectError: true, + expectedErrMsg: "failed to list pull requests", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) + + request := createMCPRequest(tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + + if tc.expectError { + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) + return + } + + require.NoError(t, err) + require.False(t, result.IsError) + + textContent := getTextResult(t, result) + var response map[string]interface{} + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + + assert.Equal(t, float64(tc.expectedCount), response["total_count"]) + pullRequests, ok := response["pull_requests"].([]interface{}) + require.True(t, ok) + assert.Len(t, pullRequests, tc.expectedCount) + }) + } +} + func Test_MergePullRequest(t *testing.T) { // Verify tool definition once serverTool := MergePullRequest(translations.NullTranslationHelper) @@ -799,6 +978,120 @@ func Test_MergePullRequest(t *testing.T) { } } +func Test_ListPullRequests_AuthorFilter(t *testing.T) { + // Setup mock PRs from multiple authors + mockPRs := []*github.PullRequest{ + { + Number: github.Ptr(42), + Title: github.Ptr("PR by user1"), + State: github.Ptr("open"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"), + User: &github.User{ + Login: github.Ptr("user1"), + }, + }, + { + Number: github.Ptr(43), + Title: github.Ptr("PR by user2"), + State: github.Ptr("open"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/43"), + User: &github.User{ + Login: github.Ptr("user2"), + }, + }, + { + Number: github.Ptr(44), + Title: github.Ptr("Another PR by user1"), + State: github.Ptr("closed"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/44"), + User: &github.User{ + Login: github.Ptr("user1"), + }, + }, + } + + tests := []struct { + name string + author string + expectedCount int + expectedNumbers []int + }{ + { + name: "filter by user1 returns 2 PRs", + author: "user1", + expectedCount: 2, + expectedNumbers: []int{42, 44}, + }, + { + name: "filter by user2 returns 1 PR", + author: "user2", + expectedCount: 1, + expectedNumbers: []int{43}, + }, + { + name: "filter by USER1 (case insensitive) returns 2 PRs", + author: "USER1", + expectedCount: 2, + expectedNumbers: []int{42, 44}, + }, + { + name: "filter by nonexistent user returns 0 PRs", + author: "nonexistent", + expectedCount: 0, + expectedNumbers: []int{}, + }, + { + name: "no author filter returns all PRs", + author: "", + expectedCount: 3, + expectedNumbers: []int{42, 43, 44}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockedClient := MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposPullsByOwnerByRepo: mockResponse(t, http.StatusOK, mockPRs), + }) + + client := github.NewClient(mockedClient) + serverTool := ListPullRequests(translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) + + requestArgs := map[string]interface{}{ + "owner": "owner", + "repo": "repo", + } + if tc.author != "" { + requestArgs["author"] = tc.author + } + + request := createMCPRequest(requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + + require.NoError(t, err) + require.False(t, result.IsError) + + textContent := getTextResult(t, result) + var returnedPRs []*github.PullRequest + err = json.Unmarshal([]byte(textContent.Text), &returnedPRs) + require.NoError(t, err) + + assert.Len(t, returnedPRs, tc.expectedCount) + + // Verify the expected PR numbers are returned + returnedNumbers := make([]int, len(returnedPRs)) + for i, pr := range returnedPRs { + returnedNumbers[i] = *pr.Number + } + assert.ElementsMatch(t, tc.expectedNumbers, returnedNumbers) + }) + } +} + func Test_SearchPullRequests(t *testing.T) { serverTool := SearchPullRequests(translations.NullTranslationHelper) tool := serverTool.Tool diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 676976140..45ef5a0da 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -205,6 +205,7 @@ func AllTools(t translations.TranslationHelperFunc) []inventory.ServerTool { // Pull request tools PullRequestRead(t), ListPullRequests(t), + GetPRsReviewedBy(t), SearchPullRequests(t), MergePullRequest(t), UpdatePullRequestBranch(t),