From 09ecc6648bbf97a9cb615839f35841497b372bda Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Sat, 13 Dec 2025 12:16:32 +0100 Subject: [PATCH 1/3] refactor: separate ServerTool into own file with HandlerFunc pattern - Extract ServerTool struct into pkg/toolsets/server_tool.go - Add ToolDependencies struct for passing common dependencies to handlers - HandlerFunc allows lazy handler generation from Tool definitions - NewServerTool for new dependency-based tools - NewServerToolLegacy for backward compatibility with existing handlers - Update toolsets.go to store and pass dependencies - Update all call sites to use NewServerToolLegacy Co-authored-by: Adam Holt <4619+omgitsads@users.noreply.github.com> --- internal/ghmcp/server.go | 3 +- pkg/github/dynamic_tools.go | 4 +- pkg/toolsets/server_tool.go | 121 ++++++++++++++++++++++++++++++++++ pkg/toolsets/toolsets.go | 44 +++++-------- pkg/toolsets/toolsets_test.go | 26 ++++++-- 5 files changed, 158 insertions(+), 40 deletions(-) create mode 100644 pkg/toolsets/server_tool.go diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index 41f9016a2..b210ece0d 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -18,6 +18,7 @@ import ( "github.com/github/github-mcp-server/pkg/lockdown" mcplog "github.com/github/github-mcp-server/pkg/log" "github.com/github/github-mcp-server/pkg/raw" + "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" gogithub "github.com/google/go-github/v79/github" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -181,7 +182,7 @@ func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { enabledTools, _ = tsg.ResolveToolAliases(enabledTools) // Register the specified tools (additive to any toolsets already enabled) - err = tsg.RegisterSpecificTools(ghServer, enabledTools, cfg.ReadOnly) + err = tsg.RegisterSpecificTools(ghServer, enabledTools, cfg.ReadOnly, toolsets.ToolDependencies{}) if err != nil { return nil, fmt.Errorf("failed to register tools: %w", err) } diff --git a/pkg/github/dynamic_tools.go b/pkg/github/dynamic_tools.go index c65510246..75d74f676 100644 --- a/pkg/github/dynamic_tools.go +++ b/pkg/github/dynamic_tools.go @@ -61,9 +61,7 @@ func EnableToolset(s *mcp.Server, toolsetGroup *toolsets.ToolsetGroup, t transla // // Send notification to all initialized sessions // s.sendNotificationToAllClients("notifications/tools/list_changed", nil) - for _, serverTool := range toolset.GetActiveTools() { - serverTool.RegisterFunc(s) - } + toolset.RegisterTools(s) return utils.NewToolResultText(fmt.Sprintf("Toolset %s enabled", toolsetName)), nil, nil }) diff --git a/pkg/toolsets/server_tool.go b/pkg/toolsets/server_tool.go new file mode 100644 index 000000000..93076ed1c --- /dev/null +++ b/pkg/toolsets/server_tool.go @@ -0,0 +1,121 @@ +package toolsets + +import ( + "context" + "encoding/json" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// HandlerFunc is a function that takes dependencies and returns an MCP tool handler. +// This allows tools to be defined statically while their handlers are generated +// on-demand with the appropriate dependencies. +type HandlerFunc func(deps ToolDependencies) mcp.ToolHandler + +// ToolDependencies contains all dependencies that tool handlers might need. +// Fields are pointers/interfaces so they can be nil when not needed by a specific tool. +type ToolDependencies struct { + // GetClient returns a GitHub REST API client + GetClient any // func(context.Context) (*github.Client, error) + + // GetGQLClient returns a GitHub GraphQL client + GetGQLClient any // func(context.Context) (*githubv4.Client, error) + + // GetRawClient returns a raw HTTP client for GitHub + GetRawClient any // raw.GetRawClientFn + + // RepoAccessCache is the lockdown mode repo access cache + RepoAccessCache any // *lockdown.RepoAccessCache + + // T is the translation helper function + T any // translations.TranslationHelperFunc + + // Flags are feature flags + Flags any // FeatureFlags + + // ContentWindowSize is the size of the content window for log truncation + ContentWindowSize int +} + +// ServerTool represents an MCP tool with a handler generator function. +// The tool definition is static, while the handler is generated on-demand +// when the tool is registered with a server. +type ServerTool struct { + // Tool is the MCP tool definition containing name, description, schema, etc. + Tool mcp.Tool + + // HandlerFunc generates the handler when given dependencies. + // This allows tools to be passed around without handlers being set up, + // and handlers are only created when needed. + HandlerFunc HandlerFunc +} + +// Handler returns a tool handler by calling HandlerFunc with the given dependencies. +func (st *ServerTool) Handler(deps ToolDependencies) mcp.ToolHandler { + if st.HandlerFunc == nil { + return nil + } + return st.HandlerFunc(deps) +} + +// RegisterFunc registers the tool with the server using the provided dependencies. +func (st *ServerTool) RegisterFunc(s *mcp.Server, deps ToolDependencies) { + handler := st.Handler(deps) + s.AddTool(&st.Tool, handler) +} + +// NewServerTool creates a ServerTool from a tool definition and a typed handler function. +// The handler function takes dependencies and returns a typed handler. +func NewServerTool[In any, Out any](tool mcp.Tool, handlerFn func(deps ToolDependencies) mcp.ToolHandlerFor[In, Out]) ServerTool { + return ServerTool{ + Tool: tool, + HandlerFunc: func(deps ToolDependencies) mcp.ToolHandler { + typedHandler := handlerFn(deps) + return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + var arguments In + if err := json.Unmarshal(req.Params.Arguments, &arguments); err != nil { + return nil, err + } + resp, _, err := typedHandler(ctx, req, arguments) + return resp, err + } + }, + } +} + +// NewServerToolFromHandler creates a ServerTool from a tool definition and a raw handler function. +// Use this when you have a handler that already conforms to mcp.ToolHandler. +func NewServerToolFromHandler(tool mcp.Tool, handlerFn func(deps ToolDependencies) mcp.ToolHandler) ServerTool { + return ServerTool{Tool: tool, HandlerFunc: handlerFn} +} + +// NewServerToolLegacy creates a ServerTool from a tool definition and an already-bound typed handler. +// This is for backward compatibility during the refactor - the handler doesn't use ToolDependencies. +// Deprecated: Use NewServerTool instead for new code. +func NewServerToolLegacy[In any, Out any](tool mcp.Tool, handler mcp.ToolHandlerFor[In, Out]) ServerTool { + return ServerTool{ + Tool: tool, + HandlerFunc: func(_ ToolDependencies) mcp.ToolHandler { + return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + var arguments In + if err := json.Unmarshal(req.Params.Arguments, &arguments); err != nil { + return nil, err + } + resp, _, err := handler(ctx, req, arguments) + return resp, err + } + }, + } +} + +// NewServerToolFromHandlerLegacy creates a ServerTool from a tool definition and an already-bound raw handler. +// This is for backward compatibility during the refactor - the handler doesn't use ToolDependencies. +// Deprecated: Use NewServerToolFromHandler instead for new code. +func NewServerToolFromHandlerLegacy(tool mcp.Tool, handler mcp.ToolHandler) ServerTool { + return ServerTool{ + Tool: tool, + HandlerFunc: func(_ ToolDependencies) mcp.ToolHandler { + return handler + }, + } +} diff --git a/pkg/toolsets/toolsets.go b/pkg/toolsets/toolsets.go index d96b5fb50..8105c3283 100644 --- a/pkg/toolsets/toolsets.go +++ b/pkg/toolsets/toolsets.go @@ -1,8 +1,6 @@ package toolsets import ( - "context" - "encoding/json" "fmt" "os" "strings" @@ -32,27 +30,7 @@ func NewToolsetDoesNotExistError(name string) *ToolsetDoesNotExistError { return &ToolsetDoesNotExistError{Name: name} } -type ServerTool struct { - Tool mcp.Tool - RegisterFunc func(s *mcp.Server) -} - -func NewServerTool[In any, Out any](tool mcp.Tool, handler mcp.ToolHandlerFor[In, Out]) ServerTool { - return ServerTool{Tool: tool, RegisterFunc: func(s *mcp.Server) { - th := func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { - var arguments In - if err := json.Unmarshal(req.Params.Arguments, &arguments); err != nil { - return nil, err - } - - resp, _, err := handler(ctx, req, arguments) - - return resp, err - } - - s.AddTool(&tool, th) - }} -} +// ServerTool is defined in server_tool.go type ServerResourceTemplate struct { Template mcp.ResourceTemplate @@ -86,6 +64,8 @@ type Toolset struct { readOnly bool writeTools []ServerTool readTools []ServerTool + // deps holds the dependencies for tool handlers + deps ToolDependencies // resources are not tools, but the community seems to be moving towards namespaces as a broader concept // and in order to have multiple servers running concurrently, we want to avoid overlapping resources too. resourceTemplates []ServerResourceTemplate @@ -114,16 +94,22 @@ func (t *Toolset) RegisterTools(s *mcp.Server) { if !t.Enabled { return } - for _, tool := range t.readTools { - tool.RegisterFunc(s) + for i := range t.readTools { + t.readTools[i].RegisterFunc(s, t.deps) } if !t.readOnly { - for _, tool := range t.writeTools { - tool.RegisterFunc(s) + for i := range t.writeTools { + t.writeTools[i].RegisterFunc(s, t.deps) } } } +// SetDependencies sets the dependencies for this toolset's tool handlers. +func (t *Toolset) SetDependencies(deps ToolDependencies) *Toolset { + t.deps = deps + return t +} + func (t *Toolset) AddResourceTemplates(templates ...ServerResourceTemplate) *Toolset { t.resourceTemplates = append(t.resourceTemplates, templates...) return t @@ -358,7 +344,7 @@ func (tg *ToolsetGroup) FindToolByName(toolName string) (*ServerTool, string, er // RegisterSpecificTools registers only the specified tools. // Respects read-only mode (skips write tools if readOnly=true). // Returns error if any tool is not found. -func (tg *ToolsetGroup) RegisterSpecificTools(s *mcp.Server, toolNames []string, readOnly bool) error { +func (tg *ToolsetGroup) RegisterSpecificTools(s *mcp.Server, toolNames []string, readOnly bool, deps ToolDependencies) error { var skippedTools []string for _, toolName := range toolNames { tool, _, err := tg.FindToolByName(toolName) @@ -373,7 +359,7 @@ func (tg *ToolsetGroup) RegisterSpecificTools(s *mcp.Server, toolNames []string, } // Register the tool - tool.RegisterFunc(s) + tool.RegisterFunc(s, deps) } // Log skipped write tools if any diff --git a/pkg/toolsets/toolsets_test.go b/pkg/toolsets/toolsets_test.go index 6362aad0e..b618a04dd 100644 --- a/pkg/toolsets/toolsets_test.go +++ b/pkg/toolsets/toolsets_test.go @@ -1,6 +1,8 @@ package toolsets import ( + "context" + "encoding/json" "errors" "testing" @@ -9,15 +11,20 @@ import ( // mockTool creates a minimal ServerTool for testing func mockTool(name string, readOnly bool) ServerTool { - return ServerTool{ - Tool: mcp.Tool{ + return NewServerToolFromHandler( + mcp.Tool{ Name: name, Annotations: &mcp.ToolAnnotations{ ReadOnlyHint: readOnly, }, + InputSchema: json.RawMessage(`{"type":"object","properties":{}}`), }, - RegisterFunc: func(_ *mcp.Server) {}, - } + func(_ ToolDependencies) mcp.ToolHandler { + return func(_ context.Context, _ *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return nil, nil + } + }, + ) } func TestNewToolsetGroupIsEmptyWithoutEverythingOn(t *testing.T) { @@ -375,20 +382,25 @@ func TestRegisterSpecificTools(t *testing.T) { toolset.writeTools = append(toolset.writeTools, mockTool("issue_write", false)) tsg.AddToolset(toolset) + deps := ToolDependencies{} + + // Create a real server for testing + server := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "1.0.0"}, nil) + // Test registering with canonical names - err := tsg.RegisterSpecificTools(nil, []string{"issue_read"}, false) + err := tsg.RegisterSpecificTools(server, []string{"issue_read"}, false, deps) if err != nil { t.Errorf("expected no error registering tool, got %v", err) } // Test registering write tool in read-only mode (should skip but not error) - err = tsg.RegisterSpecificTools(nil, []string{"issue_write"}, true) + err = tsg.RegisterSpecificTools(server, []string{"issue_write"}, true, deps) if err != nil { t.Errorf("expected no error when skipping write tool in read-only mode, got %v", err) } // Test registering non-existent tool (should error) - err = tsg.RegisterSpecificTools(nil, []string{"nonexistent"}, false) + err = tsg.RegisterSpecificTools(server, []string{"nonexistent"}, false, deps) if err == nil { t.Error("expected error for non-existent tool") } From 187a8b640a9a89b846d3300cd1f2ec464ce5cb92 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Sat, 13 Dec 2025 12:16:37 +0100 Subject: [PATCH 2/3] Wire ToolDependencies through toolsets - Move ToolDependencies to pkg/github/dependencies.go with proper types - Use 'any' in toolsets package to avoid circular dependencies - Add NewTool/NewToolFromHandler helpers that isolate type assertion - Tool implementations will be fully typed with no assertions scattered - Infrastructure ready for incremental tool migration --- internal/ghmcp/server.go | 14 ++- pkg/github/dependencies.go | 53 ++++++++ pkg/github/tools.go | 228 +++++++++++++++++++--------------- pkg/toolsets/server_tool.go | 50 +++----- pkg/toolsets/toolsets.go | 9 +- pkg/toolsets/toolsets_test.go | 5 +- 6 files changed, 216 insertions(+), 143 deletions(-) create mode 100644 pkg/github/dependencies.go diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index b210ece0d..c0f4e25e7 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -18,7 +18,6 @@ import ( "github.com/github/github-mcp-server/pkg/lockdown" mcplog "github.com/github/github-mcp-server/pkg/log" "github.com/github/github-mcp-server/pkg/raw" - "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" gogithub "github.com/google/go-github/v79/github" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -152,6 +151,17 @@ func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { ghServer.AddReceivingMiddleware(addGitHubAPIErrorToContext) ghServer.AddReceivingMiddleware(addUserAgentsMiddleware(cfg, restClient, gqlHTTPClient)) + // Create the dependencies struct for tool handlers + deps := github.ToolDependencies{ + GetClient: getClient, + GetGQLClient: getGQLClient, + GetRawClient: getRawClient, + RepoAccessCache: repoAccessCache, + T: cfg.Translator, + Flags: github.FeatureFlags{LockdownMode: cfg.LockdownMode}, + ContentWindowSize: cfg.ContentWindowSize, + } + // Create default toolsets tsg := github.DefaultToolsetGroup( cfg.ReadOnly, @@ -182,7 +192,7 @@ func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { enabledTools, _ = tsg.ResolveToolAliases(enabledTools) // Register the specified tools (additive to any toolsets already enabled) - err = tsg.RegisterSpecificTools(ghServer, enabledTools, cfg.ReadOnly, toolsets.ToolDependencies{}) + err = tsg.RegisterSpecificTools(ghServer, enabledTools, cfg.ReadOnly, deps) if err != nil { return nil, fmt.Errorf("failed to register tools: %w", err) } diff --git a/pkg/github/dependencies.go b/pkg/github/dependencies.go new file mode 100644 index 000000000..3124f6bd0 --- /dev/null +++ b/pkg/github/dependencies.go @@ -0,0 +1,53 @@ +package github + +import ( + "github.com/github/github-mcp-server/pkg/lockdown" + "github.com/github/github-mcp-server/pkg/raw" + "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/translations" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// ToolDependencies contains all dependencies that tool handlers might need. +// This is a properly-typed struct that lives in pkg/github to avoid circular +// dependencies. The toolsets package uses `any` for deps and tool handlers +// type-assert to this struct. +type ToolDependencies struct { + // GetClient returns a GitHub REST API client + GetClient GetClientFn + + // GetGQLClient returns a GitHub GraphQL client + GetGQLClient GetGQLClientFn + + // GetRawClient returns a raw HTTP client for GitHub + GetRawClient raw.GetRawClientFn + + // RepoAccessCache is the lockdown mode repo access cache + RepoAccessCache *lockdown.RepoAccessCache + + // T is the translation helper function + T translations.TranslationHelperFunc + + // Flags are feature flags + Flags FeatureFlags + + // ContentWindowSize is the size of the content window for log truncation + ContentWindowSize int +} + +// NewTool creates a ServerTool with fully-typed ToolDependencies. +// This helper isolates the type assertion from `any` to `ToolDependencies`, +// so tool implementations remain fully typed without assertions scattered throughout. +func NewTool[In, Out any](tool mcp.Tool, handler func(deps ToolDependencies) mcp.ToolHandlerFor[In, Out]) toolsets.ServerTool { + return toolsets.NewServerTool(tool, func(d any) mcp.ToolHandlerFor[In, Out] { + return handler(d.(ToolDependencies)) + }) +} + +// NewToolFromHandler creates a ServerTool with fully-typed ToolDependencies +// for handlers that conform to mcp.ToolHandler directly. +func NewToolFromHandler(tool mcp.Tool, handler func(deps ToolDependencies) mcp.ToolHandler) toolsets.ServerTool { + return toolsets.NewServerToolFromHandler(tool, func(d any) mcp.ToolHandler { + return handler(d.(ToolDependencies)) + }) +} diff --git a/pkg/github/tools.go b/pkg/github/tools.go index f21a9ae5b..efabfc92f 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -163,29 +163,41 @@ func GetDefaultToolsetIDs() []string { func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetGQLClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc, contentWindowSize int, flags FeatureFlags, cache *lockdown.RepoAccessCache) *toolsets.ToolsetGroup { tsg := toolsets.NewToolsetGroup(readOnly) + // Create the dependencies struct that will be passed to all tool handlers + deps := ToolDependencies{ + GetClient: getClient, + GetGQLClient: getGQLClient, + GetRawClient: getRawClient, + RepoAccessCache: cache, + T: t, + Flags: flags, + ContentWindowSize: contentWindowSize, + } + // Define all available features with their default state (disabled) // Create toolsets repos := toolsets.NewToolset(ToolsetMetadataRepos.ID, ToolsetMetadataRepos.Description). + SetDependencies(deps). AddReadTools( - toolsets.NewServerTool(SearchRepositories(getClient, t)), - toolsets.NewServerTool(GetFileContents(getClient, getRawClient, t)), - toolsets.NewServerTool(ListCommits(getClient, t)), - toolsets.NewServerTool(SearchCode(getClient, t)), - toolsets.NewServerTool(GetCommit(getClient, t)), - toolsets.NewServerTool(ListBranches(getClient, t)), - toolsets.NewServerTool(ListTags(getClient, t)), - toolsets.NewServerTool(GetTag(getClient, t)), - toolsets.NewServerTool(ListReleases(getClient, t)), - toolsets.NewServerTool(GetLatestRelease(getClient, t)), - toolsets.NewServerTool(GetReleaseByTag(getClient, t)), + toolsets.NewServerToolLegacy(SearchRepositories(getClient, t)), + toolsets.NewServerToolLegacy(GetFileContents(getClient, getRawClient, t)), + toolsets.NewServerToolLegacy(ListCommits(getClient, t)), + toolsets.NewServerToolLegacy(SearchCode(getClient, t)), + toolsets.NewServerToolLegacy(GetCommit(getClient, t)), + toolsets.NewServerToolLegacy(ListBranches(getClient, t)), + toolsets.NewServerToolLegacy(ListTags(getClient, t)), + toolsets.NewServerToolLegacy(GetTag(getClient, t)), + toolsets.NewServerToolLegacy(ListReleases(getClient, t)), + toolsets.NewServerToolLegacy(GetLatestRelease(getClient, t)), + toolsets.NewServerToolLegacy(GetReleaseByTag(getClient, t)), ). AddWriteTools( - toolsets.NewServerTool(CreateOrUpdateFile(getClient, t)), - toolsets.NewServerTool(CreateRepository(getClient, t)), - toolsets.NewServerTool(ForkRepository(getClient, t)), - toolsets.NewServerTool(CreateBranch(getClient, t)), - toolsets.NewServerTool(PushFiles(getClient, t)), - toolsets.NewServerTool(DeleteFile(getClient, t)), + toolsets.NewServerToolLegacy(CreateOrUpdateFile(getClient, t)), + toolsets.NewServerToolLegacy(CreateRepository(getClient, t)), + toolsets.NewServerToolLegacy(ForkRepository(getClient, t)), + toolsets.NewServerToolLegacy(CreateBranch(getClient, t)), + toolsets.NewServerToolLegacy(PushFiles(getClient, t)), + toolsets.NewServerToolLegacy(DeleteFile(getClient, t)), ). AddResourceTemplates( toolsets.NewServerResourceTemplate(GetRepositoryResourceContent(getClient, getRawClient, t)), @@ -195,166 +207,184 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG toolsets.NewServerResourceTemplate(GetRepositoryResourcePrContent(getClient, getRawClient, t)), ) git := toolsets.NewToolset(ToolsetMetadataGit.ID, ToolsetMetadataGit.Description). + SetDependencies(deps). AddReadTools( - toolsets.NewServerTool(GetRepositoryTree(getClient, t)), + toolsets.NewServerToolLegacy(GetRepositoryTree(getClient, t)), ) issues := toolsets.NewToolset(ToolsetMetadataIssues.ID, ToolsetMetadataIssues.Description). + SetDependencies(deps). AddReadTools( - toolsets.NewServerTool(IssueRead(getClient, getGQLClient, cache, t, flags)), - toolsets.NewServerTool(SearchIssues(getClient, t)), - toolsets.NewServerTool(ListIssues(getGQLClient, t)), - toolsets.NewServerTool(ListIssueTypes(getClient, t)), - toolsets.NewServerTool(GetLabel(getGQLClient, t)), + toolsets.NewServerToolLegacy(IssueRead(getClient, getGQLClient, cache, t, flags)), + toolsets.NewServerToolLegacy(SearchIssues(getClient, t)), + toolsets.NewServerToolLegacy(ListIssues(getGQLClient, t)), + toolsets.NewServerToolLegacy(ListIssueTypes(getClient, t)), + toolsets.NewServerToolLegacy(GetLabel(getGQLClient, t)), ). AddWriteTools( - toolsets.NewServerTool(IssueWrite(getClient, getGQLClient, t)), - toolsets.NewServerTool(AddIssueComment(getClient, t)), - toolsets.NewServerTool(AssignCopilotToIssue(getGQLClient, t)), - toolsets.NewServerTool(SubIssueWrite(getClient, t)), + toolsets.NewServerToolLegacy(IssueWrite(getClient, getGQLClient, t)), + toolsets.NewServerToolLegacy(AddIssueComment(getClient, t)), + toolsets.NewServerToolLegacy(AssignCopilotToIssue(getGQLClient, t)), + toolsets.NewServerToolLegacy(SubIssueWrite(getClient, t)), ).AddPrompts( toolsets.NewServerPrompt(AssignCodingAgentPrompt(t)), toolsets.NewServerPrompt(IssueToFixWorkflowPrompt(t)), ) users := toolsets.NewToolset(ToolsetMetadataUsers.ID, ToolsetMetadataUsers.Description). + SetDependencies(deps). AddReadTools( - toolsets.NewServerTool(SearchUsers(getClient, t)), + toolsets.NewServerToolLegacy(SearchUsers(getClient, t)), ) orgs := toolsets.NewToolset(ToolsetMetadataOrgs.ID, ToolsetMetadataOrgs.Description). + SetDependencies(deps). AddReadTools( - toolsets.NewServerTool(SearchOrgs(getClient, t)), + toolsets.NewServerToolLegacy(SearchOrgs(getClient, t)), ) pullRequests := toolsets.NewToolset(ToolsetMetadataPullRequests.ID, ToolsetMetadataPullRequests.Description). + SetDependencies(deps). AddReadTools( - toolsets.NewServerTool(PullRequestRead(getClient, cache, t, flags)), - toolsets.NewServerTool(ListPullRequests(getClient, t)), - toolsets.NewServerTool(SearchPullRequests(getClient, t)), + toolsets.NewServerToolLegacy(PullRequestRead(getClient, cache, t, flags)), + toolsets.NewServerToolLegacy(ListPullRequests(getClient, t)), + toolsets.NewServerToolLegacy(SearchPullRequests(getClient, t)), ). AddWriteTools( - toolsets.NewServerTool(MergePullRequest(getClient, t)), - toolsets.NewServerTool(UpdatePullRequestBranch(getClient, t)), - toolsets.NewServerTool(CreatePullRequest(getClient, t)), - toolsets.NewServerTool(UpdatePullRequest(getClient, getGQLClient, t)), - toolsets.NewServerTool(RequestCopilotReview(getClient, t)), + toolsets.NewServerToolLegacy(MergePullRequest(getClient, t)), + toolsets.NewServerToolLegacy(UpdatePullRequestBranch(getClient, t)), + toolsets.NewServerToolLegacy(CreatePullRequest(getClient, t)), + toolsets.NewServerToolLegacy(UpdatePullRequest(getClient, getGQLClient, t)), + toolsets.NewServerToolLegacy(RequestCopilotReview(getClient, t)), // Reviews - toolsets.NewServerTool(PullRequestReviewWrite(getGQLClient, t)), - toolsets.NewServerTool(AddCommentToPendingReview(getGQLClient, t)), + toolsets.NewServerToolLegacy(PullRequestReviewWrite(getGQLClient, t)), + toolsets.NewServerToolLegacy(AddCommentToPendingReview(getGQLClient, t)), ) codeSecurity := toolsets.NewToolset(ToolsetMetadataCodeSecurity.ID, ToolsetMetadataCodeSecurity.Description). + SetDependencies(deps). AddReadTools( - toolsets.NewServerTool(GetCodeScanningAlert(getClient, t)), - toolsets.NewServerTool(ListCodeScanningAlerts(getClient, t)), + toolsets.NewServerToolLegacy(GetCodeScanningAlert(getClient, t)), + toolsets.NewServerToolLegacy(ListCodeScanningAlerts(getClient, t)), ) secretProtection := toolsets.NewToolset(ToolsetMetadataSecretProtection.ID, ToolsetMetadataSecretProtection.Description). + SetDependencies(deps). AddReadTools( - toolsets.NewServerTool(GetSecretScanningAlert(getClient, t)), - toolsets.NewServerTool(ListSecretScanningAlerts(getClient, t)), + toolsets.NewServerToolLegacy(GetSecretScanningAlert(getClient, t)), + toolsets.NewServerToolLegacy(ListSecretScanningAlerts(getClient, t)), ) dependabot := toolsets.NewToolset(ToolsetMetadataDependabot.ID, ToolsetMetadataDependabot.Description). + SetDependencies(deps). AddReadTools( - toolsets.NewServerTool(GetDependabotAlert(getClient, t)), - toolsets.NewServerTool(ListDependabotAlerts(getClient, t)), + toolsets.NewServerToolLegacy(GetDependabotAlert(getClient, t)), + toolsets.NewServerToolLegacy(ListDependabotAlerts(getClient, t)), ) notifications := toolsets.NewToolset(ToolsetMetadataNotifications.ID, ToolsetMetadataNotifications.Description). + SetDependencies(deps). AddReadTools( - toolsets.NewServerTool(ListNotifications(getClient, t)), - toolsets.NewServerTool(GetNotificationDetails(getClient, t)), + toolsets.NewServerToolLegacy(ListNotifications(getClient, t)), + toolsets.NewServerToolLegacy(GetNotificationDetails(getClient, t)), ). AddWriteTools( - toolsets.NewServerTool(DismissNotification(getClient, t)), - toolsets.NewServerTool(MarkAllNotificationsRead(getClient, t)), - toolsets.NewServerTool(ManageNotificationSubscription(getClient, t)), - toolsets.NewServerTool(ManageRepositoryNotificationSubscription(getClient, t)), + toolsets.NewServerToolLegacy(DismissNotification(getClient, t)), + toolsets.NewServerToolLegacy(MarkAllNotificationsRead(getClient, t)), + toolsets.NewServerToolLegacy(ManageNotificationSubscription(getClient, t)), + toolsets.NewServerToolLegacy(ManageRepositoryNotificationSubscription(getClient, t)), ) discussions := toolsets.NewToolset(ToolsetMetadataDiscussions.ID, ToolsetMetadataDiscussions.Description). + SetDependencies(deps). AddReadTools( - toolsets.NewServerTool(ListDiscussions(getGQLClient, t)), - toolsets.NewServerTool(GetDiscussion(getGQLClient, t)), - toolsets.NewServerTool(GetDiscussionComments(getGQLClient, t)), - toolsets.NewServerTool(ListDiscussionCategories(getGQLClient, t)), + toolsets.NewServerToolLegacy(ListDiscussions(getGQLClient, t)), + toolsets.NewServerToolLegacy(GetDiscussion(getGQLClient, t)), + toolsets.NewServerToolLegacy(GetDiscussionComments(getGQLClient, t)), + toolsets.NewServerToolLegacy(ListDiscussionCategories(getGQLClient, t)), ) actions := toolsets.NewToolset(ToolsetMetadataActions.ID, ToolsetMetadataActions.Description). + SetDependencies(deps). AddReadTools( - toolsets.NewServerTool(ListWorkflows(getClient, t)), - toolsets.NewServerTool(ListWorkflowRuns(getClient, t)), - toolsets.NewServerTool(GetWorkflowRun(getClient, t)), - toolsets.NewServerTool(GetWorkflowRunLogs(getClient, t)), - toolsets.NewServerTool(ListWorkflowJobs(getClient, t)), - toolsets.NewServerTool(GetJobLogs(getClient, t, contentWindowSize)), - toolsets.NewServerTool(ListWorkflowRunArtifacts(getClient, t)), - toolsets.NewServerTool(DownloadWorkflowRunArtifact(getClient, t)), - toolsets.NewServerTool(GetWorkflowRunUsage(getClient, t)), + toolsets.NewServerToolLegacy(ListWorkflows(getClient, t)), + toolsets.NewServerToolLegacy(ListWorkflowRuns(getClient, t)), + toolsets.NewServerToolLegacy(GetWorkflowRun(getClient, t)), + toolsets.NewServerToolLegacy(GetWorkflowRunLogs(getClient, t)), + toolsets.NewServerToolLegacy(ListWorkflowJobs(getClient, t)), + toolsets.NewServerToolLegacy(GetJobLogs(getClient, t, contentWindowSize)), + toolsets.NewServerToolLegacy(ListWorkflowRunArtifacts(getClient, t)), + toolsets.NewServerToolLegacy(DownloadWorkflowRunArtifact(getClient, t)), + toolsets.NewServerToolLegacy(GetWorkflowRunUsage(getClient, t)), ). AddWriteTools( - toolsets.NewServerTool(RunWorkflow(getClient, t)), - toolsets.NewServerTool(RerunWorkflowRun(getClient, t)), - toolsets.NewServerTool(RerunFailedJobs(getClient, t)), - toolsets.NewServerTool(CancelWorkflowRun(getClient, t)), - toolsets.NewServerTool(DeleteWorkflowRunLogs(getClient, t)), + toolsets.NewServerToolLegacy(RunWorkflow(getClient, t)), + toolsets.NewServerToolLegacy(RerunWorkflowRun(getClient, t)), + toolsets.NewServerToolLegacy(RerunFailedJobs(getClient, t)), + toolsets.NewServerToolLegacy(CancelWorkflowRun(getClient, t)), + toolsets.NewServerToolLegacy(DeleteWorkflowRunLogs(getClient, t)), ) securityAdvisories := toolsets.NewToolset(ToolsetMetadataSecurityAdvisories.ID, ToolsetMetadataSecurityAdvisories.Description). + SetDependencies(deps). AddReadTools( - toolsets.NewServerTool(ListGlobalSecurityAdvisories(getClient, t)), - toolsets.NewServerTool(GetGlobalSecurityAdvisory(getClient, t)), - toolsets.NewServerTool(ListRepositorySecurityAdvisories(getClient, t)), - toolsets.NewServerTool(ListOrgRepositorySecurityAdvisories(getClient, t)), + toolsets.NewServerToolLegacy(ListGlobalSecurityAdvisories(getClient, t)), + toolsets.NewServerToolLegacy(GetGlobalSecurityAdvisory(getClient, t)), + toolsets.NewServerToolLegacy(ListRepositorySecurityAdvisories(getClient, t)), + toolsets.NewServerToolLegacy(ListOrgRepositorySecurityAdvisories(getClient, t)), ) // // Keep experiments alive so the system doesn't error out when it's always enabled - experiments := toolsets.NewToolset(ToolsetMetadataExperiments.ID, ToolsetMetadataExperiments.Description) + experiments := toolsets.NewToolset(ToolsetMetadataExperiments.ID, ToolsetMetadataExperiments.Description). + SetDependencies(deps) contextTools := toolsets.NewToolset(ToolsetMetadataContext.ID, ToolsetMetadataContext.Description). + SetDependencies(deps). AddReadTools( - toolsets.NewServerTool(GetMe(getClient, t)), - toolsets.NewServerTool(GetTeams(getClient, getGQLClient, t)), - toolsets.NewServerTool(GetTeamMembers(getGQLClient, t)), + toolsets.NewServerToolLegacy(GetMe(getClient, t)), + toolsets.NewServerToolLegacy(GetTeams(getClient, getGQLClient, t)), + toolsets.NewServerToolLegacy(GetTeamMembers(getGQLClient, t)), ) gists := toolsets.NewToolset(ToolsetMetadataGists.ID, ToolsetMetadataGists.Description). + SetDependencies(deps). AddReadTools( - toolsets.NewServerTool(ListGists(getClient, t)), - toolsets.NewServerTool(GetGist(getClient, t)), + toolsets.NewServerToolLegacy(ListGists(getClient, t)), + toolsets.NewServerToolLegacy(GetGist(getClient, t)), ). AddWriteTools( - toolsets.NewServerTool(CreateGist(getClient, t)), - toolsets.NewServerTool(UpdateGist(getClient, t)), + toolsets.NewServerToolLegacy(CreateGist(getClient, t)), + toolsets.NewServerToolLegacy(UpdateGist(getClient, t)), ) projects := toolsets.NewToolset(ToolsetMetadataProjects.ID, ToolsetMetadataProjects.Description). + SetDependencies(deps). AddReadTools( - toolsets.NewServerTool(ListProjects(getClient, t)), - toolsets.NewServerTool(GetProject(getClient, t)), - toolsets.NewServerTool(ListProjectFields(getClient, t)), - toolsets.NewServerTool(GetProjectField(getClient, t)), - toolsets.NewServerTool(ListProjectItems(getClient, t)), - toolsets.NewServerTool(GetProjectItem(getClient, t)), + toolsets.NewServerToolLegacy(ListProjects(getClient, t)), + toolsets.NewServerToolLegacy(GetProject(getClient, t)), + toolsets.NewServerToolLegacy(ListProjectFields(getClient, t)), + toolsets.NewServerToolLegacy(GetProjectField(getClient, t)), + toolsets.NewServerToolLegacy(ListProjectItems(getClient, t)), + toolsets.NewServerToolLegacy(GetProjectItem(getClient, t)), ). AddWriteTools( - toolsets.NewServerTool(AddProjectItem(getClient, t)), - toolsets.NewServerTool(DeleteProjectItem(getClient, t)), - toolsets.NewServerTool(UpdateProjectItem(getClient, t)), + toolsets.NewServerToolLegacy(AddProjectItem(getClient, t)), + toolsets.NewServerToolLegacy(DeleteProjectItem(getClient, t)), + toolsets.NewServerToolLegacy(UpdateProjectItem(getClient, t)), ) stargazers := toolsets.NewToolset(ToolsetMetadataStargazers.ID, ToolsetMetadataStargazers.Description). + SetDependencies(deps). AddReadTools( - toolsets.NewServerTool(ListStarredRepositories(getClient, t)), + toolsets.NewServerToolLegacy(ListStarredRepositories(getClient, t)), ). AddWriteTools( - toolsets.NewServerTool(StarRepository(getClient, t)), - toolsets.NewServerTool(UnstarRepository(getClient, t)), + toolsets.NewServerToolLegacy(StarRepository(getClient, t)), + toolsets.NewServerToolLegacy(UnstarRepository(getClient, t)), ) labels := toolsets.NewToolset(ToolsetLabels.ID, ToolsetLabels.Description). + SetDependencies(deps). AddReadTools( // get - toolsets.NewServerTool(GetLabel(getGQLClient, t)), + toolsets.NewServerToolLegacy(GetLabel(getGQLClient, t)), // list labels on repo or issue - toolsets.NewServerTool(ListLabels(getGQLClient, t)), + toolsets.NewServerToolLegacy(ListLabels(getGQLClient, t)), ). AddWriteTools( // create or update - toolsets.NewServerTool(LabelWrite(getGQLClient, t)), + toolsets.NewServerToolLegacy(LabelWrite(getGQLClient, t)), ) // Add toolsets to the group @@ -391,9 +421,9 @@ func InitDynamicToolset(s *mcp.Server, tsg *toolsets.ToolsetGroup, t translation // Need to add the dynamic toolset last so it can be used to enable other toolsets dynamicToolSelection := toolsets.NewToolset(ToolsetMetadataDynamic.ID, ToolsetMetadataDynamic.Description). AddReadTools( - toolsets.NewServerTool(ListAvailableToolsets(tsg, t)), - toolsets.NewServerTool(GetToolsetsTools(tsg, t)), - toolsets.NewServerTool(EnableToolset(s, tsg, t)), + toolsets.NewServerToolLegacy(ListAvailableToolsets(tsg, t)), + toolsets.NewServerToolLegacy(GetToolsetsTools(tsg, t)), + toolsets.NewServerToolLegacy(EnableToolset(s, tsg, t)), ) dynamicToolSelection.Enabled = true diff --git a/pkg/toolsets/server_tool.go b/pkg/toolsets/server_tool.go index 93076ed1c..3e3e5d9f8 100644 --- a/pkg/toolsets/server_tool.go +++ b/pkg/toolsets/server_tool.go @@ -10,32 +10,9 @@ import ( // HandlerFunc is a function that takes dependencies and returns an MCP tool handler. // This allows tools to be defined statically while their handlers are generated // on-demand with the appropriate dependencies. -type HandlerFunc func(deps ToolDependencies) mcp.ToolHandler - -// ToolDependencies contains all dependencies that tool handlers might need. -// Fields are pointers/interfaces so they can be nil when not needed by a specific tool. -type ToolDependencies struct { - // GetClient returns a GitHub REST API client - GetClient any // func(context.Context) (*github.Client, error) - - // GetGQLClient returns a GitHub GraphQL client - GetGQLClient any // func(context.Context) (*githubv4.Client, error) - - // GetRawClient returns a raw HTTP client for GitHub - GetRawClient any // raw.GetRawClientFn - - // RepoAccessCache is the lockdown mode repo access cache - RepoAccessCache any // *lockdown.RepoAccessCache - - // T is the translation helper function - T any // translations.TranslationHelperFunc - - // Flags are feature flags - Flags any // FeatureFlags - - // ContentWindowSize is the size of the content window for log truncation - ContentWindowSize int -} +// The deps parameter is typed as `any` to avoid circular dependencies - callers +// should define their own typed dependencies struct and type-assert as needed. +type HandlerFunc func(deps any) mcp.ToolHandler // ServerTool represents an MCP tool with a handler generator function. // The tool definition is static, while the handler is generated on-demand @@ -51,7 +28,7 @@ type ServerTool struct { } // Handler returns a tool handler by calling HandlerFunc with the given dependencies. -func (st *ServerTool) Handler(deps ToolDependencies) mcp.ToolHandler { +func (st *ServerTool) Handler(deps any) mcp.ToolHandler { if st.HandlerFunc == nil { return nil } @@ -59,17 +36,18 @@ func (st *ServerTool) Handler(deps ToolDependencies) mcp.ToolHandler { } // RegisterFunc registers the tool with the server using the provided dependencies. -func (st *ServerTool) RegisterFunc(s *mcp.Server, deps ToolDependencies) { +func (st *ServerTool) RegisterFunc(s *mcp.Server, deps any) { handler := st.Handler(deps) s.AddTool(&st.Tool, handler) } // NewServerTool creates a ServerTool from a tool definition and a typed handler function. -// The handler function takes dependencies and returns a typed handler. -func NewServerTool[In any, Out any](tool mcp.Tool, handlerFn func(deps ToolDependencies) mcp.ToolHandlerFor[In, Out]) ServerTool { +// The handler function takes dependencies (as any) and returns a typed handler. +// Callers should type-assert deps to their typed dependencies struct. +func NewServerTool[In any, Out any](tool mcp.Tool, handlerFn func(deps any) mcp.ToolHandlerFor[In, Out]) ServerTool { return ServerTool{ Tool: tool, - HandlerFunc: func(deps ToolDependencies) mcp.ToolHandler { + HandlerFunc: func(deps any) mcp.ToolHandler { typedHandler := handlerFn(deps) return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { var arguments In @@ -85,17 +63,17 @@ func NewServerTool[In any, Out any](tool mcp.Tool, handlerFn func(deps ToolDepen // NewServerToolFromHandler creates a ServerTool from a tool definition and a raw handler function. // Use this when you have a handler that already conforms to mcp.ToolHandler. -func NewServerToolFromHandler(tool mcp.Tool, handlerFn func(deps ToolDependencies) mcp.ToolHandler) ServerTool { +func NewServerToolFromHandler(tool mcp.Tool, handlerFn func(deps any) mcp.ToolHandler) ServerTool { return ServerTool{Tool: tool, HandlerFunc: handlerFn} } // NewServerToolLegacy creates a ServerTool from a tool definition and an already-bound typed handler. -// This is for backward compatibility during the refactor - the handler doesn't use ToolDependencies. +// This is for backward compatibility during the refactor - the handler doesn't use dependencies. // Deprecated: Use NewServerTool instead for new code. func NewServerToolLegacy[In any, Out any](tool mcp.Tool, handler mcp.ToolHandlerFor[In, Out]) ServerTool { return ServerTool{ Tool: tool, - HandlerFunc: func(_ ToolDependencies) mcp.ToolHandler { + HandlerFunc: func(_ any) mcp.ToolHandler { return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { var arguments In if err := json.Unmarshal(req.Params.Arguments, &arguments); err != nil { @@ -109,12 +87,12 @@ func NewServerToolLegacy[In any, Out any](tool mcp.Tool, handler mcp.ToolHandler } // NewServerToolFromHandlerLegacy creates a ServerTool from a tool definition and an already-bound raw handler. -// This is for backward compatibility during the refactor - the handler doesn't use ToolDependencies. +// This is for backward compatibility during the refactor - the handler doesn't use dependencies. // Deprecated: Use NewServerToolFromHandler instead for new code. func NewServerToolFromHandlerLegacy(tool mcp.Tool, handler mcp.ToolHandler) ServerTool { return ServerTool{ Tool: tool, - HandlerFunc: func(_ ToolDependencies) mcp.ToolHandler { + HandlerFunc: func(_ any) mcp.ToolHandler { return handler }, } diff --git a/pkg/toolsets/toolsets.go b/pkg/toolsets/toolsets.go index 8105c3283..8502328d5 100644 --- a/pkg/toolsets/toolsets.go +++ b/pkg/toolsets/toolsets.go @@ -64,8 +64,8 @@ type Toolset struct { readOnly bool writeTools []ServerTool readTools []ServerTool - // deps holds the dependencies for tool handlers - deps ToolDependencies + // deps holds the dependencies for tool handlers (typed as any to avoid circular deps) + deps any // resources are not tools, but the community seems to be moving towards namespaces as a broader concept // and in order to have multiple servers running concurrently, we want to avoid overlapping resources too. resourceTemplates []ServerResourceTemplate @@ -105,7 +105,8 @@ func (t *Toolset) RegisterTools(s *mcp.Server) { } // SetDependencies sets the dependencies for this toolset's tool handlers. -func (t *Toolset) SetDependencies(deps ToolDependencies) *Toolset { +// The deps parameter is typed as `any` to avoid circular dependencies between packages. +func (t *Toolset) SetDependencies(deps any) *Toolset { t.deps = deps return t } @@ -344,7 +345,7 @@ func (tg *ToolsetGroup) FindToolByName(toolName string) (*ServerTool, string, er // RegisterSpecificTools registers only the specified tools. // Respects read-only mode (skips write tools if readOnly=true). // Returns error if any tool is not found. -func (tg *ToolsetGroup) RegisterSpecificTools(s *mcp.Server, toolNames []string, readOnly bool, deps ToolDependencies) error { +func (tg *ToolsetGroup) RegisterSpecificTools(s *mcp.Server, toolNames []string, readOnly bool, deps any) error { var skippedTools []string for _, toolName := range toolNames { tool, _, err := tg.FindToolByName(toolName) diff --git a/pkg/toolsets/toolsets_test.go b/pkg/toolsets/toolsets_test.go index b618a04dd..66be5cba2 100644 --- a/pkg/toolsets/toolsets_test.go +++ b/pkg/toolsets/toolsets_test.go @@ -19,7 +19,7 @@ func mockTool(name string, readOnly bool) ServerTool { }, InputSchema: json.RawMessage(`{"type":"object","properties":{}}`), }, - func(_ ToolDependencies) mcp.ToolHandler { + func(_ any) mcp.ToolHandler { return func(_ context.Context, _ *mcp.CallToolRequest) (*mcp.CallToolResult, error) { return nil, nil } @@ -382,7 +382,8 @@ func TestRegisterSpecificTools(t *testing.T) { toolset.writeTools = append(toolset.writeTools, mockTool("issue_write", false)) tsg.AddToolset(toolset) - deps := ToolDependencies{} + // deps is typed as any in toolsets package (to avoid circular deps) + var deps any // Create a real server for testing server := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "1.0.0"}, nil) From 43acefe0865552f263a147910341fefe7cdb3b30 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Sat, 13 Dec 2025 12:47:23 +0100 Subject: [PATCH 3/3] refactor(search): migrate search tools to new ServerTool pattern Migrate search.go tools (SearchRepositories, SearchCode, SearchUsers, SearchOrgs) to use the new NewTool helper and ToolDependencies pattern. - Functions now take only TranslationHelperFunc and return ServerTool - Handler generation uses ToolDependencies for typed access to clients - Update tools.go call sites to remove getClient parameter - Update tests to use new Handler(deps) pattern This demonstrates the migration pattern for additional tool files. Co-authored-by: Adam Holt --- pkg/github/search.go | 351 ++++++++++++++++++++------------------ pkg/github/search_test.go | 52 ++++-- pkg/github/tools.go | 8 +- 3 files changed, 223 insertions(+), 188 deletions(-) diff --git a/pkg/github/search.go b/pkg/github/search.go index cffd0bf15..eaaf49369 100644 --- a/pkg/github/search.go +++ b/pkg/github/search.go @@ -8,6 +8,7 @@ import ( "net/http" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -16,7 +17,7 @@ import ( ) // SearchRepositories creates a tool to search for GitHub repositories. -func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func SearchRepositories(t translations.TranslationHelperFunc) toolsets.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -44,7 +45,8 @@ func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperF } WithPagination(schema) - return mcp.Tool{ + return NewTool( + mcp.Tool{ Name: "search_repositories", Description: t("TOOL_SEARCH_REPOSITORIES_DESCRIPTION", "Find GitHub repositories by name, description, readme, topics, or other metadata. Perfect for discovering projects, finding examples, or locating specific repositories across GitHub."), Annotations: &mcp.ToolAnnotations{ @@ -53,115 +55,118 @@ func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperF }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - query, err := RequiredParam[string](args, "query") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sort, err := OptionalParam[string](args, "sort") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - order, err := OptionalParam[string](args, "order") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - minimalOutput, err := OptionalBoolParamWithDefault(args, "minimal_output", true) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - opts := &github.SearchOptions{ - Sort: sort, - Order: order, - ListOptions: github.ListOptions{ - Page: pagination.Page, - PerPage: pagination.PerPage, - }, - } - - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - result, resp, err := client.Search.Repositories(ctx, query, opts) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to search repositories with query '%s'", query), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + query, err := RequiredParam[string](args, "query") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + sort, err := OptionalParam[string](args, "sort") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + order, err := OptionalParam[string](args, "order") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + minimalOutput, err := OptionalBoolParamWithDefault(args, "minimal_output", true) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + opts := &github.SearchOptions{ + Sort: sort, + Order: order, + ListOptions: github.ListOptions{ + Page: pagination.Page, + PerPage: pagination.PerPage, + }, + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + client, err := deps.GetClient(ctx) if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to search repositories: %s", string(body))), nil, nil - } + result, resp, err := client.Search.Repositories(ctx, query, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to search repositories with query '%s'", query), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - // Return either minimal or full response based on parameter - var r []byte - if minimalOutput { - minimalRepos := make([]MinimalRepository, 0, len(result.Repositories)) - for _, repo := range result.Repositories { - minimalRepo := MinimalRepository{ - ID: repo.GetID(), - Name: repo.GetName(), - FullName: repo.GetFullName(), - Description: repo.GetDescription(), - HTMLURL: repo.GetHTMLURL(), - Language: repo.GetLanguage(), - Stars: repo.GetStargazersCount(), - Forks: repo.GetForksCount(), - OpenIssues: repo.GetOpenIssuesCount(), - Private: repo.GetPrivate(), - Fork: repo.GetFork(), - Archived: repo.GetArchived(), - DefaultBranch: repo.GetDefaultBranch(), + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } + return utils.NewToolResultError(fmt.Sprintf("failed to search repositories: %s", string(body))), nil, nil + } - if repo.UpdatedAt != nil { - minimalRepo.UpdatedAt = repo.UpdatedAt.Format("2006-01-02T15:04:05Z") - } - if repo.CreatedAt != nil { - minimalRepo.CreatedAt = repo.CreatedAt.Format("2006-01-02T15:04:05Z") - } - if repo.Topics != nil { - minimalRepo.Topics = repo.Topics + // Return either minimal or full response based on parameter + var r []byte + if minimalOutput { + minimalRepos := make([]MinimalRepository, 0, len(result.Repositories)) + for _, repo := range result.Repositories { + minimalRepo := MinimalRepository{ + ID: repo.GetID(), + Name: repo.GetName(), + FullName: repo.GetFullName(), + Description: repo.GetDescription(), + HTMLURL: repo.GetHTMLURL(), + Language: repo.GetLanguage(), + Stars: repo.GetStargazersCount(), + Forks: repo.GetForksCount(), + OpenIssues: repo.GetOpenIssuesCount(), + Private: repo.GetPrivate(), + Fork: repo.GetFork(), + Archived: repo.GetArchived(), + DefaultBranch: repo.GetDefaultBranch(), + } + + if repo.UpdatedAt != nil { + minimalRepo.UpdatedAt = repo.UpdatedAt.Format("2006-01-02T15:04:05Z") + } + if repo.CreatedAt != nil { + minimalRepo.CreatedAt = repo.CreatedAt.Format("2006-01-02T15:04:05Z") + } + if repo.Topics != nil { + minimalRepo.Topics = repo.Topics + } + + minimalRepos = append(minimalRepos, minimalRepo) } - minimalRepos = append(minimalRepos, minimalRepo) - } + minimalResult := &MinimalSearchRepositoriesResult{ + TotalCount: result.GetTotal(), + IncompleteResults: result.GetIncompleteResults(), + Items: minimalRepos, + } - minimalResult := &MinimalSearchRepositoriesResult{ - TotalCount: result.GetTotal(), - IncompleteResults: result.GetIncompleteResults(), - Items: minimalRepos, + r, err = json.Marshal(minimalResult) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal minimal response", err), nil, nil + } + } else { + r, err = json.Marshal(result) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal full response", err), nil, nil + } } - r, err = json.Marshal(minimalResult) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal minimal response", err), nil, nil - } - } else { - r, err = json.Marshal(result) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal full response", err), nil, nil - } + return utils.NewToolResultText(string(r)), nil, nil } - - return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) } // SearchCode creates a tool to search for code across GitHub repositories. -func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func SearchCode(t translations.TranslationHelperFunc) toolsets.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -183,7 +188,8 @@ func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (mc } WithPagination(schema) - return mcp.Tool{ + return NewTool( + mcp.Tool{ Name: "search_code", Description: t("TOOL_SEARCH_CODE_DESCRIPTION", "Fast and precise code search across ALL GitHub repositories using GitHub's native search engine. Best for finding exact symbols, functions, classes, or specific code patterns."), Annotations: &mcp.ToolAnnotations{ @@ -192,66 +198,69 @@ func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (mc }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - query, err := RequiredParam[string](args, "query") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sort, err := OptionalParam[string](args, "sort") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - order, err := OptionalParam[string](args, "order") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + query, err := RequiredParam[string](args, "query") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + sort, err := OptionalParam[string](args, "sort") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + order, err := OptionalParam[string](args, "order") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - opts := &github.SearchOptions{ - Sort: sort, - Order: order, - ListOptions: github.ListOptions{ - PerPage: pagination.PerPage, - Page: pagination.Page, - }, - } + opts := &github.SearchOptions{ + Sort: sort, + Order: order, + ListOptions: github.ListOptions{ + PerPage: pagination.PerPage, + Page: pagination.Page, + }, + } - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - result, resp, err := client.Search.Code(ctx, query, opts) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to search code with query '%s'", query), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + result, resp, err := client.Search.Code(ctx, query, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to search code with query '%s'", query), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to search code: %s", string(body))), nil, nil + } + + r, err := json.Marshal(result) if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to search code: %s", string(body))), nil, nil - } - r, err := json.Marshal(result) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + return utils.NewToolResultText(string(r)), nil, nil } - - return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) } -func userOrOrgHandler(accountType string, getClient GetClientFn) mcp.ToolHandlerFor[map[string]any, any] { +func userOrOrgHandler(accountType string, deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { query, err := RequiredParam[string](args, "query") if err != nil { @@ -279,7 +288,7 @@ func userOrOrgHandler(accountType string, getClient GetClientFn) mcp.ToolHandler }, } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } @@ -340,7 +349,7 @@ func userOrOrgHandler(accountType string, getClient GetClientFn) mcp.ToolHandler } // SearchUsers creates a tool to search for GitHub users. -func SearchUsers(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func SearchUsers(t translations.TranslationHelperFunc) toolsets.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -363,19 +372,24 @@ func SearchUsers(getClient GetClientFn, t translations.TranslationHelperFunc) (m } WithPagination(schema) - return mcp.Tool{ - Name: "search_users", - Description: t("TOOL_SEARCH_USERS_DESCRIPTION", "Find GitHub users by username, real name, or other profile information. Useful for locating developers, contributors, or team members."), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_SEARCH_USERS_USER_TITLE", "Search users"), - ReadOnlyHint: true, + return NewTool( + mcp.Tool{ + Name: "search_users", + Description: t("TOOL_SEARCH_USERS_DESCRIPTION", "Find GitHub users by username, real name, or other profile information. Useful for locating developers, contributors, or team members."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_SEARCH_USERS_USER_TITLE", "Search users"), + ReadOnlyHint: true, + }, + InputSchema: schema, + }, + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return userOrOrgHandler("user", deps) }, - InputSchema: schema, - }, userOrOrgHandler("user", getClient) + ) } // SearchOrgs creates a tool to search for GitHub organizations. -func SearchOrgs(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func SearchOrgs(t translations.TranslationHelperFunc) toolsets.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -398,13 +412,18 @@ func SearchOrgs(getClient GetClientFn, t translations.TranslationHelperFunc) (mc } WithPagination(schema) - return mcp.Tool{ - Name: "search_orgs", - Description: t("TOOL_SEARCH_ORGS_DESCRIPTION", "Find GitHub organizations by name, location, or other organization metadata. Ideal for discovering companies, open source foundations, or teams."), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_SEARCH_ORGS_USER_TITLE", "Search organizations"), - ReadOnlyHint: true, + return NewTool( + mcp.Tool{ + Name: "search_orgs", + Description: t("TOOL_SEARCH_ORGS_DESCRIPTION", "Find GitHub organizations by name, location, or other organization metadata. Ideal for discovering companies, open source foundations, or teams."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_SEARCH_ORGS_USER_TITLE", "Search organizations"), + ReadOnlyHint: true, + }, + InputSchema: schema, + }, + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return userOrOrgHandler("org", deps) }, - InputSchema: schema, - }, userOrOrgHandler("org", getClient) + ) } diff --git a/pkg/github/search_test.go b/pkg/github/search_test.go index 0b923edcd..41d12df1b 100644 --- a/pkg/github/search_test.go +++ b/pkg/github/search_test.go @@ -17,8 +17,8 @@ import ( func Test_SearchRepositories(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := SearchRepositories(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := SearchRepositories(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "search_repositories", tool.Name) @@ -134,13 +134,16 @@ func Test_SearchRepositories(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SearchRepositories(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -205,7 +208,11 @@ func Test_SearchRepositories_FullOutput(t *testing.T) { ) client := github.NewClient(mockedClient) - _, handlerTest := SearchRepositories(stubGetClientFn(client), translations.NullTranslationHelper) + serverTool := SearchRepositories(translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) args := map[string]interface{}{ "query": "golang test", @@ -214,7 +221,7 @@ func Test_SearchRepositories_FullOutput(t *testing.T) { request := createMCPRequest(args) - result, _, err := handlerTest(context.Background(), &request, args) + result, err := handler(context.Background(), &request) require.NoError(t, err) require.False(t, result.IsError) @@ -236,8 +243,8 @@ func Test_SearchRepositories_FullOutput(t *testing.T) { func Test_SearchCode(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := SearchCode(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := SearchCode(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "search_code", tool.Name) @@ -351,13 +358,16 @@ func Test_SearchCode(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SearchCode(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -394,8 +404,8 @@ func Test_SearchCode(t *testing.T) { func Test_SearchUsers(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := SearchUsers(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := SearchUsers(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "search_users", tool.Name) @@ -548,13 +558,16 @@ func Test_SearchUsers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SearchUsers(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -592,8 +605,8 @@ func Test_SearchUsers(t *testing.T) { func Test_SearchOrgs(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := SearchOrgs(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := SearchOrgs(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) @@ -720,13 +733,16 @@ func Test_SearchOrgs(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SearchOrgs(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { diff --git a/pkg/github/tools.go b/pkg/github/tools.go index efabfc92f..dff1ca02e 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -179,10 +179,10 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG repos := toolsets.NewToolset(ToolsetMetadataRepos.ID, ToolsetMetadataRepos.Description). SetDependencies(deps). AddReadTools( - toolsets.NewServerToolLegacy(SearchRepositories(getClient, t)), + SearchRepositories(t), toolsets.NewServerToolLegacy(GetFileContents(getClient, getRawClient, t)), toolsets.NewServerToolLegacy(ListCommits(getClient, t)), - toolsets.NewServerToolLegacy(SearchCode(getClient, t)), + SearchCode(t), toolsets.NewServerToolLegacy(GetCommit(getClient, t)), toolsets.NewServerToolLegacy(ListBranches(getClient, t)), toolsets.NewServerToolLegacy(ListTags(getClient, t)), @@ -232,12 +232,12 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG users := toolsets.NewToolset(ToolsetMetadataUsers.ID, ToolsetMetadataUsers.Description). SetDependencies(deps). AddReadTools( - toolsets.NewServerToolLegacy(SearchUsers(getClient, t)), + SearchUsers(t), ) orgs := toolsets.NewToolset(ToolsetMetadataOrgs.ID, ToolsetMetadataOrgs.Description). SetDependencies(deps). AddReadTools( - toolsets.NewServerToolLegacy(SearchOrgs(getClient, t)), + SearchOrgs(t), ) pullRequests := toolsets.NewToolset(ToolsetMetadataPullRequests.ID, ToolsetMetadataPullRequests.Description). SetDependencies(deps).