diff --git a/pkg/github/gists.go b/pkg/github/gists.go index b54553aac..baca42399 100644 --- a/pkg/github/gists.go +++ b/pkg/github/gists.go @@ -7,6 +7,7 @@ import ( "io" "net/http" + "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -15,346 +16,350 @@ import ( ) // ListGists creates a tool to list gists for a user -func ListGists(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "list_gists", - Description: t("TOOL_LIST_GISTS_DESCRIPTION", "List gists for a user"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_LIST_GISTS", "List Gists"), - ReadOnlyHint: true, - }, - InputSchema: WithPagination(&jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "username": { - Type: "string", - Description: "GitHub username (omit for authenticated user's gists)", - }, - "since": { - Type: "string", - Description: "Only gists updated after this time (ISO 8601 timestamp)", - }, +func ListGists(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ + Name: "list_gists", + Description: t("TOOL_LIST_GISTS_DESCRIPTION", "List gists for a user"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_LIST_GISTS", "List Gists"), + ReadOnlyHint: true, }, - }), - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - username, err := OptionalParam[string](args, "username") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - since, err := OptionalParam[string](args, "since") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - opts := &github.GistListOptions{ - ListOptions: github.ListOptions{ - Page: pagination.Page, - PerPage: pagination.PerPage, - }, - } - - // Parse since timestamp if provided - if since != "" { - sinceTime, err := parseISOTimestamp(since) - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid since timestamp: %v", err)), nil, nil - } - opts.Since = sinceTime - } - - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - gists, resp, err := client.Gists.List(ctx, username, opts) - if err != nil { - return nil, nil, fmt.Errorf("failed to list gists: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + InputSchema: WithPagination(&jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "username": { + Type: "string", + Description: "GitHub username (omit for authenticated user's gists)", + }, + "since": { + Type: "string", + Description: "Only gists updated after this time (ISO 8601 timestamp)", + }, + }, + }), + }, + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + username, err := OptionalParam[string](args, "username") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + since, err := OptionalParam[string](args, "since") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + opts := &github.GistListOptions{ + ListOptions: github.ListOptions{ + Page: pagination.Page, + PerPage: pagination.PerPage, + }, + } + + // Parse since timestamp if provided + if since != "" { + sinceTime, err := parseISOTimestamp(since) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid since timestamp: %v", err)), nil, nil + } + opts.Since = sinceTime + } + + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + + gists, resp, err := client.Gists.List(ctx, username, opts) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to list gists", err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to list gists: %s", string(body))), nil, nil + } + + r, err := json.Marshal(gists) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + } + + return utils.NewToolResultText(string(r)), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to list gists: %s", string(body))), nil, nil - } - - r, err := json.Marshal(gists) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return utils.NewToolResultText(string(r)), nil, nil - }) - - return tool, handler + }, + ) } // GetGist creates a tool to get the content of a gist -func GetGist(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "get_gist", - Description: t("TOOL_GET_GIST_DESCRIPTION", "Get gist content of a particular gist, by gist ID"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_GET_GIST", "Get Gist Content"), - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "gist_id": { - Type: "string", - Description: "The ID of the gist", +func GetGist(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ + Name: "get_gist", + Description: t("TOOL_GET_GIST_DESCRIPTION", "Get gist content of a particular gist, by gist ID"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_GET_GIST", "Get Gist Content"), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "gist_id": { + Type: "string", + Description: "The ID of the gist", + }, }, + Required: []string{"gist_id"}, }, - Required: []string{"gist_id"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - gistID, err := RequiredParam[string](args, "gist_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - gist, resp, err := client.Gists.Get(ctx, gistID) - if err != nil { - return nil, nil, fmt.Errorf("failed to get gist: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + gistID, err := RequiredParam[string](args, "gist_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + + gist, resp, err := client.Gists.Get(ctx, gistID) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get gist", err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to get gist: %s", string(body))), nil, nil + } + + r, err := json.Marshal(gist) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + } + + return utils.NewToolResultText(string(r)), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to get gist: %s", string(body))), nil, nil - } - - r, err := json.Marshal(gist) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return utils.NewToolResultText(string(r)), nil, nil - }) - - return tool, handler + }, + ) } // CreateGist creates a tool to create a new gist -func CreateGist(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "create_gist", - Description: t("TOOL_CREATE_GIST_DESCRIPTION", "Create a new gist"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_CREATE_GIST", "Create Gist"), - ReadOnlyHint: false, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "description": { - Type: "string", - Description: "Description of the gist", - }, - "filename": { - Type: "string", - Description: "Filename for simple single-file gist creation", - }, - "content": { - Type: "string", - Description: "Content for simple single-file gist creation", - }, - "public": { - Type: "boolean", - Description: "Whether the gist is public", - Default: json.RawMessage(`false`), +func CreateGist(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ + Name: "create_gist", + Description: t("TOOL_CREATE_GIST_DESCRIPTION", "Create a new gist"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_CREATE_GIST", "Create Gist"), + ReadOnlyHint: false, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "description": { + Type: "string", + Description: "Description of the gist", + }, + "filename": { + Type: "string", + Description: "Filename for simple single-file gist creation", + }, + "content": { + Type: "string", + Description: "Content for simple single-file gist creation", + }, + "public": { + Type: "boolean", + Description: "Whether the gist is public", + Default: json.RawMessage(`false`), + }, }, + Required: []string{"filename", "content"}, }, - Required: []string{"filename", "content"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - description, err := OptionalParam[string](args, "description") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - filename, err := RequiredParam[string](args, "filename") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - content, err := RequiredParam[string](args, "content") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - public, err := OptionalParam[bool](args, "public") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - files := make(map[github.GistFilename]github.GistFile) - files[github.GistFilename(filename)] = github.GistFile{ - Filename: github.Ptr(filename), - Content: github.Ptr(content), - } - - gist := &github.Gist{ - Files: files, - Public: github.Ptr(public), - Description: github.Ptr(description), - } - - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - createdGist, resp, err := client.Gists.Create(ctx, gist) - if err != nil { - return nil, nil, fmt.Errorf("failed to create gist: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusCreated { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + description, err := OptionalParam[string](args, "description") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + filename, err := RequiredParam[string](args, "filename") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + content, err := RequiredParam[string](args, "content") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + public, err := OptionalParam[bool](args, "public") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + files := make(map[github.GistFilename]github.GistFile) + files[github.GistFilename(filename)] = github.GistFile{ + Filename: github.Ptr(filename), + Content: github.Ptr(content), + } + + gist := &github.Gist{ + Files: files, + Public: github.Ptr(public), + Description: github.Ptr(description), + } + + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + + createdGist, resp, err := client.Gists.Create(ctx, gist) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to create gist", err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusCreated { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to create gist: %s", string(body))), nil, nil + } + + minimalResponse := MinimalResponse{ + ID: createdGist.GetID(), + URL: createdGist.GetHTMLURL(), + } + + r, err := json.Marshal(minimalResponse) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + } + + return utils.NewToolResultText(string(r)), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to create gist: %s", string(body))), nil, nil - } - - minimalResponse := MinimalResponse{ - ID: createdGist.GetID(), - URL: createdGist.GetHTMLURL(), - } - - r, err := json.Marshal(minimalResponse) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return utils.NewToolResultText(string(r)), nil, nil - }) - - return tool, handler + }, + ) } // UpdateGist creates a tool to edit an existing gist -func UpdateGist(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "update_gist", - Description: t("TOOL_UPDATE_GIST_DESCRIPTION", "Update an existing gist"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_UPDATE_GIST", "Update Gist"), - ReadOnlyHint: false, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "gist_id": { - Type: "string", - Description: "ID of the gist to update", - }, - "description": { - Type: "string", - Description: "Updated description of the gist", - }, - "filename": { - Type: "string", - Description: "Filename to update or create", - }, - "content": { - Type: "string", - Description: "Content for the file", +func UpdateGist(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ + Name: "update_gist", + Description: t("TOOL_UPDATE_GIST_DESCRIPTION", "Update an existing gist"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_UPDATE_GIST", "Update Gist"), + ReadOnlyHint: false, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "gist_id": { + Type: "string", + Description: "ID of the gist to update", + }, + "description": { + Type: "string", + Description: "Updated description of the gist", + }, + "filename": { + Type: "string", + Description: "Filename to update or create", + }, + "content": { + Type: "string", + Description: "Content for the file", + }, }, + Required: []string{"gist_id", "filename", "content"}, }, - Required: []string{"gist_id", "filename", "content"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - gistID, err := RequiredParam[string](args, "gist_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - description, err := OptionalParam[string](args, "description") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - filename, err := RequiredParam[string](args, "filename") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - content, err := RequiredParam[string](args, "content") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - files := make(map[github.GistFilename]github.GistFile) - files[github.GistFilename(filename)] = github.GistFile{ - Filename: github.Ptr(filename), - Content: github.Ptr(content), - } - - gist := &github.Gist{ - Files: files, - Description: github.Ptr(description), - } - - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - updatedGist, resp, err := client.Gists.Edit(ctx, gistID, gist) - if err != nil { - return nil, nil, fmt.Errorf("failed to update gist: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + gistID, err := RequiredParam[string](args, "gist_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + description, err := OptionalParam[string](args, "description") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + filename, err := RequiredParam[string](args, "filename") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + content, err := RequiredParam[string](args, "content") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + files := make(map[github.GistFilename]github.GistFile) + files[github.GistFilename(filename)] = github.GistFile{ + Filename: github.Ptr(filename), + Content: github.Ptr(content), + } + + gist := &github.Gist{ + Files: files, + Description: github.Ptr(description), + } + + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + + updatedGist, resp, err := client.Gists.Edit(ctx, gistID, gist) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to update gist", err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to update gist: %s", string(body))), nil, nil + } + + minimalResponse := MinimalResponse{ + ID: updatedGist.GetID(), + URL: updatedGist.GetHTMLURL(), + } + + r, err := json.Marshal(minimalResponse) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + } + + return utils.NewToolResultText(string(r)), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to update gist: %s", string(body))), nil, nil - } - - minimalResponse := MinimalResponse{ - ID: updatedGist.GetID(), - URL: updatedGist.GetHTMLURL(), - } - - r, err := json.Marshal(minimalResponse) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return utils.NewToolResultText(string(r)), nil, nil - }) - - return tool, handler + }, + ) } diff --git a/pkg/github/gists_test.go b/pkg/github/gists_test.go index f0f62f420..44b294eb6 100644 --- a/pkg/github/gists_test.go +++ b/pkg/github/gists_test.go @@ -18,8 +18,8 @@ import ( func Test_ListGists(t *testing.T) { // Verify tool definition - mockClient := github.NewClient(nil) - tool, _ := ListGists(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ListGists(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) @@ -158,28 +158,27 @@ func Test_ListGists(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ListGists(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) + require.NoError(t, err) // Verify results if tc.expectError { - if err != nil { - assert.Contains(t, err.Error(), tc.expectedErrMsg) - } else { - // For errors returned as part of the result, not as an error - assert.NotNil(t, result) - textContent := getTextResult(t, result) - assert.Contains(t, textContent.Text, tc.expectedErrMsg) - } + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } - require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -202,8 +201,8 @@ func Test_ListGists(t *testing.T) { func Test_GetGist(t *testing.T) { // Verify tool definition - mockClient := github.NewClient(nil) - tool, _ := GetGist(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := GetGist(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) @@ -276,28 +275,27 @@ func Test_GetGist(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetGist(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) + require.NoError(t, err) // Verify results if tc.expectError { - if err != nil { - assert.Contains(t, err.Error(), tc.expectedErrMsg) - } else { - // For errors returned as part of the result, not as an error - assert.NotNil(t, result) - textContent := getTextResult(t, result) - assert.Contains(t, textContent.Text, tc.expectedErrMsg) - } + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } - require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -317,8 +315,8 @@ func Test_GetGist(t *testing.T) { func Test_CreateGist(t *testing.T) { // Verify tool definition - mockClient := github.NewClient(nil) - tool, _ := CreateGist(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := CreateGist(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) @@ -423,28 +421,27 @@ func Test_CreateGist(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := CreateGist(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) + require.NoError(t, err) // Verify results if tc.expectError { - if err != nil { - assert.Contains(t, err.Error(), tc.expectedErrMsg) - } else { - // For errors returned as part of the result, not as an error - assert.NotNil(t, result) - textContent := getTextResult(t, result) - assert.Contains(t, textContent.Text, tc.expectedErrMsg) - } + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } - require.NoError(t, err) + require.False(t, result.IsError) assert.NotNil(t, result) // Parse the result and get the text content @@ -462,8 +459,8 @@ func Test_CreateGist(t *testing.T) { func Test_UpdateGist(t *testing.T) { // Verify tool definition - mockClient := github.NewClient(nil) - tool, _ := UpdateGist(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := UpdateGist(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) @@ -583,28 +580,27 @@ func Test_UpdateGist(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := UpdateGist(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) + require.NoError(t, err) // Verify results if tc.expectError { - if err != nil { - assert.Contains(t, err.Error(), tc.expectedErrMsg) - } else { - // For errors returned as part of the result, not as an error - assert.NotNil(t, result) - textContent := getTextResult(t, result) - assert.Contains(t, textContent.Text, tc.expectedErrMsg) - } + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } - require.NoError(t, err) + require.False(t, result.IsError) assert.NotNil(t, result) // Parse the result and get the text content diff --git a/pkg/github/notifications.go b/pkg/github/notifications.go index 7f9e98f91..23d63c946 100644 --- a/pkg/github/notifications.go +++ b/pkg/github/notifications.go @@ -10,6 +10,7 @@ import ( "time" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -24,8 +25,9 @@ const ( ) // ListNotifications creates a tool to list notifications for the current user. -func ListNotifications(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ListNotifications(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "list_notifications", Description: t("TOOL_LIST_NOTIFICATIONS_DESCRIPTION", "Lists all GitHub notifications for the authenticated user, including unread notifications, mentions, review requests, assignments, and updates on issues or pull requests. Use this tool whenever the user asks what to work on next, requests a summary of their GitHub activity, wants to see pending reviews, or needs to check for new updates or tasks. This tool is the primary way to discover actionable items, reminders, and outstanding work on GitHub. Always call this tool when asked what to work on next, what is pending, or what needs attention in GitHub."), Annotations: &mcp.ToolAnnotations{ @@ -59,106 +61,110 @@ func ListNotifications(getClient GetClientFn, t translations.TranslationHelperFu }, }), }, - mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, err - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - filter, err := OptionalParam[string](args, "filter") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + filter, err := OptionalParam[string](args, "filter") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - since, err := OptionalParam[string](args, "since") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + since, err := OptionalParam[string](args, "since") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - before, err := OptionalParam[string](args, "before") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + before, err := OptionalParam[string](args, "before") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - owner, err := OptionalParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := OptionalParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := OptionalParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := OptionalParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - paginationParams, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + paginationParams, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Build options - opts := &github.NotificationListOptions{ - All: filter == FilterIncludeRead, - Participating: filter == FilterOnlyParticipating, - ListOptions: github.ListOptions{ - Page: paginationParams.Page, - PerPage: paginationParams.PerPage, - }, - } + // Build options + opts := &github.NotificationListOptions{ + All: filter == FilterIncludeRead, + Participating: filter == FilterOnlyParticipating, + ListOptions: github.ListOptions{ + Page: paginationParams.Page, + PerPage: paginationParams.PerPage, + }, + } - // Parse time parameters if provided - if since != "" { - sinceTime, err := time.Parse(time.RFC3339, since) - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid since time format, should be RFC3339/ISO8601: %v", err)), nil, nil + // Parse time parameters if provided + if since != "" { + sinceTime, err := time.Parse(time.RFC3339, since) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid since time format, should be RFC3339/ISO8601: %v", err)), nil, nil + } + opts.Since = sinceTime } - opts.Since = sinceTime - } - if before != "" { - beforeTime, err := time.Parse(time.RFC3339, before) - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid before time format, should be RFC3339/ISO8601: %v", err)), nil, nil + if before != "" { + beforeTime, err := time.Parse(time.RFC3339, before) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid before time format, should be RFC3339/ISO8601: %v", err)), nil, nil + } + opts.Before = beforeTime } - opts.Before = beforeTime - } - var notifications []*github.Notification - var resp *github.Response + var notifications []*github.Notification + var resp *github.Response - if owner != "" && repo != "" { - notifications, resp, err = client.Activity.ListRepositoryNotifications(ctx, owner, repo, opts) - } else { - notifications, resp, err = client.Activity.ListNotifications(ctx, opts) - } - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to list notifications", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + if owner != "" && repo != "" { + notifications, resp, err = client.Activity.ListRepositoryNotifications(ctx, owner, repo, opts) + } else { + notifications, resp, err = client.Activity.ListNotifications(ctx, opts) + } + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to list notifications", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to get notifications: %s", string(body))), nil, nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + // Marshal response to JSON + r, err := json.Marshal(notifications) if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, err + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to get notifications: %s", string(body))), nil, nil - } - // Marshal response to JSON - r, err := json.Marshal(notifications) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, err + return utils.NewToolResultText(string(r)), nil, nil } - - return utils.NewToolResultText(string(r)), nil, nil - }) + }, + ) } // DismissNotification creates a tool to mark a notification as read/done. -func DismissNotification(getclient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func DismissNotification(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "dismiss_notification", Description: t("TOOL_DISMISS_NOTIFICATION_DESCRIPTION", "Dismiss a notification by marking it as read or done"), Annotations: &mcp.ToolAnnotations{ @@ -181,62 +187,66 @@ func DismissNotification(getclient GetClientFn, t translations.TranslationHelper Required: []string{"threadID", "state"}, }, }, - mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := getclient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, err - } - - threadID, err := RequiredParam[string](args, "threadID") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - state, err := RequiredParam[string](args, "state") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + threadID, err := RequiredParam[string](args, "threadID") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var resp *github.Response - switch state { - case "done": - // for some inexplicable reason, the API seems to have threadID as int64 and string depending on the endpoint - var threadIDInt int64 - threadIDInt, err = strconv.ParseInt(threadID, 10, 64) + state, err := RequiredParam[string](args, "state") if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid threadID format: %v", err)), nil, nil + return utils.NewToolResultError(err.Error()), nil, nil } - resp, err = client.Activity.MarkThreadDone(ctx, threadIDInt) - case "read": - resp, err = client.Activity.MarkThreadRead(ctx, threadID) - default: - return utils.NewToolResultError("Invalid state. Must be one of: read, done."), nil, nil - } - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to mark notification as %s", state), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + var resp *github.Response + switch state { + case "done": + // for some inexplicable reason, the API seems to have threadID as int64 and string depending on the endpoint + var threadIDInt int64 + threadIDInt, err = strconv.ParseInt(threadID, 10, 64) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid threadID format: %v", err)), nil, nil + } + resp, err = client.Activity.MarkThreadDone(ctx, threadIDInt) + case "read": + resp, err = client.Activity.MarkThreadRead(ctx, threadID) + default: + return utils.NewToolResultError("Invalid state. Must be one of: read, done."), nil, nil + } - if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, err + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to mark notification as %s", state), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to mark notification as %s: %s", state, string(body))), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to mark notification as %s: %s", state, string(body))), nil, nil - } - return utils.NewToolResultText(fmt.Sprintf("Notification marked as %s", state)), nil, nil - }) + return utils.NewToolResultText(fmt.Sprintf("Notification marked as %s", state)), nil, nil + } + }, + ) } // MarkAllNotificationsRead creates a tool to mark all notifications as read. -func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func MarkAllNotificationsRead(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "mark_all_notifications_read", Description: t("TOOL_MARK_ALL_NOTIFICATIONS_READ_DESCRIPTION", "Mark all notifications as read"), Annotations: &mcp.ToolAnnotations{ @@ -261,70 +271,74 @@ func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationH }, }, }, - mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, err - } - - lastReadAt, err := OptionalParam[string](args, "lastReadAt") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - owner, err := OptionalParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := OptionalParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + lastReadAt, err := OptionalParam[string](args, "lastReadAt") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var lastReadTime time.Time - if lastReadAt != "" { - lastReadTime, err = time.Parse(time.RFC3339, lastReadAt) + owner, err := OptionalParam[string](args, "owner") if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid lastReadAt time format, should be RFC3339/ISO8601: %v", err)), nil, nil + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := OptionalParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil } - } else { - lastReadTime = time.Now() - } - markReadOptions := github.Timestamp{ - Time: lastReadTime, - } + var lastReadTime time.Time + if lastReadAt != "" { + lastReadTime, err = time.Parse(time.RFC3339, lastReadAt) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid lastReadAt time format, should be RFC3339/ISO8601: %v", err)), nil, nil + } + } else { + lastReadTime = time.Now() + } - var resp *github.Response - if owner != "" && repo != "" { - resp, err = client.Activity.MarkRepositoryNotificationsRead(ctx, owner, repo, markReadOptions) - } else { - resp, err = client.Activity.MarkNotificationsRead(ctx, markReadOptions) - } - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to mark all notifications as read", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + markReadOptions := github.Timestamp{ + Time: lastReadTime, + } - if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + var resp *github.Response + if owner != "" && repo != "" { + resp, err = client.Activity.MarkRepositoryNotificationsRead(ctx, owner, repo, markReadOptions) + } else { + resp, err = client.Activity.MarkNotificationsRead(ctx, markReadOptions) + } if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, err + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to mark all notifications as read", + resp, + err, + ), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to mark all notifications as read: %s", string(body))), nil, nil - } + defer func() { _ = resp.Body.Close() }() - return utils.NewToolResultText("All notifications marked as read"), nil, nil - }) + if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to mark all notifications as read: %s", string(body))), nil, nil + } + + return utils.NewToolResultText("All notifications marked as read"), nil, nil + } + }, + ) } // GetNotificationDetails creates a tool to get details for a specific notification. -func GetNotificationDetails(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetNotificationDetails(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "get_notification_details", Description: t("TOOL_GET_NOTIFICATION_DETAILS_DESCRIPTION", "Get detailed information for a specific GitHub notification, always call this tool when the user asks for details about a specific notification, if you don't know the ID list notifications first."), Annotations: &mcp.ToolAnnotations{ @@ -342,42 +356,45 @@ func GetNotificationDetails(getClient GetClientFn, t translations.TranslationHel Required: []string{"notificationID"}, }, }, - mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, err - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - notificationID, err := RequiredParam[string](args, "notificationID") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + notificationID, err := RequiredParam[string](args, "notificationID") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - thread, resp, err := client.Activity.GetThread(ctx, notificationID) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to get notification details for ID '%s'", notificationID), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + thread, resp, err := client.Activity.GetThread(ctx, notificationID) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to get notification details for ID '%s'", notificationID), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to get notification details: %s", string(body))), nil, nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + r, err := json.Marshal(thread) if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, err + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to get notification details: %s", string(body))), nil, nil - } - r, err := json.Marshal(thread) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, err + return utils.NewToolResultText(string(r)), nil, nil } - - return utils.NewToolResultText(string(r)), nil, nil - }) + }, + ) } // Enum values for ManageNotificationSubscription action @@ -388,8 +405,9 @@ const ( ) // ManageNotificationSubscription creates a tool to manage a notification subscription (ignore, watch, delete) -func ManageNotificationSubscription(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ManageNotificationSubscription(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "manage_notification_subscription", Description: t("TOOL_MANAGE_NOTIFICATION_SUBSCRIPTION_DESCRIPTION", "Manage a notification subscription: ignore, watch, or delete a notification thread subscription."), Annotations: &mcp.ToolAnnotations{ @@ -412,65 +430,68 @@ func ManageNotificationSubscription(getClient GetClientFn, t translations.Transl Required: []string{"notificationID", "action"}, }, }, - mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, err - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - notificationID, err := RequiredParam[string](args, "notificationID") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - action, err := RequiredParam[string](args, "action") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + notificationID, err := RequiredParam[string](args, "notificationID") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + action, err := RequiredParam[string](args, "action") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var ( - resp *github.Response - result any - apiErr error - ) - - switch action { - case NotificationActionIgnore: - sub := &github.Subscription{Ignored: ToBoolPtr(true)} - result, resp, apiErr = client.Activity.SetThreadSubscription(ctx, notificationID, sub) - case NotificationActionWatch: - sub := &github.Subscription{Ignored: ToBoolPtr(false), Subscribed: ToBoolPtr(true)} - result, resp, apiErr = client.Activity.SetThreadSubscription(ctx, notificationID, sub) - case NotificationActionDelete: - resp, apiErr = client.Activity.DeleteThreadSubscription(ctx, notificationID) - default: - return utils.NewToolResultError("Invalid action. Must be one of: ignore, watch, delete."), nil, nil - } + var ( + resp *github.Response + result any + apiErr error + ) + + switch action { + case NotificationActionIgnore: + sub := &github.Subscription{Ignored: ToBoolPtr(true)} + result, resp, apiErr = client.Activity.SetThreadSubscription(ctx, notificationID, sub) + case NotificationActionWatch: + sub := &github.Subscription{Ignored: ToBoolPtr(false), Subscribed: ToBoolPtr(true)} + result, resp, apiErr = client.Activity.SetThreadSubscription(ctx, notificationID, sub) + case NotificationActionDelete: + resp, apiErr = client.Activity.DeleteThreadSubscription(ctx, notificationID) + default: + return utils.NewToolResultError("Invalid action. Must be one of: ignore, watch, delete."), nil, nil + } - if apiErr != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to %s notification subscription", action), - resp, - apiErr, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + if apiErr != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to %s notification subscription", action), + resp, + apiErr, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - body, _ := io.ReadAll(resp.Body) - return utils.NewToolResultError(fmt.Sprintf("failed to %s notification subscription: %s", action, string(body))), nil, nil - } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + body, _ := io.ReadAll(resp.Body) + return utils.NewToolResultError(fmt.Sprintf("failed to %s notification subscription: %s", action, string(body))), nil, nil + } - if action == NotificationActionDelete { - // Special case for delete as there is no response body - return utils.NewToolResultText("Notification subscription deleted"), nil, nil - } + if action == NotificationActionDelete { + // Special case for delete as there is no response body + return utils.NewToolResultText("Notification subscription deleted"), nil, nil + } - r, err := json.Marshal(result) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, err + r, err := json.Marshal(result) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + } + return utils.NewToolResultText(string(r)), nil, nil } - return utils.NewToolResultText(string(r)), nil, nil - }) + }, + ) } const ( @@ -480,8 +501,9 @@ const ( ) // ManageRepositoryNotificationSubscription creates a tool to manage a repository notification subscription (ignore, watch, delete) -func ManageRepositoryNotificationSubscription(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ManageRepositoryNotificationSubscription(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "manage_repository_notification_subscription", Description: t("TOOL_MANAGE_REPOSITORY_NOTIFICATION_SUBSCRIPTION_DESCRIPTION", "Manage a repository notification subscription: ignore, watch, or delete repository notifications subscription for the provided repository."), Annotations: &mcp.ToolAnnotations{ @@ -508,70 +530,73 @@ func ManageRepositoryNotificationSubscription(getClient GetClientFn, t translati Required: []string{"owner", "repo", "action"}, }, }, - mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, err - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - action, err := RequiredParam[string](args, "action") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + action, err := RequiredParam[string](args, "action") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var ( - resp *github.Response - result any - apiErr error - ) - - switch action { - case RepositorySubscriptionActionIgnore: - sub := &github.Subscription{Ignored: ToBoolPtr(true)} - result, resp, apiErr = client.Activity.SetRepositorySubscription(ctx, owner, repo, sub) - case RepositorySubscriptionActionWatch: - sub := &github.Subscription{Ignored: ToBoolPtr(false), Subscribed: ToBoolPtr(true)} - result, resp, apiErr = client.Activity.SetRepositorySubscription(ctx, owner, repo, sub) - case RepositorySubscriptionActionDelete: - resp, apiErr = client.Activity.DeleteRepositorySubscription(ctx, owner, repo) - default: - return utils.NewToolResultError("Invalid action. Must be one of: ignore, watch, delete."), nil, nil - } + var ( + resp *github.Response + result any + apiErr error + ) + + switch action { + case RepositorySubscriptionActionIgnore: + sub := &github.Subscription{Ignored: ToBoolPtr(true)} + result, resp, apiErr = client.Activity.SetRepositorySubscription(ctx, owner, repo, sub) + case RepositorySubscriptionActionWatch: + sub := &github.Subscription{Ignored: ToBoolPtr(false), Subscribed: ToBoolPtr(true)} + result, resp, apiErr = client.Activity.SetRepositorySubscription(ctx, owner, repo, sub) + case RepositorySubscriptionActionDelete: + resp, apiErr = client.Activity.DeleteRepositorySubscription(ctx, owner, repo) + default: + return utils.NewToolResultError("Invalid action. Must be one of: ignore, watch, delete."), nil, nil + } - if apiErr != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to %s repository subscription", action), - resp, - apiErr, - ), nil, nil - } - if resp != nil { - defer func() { _ = resp.Body.Close() }() - } + if apiErr != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to %s repository subscription", action), + resp, + apiErr, + ), nil, nil + } + if resp != nil { + defer func() { _ = resp.Body.Close() }() + } - // Handle non-2xx status codes - if resp != nil && (resp.StatusCode < 200 || resp.StatusCode >= 300) { - body, _ := io.ReadAll(resp.Body) - return utils.NewToolResultError(fmt.Sprintf("failed to %s repository subscription: %s", action, string(body))), nil, nil - } + // Handle non-2xx status codes + if resp != nil && (resp.StatusCode < 200 || resp.StatusCode >= 300) { + body, _ := io.ReadAll(resp.Body) + return utils.NewToolResultError(fmt.Sprintf("failed to %s repository subscription: %s", action, string(body))), nil, nil + } - if action == RepositorySubscriptionActionDelete { - // Special case for delete as there is no response body - return utils.NewToolResultText("Repository subscription deleted"), nil, nil - } + if action == RepositorySubscriptionActionDelete { + // Special case for delete as there is no response body + return utils.NewToolResultText("Repository subscription deleted"), nil, nil + } - r, err := json.Marshal(result) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, err + r, err := json.Marshal(result) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + } + return utils.NewToolResultText(string(r)), nil, nil } - return utils.NewToolResultText(string(r)), nil, nil - }) + }, + ) } diff --git a/pkg/github/notifications_test.go b/pkg/github/notifications_test.go index 37135bf5c..0a330c316 100644 --- a/pkg/github/notifications_test.go +++ b/pkg/github/notifications_test.go @@ -17,8 +17,8 @@ import ( func Test_ListNotifications(t *testing.T) { // Verify tool definition and schema - mockClient := github.NewClient(nil) - tool, _ := ListNotifications(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ListNotifications(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "list_notifications", tool.Name) @@ -125,12 +125,15 @@ func Test_ListNotifications(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - _, handler := ListNotifications(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) + require.NoError(t, err) if tc.expectError { - require.NoError(t, err) require.True(t, result.IsError) errorContent := getErrorResult(t, result) if tc.expectedErrMsg != "" { @@ -139,7 +142,6 @@ func Test_ListNotifications(t *testing.T) { return } - require.NoError(t, err) require.False(t, result.IsError) textContent := getTextResult(t, result) t.Logf("textContent: %s", textContent.Text) @@ -154,8 +156,8 @@ func Test_ListNotifications(t *testing.T) { func Test_ManageNotificationSubscription(t *testing.T) { // Verify tool definition and schema - mockClient := github.NewClient(nil) - tool, _ := ManageNotificationSubscription(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ManageNotificationSubscription(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "manage_notification_subscription", tool.Name) @@ -256,10 +258,14 @@ func Test_ManageNotificationSubscription(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - _, handler := ManageNotificationSubscription(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) + require.NoError(t, err) if tc.expectError { require.NoError(t, err) require.NotNil(t, result) @@ -295,8 +301,8 @@ func Test_ManageNotificationSubscription(t *testing.T) { func Test_ManageRepositoryNotificationSubscription(t *testing.T) { // Verify tool definition and schema - mockClient := github.NewClient(nil) - tool, _ := ManageRepositoryNotificationSubscription(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ManageRepositoryNotificationSubscription(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "manage_repository_notification_subscription", tool.Name) @@ -415,12 +421,15 @@ func Test_ManageRepositoryNotificationSubscription(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - _, handler := ManageRepositoryNotificationSubscription(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) + require.NoError(t, err) if tc.expectError { - require.NoError(t, err) require.NotNil(t, result) text := getTextResult(t, result).Text switch { @@ -461,8 +470,8 @@ func Test_ManageRepositoryNotificationSubscription(t *testing.T) { func Test_DismissNotification(t *testing.T) { // Verify tool definition and schema - mockClient := github.NewClient(nil) - tool, _ := DismissNotification(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := DismissNotification(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "dismiss_notification", tool.Name) @@ -554,13 +563,16 @@ func Test_DismissNotification(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - _, handler := DismissNotification(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) + require.NoError(t, err) if tc.expectError { // The tool returns a ToolResultError with a specific message - require.NoError(t, err) require.NotNil(t, result) text := getTextResult(t, result).Text switch { @@ -596,8 +608,8 @@ func Test_DismissNotification(t *testing.T) { func Test_MarkAllNotificationsRead(t *testing.T) { // Verify tool definition and schema - mockClient := github.NewClient(nil) - tool, _ := MarkAllNotificationsRead(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := MarkAllNotificationsRead(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "mark_all_notifications_read", tool.Name) @@ -676,12 +688,15 @@ func Test_MarkAllNotificationsRead(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - _, handler := MarkAllNotificationsRead(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) + require.NoError(t, err) if tc.expectError { - require.NoError(t, err) require.True(t, result.IsError) errorContent := getErrorResult(t, result) if tc.expectedErrMsg != "" { @@ -702,8 +717,8 @@ func Test_MarkAllNotificationsRead(t *testing.T) { func Test_GetNotificationDetails(t *testing.T) { // Verify tool definition and schema - mockClient := github.NewClient(nil) - tool, _ := GetNotificationDetails(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := GetNotificationDetails(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "get_notification_details", tool.Name) @@ -757,12 +772,15 @@ func Test_GetNotificationDetails(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - _, handler := GetNotificationDetails(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) + require.NoError(t, err) if tc.expectError { - require.NoError(t, err) require.True(t, result.IsError) errorContent := getErrorResult(t, result) if tc.expectedErrMsg != "" { diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 51fcad7f6..849e4e68a 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -278,14 +278,14 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG notifications := toolsets.NewToolset(ToolsetMetadataNotifications.ID, ToolsetMetadataNotifications.Description). SetDependencies(deps). AddReadTools( - toolsets.NewServerToolLegacy(ListNotifications(getClient, t)), - toolsets.NewServerToolLegacy(GetNotificationDetails(getClient, t)), + ListNotifications(t), + GetNotificationDetails(t), ). AddWriteTools( - toolsets.NewServerToolLegacy(DismissNotification(getClient, t)), - toolsets.NewServerToolLegacy(MarkAllNotificationsRead(getClient, t)), - toolsets.NewServerToolLegacy(ManageNotificationSubscription(getClient, t)), - toolsets.NewServerToolLegacy(ManageRepositoryNotificationSubscription(getClient, t)), + DismissNotification(t), + MarkAllNotificationsRead(t), + ManageNotificationSubscription(t), + ManageRepositoryNotificationSubscription(t), ) discussions := toolsets.NewToolset(ToolsetMetadataDiscussions.ID, ToolsetMetadataDiscussions.Description). @@ -342,12 +342,12 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG gists := toolsets.NewToolset(ToolsetMetadataGists.ID, ToolsetMetadataGists.Description). SetDependencies(deps). AddReadTools( - toolsets.NewServerToolLegacy(ListGists(getClient, t)), - toolsets.NewServerToolLegacy(GetGist(getClient, t)), + ListGists(t), + GetGist(t), ). AddWriteTools( - toolsets.NewServerToolLegacy(CreateGist(getClient, t)), - toolsets.NewServerToolLegacy(UpdateGist(getClient, t)), + CreateGist(t), + UpdateGist(t), ) projects := toolsets.NewToolset(ToolsetMetadataProjects.ID, ToolsetMetadataProjects.Description).