diff --git a/pkg/github/security_advisories.go b/pkg/github/security_advisories.go index 58148a7a3..8c7df4265 100644 --- a/pkg/github/security_advisories.go +++ b/pkg/github/security_advisories.go @@ -7,6 +7,7 @@ import ( "io" "net/http" + "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -14,445 +15,449 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -func ListGlobalSecurityAdvisories(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "list_global_security_advisories", - Description: t("TOOL_LIST_GLOBAL_SECURITY_ADVISORIES_DESCRIPTION", "List global security advisories from GitHub."), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_LIST_GLOBAL_SECURITY_ADVISORIES_USER_TITLE", "List global security advisories"), - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "ghsaId": { - Type: "string", - Description: "Filter by GitHub Security Advisory ID (format: GHSA-xxxx-xxxx-xxxx).", - }, - "type": { - Type: "string", - Description: "Advisory type.", - Enum: []any{"reviewed", "malware", "unreviewed"}, - Default: json.RawMessage(`"reviewed"`), - }, - "cveId": { - Type: "string", - Description: "Filter by CVE ID.", - }, - "ecosystem": { - Type: "string", - Description: "Filter by package ecosystem.", - Enum: []any{"actions", "composer", "erlang", "go", "maven", "npm", "nuget", "other", "pip", "pub", "rubygems", "rust"}, - }, - "severity": { - Type: "string", - Description: "Filter by severity.", - Enum: []any{"unknown", "low", "medium", "high", "critical"}, - }, - "cwes": { - Type: "array", - Description: "Filter by Common Weakness Enumeration IDs (e.g. [\"79\", \"284\", \"22\"]).", - Items: &jsonschema.Schema{ - Type: "string", +func ListGlobalSecurityAdvisories(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ + Name: "list_global_security_advisories", + Description: t("TOOL_LIST_GLOBAL_SECURITY_ADVISORIES_DESCRIPTION", "List global security advisories from GitHub."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_LIST_GLOBAL_SECURITY_ADVISORIES_USER_TITLE", "List global security advisories"), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "ghsaId": { + Type: "string", + Description: "Filter by GitHub Security Advisory ID (format: GHSA-xxxx-xxxx-xxxx).", + }, + "type": { + Type: "string", + Description: "Advisory type.", + Enum: []any{"reviewed", "malware", "unreviewed"}, + Default: json.RawMessage(`"reviewed"`), + }, + "cveId": { + Type: "string", + Description: "Filter by CVE ID.", + }, + "ecosystem": { + Type: "string", + Description: "Filter by package ecosystem.", + Enum: []any{"actions", "composer", "erlang", "go", "maven", "npm", "nuget", "other", "pip", "pub", "rubygems", "rust"}, + }, + "severity": { + Type: "string", + Description: "Filter by severity.", + Enum: []any{"unknown", "low", "medium", "high", "critical"}, + }, + "cwes": { + Type: "array", + Description: "Filter by Common Weakness Enumeration IDs (e.g. [\"79\", \"284\", \"22\"]).", + Items: &jsonschema.Schema{ + Type: "string", + }, + }, + "isWithdrawn": { + Type: "boolean", + Description: "Whether to only return withdrawn advisories.", + }, + "affects": { + Type: "string", + Description: "Filter advisories by affected package or version (e.g. \"package1,package2@1.0.0\").", + }, + "published": { + Type: "string", + Description: "Filter by publish date or date range (ISO 8601 date or range).", + }, + "updated": { + Type: "string", + Description: "Filter by update date or date range (ISO 8601 date or range).", + }, + "modified": { + Type: "string", + Description: "Filter by publish or update date or date range (ISO 8601 date or range).", }, - }, - "isWithdrawn": { - Type: "boolean", - Description: "Whether to only return withdrawn advisories.", - }, - "affects": { - Type: "string", - Description: "Filter advisories by affected package or version (e.g. \"package1,package2@1.0.0\").", - }, - "published": { - Type: "string", - Description: "Filter by publish date or date range (ISO 8601 date or range).", - }, - "updated": { - Type: "string", - Description: "Filter by update date or date range (ISO 8601 date or range).", - }, - "modified": { - Type: "string", - Description: "Filter by publish or update date or date range (ISO 8601 date or range).", }, }, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - ghsaID, err := OptionalParam[string](args, "ghsaId") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid ghsaId: %v", err)), nil, nil - } - - typ, err := OptionalParam[string](args, "type") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid type: %v", err)), nil, nil - } - - cveID, err := OptionalParam[string](args, "cveId") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid cveId: %v", err)), nil, nil - } - - eco, err := OptionalParam[string](args, "ecosystem") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid ecosystem: %v", err)), nil, nil - } - - sev, err := OptionalParam[string](args, "severity") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid severity: %v", err)), nil, nil - } - - cwes, err := OptionalStringArrayParam(args, "cwes") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid cwes: %v", err)), nil, nil - } - - isWithdrawn, err := OptionalParam[bool](args, "isWithdrawn") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid isWithdrawn: %v", err)), nil, nil - } - - affects, err := OptionalParam[string](args, "affects") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid affects: %v", err)), nil, nil - } - - published, err := OptionalParam[string](args, "published") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid published: %v", err)), nil, nil - } - - updated, err := OptionalParam[string](args, "updated") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid updated: %v", err)), nil, nil - } - - modified, err := OptionalParam[string](args, "modified") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid modified: %v", err)), nil, nil - } - - opts := &github.ListGlobalSecurityAdvisoriesOptions{} - - if ghsaID != "" { - opts.GHSAID = &ghsaID - } - if typ != "" { - opts.Type = &typ - } - if cveID != "" { - opts.CVEID = &cveID - } - if eco != "" { - opts.Ecosystem = &eco - } - if sev != "" { - opts.Severity = &sev - } - if len(cwes) > 0 { - opts.CWEs = cwes - } - - if isWithdrawn { - opts.IsWithdrawn = &isWithdrawn - } - - if affects != "" { - opts.Affects = &affects - } - if published != "" { - opts.Published = &published - } - if updated != "" { - opts.Updated = &updated - } - if modified != "" { - opts.Modified = &modified - } - - advisories, resp, err := client.SecurityAdvisories.ListGlobalSecurityAdvisories(ctx, opts) - if err != nil { - return nil, nil, fmt.Errorf("failed to list global security advisories: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + ghsaID, err := OptionalParam[string](args, "ghsaId") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid ghsaId: %v", err)), nil, nil + } + + typ, err := OptionalParam[string](args, "type") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid type: %v", err)), nil, nil + } + + cveID, err := OptionalParam[string](args, "cveId") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid cveId: %v", err)), nil, nil + } + + eco, err := OptionalParam[string](args, "ecosystem") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid ecosystem: %v", err)), nil, nil + } + + sev, err := OptionalParam[string](args, "severity") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid severity: %v", err)), nil, nil + } + + cwes, err := OptionalStringArrayParam(args, "cwes") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid cwes: %v", err)), nil, nil + } + + isWithdrawn, err := OptionalParam[bool](args, "isWithdrawn") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid isWithdrawn: %v", err)), nil, nil + } + + affects, err := OptionalParam[string](args, "affects") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid affects: %v", err)), nil, nil + } + + published, err := OptionalParam[string](args, "published") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid published: %v", err)), nil, nil + } + + updated, err := OptionalParam[string](args, "updated") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid updated: %v", err)), nil, nil + } + + modified, err := OptionalParam[string](args, "modified") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid modified: %v", err)), nil, nil + } + + opts := &github.ListGlobalSecurityAdvisoriesOptions{} + + if ghsaID != "" { + opts.GHSAID = &ghsaID + } + if typ != "" { + opts.Type = &typ + } + if cveID != "" { + opts.CVEID = &cveID + } + if eco != "" { + opts.Ecosystem = &eco + } + if sev != "" { + opts.Severity = &sev + } + if len(cwes) > 0 { + opts.CWEs = cwes + } + + if isWithdrawn { + opts.IsWithdrawn = &isWithdrawn + } + + if affects != "" { + opts.Affects = &affects + } + if published != "" { + opts.Published = &published + } + if updated != "" { + opts.Updated = &updated + } + if modified != "" { + opts.Modified = &modified + } + + advisories, resp, err := client.SecurityAdvisories.ListGlobalSecurityAdvisories(ctx, opts) + if err != nil { + return nil, nil, fmt.Errorf("failed to list global security advisories: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return utils.NewToolResultError(fmt.Sprintf("failed to list advisories: %s", string(body))), nil, nil + } + + r, err := json.Marshal(advisories) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal advisories: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to list advisories: %s", string(body))), nil, nil - } - - r, err := json.Marshal(advisories) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal advisories: %w", err) - } - - return utils.NewToolResultText(string(r)), nil, nil - }) - - return tool, handler + }, + ) } -func ListRepositorySecurityAdvisories(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "list_repository_security_advisories", - Description: t("TOOL_LIST_REPOSITORY_SECURITY_ADVISORIES_DESCRIPTION", "List repository security advisories for a GitHub repository."), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_LIST_REPOSITORY_SECURITY_ADVISORIES_USER_TITLE", "List repository security advisories"), - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "The owner of the repository.", - }, - "repo": { - Type: "string", - Description: "The name of the repository.", - }, - "direction": { - Type: "string", - Description: "Sort direction.", - Enum: []any{"asc", "desc"}, - }, - "sort": { - Type: "string", - Description: "Sort field.", - Enum: []any{"created", "updated", "published"}, - }, - "state": { - Type: "string", - Description: "Filter by advisory state.", - Enum: []any{"triage", "draft", "published", "closed"}, +func ListRepositorySecurityAdvisories(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ + Name: "list_repository_security_advisories", + Description: t("TOOL_LIST_REPOSITORY_SECURITY_ADVISORIES_DESCRIPTION", "List repository security advisories for a GitHub repository."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_LIST_REPOSITORY_SECURITY_ADVISORIES_USER_TITLE", "List repository security advisories"), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "The owner of the repository.", + }, + "repo": { + Type: "string", + Description: "The name of the repository.", + }, + "direction": { + Type: "string", + Description: "Sort direction.", + Enum: []any{"asc", "desc"}, + }, + "sort": { + Type: "string", + Description: "Sort field.", + Enum: []any{"created", "updated", "published"}, + }, + "state": { + Type: "string", + Description: "Filter by advisory state.", + Enum: []any{"triage", "draft", "published", "closed"}, + }, }, + Required: []string{"owner", "repo"}, }, - Required: []string{"owner", "repo"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - direction, err := OptionalParam[string](args, "direction") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sortField, err := OptionalParam[string](args, "sort") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - state, err := OptionalParam[string](args, "state") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - opts := &github.ListRepositorySecurityAdvisoriesOptions{} - if direction != "" { - opts.Direction = direction - } - if sortField != "" { - opts.Sort = sortField - } - if state != "" { - opts.State = state - } - - advisories, resp, err := client.SecurityAdvisories.ListRepositorySecurityAdvisories(ctx, owner, repo, opts) - if err != nil { - return nil, nil, fmt.Errorf("failed to list repository security advisories: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + direction, err := OptionalParam[string](args, "direction") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + sortField, err := OptionalParam[string](args, "sort") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + state, err := OptionalParam[string](args, "state") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + opts := &github.ListRepositorySecurityAdvisoriesOptions{} + if direction != "" { + opts.Direction = direction + } + if sortField != "" { + opts.Sort = sortField + } + if state != "" { + opts.State = state + } + + advisories, resp, err := client.SecurityAdvisories.ListRepositorySecurityAdvisories(ctx, owner, repo, opts) + if err != nil { + return nil, nil, fmt.Errorf("failed to list repository security advisories: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return utils.NewToolResultError(fmt.Sprintf("failed to list repository advisories: %s", string(body))), nil, nil + } + + r, err := json.Marshal(advisories) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal advisories: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to list repository advisories: %s", string(body))), nil, nil - } - - r, err := json.Marshal(advisories) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal advisories: %w", err) - } - - return utils.NewToolResultText(string(r)), nil, nil - }) - - return tool, handler + }, + ) } -func GetGlobalSecurityAdvisory(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "get_global_security_advisory", - Description: t("TOOL_GET_GLOBAL_SECURITY_ADVISORY_DESCRIPTION", "Get a global security advisory"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_GET_GLOBAL_SECURITY_ADVISORY_USER_TITLE", "Get a global security advisory"), - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "ghsaId": { - Type: "string", - Description: "GitHub Security Advisory ID (format: GHSA-xxxx-xxxx-xxxx).", +func GetGlobalSecurityAdvisory(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ + Name: "get_global_security_advisory", + Description: t("TOOL_GET_GLOBAL_SECURITY_ADVISORY_DESCRIPTION", "Get a global security advisory"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_GET_GLOBAL_SECURITY_ADVISORY_USER_TITLE", "Get a global security advisory"), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "ghsaId": { + Type: "string", + Description: "GitHub Security Advisory ID (format: GHSA-xxxx-xxxx-xxxx).", + }, }, + Required: []string{"ghsaId"}, }, - Required: []string{"ghsaId"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - ghsaID, err := RequiredParam[string](args, "ghsaId") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid ghsaId: %v", err)), nil, nil - } - - advisory, resp, err := client.SecurityAdvisories.GetGlobalSecurityAdvisories(ctx, ghsaID) - if err != nil { - return nil, nil, fmt.Errorf("failed to get advisory: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + ghsaID, err := RequiredParam[string](args, "ghsaId") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid ghsaId: %v", err)), nil, nil + } + + advisory, resp, err := client.SecurityAdvisories.GetGlobalSecurityAdvisories(ctx, ghsaID) + if err != nil { + return nil, nil, fmt.Errorf("failed to get advisory: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return utils.NewToolResultError(fmt.Sprintf("failed to get advisory: %s", string(body))), nil, nil + } + + r, err := json.Marshal(advisory) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal advisory: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to get advisory: %s", string(body))), nil, nil - } - - r, err := json.Marshal(advisory) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal advisory: %w", err) - } - - return utils.NewToolResultText(string(r)), nil, nil - }) - - return tool, handler + }, + ) } -func ListOrgRepositorySecurityAdvisories(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "list_org_repository_security_advisories", - Description: t("TOOL_LIST_ORG_REPOSITORY_SECURITY_ADVISORIES_DESCRIPTION", "List repository security advisories for a GitHub organization."), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_LIST_ORG_REPOSITORY_SECURITY_ADVISORIES_USER_TITLE", "List org repository security advisories"), - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "org": { - Type: "string", - Description: "The organization login.", - }, - "direction": { - Type: "string", - Description: "Sort direction.", - Enum: []any{"asc", "desc"}, - }, - "sort": { - Type: "string", - Description: "Sort field.", - Enum: []any{"created", "updated", "published"}, - }, - "state": { - Type: "string", - Description: "Filter by advisory state.", - Enum: []any{"triage", "draft", "published", "closed"}, +func ListOrgRepositorySecurityAdvisories(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ + Name: "list_org_repository_security_advisories", + Description: t("TOOL_LIST_ORG_REPOSITORY_SECURITY_ADVISORIES_DESCRIPTION", "List repository security advisories for a GitHub organization."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_LIST_ORG_REPOSITORY_SECURITY_ADVISORIES_USER_TITLE", "List org repository security advisories"), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "org": { + Type: "string", + Description: "The organization login.", + }, + "direction": { + Type: "string", + Description: "Sort direction.", + Enum: []any{"asc", "desc"}, + }, + "sort": { + Type: "string", + Description: "Sort field.", + Enum: []any{"created", "updated", "published"}, + }, + "state": { + Type: "string", + Description: "Filter by advisory state.", + Enum: []any{"triage", "draft", "published", "closed"}, + }, }, + Required: []string{"org"}, }, - Required: []string{"org"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - org, err := RequiredParam[string](args, "org") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - direction, err := OptionalParam[string](args, "direction") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sortField, err := OptionalParam[string](args, "sort") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - state, err := OptionalParam[string](args, "state") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - opts := &github.ListRepositorySecurityAdvisoriesOptions{} - if direction != "" { - opts.Direction = direction - } - if sortField != "" { - opts.Sort = sortField - } - if state != "" { - opts.State = state - } - - advisories, resp, err := client.SecurityAdvisories.ListRepositorySecurityAdvisoriesForOrg(ctx, org, opts) - if err != nil { - return nil, nil, fmt.Errorf("failed to list organization repository security advisories: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + org, err := RequiredParam[string](args, "org") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + direction, err := OptionalParam[string](args, "direction") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + sortField, err := OptionalParam[string](args, "sort") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + state, err := OptionalParam[string](args, "state") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + opts := &github.ListRepositorySecurityAdvisoriesOptions{} + if direction != "" { + opts.Direction = direction + } + if sortField != "" { + opts.Sort = sortField + } + if state != "" { + opts.State = state + } + + advisories, resp, err := client.SecurityAdvisories.ListRepositorySecurityAdvisoriesForOrg(ctx, org, opts) + if err != nil { + return nil, nil, fmt.Errorf("failed to list organization repository security advisories: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return utils.NewToolResultError(fmt.Sprintf("failed to list organization repository advisories: %s", string(body))), nil, nil + } + + r, err := json.Marshal(advisories) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal advisories: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to list organization repository advisories: %s", string(body))), nil, nil - } - - r, err := json.Marshal(advisories) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal advisories: %w", err) - } - - return utils.NewToolResultText(string(r)), nil, nil - }) - - return tool, handler + }, + ) } diff --git a/pkg/github/security_advisories_test.go b/pkg/github/security_advisories_test.go index ed632d0be..16506a3e8 100644 --- a/pkg/github/security_advisories_test.go +++ b/pkg/github/security_advisories_test.go @@ -16,8 +16,8 @@ import ( ) func Test_ListGlobalSecurityAdvisories(t *testing.T) { - mockClient := github.NewClient(nil) - tool, _ := ListGlobalSecurityAdvisories(stubGetClientFn(mockClient), translations.NullTranslationHelper) + toolDef := ListGlobalSecurityAdvisories(translations.NullTranslationHelper) + tool := toolDef.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "list_global_security_advisories", tool.Name) @@ -103,13 +103,14 @@ func Test_ListGlobalSecurityAdvisories(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ListGlobalSecurityAdvisories(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{GetClient: stubGetClientFn(client)} + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) - // Call handler - note the new signature with 3 parameters and 3 return values - result, _, err := handler(context.Background(), &request, tc.requestArgs) + // Call handler + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -139,8 +140,8 @@ func Test_ListGlobalSecurityAdvisories(t *testing.T) { } func Test_GetGlobalSecurityAdvisory(t *testing.T) { - mockClient := github.NewClient(nil) - tool, _ := GetGlobalSecurityAdvisory(stubGetClientFn(mockClient), translations.NullTranslationHelper) + toolDef := GetGlobalSecurityAdvisory(translations.NullTranslationHelper) + tool := toolDef.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "get_global_security_advisory", tool.Name) @@ -223,13 +224,14 @@ func Test_GetGlobalSecurityAdvisory(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetGlobalSecurityAdvisory(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{GetClient: stubGetClientFn(client)} + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) - // Call handler - note the new signature with 3 parameters and 3 return values - result, _, err := handler(context.Background(), &request, tc.requestArgs) + // Call handler + result, err := handler(context.Background(), &request) // Verify results if tc.expectError { @@ -254,8 +256,8 @@ func Test_GetGlobalSecurityAdvisory(t *testing.T) { func Test_ListRepositorySecurityAdvisories(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListRepositorySecurityAdvisories(stubGetClientFn(mockClient), translations.NullTranslationHelper) + toolDef := ListRepositorySecurityAdvisories(translations.NullTranslationHelper) + tool := toolDef.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "list_repository_security_advisories", tool.Name) @@ -370,12 +372,13 @@ func Test_ListRepositorySecurityAdvisories(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - _, handler := ListRepositorySecurityAdvisories(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{GetClient: stubGetClientFn(client)} + handler := toolDef.Handler(deps) request := createMCPRequest(tc.requestArgs) - // Call handler - note the new signature with 3 parameters and 3 return values - result, _, err := handler(context.Background(), &request, tc.requestArgs) + // Call handler + result, err := handler(context.Background(), &request) if tc.expectError { require.Error(t, err) @@ -403,8 +406,8 @@ func Test_ListRepositorySecurityAdvisories(t *testing.T) { func Test_ListOrgRepositorySecurityAdvisories(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListOrgRepositorySecurityAdvisories(stubGetClientFn(mockClient), translations.NullTranslationHelper) + toolDef := ListOrgRepositorySecurityAdvisories(translations.NullTranslationHelper) + tool := toolDef.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "list_org_repository_security_advisories", tool.Name) @@ -514,12 +517,13 @@ func Test_ListOrgRepositorySecurityAdvisories(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - _, handler := ListOrgRepositorySecurityAdvisories(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{GetClient: stubGetClientFn(client)} + handler := toolDef.Handler(deps) request := createMCPRequest(tc.requestArgs) - // Call handler - note the new signature with 3 parameters and 3 return values - result, _, err := handler(context.Background(), &request, tc.requestArgs) + // Call handler + result, err := handler(context.Background(), &request) if tc.expectError { require.Error(t, err) diff --git a/pkg/github/tools.go b/pkg/github/tools.go index d4f473724..6af9ce40d 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -321,10 +321,10 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG securityAdvisories := toolsets.NewToolset(ToolsetMetadataSecurityAdvisories.ID, ToolsetMetadataSecurityAdvisories.Description). SetDependencies(deps). AddReadTools( - toolsets.NewServerToolLegacy(ListGlobalSecurityAdvisories(getClient, t)), - toolsets.NewServerToolLegacy(GetGlobalSecurityAdvisory(getClient, t)), - toolsets.NewServerToolLegacy(ListRepositorySecurityAdvisories(getClient, t)), - toolsets.NewServerToolLegacy(ListOrgRepositorySecurityAdvisories(getClient, t)), + ListGlobalSecurityAdvisories(t), + GetGlobalSecurityAdvisory(t), + ListRepositorySecurityAdvisories(t), + ListOrgRepositorySecurityAdvisories(t), ) // // Keep experiments alive so the system doesn't error out when it's always enabled