diff --git a/pkg/github/issues.go b/pkg/github/issues.go index ec83e4efa..142bdd421 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -12,6 +12,7 @@ import ( ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/sanitize" + "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/go-viper/mapstructure/v2" @@ -229,7 +230,7 @@ func fragmentToIssue(fragment IssueFragment) *github.Issue { } // IssueRead creates a tool to get details of a specific issue in a GitHub repository. -func IssueRead(getClient GetClientFn, getGQLClient GetGQLClientFn, cache *lockdown.RepoAccessCache, t translations.TranslationHelperFunc, flags FeatureFlags) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func IssueRead(t translations.TranslationHelperFunc) toolsets.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -261,7 +262,8 @@ Options are: } WithPagination(schema) - return mcp.Tool{ + return NewTool( + mcp.Tool{ Name: "issue_read", Description: t("TOOL_ISSUE_READ_DESCRIPTION", "Get information about a specific issue in a GitHub repository."), Annotations: &mcp.ToolAnnotations{ @@ -270,57 +272,59 @@ Options are: }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - method, err := RequiredParam[string](args, "method") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + method, err := RequiredParam[string](args, "method") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - issueNumber, err := RequiredInt(args, "issue_number") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + issueNumber, err := RequiredInt(args, "issue_number") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - gqlClient, err := getGQLClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub graphql client", err), nil, nil - } + gqlClient, err := deps.GetGQLClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub graphql client", err), nil, nil + } - switch method { - case "get": - result, err := GetIssue(ctx, client, cache, owner, repo, issueNumber, flags) - return result, nil, err - case "get_comments": - result, err := GetIssueComments(ctx, client, cache, owner, repo, issueNumber, pagination, flags) - return result, nil, err - case "get_sub_issues": - result, err := GetSubIssues(ctx, client, cache, owner, repo, issueNumber, pagination, flags) - return result, nil, err - case "get_labels": - result, err := GetIssueLabels(ctx, gqlClient, owner, repo, issueNumber) - return result, nil, err - default: - return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil, nil + switch method { + case "get": + result, err := GetIssue(ctx, client, deps.RepoAccessCache, owner, repo, issueNumber, deps.Flags) + return result, nil, err + case "get_comments": + result, err := GetIssueComments(ctx, client, deps.RepoAccessCache, owner, repo, issueNumber, pagination, deps.Flags) + return result, nil, err + case "get_sub_issues": + result, err := GetSubIssues(ctx, client, deps.RepoAccessCache, owner, repo, issueNumber, pagination, deps.Flags) + return result, nil, err + case "get_labels": + result, err := GetIssueLabels(ctx, gqlClient, owner, repo, issueNumber) + return result, nil, err + default: + return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil, nil + } } - } + }) } func GetIssue(ctx context.Context, client *github.Client, cache *lockdown.RepoAccessCache, owner string, repo string, issueNumber int, flags FeatureFlags) (*mcp.CallToolResult, error) { @@ -540,8 +544,9 @@ func GetIssueLabels(ctx context.Context, client *githubv4.Client, owner string, } // ListIssueTypes creates a tool to list defined issue types for an organization. This can be used to understand supported issue type values for creating or updating issues. -func ListIssueTypes(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ListIssueTypes(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "list_issue_types", Description: t("TOOL_LIST_ISSUE_TYPES_FOR_ORG", "List supported issue types for repository owner (organization)."), Annotations: &mcp.ToolAnnotations{ @@ -559,42 +564,45 @@ func ListIssueTypes(getClient GetClientFn, t translations.TranslationHelperFunc) Required: []string{"owner"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - issueTypes, resp, err := client.Organizations.ListIssueTypes(ctx, owner) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to list issue types", err), nil, nil - } - defer func() { _ = resp.Body.Close() }() + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + issueTypes, resp, err := client.Organizations.ListIssueTypes(ctx, owner) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to list issue types", err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to list issue types: %s", string(body))), nil, nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + r, err := json.Marshal(issueTypes) if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to marshal issue types", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to list issue types: %s", string(body))), nil, nil - } - r, err := json.Marshal(issueTypes) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal issue types", err), nil, nil + return utils.NewToolResultText(string(r)), nil, nil } - - return utils.NewToolResultText(string(r)), nil, nil - } + }) } // AddIssueComment creates a tool to add a comment to an issue. -func AddIssueComment(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func AddIssueComment(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "add_issue_comment", Description: t("TOOL_ADD_ISSUE_COMMENT_DESCRIPTION", "Add a comment to a specific issue in a GitHub repository. Use this tool to add comments to pull requests as well (in this case pass pull request number as issue_number), but only if user is not asking specifically to add review comments."), Annotations: &mcp.ToolAnnotations{ @@ -624,58 +632,61 @@ func AddIssueComment(getClient GetClientFn, t translations.TranslationHelperFunc Required: []string{"owner", "repo", "issue_number", "body"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - issueNumber, err := RequiredInt(args, "issue_number") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - body, err := RequiredParam[string](args, "body") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + issueNumber, err := RequiredInt(args, "issue_number") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + body, err := RequiredParam[string](args, "body") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - comment := &github.IssueComment{ - Body: github.Ptr(body), - } + comment := &github.IssueComment{ + Body: github.Ptr(body), + } - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - createdComment, resp, err := client.Issues.CreateComment(ctx, owner, repo, issueNumber, comment) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to create comment", err), nil, nil - } - defer func() { _ = resp.Body.Close() }() + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + createdComment, resp, err := client.Issues.CreateComment(ctx, owner, repo, issueNumber, comment) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to create comment", err), nil, nil + } + defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusCreated { - body, err := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusCreated { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to create comment: %s", string(body))), nil, nil + } + + r, err := json.Marshal(createdComment) if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to create comment: %s", string(body))), nil, nil - } - r, err := json.Marshal(createdComment) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + return utils.NewToolResultText(string(r)), nil, nil } - - return utils.NewToolResultText(string(r)), nil, nil - } + }) } // SubIssueWrite creates a tool to add a sub-issue to a parent issue. -func SubIssueWrite(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func SubIssueWrite(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "sub_issue_write", Description: t("TOOL_SUB_ISSUE_WRITE_DESCRIPTION", "Add a sub-issue to a parent issue in a GitHub repository."), Annotations: &mcp.ToolAnnotations{ @@ -726,62 +737,64 @@ Options are: Required: []string{"method", "owner", "repo", "issue_number", "sub_issue_id"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - method, err := RequiredParam[string](args, "method") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + method, err := RequiredParam[string](args, "method") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - issueNumber, err := RequiredInt(args, "issue_number") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - subIssueID, err := RequiredInt(args, "sub_issue_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - replaceParent, err := OptionalParam[bool](args, "replace_parent") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - afterID, err := OptionalIntParam(args, "after_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - beforeID, err := OptionalIntParam(args, "before_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + issueNumber, err := RequiredInt(args, "issue_number") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + subIssueID, err := RequiredInt(args, "sub_issue_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + replaceParent, err := OptionalParam[bool](args, "replace_parent") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + afterID, err := OptionalIntParam(args, "after_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + beforeID, err := OptionalIntParam(args, "before_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - switch strings.ToLower(method) { - case "add": - result, err := AddSubIssue(ctx, client, owner, repo, issueNumber, subIssueID, replaceParent) - return result, nil, err - case "remove": - // Call the remove sub-issue function - result, err := RemoveSubIssue(ctx, client, owner, repo, issueNumber, subIssueID) - return result, nil, err - case "reprioritize": - // Call the reprioritize sub-issue function - result, err := ReprioritizeSubIssue(ctx, client, owner, repo, issueNumber, subIssueID, afterID, beforeID) - return result, nil, err - default: - return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil, nil + switch strings.ToLower(method) { + case "add": + result, err := AddSubIssue(ctx, client, owner, repo, issueNumber, subIssueID, replaceParent) + return result, nil, err + case "remove": + // Call the remove sub-issue function + result, err := RemoveSubIssue(ctx, client, owner, repo, issueNumber, subIssueID) + return result, nil, err + case "reprioritize": + // Call the reprioritize sub-issue function + result, err := ReprioritizeSubIssue(ctx, client, owner, repo, issueNumber, subIssueID, afterID, beforeID) + return result, nil, err + default: + return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil, nil + } } - } + }) } func AddSubIssue(ctx context.Context, client *github.Client, owner string, repo string, issueNumber int, subIssueID int, replaceParent bool) (*mcp.CallToolResult, error) { @@ -899,7 +912,7 @@ func ReprioritizeSubIssue(ctx context.Context, client *github.Client, owner stri } // SearchIssues creates a tool to search for issues. -func SearchIssues(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func SearchIssues(t translations.TranslationHelperFunc) toolsets.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -942,7 +955,8 @@ func SearchIssues(getClient GetClientFn, t translations.TranslationHelperFunc) ( } WithPagination(schema) - return mcp.Tool{ + return NewTool( + mcp.Tool{ Name: "search_issues", Description: t("TOOL_SEARCH_ISSUES_DESCRIPTION", "Search for issues in GitHub repositories using issues search syntax already scoped to is:issue"), Annotations: &mcp.ToolAnnotations{ @@ -951,15 +965,18 @@ func SearchIssues(getClient GetClientFn, t translations.TranslationHelperFunc) ( }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - result, err := searchHandler(ctx, getClient, args, "issue", "failed to search issues") - return result, nil, err - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + result, err := searchHandler(ctx, deps.GetClient, args, "issue", "failed to search issues") + return result, nil, err + } + }) } // IssueWrite creates a tool to create a new or update an existing issue in a GitHub repository. -func IssueWrite(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func IssueWrite(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "issue_write", Description: t("TOOL_ISSUE_WRITE_DESCRIPTION", "Create a new or update an existing issue in a GitHub repository."), Annotations: &mcp.ToolAnnotations{ @@ -1038,104 +1055,106 @@ Options are: Required: []string{"method", "owner", "repo"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - method, err := RequiredParam[string](args, "method") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - title, err := OptionalParam[string](args, "title") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + method, err := RequiredParam[string](args, "method") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Optional parameters - body, err := OptionalParam[string](args, "body") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + title, err := OptionalParam[string](args, "title") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Get assignees - assignees, err := OptionalStringArrayParam(args, "assignees") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + // Optional parameters + body, err := OptionalParam[string](args, "body") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Get labels - labels, err := OptionalStringArrayParam(args, "labels") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + // Get assignees + assignees, err := OptionalStringArrayParam(args, "assignees") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Get optional milestone - milestone, err := OptionalIntParam(args, "milestone") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + // Get labels + labels, err := OptionalStringArrayParam(args, "labels") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var milestoneNum int - if milestone != 0 { - milestoneNum = milestone - } + // Get optional milestone + milestone, err := OptionalIntParam(args, "milestone") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Get optional type - issueType, err := OptionalParam[string](args, "type") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + var milestoneNum int + if milestone != 0 { + milestoneNum = milestone + } - // Handle state, state_reason and duplicateOf parameters - state, err := OptionalParam[string](args, "state") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + // Get optional type + issueType, err := OptionalParam[string](args, "type") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - stateReason, err := OptionalParam[string](args, "state_reason") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + // Handle state, state_reason and duplicateOf parameters + state, err := OptionalParam[string](args, "state") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - duplicateOf, err := OptionalIntParam(args, "duplicate_of") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - if duplicateOf != 0 && stateReason != "duplicate" { - return utils.NewToolResultError("duplicate_of can only be used when state_reason is 'duplicate'"), nil, nil - } + stateReason, err := OptionalParam[string](args, "state_reason") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + duplicateOf, err := OptionalIntParam(args, "duplicate_of") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + if duplicateOf != 0 && stateReason != "duplicate" { + return utils.NewToolResultError("duplicate_of can only be used when state_reason is 'duplicate'"), nil, nil + } - gqlClient, err := getGQLClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GraphQL client", err), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - switch method { - case "create": - result, err := CreateIssue(ctx, client, owner, repo, title, body, assignees, labels, milestoneNum, issueType) - return result, nil, err - case "update": - issueNumber, err := RequiredInt(args, "issue_number") + gqlClient, err := deps.GetGQLClient(ctx) if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil + return utils.NewToolResultErrorFromErr("failed to get GraphQL client", err), nil, nil + } + + switch method { + case "create": + result, err := CreateIssue(ctx, client, owner, repo, title, body, assignees, labels, milestoneNum, issueType) + return result, nil, err + case "update": + issueNumber, err := RequiredInt(args, "issue_number") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + result, err := UpdateIssue(ctx, client, gqlClient, owner, repo, issueNumber, title, body, assignees, labels, milestoneNum, issueType, state, stateReason, duplicateOf) + return result, nil, err + default: + return utils.NewToolResultError("invalid method, must be either 'create' or 'update'"), nil, nil } - result, err := UpdateIssue(ctx, client, gqlClient, owner, repo, issueNumber, title, body, assignees, labels, milestoneNum, issueType, state, stateReason, duplicateOf) - return result, nil, err - default: - return utils.NewToolResultError("invalid method, must be either 'create' or 'update'"), nil, nil } - } + }) } func CreateIssue(ctx context.Context, client *github.Client, owner string, repo string, title string, body string, assignees []string, labels []string, milestoneNum int, issueType string) (*mcp.CallToolResult, error) { @@ -1313,7 +1332,7 @@ func UpdateIssue(ctx context.Context, client *github.Client, gqlClient *githubv4 } // ListIssues creates a tool to list and filter repository issues -func ListIssues(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func ListIssues(t translations.TranslationHelperFunc) toolsets.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1356,7 +1375,8 @@ func ListIssues(getGQLClient GetGQLClientFn, t translations.TranslationHelperFun } WithCursorPagination(schema) - return mcp.Tool{ + return NewTool( + mcp.Tool{ Name: "list_issues", Description: t("TOOL_LIST_ISSUES_DESCRIPTION", "List issues in a GitHub repository. For pagination, use the 'endCursor' from the previous response's 'pageInfo' in the 'after' parameter."), Annotations: &mcp.ToolAnnotations{ @@ -1365,186 +1385,188 @@ func ListIssues(getGQLClient GetGQLClientFn, t translations.TranslationHelperFun }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - // Set optional parameters if provided - state, err := OptionalParam[string](args, "state") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Normalize and filter by state - state = strings.ToUpper(state) - var states []githubv4.IssueState + // Set optional parameters if provided + state, err := OptionalParam[string](args, "state") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - switch state { - case "OPEN", "CLOSED": - states = []githubv4.IssueState{githubv4.IssueState(state)} - default: - states = []githubv4.IssueState{githubv4.IssueStateOpen, githubv4.IssueStateClosed} - } + // Normalize and filter by state + state = strings.ToUpper(state) + var states []githubv4.IssueState - // Get labels - labels, err := OptionalStringArrayParam(args, "labels") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + switch state { + case "OPEN", "CLOSED": + states = []githubv4.IssueState{githubv4.IssueState(state)} + default: + states = []githubv4.IssueState{githubv4.IssueStateOpen, githubv4.IssueStateClosed} + } - orderBy, err := OptionalParam[string](args, "orderBy") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + // Get labels + labels, err := OptionalStringArrayParam(args, "labels") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - direction, err := OptionalParam[string](args, "direction") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + orderBy, err := OptionalParam[string](args, "orderBy") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Normalize and validate orderBy - orderBy = strings.ToUpper(orderBy) - switch orderBy { - case "CREATED_AT", "UPDATED_AT", "COMMENTS": - // Valid, keep as is - default: - orderBy = "CREATED_AT" - } + direction, err := OptionalParam[string](args, "direction") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Normalize and validate direction - direction = strings.ToUpper(direction) - switch direction { - case "ASC", "DESC": - // Valid, keep as is - default: - direction = "DESC" - } + // Normalize and validate orderBy + orderBy = strings.ToUpper(orderBy) + switch orderBy { + case "CREATED_AT", "UPDATED_AT", "COMMENTS": + // Valid, keep as is + default: + orderBy = "CREATED_AT" + } - since, err := OptionalParam[string](args, "since") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + // Normalize and validate direction + direction = strings.ToUpper(direction) + switch direction { + case "ASC", "DESC": + // Valid, keep as is + default: + direction = "DESC" + } - // There are two optional parameters: since and labels. - var sinceTime time.Time - var hasSince bool - if since != "" { - sinceTime, err = parseISOTimestamp(since) + since, err := OptionalParam[string](args, "since") if err != nil { - return utils.NewToolResultError(fmt.Sprintf("failed to list issues: %s", err.Error())), nil, nil + return utils.NewToolResultError(err.Error()), nil, nil } - hasSince = true - } - hasLabels := len(labels) > 0 - // Get pagination parameters and convert to GraphQL format - pagination, err := OptionalCursorPaginationParams(args) - if err != nil { - return nil, nil, err - } + // There are two optional parameters: since and labels. + var sinceTime time.Time + var hasSince bool + if since != "" { + sinceTime, err = parseISOTimestamp(since) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("failed to list issues: %s", err.Error())), nil, nil + } + hasSince = true + } + hasLabels := len(labels) > 0 - // Check if someone tried to use page-based pagination instead of cursor-based - if _, pageProvided := args["page"]; pageProvided { - return utils.NewToolResultError("This tool uses cursor-based pagination. Use the 'after' parameter with the 'endCursor' value from the previous response instead of 'page'."), nil, nil - } + // Get pagination parameters and convert to GraphQL format + pagination, err := OptionalCursorPaginationParams(args) + if err != nil { + return nil, nil, err + } - // Check if pagination parameters were explicitly provided - _, perPageProvided := args["perPage"] - paginationExplicit := perPageProvided + // Check if someone tried to use page-based pagination instead of cursor-based + if _, pageProvided := args["page"]; pageProvided { + return utils.NewToolResultError("This tool uses cursor-based pagination. Use the 'after' parameter with the 'endCursor' value from the previous response instead of 'page'."), nil, nil + } - paginationParams, err := pagination.ToGraphQLParams() - if err != nil { - return nil, nil, err - } + // Check if pagination parameters were explicitly provided + _, perPageProvided := args["perPage"] + paginationExplicit := perPageProvided - // Use default of 30 if pagination was not explicitly provided - if !paginationExplicit { - defaultFirst := int32(DefaultGraphQLPageSize) - paginationParams.First = &defaultFirst - } + paginationParams, err := pagination.ToGraphQLParams() + if err != nil { + return nil, nil, err + } - client, err := getGQLClient(ctx) - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil - } + // Use default of 30 if pagination was not explicitly provided + if !paginationExplicit { + defaultFirst := int32(DefaultGraphQLPageSize) + paginationParams.First = &defaultFirst + } - vars := map[string]interface{}{ - "owner": githubv4.String(owner), - "repo": githubv4.String(repo), - "states": states, - "orderBy": githubv4.IssueOrderField(orderBy), - "direction": githubv4.OrderDirection(direction), - "first": githubv4.Int(*paginationParams.First), - } + client, err := deps.GetGQLClient(ctx) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil + } - if paginationParams.After != nil { - vars["after"] = githubv4.String(*paginationParams.After) - } else { - // Used within query, therefore must be set to nil and provided as $after - vars["after"] = (*githubv4.String)(nil) - } + vars := map[string]interface{}{ + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), + "states": states, + "orderBy": githubv4.IssueOrderField(orderBy), + "direction": githubv4.OrderDirection(direction), + "first": githubv4.Int(*paginationParams.First), + } - // Ensure optional parameters are set - if hasLabels { - // Use query with labels filtering - convert string labels to githubv4.String slice - labelStrings := make([]githubv4.String, len(labels)) - for i, label := range labels { - labelStrings[i] = githubv4.String(label) + if paginationParams.After != nil { + vars["after"] = githubv4.String(*paginationParams.After) + } else { + // Used within query, therefore must be set to nil and provided as $after + vars["after"] = (*githubv4.String)(nil) } - vars["labels"] = labelStrings - } - if hasSince { - vars["since"] = githubv4.DateTime{Time: sinceTime} - } + // Ensure optional parameters are set + if hasLabels { + // Use query with labels filtering - convert string labels to githubv4.String slice + labelStrings := make([]githubv4.String, len(labels)) + for i, label := range labels { + labelStrings[i] = githubv4.String(label) + } + vars["labels"] = labelStrings + } - issueQuery := getIssueQueryType(hasLabels, hasSince) - if err := client.Query(ctx, issueQuery, vars); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + if hasSince { + vars["since"] = githubv4.DateTime{Time: sinceTime} + } - // Extract and convert all issue nodes using the common interface - var issues []*github.Issue - var pageInfo struct { - HasNextPage githubv4.Boolean - HasPreviousPage githubv4.Boolean - StartCursor githubv4.String - EndCursor githubv4.String - } - var totalCount int + issueQuery := getIssueQueryType(hasLabels, hasSince) + if err := client.Query(ctx, issueQuery, vars); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - if queryResult, ok := issueQuery.(IssueQueryResult); ok { - fragment := queryResult.GetIssueFragment() - for _, issue := range fragment.Nodes { - issues = append(issues, fragmentToIssue(issue)) + // Extract and convert all issue nodes using the common interface + var issues []*github.Issue + var pageInfo struct { + HasNextPage githubv4.Boolean + HasPreviousPage githubv4.Boolean + StartCursor githubv4.String + EndCursor githubv4.String } - pageInfo = fragment.PageInfo - totalCount = fragment.TotalCount - } + var totalCount int - // Create response with issues - response := map[string]interface{}{ - "issues": issues, - "pageInfo": map[string]interface{}{ - "hasNextPage": pageInfo.HasNextPage, - "hasPreviousPage": pageInfo.HasPreviousPage, - "startCursor": string(pageInfo.StartCursor), - "endCursor": string(pageInfo.EndCursor), - }, - "totalCount": totalCount, - } - out, err := json.Marshal(response) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal issues: %w", err) + if queryResult, ok := issueQuery.(IssueQueryResult); ok { + fragment := queryResult.GetIssueFragment() + for _, issue := range fragment.Nodes { + issues = append(issues, fragmentToIssue(issue)) + } + pageInfo = fragment.PageInfo + totalCount = fragment.TotalCount + } + + // Create response with issues + response := map[string]interface{}{ + "issues": issues, + "pageInfo": map[string]interface{}{ + "hasNextPage": pageInfo.HasNextPage, + "hasPreviousPage": pageInfo.HasPreviousPage, + "startCursor": string(pageInfo.StartCursor), + "endCursor": string(pageInfo.EndCursor), + }, + "totalCount": totalCount, + } + out, err := json.Marshal(response) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal issues: %w", err) + } + return utils.NewToolResultText(string(out)), nil, nil } - return utils.NewToolResultText(string(out)), nil, nil - } + }) } // mvpDescription is an MVP idea for generating tool descriptions from structured data in a shared format. @@ -1577,7 +1599,7 @@ func (d *mvpDescription) String() string { return sb.String() } -func AssignCopilotToIssue(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func AssignCopilotToIssue(t translations.TranslationHelperFunc) toolsets.ServerTool { description := mvpDescription{ summary: "Assign Copilot to a specific issue in a GitHub repository.", outcomes: []string{ @@ -1588,7 +1610,8 @@ func AssignCopilotToIssue(getGQLClient GetGQLClientFn, t translations.Translatio }, } - return mcp.Tool{ + return NewTool( + mcp.Tool{ Name: "assign_copilot_to_issue", Description: t("TOOL_ASSIGN_COPILOT_TO_ISSUE_DESCRIPTION", description.String()), Annotations: &mcp.ToolAnnotations{ @@ -1615,132 +1638,134 @@ func AssignCopilotToIssue(getGQLClient GetGQLClientFn, t translations.Translatio Required: []string{"owner", "repo", "issueNumber"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - var params struct { - Owner string - Repo string - IssueNumber int32 - } - if err := mapstructure.Decode(args, ¶ms); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := getGQLClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + var params struct { + Owner string + Repo string + IssueNumber int32 + } + if err := mapstructure.Decode(args, ¶ms); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Firstly, we try to find the copilot bot in the suggested actors for the repository. - // Although as I write this, we would expect copilot to be at the top of the list, in future, maybe - // it will not be on the first page of responses, thus we will keep paginating until we find it. - type botAssignee struct { - ID githubv4.ID - Login string - TypeName string `graphql:"__typename"` - } + client, err := deps.GetGQLClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - type suggestedActorsQuery struct { - Repository struct { - SuggestedActors struct { - Nodes []struct { - Bot botAssignee `graphql:"... on Bot"` - } - PageInfo struct { - HasNextPage bool - EndCursor string - } - } `graphql:"suggestedActors(first: 100, after: $endCursor, capabilities: CAN_BE_ASSIGNED)"` - } `graphql:"repository(owner: $owner, name: $name)"` - } + // Firstly, we try to find the copilot bot in the suggested actors for the repository. + // Although as I write this, we would expect copilot to be at the top of the list, in future, maybe + // it will not be on the first page of responses, thus we will keep paginating until we find it. + type botAssignee struct { + ID githubv4.ID + Login string + TypeName string `graphql:"__typename"` + } - variables := map[string]any{ - "owner": githubv4.String(params.Owner), - "name": githubv4.String(params.Repo), - "endCursor": (*githubv4.String)(nil), - } + type suggestedActorsQuery struct { + Repository struct { + SuggestedActors struct { + Nodes []struct { + Bot botAssignee `graphql:"... on Bot"` + } + PageInfo struct { + HasNextPage bool + EndCursor string + } + } `graphql:"suggestedActors(first: 100, after: $endCursor, capabilities: CAN_BE_ASSIGNED)"` + } `graphql:"repository(owner: $owner, name: $name)"` + } - var copilotAssignee *botAssignee - for { - var query suggestedActorsQuery - err := client.Query(ctx, &query, variables) - if err != nil { - return nil, nil, err + variables := map[string]any{ + "owner": githubv4.String(params.Owner), + "name": githubv4.String(params.Repo), + "endCursor": (*githubv4.String)(nil), } - // Iterate all the returned nodes looking for the copilot bot, which is supposed to have the - // same name on each host. We need this in order to get the ID for later assignment. - for _, node := range query.Repository.SuggestedActors.Nodes { - if node.Bot.Login == "copilot-swe-agent" { - copilotAssignee = &node.Bot + var copilotAssignee *botAssignee + for { + var query suggestedActorsQuery + err := client.Query(ctx, &query, variables) + if err != nil { + return nil, nil, err + } + + // Iterate all the returned nodes looking for the copilot bot, which is supposed to have the + // same name on each host. We need this in order to get the ID for later assignment. + for _, node := range query.Repository.SuggestedActors.Nodes { + if node.Bot.Login == "copilot-swe-agent" { + copilotAssignee = &node.Bot + break + } + } + + if !query.Repository.SuggestedActors.PageInfo.HasNextPage { break } + variables["endCursor"] = githubv4.String(query.Repository.SuggestedActors.PageInfo.EndCursor) } - if !query.Repository.SuggestedActors.PageInfo.HasNextPage { - break + // If we didn't find the copilot bot, we can't proceed any further. + if copilotAssignee == nil { + // The e2e tests depend upon this specific message to skip the test. + return utils.NewToolResultError("copilot isn't available as an assignee for this issue. Please inform the user to visit https://docs.github.com/en/copilot/using-github-copilot/using-copilot-coding-agent-to-work-on-tasks/about-assigning-tasks-to-copilot for more information."), nil, nil } - variables["endCursor"] = githubv4.String(query.Repository.SuggestedActors.PageInfo.EndCursor) - } - // If we didn't find the copilot bot, we can't proceed any further. - if copilotAssignee == nil { - // The e2e tests depend upon this specific message to skip the test. - return utils.NewToolResultError("copilot isn't available as an assignee for this issue. Please inform the user to visit https://docs.github.com/en/copilot/using-github-copilot/using-copilot-coding-agent-to-work-on-tasks/about-assigning-tasks-to-copilot for more information."), nil, nil - } + // Next let's get the GQL Node ID and current assignees for this issue because the only way to + // assign copilot is to use replaceActorsForAssignable which requires the full list. + var getIssueQuery struct { + Repository struct { + Issue struct { + ID githubv4.ID + Assignees struct { + Nodes []struct { + ID githubv4.ID + } + } `graphql:"assignees(first: 100)"` + } `graphql:"issue(number: $number)"` + } `graphql:"repository(owner: $owner, name: $name)"` + } - // Next let's get the GQL Node ID and current assignees for this issue because the only way to - // assign copilot is to use replaceActorsForAssignable which requires the full list. - var getIssueQuery struct { - Repository struct { - Issue struct { - ID githubv4.ID - Assignees struct { - Nodes []struct { - ID githubv4.ID - } - } `graphql:"assignees(first: 100)"` - } `graphql:"issue(number: $number)"` - } `graphql:"repository(owner: $owner, name: $name)"` - } + variables = map[string]any{ + "owner": githubv4.String(params.Owner), + "name": githubv4.String(params.Repo), + "number": githubv4.Int(params.IssueNumber), + } - variables = map[string]any{ - "owner": githubv4.String(params.Owner), - "name": githubv4.String(params.Repo), - "number": githubv4.Int(params.IssueNumber), - } + if err := client.Query(ctx, &getIssueQuery, variables); err != nil { + return utils.NewToolResultError(fmt.Sprintf("failed to get issue ID: %v", err)), nil, nil + } - if err := client.Query(ctx, &getIssueQuery, variables); err != nil { - return utils.NewToolResultError(fmt.Sprintf("failed to get issue ID: %v", err)), nil, nil - } + // Finally, do the assignment. Just for reference, assigning copilot to an issue that it is already + // assigned to seems to have no impact (which is a good thing). + var assignCopilotMutation struct { + ReplaceActorsForAssignable struct { + Typename string `graphql:"__typename"` // Not required but we need a selector or GQL errors + } `graphql:"replaceActorsForAssignable(input: $input)"` + } - // Finally, do the assignment. Just for reference, assigning copilot to an issue that it is already - // assigned to seems to have no impact (which is a good thing). - var assignCopilotMutation struct { - ReplaceActorsForAssignable struct { - Typename string `graphql:"__typename"` // Not required but we need a selector or GQL errors - } `graphql:"replaceActorsForAssignable(input: $input)"` - } + actorIDs := make([]githubv4.ID, len(getIssueQuery.Repository.Issue.Assignees.Nodes)+1) + for i, node := range getIssueQuery.Repository.Issue.Assignees.Nodes { + actorIDs[i] = node.ID + } + actorIDs[len(getIssueQuery.Repository.Issue.Assignees.Nodes)] = copilotAssignee.ID + + if err := client.Mutate( + ctx, + &assignCopilotMutation, + ReplaceActorsForAssignableInput{ + AssignableID: getIssueQuery.Repository.Issue.ID, + ActorIDs: actorIDs, + }, + nil, + ); err != nil { + return nil, nil, fmt.Errorf("failed to replace actors for assignable: %w", err) + } - actorIDs := make([]githubv4.ID, len(getIssueQuery.Repository.Issue.Assignees.Nodes)+1) - for i, node := range getIssueQuery.Repository.Issue.Assignees.Nodes { - actorIDs[i] = node.ID + return utils.NewToolResultText("successfully assigned copilot to issue"), nil, nil } - actorIDs[len(getIssueQuery.Repository.Issue.Assignees.Nodes)] = copilotAssignee.ID - - if err := client.Mutate( - ctx, - &assignCopilotMutation, - ReplaceActorsForAssignableInput{ - AssignableID: getIssueQuery.Repository.Issue.ID, - ActorIDs: actorIDs, - }, - nil, - ); err != nil { - return nil, nil, fmt.Errorf("failed to replace actors for assignable: %w", err) - } - - return utils.NewToolResultText("successfully assigned copilot to issue"), nil, nil - } + }) } type ReplaceActorsForAssignableInput struct { diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go index c4454624b..c832f031a 100644 --- a/pkg/github/issues_test.go +++ b/pkg/github/issues_test.go @@ -122,9 +122,8 @@ func toString(v any) string { func Test_GetIssue(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - defaultGQLClient := githubv4.NewClient(nil) - tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(defaultGQLClient), repoAccessCache, translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + serverTool := IssueRead(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "issue_read", tool.Name) @@ -327,14 +326,20 @@ func Test_GetIssue(t *testing.T) { gqlClient = githubv4.NewClient(tc.gqlHTTPClient) cache = stubRepoAccessCache(gqlClient, 15*time.Minute) } else { - gqlClient = defaultGQLClient + gqlClient = githubv4.NewClient(nil) } flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) - _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), cache, translations.NullTranslationHelper, flags) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + GetGQLClient: stubGetGQLClientFn(gqlClient), + RepoAccessCache: cache, + Flags: flags, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) if tc.expectHandlerError { require.Error(t, err) @@ -368,8 +373,8 @@ func Test_GetIssue(t *testing.T) { func Test_AddIssueComment(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := AddIssueComment(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := AddIssueComment(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "add_issue_comment", tool.Name) @@ -442,13 +447,16 @@ func Test_AddIssueComment(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := AddIssueComment(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -483,8 +491,8 @@ func Test_AddIssueComment(t *testing.T) { func Test_SearchIssues(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := SearchIssues(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := SearchIssues(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "search_issues", tool.Name) @@ -773,13 +781,16 @@ func Test_SearchIssues(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SearchIssues(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -817,9 +828,8 @@ func Test_SearchIssues(t *testing.T) { func Test_CreateIssue(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - mockGQLClient := githubv4.NewClient(nil) - tool, _ := IssueWrite(stubGetClientFn(mockClient), stubGetGQLClientFn(mockGQLClient), translations.NullTranslationHelper) + serverTool := IssueWrite(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "issue_write", tool.Name) @@ -942,13 +952,17 @@ func Test_CreateIssue(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) gqlClient := githubv4.NewClient(nil) - _, handler := IssueWrite(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + GetGQLClient: stubGetGQLClientFn(gqlClient), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -979,8 +993,8 @@ func Test_CreateIssue(t *testing.T) { func Test_ListIssues(t *testing.T) { // Verify tool definition - mockClient := githubv4.NewClient(nil) - tool, _ := ListIssues(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ListIssues(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "list_issues", tool.Name) @@ -1254,10 +1268,13 @@ func Test_ListIssues(t *testing.T) { } gqlClient := githubv4.NewClient(httpClient) - _, handler := ListIssues(stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetGQLClient: stubGetGQLClientFn(gqlClient), + } + handler := serverTool.Handler(deps) req := createMCPRequest(tc.reqParams) - res, _, err := handler(context.Background(), &req, tc.reqParams) + res, err := handler(context.Background(), &req) text := getTextResult(t, res).Text if tc.expectError { @@ -1300,9 +1317,8 @@ func Test_ListIssues(t *testing.T) { func Test_UpdateIssue(t *testing.T) { // Verify tool definition - mockClient := github.NewClient(nil) - mockGQLClient := githubv4.NewClient(nil) - tool, _ := IssueWrite(stubGetClientFn(mockClient), stubGetGQLClientFn(mockGQLClient), translations.NullTranslationHelper) + serverTool := IssueWrite(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "issue_write", tool.Name) @@ -1753,13 +1769,17 @@ func Test_UpdateIssue(t *testing.T) { // Setup clients with mocks restClient := github.NewClient(tc.mockedRESTClient) gqlClient := githubv4.NewClient(tc.mockedGQLClient) - _, handler := IssueWrite(stubGetClientFn(restClient), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(restClient), + GetGQLClient: stubGetGQLClientFn(gqlClient), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError || tc.expectedErrMsg != "" { @@ -1844,9 +1864,8 @@ func Test_ParseISOTimestamp(t *testing.T) { func Test_GetIssueComments(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - gqlClient := githubv4.NewClient(nil) - tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient, 15*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + serverTool := IssueRead(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "issue_read", tool.Name) @@ -1997,13 +2016,19 @@ func Test_GetIssueComments(t *testing.T) { } cache := stubRepoAccessCache(gqlClient, 15*time.Minute) flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) - _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), cache, translations.NullTranslationHelper, flags) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + GetGQLClient: stubGetGQLClientFn(gqlClient), + RepoAccessCache: cache, + Flags: flags, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -2035,9 +2060,8 @@ func Test_GetIssueLabels(t *testing.T) { t.Parallel() // Verify tool definition - mockGQClient := githubv4.NewClient(nil) - mockClient := github.NewClient(nil) - tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(mockGQClient), stubRepoAccessCache(mockGQClient, 15*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + serverTool := IssueRead(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "issue_read", tool.Name) @@ -2112,10 +2136,16 @@ func Test_GetIssueLabels(t *testing.T) { t.Run(tc.name, func(t *testing.T) { gqlClient := githubv4.NewClient(tc.mockedClient) client := github.NewClient(nil) - _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient, 15*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + GetGQLClient: stubGetGQLClientFn(gqlClient), + RepoAccessCache: stubRepoAccessCache(gqlClient, 15*time.Minute), + Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) assert.NotNil(t, result) @@ -2137,8 +2167,8 @@ func TestAssignCopilotToIssue(t *testing.T) { t.Parallel() // Verify tool definition - mockClient := githubv4.NewClient(nil) - tool, _ := AssignCopilotToIssue(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper) + serverTool := AssignCopilotToIssue(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "assign_copilot_to_issue", tool.Name) @@ -2530,13 +2560,16 @@ func TestAssignCopilotToIssue(t *testing.T) { t.Parallel() // Setup client with mock client := githubv4.NewClient(tc.mockedClient) - _, handler := AssignCopilotToIssue(stubGetGQLClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetGQLClient: stubGetGQLClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) textContent := getTextResult(t, result) @@ -2555,8 +2588,8 @@ func TestAssignCopilotToIssue(t *testing.T) { func Test_AddSubIssue(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := SubIssueWrite(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := SubIssueWrite(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "sub_issue_write", tool.Name) @@ -2758,13 +2791,16 @@ func Test_AddSubIssue(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SubIssueWrite(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -2801,9 +2837,8 @@ func Test_AddSubIssue(t *testing.T) { func Test_GetSubIssues(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - gqlClient := githubv4.NewClient(nil) - tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient, 15*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + serverTool := IssueRead(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "issue_read", tool.Name) @@ -3000,13 +3035,19 @@ func Test_GetSubIssues(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) gqlClient := githubv4.NewClient(nil) - _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient, 15*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + GetGQLClient: stubGetGQLClientFn(gqlClient), + RepoAccessCache: stubRepoAccessCache(gqlClient, 15*time.Minute), + Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -3052,8 +3093,8 @@ func Test_GetSubIssues(t *testing.T) { func Test_RemoveSubIssue(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := SubIssueWrite(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := SubIssueWrite(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "sub_issue_write", tool.Name) @@ -3234,13 +3275,16 @@ func Test_RemoveSubIssue(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SubIssueWrite(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -3277,8 +3321,8 @@ func Test_RemoveSubIssue(t *testing.T) { func Test_ReprioritizeSubIssue(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := SubIssueWrite(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := SubIssueWrite(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "sub_issue_write", tool.Name) @@ -3520,13 +3564,16 @@ func Test_ReprioritizeSubIssue(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SubIssueWrite(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -3563,8 +3610,8 @@ func Test_ReprioritizeSubIssue(t *testing.T) { func Test_ListIssueTypes(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListIssueTypes(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ListIssueTypes(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "list_issue_types", tool.Name) @@ -3651,13 +3698,16 @@ func Test_ListIssueTypes(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ListIssueTypes(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 661384529..bfe870775 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -16,12 +16,13 @@ import ( ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/sanitize" + "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" ) // PullRequestRead creates a tool to get details of a specific pull request. -func PullRequestRead(getClient GetClientFn, cache *lockdown.RepoAccessCache, t translations.TranslationHelperFunc, flags FeatureFlags) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func PullRequestRead(t translations.TranslationHelperFunc) toolsets.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -56,7 +57,8 @@ Possible options: } WithPagination(schema) - return mcp.Tool{ + return NewTool( + mcp.Tool{ Name: "pull_request_read", Description: t("TOOL_PULL_REQUEST_READ_DESCRIPTION", "Get information on a specific pull request in GitHub repository."), Annotations: &mcp.ToolAnnotations{ @@ -65,60 +67,62 @@ Possible options: }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - method, err := RequiredParam[string](args, "method") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + method, err := RequiredParam[string](args, "method") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pullNumber, err := RequiredInt(args, "pullNumber") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pullNumber, err := RequiredInt(args, "pullNumber") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - switch method { - case "get": - result, err := GetPullRequest(ctx, client, cache, owner, repo, pullNumber, flags) - return result, nil, err - case "get_diff": - result, err := GetPullRequestDiff(ctx, client, owner, repo, pullNumber) - return result, nil, err - case "get_status": - result, err := GetPullRequestStatus(ctx, client, owner, repo, pullNumber) - return result, nil, err - case "get_files": - result, err := GetPullRequestFiles(ctx, client, owner, repo, pullNumber, pagination) - return result, nil, err - case "get_review_comments": - result, err := GetPullRequestReviewComments(ctx, client, cache, owner, repo, pullNumber, pagination, flags) - return result, nil, err - case "get_reviews": - result, err := GetPullRequestReviews(ctx, client, cache, owner, repo, pullNumber, flags) - return result, nil, err - case "get_comments": - result, err := GetIssueComments(ctx, client, cache, owner, repo, pullNumber, pagination, flags) - return result, nil, err - default: - return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil, nil + switch method { + case "get": + result, err := GetPullRequest(ctx, client, deps.RepoAccessCache, owner, repo, pullNumber, deps.Flags) + return result, nil, err + case "get_diff": + result, err := GetPullRequestDiff(ctx, client, owner, repo, pullNumber) + return result, nil, err + case "get_status": + result, err := GetPullRequestStatus(ctx, client, owner, repo, pullNumber) + return result, nil, err + case "get_files": + result, err := GetPullRequestFiles(ctx, client, owner, repo, pullNumber, pagination) + return result, nil, err + case "get_review_comments": + result, err := GetPullRequestReviewComments(ctx, client, deps.RepoAccessCache, owner, repo, pullNumber, pagination, deps.Flags) + return result, nil, err + case "get_reviews": + result, err := GetPullRequestReviews(ctx, client, deps.RepoAccessCache, owner, repo, pullNumber, deps.Flags) + return result, nil, err + case "get_comments": + result, err := GetIssueComments(ctx, client, deps.RepoAccessCache, owner, repo, pullNumber, pagination, deps.Flags) + return result, nil, err + default: + return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil, nil + } } - } + }) } func GetPullRequest(ctx context.Context, client *github.Client, cache *lockdown.RepoAccessCache, owner, repo string, pullNumber int, ff FeatureFlags) (*mcp.CallToolResult, error) { @@ -385,7 +389,7 @@ func GetPullRequestReviews(ctx context.Context, client *github.Client, cache *lo } // CreatePullRequest creates a tool to create a new pull request. -func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func CreatePullRequest(t translations.TranslationHelperFunc) toolsets.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -425,7 +429,8 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu Required: []string{"owner", "repo", "title", "head", "base"}, } - return mcp.Tool{ + return NewTool( + mcp.Tool{ Name: "create_pull_request", Description: t("TOOL_CREATE_PULL_REQUEST_DESCRIPTION", "Create a new pull request in a GitHub repository."), Annotations: &mcp.ToolAnnotations{ @@ -434,95 +439,97 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - title, err := RequiredParam[string](args, "title") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - head, err := RequiredParam[string](args, "head") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - base, err := RequiredParam[string](args, "base") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - body, err := OptionalParam[string](args, "body") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + title, err := RequiredParam[string](args, "title") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + head, err := RequiredParam[string](args, "head") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + base, err := RequiredParam[string](args, "base") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - draft, err := OptionalParam[bool](args, "draft") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + body, err := OptionalParam[string](args, "body") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - maintainerCanModify, err := OptionalParam[bool](args, "maintainer_can_modify") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + draft, err := OptionalParam[bool](args, "draft") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - newPR := &github.NewPullRequest{ - Title: github.Ptr(title), - Head: github.Ptr(head), - Base: github.Ptr(base), - } + maintainerCanModify, err := OptionalParam[bool](args, "maintainer_can_modify") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - if body != "" { - newPR.Body = github.Ptr(body) - } + newPR := &github.NewPullRequest{ + Title: github.Ptr(title), + Head: github.Ptr(head), + Base: github.Ptr(base), + } - newPR.Draft = github.Ptr(draft) - newPR.MaintainerCanModify = github.Ptr(maintainerCanModify) + if body != "" { + newPR.Body = github.Ptr(body) + } - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - pr, resp, err := client.PullRequests.Create(ctx, owner, repo, newPR) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to create pull request", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + newPR.Draft = github.Ptr(draft) + newPR.MaintainerCanModify = github.Ptr(maintainerCanModify) - if resp.StatusCode != http.StatusCreated { - bodyBytes, err := io.ReadAll(resp.Body) + client, err := deps.GetClient(ctx) if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to create pull request: %s", string(bodyBytes))), nil, nil - } + pr, resp, err := client.PullRequests.Create(ctx, owner, repo, newPR) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to create pull request", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - // Return minimal response with just essential information - minimalResponse := MinimalResponse{ - ID: fmt.Sprintf("%d", pr.GetID()), - URL: pr.GetHTMLURL(), - } + if resp.StatusCode != http.StatusCreated { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to create pull request: %s", string(bodyBytes))), nil, nil + } - r, err := json.Marshal(minimalResponse) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil - } + // Return minimal response with just essential information + minimalResponse := MinimalResponse{ + ID: fmt.Sprintf("%d", pr.GetID()), + URL: pr.GetHTMLURL(), + } - return utils.NewToolResultText(string(r)), nil, nil - } + r, err := json.Marshal(minimalResponse) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + } + + return utils.NewToolResultText(string(r)), nil, nil + } + }) } // UpdatePullRequest creates a tool to update an existing pull request. -func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func UpdatePullRequest(t translations.TranslationHelperFunc) toolsets.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -574,7 +581,8 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra Required: []string{"owner", "repo", "pullNumber"}, } - return mcp.Tool{ + return NewTool( + mcp.Tool{ Name: "update_pull_request", Description: t("TOOL_UPDATE_PULL_REQUEST_DESCRIPTION", "Update an existing pull request in a GitHub repository."), Annotations: &mcp.ToolAnnotations{ @@ -583,188 +591,214 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pullNumber, err := RequiredInt(args, "pullNumber") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - _, draftProvided := args["draft"] - var draftValue bool - if draftProvided { - draftValue, err = OptionalParam[bool](args, "draft") + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pullNumber, err := RequiredInt(args, "pullNumber") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil } - } - - update := &github.PullRequest{} - restUpdateNeeded := false - if title, ok, err := OptionalParamOK[string](args, "title"); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } else if ok { - update.Title = github.Ptr(title) - restUpdateNeeded = true - } + _, draftProvided := args["draft"] + var draftValue bool + if draftProvided { + draftValue, err = OptionalParam[bool](args, "draft") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + } - if body, ok, err := OptionalParamOK[string](args, "body"); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } else if ok { - update.Body = github.Ptr(body) - restUpdateNeeded = true - } + update := &github.PullRequest{} + restUpdateNeeded := false - if state, ok, err := OptionalParamOK[string](args, "state"); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } else if ok { - update.State = github.Ptr(state) - restUpdateNeeded = true - } + if title, ok, err := OptionalParamOK[string](args, "title"); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } else if ok { + update.Title = github.Ptr(title) + restUpdateNeeded = true + } - if base, ok, err := OptionalParamOK[string](args, "base"); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } else if ok { - update.Base = &github.PullRequestBranch{Ref: github.Ptr(base)} - restUpdateNeeded = true - } + if body, ok, err := OptionalParamOK[string](args, "body"); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } else if ok { + update.Body = github.Ptr(body) + restUpdateNeeded = true + } - if maintainerCanModify, ok, err := OptionalParamOK[bool](args, "maintainer_can_modify"); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } else if ok { - update.MaintainerCanModify = github.Ptr(maintainerCanModify) - restUpdateNeeded = true - } + if state, ok, err := OptionalParamOK[string](args, "state"); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } else if ok { + update.State = github.Ptr(state) + restUpdateNeeded = true + } - // Handle reviewers separately - reviewers, err := OptionalStringArrayParam(args, "reviewers") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + if base, ok, err := OptionalParamOK[string](args, "base"); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } else if ok { + update.Base = &github.PullRequestBranch{Ref: github.Ptr(base)} + restUpdateNeeded = true + } - // If no updates, no draft change, and no reviewers, return error early - if !restUpdateNeeded && !draftProvided && len(reviewers) == 0 { - return utils.NewToolResultError("No update parameters provided."), nil, nil - } + if maintainerCanModify, ok, err := OptionalParamOK[bool](args, "maintainer_can_modify"); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } else if ok { + update.MaintainerCanModify = github.Ptr(maintainerCanModify) + restUpdateNeeded = true + } - // Handle REST API updates (title, body, state, base, maintainer_can_modify) - if restUpdateNeeded { - client, err := getClient(ctx) + // Handle reviewers separately + reviewers, err := OptionalStringArrayParam(args, "reviewers") if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + return utils.NewToolResultError(err.Error()), nil, nil } - _, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to update pull request", - resp, - err, - ), nil, nil + // If no updates, no draft change, and no reviewers, return error early + if !restUpdateNeeded && !draftProvided && len(reviewers) == 0 { + return utils.NewToolResultError("No update parameters provided."), nil, nil } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - bodyBytes, err := io.ReadAll(resp.Body) + // Handle REST API updates (title, body, state, base, maintainer_can_modify) + if restUpdateNeeded { + client, err := deps.GetClient(ctx) if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(bodyBytes))), nil, nil - } - } - // Handle draft status changes using GraphQL - if draftProvided { - gqlClient, err := getGQLClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub GraphQL client", err), nil, nil - } + _, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to update pull request", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - var prQuery struct { - Repository struct { - PullRequest struct { - ID githubv4.ID - IsDraft githubv4.Boolean - } `graphql:"pullRequest(number: $prNum)"` - } `graphql:"repository(owner: $owner, name: $repo)"` + if resp.StatusCode != http.StatusOK { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(bodyBytes))), nil, nil + } } - err = gqlClient.Query(ctx, &prQuery, map[string]interface{}{ - "owner": githubv4.String(owner), - "repo": githubv4.String(repo), - "prNum": githubv4.Int(pullNumber), // #nosec G115 - pull request numbers are always small positive integers - }) - if err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find pull request", err), nil, nil - } + // Handle draft status changes using GraphQL + if draftProvided { + gqlClient, err := deps.GetGQLClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub GraphQL client", err), nil, nil + } - currentIsDraft := bool(prQuery.Repository.PullRequest.IsDraft) + var prQuery struct { + Repository struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } - if currentIsDraft != draftValue { - if draftValue { - // Convert to draft - var mutation struct { - ConvertPullRequestToDraft struct { - PullRequest struct { - ID githubv4.ID - IsDraft githubv4.Boolean - } - } `graphql:"convertPullRequestToDraft(input: $input)"` - } + err = gqlClient.Query(ctx, &prQuery, map[string]interface{}{ + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), + "prNum": githubv4.Int(pullNumber), // #nosec G115 - pull request numbers are always small positive integers + }) + if err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find pull request", err), nil, nil + } - err = gqlClient.Mutate(ctx, &mutation, githubv4.ConvertPullRequestToDraftInput{ - PullRequestID: prQuery.Repository.PullRequest.ID, - }, nil) - if err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to convert pull request to draft", err), nil, nil + currentIsDraft := bool(prQuery.Repository.PullRequest.IsDraft) + + if currentIsDraft != draftValue { + if draftValue { + // Convert to draft + var mutation struct { + ConvertPullRequestToDraft struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } + } `graphql:"convertPullRequestToDraft(input: $input)"` + } + + err = gqlClient.Mutate(ctx, &mutation, githubv4.ConvertPullRequestToDraftInput{ + PullRequestID: prQuery.Repository.PullRequest.ID, + }, nil) + if err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to convert pull request to draft", err), nil, nil + } + } else { + // Mark as ready for review + var mutation struct { + MarkPullRequestReadyForReview struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } + } `graphql:"markPullRequestReadyForReview(input: $input)"` + } + + err = gqlClient.Mutate(ctx, &mutation, githubv4.MarkPullRequestReadyForReviewInput{ + PullRequestID: prQuery.Repository.PullRequest.ID, + }, nil) + if err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to mark pull request ready for review", err), nil, nil + } } - } else { - // Mark as ready for review - var mutation struct { - MarkPullRequestReadyForReview struct { - PullRequest struct { - ID githubv4.ID - IsDraft githubv4.Boolean - } - } `graphql:"markPullRequestReadyForReview(input: $input)"` + } + } + + // Handle reviewer requests + if len(reviewers) > 0 { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + + reviewersRequest := github.ReviewersRequest{ + Reviewers: reviewers, + } + + _, resp, err := client.PullRequests.RequestReviewers(ctx, owner, repo, pullNumber, reviewersRequest) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to request reviewers", + resp, + err, + ), nil, nil + } + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() } + }() - err = gqlClient.Mutate(ctx, &mutation, githubv4.MarkPullRequestReadyForReviewInput{ - PullRequestID: prQuery.Repository.PullRequest.ID, - }, nil) + if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { + bodyBytes, err := io.ReadAll(resp.Body) if err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to mark pull request ready for review", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } + return utils.NewToolResultError(fmt.Sprintf("failed to request reviewers: %s", string(bodyBytes))), nil, nil } } - } - // Handle reviewer requests - if len(reviewers) > 0 { - client, err := getClient(ctx) + // Get the final state of the PR to return + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } - reviewersRequest := github.ReviewersRequest{ - Reviewers: reviewers, - } - - _, resp, err := client.PullRequests.RequestReviewers(ctx, owner, repo, pullNumber, reviewersRequest) + finalPR, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to request reviewers", - resp, - err, - ), nil, nil + return ghErrors.NewGitHubAPIErrorResponse(ctx, "Failed to get pull request", resp, err), nil, nil } defer func() { if resp != nil && resp.Body != nil { @@ -772,48 +806,24 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra } }() - if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil - } - return utils.NewToolResultError(fmt.Sprintf("failed to request reviewers: %s", string(bodyBytes))), nil, nil + // Return minimal response with just essential information + minimalResponse := MinimalResponse{ + ID: fmt.Sprintf("%d", finalPR.GetID()), + URL: finalPR.GetHTMLURL(), } - } - // Get the final state of the PR to return - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - - finalPR, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, "Failed to get pull request", resp, err), nil, nil - } - defer func() { - if resp != nil && resp.Body != nil { - _ = resp.Body.Close() + r, err := json.Marshal(minimalResponse) + if err != nil { + return utils.NewToolResultErrorFromErr("Failed to marshal response", err), nil, nil } - }() - // Return minimal response with just essential information - minimalResponse := MinimalResponse{ - ID: fmt.Sprintf("%d", finalPR.GetID()), - URL: finalPR.GetHTMLURL(), + return utils.NewToolResultText(string(r)), nil, nil } - - r, err := json.Marshal(minimalResponse) - if err != nil { - return utils.NewToolResultErrorFromErr("Failed to marshal response", err), nil, nil - } - - return utils.NewToolResultText(string(r)), nil, nil - } + }) } // ListPullRequests creates a tool to list and filter repository pull requests. -func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func ListPullRequests(t translations.TranslationHelperFunc) toolsets.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -853,7 +863,8 @@ func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFun } WithPagination(schema) - return mcp.Tool{ + return NewTool( + mcp.Tool{ Name: "list_pull_requests", Description: t("TOOL_LIST_PULL_REQUESTS_DESCRIPTION", "List pull requests in a GitHub repository. If the user specifies an author, then DO NOT use this tool and use the search_pull_requests tool instead."), Annotations: &mcp.ToolAnnotations{ @@ -862,98 +873,100 @@ func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFun }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - state, err := OptionalParam[string](args, "state") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - head, err := OptionalParam[string](args, "head") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - base, err := OptionalParam[string](args, "base") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sort, err := OptionalParam[string](args, "sort") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - direction, err := OptionalParam[string](args, "direction") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - opts := &github.PullRequestListOptions{ - State: state, - Head: head, - Base: base, - Sort: sort, - Direction: direction, - ListOptions: github.ListOptions{ - PerPage: pagination.PerPage, - Page: pagination.Page, - }, - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + state, err := OptionalParam[string](args, "state") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + head, err := OptionalParam[string](args, "head") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + base, err := OptionalParam[string](args, "base") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + sort, err := OptionalParam[string](args, "sort") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + direction, err := OptionalParam[string](args, "direction") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - prs, resp, err := client.PullRequests.List(ctx, owner, repo, opts) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to list pull requests", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + opts := &github.PullRequestListOptions{ + State: state, + Head: head, + Base: base, + Sort: sort, + Direction: direction, + ListOptions: github.ListOptions{ + PerPage: pagination.PerPage, + Page: pagination.Page, + }, + } - if resp.StatusCode != http.StatusOK { - bodyBytes, err := io.ReadAll(resp.Body) + client, err := deps.GetClient(ctx) if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to list pull requests: %s", string(bodyBytes))), nil, nil - } + prs, resp, err := client.PullRequests.List(ctx, owner, repo, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to list pull requests", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - // sanitize title/body on each PR - for _, pr := range prs { - if pr == nil { - continue + if resp.StatusCode != http.StatusOK { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to list pull requests: %s", string(bodyBytes))), nil, nil } - if pr.Title != nil { - pr.Title = github.Ptr(sanitize.Sanitize(*pr.Title)) + + // sanitize title/body on each PR + for _, pr := range prs { + if pr == nil { + continue + } + if pr.Title != nil { + pr.Title = github.Ptr(sanitize.Sanitize(*pr.Title)) + } + if pr.Body != nil { + pr.Body = github.Ptr(sanitize.Sanitize(*pr.Body)) + } } - if pr.Body != nil { - pr.Body = github.Ptr(sanitize.Sanitize(*pr.Body)) + + r, err := json.Marshal(prs) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil } - } - r, err := json.Marshal(prs) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + return utils.NewToolResultText(string(r)), nil, nil } - - return utils.NewToolResultText(string(r)), nil, nil - } + }) } // MergePullRequest creates a tool to merge a pull request. -func MergePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func MergePullRequest(t translations.TranslationHelperFunc) toolsets.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -986,7 +999,8 @@ func MergePullRequest(getClient GetClientFn, t translations.TranslationHelperFun Required: []string{"owner", "repo", "pullNumber"}, } - return mcp.Tool{ + return NewTool( + mcp.Tool{ Name: "merge_pull_request", Description: t("TOOL_MERGE_PULL_REQUEST_DESCRIPTION", "Merge a pull request in a GitHub repository."), Annotations: &mcp.ToolAnnotations{ @@ -995,70 +1009,72 @@ func MergePullRequest(getClient GetClientFn, t translations.TranslationHelperFun }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pullNumber, err := RequiredInt(args, "pullNumber") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - commitTitle, err := OptionalParam[string](args, "commit_title") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - commitMessage, err := OptionalParam[string](args, "commit_message") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - mergeMethod, err := OptionalParam[string](args, "merge_method") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pullNumber, err := RequiredInt(args, "pullNumber") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + commitTitle, err := OptionalParam[string](args, "commit_title") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + commitMessage, err := OptionalParam[string](args, "commit_message") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + mergeMethod, err := OptionalParam[string](args, "merge_method") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - options := &github.PullRequestOptions{ - CommitTitle: commitTitle, - MergeMethod: mergeMethod, - } + options := &github.PullRequestOptions{ + CommitTitle: commitTitle, + MergeMethod: mergeMethod, + } - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - result, resp, err := client.PullRequests.Merge(ctx, owner, repo, pullNumber, commitMessage, options) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to merge pull request", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + result, resp, err := client.PullRequests.Merge(ctx, owner, repo, pullNumber, commitMessage, options) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to merge pull request", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - bodyBytes, err := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to merge pull request: %s", string(bodyBytes))), nil, nil + } + + r, err := json.Marshal(result) if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to merge pull request: %s", string(bodyBytes))), nil, nil - } - r, err := json.Marshal(result) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + return utils.NewToolResultText(string(r)), nil, nil } - - return utils.NewToolResultText(string(r)), nil, nil - } + }) } // SearchPullRequests creates a tool to search for pull requests. -func SearchPullRequests(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func SearchPullRequests(t translations.TranslationHelperFunc) toolsets.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1101,7 +1117,8 @@ func SearchPullRequests(getClient GetClientFn, t translations.TranslationHelperF } WithPagination(schema) - return mcp.Tool{ + return NewTool( + mcp.Tool{ Name: "search_pull_requests", Description: t("TOOL_SEARCH_PULL_REQUESTS_DESCRIPTION", "Search for pull requests in GitHub repositories using issues search syntax already scoped to is:pr"), Annotations: &mcp.ToolAnnotations{ @@ -1110,14 +1127,16 @@ func SearchPullRequests(getClient GetClientFn, t translations.TranslationHelperF }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - result, err := searchHandler(ctx, getClient, args, "pr", "failed to search pull requests") - return result, nil, err - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + result, err := searchHandler(ctx, deps.GetClient, args, "pr", "failed to search pull requests") + return result, nil, err + } + }) } // UpdatePullRequestBranch creates a tool to update a pull request branch with the latest changes from the base branch. -func UpdatePullRequestBranch(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func UpdatePullRequestBranch(t translations.TranslationHelperFunc) toolsets.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1141,7 +1160,8 @@ func UpdatePullRequestBranch(getClient GetClientFn, t translations.TranslationHe Required: []string{"owner", "repo", "pullNumber"}, } - return mcp.Tool{ + return NewTool( + mcp.Tool{ Name: "update_pull_request_branch", Description: t("TOOL_UPDATE_PULL_REQUEST_BRANCH_DESCRIPTION", "Update the branch of a pull request with the latest changes from the base branch."), Annotations: &mcp.ToolAnnotations{ @@ -1150,62 +1170,64 @@ func UpdatePullRequestBranch(getClient GetClientFn, t translations.TranslationHe }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pullNumber, err := RequiredInt(args, "pullNumber") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - expectedHeadSHA, err := OptionalParam[string](args, "expectedHeadSha") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - opts := &github.PullRequestBranchUpdateOptions{} - if expectedHeadSHA != "" { - opts.ExpectedHeadSHA = github.Ptr(expectedHeadSHA) - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pullNumber, err := RequiredInt(args, "pullNumber") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + expectedHeadSHA, err := OptionalParam[string](args, "expectedHeadSha") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + opts := &github.PullRequestBranchUpdateOptions{} + if expectedHeadSHA != "" { + opts.ExpectedHeadSHA = github.Ptr(expectedHeadSHA) + } - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - result, resp, err := client.PullRequests.UpdateBranch(ctx, owner, repo, pullNumber, opts) - if err != nil { - // Check if it's an acceptedError. An acceptedError indicates that the update is in progress, - // and it's not a real error. - if resp != nil && resp.StatusCode == http.StatusAccepted && isAcceptedError(err) { - return utils.NewToolResultText("Pull request branch update is in progress"), nil, nil - } - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to update pull request branch", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + result, resp, err := client.PullRequests.UpdateBranch(ctx, owner, repo, pullNumber, opts) + if err != nil { + // Check if it's an acceptedError. An acceptedError indicates that the update is in progress, + // and it's not a real error. + if resp != nil && resp.StatusCode == http.StatusAccepted && isAcceptedError(err) { + return utils.NewToolResultText("Pull request branch update is in progress"), nil, nil + } + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to update pull request branch", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusAccepted { - bodyBytes, err := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusAccepted { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to update pull request branch: %s", string(bodyBytes))), nil, nil + } + + r, err := json.Marshal(result) if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to update pull request branch: %s", string(bodyBytes))), nil, nil - } - r, err := json.Marshal(result) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + return utils.NewToolResultText(string(r)), nil, nil } - - return utils.NewToolResultText(string(r)), nil, nil - } + }) } type PullRequestReviewWriteParams struct { @@ -1218,7 +1240,7 @@ type PullRequestReviewWriteParams struct { CommitID *string } -func PullRequestReviewWrite(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func PullRequestReviewWrite(t translations.TranslationHelperFunc) toolsets.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1259,7 +1281,8 @@ func PullRequestReviewWrite(getGQLClient GetGQLClientFn, t translations.Translat Required: []string{"method", "owner", "repo", "pullNumber"}, } - return mcp.Tool{ + return NewTool( + mcp.Tool{ Name: "pull_request_review_write", Description: t("TOOL_PULL_REQUEST_REVIEW_WRITE_DESCRIPTION", `Create and/or submit, delete review of a pull request. @@ -1274,32 +1297,34 @@ Available methods: }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - var params PullRequestReviewWriteParams - if err := mapstructure.Decode(args, ¶ms); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + var params PullRequestReviewWriteParams + if err := mapstructure.Decode(args, ¶ms); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Given our owner, repo and PR number, lookup the GQL ID of the PR. - client, err := getGQLClient(ctx) - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil - } + // Given our owner, repo and PR number, lookup the GQL ID of the PR. + client, err := deps.GetGQLClient(ctx) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil + } - switch params.Method { - case "create": - result, err := CreatePullRequestReview(ctx, client, params) - return result, nil, err - case "submit_pending": - result, err := SubmitPendingPullRequestReview(ctx, client, params) - return result, nil, err - case "delete_pending": - result, err := DeletePendingPullRequestReview(ctx, client, params) - return result, nil, err - default: - return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", params.Method)), nil, nil + switch params.Method { + case "create": + result, err := CreatePullRequestReview(ctx, client, params) + return result, nil, err + case "submit_pending": + result, err := SubmitPendingPullRequestReview(ctx, client, params) + return result, nil, err + case "delete_pending": + result, err := DeletePendingPullRequestReview(ctx, client, params) + return result, nil, err + default: + return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", params.Method)), nil, nil + } } - } + }) } func CreatePullRequestReview(ctx context.Context, client *githubv4.Client, params PullRequestReviewWriteParams) (*mcp.CallToolResult, error) { @@ -1526,7 +1551,7 @@ func DeletePendingPullRequestReview(ctx context.Context, client *githubv4.Client } // AddCommentToPendingReview creates a tool to add a comment to a pull request review. -func AddCommentToPendingReview(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func AddCommentToPendingReview(t translations.TranslationHelperFunc) toolsets.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1586,7 +1611,8 @@ func AddCommentToPendingReview(getGQLClient GetGQLClientFn, t translations.Trans Required: []string{"owner", "repo", "pullNumber", "path", "body", "subjectType"}, } - return mcp.Tool{ + return NewTool( + mcp.Tool{ Name: "add_comment_to_pending_review", Description: t("TOOL_ADD_COMMENT_TO_PENDING_REVIEW_DESCRIPTION", "Add review comment to the requester's latest pending pull request review. A pending review needs to already exist to call this (check with the user if not sure)."), Annotations: &mcp.ToolAnnotations{ @@ -1595,127 +1621,129 @@ func AddCommentToPendingReview(getGQLClient GetGQLClientFn, t translations.Trans }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - var params struct { - Owner string - Repo string - PullNumber int32 - Path string - Body string - SubjectType string - Line *int32 - Side *string - StartLine *int32 - StartSide *string - } - if err := mapstructure.Decode(args, ¶ms); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + var params struct { + Owner string + Repo string + PullNumber int32 + Path string + Body string + SubjectType string + Line *int32 + Side *string + StartLine *int32 + StartSide *string + } + if err := mapstructure.Decode(args, ¶ms); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getGQLClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil - } + client, err := deps.GetGQLClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil + } - // First we'll get the current user - var getViewerQuery struct { - Viewer struct { - Login githubv4.String + // First we'll get the current user + var getViewerQuery struct { + Viewer struct { + Login githubv4.String + } } - } - if err := client.Query(ctx, &getViewerQuery, nil); err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, - "failed to get current user", - err, - ), nil, nil - } + if err := client.Query(ctx, &getViewerQuery, nil); err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, + "failed to get current user", + err, + ), nil, nil + } - var getLatestReviewForViewerQuery struct { - Repository struct { - PullRequest struct { - Reviews struct { - Nodes []struct { - ID githubv4.ID - State githubv4.PullRequestReviewState - URL githubv4.URI - } - } `graphql:"reviews(first: 1, author: $author)"` - } `graphql:"pullRequest(number: $prNum)"` - } `graphql:"repository(owner: $owner, name: $name)"` - } + var getLatestReviewForViewerQuery struct { + Repository struct { + PullRequest struct { + Reviews struct { + Nodes []struct { + ID githubv4.ID + State githubv4.PullRequestReviewState + URL githubv4.URI + } + } `graphql:"reviews(first: 1, author: $author)"` + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $name)"` + } - vars := map[string]any{ - "author": githubv4.String(getViewerQuery.Viewer.Login), - "owner": githubv4.String(params.Owner), - "name": githubv4.String(params.Repo), - "prNum": githubv4.Int(params.PullNumber), - } + vars := map[string]any{ + "author": githubv4.String(getViewerQuery.Viewer.Login), + "owner": githubv4.String(params.Owner), + "name": githubv4.String(params.Repo), + "prNum": githubv4.Int(params.PullNumber), + } - if err := client.Query(context.Background(), &getLatestReviewForViewerQuery, vars); err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, - "failed to get latest review for current user", - err, - ), nil, nil - } + if err := client.Query(context.Background(), &getLatestReviewForViewerQuery, vars); err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, + "failed to get latest review for current user", + err, + ), nil, nil + } - // Validate there is one review and the state is pending - if len(getLatestReviewForViewerQuery.Repository.PullRequest.Reviews.Nodes) == 0 { - return utils.NewToolResultError("No pending review found for the viewer"), nil, nil - } + // Validate there is one review and the state is pending + if len(getLatestReviewForViewerQuery.Repository.PullRequest.Reviews.Nodes) == 0 { + return utils.NewToolResultError("No pending review found for the viewer"), nil, nil + } - review := getLatestReviewForViewerQuery.Repository.PullRequest.Reviews.Nodes[0] - if review.State != githubv4.PullRequestReviewStatePending { - errText := fmt.Sprintf("The latest review, found at %s is not pending", review.URL) - return utils.NewToolResultError(errText), nil, nil - } + review := getLatestReviewForViewerQuery.Repository.PullRequest.Reviews.Nodes[0] + if review.State != githubv4.PullRequestReviewStatePending { + errText := fmt.Sprintf("The latest review, found at %s is not pending", review.URL) + return utils.NewToolResultError(errText), nil, nil + } - // Then we can create a new review thread comment on the review. - var addPullRequestReviewThreadMutation struct { - AddPullRequestReviewThread struct { - Thread struct { - ID githubv4.ID // We don't need this, but a selector is required or GQL complains. - } - } `graphql:"addPullRequestReviewThread(input: $input)"` - } + // Then we can create a new review thread comment on the review. + var addPullRequestReviewThreadMutation struct { + AddPullRequestReviewThread struct { + Thread struct { + ID githubv4.ID // We don't need this, but a selector is required or GQL complains. + } + } `graphql:"addPullRequestReviewThread(input: $input)"` + } - if err := client.Mutate( - ctx, - &addPullRequestReviewThreadMutation, - githubv4.AddPullRequestReviewThreadInput{ - Path: githubv4.String(params.Path), - Body: githubv4.String(params.Body), - SubjectType: newGQLStringlikePtr[githubv4.PullRequestReviewThreadSubjectType](¶ms.SubjectType), - Line: newGQLIntPtr(params.Line), - Side: newGQLStringlikePtr[githubv4.DiffSide](params.Side), - StartLine: newGQLIntPtr(params.StartLine), - StartSide: newGQLStringlikePtr[githubv4.DiffSide](params.StartSide), - PullRequestReviewID: &review.ID, - }, - nil, - ); err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + if err := client.Mutate( + ctx, + &addPullRequestReviewThreadMutation, + githubv4.AddPullRequestReviewThreadInput{ + Path: githubv4.String(params.Path), + Body: githubv4.String(params.Body), + SubjectType: newGQLStringlikePtr[githubv4.PullRequestReviewThreadSubjectType](¶ms.SubjectType), + Line: newGQLIntPtr(params.Line), + Side: newGQLStringlikePtr[githubv4.DiffSide](params.Side), + StartLine: newGQLIntPtr(params.StartLine), + StartSide: newGQLStringlikePtr[githubv4.DiffSide](params.StartSide), + PullRequestReviewID: &review.ID, + }, + nil, + ); err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - if addPullRequestReviewThreadMutation.AddPullRequestReviewThread.Thread.ID == nil { - return utils.NewToolResultError(`Failed to add comment to pending review. Possible reasons: + if addPullRequestReviewThreadMutation.AddPullRequestReviewThread.Thread.ID == nil { + return utils.NewToolResultError(`Failed to add comment to pending review. Possible reasons: - The line number doesn't exist in the pull request diff - The file path is incorrect - The side (LEFT/RIGHT) is invalid for the specified line `), nil, nil - } + } - // Return nothing interesting, just indicate success for the time being. - // In future, we may want to return the review ID, but for the moment, we're not leaking - // API implementation details to the LLM. - return utils.NewToolResultText("pull request review comment successfully added to pending review"), nil, nil - } + // Return nothing interesting, just indicate success for the time being. + // In future, we may want to return the review ID, but for the moment, we're not leaking + // API implementation details to the LLM. + return utils.NewToolResultText("pull request review comment successfully added to pending review"), nil, nil + } + }) } // RequestCopilotReview creates a tool to request a Copilot review for a pull request. // Note that this tool will not work on GHES where this feature is unsupported. In future, we should not expose this // tool if the configured host does not support it. -func RequestCopilotReview(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func RequestCopilotReview(t translations.TranslationHelperFunc) toolsets.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1735,7 +1763,8 @@ func RequestCopilotReview(getClient GetClientFn, t translations.TranslationHelpe Required: []string{"owner", "repo", "pullNumber"}, } - return mcp.Tool{ + return NewTool( + mcp.Tool{ Name: "request_copilot_review", Description: t("TOOL_REQUEST_COPILOT_REVIEW_DESCRIPTION", "Request a GitHub Copilot code review for a pull request. Use this for automated feedback on pull requests, usually before requesting a human reviewer."), Annotations: &mcp.ToolAnnotations{ @@ -1744,57 +1773,59 @@ func RequestCopilotReview(getClient GetClientFn, t translations.TranslationHelpe }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - pullNumber, err := RequiredInt(args, "pullNumber") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + pullNumber, err := RequiredInt(args, "pullNumber") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - _, resp, err := client.PullRequests.RequestReviewers( - ctx, - owner, - repo, - pullNumber, - github.ReviewersRequest{ - // The login name of the copilot reviewer bot - Reviewers: []string{"copilot-pull-request-reviewer[bot]"}, - }, - ) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to request copilot review", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - if resp.StatusCode != http.StatusCreated { - bodyBytes, err := io.ReadAll(resp.Body) + _, resp, err := client.PullRequests.RequestReviewers( + ctx, + owner, + repo, + pullNumber, + github.ReviewersRequest{ + // The login name of the copilot reviewer bot + Reviewers: []string{"copilot-pull-request-reviewer[bot]"}, + }, + ) if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to request copilot review", + resp, + err, + ), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to request copilot review: %s", string(bodyBytes))), nil, nil - } + defer func() { _ = resp.Body.Close() }() - // Return nothing on success, as there's not much value in returning the Pull Request itself - return utils.NewToolResultText(""), nil, nil - } + if resp.StatusCode != http.StatusCreated { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to request copilot review: %s", string(bodyBytes))), nil, nil + } + + // Return nothing on success, as there's not much value in returning the Pull Request itself + return utils.NewToolResultText(""), nil, nil + } + }) } // newGQLString like takes something that approximates a string (of which there are many types in shurcooL/githubv4) diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 94313d4e3..7531edf6d 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -21,8 +21,8 @@ import ( func Test_GetPullRequest(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + serverTool := PullRequestRead(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -104,13 +104,20 @@ func Test_GetPullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + gqlClient := githubv4.NewClient(githubv4mock.NewMockedHTTPClient()) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + GetGQLClient: stubGetGQLClientFn(gqlClient), + RepoAccessCache: stubRepoAccessCache(gqlClient, 5*time.Minute), + Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -141,8 +148,8 @@ func Test_GetPullRequest(t *testing.T) { func Test_UpdatePullRequest(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := UpdatePullRequest(stubGetClientFn(mockClient), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper) + serverTool := UpdatePullRequest(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "update_pull_request", tool.Name) @@ -362,13 +369,18 @@ func Test_UpdatePullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := UpdatePullRequest(stubGetClientFn(client), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper) + gqlClient := githubv4.NewClient(nil) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + GetGQLClient: stubGetGQLClientFn(gqlClient), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError || tc.expectedErrMsg != "" { @@ -548,11 +560,16 @@ func Test_UpdatePullRequest_Draft(t *testing.T) { )) gqlClient := githubv4.NewClient(tc.mockedClient) - _, handler := UpdatePullRequest(stubGetClientFn(restClient), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + serverTool := UpdatePullRequest(translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(restClient), + GetGQLClient: stubGetGQLClientFn(gqlClient), + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) if tc.expectError || tc.expectedErrMsg != "" { require.NoError(t, err) @@ -580,8 +597,8 @@ func Test_UpdatePullRequest_Draft(t *testing.T) { func Test_ListPullRequests(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListPullRequests(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ListPullRequests(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "list_pull_requests", tool.Name) @@ -675,13 +692,17 @@ func Test_ListPullRequests(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ListPullRequests(stubGetClientFn(client), translations.NullTranslationHelper) + serverTool := ListPullRequests(translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -715,8 +736,8 @@ func Test_ListPullRequests(t *testing.T) { func Test_MergePullRequest(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := MergePullRequest(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := MergePullRequest(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "merge_pull_request", tool.Name) @@ -795,13 +816,17 @@ func Test_MergePullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := MergePullRequest(stubGetClientFn(client), translations.NullTranslationHelper) + serverTool := MergePullRequest(translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -830,8 +855,8 @@ func Test_MergePullRequest(t *testing.T) { } func Test_SearchPullRequests(t *testing.T) { - mockClient := github.NewClient(nil) - tool, _ := SearchPullRequests(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := SearchPullRequests(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "search_pull_requests", tool.Name) @@ -1097,13 +1122,17 @@ func Test_SearchPullRequests(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SearchPullRequests(stubGetClientFn(client), translations.NullTranslationHelper) + serverTool := SearchPullRequests(translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -1141,8 +1170,8 @@ func Test_SearchPullRequests(t *testing.T) { func Test_GetPullRequestFiles(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + serverTool := PullRequestRead(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1246,13 +1275,19 @@ func Test_GetPullRequestFiles(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + serverTool := PullRequestRead(translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + RepoAccessCache: stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), + Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -1286,8 +1321,8 @@ func Test_GetPullRequestFiles(t *testing.T) { func Test_GetPullRequestStatus(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + serverTool := PullRequestRead(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1415,13 +1450,19 @@ func Test_GetPullRequestStatus(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + serverTool := PullRequestRead(translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + RepoAccessCache: stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), + Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -1456,8 +1497,8 @@ func Test_GetPullRequestStatus(t *testing.T) { func Test_UpdatePullRequestBranch(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := UpdatePullRequestBranch(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := UpdatePullRequestBranch(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "update_pull_request_branch", tool.Name) @@ -1547,13 +1588,17 @@ func Test_UpdatePullRequestBranch(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := UpdatePullRequestBranch(stubGetClientFn(client), translations.NullTranslationHelper) + serverTool := UpdatePullRequestBranch(translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -1577,8 +1622,8 @@ func Test_UpdatePullRequestBranch(t *testing.T) { func Test_GetPullRequestComments(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + serverTool := PullRequestRead(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1717,13 +1762,19 @@ func Test_GetPullRequestComments(t *testing.T) { } cache := stubRepoAccessCache(gqlClient, 5*time.Minute) flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) - _, handler := PullRequestRead(stubGetClientFn(client), cache, translations.NullTranslationHelper, flags) + serverTool := PullRequestRead(translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + RepoAccessCache: cache, + Flags: flags, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -1760,8 +1811,8 @@ func Test_GetPullRequestComments(t *testing.T) { func Test_GetPullRequestReviews(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + serverTool := PullRequestRead(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1899,13 +1950,19 @@ func Test_GetPullRequestReviews(t *testing.T) { } cache := stubRepoAccessCache(gqlClient, 5*time.Minute) flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) - _, handler := PullRequestRead(stubGetClientFn(client), cache, translations.NullTranslationHelper, flags) + serverTool := PullRequestRead(translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + RepoAccessCache: cache, + Flags: flags, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -1942,8 +1999,8 @@ func Test_GetPullRequestReviews(t *testing.T) { func Test_CreatePullRequest(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := CreatePullRequest(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := CreatePullRequest(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "create_pull_request", tool.Name) @@ -2057,13 +2114,17 @@ func Test_CreatePullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := CreatePullRequest(stubGetClientFn(client), translations.NullTranslationHelper) + serverTool := CreatePullRequest(translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -2096,8 +2157,8 @@ func TestCreateAndSubmitPullRequestReview(t *testing.T) { t.Parallel() // Verify tool definition once - mockClient := githubv4.NewClient(nil) - tool, _ := PullRequestReviewWrite(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper) + serverTool := PullRequestReviewWrite(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_review_write", tool.Name) @@ -2269,13 +2330,17 @@ func TestCreateAndSubmitPullRequestReview(t *testing.T) { // Setup client with mock client := githubv4.NewClient(tc.mockedClient) - _, handler := PullRequestReviewWrite(stubGetGQLClientFn(client), translations.NullTranslationHelper) + serverTool := PullRequestReviewWrite(translations.NullTranslationHelper) + deps := ToolDependencies{ + GetGQLClient: stubGetGQLClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) textContent := getTextResult(t, result) @@ -2295,8 +2360,8 @@ func TestCreateAndSubmitPullRequestReview(t *testing.T) { func Test_RequestCopilotReview(t *testing.T) { t.Parallel() - mockClient := github.NewClient(nil) - tool, _ := RequestCopilotReview(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := RequestCopilotReview(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "request_copilot_review", tool.Name) @@ -2381,11 +2446,15 @@ func Test_RequestCopilotReview(t *testing.T) { t.Parallel() client := github.NewClient(tc.mockedClient) - _, handler := RequestCopilotReview(stubGetClientFn(client), translations.NullTranslationHelper) + serverTool := RequestCopilotReview(translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) if tc.expectError { require.NoError(t, err) @@ -2410,8 +2479,8 @@ func TestCreatePendingPullRequestReview(t *testing.T) { t.Parallel() // Verify tool definition once - mockClient := githubv4.NewClient(nil) - tool, _ := PullRequestReviewWrite(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper) + serverTool := PullRequestReviewWrite(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_review_write", tool.Name) @@ -2571,13 +2640,17 @@ func TestCreatePendingPullRequestReview(t *testing.T) { // Setup client with mock client := githubv4.NewClient(tc.mockedClient) - _, handler := PullRequestReviewWrite(stubGetGQLClientFn(client), translations.NullTranslationHelper) + serverTool := PullRequestReviewWrite(translations.NullTranslationHelper) + deps := ToolDependencies{ + GetGQLClient: stubGetGQLClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) textContent := getTextResult(t, result) @@ -2598,8 +2671,8 @@ func TestAddPullRequestReviewCommentToPendingReview(t *testing.T) { t.Parallel() // Verify tool definition once - mockClient := githubv4.NewClient(nil) - tool, _ := AddCommentToPendingReview(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper) + serverTool := AddCommentToPendingReview(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "add_comment_to_pending_review", tool.Name) @@ -2750,13 +2823,17 @@ func TestAddPullRequestReviewCommentToPendingReview(t *testing.T) { // Setup client with mock client := githubv4.NewClient(tc.mockedClient) - _, handler := AddCommentToPendingReview(stubGetGQLClientFn(client), translations.NullTranslationHelper) + serverTool := AddCommentToPendingReview(translations.NullTranslationHelper) + deps := ToolDependencies{ + GetGQLClient: stubGetGQLClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) textContent := getTextResult(t, result) @@ -2777,8 +2854,8 @@ func TestSubmitPendingPullRequestReview(t *testing.T) { t.Parallel() // Verify tool definition once - mockClient := githubv4.NewClient(nil) - tool, _ := PullRequestReviewWrite(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper) + serverTool := PullRequestReviewWrite(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_review_write", tool.Name) @@ -2851,13 +2928,17 @@ func TestSubmitPendingPullRequestReview(t *testing.T) { // Setup client with mock client := githubv4.NewClient(tc.mockedClient) - _, handler := PullRequestReviewWrite(stubGetGQLClientFn(client), translations.NullTranslationHelper) + serverTool := PullRequestReviewWrite(translations.NullTranslationHelper) + deps := ToolDependencies{ + GetGQLClient: stubGetGQLClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) textContent := getTextResult(t, result) @@ -2878,8 +2959,8 @@ func TestDeletePendingPullRequestReview(t *testing.T) { t.Parallel() // Verify tool definition once - mockClient := githubv4.NewClient(nil) - tool, _ := PullRequestReviewWrite(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper) + serverTool := PullRequestReviewWrite(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_review_write", tool.Name) @@ -2946,13 +3027,17 @@ func TestDeletePendingPullRequestReview(t *testing.T) { // Setup client with mock client := githubv4.NewClient(tc.mockedClient) - _, handler := PullRequestReviewWrite(stubGetGQLClientFn(client), translations.NullTranslationHelper) + serverTool := PullRequestReviewWrite(translations.NullTranslationHelper) + deps := ToolDependencies{ + GetGQLClient: stubGetGQLClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) textContent := getTextResult(t, result) @@ -2973,8 +3058,8 @@ func TestGetPullRequestDiff(t *testing.T) { t.Parallel() // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + serverTool := PullRequestRead(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -3033,13 +3118,19 @@ index 5d6e7b2..8a4f5c3 100644 // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + serverTool := PullRequestRead(translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + RepoAccessCache: stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), + Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) textContent := getTextResult(t, result) diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 3bd210495..ba53f22af 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -214,17 +214,17 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG issues := toolsets.NewToolset(ToolsetMetadataIssues.ID, ToolsetMetadataIssues.Description). SetDependencies(deps). AddReadTools( - toolsets.NewServerToolLegacy(IssueRead(getClient, getGQLClient, cache, t, flags)), - toolsets.NewServerToolLegacy(SearchIssues(getClient, t)), - toolsets.NewServerToolLegacy(ListIssues(getGQLClient, t)), - toolsets.NewServerToolLegacy(ListIssueTypes(getClient, t)), + IssueRead(t), + SearchIssues(t), + ListIssues(t), + ListIssueTypes(t), toolsets.NewServerToolLegacy(GetLabel(getGQLClient, t)), ). AddWriteTools( - toolsets.NewServerToolLegacy(IssueWrite(getClient, getGQLClient, t)), - toolsets.NewServerToolLegacy(AddIssueComment(getClient, t)), - toolsets.NewServerToolLegacy(AssignCopilotToIssue(getGQLClient, t)), - toolsets.NewServerToolLegacy(SubIssueWrite(getClient, t)), + IssueWrite(t), + AddIssueComment(t), + AssignCopilotToIssue(t), + SubIssueWrite(t), ).AddPrompts( toolsets.NewServerPrompt(AssignCodingAgentPrompt(t)), toolsets.NewServerPrompt(IssueToFixWorkflowPrompt(t)), @@ -242,19 +242,19 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG pullRequests := toolsets.NewToolset(ToolsetMetadataPullRequests.ID, ToolsetMetadataPullRequests.Description). SetDependencies(deps). AddReadTools( - toolsets.NewServerToolLegacy(PullRequestRead(getClient, cache, t, flags)), - toolsets.NewServerToolLegacy(ListPullRequests(getClient, t)), - toolsets.NewServerToolLegacy(SearchPullRequests(getClient, t)), + PullRequestRead(t), + ListPullRequests(t), + SearchPullRequests(t), ). AddWriteTools( - toolsets.NewServerToolLegacy(MergePullRequest(getClient, t)), - toolsets.NewServerToolLegacy(UpdatePullRequestBranch(getClient, t)), - toolsets.NewServerToolLegacy(CreatePullRequest(getClient, t)), - toolsets.NewServerToolLegacy(UpdatePullRequest(getClient, getGQLClient, t)), - toolsets.NewServerToolLegacy(RequestCopilotReview(getClient, t)), + MergePullRequest(t), + UpdatePullRequestBranch(t), + CreatePullRequest(t), + UpdatePullRequest(t), + RequestCopilotReview(t), // Reviews - toolsets.NewServerToolLegacy(PullRequestReviewWrite(getGQLClient, t)), - toolsets.NewServerToolLegacy(AddCommentToPendingReview(getGQLClient, t)), + PullRequestReviewWrite(t), + AddCommentToPendingReview(t), ) codeSecurity := toolsets.NewToolset(ToolsetMetadataCodeSecurity.ID, ToolsetMetadataCodeSecurity.Description). SetDependencies(deps).