From d8a485e6f60f0b4649842ce8840b94b4612730bd Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Sat, 13 Dec 2025 18:40:23 +0100 Subject: [PATCH] refactor(discussions): migrate to NewTool pattern Co-authored-by: Adam Holt --- pkg/github/discussions.go | 659 +++++++++++++++++---------------- pkg/github/discussions_test.go | 70 ++-- pkg/github/tools.go | 8 +- 3 files changed, 380 insertions(+), 357 deletions(-) diff --git a/pkg/github/discussions.go b/pkg/github/discussions.go index 8a5019701..94f7f6f1b 100644 --- a/pkg/github/discussions.go +++ b/pkg/github/discussions.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" + "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/go-viper/mapstructure/v2" @@ -121,8 +122,9 @@ func getQueryType(useOrdering bool, categoryID *githubv4.ID) any { return &BasicNoOrder{} } -func ListDiscussions(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ListDiscussions(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "list_discussions", Description: t("TOOL_LIST_DISCUSSIONS_DESCRIPTION", "List discussions for a repository or organisation."), Annotations: &mcp.ToolAnnotations{ @@ -158,120 +160,124 @@ func ListDiscussions(getGQLClient GetGQLClientFn, t translations.TranslationHelp Required: []string{"owner"}, }), }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := OptionalParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - // when not provided, default to the .github repository - // this will query discussions at the organisation level - if repo == "" { - repo = ".github" - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := OptionalParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + // when not provided, default to the .github repository + // this will query discussions at the organisation level + if repo == "" { + repo = ".github" + } - category, err := OptionalParam[string](args, "category") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + category, err := OptionalParam[string](args, "category") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - orderBy, err := OptionalParam[string](args, "orderBy") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + orderBy, err := OptionalParam[string](args, "orderBy") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - direction, err := OptionalParam[string](args, "direction") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + direction, err := OptionalParam[string](args, "direction") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Get pagination parameters and convert to GraphQL format - pagination, err := OptionalCursorPaginationParams(args) - if err != nil { - return nil, nil, err - } - paginationParams, err := pagination.ToGraphQLParams() - if err != nil { - return nil, nil, err - } + // Get pagination parameters and convert to GraphQL format + pagination, err := OptionalCursorPaginationParams(args) + if err != nil { + return nil, nil, err + } + paginationParams, err := pagination.ToGraphQLParams() + if err != nil { + return nil, nil, err + } - client, err := getGQLClient(ctx) - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil - } + client, err := deps.GetGQLClient(ctx) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil + } - var categoryID *githubv4.ID - if category != "" { - id := githubv4.ID(category) - categoryID = &id - } + var categoryID *githubv4.ID + if category != "" { + id := githubv4.ID(category) + categoryID = &id + } - vars := map[string]interface{}{ - "owner": githubv4.String(owner), - "repo": githubv4.String(repo), - "first": githubv4.Int(*paginationParams.First), - } - if paginationParams.After != nil { - vars["after"] = githubv4.String(*paginationParams.After) - } else { - vars["after"] = (*githubv4.String)(nil) - } + vars := map[string]interface{}{ + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), + "first": githubv4.Int(*paginationParams.First), + } + if paginationParams.After != nil { + vars["after"] = githubv4.String(*paginationParams.After) + } else { + vars["after"] = (*githubv4.String)(nil) + } - // this is an extra check in case the tool description is misinterpreted, because - // we shouldn't use ordering unless both a 'field' and 'direction' are provided - useOrdering := orderBy != "" && direction != "" - if useOrdering { - vars["orderByField"] = githubv4.DiscussionOrderField(orderBy) - vars["orderByDirection"] = githubv4.OrderDirection(direction) - } + // this is an extra check in case the tool description is misinterpreted, because + // we shouldn't use ordering unless both a 'field' and 'direction' are provided + useOrdering := orderBy != "" && direction != "" + if useOrdering { + vars["orderByField"] = githubv4.DiscussionOrderField(orderBy) + vars["orderByDirection"] = githubv4.OrderDirection(direction) + } - if categoryID != nil { - vars["categoryId"] = *categoryID - } + if categoryID != nil { + vars["categoryId"] = *categoryID + } - discussionQuery := getQueryType(useOrdering, categoryID) - if err := client.Query(ctx, discussionQuery, vars); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + discussionQuery := getQueryType(useOrdering, categoryID) + if err := client.Query(ctx, discussionQuery, vars); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Extract and convert all discussion nodes using the common interface - var discussions []*github.Discussion - var pageInfo PageInfoFragment - var totalCount githubv4.Int - if queryResult, ok := discussionQuery.(DiscussionQueryResult); ok { - fragment := queryResult.GetDiscussionFragment() - for _, node := range fragment.Nodes { - discussions = append(discussions, fragmentToDiscussion(node)) - } - pageInfo = fragment.PageInfo - totalCount = fragment.TotalCount - } + // Extract and convert all discussion nodes using the common interface + var discussions []*github.Discussion + var pageInfo PageInfoFragment + var totalCount githubv4.Int + if queryResult, ok := discussionQuery.(DiscussionQueryResult); ok { + fragment := queryResult.GetDiscussionFragment() + for _, node := range fragment.Nodes { + discussions = append(discussions, fragmentToDiscussion(node)) + } + pageInfo = fragment.PageInfo + totalCount = fragment.TotalCount + } - // Create response with pagination info - response := map[string]interface{}{ - "discussions": discussions, - "pageInfo": map[string]interface{}{ - "hasNextPage": pageInfo.HasNextPage, - "hasPreviousPage": pageInfo.HasPreviousPage, - "startCursor": string(pageInfo.StartCursor), - "endCursor": string(pageInfo.EndCursor), - }, - "totalCount": totalCount, - } + // Create response with pagination info + response := map[string]interface{}{ + "discussions": discussions, + "pageInfo": map[string]interface{}{ + "hasNextPage": pageInfo.HasNextPage, + "hasPreviousPage": pageInfo.HasPreviousPage, + "startCursor": string(pageInfo.StartCursor), + "endCursor": string(pageInfo.EndCursor), + }, + "totalCount": totalCount, + } - out, err := json.Marshal(response) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal discussions: %w", err) + out, err := json.Marshal(response) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal discussions: %w", err) + } + return utils.NewToolResultText(string(out)), nil, nil } - return utils.NewToolResultText(string(out)), nil, nil - } + }, + ) } -func GetDiscussion(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetDiscussion(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "get_discussion", Description: t("TOOL_GET_DISCUSSION_DESCRIPTION", "Get a specific discussion by ID"), Annotations: &mcp.ToolAnnotations{ @@ -297,81 +303,85 @@ func GetDiscussion(getGQLClient GetGQLClientFn, t translations.TranslationHelper Required: []string{"owner", "repo", "discussionNumber"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - // Decode params - var params struct { - Owner string - Repo string - DiscussionNumber int32 - } - if err := mapstructure.Decode(args, ¶ms); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - client, err := getGQLClient(ctx) - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + // Decode params + var params struct { + Owner string + Repo string + DiscussionNumber int32 + } + if err := mapstructure.Decode(args, ¶ms); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + client, err := deps.GetGQLClient(ctx) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil + } - var q struct { - Repository struct { - Discussion struct { - Number githubv4.Int - Title githubv4.String - Body githubv4.String - CreatedAt githubv4.DateTime - Closed githubv4.Boolean - IsAnswered githubv4.Boolean - AnswerChosenAt *githubv4.DateTime - URL githubv4.String `graphql:"url"` - Category struct { - Name githubv4.String - } `graphql:"category"` - } `graphql:"discussion(number: $discussionNumber)"` - } `graphql:"repository(owner: $owner, name: $repo)"` - } - vars := map[string]interface{}{ - "owner": githubv4.String(params.Owner), - "repo": githubv4.String(params.Repo), - "discussionNumber": githubv4.Int(params.DiscussionNumber), - } - if err := client.Query(ctx, &q, vars); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - d := q.Repository.Discussion - - // Build response as map to include fields not present in go-github's Discussion struct. - // The go-github library's Discussion type lacks isAnswered and answerChosenAt fields, - // so we use map[string]interface{} for the response (consistent with other functions - // like ListDiscussions and GetDiscussionComments). - response := map[string]interface{}{ - "number": int(d.Number), - "title": string(d.Title), - "body": string(d.Body), - "url": string(d.URL), - "closed": bool(d.Closed), - "isAnswered": bool(d.IsAnswered), - "createdAt": d.CreatedAt.Time, - "category": map[string]interface{}{ - "name": string(d.Category.Name), - }, - } + var q struct { + Repository struct { + Discussion struct { + Number githubv4.Int + Title githubv4.String + Body githubv4.String + CreatedAt githubv4.DateTime + Closed githubv4.Boolean + IsAnswered githubv4.Boolean + AnswerChosenAt *githubv4.DateTime + URL githubv4.String `graphql:"url"` + Category struct { + Name githubv4.String + } `graphql:"category"` + } `graphql:"discussion(number: $discussionNumber)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } + vars := map[string]interface{}{ + "owner": githubv4.String(params.Owner), + "repo": githubv4.String(params.Repo), + "discussionNumber": githubv4.Int(params.DiscussionNumber), + } + if err := client.Query(ctx, &q, vars); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + d := q.Repository.Discussion + + // Build response as map to include fields not present in go-github's Discussion struct. + // The go-github library's Discussion type lacks isAnswered and answerChosenAt fields, + // so we use map[string]interface{} for the response (consistent with other functions + // like ListDiscussions and GetDiscussionComments). + response := map[string]interface{}{ + "number": int(d.Number), + "title": string(d.Title), + "body": string(d.Body), + "url": string(d.URL), + "closed": bool(d.Closed), + "isAnswered": bool(d.IsAnswered), + "createdAt": d.CreatedAt.Time, + "category": map[string]interface{}{ + "name": string(d.Category.Name), + }, + } - // Add optional timestamp fields if present - if d.AnswerChosenAt != nil { - response["answerChosenAt"] = d.AnswerChosenAt.Time - } + // Add optional timestamp fields if present + if d.AnswerChosenAt != nil { + response["answerChosenAt"] = d.AnswerChosenAt.Time + } - out, err := json.Marshal(response) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal discussion: %w", err) - } + out, err := json.Marshal(response) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal discussion: %w", err) + } - return utils.NewToolResultText(string(out)), nil, nil - } + return utils.NewToolResultText(string(out)), nil, nil + } + }, + ) } -func GetDiscussionComments(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetDiscussionComments(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "get_discussion_comments", Description: t("TOOL_GET_DISCUSSION_COMMENTS_DESCRIPTION", "Get comments from a discussion"), Annotations: &mcp.ToolAnnotations{ @@ -397,104 +407,108 @@ func GetDiscussionComments(getGQLClient GetGQLClientFn, t translations.Translati Required: []string{"owner", "repo", "discussionNumber"}, }), }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - // Decode params - var params struct { - Owner string - Repo string - DiscussionNumber int32 - } - if err := mapstructure.Decode(args, ¶ms); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + // Decode params + var params struct { + Owner string + Repo string + DiscussionNumber int32 + } + if err := mapstructure.Decode(args, ¶ms); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Get pagination parameters and convert to GraphQL format - pagination, err := OptionalCursorPaginationParams(args) - if err != nil { - return nil, nil, err - } + // Get pagination parameters and convert to GraphQL format + pagination, err := OptionalCursorPaginationParams(args) + if err != nil { + return nil, nil, err + } - // Check if pagination parameters were explicitly provided - _, perPageProvided := args["perPage"] - paginationExplicit := perPageProvided + // Check if pagination parameters were explicitly provided + _, perPageProvided := args["perPage"] + paginationExplicit := perPageProvided - paginationParams, err := pagination.ToGraphQLParams() - if err != nil { - return nil, nil, err - } + paginationParams, err := pagination.ToGraphQLParams() + if err != nil { + return nil, nil, err + } - // Use default of 30 if pagination was not explicitly provided - if !paginationExplicit { - defaultFirst := int32(DefaultGraphQLPageSize) - paginationParams.First = &defaultFirst - } + // Use default of 30 if pagination was not explicitly provided + if !paginationExplicit { + defaultFirst := int32(DefaultGraphQLPageSize) + paginationParams.First = &defaultFirst + } - client, err := getGQLClient(ctx) - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil - } + client, err := deps.GetGQLClient(ctx) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil + } - var q struct { - Repository struct { - Discussion struct { - Comments struct { - Nodes []struct { - Body githubv4.String - } - PageInfo struct { - HasNextPage githubv4.Boolean - HasPreviousPage githubv4.Boolean - StartCursor githubv4.String - EndCursor githubv4.String - } - TotalCount int - } `graphql:"comments(first: $first, after: $after)"` - } `graphql:"discussion(number: $discussionNumber)"` - } `graphql:"repository(owner: $owner, name: $repo)"` - } - vars := map[string]interface{}{ - "owner": githubv4.String(params.Owner), - "repo": githubv4.String(params.Repo), - "discussionNumber": githubv4.Int(params.DiscussionNumber), - "first": githubv4.Int(*paginationParams.First), - } - if paginationParams.After != nil { - vars["after"] = githubv4.String(*paginationParams.After) - } else { - vars["after"] = (*githubv4.String)(nil) - } - if err := client.Query(ctx, &q, vars); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + var q struct { + Repository struct { + Discussion struct { + Comments struct { + Nodes []struct { + Body githubv4.String + } + PageInfo struct { + HasNextPage githubv4.Boolean + HasPreviousPage githubv4.Boolean + StartCursor githubv4.String + EndCursor githubv4.String + } + TotalCount int + } `graphql:"comments(first: $first, after: $after)"` + } `graphql:"discussion(number: $discussionNumber)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } + vars := map[string]interface{}{ + "owner": githubv4.String(params.Owner), + "repo": githubv4.String(params.Repo), + "discussionNumber": githubv4.Int(params.DiscussionNumber), + "first": githubv4.Int(*paginationParams.First), + } + if paginationParams.After != nil { + vars["after"] = githubv4.String(*paginationParams.After) + } else { + vars["after"] = (*githubv4.String)(nil) + } + if err := client.Query(ctx, &q, vars); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var comments []*github.IssueComment - for _, c := range q.Repository.Discussion.Comments.Nodes { - comments = append(comments, &github.IssueComment{Body: github.Ptr(string(c.Body))}) - } + var comments []*github.IssueComment + for _, c := range q.Repository.Discussion.Comments.Nodes { + comments = append(comments, &github.IssueComment{Body: github.Ptr(string(c.Body))}) + } - // Create response with pagination info - response := map[string]interface{}{ - "comments": comments, - "pageInfo": map[string]interface{}{ - "hasNextPage": q.Repository.Discussion.Comments.PageInfo.HasNextPage, - "hasPreviousPage": q.Repository.Discussion.Comments.PageInfo.HasPreviousPage, - "startCursor": string(q.Repository.Discussion.Comments.PageInfo.StartCursor), - "endCursor": string(q.Repository.Discussion.Comments.PageInfo.EndCursor), - }, - "totalCount": q.Repository.Discussion.Comments.TotalCount, - } + // Create response with pagination info + response := map[string]interface{}{ + "comments": comments, + "pageInfo": map[string]interface{}{ + "hasNextPage": q.Repository.Discussion.Comments.PageInfo.HasNextPage, + "hasPreviousPage": q.Repository.Discussion.Comments.PageInfo.HasPreviousPage, + "startCursor": string(q.Repository.Discussion.Comments.PageInfo.StartCursor), + "endCursor": string(q.Repository.Discussion.Comments.PageInfo.EndCursor), + }, + "totalCount": q.Repository.Discussion.Comments.TotalCount, + } - out, err := json.Marshal(response) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal comments: %w", err) - } + out, err := json.Marshal(response) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal comments: %w", err) + } - return utils.NewToolResultText(string(out)), nil, nil - } + return utils.NewToolResultText(string(out)), nil, nil + } + }, + ) } -func ListDiscussionCategories(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ListDiscussionCategories(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "list_discussion_categories", Description: t("TOOL_LIST_DISCUSSION_CATEGORIES_DESCRIPTION", "List discussion categories with their id and name, for a repository or organisation."), Annotations: &mcp.ToolAnnotations{ @@ -516,76 +530,79 @@ func ListDiscussionCategories(getGQLClient GetGQLClientFn, t translations.Transl Required: []string{"owner"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := OptionalParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - // when not provided, default to the .github repository - // this will query discussion categories at the organisation level - if repo == "" { - repo = ".github" - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := OptionalParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + // when not provided, default to the .github repository + // this will query discussion categories at the organisation level + if repo == "" { + repo = ".github" + } - client, err := getGQLClient(ctx) - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil - } + client, err := deps.GetGQLClient(ctx) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil + } - var q struct { - Repository struct { - DiscussionCategories struct { - Nodes []struct { - ID githubv4.ID - Name githubv4.String - } - PageInfo struct { - HasNextPage githubv4.Boolean - HasPreviousPage githubv4.Boolean - StartCursor githubv4.String - EndCursor githubv4.String - } - TotalCount int - } `graphql:"discussionCategories(first: $first)"` - } `graphql:"repository(owner: $owner, name: $repo)"` - } - vars := map[string]interface{}{ - "owner": githubv4.String(owner), - "repo": githubv4.String(repo), - "first": githubv4.Int(25), - } - if err := client.Query(ctx, &q, vars); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + var q struct { + Repository struct { + DiscussionCategories struct { + Nodes []struct { + ID githubv4.ID + Name githubv4.String + } + PageInfo struct { + HasNextPage githubv4.Boolean + HasPreviousPage githubv4.Boolean + StartCursor githubv4.String + EndCursor githubv4.String + } + TotalCount int + } `graphql:"discussionCategories(first: $first)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } + vars := map[string]interface{}{ + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), + "first": githubv4.Int(25), + } + if err := client.Query(ctx, &q, vars); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var categories []map[string]string - for _, c := range q.Repository.DiscussionCategories.Nodes { - categories = append(categories, map[string]string{ - "id": fmt.Sprint(c.ID), - "name": string(c.Name), - }) - } + var categories []map[string]string + for _, c := range q.Repository.DiscussionCategories.Nodes { + categories = append(categories, map[string]string{ + "id": fmt.Sprint(c.ID), + "name": string(c.Name), + }) + } - // Create response with pagination info - response := map[string]interface{}{ - "categories": categories, - "pageInfo": map[string]interface{}{ - "hasNextPage": q.Repository.DiscussionCategories.PageInfo.HasNextPage, - "hasPreviousPage": q.Repository.DiscussionCategories.PageInfo.HasPreviousPage, - "startCursor": string(q.Repository.DiscussionCategories.PageInfo.StartCursor), - "endCursor": string(q.Repository.DiscussionCategories.PageInfo.EndCursor), - }, - "totalCount": q.Repository.DiscussionCategories.TotalCount, - } + // Create response with pagination info + response := map[string]interface{}{ + "categories": categories, + "pageInfo": map[string]interface{}{ + "hasNextPage": q.Repository.DiscussionCategories.PageInfo.HasNextPage, + "hasPreviousPage": q.Repository.DiscussionCategories.PageInfo.HasPreviousPage, + "startCursor": string(q.Repository.DiscussionCategories.PageInfo.StartCursor), + "endCursor": string(q.Repository.DiscussionCategories.PageInfo.EndCursor), + }, + "totalCount": q.Repository.DiscussionCategories.TotalCount, + } - out, err := json.Marshal(response) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal discussion categories: %w", err) + out, err := json.Marshal(response) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal discussion categories: %w", err) + } + return utils.NewToolResultText(string(out)), nil, nil } - return utils.NewToolResultText(string(out)), nil, nil - } + }, + ) } diff --git a/pkg/github/discussions_test.go b/pkg/github/discussions_test.go index 1a73d523e..758c82200 100644 --- a/pkg/github/discussions_test.go +++ b/pkg/github/discussions_test.go @@ -213,13 +213,13 @@ var ( ) func Test_ListDiscussions(t *testing.T) { - mockClient := githubv4.NewClient(nil) - toolDef, _ := ListDiscussions(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(toolDef.Name, toolDef)) + toolDef := ListDiscussions(translations.NullTranslationHelper) + tool := toolDef.Tool + require.NoError(t, toolsnaps.Test(tool.Name, tool)) - assert.Equal(t, "list_discussions", toolDef.Name) - assert.NotEmpty(t, toolDef.Description) - schema, ok := toolDef.InputSchema.(*jsonschema.Schema) + assert.Equal(t, "list_discussions", tool.Name) + assert.NotEmpty(t, tool.Description) + schema, ok := tool.InputSchema.(*jsonschema.Schema) require.True(t, ok, "InputSchema should be *jsonschema.Schema") assert.Contains(t, schema.Properties, "owner") assert.Contains(t, schema.Properties, "repo") @@ -447,10 +447,11 @@ func Test_ListDiscussions(t *testing.T) { } gqlClient := githubv4.NewClient(httpClient) - _, handler := ListDiscussions(stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + deps := ToolDependencies{GetGQLClient: stubGetGQLClientFn(gqlClient)} + handler := toolDef.Handler(deps) req := createMCPRequest(tc.reqParams) - res, _, err := handler(context.Background(), &req, tc.reqParams) + res, err := handler(context.Background(), &req) text := getTextResult(t, res).Text if tc.expectError { @@ -494,12 +495,13 @@ func Test_ListDiscussions(t *testing.T) { func Test_GetDiscussion(t *testing.T) { // Verify tool definition and schema - toolDef, _ := GetDiscussion(nil, translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(toolDef.Name, toolDef)) + toolDef := GetDiscussion(translations.NullTranslationHelper) + tool := toolDef.Tool + require.NoError(t, toolsnaps.Test(tool.Name, tool)) - assert.Equal(t, "get_discussion", toolDef.Name) - assert.NotEmpty(t, toolDef.Description) - schema, ok := toolDef.InputSchema.(*jsonschema.Schema) + assert.Equal(t, "get_discussion", tool.Name) + assert.NotEmpty(t, tool.Description) + schema, ok := tool.InputSchema.(*jsonschema.Schema) require.True(t, ok, "InputSchema should be *jsonschema.Schema") assert.Contains(t, schema.Properties, "owner") assert.Contains(t, schema.Properties, "repo") @@ -557,11 +559,12 @@ func Test_GetDiscussion(t *testing.T) { matcher := githubv4mock.NewQueryMatcher(qGetDiscussion, vars, tc.response) httpClient := githubv4mock.NewMockedHTTPClient(matcher) gqlClient := githubv4.NewClient(httpClient) - _, handler := GetDiscussion(stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + deps := ToolDependencies{GetGQLClient: stubGetGQLClientFn(gqlClient)} + handler := toolDef.Handler(deps) reqParams := map[string]interface{}{"owner": "owner", "repo": "repo", "discussionNumber": int32(1)} req := createMCPRequest(reqParams) - res, _, err := handler(context.Background(), &req, reqParams) + res, err := handler(context.Background(), &req) text := getTextResult(t, res).Text if tc.expectError { @@ -589,12 +592,13 @@ func Test_GetDiscussion(t *testing.T) { func Test_GetDiscussionComments(t *testing.T) { // Verify tool definition and schema - toolDef, _ := GetDiscussionComments(nil, translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(toolDef.Name, toolDef)) + toolDef := GetDiscussionComments(translations.NullTranslationHelper) + tool := toolDef.Tool + require.NoError(t, toolsnaps.Test(tool.Name, tool)) - assert.Equal(t, "get_discussion_comments", toolDef.Name) - assert.NotEmpty(t, toolDef.Description) - schema, ok := toolDef.InputSchema.(*jsonschema.Schema) + assert.Equal(t, "get_discussion_comments", tool.Name) + assert.NotEmpty(t, tool.Description) + schema, ok := tool.InputSchema.(*jsonschema.Schema) require.True(t, ok, "InputSchema should be *jsonschema.Schema") assert.Contains(t, schema.Properties, "owner") assert.Contains(t, schema.Properties, "repo") @@ -635,7 +639,8 @@ func Test_GetDiscussionComments(t *testing.T) { matcher := githubv4mock.NewQueryMatcher(qGetComments, vars, mockResponse) httpClient := githubv4mock.NewMockedHTTPClient(matcher) gqlClient := githubv4.NewClient(httpClient) - _, handler := GetDiscussionComments(stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + deps := ToolDependencies{GetGQLClient: stubGetGQLClientFn(gqlClient)} + handler := toolDef.Handler(deps) reqParams := map[string]interface{}{ "owner": "owner", @@ -644,7 +649,7 @@ func Test_GetDiscussionComments(t *testing.T) { } request := createMCPRequest(reqParams) - result, _, err := handler(context.Background(), &request, reqParams) + result, err := handler(context.Background(), &request) require.NoError(t, err) textContent := getTextResult(t, result) @@ -671,14 +676,14 @@ func Test_GetDiscussionComments(t *testing.T) { } func Test_ListDiscussionCategories(t *testing.T) { - mockClient := githubv4.NewClient(nil) - toolDef, _ := ListDiscussionCategories(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(toolDef.Name, toolDef)) - - assert.Equal(t, "list_discussion_categories", toolDef.Name) - assert.NotEmpty(t, toolDef.Description) - assert.Contains(t, toolDef.Description, "or organisation") - schema, ok := toolDef.InputSchema.(*jsonschema.Schema) + toolDef := ListDiscussionCategories(translations.NullTranslationHelper) + tool := toolDef.Tool + require.NoError(t, toolsnaps.Test(tool.Name, tool)) + + assert.Equal(t, "list_discussion_categories", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.Description, "or organisation") + schema, ok := tool.InputSchema.(*jsonschema.Schema) require.True(t, ok, "InputSchema should be *jsonschema.Schema") assert.Contains(t, schema.Properties, "owner") assert.Contains(t, schema.Properties, "repo") @@ -786,10 +791,11 @@ func Test_ListDiscussionCategories(t *testing.T) { httpClient := githubv4mock.NewMockedHTTPClient(matcher) gqlClient := githubv4.NewClient(httpClient) - _, handler := ListDiscussionCategories(stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + deps := ToolDependencies{GetGQLClient: stubGetGQLClientFn(gqlClient)} + handler := toolDef.Handler(deps) req := createMCPRequest(tc.reqParams) - res, _, err := handler(context.Background(), &req, tc.reqParams) + res, err := handler(context.Background(), &req) text := getTextResult(t, res).Text if tc.expectError { diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 947c727c2..d4f473724 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -291,10 +291,10 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG discussions := toolsets.NewToolset(ToolsetMetadataDiscussions.ID, ToolsetMetadataDiscussions.Description). SetDependencies(deps). AddReadTools( - toolsets.NewServerToolLegacy(ListDiscussions(getGQLClient, t)), - toolsets.NewServerToolLegacy(GetDiscussion(getGQLClient, t)), - toolsets.NewServerToolLegacy(GetDiscussionComments(getGQLClient, t)), - toolsets.NewServerToolLegacy(ListDiscussionCategories(getGQLClient, t)), + ListDiscussions(t), + GetDiscussion(t), + GetDiscussionComments(t), + ListDiscussionCategories(t), ) actions := toolsets.NewToolset(ToolsetMetadataActions.ID, ToolsetMetadataActions.Description).