From b29546e6e564b1eb179928c62ef0e581bb3e1dcb Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Sat, 13 Dec 2025 12:47:23 +0100 Subject: [PATCH 1/2] refactor(search): migrate search tools to new ServerTool pattern Migrate search.go tools (SearchRepositories, SearchCode, SearchUsers, SearchOrgs) to use the new NewTool helper and ToolDependencies pattern. - Functions now take only TranslationHelperFunc and return ServerTool - Handler generation uses ToolDependencies for typed access to clients - Update tools.go call sites to remove getClient parameter - Update tests to use new Handler(deps) pattern This demonstrates the migration pattern for additional tool files. Co-authored-by: Adam Holt --- pkg/github/search.go | 351 ++++++++++++++++++++------------------ pkg/github/search_test.go | 52 ++++-- pkg/github/tools.go | 8 +- 3 files changed, 223 insertions(+), 188 deletions(-) diff --git a/pkg/github/search.go b/pkg/github/search.go index cffd0bf15..eaaf49369 100644 --- a/pkg/github/search.go +++ b/pkg/github/search.go @@ -8,6 +8,7 @@ import ( "net/http" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -16,7 +17,7 @@ import ( ) // SearchRepositories creates a tool to search for GitHub repositories. -func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func SearchRepositories(t translations.TranslationHelperFunc) toolsets.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -44,7 +45,8 @@ func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperF } WithPagination(schema) - return mcp.Tool{ + return NewTool( + mcp.Tool{ Name: "search_repositories", Description: t("TOOL_SEARCH_REPOSITORIES_DESCRIPTION", "Find GitHub repositories by name, description, readme, topics, or other metadata. Perfect for discovering projects, finding examples, or locating specific repositories across GitHub."), Annotations: &mcp.ToolAnnotations{ @@ -53,115 +55,118 @@ func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperF }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - query, err := RequiredParam[string](args, "query") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sort, err := OptionalParam[string](args, "sort") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - order, err := OptionalParam[string](args, "order") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - minimalOutput, err := OptionalBoolParamWithDefault(args, "minimal_output", true) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - opts := &github.SearchOptions{ - Sort: sort, - Order: order, - ListOptions: github.ListOptions{ - Page: pagination.Page, - PerPage: pagination.PerPage, - }, - } - - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - result, resp, err := client.Search.Repositories(ctx, query, opts) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to search repositories with query '%s'", query), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + query, err := RequiredParam[string](args, "query") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + sort, err := OptionalParam[string](args, "sort") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + order, err := OptionalParam[string](args, "order") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + minimalOutput, err := OptionalBoolParamWithDefault(args, "minimal_output", true) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + opts := &github.SearchOptions{ + Sort: sort, + Order: order, + ListOptions: github.ListOptions{ + Page: pagination.Page, + PerPage: pagination.PerPage, + }, + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + client, err := deps.GetClient(ctx) if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to search repositories: %s", string(body))), nil, nil - } + result, resp, err := client.Search.Repositories(ctx, query, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to search repositories with query '%s'", query), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - // Return either minimal or full response based on parameter - var r []byte - if minimalOutput { - minimalRepos := make([]MinimalRepository, 0, len(result.Repositories)) - for _, repo := range result.Repositories { - minimalRepo := MinimalRepository{ - ID: repo.GetID(), - Name: repo.GetName(), - FullName: repo.GetFullName(), - Description: repo.GetDescription(), - HTMLURL: repo.GetHTMLURL(), - Language: repo.GetLanguage(), - Stars: repo.GetStargazersCount(), - Forks: repo.GetForksCount(), - OpenIssues: repo.GetOpenIssuesCount(), - Private: repo.GetPrivate(), - Fork: repo.GetFork(), - Archived: repo.GetArchived(), - DefaultBranch: repo.GetDefaultBranch(), + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } + return utils.NewToolResultError(fmt.Sprintf("failed to search repositories: %s", string(body))), nil, nil + } - if repo.UpdatedAt != nil { - minimalRepo.UpdatedAt = repo.UpdatedAt.Format("2006-01-02T15:04:05Z") - } - if repo.CreatedAt != nil { - minimalRepo.CreatedAt = repo.CreatedAt.Format("2006-01-02T15:04:05Z") - } - if repo.Topics != nil { - minimalRepo.Topics = repo.Topics + // Return either minimal or full response based on parameter + var r []byte + if minimalOutput { + minimalRepos := make([]MinimalRepository, 0, len(result.Repositories)) + for _, repo := range result.Repositories { + minimalRepo := MinimalRepository{ + ID: repo.GetID(), + Name: repo.GetName(), + FullName: repo.GetFullName(), + Description: repo.GetDescription(), + HTMLURL: repo.GetHTMLURL(), + Language: repo.GetLanguage(), + Stars: repo.GetStargazersCount(), + Forks: repo.GetForksCount(), + OpenIssues: repo.GetOpenIssuesCount(), + Private: repo.GetPrivate(), + Fork: repo.GetFork(), + Archived: repo.GetArchived(), + DefaultBranch: repo.GetDefaultBranch(), + } + + if repo.UpdatedAt != nil { + minimalRepo.UpdatedAt = repo.UpdatedAt.Format("2006-01-02T15:04:05Z") + } + if repo.CreatedAt != nil { + minimalRepo.CreatedAt = repo.CreatedAt.Format("2006-01-02T15:04:05Z") + } + if repo.Topics != nil { + minimalRepo.Topics = repo.Topics + } + + minimalRepos = append(minimalRepos, minimalRepo) } - minimalRepos = append(minimalRepos, minimalRepo) - } + minimalResult := &MinimalSearchRepositoriesResult{ + TotalCount: result.GetTotal(), + IncompleteResults: result.GetIncompleteResults(), + Items: minimalRepos, + } - minimalResult := &MinimalSearchRepositoriesResult{ - TotalCount: result.GetTotal(), - IncompleteResults: result.GetIncompleteResults(), - Items: minimalRepos, + r, err = json.Marshal(minimalResult) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal minimal response", err), nil, nil + } + } else { + r, err = json.Marshal(result) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal full response", err), nil, nil + } } - r, err = json.Marshal(minimalResult) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal minimal response", err), nil, nil - } - } else { - r, err = json.Marshal(result) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal full response", err), nil, nil - } + return utils.NewToolResultText(string(r)), nil, nil } - - return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) } // SearchCode creates a tool to search for code across GitHub repositories. -func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func SearchCode(t translations.TranslationHelperFunc) toolsets.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -183,7 +188,8 @@ func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (mc } WithPagination(schema) - return mcp.Tool{ + return NewTool( + mcp.Tool{ Name: "search_code", Description: t("TOOL_SEARCH_CODE_DESCRIPTION", "Fast and precise code search across ALL GitHub repositories using GitHub's native search engine. Best for finding exact symbols, functions, classes, or specific code patterns."), Annotations: &mcp.ToolAnnotations{ @@ -192,66 +198,69 @@ func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (mc }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - query, err := RequiredParam[string](args, "query") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sort, err := OptionalParam[string](args, "sort") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - order, err := OptionalParam[string](args, "order") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + query, err := RequiredParam[string](args, "query") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + sort, err := OptionalParam[string](args, "sort") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + order, err := OptionalParam[string](args, "order") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - opts := &github.SearchOptions{ - Sort: sort, - Order: order, - ListOptions: github.ListOptions{ - PerPage: pagination.PerPage, - Page: pagination.Page, - }, - } + opts := &github.SearchOptions{ + Sort: sort, + Order: order, + ListOptions: github.ListOptions{ + PerPage: pagination.PerPage, + Page: pagination.Page, + }, + } - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - result, resp, err := client.Search.Code(ctx, query, opts) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to search code with query '%s'", query), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + result, resp, err := client.Search.Code(ctx, query, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to search code with query '%s'", query), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to search code: %s", string(body))), nil, nil + } + + r, err := json.Marshal(result) if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to search code: %s", string(body))), nil, nil - } - r, err := json.Marshal(result) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + return utils.NewToolResultText(string(r)), nil, nil } - - return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) } -func userOrOrgHandler(accountType string, getClient GetClientFn) mcp.ToolHandlerFor[map[string]any, any] { +func userOrOrgHandler(accountType string, deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { query, err := RequiredParam[string](args, "query") if err != nil { @@ -279,7 +288,7 @@ func userOrOrgHandler(accountType string, getClient GetClientFn) mcp.ToolHandler }, } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } @@ -340,7 +349,7 @@ func userOrOrgHandler(accountType string, getClient GetClientFn) mcp.ToolHandler } // SearchUsers creates a tool to search for GitHub users. -func SearchUsers(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func SearchUsers(t translations.TranslationHelperFunc) toolsets.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -363,19 +372,24 @@ func SearchUsers(getClient GetClientFn, t translations.TranslationHelperFunc) (m } WithPagination(schema) - return mcp.Tool{ - Name: "search_users", - Description: t("TOOL_SEARCH_USERS_DESCRIPTION", "Find GitHub users by username, real name, or other profile information. Useful for locating developers, contributors, or team members."), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_SEARCH_USERS_USER_TITLE", "Search users"), - ReadOnlyHint: true, + return NewTool( + mcp.Tool{ + Name: "search_users", + Description: t("TOOL_SEARCH_USERS_DESCRIPTION", "Find GitHub users by username, real name, or other profile information. Useful for locating developers, contributors, or team members."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_SEARCH_USERS_USER_TITLE", "Search users"), + ReadOnlyHint: true, + }, + InputSchema: schema, + }, + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return userOrOrgHandler("user", deps) }, - InputSchema: schema, - }, userOrOrgHandler("user", getClient) + ) } // SearchOrgs creates a tool to search for GitHub organizations. -func SearchOrgs(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func SearchOrgs(t translations.TranslationHelperFunc) toolsets.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -398,13 +412,18 @@ func SearchOrgs(getClient GetClientFn, t translations.TranslationHelperFunc) (mc } WithPagination(schema) - return mcp.Tool{ - Name: "search_orgs", - Description: t("TOOL_SEARCH_ORGS_DESCRIPTION", "Find GitHub organizations by name, location, or other organization metadata. Ideal for discovering companies, open source foundations, or teams."), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_SEARCH_ORGS_USER_TITLE", "Search organizations"), - ReadOnlyHint: true, + return NewTool( + mcp.Tool{ + Name: "search_orgs", + Description: t("TOOL_SEARCH_ORGS_DESCRIPTION", "Find GitHub organizations by name, location, or other organization metadata. Ideal for discovering companies, open source foundations, or teams."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_SEARCH_ORGS_USER_TITLE", "Search organizations"), + ReadOnlyHint: true, + }, + InputSchema: schema, + }, + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return userOrOrgHandler("org", deps) }, - InputSchema: schema, - }, userOrOrgHandler("org", getClient) + ) } diff --git a/pkg/github/search_test.go b/pkg/github/search_test.go index 0b923edcd..41d12df1b 100644 --- a/pkg/github/search_test.go +++ b/pkg/github/search_test.go @@ -17,8 +17,8 @@ import ( func Test_SearchRepositories(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := SearchRepositories(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := SearchRepositories(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "search_repositories", tool.Name) @@ -134,13 +134,16 @@ func Test_SearchRepositories(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SearchRepositories(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -205,7 +208,11 @@ func Test_SearchRepositories_FullOutput(t *testing.T) { ) client := github.NewClient(mockedClient) - _, handlerTest := SearchRepositories(stubGetClientFn(client), translations.NullTranslationHelper) + serverTool := SearchRepositories(translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) args := map[string]interface{}{ "query": "golang test", @@ -214,7 +221,7 @@ func Test_SearchRepositories_FullOutput(t *testing.T) { request := createMCPRequest(args) - result, _, err := handlerTest(context.Background(), &request, args) + result, err := handler(context.Background(), &request) require.NoError(t, err) require.False(t, result.IsError) @@ -236,8 +243,8 @@ func Test_SearchRepositories_FullOutput(t *testing.T) { func Test_SearchCode(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := SearchCode(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := SearchCode(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "search_code", tool.Name) @@ -351,13 +358,16 @@ func Test_SearchCode(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SearchCode(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -394,8 +404,8 @@ func Test_SearchCode(t *testing.T) { func Test_SearchUsers(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := SearchUsers(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := SearchUsers(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "search_users", tool.Name) @@ -548,13 +558,16 @@ func Test_SearchUsers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SearchUsers(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -592,8 +605,8 @@ func Test_SearchUsers(t *testing.T) { func Test_SearchOrgs(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := SearchOrgs(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := SearchOrgs(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) @@ -720,13 +733,16 @@ func Test_SearchOrgs(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SearchOrgs(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { diff --git a/pkg/github/tools.go b/pkg/github/tools.go index efabfc92f..dff1ca02e 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -179,10 +179,10 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG repos := toolsets.NewToolset(ToolsetMetadataRepos.ID, ToolsetMetadataRepos.Description). SetDependencies(deps). AddReadTools( - toolsets.NewServerToolLegacy(SearchRepositories(getClient, t)), + SearchRepositories(t), toolsets.NewServerToolLegacy(GetFileContents(getClient, getRawClient, t)), toolsets.NewServerToolLegacy(ListCommits(getClient, t)), - toolsets.NewServerToolLegacy(SearchCode(getClient, t)), + SearchCode(t), toolsets.NewServerToolLegacy(GetCommit(getClient, t)), toolsets.NewServerToolLegacy(ListBranches(getClient, t)), toolsets.NewServerToolLegacy(ListTags(getClient, t)), @@ -232,12 +232,12 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG users := toolsets.NewToolset(ToolsetMetadataUsers.ID, ToolsetMetadataUsers.Description). SetDependencies(deps). AddReadTools( - toolsets.NewServerToolLegacy(SearchUsers(getClient, t)), + SearchUsers(t), ) orgs := toolsets.NewToolset(ToolsetMetadataOrgs.ID, ToolsetMetadataOrgs.Description). SetDependencies(deps). AddReadTools( - toolsets.NewServerToolLegacy(SearchOrgs(getClient, t)), + SearchOrgs(t), ) pullRequests := toolsets.NewToolset(ToolsetMetadataPullRequests.ID, ToolsetMetadataPullRequests.Description). SetDependencies(deps). From b863da9dcd92d47c4da16140c03dfce2ac31df00 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Sat, 13 Dec 2025 13:00:29 +0100 Subject: [PATCH 2/2] Migrate context_tools to new ServerTool pattern Convert GetMe, GetTeams, and GetTeamMembers to use the new typed dependency injection pattern: - Functions now take only translations helper, return toolsets.ServerTool - Handler is generated lazily via deps.GetClient/deps.GetGQLClient - Tests updated to use serverTool.Handler(deps) pattern - Fixed error return pattern to return nil for Go error (via result.IsError) Co-authored-by: Adam Holt --- pkg/github/context_tools.go | 305 ++++++++++++++++--------------- pkg/github/context_tools_test.go | 61 +++++-- pkg/github/tools.go | 6 +- 3 files changed, 204 insertions(+), 168 deletions(-) diff --git a/pkg/github/context_tools.go b/pkg/github/context_tools.go index 3fe622379..e8043731a 100644 --- a/pkg/github/context_tools.go +++ b/pkg/github/context_tools.go @@ -6,6 +6,7 @@ import ( "time" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/jsonschema-go/jsonschema" @@ -36,8 +37,9 @@ type UserDetails struct { } // GetMe creates a tool to get details of the authenticated user. -func GetMe(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetMe(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "get_me", Description: t("TOOL_GET_ME_DESCRIPTION", "Get details of the authenticated GitHub user. Use this when a request is about the user's own profile for GitHub. Or when information is missing to build other tool calls."), Annotations: &mcp.ToolAnnotations{ @@ -48,50 +50,53 @@ func GetMe(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Too // OpenAI strict mode requires the properties field to be present. InputSchema: json.RawMessage(`{"type":"object","properties":{}}`), }, - mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, _ map[string]any) (*mcp.CallToolResult, any, error) { - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, _ map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - user, res, err := client.Users.Get(ctx, "") - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get user", - res, - err, - ), nil, err - } + user, res, err := client.Users.Get(ctx, "") + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get user", + res, + err, + ), nil, nil + } - // Create minimal user representation instead of returning full user object - minimalUser := MinimalUser{ - Login: user.GetLogin(), - ID: user.GetID(), - ProfileURL: user.GetHTMLURL(), - AvatarURL: user.GetAvatarURL(), - Details: &UserDetails{ - Name: user.GetName(), - Company: user.GetCompany(), - Blog: user.GetBlog(), - Location: user.GetLocation(), - Email: user.GetEmail(), - Hireable: user.GetHireable(), - Bio: user.GetBio(), - TwitterUsername: user.GetTwitterUsername(), - PublicRepos: user.GetPublicRepos(), - PublicGists: user.GetPublicGists(), - Followers: user.GetFollowers(), - Following: user.GetFollowing(), - CreatedAt: user.GetCreatedAt().Time, - UpdatedAt: user.GetUpdatedAt().Time, - PrivateGists: user.GetPrivateGists(), - TotalPrivateRepos: user.GetTotalPrivateRepos(), - OwnedPrivateRepos: user.GetOwnedPrivateRepos(), - }, - } + // Create minimal user representation instead of returning full user object + minimalUser := MinimalUser{ + Login: user.GetLogin(), + ID: user.GetID(), + ProfileURL: user.GetHTMLURL(), + AvatarURL: user.GetAvatarURL(), + Details: &UserDetails{ + Name: user.GetName(), + Company: user.GetCompany(), + Blog: user.GetBlog(), + Location: user.GetLocation(), + Email: user.GetEmail(), + Hireable: user.GetHireable(), + Bio: user.GetBio(), + TwitterUsername: user.GetTwitterUsername(), + PublicRepos: user.GetPublicRepos(), + PublicGists: user.GetPublicGists(), + Followers: user.GetFollowers(), + Following: user.GetFollowing(), + CreatedAt: user.GetCreatedAt().Time, + UpdatedAt: user.GetUpdatedAt().Time, + PrivateGists: user.GetPrivateGists(), + TotalPrivateRepos: user.GetTotalPrivateRepos(), + OwnedPrivateRepos: user.GetOwnedPrivateRepos(), + }, + } - return MarshalledTextResult(minimalUser), nil, nil - }) + return MarshalledTextResult(minimalUser), nil, nil + } + }, + ) } type TeamInfo struct { @@ -105,8 +110,9 @@ type OrganizationTeams struct { Teams []TeamInfo `json:"teams"` } -func GetTeams(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetTeams(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "get_teams", Description: t("TOOL_GET_TEAMS_DESCRIPTION", "Get details of the teams the user is a member of. Limited to organizations accessible with current credentials"), Annotations: &mcp.ToolAnnotations{ @@ -123,84 +129,88 @@ func GetTeams(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations }, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - user, err := OptionalParam[string](args, "user") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - var username string - if user != "" { - username = user - } else { - client, err := getClient(ctx) + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + user, err := OptionalParam[string](args, "user") if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + return utils.NewToolResultError(err.Error()), nil, nil } - userResp, res, err := client.Users.Get(ctx, "") + var username string + if user != "" { + username = user + } else { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + + userResp, res, err := client.Users.Get(ctx, "") + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get user", + res, + err, + ), nil, nil + } + username = userResp.GetLogin() + } + + gqlClient, err := deps.GetGQLClient(ctx) if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get user", - res, - err, - ), nil, nil + return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil } - username = userResp.GetLogin() - } - gqlClient, err := getGQLClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil - } + var q struct { + User struct { + Organizations struct { + Nodes []struct { + Login githubv4.String + Teams struct { + Nodes []struct { + Name githubv4.String + Slug githubv4.String + Description githubv4.String + } + } `graphql:"teams(first: 100, userLogins: [$login])"` + } + } `graphql:"organizations(first: 100)"` + } `graphql:"user(login: $login)"` + } + vars := map[string]interface{}{ + "login": githubv4.String(username), + } + if err := gqlClient.Query(ctx, &q, vars); err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find teams", err), nil, nil + } - var q struct { - User struct { - Organizations struct { - Nodes []struct { - Login githubv4.String - Teams struct { - Nodes []struct { - Name githubv4.String - Slug githubv4.String - Description githubv4.String - } - } `graphql:"teams(first: 100, userLogins: [$login])"` - } - } `graphql:"organizations(first: 100)"` - } `graphql:"user(login: $login)"` - } - vars := map[string]interface{}{ - "login": githubv4.String(username), - } - if err := gqlClient.Query(ctx, &q, vars); err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find teams", err), nil, nil - } + var organizations []OrganizationTeams + for _, org := range q.User.Organizations.Nodes { + orgTeams := OrganizationTeams{ + Org: string(org.Login), + Teams: make([]TeamInfo, 0, len(org.Teams.Nodes)), + } - var organizations []OrganizationTeams - for _, org := range q.User.Organizations.Nodes { - orgTeams := OrganizationTeams{ - Org: string(org.Login), - Teams: make([]TeamInfo, 0, len(org.Teams.Nodes)), - } + for _, team := range org.Teams.Nodes { + orgTeams.Teams = append(orgTeams.Teams, TeamInfo{ + Name: string(team.Name), + Slug: string(team.Slug), + Description: string(team.Description), + }) + } - for _, team := range org.Teams.Nodes { - orgTeams.Teams = append(orgTeams.Teams, TeamInfo{ - Name: string(team.Name), - Slug: string(team.Slug), - Description: string(team.Description), - }) + organizations = append(organizations, orgTeams) } - organizations = append(organizations, orgTeams) + return MarshalledTextResult(organizations), nil, nil } - - return MarshalledTextResult(organizations), nil, nil - } + }, + ) } -func GetTeamMembers(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetTeamMembers(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "get_team_members", Description: t("TOOL_GET_TEAM_MEMBERS_DESCRIPTION", "Get member usernames of a specific team in an organization. Limited to organizations accessible with current credentials"), Annotations: &mcp.ToolAnnotations{ @@ -222,46 +232,49 @@ func GetTeamMembers(getGQLClient GetGQLClientFn, t translations.TranslationHelpe Required: []string{"org", "team_slug"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - org, err := RequiredParam[string](args, "org") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + org, err := RequiredParam[string](args, "org") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - teamSlug, err := RequiredParam[string](args, "team_slug") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + teamSlug, err := RequiredParam[string](args, "team_slug") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - gqlClient, err := getGQLClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil - } + gqlClient, err := deps.GetGQLClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil + } - var q struct { - Organization struct { - Team struct { - Members struct { - Nodes []struct { - Login githubv4.String - } - } `graphql:"members(first: 100)"` - } `graphql:"team(slug: $teamSlug)"` - } `graphql:"organization(login: $org)"` - } - vars := map[string]interface{}{ - "org": githubv4.String(org), - "teamSlug": githubv4.String(teamSlug), - } - if err := gqlClient.Query(ctx, &q, vars); err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to get team members", err), nil, nil - } + var q struct { + Organization struct { + Team struct { + Members struct { + Nodes []struct { + Login githubv4.String + } + } `graphql:"members(first: 100)"` + } `graphql:"team(slug: $teamSlug)"` + } `graphql:"organization(login: $org)"` + } + vars := map[string]interface{}{ + "org": githubv4.String(org), + "teamSlug": githubv4.String(teamSlug), + } + if err := gqlClient.Query(ctx, &q, vars); err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to get team members", err), nil, nil + } - var members []string - for _, member := range q.Organization.Team.Members.Nodes { - members = append(members, string(member.Login)) - } + var members []string + for _, member := range q.Organization.Team.Members.Nodes { + members = append(members, string(member.Login)) + } - return MarshalledTextResult(members), nil, nil - } + return MarshalledTextResult(members), nil, nil + } + }, + ) } diff --git a/pkg/github/context_tools_test.go b/pkg/github/context_tools_test.go index 96e21c233..0e28aad49 100644 --- a/pkg/github/context_tools_test.go +++ b/pkg/github/context_tools_test.go @@ -20,7 +20,8 @@ import ( func Test_GetMe(t *testing.T) { t.Parallel() - tool, _ := GetMe(nil, translations.NullTranslationHelper) + serverTool := GetMe(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) // Verify some basic very important properties @@ -108,21 +109,28 @@ func Test_GetMe(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - _, handler := GetMe(tc.stubbedGetClientFn, translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: tc.stubbedGetClientFn, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, _ := handler(context.Background(), &request, tc.requestArgs) - textContent := getTextResult(t, result) + result, err := handler(context.Background(), &request) + require.NoError(t, err) if tc.expectToolError { - assert.True(t, result.IsError, "expected tool call result to be an error") - assert.Contains(t, textContent.Text, tc.expectedToolErrMsg) + require.True(t, result.IsError, "expected tool call result to be an error") + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedToolErrMsg) return } + require.False(t, result.IsError) + textContent := getTextResult(t, result) + // Unmarshal and verify the result var returnedUser MinimalUser - err := json.Unmarshal([]byte(textContent.Text), &returnedUser) + err = json.Unmarshal([]byte(textContent.Text), &returnedUser) require.NoError(t, err) // Verify minimal user details @@ -145,7 +153,8 @@ func Test_GetMe(t *testing.T) { func Test_GetTeams(t *testing.T) { t.Parallel() - tool, _ := GetTeams(nil, nil, translations.NullTranslationHelper) + serverTool := GetTeams(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "get_teams", tool.Name) @@ -331,19 +340,26 @@ func Test_GetTeams(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - _, handler := GetTeams(tc.stubbedGetClientFn, tc.stubbedGetGQLClientFn, translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: tc.stubbedGetClientFn, + GetGQLClient: tc.stubbedGetGQLClientFn, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) - textContent := getTextResult(t, result) if tc.expectToolError { - assert.True(t, result.IsError, "expected tool call result to be an error") - assert.Contains(t, textContent.Text, tc.expectedToolErrMsg) + require.True(t, result.IsError, "expected tool call result to be an error") + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedToolErrMsg) return } + require.False(t, result.IsError) + textContent := getTextResult(t, result) + var organizations []OrganizationTeams err = json.Unmarshal([]byte(textContent.Text), &organizations) require.NoError(t, err) @@ -372,7 +388,8 @@ func Test_GetTeams(t *testing.T) { func Test_GetTeamMembers(t *testing.T) { t.Parallel() - tool, _ := GetTeamMembers(nil, translations.NullTranslationHelper) + serverTool := GetTeamMembers(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "get_team_members", tool.Name) @@ -467,19 +484,25 @@ func Test_GetTeamMembers(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - _, handler := GetTeamMembers(tc.stubbedGetGQLClientFn, translations.NullTranslationHelper) + deps := ToolDependencies{ + GetGQLClient: tc.stubbedGetGQLClientFn, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) - textContent := getTextResult(t, result) if tc.expectToolError { - assert.True(t, result.IsError, "expected tool call result to be an error") - assert.Contains(t, textContent.Text, tc.expectedToolErrMsg) + require.True(t, result.IsError, "expected tool call result to be an error") + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedToolErrMsg) return } + require.False(t, result.IsError) + textContent := getTextResult(t, result) + var members []string err = json.Unmarshal([]byte(textContent.Text), &members) require.NoError(t, err) diff --git a/pkg/github/tools.go b/pkg/github/tools.go index dff1ca02e..51fcad7f6 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -334,9 +334,9 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG contextTools := toolsets.NewToolset(ToolsetMetadataContext.ID, ToolsetMetadataContext.Description). SetDependencies(deps). AddReadTools( - toolsets.NewServerToolLegacy(GetMe(getClient, t)), - toolsets.NewServerToolLegacy(GetTeams(getClient, getGQLClient, t)), - toolsets.NewServerToolLegacy(GetTeamMembers(getGQLClient, t)), + GetMe(t), + GetTeams(t), + GetTeamMembers(t), ) gists := toolsets.NewToolset(ToolsetMetadataGists.ID, ToolsetMetadataGists.Description).