From d29e73b0ba0a79513071a79a2984ddab0b495b6b Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Sat, 13 Dec 2025 16:31:14 +0100 Subject: [PATCH] refactor(git): migrate GetRepositoryTree to NewTool pattern --- pkg/github/git.go | 228 ++++++++++++++++---------------- pkg/github/git_test.go | 19 +-- pkg/github/repositories_test.go | 178 ------------------------- pkg/github/tools.go | 2 +- 4 files changed, 126 insertions(+), 301 deletions(-) diff --git a/pkg/github/git.go b/pkg/github/git.go index c2a839132..c5fdded7c 100644 --- a/pkg/github/git.go +++ b/pkg/github/git.go @@ -7,6 +7,7 @@ import ( "strings" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -37,140 +38,139 @@ type TreeResponse struct { } // GetRepositoryTree creates a tool to get the tree structure of a GitHub repository. -func GetRepositoryTree(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "get_repository_tree", - Description: t("TOOL_GET_REPOSITORY_TREE_DESCRIPTION", "Get the tree structure (files and directories) of a GitHub repository at a specific ref or SHA"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_GET_REPOSITORY_TREE_USER_TITLE", "Get repository tree"), - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "Repository owner (username or organization)", - }, - "repo": { - Type: "string", - Description: "Repository name", - }, - "tree_sha": { - Type: "string", - Description: "The SHA1 value or ref (branch or tag) name of the tree. Defaults to the repository's default branch", - }, - "recursive": { - Type: "boolean", - Description: "Setting this parameter to true returns the objects or subtrees referenced by the tree. Default is false", - Default: json.RawMessage(`false`), - }, - "path_filter": { - Type: "string", - Description: "Optional path prefix to filter the tree results (e.g., 'src/' to only show files in the src directory)", +func GetRepositoryTree(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ + Name: "get_repository_tree", + Description: t("TOOL_GET_REPOSITORY_TREE_DESCRIPTION", "Get the tree structure (files and directories) of a GitHub repository at a specific ref or SHA"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_GET_REPOSITORY_TREE_USER_TITLE", "Get repository tree"), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner (username or organization)", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "tree_sha": { + Type: "string", + Description: "The SHA1 value or ref (branch or tag) name of the tree. Defaults to the repository's default branch", + }, + "recursive": { + Type: "boolean", + Description: "Setting this parameter to true returns the objects or subtrees referenced by the tree. Default is false", + Default: json.RawMessage(`false`), + }, + "path_filter": { + Type: "string", + Description: "Optional path prefix to filter the tree results (e.g., 'src/' to only show files in the src directory)", + }, }, + Required: []string{"owner", "repo"}, }, - Required: []string{"owner", "repo"}, }, - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + treeSHA, err := OptionalParam[string](args, "tree_sha") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + recursive, err := OptionalBoolParamWithDefault(args, "recursive", false) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pathFilter, err := OptionalParam[string](args, "path_filter") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - handler := mcp.ToolHandlerFor[map[string]any, any]( - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - treeSHA, err := OptionalParam[string](args, "tree_sha") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - recursive, err := OptionalBoolParamWithDefault(args, "recursive", false) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pathFilter, err := OptionalParam[string](args, "path_filter") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultError("failed to get GitHub client"), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultError("failed to get GitHub client"), nil, nil - } + // If no tree_sha is provided, use the repository's default branch + if treeSHA == "" { + repoInfo, repoResp, err := client.Repositories.Get(ctx, owner, repo) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get repository info", + repoResp, + err, + ), nil, nil + } + treeSHA = *repoInfo.DefaultBranch + } - // If no tree_sha is provided, use the repository's default branch - if treeSHA == "" { - repoInfo, repoResp, err := client.Repositories.Get(ctx, owner, repo) + // Get the tree using the GitHub Git Tree API + tree, resp, err := client.Git.GetTree(ctx, owner, repo, treeSHA, recursive) if err != nil { return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get repository info", - repoResp, + "failed to get repository tree", + resp, err, ), nil, nil } - treeSHA = *repoInfo.DefaultBranch - } + defer func() { _ = resp.Body.Close() }() - // Get the tree using the GitHub Git Tree API - tree, resp, err := client.Git.GetTree(ctx, owner, repo, treeSHA, recursive) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get repository tree", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() - - // Filter tree entries if path_filter is provided - var filteredEntries []*github.TreeEntry - if pathFilter != "" { - for _, entry := range tree.Entries { - if strings.HasPrefix(entry.GetPath(), pathFilter) { - filteredEntries = append(filteredEntries, entry) + // Filter tree entries if path_filter is provided + var filteredEntries []*github.TreeEntry + if pathFilter != "" { + for _, entry := range tree.Entries { + if strings.HasPrefix(entry.GetPath(), pathFilter) { + filteredEntries = append(filteredEntries, entry) + } } + } else { + filteredEntries = tree.Entries } - } else { - filteredEntries = tree.Entries - } - treeEntries := make([]TreeEntryResponse, len(filteredEntries)) - for i, entry := range filteredEntries { - treeEntries[i] = TreeEntryResponse{ - Path: entry.GetPath(), - Type: entry.GetType(), - Mode: entry.GetMode(), - SHA: entry.GetSHA(), - URL: entry.GetURL(), + treeEntries := make([]TreeEntryResponse, len(filteredEntries)) + for i, entry := range filteredEntries { + treeEntries[i] = TreeEntryResponse{ + Path: entry.GetPath(), + Type: entry.GetType(), + Mode: entry.GetMode(), + SHA: entry.GetSHA(), + URL: entry.GetURL(), + } + if entry.Size != nil { + treeEntries[i].Size = entry.Size + } } - if entry.Size != nil { - treeEntries[i].Size = entry.Size + + response := TreeResponse{ + SHA: *tree.SHA, + Truncated: *tree.Truncated, + Tree: treeEntries, + TreeSHA: treeSHA, + Owner: owner, + Repo: repo, + Recursive: recursive, + Count: len(filteredEntries), } - } - response := TreeResponse{ - SHA: *tree.SHA, - Truncated: *tree.Truncated, - Tree: treeEntries, - TreeSHA: treeSHA, - Owner: owner, - Repo: repo, - Recursive: recursive, - Count: len(filteredEntries), - } + r, err := json.Marshal(response) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } - r, err := json.Marshal(response) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + return utils.NewToolResultText(string(r)), nil, nil } - - return utils.NewToolResultText(string(r)), nil, nil }, ) - - return tool, handler } diff --git a/pkg/github/git_test.go b/pkg/github/git_test.go index 66cbccd6e..69442e312 100644 --- a/pkg/github/git_test.go +++ b/pkg/github/git_test.go @@ -18,15 +18,14 @@ import ( func Test_GetRepositoryTree(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := GetRepositoryTree(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + toolDef := GetRepositoryTree(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "get_repository_tree", tool.Name) - assert.NotEmpty(t, tool.Description) + assert.Equal(t, "get_repository_tree", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) // Type assert the InputSchema to access its properties - inputSchema, ok := tool.InputSchema.(*jsonschema.Schema) + inputSchema, ok := toolDef.Tool.InputSchema.(*jsonschema.Schema) require.True(t, ok, "expected InputSchema to be *jsonschema.Schema") assert.Contains(t, inputSchema.Properties, "owner") assert.Contains(t, inputSchema.Properties, "repo") @@ -148,12 +147,16 @@ func Test_GetRepositoryTree(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - _, handler := GetRepositoryTree(stubGetClientFromHTTPFn(tc.mockedClient), translations.NullTranslationHelper) + client := github.NewClient(tc.mockedClient) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := toolDef.Handler(deps) // Create the tool request request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) if tc.expectError { require.NoError(t, err) diff --git a/pkg/github/repositories_test.go b/pkg/github/repositories_test.go index b4ccd3603..949686d92 100644 --- a/pkg/github/repositories_test.go +++ b/pkg/github/repositories_test.go @@ -3348,181 +3348,3 @@ func Test_UnstarRepository(t *testing.T) { }) } } - -func Test_RepositoriesGetRepositoryTree(t *testing.T) { - // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := GetRepositoryTree(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) - - schema, ok := tool.InputSchema.(*jsonschema.Schema) - require.True(t, ok, "InputSchema should be *jsonschema.Schema") - - assert.Equal(t, "get_repository_tree", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, schema.Properties, "owner") - assert.Contains(t, schema.Properties, "repo") - assert.Contains(t, schema.Properties, "tree_sha") - assert.Contains(t, schema.Properties, "recursive") - assert.Contains(t, schema.Properties, "path_filter") - assert.ElementsMatch(t, schema.Required, []string{"owner", "repo"}) - - // Setup mock data - mockRepo := &github.Repository{ - DefaultBranch: github.Ptr("main"), - } - mockTree := &github.Tree{ - SHA: github.Ptr("abc123"), - Truncated: github.Ptr(false), - Entries: []*github.TreeEntry{ - { - Path: github.Ptr("README.md"), - Mode: github.Ptr("100644"), - Type: github.Ptr("blob"), - SHA: github.Ptr("file1sha"), - Size: github.Ptr(123), - URL: github.Ptr("https://api.github.com/repos/owner/repo/git/blobs/file1sha"), - }, - { - Path: github.Ptr("src/main.go"), - Mode: github.Ptr("100644"), - Type: github.Ptr("blob"), - SHA: github.Ptr("file2sha"), - Size: github.Ptr(456), - URL: github.Ptr("https://api.github.com/repos/owner/repo/git/blobs/file2sha"), - }, - }, - } - - tests := []struct { - name string - mockedClient *http.Client - requestArgs map[string]interface{} - expectError bool - expectedErrMsg string - }{ - { - name: "successfully get repository tree", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposByOwnerByRepo, - mockResponse(t, http.StatusOK, mockRepo), - ), - mock.WithRequestMatchHandler( - mock.GetReposGitTreesByOwnerByRepoByTreeSha, - mockResponse(t, http.StatusOK, mockTree), - ), - ), - requestArgs: map[string]interface{}{ - "owner": "owner", - "repo": "repo", - }, - }, - { - name: "successfully get repository tree with path filter", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposByOwnerByRepo, - mockResponse(t, http.StatusOK, mockRepo), - ), - mock.WithRequestMatchHandler( - mock.GetReposGitTreesByOwnerByRepoByTreeSha, - mockResponse(t, http.StatusOK, mockTree), - ), - ), - requestArgs: map[string]interface{}{ - "owner": "owner", - "repo": "repo", - "path_filter": "src/", - }, - }, - { - name: "repository not found", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposByOwnerByRepo, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNotFound) - _, _ = w.Write([]byte(`{"message": "Not Found"}`)) - }), - ), - ), - requestArgs: map[string]interface{}{ - "owner": "owner", - "repo": "nonexistent", - }, - expectError: true, - expectedErrMsg: "failed to get repository info", - }, - { - name: "tree not found", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposByOwnerByRepo, - mockResponse(t, http.StatusOK, mockRepo), - ), - mock.WithRequestMatchHandler( - mock.GetReposGitTreesByOwnerByRepoByTreeSha, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNotFound) - _, _ = w.Write([]byte(`{"message": "Not Found"}`)) - }), - ), - ), - requestArgs: map[string]interface{}{ - "owner": "owner", - "repo": "repo", - }, - expectError: true, - expectedErrMsg: "failed to get repository tree", - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - _, handler := GetRepositoryTree(stubGetClientFromHTTPFn(tc.mockedClient), translations.NullTranslationHelper) - - // Create the tool request - request := createMCPRequest(tc.requestArgs) - - result, _, err := handler(context.Background(), &request, tc.requestArgs) - - if tc.expectError { - require.NoError(t, err) - require.True(t, result.IsError) - errorContent := getErrorResult(t, result) - assert.Contains(t, errorContent.Text, tc.expectedErrMsg) - } else { - require.NoError(t, err) - require.False(t, result.IsError) - - // Parse the result and get the text content - textContent := getTextResult(t, result) - - // Parse the JSON response - var treeResponse map[string]interface{} - err := json.Unmarshal([]byte(textContent.Text), &treeResponse) - require.NoError(t, err) - - // Verify response structure - assert.Equal(t, "owner", treeResponse["owner"]) - assert.Equal(t, "repo", treeResponse["repo"]) - assert.Contains(t, treeResponse, "tree") - assert.Contains(t, treeResponse, "count") - assert.Contains(t, treeResponse, "sha") - assert.Contains(t, treeResponse, "truncated") - - // Check filtering if path_filter was provided - if pathFilter, exists := tc.requestArgs["path_filter"]; exists { - tree := treeResponse["tree"].([]interface{}) - for _, entry := range tree { - entryMap := entry.(map[string]interface{}) - path := entryMap["path"].(string) - assert.True(t, strings.HasPrefix(path, pathFilter.(string)), - "Path %s should start with filter %s", path, pathFilter) - } - } - } - }) - } -} diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 8e811c9bf..1be7e6151 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -209,7 +209,7 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG git := toolsets.NewToolset(ToolsetMetadataGit.ID, ToolsetMetadataGit.Description). SetDependencies(deps). AddReadTools( - toolsets.NewServerToolLegacy(GetRepositoryTree(getClient, t)), + GetRepositoryTree(t), ) issues := toolsets.NewToolset(ToolsetMetadataIssues.ID, ToolsetMetadataIssues.Description). SetDependencies(deps).