diff --git a/pkg/github/__toolsnaps__/create_discussion.snap b/pkg/github/__toolsnaps__/create_discussion.snap new file mode 100644 index 000000000..9d188c61b --- /dev/null +++ b/pkg/github/__toolsnaps__/create_discussion.snap @@ -0,0 +1,38 @@ +{ + "annotations": { + "title": "Create discussion" + }, + "description": "Create a new discussion in a repository or organisation.", + "inputSchema": { + "properties": { + "body": { + "description": "Discussion body text in markdown format", + "type": "string" + }, + "categoryId": { + "description": "Category ID where the discussion should be created (obtainable via list_discussion_categories)", + "type": "string" + }, + "owner": { + "description": "Repository owner", + "type": "string" + }, + "repo": { + "description": "Repository name. If not provided, the discussion will be created at the organisation level.", + "type": "string" + }, + "title": { + "description": "Discussion title", + "type": "string" + } + }, + "required": [ + "owner", + "categoryId", + "title", + "body" + ], + "type": "object" + }, + "name": "create_discussion" +} \ No newline at end of file diff --git a/pkg/github/discussions.go b/pkg/github/discussions.go index c03670818..fd72c76f0 100644 --- a/pkg/github/discussions.go +++ b/pkg/github/discussions.go @@ -507,6 +507,140 @@ func GetDiscussionComments(t translations.TranslationHelperFunc) inventory.Serve ) } +// getDiscussionRepositoryID fetches the repository ID needed for createDiscussion mutation +func getDiscussionRepositoryID(ctx context.Context, client *githubv4.Client, owner, repo string) (githubv4.ID, error) { + var repoQuery struct { + Repository struct { + ID githubv4.ID + } `graphql:"repository(owner: $owner, name: $repo)"` + } + vars := map[string]any{ + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), + } + if err := client.Query(ctx, &repoQuery, vars); err != nil { + return "", err + } + return repoQuery.Repository.ID, nil +} + +func CreateDiscussion(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataDiscussions, + mcp.Tool{ + Name: "create_discussion", + Description: t("TOOL_CREATE_DISCUSSION_DESCRIPTION", "Create a new discussion in a repository or organisation."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_CREATE_DISCUSSION_USER_TITLE", "Create discussion"), + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner", + }, + "repo": { + Type: "string", + Description: "Repository name. If not provided, the discussion will be created at the organisation level.", + }, + "categoryId": { + Type: "string", + Description: "Category ID where the discussion should be created (obtainable via list_discussion_categories)", + }, + "title": { + Type: "string", + Description: "Discussion title", + }, + "body": { + Type: "string", + Description: "Discussion body text in markdown format", + }, + }, + Required: []string{"owner", "categoryId", "title", "body"}, + }, + }, + []scopes.Scope{scopes.Repo}, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := OptionalParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + // when not provided, default to the .github repository + // this will create the discussion at the organisation level + if repo == "" { + repo = ".github" + } + + categoryID, err := RequiredParam[string](args, "categoryId") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + title, err := RequiredParam[string](args, "title") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + body, err := RequiredParam[string](args, "body") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + client, err := deps.GetGQLClient(ctx) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil + } + + // Get repository ID first + repoID, err := getDiscussionRepositoryID(ctx, client, owner, repo) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("failed to get repository ID: %v", err)), nil, nil + } + + // Define the mutation + var mutation struct { + CreateDiscussion struct { + Discussion struct { + ID githubv4.ID + Number githubv4.Int + URL githubv4.String + } + } `graphql:"createDiscussion(input: $input)"` + } + + input := githubv4.CreateDiscussionInput{ + RepositoryID: repoID, + CategoryID: githubv4.ID(categoryID), + Title: githubv4.String(title), + Body: githubv4.String(body), + } + + if err := client.Mutate(ctx, &mutation, input, nil); err != nil { + return utils.NewToolResultError(fmt.Sprintf("failed to create discussion: %v", err)), nil, nil + } + + // Build response + response := map[string]interface{}{ + "id": fmt.Sprint(mutation.CreateDiscussion.Discussion.ID), + "number": int(mutation.CreateDiscussion.Discussion.Number), + "url": string(mutation.CreateDiscussion.Discussion.URL), + } + + out, err := json.Marshal(response) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal discussion: %w", err) + } + + return utils.NewToolResultText(string(out)), nil, nil + }, + ) +} + func ListDiscussionCategories(t translations.TranslationHelperFunc) inventory.ServerTool { return NewTool( ToolsetMetadataDiscussions, diff --git a/pkg/github/discussions_test.go b/pkg/github/discussions_test.go index 0ec998280..1c2b8dbb4 100644 --- a/pkg/github/discussions_test.go +++ b/pkg/github/discussions_test.go @@ -819,3 +819,167 @@ func Test_ListDiscussionCategories(t *testing.T) { }) } } + +func Test_CreateDiscussion(t *testing.T) { + t.Parallel() + + toolDef := CreateDiscussion(translations.NullTranslationHelper) + tool := toolDef.Tool + require.NoError(t, toolsnaps.Test(tool.Name, tool)) + + assert.Equal(t, "create_discussion", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.Description, "Create") + + // Verify tool schema with type assertion + schema, ok := tool.InputSchema.(*jsonschema.Schema) + require.True(t, ok, "InputSchema should be *jsonschema.Schema") + assert.Equal(t, "object", schema.Type) + assert.Contains(t, schema.Properties, "owner") + assert.Contains(t, schema.Properties, "repo") + assert.Contains(t, schema.Properties, "categoryId") + assert.Contains(t, schema.Properties, "title") + assert.Contains(t, schema.Properties, "body") + assert.ElementsMatch(t, schema.Required, []string{"owner", "categoryId", "title", "body"}) + + // Query for getting repository ID + qGetRepoID := struct { + Repository struct { + ID githubv4.ID + } `graphql:"repository(owner: $owner, name: $repo)"` + }{} + + // Mutation for creating discussion + qCreateDiscussion := struct { + CreateDiscussion struct { + Discussion struct { + ID githubv4.ID + Number githubv4.Int + URL githubv4.String + } + } `graphql:"createDiscussion(input: $input)"` + }{} + + tests := []struct { + name string + reqParams map[string]any + repoVars map[string]any + repoResponse githubv4mock.GQLResponse + mutInput githubv4.CreateDiscussionInput + mutResponse githubv4mock.GQLResponse + expectError bool + expectedID string + expectedNum int + expectedURL string + }{ + { + name: "successful discussion creation", + reqParams: map[string]any{ + "owner": "test-owner", + "repo": "test-repo", + "categoryId": "cat-123", + "title": "Test Discussion", + "body": "This is the body of the test discussion", + }, + repoVars: map[string]any{ + "owner": githubv4.String("test-owner"), + "repo": githubv4.String("test-repo"), + }, + repoResponse: githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "id": "repo-id-123", + }, + }), + mutInput: githubv4.CreateDiscussionInput{ + RepositoryID: githubv4.ID("repo-id-123"), + CategoryID: githubv4.ID("cat-123"), + Title: githubv4.String("Test Discussion"), + Body: githubv4.String("This is the body of the test discussion"), + }, + mutResponse: githubv4mock.DataResponse(map[string]any{ + "createDiscussion": map[string]any{ + "discussion": map[string]any{ + "id": "disc-123", + "number": 42, + "url": "https://github.com/test-owner/test-repo/discussions/42", + }, + }, + }), + expectError: false, + expectedID: "disc-123", + expectedNum: 42, + expectedURL: "https://github.com/test-owner/test-repo/discussions/42", + }, + { + name: "org level discussion (no repo specified)", + reqParams: map[string]any{ + "owner": "test-org", + "categoryId": "cat-456", + "title": "Org Discussion", + "body": "An org-level discussion body", + }, + repoVars: map[string]any{ + "owner": githubv4.String("test-org"), + "repo": githubv4.String(".github"), + }, + repoResponse: githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "id": "org-repo-id", + }, + }), + mutInput: githubv4.CreateDiscussionInput{ + RepositoryID: githubv4.ID("org-repo-id"), + CategoryID: githubv4.ID("cat-456"), + Title: githubv4.String("Org Discussion"), + Body: githubv4.String("An org-level discussion body"), + }, + mutResponse: githubv4mock.DataResponse(map[string]any{ + "createDiscussion": map[string]any{ + "discussion": map[string]any{ + "id": "org-disc-1", + "number": 1, + "url": "https://github.com/test-org/.github/discussions/1", + }, + }, + }), + expectError: false, + expectedID: "org-disc-1", + expectedNum: 1, + expectedURL: "https://github.com/test-org/.github/discussions/1", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Create matchers for the sequence of GraphQL calls + repoMatcher := githubv4mock.NewQueryMatcher(qGetRepoID, tc.repoVars, tc.repoResponse) + mutMatcher := githubv4mock.NewMutationMatcher(qCreateDiscussion, tc.mutInput, nil, tc.mutResponse) + httpClient := githubv4mock.NewMockedHTTPClient(repoMatcher, mutMatcher) + gqlClient := githubv4.NewClient(httpClient) + + deps := BaseDeps{GQLClient: gqlClient} + handler := toolDef.Handler(deps) + + req := createMCPRequest(tc.reqParams) + res, err := handler(ContextWithDeps(context.Background(), deps), &req) + + if tc.expectError { + require.True(t, res.IsError) + return + } + require.NoError(t, err) + + text := getTextResult(t, res).Text + + var response struct { + ID string `json:"id"` + Number int `json:"number"` + URL string `json:"url"` + } + require.NoError(t, json.Unmarshal([]byte(text), &response)) + assert.Equal(t, tc.expectedID, response.ID) + assert.Equal(t, tc.expectedNum, response.Number) + assert.Equal(t, tc.expectedURL, response.URL) + }) + } +}