diff --git a/pkg/github/context_tools.go b/pkg/github/context_tools.go index 3fe622379..e8043731a 100644 --- a/pkg/github/context_tools.go +++ b/pkg/github/context_tools.go @@ -6,6 +6,7 @@ import ( "time" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/jsonschema-go/jsonschema" @@ -36,8 +37,9 @@ type UserDetails struct { } // GetMe creates a tool to get details of the authenticated user. -func GetMe(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetMe(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "get_me", Description: t("TOOL_GET_ME_DESCRIPTION", "Get details of the authenticated GitHub user. Use this when a request is about the user's own profile for GitHub. Or when information is missing to build other tool calls."), Annotations: &mcp.ToolAnnotations{ @@ -48,50 +50,53 @@ func GetMe(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Too // OpenAI strict mode requires the properties field to be present. InputSchema: json.RawMessage(`{"type":"object","properties":{}}`), }, - mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, _ map[string]any) (*mcp.CallToolResult, any, error) { - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, _ map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - user, res, err := client.Users.Get(ctx, "") - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get user", - res, - err, - ), nil, err - } + user, res, err := client.Users.Get(ctx, "") + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get user", + res, + err, + ), nil, nil + } - // Create minimal user representation instead of returning full user object - minimalUser := MinimalUser{ - Login: user.GetLogin(), - ID: user.GetID(), - ProfileURL: user.GetHTMLURL(), - AvatarURL: user.GetAvatarURL(), - Details: &UserDetails{ - Name: user.GetName(), - Company: user.GetCompany(), - Blog: user.GetBlog(), - Location: user.GetLocation(), - Email: user.GetEmail(), - Hireable: user.GetHireable(), - Bio: user.GetBio(), - TwitterUsername: user.GetTwitterUsername(), - PublicRepos: user.GetPublicRepos(), - PublicGists: user.GetPublicGists(), - Followers: user.GetFollowers(), - Following: user.GetFollowing(), - CreatedAt: user.GetCreatedAt().Time, - UpdatedAt: user.GetUpdatedAt().Time, - PrivateGists: user.GetPrivateGists(), - TotalPrivateRepos: user.GetTotalPrivateRepos(), - OwnedPrivateRepos: user.GetOwnedPrivateRepos(), - }, - } + // Create minimal user representation instead of returning full user object + minimalUser := MinimalUser{ + Login: user.GetLogin(), + ID: user.GetID(), + ProfileURL: user.GetHTMLURL(), + AvatarURL: user.GetAvatarURL(), + Details: &UserDetails{ + Name: user.GetName(), + Company: user.GetCompany(), + Blog: user.GetBlog(), + Location: user.GetLocation(), + Email: user.GetEmail(), + Hireable: user.GetHireable(), + Bio: user.GetBio(), + TwitterUsername: user.GetTwitterUsername(), + PublicRepos: user.GetPublicRepos(), + PublicGists: user.GetPublicGists(), + Followers: user.GetFollowers(), + Following: user.GetFollowing(), + CreatedAt: user.GetCreatedAt().Time, + UpdatedAt: user.GetUpdatedAt().Time, + PrivateGists: user.GetPrivateGists(), + TotalPrivateRepos: user.GetTotalPrivateRepos(), + OwnedPrivateRepos: user.GetOwnedPrivateRepos(), + }, + } - return MarshalledTextResult(minimalUser), nil, nil - }) + return MarshalledTextResult(minimalUser), nil, nil + } + }, + ) } type TeamInfo struct { @@ -105,8 +110,9 @@ type OrganizationTeams struct { Teams []TeamInfo `json:"teams"` } -func GetTeams(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetTeams(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "get_teams", Description: t("TOOL_GET_TEAMS_DESCRIPTION", "Get details of the teams the user is a member of. Limited to organizations accessible with current credentials"), Annotations: &mcp.ToolAnnotations{ @@ -123,84 +129,88 @@ func GetTeams(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations }, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - user, err := OptionalParam[string](args, "user") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - var username string - if user != "" { - username = user - } else { - client, err := getClient(ctx) + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + user, err := OptionalParam[string](args, "user") if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + return utils.NewToolResultError(err.Error()), nil, nil } - userResp, res, err := client.Users.Get(ctx, "") + var username string + if user != "" { + username = user + } else { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + + userResp, res, err := client.Users.Get(ctx, "") + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get user", + res, + err, + ), nil, nil + } + username = userResp.GetLogin() + } + + gqlClient, err := deps.GetGQLClient(ctx) if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get user", - res, - err, - ), nil, nil + return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil } - username = userResp.GetLogin() - } - gqlClient, err := getGQLClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil - } + var q struct { + User struct { + Organizations struct { + Nodes []struct { + Login githubv4.String + Teams struct { + Nodes []struct { + Name githubv4.String + Slug githubv4.String + Description githubv4.String + } + } `graphql:"teams(first: 100, userLogins: [$login])"` + } + } `graphql:"organizations(first: 100)"` + } `graphql:"user(login: $login)"` + } + vars := map[string]interface{}{ + "login": githubv4.String(username), + } + if err := gqlClient.Query(ctx, &q, vars); err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find teams", err), nil, nil + } - var q struct { - User struct { - Organizations struct { - Nodes []struct { - Login githubv4.String - Teams struct { - Nodes []struct { - Name githubv4.String - Slug githubv4.String - Description githubv4.String - } - } `graphql:"teams(first: 100, userLogins: [$login])"` - } - } `graphql:"organizations(first: 100)"` - } `graphql:"user(login: $login)"` - } - vars := map[string]interface{}{ - "login": githubv4.String(username), - } - if err := gqlClient.Query(ctx, &q, vars); err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find teams", err), nil, nil - } + var organizations []OrganizationTeams + for _, org := range q.User.Organizations.Nodes { + orgTeams := OrganizationTeams{ + Org: string(org.Login), + Teams: make([]TeamInfo, 0, len(org.Teams.Nodes)), + } - var organizations []OrganizationTeams - for _, org := range q.User.Organizations.Nodes { - orgTeams := OrganizationTeams{ - Org: string(org.Login), - Teams: make([]TeamInfo, 0, len(org.Teams.Nodes)), - } + for _, team := range org.Teams.Nodes { + orgTeams.Teams = append(orgTeams.Teams, TeamInfo{ + Name: string(team.Name), + Slug: string(team.Slug), + Description: string(team.Description), + }) + } - for _, team := range org.Teams.Nodes { - orgTeams.Teams = append(orgTeams.Teams, TeamInfo{ - Name: string(team.Name), - Slug: string(team.Slug), - Description: string(team.Description), - }) + organizations = append(organizations, orgTeams) } - organizations = append(organizations, orgTeams) + return MarshalledTextResult(organizations), nil, nil } - - return MarshalledTextResult(organizations), nil, nil - } + }, + ) } -func GetTeamMembers(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetTeamMembers(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "get_team_members", Description: t("TOOL_GET_TEAM_MEMBERS_DESCRIPTION", "Get member usernames of a specific team in an organization. Limited to organizations accessible with current credentials"), Annotations: &mcp.ToolAnnotations{ @@ -222,46 +232,49 @@ func GetTeamMembers(getGQLClient GetGQLClientFn, t translations.TranslationHelpe Required: []string{"org", "team_slug"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - org, err := RequiredParam[string](args, "org") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + org, err := RequiredParam[string](args, "org") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - teamSlug, err := RequiredParam[string](args, "team_slug") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + teamSlug, err := RequiredParam[string](args, "team_slug") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - gqlClient, err := getGQLClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil - } + gqlClient, err := deps.GetGQLClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil + } - var q struct { - Organization struct { - Team struct { - Members struct { - Nodes []struct { - Login githubv4.String - } - } `graphql:"members(first: 100)"` - } `graphql:"team(slug: $teamSlug)"` - } `graphql:"organization(login: $org)"` - } - vars := map[string]interface{}{ - "org": githubv4.String(org), - "teamSlug": githubv4.String(teamSlug), - } - if err := gqlClient.Query(ctx, &q, vars); err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to get team members", err), nil, nil - } + var q struct { + Organization struct { + Team struct { + Members struct { + Nodes []struct { + Login githubv4.String + } + } `graphql:"members(first: 100)"` + } `graphql:"team(slug: $teamSlug)"` + } `graphql:"organization(login: $org)"` + } + vars := map[string]interface{}{ + "org": githubv4.String(org), + "teamSlug": githubv4.String(teamSlug), + } + if err := gqlClient.Query(ctx, &q, vars); err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to get team members", err), nil, nil + } - var members []string - for _, member := range q.Organization.Team.Members.Nodes { - members = append(members, string(member.Login)) - } + var members []string + for _, member := range q.Organization.Team.Members.Nodes { + members = append(members, string(member.Login)) + } - return MarshalledTextResult(members), nil, nil - } + return MarshalledTextResult(members), nil, nil + } + }, + ) } diff --git a/pkg/github/context_tools_test.go b/pkg/github/context_tools_test.go index 96e21c233..0e28aad49 100644 --- a/pkg/github/context_tools_test.go +++ b/pkg/github/context_tools_test.go @@ -20,7 +20,8 @@ import ( func Test_GetMe(t *testing.T) { t.Parallel() - tool, _ := GetMe(nil, translations.NullTranslationHelper) + serverTool := GetMe(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) // Verify some basic very important properties @@ -108,21 +109,28 @@ func Test_GetMe(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - _, handler := GetMe(tc.stubbedGetClientFn, translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: tc.stubbedGetClientFn, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, _ := handler(context.Background(), &request, tc.requestArgs) - textContent := getTextResult(t, result) + result, err := handler(context.Background(), &request) + require.NoError(t, err) if tc.expectToolError { - assert.True(t, result.IsError, "expected tool call result to be an error") - assert.Contains(t, textContent.Text, tc.expectedToolErrMsg) + require.True(t, result.IsError, "expected tool call result to be an error") + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedToolErrMsg) return } + require.False(t, result.IsError) + textContent := getTextResult(t, result) + // Unmarshal and verify the result var returnedUser MinimalUser - err := json.Unmarshal([]byte(textContent.Text), &returnedUser) + err = json.Unmarshal([]byte(textContent.Text), &returnedUser) require.NoError(t, err) // Verify minimal user details @@ -145,7 +153,8 @@ func Test_GetMe(t *testing.T) { func Test_GetTeams(t *testing.T) { t.Parallel() - tool, _ := GetTeams(nil, nil, translations.NullTranslationHelper) + serverTool := GetTeams(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "get_teams", tool.Name) @@ -331,19 +340,26 @@ func Test_GetTeams(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - _, handler := GetTeams(tc.stubbedGetClientFn, tc.stubbedGetGQLClientFn, translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: tc.stubbedGetClientFn, + GetGQLClient: tc.stubbedGetGQLClientFn, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) - textContent := getTextResult(t, result) if tc.expectToolError { - assert.True(t, result.IsError, "expected tool call result to be an error") - assert.Contains(t, textContent.Text, tc.expectedToolErrMsg) + require.True(t, result.IsError, "expected tool call result to be an error") + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedToolErrMsg) return } + require.False(t, result.IsError) + textContent := getTextResult(t, result) + var organizations []OrganizationTeams err = json.Unmarshal([]byte(textContent.Text), &organizations) require.NoError(t, err) @@ -372,7 +388,8 @@ func Test_GetTeams(t *testing.T) { func Test_GetTeamMembers(t *testing.T) { t.Parallel() - tool, _ := GetTeamMembers(nil, translations.NullTranslationHelper) + serverTool := GetTeamMembers(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "get_team_members", tool.Name) @@ -467,19 +484,25 @@ func Test_GetTeamMembers(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - _, handler := GetTeamMembers(tc.stubbedGetGQLClientFn, translations.NullTranslationHelper) + deps := ToolDependencies{ + GetGQLClient: tc.stubbedGetGQLClientFn, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) - textContent := getTextResult(t, result) if tc.expectToolError { - assert.True(t, result.IsError, "expected tool call result to be an error") - assert.Contains(t, textContent.Text, tc.expectedToolErrMsg) + require.True(t, result.IsError, "expected tool call result to be an error") + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedToolErrMsg) return } + require.False(t, result.IsError) + textContent := getTextResult(t, result) + var members []string err = json.Unmarshal([]byte(textContent.Text), &members) require.NoError(t, err) diff --git a/pkg/github/gists.go b/pkg/github/gists.go index b54553aac..baca42399 100644 --- a/pkg/github/gists.go +++ b/pkg/github/gists.go @@ -7,6 +7,7 @@ import ( "io" "net/http" + "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -15,346 +16,350 @@ import ( ) // ListGists creates a tool to list gists for a user -func ListGists(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "list_gists", - Description: t("TOOL_LIST_GISTS_DESCRIPTION", "List gists for a user"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_LIST_GISTS", "List Gists"), - ReadOnlyHint: true, - }, - InputSchema: WithPagination(&jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "username": { - Type: "string", - Description: "GitHub username (omit for authenticated user's gists)", - }, - "since": { - Type: "string", - Description: "Only gists updated after this time (ISO 8601 timestamp)", - }, +func ListGists(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ + Name: "list_gists", + Description: t("TOOL_LIST_GISTS_DESCRIPTION", "List gists for a user"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_LIST_GISTS", "List Gists"), + ReadOnlyHint: true, }, - }), - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - username, err := OptionalParam[string](args, "username") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - since, err := OptionalParam[string](args, "since") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - opts := &github.GistListOptions{ - ListOptions: github.ListOptions{ - Page: pagination.Page, - PerPage: pagination.PerPage, - }, - } - - // Parse since timestamp if provided - if since != "" { - sinceTime, err := parseISOTimestamp(since) - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid since timestamp: %v", err)), nil, nil - } - opts.Since = sinceTime - } - - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - gists, resp, err := client.Gists.List(ctx, username, opts) - if err != nil { - return nil, nil, fmt.Errorf("failed to list gists: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + InputSchema: WithPagination(&jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "username": { + Type: "string", + Description: "GitHub username (omit for authenticated user's gists)", + }, + "since": { + Type: "string", + Description: "Only gists updated after this time (ISO 8601 timestamp)", + }, + }, + }), + }, + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + username, err := OptionalParam[string](args, "username") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + since, err := OptionalParam[string](args, "since") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + opts := &github.GistListOptions{ + ListOptions: github.ListOptions{ + Page: pagination.Page, + PerPage: pagination.PerPage, + }, + } + + // Parse since timestamp if provided + if since != "" { + sinceTime, err := parseISOTimestamp(since) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid since timestamp: %v", err)), nil, nil + } + opts.Since = sinceTime + } + + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + + gists, resp, err := client.Gists.List(ctx, username, opts) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to list gists", err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to list gists: %s", string(body))), nil, nil + } + + r, err := json.Marshal(gists) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + } + + return utils.NewToolResultText(string(r)), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to list gists: %s", string(body))), nil, nil - } - - r, err := json.Marshal(gists) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return utils.NewToolResultText(string(r)), nil, nil - }) - - return tool, handler + }, + ) } // GetGist creates a tool to get the content of a gist -func GetGist(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "get_gist", - Description: t("TOOL_GET_GIST_DESCRIPTION", "Get gist content of a particular gist, by gist ID"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_GET_GIST", "Get Gist Content"), - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "gist_id": { - Type: "string", - Description: "The ID of the gist", +func GetGist(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ + Name: "get_gist", + Description: t("TOOL_GET_GIST_DESCRIPTION", "Get gist content of a particular gist, by gist ID"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_GET_GIST", "Get Gist Content"), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "gist_id": { + Type: "string", + Description: "The ID of the gist", + }, }, + Required: []string{"gist_id"}, }, - Required: []string{"gist_id"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - gistID, err := RequiredParam[string](args, "gist_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - gist, resp, err := client.Gists.Get(ctx, gistID) - if err != nil { - return nil, nil, fmt.Errorf("failed to get gist: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + gistID, err := RequiredParam[string](args, "gist_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + + gist, resp, err := client.Gists.Get(ctx, gistID) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get gist", err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to get gist: %s", string(body))), nil, nil + } + + r, err := json.Marshal(gist) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + } + + return utils.NewToolResultText(string(r)), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to get gist: %s", string(body))), nil, nil - } - - r, err := json.Marshal(gist) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return utils.NewToolResultText(string(r)), nil, nil - }) - - return tool, handler + }, + ) } // CreateGist creates a tool to create a new gist -func CreateGist(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "create_gist", - Description: t("TOOL_CREATE_GIST_DESCRIPTION", "Create a new gist"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_CREATE_GIST", "Create Gist"), - ReadOnlyHint: false, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "description": { - Type: "string", - Description: "Description of the gist", - }, - "filename": { - Type: "string", - Description: "Filename for simple single-file gist creation", - }, - "content": { - Type: "string", - Description: "Content for simple single-file gist creation", - }, - "public": { - Type: "boolean", - Description: "Whether the gist is public", - Default: json.RawMessage(`false`), +func CreateGist(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ + Name: "create_gist", + Description: t("TOOL_CREATE_GIST_DESCRIPTION", "Create a new gist"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_CREATE_GIST", "Create Gist"), + ReadOnlyHint: false, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "description": { + Type: "string", + Description: "Description of the gist", + }, + "filename": { + Type: "string", + Description: "Filename for simple single-file gist creation", + }, + "content": { + Type: "string", + Description: "Content for simple single-file gist creation", + }, + "public": { + Type: "boolean", + Description: "Whether the gist is public", + Default: json.RawMessage(`false`), + }, }, + Required: []string{"filename", "content"}, }, - Required: []string{"filename", "content"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - description, err := OptionalParam[string](args, "description") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - filename, err := RequiredParam[string](args, "filename") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - content, err := RequiredParam[string](args, "content") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - public, err := OptionalParam[bool](args, "public") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - files := make(map[github.GistFilename]github.GistFile) - files[github.GistFilename(filename)] = github.GistFile{ - Filename: github.Ptr(filename), - Content: github.Ptr(content), - } - - gist := &github.Gist{ - Files: files, - Public: github.Ptr(public), - Description: github.Ptr(description), - } - - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - createdGist, resp, err := client.Gists.Create(ctx, gist) - if err != nil { - return nil, nil, fmt.Errorf("failed to create gist: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusCreated { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + description, err := OptionalParam[string](args, "description") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + filename, err := RequiredParam[string](args, "filename") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + content, err := RequiredParam[string](args, "content") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + public, err := OptionalParam[bool](args, "public") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + files := make(map[github.GistFilename]github.GistFile) + files[github.GistFilename(filename)] = github.GistFile{ + Filename: github.Ptr(filename), + Content: github.Ptr(content), + } + + gist := &github.Gist{ + Files: files, + Public: github.Ptr(public), + Description: github.Ptr(description), + } + + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + + createdGist, resp, err := client.Gists.Create(ctx, gist) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to create gist", err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusCreated { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to create gist: %s", string(body))), nil, nil + } + + minimalResponse := MinimalResponse{ + ID: createdGist.GetID(), + URL: createdGist.GetHTMLURL(), + } + + r, err := json.Marshal(minimalResponse) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + } + + return utils.NewToolResultText(string(r)), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to create gist: %s", string(body))), nil, nil - } - - minimalResponse := MinimalResponse{ - ID: createdGist.GetID(), - URL: createdGist.GetHTMLURL(), - } - - r, err := json.Marshal(minimalResponse) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return utils.NewToolResultText(string(r)), nil, nil - }) - - return tool, handler + }, + ) } // UpdateGist creates a tool to edit an existing gist -func UpdateGist(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "update_gist", - Description: t("TOOL_UPDATE_GIST_DESCRIPTION", "Update an existing gist"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_UPDATE_GIST", "Update Gist"), - ReadOnlyHint: false, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "gist_id": { - Type: "string", - Description: "ID of the gist to update", - }, - "description": { - Type: "string", - Description: "Updated description of the gist", - }, - "filename": { - Type: "string", - Description: "Filename to update or create", - }, - "content": { - Type: "string", - Description: "Content for the file", +func UpdateGist(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ + Name: "update_gist", + Description: t("TOOL_UPDATE_GIST_DESCRIPTION", "Update an existing gist"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_UPDATE_GIST", "Update Gist"), + ReadOnlyHint: false, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "gist_id": { + Type: "string", + Description: "ID of the gist to update", + }, + "description": { + Type: "string", + Description: "Updated description of the gist", + }, + "filename": { + Type: "string", + Description: "Filename to update or create", + }, + "content": { + Type: "string", + Description: "Content for the file", + }, }, + Required: []string{"gist_id", "filename", "content"}, }, - Required: []string{"gist_id", "filename", "content"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - gistID, err := RequiredParam[string](args, "gist_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - description, err := OptionalParam[string](args, "description") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - filename, err := RequiredParam[string](args, "filename") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - content, err := RequiredParam[string](args, "content") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - files := make(map[github.GistFilename]github.GistFile) - files[github.GistFilename(filename)] = github.GistFile{ - Filename: github.Ptr(filename), - Content: github.Ptr(content), - } - - gist := &github.Gist{ - Files: files, - Description: github.Ptr(description), - } - - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - updatedGist, resp, err := client.Gists.Edit(ctx, gistID, gist) - if err != nil { - return nil, nil, fmt.Errorf("failed to update gist: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + gistID, err := RequiredParam[string](args, "gist_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + description, err := OptionalParam[string](args, "description") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + filename, err := RequiredParam[string](args, "filename") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + content, err := RequiredParam[string](args, "content") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + files := make(map[github.GistFilename]github.GistFile) + files[github.GistFilename(filename)] = github.GistFile{ + Filename: github.Ptr(filename), + Content: github.Ptr(content), + } + + gist := &github.Gist{ + Files: files, + Description: github.Ptr(description), + } + + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + + updatedGist, resp, err := client.Gists.Edit(ctx, gistID, gist) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to update gist", err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to update gist: %s", string(body))), nil, nil + } + + minimalResponse := MinimalResponse{ + ID: updatedGist.GetID(), + URL: updatedGist.GetHTMLURL(), + } + + r, err := json.Marshal(minimalResponse) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + } + + return utils.NewToolResultText(string(r)), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to update gist: %s", string(body))), nil, nil - } - - minimalResponse := MinimalResponse{ - ID: updatedGist.GetID(), - URL: updatedGist.GetHTMLURL(), - } - - r, err := json.Marshal(minimalResponse) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return utils.NewToolResultText(string(r)), nil, nil - }) - - return tool, handler + }, + ) } diff --git a/pkg/github/gists_test.go b/pkg/github/gists_test.go index f0f62f420..44b294eb6 100644 --- a/pkg/github/gists_test.go +++ b/pkg/github/gists_test.go @@ -18,8 +18,8 @@ import ( func Test_ListGists(t *testing.T) { // Verify tool definition - mockClient := github.NewClient(nil) - tool, _ := ListGists(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ListGists(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) @@ -158,28 +158,27 @@ func Test_ListGists(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ListGists(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) + require.NoError(t, err) // Verify results if tc.expectError { - if err != nil { - assert.Contains(t, err.Error(), tc.expectedErrMsg) - } else { - // For errors returned as part of the result, not as an error - assert.NotNil(t, result) - textContent := getTextResult(t, result) - assert.Contains(t, textContent.Text, tc.expectedErrMsg) - } + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } - require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -202,8 +201,8 @@ func Test_ListGists(t *testing.T) { func Test_GetGist(t *testing.T) { // Verify tool definition - mockClient := github.NewClient(nil) - tool, _ := GetGist(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := GetGist(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) @@ -276,28 +275,27 @@ func Test_GetGist(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetGist(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) + require.NoError(t, err) // Verify results if tc.expectError { - if err != nil { - assert.Contains(t, err.Error(), tc.expectedErrMsg) - } else { - // For errors returned as part of the result, not as an error - assert.NotNil(t, result) - textContent := getTextResult(t, result) - assert.Contains(t, textContent.Text, tc.expectedErrMsg) - } + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } - require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -317,8 +315,8 @@ func Test_GetGist(t *testing.T) { func Test_CreateGist(t *testing.T) { // Verify tool definition - mockClient := github.NewClient(nil) - tool, _ := CreateGist(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := CreateGist(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) @@ -423,28 +421,27 @@ func Test_CreateGist(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := CreateGist(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) + require.NoError(t, err) // Verify results if tc.expectError { - if err != nil { - assert.Contains(t, err.Error(), tc.expectedErrMsg) - } else { - // For errors returned as part of the result, not as an error - assert.NotNil(t, result) - textContent := getTextResult(t, result) - assert.Contains(t, textContent.Text, tc.expectedErrMsg) - } + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } - require.NoError(t, err) + require.False(t, result.IsError) assert.NotNil(t, result) // Parse the result and get the text content @@ -462,8 +459,8 @@ func Test_CreateGist(t *testing.T) { func Test_UpdateGist(t *testing.T) { // Verify tool definition - mockClient := github.NewClient(nil) - tool, _ := UpdateGist(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := UpdateGist(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) @@ -583,28 +580,27 @@ func Test_UpdateGist(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := UpdateGist(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) + require.NoError(t, err) // Verify results if tc.expectError { - if err != nil { - assert.Contains(t, err.Error(), tc.expectedErrMsg) - } else { - // For errors returned as part of the result, not as an error - assert.NotNil(t, result) - textContent := getTextResult(t, result) - assert.Contains(t, textContent.Text, tc.expectedErrMsg) - } + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } - require.NoError(t, err) + require.False(t, result.IsError) assert.NotNil(t, result) // Parse the result and get the text content diff --git a/pkg/github/tools.go b/pkg/github/tools.go index dff1ca02e..d9008a912 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -334,20 +334,20 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG contextTools := toolsets.NewToolset(ToolsetMetadataContext.ID, ToolsetMetadataContext.Description). SetDependencies(deps). AddReadTools( - toolsets.NewServerToolLegacy(GetMe(getClient, t)), - toolsets.NewServerToolLegacy(GetTeams(getClient, getGQLClient, t)), - toolsets.NewServerToolLegacy(GetTeamMembers(getGQLClient, t)), + GetMe(t), + GetTeams(t), + GetTeamMembers(t), ) gists := toolsets.NewToolset(ToolsetMetadataGists.ID, ToolsetMetadataGists.Description). SetDependencies(deps). AddReadTools( - toolsets.NewServerToolLegacy(ListGists(getClient, t)), - toolsets.NewServerToolLegacy(GetGist(getClient, t)), + ListGists(t), + GetGist(t), ). AddWriteTools( - toolsets.NewServerToolLegacy(CreateGist(getClient, t)), - toolsets.NewServerToolLegacy(UpdateGist(getClient, t)), + CreateGist(t), + UpdateGist(t), ) projects := toolsets.NewToolset(ToolsetMetadataProjects.ID, ToolsetMetadataProjects.Description).