Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 99 additions & 90 deletions pkg/github/code_scanning.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@ import (
"net/http"

ghErrors "github.com/github/github-mcp-server/pkg/errors"
"github.com/github/github-mcp-server/pkg/toolsets"
"github.com/github/github-mcp-server/pkg/translations"
"github.com/github/github-mcp-server/pkg/utils"
"github.com/google/go-github/v79/github"
"github.com/google/jsonschema-go/jsonschema"
"github.com/modelcontextprotocol/go-sdk/mcp"
)

func GetCodeScanningAlert(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) {
return mcp.Tool{
func GetCodeScanningAlert(t translations.TranslationHelperFunc) toolsets.ServerTool {
return NewTool(
mcp.Tool{
Name: "get_code_scanning_alert",
Description: t("TOOL_GET_CODE_SCANNING_ALERT_DESCRIPTION", "Get details of a specific code scanning alert in a GitHub repository."),
Annotations: &mcp.ToolAnnotations{
Expand All @@ -42,54 +44,58 @@ func GetCodeScanningAlert(getClient GetClientFn, t translations.TranslationHelpe
Required: []string{"owner", "repo", "alertNumber"},
},
},
func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) {
owner, err := RequiredParam[string](args, "owner")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
repo, err := RequiredParam[string](args, "repo")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
alertNumber, err := RequiredInt(args, "alertNumber")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] {
return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) {
owner, err := RequiredParam[string](args, "owner")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
repo, err := RequiredParam[string](args, "repo")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
alertNumber, err := RequiredInt(args, "alertNumber")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}

client, err := getClient(ctx)
if err != nil {
return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil
}
client, err := deps.GetClient(ctx)
if err != nil {
return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil
}

alert, resp, err := client.CodeScanning.GetAlert(ctx, owner, repo, int64(alertNumber))
if err != nil {
return ghErrors.NewGitHubAPIErrorResponse(ctx,
"failed to get alert",
resp,
err,
), nil, nil
}
defer func() { _ = resp.Body.Close() }()
alert, resp, err := client.CodeScanning.GetAlert(ctx, owner, repo, int64(alertNumber))
if err != nil {
return ghErrors.NewGitHubAPIErrorResponse(ctx,
"failed to get alert",
resp,
err,
), nil, nil
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil
}
return utils.NewToolResultError(fmt.Sprintf("failed to get alert: %s", string(body))), nil, nil
}

if resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
r, err := json.Marshal(alert)
if err != nil {
return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil
return utils.NewToolResultErrorFromErr("failed to marshal alert", err), nil, nil
}
return utils.NewToolResultError(fmt.Sprintf("failed to get alert: %s", string(body))), nil, nil
}

r, err := json.Marshal(alert)
if err != nil {
return utils.NewToolResultErrorFromErr("failed to marshal alert", err), nil, nil
return utils.NewToolResultText(string(r)), nil, nil
}

return utils.NewToolResultText(string(r)), nil, nil
}
},
)
}

func ListCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) {
return mcp.Tool{
func ListCodeScanningAlerts(t translations.TranslationHelperFunc) toolsets.ServerTool {
return NewTool(
mcp.Tool{
Name: "list_code_scanning_alerts",
Description: t("TOOL_LIST_CODE_SCANNING_ALERTS_DESCRIPTION", "List code scanning alerts in a GitHub repository."),
Annotations: &mcp.ToolAnnotations{
Expand Down Expand Up @@ -130,59 +136,62 @@ func ListCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHel
Required: []string{"owner", "repo"},
},
},
func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) {
owner, err := RequiredParam[string](args, "owner")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
repo, err := RequiredParam[string](args, "repo")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
ref, err := OptionalParam[string](args, "ref")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
state, err := OptionalParam[string](args, "state")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
severity, err := OptionalParam[string](args, "severity")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
toolName, err := OptionalParam[string](args, "tool_name")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] {
return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) {
owner, err := RequiredParam[string](args, "owner")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
repo, err := RequiredParam[string](args, "repo")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
ref, err := OptionalParam[string](args, "ref")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
state, err := OptionalParam[string](args, "state")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
severity, err := OptionalParam[string](args, "severity")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
toolName, err := OptionalParam[string](args, "tool_name")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}

client, err := getClient(ctx)
if err != nil {
return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil
}
alerts, resp, err := client.CodeScanning.ListAlertsForRepo(ctx, owner, repo, &github.AlertListOptions{Ref: ref, State: state, Severity: severity, ToolName: toolName})
if err != nil {
return ghErrors.NewGitHubAPIErrorResponse(ctx,
"failed to list alerts",
resp,
err,
), nil, nil
}
defer func() { _ = resp.Body.Close() }()
client, err := deps.GetClient(ctx)
if err != nil {
return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil
}
alerts, resp, err := client.CodeScanning.ListAlertsForRepo(ctx, owner, repo, &github.AlertListOptions{Ref: ref, State: state, Severity: severity, ToolName: toolName})
if err != nil {
return ghErrors.NewGitHubAPIErrorResponse(ctx,
"failed to list alerts",
resp,
err,
), nil, nil
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil
}
return utils.NewToolResultError(fmt.Sprintf("failed to list alerts: %s", string(body))), nil, nil
}

if resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
r, err := json.Marshal(alerts)
if err != nil {
return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil
return utils.NewToolResultErrorFromErr("failed to marshal alerts", err), nil, nil
}
return utils.NewToolResultError(fmt.Sprintf("failed to list alerts: %s", string(body))), nil, nil
}

r, err := json.Marshal(alerts)
if err != nil {
return utils.NewToolResultErrorFromErr("failed to marshal alerts", err), nil, nil
return utils.NewToolResultText(string(r)), nil, nil
}

return utils.NewToolResultText(string(r)), nil, nil
}
},
)
}
36 changes: 20 additions & 16 deletions pkg/github/code_scanning_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@ import (

func Test_GetCodeScanningAlert(t *testing.T) {
// Verify tool definition once
mockClient := github.NewClient(nil)
tool, _ := GetCodeScanningAlert(stubGetClientFn(mockClient), translations.NullTranslationHelper)
require.NoError(t, toolsnaps.Test(tool.Name, tool))
toolDef := GetCodeScanningAlert(translations.NullTranslationHelper)
require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool))

assert.Equal(t, "get_code_scanning_alert", tool.Name)
assert.NotEmpty(t, tool.Description)
assert.Equal(t, "get_code_scanning_alert", toolDef.Tool.Name)
assert.NotEmpty(t, toolDef.Tool.Description)

// InputSchema is of type any, need to cast to *jsonschema.Schema
schema, ok := tool.InputSchema.(*jsonschema.Schema)
schema, ok := toolDef.Tool.InputSchema.(*jsonschema.Schema)
require.True(t, ok, "InputSchema should be *jsonschema.Schema")
assert.Contains(t, schema.Properties, "owner")
assert.Contains(t, schema.Properties, "repo")
Expand Down Expand Up @@ -89,13 +88,16 @@ func Test_GetCodeScanningAlert(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
// Setup client with mock
client := github.NewClient(tc.mockedClient)
_, handler := GetCodeScanningAlert(stubGetClientFn(client), translations.NullTranslationHelper)
deps := ToolDependencies{
GetClient: stubGetClientFn(client),
}
handler := toolDef.Handler(deps)

// Create call request
request := createMCPRequest(tc.requestArgs)

// Call handler with new signature
result, _, err := handler(context.Background(), &request, tc.requestArgs)
result, err := handler(context.Background(), &request)

// Verify results
if tc.expectError {
Expand Down Expand Up @@ -127,15 +129,14 @@ func Test_GetCodeScanningAlert(t *testing.T) {

func Test_ListCodeScanningAlerts(t *testing.T) {
// Verify tool definition once
mockClient := github.NewClient(nil)
tool, _ := ListCodeScanningAlerts(stubGetClientFn(mockClient), translations.NullTranslationHelper)
require.NoError(t, toolsnaps.Test(tool.Name, tool))
toolDef := ListCodeScanningAlerts(translations.NullTranslationHelper)
require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool))

assert.Equal(t, "list_code_scanning_alerts", tool.Name)
assert.NotEmpty(t, tool.Description)
assert.Equal(t, "list_code_scanning_alerts", toolDef.Tool.Name)
assert.NotEmpty(t, toolDef.Tool.Description)

// InputSchema is of type any, need to cast to *jsonschema.Schema
schema, ok := tool.InputSchema.(*jsonschema.Schema)
schema, ok := toolDef.Tool.InputSchema.(*jsonschema.Schema)
require.True(t, ok, "InputSchema should be *jsonschema.Schema")
assert.Contains(t, schema.Properties, "owner")
assert.Contains(t, schema.Properties, "repo")
Expand Down Expand Up @@ -219,13 +220,16 @@ func Test_ListCodeScanningAlerts(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
// Setup client with mock
client := github.NewClient(tc.mockedClient)
_, handler := ListCodeScanningAlerts(stubGetClientFn(client), translations.NullTranslationHelper)
deps := ToolDependencies{
GetClient: stubGetClientFn(client),
}
handler := toolDef.Handler(deps)

// Create call request
request := createMCPRequest(tc.requestArgs)

// Call handler with new signature
result, _, err := handler(context.Background(), &request, tc.requestArgs)
result, err := handler(context.Background(), &request)

// Verify results
if tc.expectError {
Expand Down
Loading