diff --git a/pkg/github/code_scanning.go b/pkg/github/code_scanning.go index 0f8e2780b..518855a59 100644 --- a/pkg/github/code_scanning.go +++ b/pkg/github/code_scanning.go @@ -8,6 +8,7 @@ import ( "net/http" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -15,8 +16,9 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -func GetCodeScanningAlert(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetCodeScanningAlert(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "get_code_scanning_alert", Description: t("TOOL_GET_CODE_SCANNING_ALERT_DESCRIPTION", "Get details of a specific code scanning alert in a GitHub repository."), Annotations: &mcp.ToolAnnotations{ @@ -42,54 +44,58 @@ func GetCodeScanningAlert(getClient GetClientFn, t translations.TranslationHelpe Required: []string{"owner", "repo", "alertNumber"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - alertNumber, err := RequiredInt(args, "alertNumber") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + alertNumber, err := RequiredInt(args, "alertNumber") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - alert, resp, err := client.CodeScanning.GetAlert(ctx, owner, repo, int64(alertNumber)) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get alert", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + alert, resp, err := client.CodeScanning.GetAlert(ctx, owner, repo, int64(alertNumber)) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get alert", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to get alert: %s", string(body))), nil, nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + r, err := json.Marshal(alert) if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to marshal alert", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to get alert: %s", string(body))), nil, nil - } - r, err := json.Marshal(alert) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal alert", err), nil, nil + return utils.NewToolResultText(string(r)), nil, nil } - - return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) } -func ListCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ListCodeScanningAlerts(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "list_code_scanning_alerts", Description: t("TOOL_LIST_CODE_SCANNING_ALERTS_DESCRIPTION", "List code scanning alerts in a GitHub repository."), Annotations: &mcp.ToolAnnotations{ @@ -130,59 +136,62 @@ func ListCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHel Required: []string{"owner", "repo"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - ref, err := OptionalParam[string](args, "ref") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - state, err := OptionalParam[string](args, "state") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - severity, err := OptionalParam[string](args, "severity") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - toolName, err := OptionalParam[string](args, "tool_name") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + ref, err := OptionalParam[string](args, "ref") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + state, err := OptionalParam[string](args, "state") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + severity, err := OptionalParam[string](args, "severity") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + toolName, err := OptionalParam[string](args, "tool_name") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } - alerts, resp, err := client.CodeScanning.ListAlertsForRepo(ctx, owner, repo, &github.AlertListOptions{Ref: ref, State: state, Severity: severity, ToolName: toolName}) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to list alerts", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + alerts, resp, err := client.CodeScanning.ListAlertsForRepo(ctx, owner, repo, &github.AlertListOptions{Ref: ref, State: state, Severity: severity, ToolName: toolName}) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to list alerts", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to list alerts: %s", string(body))), nil, nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + r, err := json.Marshal(alerts) if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + return utils.NewToolResultErrorFromErr("failed to marshal alerts", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to list alerts: %s", string(body))), nil, nil - } - r, err := json.Marshal(alerts) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal alerts", err), nil, nil + return utils.NewToolResultText(string(r)), nil, nil } - - return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) } diff --git a/pkg/github/code_scanning_test.go b/pkg/github/code_scanning_test.go index 13e89fc30..5e56e6788 100644 --- a/pkg/github/code_scanning_test.go +++ b/pkg/github/code_scanning_test.go @@ -17,15 +17,14 @@ import ( func Test_GetCodeScanningAlert(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := GetCodeScanningAlert(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + toolDef := GetCodeScanningAlert(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "get_code_scanning_alert", tool.Name) - assert.NotEmpty(t, tool.Description) + assert.Equal(t, "get_code_scanning_alert", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) // InputSchema is of type any, need to cast to *jsonschema.Schema - schema, ok := tool.InputSchema.(*jsonschema.Schema) + schema, ok := toolDef.Tool.InputSchema.(*jsonschema.Schema) require.True(t, ok, "InputSchema should be *jsonschema.Schema") assert.Contains(t, schema.Properties, "owner") assert.Contains(t, schema.Properties, "repo") @@ -89,13 +88,16 @@ func Test_GetCodeScanningAlert(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetCodeScanningAlert(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler with new signature - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -127,15 +129,14 @@ func Test_GetCodeScanningAlert(t *testing.T) { func Test_ListCodeScanningAlerts(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListCodeScanningAlerts(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + toolDef := ListCodeScanningAlerts(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "list_code_scanning_alerts", tool.Name) - assert.NotEmpty(t, tool.Description) + assert.Equal(t, "list_code_scanning_alerts", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) // InputSchema is of type any, need to cast to *jsonschema.Schema - schema, ok := tool.InputSchema.(*jsonschema.Schema) + schema, ok := toolDef.Tool.InputSchema.(*jsonschema.Schema) require.True(t, ok, "InputSchema should be *jsonschema.Schema") assert.Contains(t, schema.Properties, "owner") assert.Contains(t, schema.Properties, "repo") @@ -219,13 +220,16 @@ func Test_ListCodeScanningAlerts(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ListCodeScanningAlerts(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler with new signature - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { diff --git a/pkg/github/dependabot.go b/pkg/github/dependabot.go index 351cbdb37..b80fd0aa0 100644 --- a/pkg/github/dependabot.go +++ b/pkg/github/dependabot.go @@ -8,6 +8,7 @@ import ( "net/http" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -15,168 +16,170 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -func GetDependabotAlert(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "get_dependabot_alert", - Description: t("TOOL_GET_DEPENDABOT_ALERT_DESCRIPTION", "Get details of a specific dependabot alert in a GitHub repository."), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_GET_DEPENDABOT_ALERT_USER_TITLE", "Get dependabot alert"), - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "The owner of the repository.", - }, - "repo": { - Type: "string", - Description: "The name of the repository.", - }, - "alertNumber": { - Type: "number", - Description: "The number of the alert.", +func GetDependabotAlert(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ + Name: "get_dependabot_alert", + Description: t("TOOL_GET_DEPENDABOT_ALERT_DESCRIPTION", "Get details of a specific dependabot alert in a GitHub repository."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_GET_DEPENDABOT_ALERT_USER_TITLE", "Get dependabot alert"), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "The owner of the repository.", + }, + "repo": { + Type: "string", + Description: "The name of the repository.", + }, + "alertNumber": { + Type: "number", + Description: "The number of the alert.", + }, }, + Required: []string{"owner", "repo", "alertNumber"}, }, - Required: []string{"owner", "repo", "alertNumber"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - alertNumber, err := RequiredInt(args, "alertNumber") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, err - } - - alert, resp, err := client.Dependabot.GetRepoAlert(ctx, owner, repo, alertNumber) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to get alert with number '%d'", alertNumber), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, err + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + alertNumber, err := RequiredInt(args, "alertNumber") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, err + } + + alert, resp, err := client.Dependabot.GetRepoAlert(ctx, owner, repo, alertNumber) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to get alert with number '%d'", alertNumber), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, err + } + return utils.NewToolResultError(fmt.Sprintf("failed to get alert: %s", string(body))), nil, nil + } + + r, err := json.Marshal(alert) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal alert", err), nil, err + } + + return utils.NewToolResultText(string(r)), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to get alert: %s", string(body))), nil, nil - } - - r, err := json.Marshal(alert) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal alert", err), nil, err - } - - return utils.NewToolResultText(string(r)), nil, nil - }) - - return tool, handler + }, + ) } -func ListDependabotAlerts(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "list_dependabot_alerts", - Description: t("TOOL_LIST_DEPENDABOT_ALERTS_DESCRIPTION", "List dependabot alerts in a GitHub repository."), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_LIST_DEPENDABOT_ALERTS_USER_TITLE", "List dependabot alerts"), - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "The owner of the repository.", - }, - "repo": { - Type: "string", - Description: "The name of the repository.", - }, - "state": { - Type: "string", - Description: "Filter dependabot alerts by state. Defaults to open", - Enum: []any{"open", "fixed", "dismissed", "auto_dismissed"}, - Default: json.RawMessage(`"open"`), - }, - "severity": { - Type: "string", - Description: "Filter dependabot alerts by severity", - Enum: []any{"low", "medium", "high", "critical"}, +func ListDependabotAlerts(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ + Name: "list_dependabot_alerts", + Description: t("TOOL_LIST_DEPENDABOT_ALERTS_DESCRIPTION", "List dependabot alerts in a GitHub repository."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_LIST_DEPENDABOT_ALERTS_USER_TITLE", "List dependabot alerts"), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "The owner of the repository.", + }, + "repo": { + Type: "string", + Description: "The name of the repository.", + }, + "state": { + Type: "string", + Description: "Filter dependabot alerts by state. Defaults to open", + Enum: []any{"open", "fixed", "dismissed", "auto_dismissed"}, + Default: json.RawMessage(`"open"`), + }, + "severity": { + Type: "string", + Description: "Filter dependabot alerts by severity", + Enum: []any{"low", "medium", "high", "critical"}, + }, }, + Required: []string{"owner", "repo"}, }, - Required: []string{"owner", "repo"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - state, err := OptionalParam[string](args, "state") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - severity, err := OptionalParam[string](args, "severity") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, err - } - - alerts, resp, err := client.Dependabot.ListRepoAlerts(ctx, owner, repo, &github.ListAlertsOptions{ - State: ToStringPtr(state), - Severity: ToStringPtr(severity), - }) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to list alerts for repository '%s/%s'", owner, repo), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, err + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + state, err := OptionalParam[string](args, "state") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + severity, err := OptionalParam[string](args, "severity") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, err + } + + alerts, resp, err := client.Dependabot.ListRepoAlerts(ctx, owner, repo, &github.ListAlertsOptions{ + State: ToStringPtr(state), + Severity: ToStringPtr(severity), + }) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to list alerts for repository '%s/%s'", owner, repo), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, err + } + return utils.NewToolResultError(fmt.Sprintf("failed to list alerts: %s", string(body))), nil, nil + } + + r, err := json.Marshal(alerts) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal alerts", err), nil, err + } + + return utils.NewToolResultText(string(r)), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to list alerts: %s", string(body))), nil, nil - } - - r, err := json.Marshal(alerts) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal alerts", err), nil, err - } - - return utils.NewToolResultText(string(r)), nil, nil - }) - - return tool, handler + }, + ) } diff --git a/pkg/github/dependabot_test.go b/pkg/github/dependabot_test.go index 24e5130e9..ace0eb07a 100644 --- a/pkg/github/dependabot_test.go +++ b/pkg/github/dependabot_test.go @@ -16,8 +16,8 @@ import ( func Test_GetDependabotAlert(t *testing.T) { // Verify tool definition - mockClient := github.NewClient(nil) - tool, _ := GetDependabotAlert(stubGetClientFn(mockClient), translations.NullTranslationHelper) + toolDef := GetDependabotAlert(translations.NullTranslationHelper) + tool := toolDef.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) // Validate tool schema @@ -81,13 +81,14 @@ func Test_GetDependabotAlert(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetDependabotAlert(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{GetClient: stubGetClientFn(client)} + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -117,8 +118,8 @@ func Test_GetDependabotAlert(t *testing.T) { func Test_ListDependabotAlerts(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListDependabotAlerts(stubGetClientFn(mockClient), translations.NullTranslationHelper) + toolDef := ListDependabotAlerts(translations.NullTranslationHelper) + tool := toolDef.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "list_dependabot_alerts", tool.Name) @@ -231,11 +232,12 @@ func Test_ListDependabotAlerts(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - _, handler := ListDependabotAlerts(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{GetClient: stubGetClientFn(client)} + handler := toolDef.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) if tc.expectError { require.NoError(t, err) diff --git a/pkg/github/secret_scanning.go b/pkg/github/secret_scanning.go index 297e1ebfe..fa5618791 100644 --- a/pkg/github/secret_scanning.go +++ b/pkg/github/secret_scanning.go @@ -8,6 +8,7 @@ import ( "net/http" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -15,8 +16,9 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -func GetSecretScanningAlert(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetSecretScanningAlert(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "get_secret_scanning_alert", Description: t("TOOL_GET_SECRET_SCANNING_ALERT_DESCRIPTION", "Get details of a specific secret scanning alert in a GitHub repository."), Annotations: &mcp.ToolAnnotations{ @@ -42,54 +44,58 @@ func GetSecretScanningAlert(getClient GetClientFn, t translations.TranslationHel Required: []string{"owner", "repo", "alertNumber"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - alertNumber, err := RequiredInt(args, "alertNumber") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + alertNumber, err := RequiredInt(args, "alertNumber") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - alert, resp, err := client.SecretScanning.GetAlert(ctx, owner, repo, int64(alertNumber)) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to get alert with number '%d'", alertNumber), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + alert, resp, err := client.SecretScanning.GetAlert(ctx, owner, repo, int64(alertNumber)) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to get alert with number '%d'", alertNumber), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return utils.NewToolResultError(fmt.Sprintf("failed to get alert: %s", string(body))), nil, nil + } + + r, err := json.Marshal(alert) if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return nil, nil, fmt.Errorf("failed to marshal alert: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("failed to get alert: %s", string(body))), nil, nil - } - r, err := json.Marshal(alert) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal alert: %w", err) + return utils.NewToolResultText(string(r)), nil, nil } - - return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) } -func ListSecretScanningAlerts(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ListSecretScanningAlerts(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "list_secret_scanning_alerts", Description: t("TOOL_LIST_SECRET_SCANNING_ALERTS_DESCRIPTION", "List secret scanning alerts in a GitHub repository."), Annotations: &mcp.ToolAnnotations{ @@ -125,55 +131,58 @@ func ListSecretScanningAlerts(getClient GetClientFn, t translations.TranslationH Required: []string{"owner", "repo"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - state, err := OptionalParam[string](args, "state") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - secretType, err := OptionalParam[string](args, "secret_type") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - resolution, err := OptionalParam[string](args, "resolution") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + state, err := OptionalParam[string](args, "state") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + secretType, err := OptionalParam[string](args, "secret_type") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + resolution, err := OptionalParam[string](args, "resolution") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - alerts, resp, err := client.SecretScanning.ListAlertsForRepo(ctx, owner, repo, &github.SecretScanningAlertListOptions{State: state, SecretType: secretType, Resolution: resolution}) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to list alerts for repository '%s/%s'", owner, repo), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + alerts, resp, err := client.SecretScanning.ListAlertsForRepo(ctx, owner, repo, &github.SecretScanningAlertListOptions{State: state, SecretType: secretType, Resolution: resolution}) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to list alerts for repository '%s/%s'", owner, repo), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return utils.NewToolResultError(fmt.Sprintf("failed to list alerts: %s", string(body))), nil, nil + } + + r, err := json.Marshal(alerts) if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return nil, nil, fmt.Errorf("failed to marshal alerts: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("failed to list alerts: %s", string(body))), nil, nil - } - r, err := json.Marshal(alerts) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal alerts: %w", err) + return utils.NewToolResultText(string(r)), nil, nil } - - return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) } diff --git a/pkg/github/secret_scanning_test.go b/pkg/github/secret_scanning_test.go index 6eeac1862..83de16409 100644 --- a/pkg/github/secret_scanning_test.go +++ b/pkg/github/secret_scanning_test.go @@ -16,16 +16,15 @@ import ( ) func Test_GetSecretScanningAlert(t *testing.T) { - mockClient := github.NewClient(nil) - tool, _ := GetSecretScanningAlert(stubGetClientFn(mockClient), translations.NullTranslationHelper) + toolDef := GetSecretScanningAlert(translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "get_secret_scanning_alert", tool.Name) - assert.NotEmpty(t, tool.Description) + assert.Equal(t, "get_secret_scanning_alert", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) // Verify InputSchema structure - schema, ok := tool.InputSchema.(*jsonschema.Schema) + schema, ok := toolDef.Tool.InputSchema.(*jsonschema.Schema) require.True(t, ok, "InputSchema should be *jsonschema.Schema") assert.Contains(t, schema.Properties, "owner") assert.Contains(t, schema.Properties, "repo") @@ -88,13 +87,16 @@ func Test_GetSecretScanningAlert(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetSecretScanningAlert(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -125,16 +127,15 @@ func Test_GetSecretScanningAlert(t *testing.T) { func Test_ListSecretScanningAlerts(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListSecretScanningAlerts(stubGetClientFn(mockClient), translations.NullTranslationHelper) + toolDef := ListSecretScanningAlerts(translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "list_secret_scanning_alerts", tool.Name) - assert.NotEmpty(t, tool.Description) + assert.Equal(t, "list_secret_scanning_alerts", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) // Verify InputSchema structure - schema, ok := tool.InputSchema.(*jsonschema.Schema) + schema, ok := toolDef.Tool.InputSchema.(*jsonschema.Schema) require.True(t, ok, "InputSchema should be *jsonschema.Schema") assert.Contains(t, schema.Properties, "owner") assert.Contains(t, schema.Properties, "repo") @@ -227,11 +228,14 @@ func Test_ListSecretScanningAlerts(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - _, handler := ListSecretScanningAlerts(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := toolDef.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) if tc.expectError { require.NoError(t, err) diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 1be7e6151..947c727c2 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -259,20 +259,20 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG codeSecurity := toolsets.NewToolset(ToolsetMetadataCodeSecurity.ID, ToolsetMetadataCodeSecurity.Description). SetDependencies(deps). AddReadTools( - toolsets.NewServerToolLegacy(GetCodeScanningAlert(getClient, t)), - toolsets.NewServerToolLegacy(ListCodeScanningAlerts(getClient, t)), + GetCodeScanningAlert(t), + ListCodeScanningAlerts(t), ) secretProtection := toolsets.NewToolset(ToolsetMetadataSecretProtection.ID, ToolsetMetadataSecretProtection.Description). SetDependencies(deps). AddReadTools( - toolsets.NewServerToolLegacy(GetSecretScanningAlert(getClient, t)), - toolsets.NewServerToolLegacy(ListSecretScanningAlerts(getClient, t)), + GetSecretScanningAlert(t), + ListSecretScanningAlerts(t), ) dependabot := toolsets.NewToolset(ToolsetMetadataDependabot.ID, ToolsetMetadataDependabot.Description). SetDependencies(deps). AddReadTools( - toolsets.NewServerToolLegacy(GetDependabotAlert(getClient, t)), - toolsets.NewServerToolLegacy(ListDependabotAlerts(getClient, t)), + GetDependabotAlert(t), + ListDependabotAlerts(t), ) notifications := toolsets.NewToolset(ToolsetMetadataNotifications.ID, ToolsetMetadataNotifications.Description).