Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 114 additions & 114 deletions pkg/github/git.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"

ghErrors "github.com/github/github-mcp-server/pkg/errors"
"github.com/github/github-mcp-server/pkg/toolsets"
"github.com/github/github-mcp-server/pkg/translations"
"github.com/github/github-mcp-server/pkg/utils"
"github.com/google/go-github/v79/github"
Expand Down Expand Up @@ -37,140 +38,139 @@ type TreeResponse struct {
}

// GetRepositoryTree creates a tool to get the tree structure of a GitHub repository.
func GetRepositoryTree(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) {
tool := mcp.Tool{
Name: "get_repository_tree",
Description: t("TOOL_GET_REPOSITORY_TREE_DESCRIPTION", "Get the tree structure (files and directories) of a GitHub repository at a specific ref or SHA"),
Annotations: &mcp.ToolAnnotations{
Title: t("TOOL_GET_REPOSITORY_TREE_USER_TITLE", "Get repository tree"),
ReadOnlyHint: true,
},
InputSchema: &jsonschema.Schema{
Type: "object",
Properties: map[string]*jsonschema.Schema{
"owner": {
Type: "string",
Description: "Repository owner (username or organization)",
},
"repo": {
Type: "string",
Description: "Repository name",
},
"tree_sha": {
Type: "string",
Description: "The SHA1 value or ref (branch or tag) name of the tree. Defaults to the repository's default branch",
},
"recursive": {
Type: "boolean",
Description: "Setting this parameter to true returns the objects or subtrees referenced by the tree. Default is false",
Default: json.RawMessage(`false`),
},
"path_filter": {
Type: "string",
Description: "Optional path prefix to filter the tree results (e.g., 'src/' to only show files in the src directory)",
func GetRepositoryTree(t translations.TranslationHelperFunc) toolsets.ServerTool {
return NewTool(
mcp.Tool{
Name: "get_repository_tree",
Description: t("TOOL_GET_REPOSITORY_TREE_DESCRIPTION", "Get the tree structure (files and directories) of a GitHub repository at a specific ref or SHA"),
Annotations: &mcp.ToolAnnotations{
Title: t("TOOL_GET_REPOSITORY_TREE_USER_TITLE", "Get repository tree"),
ReadOnlyHint: true,
},
InputSchema: &jsonschema.Schema{
Type: "object",
Properties: map[string]*jsonschema.Schema{
"owner": {
Type: "string",
Description: "Repository owner (username or organization)",
},
"repo": {
Type: "string",
Description: "Repository name",
},
"tree_sha": {
Type: "string",
Description: "The SHA1 value or ref (branch or tag) name of the tree. Defaults to the repository's default branch",
},
"recursive": {
Type: "boolean",
Description: "Setting this parameter to true returns the objects or subtrees referenced by the tree. Default is false",
Default: json.RawMessage(`false`),
},
"path_filter": {
Type: "string",
Description: "Optional path prefix to filter the tree results (e.g., 'src/' to only show files in the src directory)",
},
},
Required: []string{"owner", "repo"},
},
Required: []string{"owner", "repo"},
},
}
func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] {
return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) {
owner, err := RequiredParam[string](args, "owner")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
repo, err := RequiredParam[string](args, "repo")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
treeSHA, err := OptionalParam[string](args, "tree_sha")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
recursive, err := OptionalBoolParamWithDefault(args, "recursive", false)
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
pathFilter, err := OptionalParam[string](args, "path_filter")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}

handler := mcp.ToolHandlerFor[map[string]any, any](
func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) {
owner, err := RequiredParam[string](args, "owner")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
repo, err := RequiredParam[string](args, "repo")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
treeSHA, err := OptionalParam[string](args, "tree_sha")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
recursive, err := OptionalBoolParamWithDefault(args, "recursive", false)
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
pathFilter, err := OptionalParam[string](args, "path_filter")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
client, err := deps.GetClient(ctx)
if err != nil {
return utils.NewToolResultError("failed to get GitHub client"), nil, nil
}

client, err := getClient(ctx)
if err != nil {
return utils.NewToolResultError("failed to get GitHub client"), nil, nil
}
// If no tree_sha is provided, use the repository's default branch
if treeSHA == "" {
repoInfo, repoResp, err := client.Repositories.Get(ctx, owner, repo)
if err != nil {
return ghErrors.NewGitHubAPIErrorResponse(ctx,
"failed to get repository info",
repoResp,
err,
), nil, nil
}
treeSHA = *repoInfo.DefaultBranch
}

// If no tree_sha is provided, use the repository's default branch
if treeSHA == "" {
repoInfo, repoResp, err := client.Repositories.Get(ctx, owner, repo)
// Get the tree using the GitHub Git Tree API
tree, resp, err := client.Git.GetTree(ctx, owner, repo, treeSHA, recursive)
if err != nil {
return ghErrors.NewGitHubAPIErrorResponse(ctx,
"failed to get repository info",
repoResp,
"failed to get repository tree",
resp,
err,
), nil, nil
}
treeSHA = *repoInfo.DefaultBranch
}
defer func() { _ = resp.Body.Close() }()

// Get the tree using the GitHub Git Tree API
tree, resp, err := client.Git.GetTree(ctx, owner, repo, treeSHA, recursive)
if err != nil {
return ghErrors.NewGitHubAPIErrorResponse(ctx,
"failed to get repository tree",
resp,
err,
), nil, nil
}
defer func() { _ = resp.Body.Close() }()

// Filter tree entries if path_filter is provided
var filteredEntries []*github.TreeEntry
if pathFilter != "" {
for _, entry := range tree.Entries {
if strings.HasPrefix(entry.GetPath(), pathFilter) {
filteredEntries = append(filteredEntries, entry)
// Filter tree entries if path_filter is provided
var filteredEntries []*github.TreeEntry
if pathFilter != "" {
for _, entry := range tree.Entries {
if strings.HasPrefix(entry.GetPath(), pathFilter) {
filteredEntries = append(filteredEntries, entry)
}
}
} else {
filteredEntries = tree.Entries
}
} else {
filteredEntries = tree.Entries
}

treeEntries := make([]TreeEntryResponse, len(filteredEntries))
for i, entry := range filteredEntries {
treeEntries[i] = TreeEntryResponse{
Path: entry.GetPath(),
Type: entry.GetType(),
Mode: entry.GetMode(),
SHA: entry.GetSHA(),
URL: entry.GetURL(),
treeEntries := make([]TreeEntryResponse, len(filteredEntries))
for i, entry := range filteredEntries {
treeEntries[i] = TreeEntryResponse{
Path: entry.GetPath(),
Type: entry.GetType(),
Mode: entry.GetMode(),
SHA: entry.GetSHA(),
URL: entry.GetURL(),
}
if entry.Size != nil {
treeEntries[i].Size = entry.Size
}
}
if entry.Size != nil {
treeEntries[i].Size = entry.Size

response := TreeResponse{
SHA: *tree.SHA,
Truncated: *tree.Truncated,
Tree: treeEntries,
TreeSHA: treeSHA,
Owner: owner,
Repo: repo,
Recursive: recursive,
Count: len(filteredEntries),
}
}

response := TreeResponse{
SHA: *tree.SHA,
Truncated: *tree.Truncated,
Tree: treeEntries,
TreeSHA: treeSHA,
Owner: owner,
Repo: repo,
Recursive: recursive,
Count: len(filteredEntries),
}
r, err := json.Marshal(response)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal response: %w", err)
}

r, err := json.Marshal(response)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal response: %w", err)
return utils.NewToolResultText(string(r)), nil, nil
}

return utils.NewToolResultText(string(r)), nil, nil
},
)

return tool, handler
}
19 changes: 11 additions & 8 deletions pkg/github/git_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@ import (

func Test_GetRepositoryTree(t *testing.T) {
// Verify tool definition once
mockClient := github.NewClient(nil)
tool, _ := GetRepositoryTree(stubGetClientFn(mockClient), translations.NullTranslationHelper)
require.NoError(t, toolsnaps.Test(tool.Name, tool))
toolDef := GetRepositoryTree(translations.NullTranslationHelper)
require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool))

assert.Equal(t, "get_repository_tree", tool.Name)
assert.NotEmpty(t, tool.Description)
assert.Equal(t, "get_repository_tree", toolDef.Tool.Name)
assert.NotEmpty(t, toolDef.Tool.Description)

// Type assert the InputSchema to access its properties
inputSchema, ok := tool.InputSchema.(*jsonschema.Schema)
inputSchema, ok := toolDef.Tool.InputSchema.(*jsonschema.Schema)
require.True(t, ok, "expected InputSchema to be *jsonschema.Schema")
assert.Contains(t, inputSchema.Properties, "owner")
assert.Contains(t, inputSchema.Properties, "repo")
Expand Down Expand Up @@ -148,12 +147,16 @@ func Test_GetRepositoryTree(t *testing.T) {

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, handler := GetRepositoryTree(stubGetClientFromHTTPFn(tc.mockedClient), translations.NullTranslationHelper)
client := github.NewClient(tc.mockedClient)
deps := ToolDependencies{
GetClient: stubGetClientFn(client),
}
handler := toolDef.Handler(deps)

// Create the tool request
request := createMCPRequest(tc.requestArgs)

result, _, err := handler(context.Background(), &request, tc.requestArgs)
result, err := handler(context.Background(), &request)

if tc.expectError {
require.NoError(t, err)
Expand Down
Loading