From 3b7fa6d5d2c1452630093bb9ff4046dc0b83fe17 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Sat, 13 Dec 2025 15:37:16 +0100 Subject: [PATCH] Refactor actions.go to use NewTool pattern Convert all 14 tool functions in actions.go to use the NewTool pattern with ToolDependencies for dependency injection. This is part of a broader effort to standardize the tool implementation pattern across the codebase. Changes: - ListWorkflows, ListWorkflowRuns, RunWorkflow, GetWorkflowRun - GetWorkflowRunLogs, ListWorkflowJobs, GetJobLogs - RerunWorkflowRun, RerunFailedJobs, CancelWorkflowRun - ListWorkflowRunArtifacts, DownloadWorkflowRunArtifact - DeleteWorkflowRunLogs, GetWorkflowRunUsage The new pattern: - Takes only translations.TranslationHelperFunc as parameter - Returns toolsets.ServerTool with Tool and Handler - Handler receives ToolDependencies for client access - Enables better testability and consistent interface Co-authored-by: Adam Holt --- pkg/github/actions.go | 1265 +++++++++++++++++++----------------- pkg/github/actions_test.go | 329 +++++----- pkg/github/tools.go | 28 +- 3 files changed, 845 insertions(+), 777 deletions(-) diff --git a/pkg/github/actions.go b/pkg/github/actions.go index 81ed55296..e9c7c11a8 100644 --- a/pkg/github/actions.go +++ b/pkg/github/actions.go @@ -11,6 +11,7 @@ import ( "github.com/github/github-mcp-server/internal/profiler" buffer "github.com/github/github-mcp-server/pkg/buffer" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -24,8 +25,9 @@ const ( ) // ListWorkflows creates a tool to list workflows in a repository -func ListWorkflows(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ListWorkflows(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "list_workflows", Description: t("TOOL_LIST_WORKFLOWS_DESCRIPTION", "List workflows in a repository"), Annotations: &mcp.ToolAnnotations{ @@ -47,51 +49,55 @@ func ListWorkflows(getClient GetClientFn, t translations.TranslationHelperFunc) Required: []string{"owner", "repo"}, }), }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - // Get optional pagination parameters - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + // Get optional pagination parameters + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Set up list options - opts := &github.ListOptions{ - PerPage: pagination.PerPage, - Page: pagination.Page, - } + // Set up list options + opts := &github.ListOptions{ + PerPage: pagination.PerPage, + Page: pagination.Page, + } - workflows, resp, err := client.Actions.ListWorkflows(ctx, owner, repo, opts) - if err != nil { - return nil, nil, fmt.Errorf("failed to list workflows: %w", err) - } - defer func() { _ = resp.Body.Close() }() + workflows, resp, err := client.Actions.ListWorkflows(ctx, owner, repo, opts) + if err != nil { + return nil, nil, fmt.Errorf("failed to list workflows: %w", err) + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(workflows) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + r, err := json.Marshal(workflows) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } - return utils.NewToolResultText(string(r)), nil, nil - } + return utils.NewToolResultText(string(r)), nil, nil + } + }, + ) } // ListWorkflowRuns creates a tool to list workflow runs for a specific workflow -func ListWorkflowRuns(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ListWorkflowRuns(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "list_workflow_runs", Description: t("TOOL_LIST_WORKFLOW_RUNS_DESCRIPTION", "List workflow runs for a specific workflow"), Annotations: &mcp.ToolAnnotations{ @@ -168,79 +174,83 @@ func ListWorkflowRuns(getClient GetClientFn, t translations.TranslationHelperFun Required: []string{"owner", "repo", "workflow_id"}, }), }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - workflowID, err := RequiredParam[string](args, "workflow_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - // Get optional filtering parameters - actor, err := OptionalParam[string](args, "actor") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - branch, err := OptionalParam[string](args, "branch") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - event, err := OptionalParam[string](args, "event") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - status, err := OptionalParam[string](args, "status") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + workflowID, err := RequiredParam[string](args, "workflow_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Get optional pagination parameters - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + // Get optional filtering parameters + actor, err := OptionalParam[string](args, "actor") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + branch, err := OptionalParam[string](args, "branch") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + event, err := OptionalParam[string](args, "event") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + status, err := OptionalParam[string](args, "status") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + // Get optional pagination parameters + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Set up list options - opts := &github.ListWorkflowRunsOptions{ - Actor: actor, - Branch: branch, - Event: event, - Status: status, - ListOptions: github.ListOptions{ - PerPage: pagination.PerPage, - Page: pagination.Page, - }, - } + // Set up list options + opts := &github.ListWorkflowRunsOptions{ + Actor: actor, + Branch: branch, + Event: event, + Status: status, + ListOptions: github.ListOptions{ + PerPage: pagination.PerPage, + Page: pagination.Page, + }, + } - workflowRuns, resp, err := client.Actions.ListWorkflowRunsByFileName(ctx, owner, repo, workflowID, opts) - if err != nil { - return nil, nil, fmt.Errorf("failed to list workflow runs: %w", err) - } - defer func() { _ = resp.Body.Close() }() + workflowRuns, resp, err := client.Actions.ListWorkflowRunsByFileName(ctx, owner, repo, workflowID, opts) + if err != nil { + return nil, nil, fmt.Errorf("failed to list workflow runs: %w", err) + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(workflowRuns) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + r, err := json.Marshal(workflowRuns) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } - return utils.NewToolResultText(string(r)), nil, nil - } + return utils.NewToolResultText(string(r)), nil, nil + } + }, + ) } // RunWorkflow creates a tool to run an Actions workflow -func RunWorkflow(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func RunWorkflow(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "run_workflow", Description: t("TOOL_RUN_WORKFLOW_DESCRIPTION", "Run an Actions workflow by workflow ID or filename"), Annotations: &mcp.ToolAnnotations{ @@ -274,80 +284,84 @@ func RunWorkflow(getClient GetClientFn, t translations.TranslationHelperFunc) (m Required: []string{"owner", "repo", "workflow_id", "ref"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - workflowID, err := RequiredParam[string](args, "workflow_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - ref, err := RequiredParam[string](args, "ref") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - // Get optional inputs parameter - var inputs map[string]interface{} - if requestInputs, ok := args["inputs"]; ok { - if inputsMap, ok := requestInputs.(map[string]interface{}); ok { - inputs = inputsMap + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + workflowID, err := RequiredParam[string](args, "workflow_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + ref, err := RequiredParam[string](args, "ref") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil } - } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + // Get optional inputs parameter + var inputs map[string]interface{} + if requestInputs, ok := args["inputs"]; ok { + if inputsMap, ok := requestInputs.(map[string]interface{}); ok { + inputs = inputsMap + } + } - event := github.CreateWorkflowDispatchEventRequest{ - Ref: ref, - Inputs: inputs, - } + event := github.CreateWorkflowDispatchEventRequest{ + Ref: ref, + Inputs: inputs, + } - var resp *github.Response - var workflowType string + var resp *github.Response + var workflowType string - if workflowIDInt, parseErr := strconv.ParseInt(workflowID, 10, 64); parseErr == nil { - resp, err = client.Actions.CreateWorkflowDispatchEventByID(ctx, owner, repo, workflowIDInt, event) - workflowType = "workflow_id" - } else { - resp, err = client.Actions.CreateWorkflowDispatchEventByFileName(ctx, owner, repo, workflowID, event) - workflowType = "workflow_file" - } + if workflowIDInt, parseErr := strconv.ParseInt(workflowID, 10, 64); parseErr == nil { + resp, err = client.Actions.CreateWorkflowDispatchEventByID(ctx, owner, repo, workflowIDInt, event) + workflowType = "workflow_id" + } else { + resp, err = client.Actions.CreateWorkflowDispatchEventByFileName(ctx, owner, repo, workflowID, event) + workflowType = "workflow_file" + } - if err != nil { - return nil, nil, fmt.Errorf("failed to run workflow: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - result := map[string]any{ - "message": "Workflow run has been queued", - "workflow_type": workflowType, - "workflow_id": workflowID, - "ref": ref, - "inputs": inputs, - "status": resp.Status, - "status_code": resp.StatusCode, - } + if err != nil { + return nil, nil, fmt.Errorf("failed to run workflow: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + result := map[string]any{ + "message": "Workflow run has been queued", + "workflow_type": workflowType, + "workflow_id": workflowID, + "ref": ref, + "inputs": inputs, + "status": resp.Status, + "status_code": resp.StatusCode, + } - r, err := json.Marshal(result) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + r, err := json.Marshal(result) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } - return utils.NewToolResultText(string(r)), nil, nil - } + return utils.NewToolResultText(string(r)), nil, nil + } + }, + ) } // GetWorkflowRun creates a tool to get details of a specific workflow run -func GetWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetWorkflowRun(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "get_workflow_run", Description: t("TOOL_GET_WORKFLOW_RUN_DESCRIPTION", "Get details of a specific workflow run"), Annotations: &mcp.ToolAnnotations{ @@ -373,44 +387,48 @@ func GetWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFunc) Required: []string{"owner", "repo", "run_id"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runIDInt, err := RequiredInt(args, "run_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runID := int64(runIDInt) + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runIDInt, err := RequiredInt(args, "run_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runID := int64(runIDInt) - workflowRun, resp, err := client.Actions.GetWorkflowRunByID(ctx, owner, repo, runID) - if err != nil { - return nil, nil, fmt.Errorf("failed to get workflow run: %w", err) - } - defer func() { _ = resp.Body.Close() }() + workflowRun, resp, err := client.Actions.GetWorkflowRunByID(ctx, owner, repo, runID) + if err != nil { + return nil, nil, fmt.Errorf("failed to get workflow run: %w", err) + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(workflowRun) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + r, err := json.Marshal(workflowRun) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } - return utils.NewToolResultText(string(r)), nil, nil - } + return utils.NewToolResultText(string(r)), nil, nil + } + }, + ) } // GetWorkflowRunLogs creates a tool to download logs for a specific workflow run -func GetWorkflowRunLogs(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetWorkflowRunLogs(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "get_workflow_run_logs", Description: t("TOOL_GET_WORKFLOW_RUN_LOGS_DESCRIPTION", "Download logs for a specific workflow run (EXPENSIVE: downloads ALL logs as ZIP. Consider using get_job_logs with failed_only=true for debugging failed jobs)"), Annotations: &mcp.ToolAnnotations{ @@ -436,54 +454,58 @@ func GetWorkflowRunLogs(getClient GetClientFn, t translations.TranslationHelperF Required: []string{"owner", "repo", "run_id"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runIDInt, err := RequiredInt(args, "run_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runID := int64(runIDInt) + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runIDInt, err := RequiredInt(args, "run_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runID := int64(runIDInt) - // Get the download URL for the logs - url, resp, err := client.Actions.GetWorkflowRunLogs(ctx, owner, repo, runID, 1) - if err != nil { - return nil, nil, fmt.Errorf("failed to get workflow run logs: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - // Create response with the logs URL and information - result := map[string]any{ - "logs_url": url.String(), - "message": "Workflow run logs are available for download", - "note": "The logs_url provides a download link for the complete workflow run logs as a ZIP archive. You can download this archive to extract and examine individual job logs.", - "warning": "This downloads ALL logs as a ZIP file which can be large and expensive. For debugging failed jobs, consider using get_job_logs with failed_only=true and run_id instead.", - "optimization_tip": "Use: get_job_logs with parameters {run_id: " + fmt.Sprintf("%d", runID) + ", failed_only: true} for more efficient failed job debugging", - } + // Get the download URL for the logs + url, resp, err := client.Actions.GetWorkflowRunLogs(ctx, owner, repo, runID, 1) + if err != nil { + return nil, nil, fmt.Errorf("failed to get workflow run logs: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + // Create response with the logs URL and information + result := map[string]any{ + "logs_url": url.String(), + "message": "Workflow run logs are available for download", + "note": "The logs_url provides a download link for the complete workflow run logs as a ZIP archive. You can download this archive to extract and examine individual job logs.", + "warning": "This downloads ALL logs as a ZIP file which can be large and expensive. For debugging failed jobs, consider using get_job_logs with failed_only=true and run_id instead.", + "optimization_tip": "Use: get_job_logs with parameters {run_id: " + fmt.Sprintf("%d", runID) + ", failed_only: true} for more efficient failed job debugging", + } - r, err := json.Marshal(result) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + r, err := json.Marshal(result) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } - return utils.NewToolResultText(string(r)), nil, nil - } + return utils.NewToolResultText(string(r)), nil, nil + } + }, + ) } // ListWorkflowJobs creates a tool to list jobs for a specific workflow run -func ListWorkflowJobs(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ListWorkflowJobs(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "list_workflow_jobs", Description: t("TOOL_LIST_WORKFLOW_JOBS_DESCRIPTION", "List jobs for a specific workflow run"), Annotations: &mcp.ToolAnnotations{ @@ -514,71 +536,75 @@ func ListWorkflowJobs(getClient GetClientFn, t translations.TranslationHelperFun Required: []string{"owner", "repo", "run_id"}, }), }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runIDInt, err := RequiredInt(args, "run_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runID := int64(runIDInt) + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - // Get optional filtering parameters - filter, err := OptionalParam[string](args, "filter") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runIDInt, err := RequiredInt(args, "run_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runID := int64(runIDInt) - // Get optional pagination parameters - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + // Get optional filtering parameters + filter, err := OptionalParam[string](args, "filter") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + // Get optional pagination parameters + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Set up list options - opts := &github.ListWorkflowJobsOptions{ - Filter: filter, - ListOptions: github.ListOptions{ - PerPage: pagination.PerPage, - Page: pagination.Page, - }, - } + // Set up list options + opts := &github.ListWorkflowJobsOptions{ + Filter: filter, + ListOptions: github.ListOptions{ + PerPage: pagination.PerPage, + Page: pagination.Page, + }, + } - jobs, resp, err := client.Actions.ListWorkflowJobs(ctx, owner, repo, runID, opts) - if err != nil { - return nil, nil, fmt.Errorf("failed to list workflow jobs: %w", err) - } - defer func() { _ = resp.Body.Close() }() + jobs, resp, err := client.Actions.ListWorkflowJobs(ctx, owner, repo, runID, opts) + if err != nil { + return nil, nil, fmt.Errorf("failed to list workflow jobs: %w", err) + } + defer func() { _ = resp.Body.Close() }() - // Add optimization tip for failed job debugging - response := map[string]any{ - "jobs": jobs, - "optimization_tip": "For debugging failed jobs, consider using get_job_logs with failed_only=true and run_id=" + fmt.Sprintf("%d", runID) + " to get logs directly without needing to list jobs first", - } + // Add optimization tip for failed job debugging + response := map[string]any{ + "jobs": jobs, + "optimization_tip": "For debugging failed jobs, consider using get_job_logs with failed_only=true and run_id=" + fmt.Sprintf("%d", runID) + " to get logs directly without needing to list jobs first", + } - r, err := json.Marshal(response) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + r, err := json.Marshal(response) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } - return utils.NewToolResultText(string(r)), nil, nil - } + return utils.NewToolResultText(string(r)), nil, nil + } + }, + ) } // GetJobLogs creates a tool to download logs for a specific workflow job or efficiently get all failed job logs for a workflow run -func GetJobLogs(getClient GetClientFn, t translations.TranslationHelperFunc, contentWindowSize int) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetJobLogs(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "get_job_logs", Description: t("TOOL_GET_JOB_LOGS_DESCRIPTION", "Download logs for a specific workflow job or efficiently get all failed job logs for a workflow run"), Annotations: &mcp.ToolAnnotations{ @@ -621,65 +647,68 @@ func GetJobLogs(getClient GetClientFn, t translations.TranslationHelperFunc, con Required: []string{"owner", "repo"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - // Get optional parameters - jobID, err := OptionalIntParam(args, "job_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runID, err := OptionalIntParam(args, "run_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - failedOnly, err := OptionalParam[bool](args, "failed_only") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - returnContent, err := OptionalParam[bool](args, "return_content") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - tailLines, err := OptionalIntParam(args, "tail_lines") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - // Default to 500 lines if not specified - if tailLines == 0 { - tailLines = 500 - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + // Get optional parameters + jobID, err := OptionalIntParam(args, "job_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runID, err := OptionalIntParam(args, "run_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + failedOnly, err := OptionalParam[bool](args, "failed_only") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + returnContent, err := OptionalParam[bool](args, "return_content") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + tailLines, err := OptionalIntParam(args, "tail_lines") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + // Default to 500 lines if not specified + if tailLines == 0 { + tailLines = 500 + } - // Validate parameters - if failedOnly && runID == 0 { - return utils.NewToolResultError("run_id is required when failed_only is true"), nil, nil - } - if !failedOnly && jobID == 0 { - return utils.NewToolResultError("job_id is required when failed_only is false"), nil, nil - } + // Validate parameters + if failedOnly && runID == 0 { + return utils.NewToolResultError("run_id is required when failed_only is true"), nil, nil + } + if !failedOnly && jobID == 0 { + return utils.NewToolResultError("job_id is required when failed_only is false"), nil, nil + } - if failedOnly && runID > 0 { - // Handle failed-only mode: get logs for all failed jobs in the workflow run - return handleFailedJobLogs(ctx, client, owner, repo, int64(runID), returnContent, tailLines, contentWindowSize) - } else if jobID > 0 { - // Handle single job mode - return handleSingleJobLogs(ctx, client, owner, repo, int64(jobID), returnContent, tailLines, contentWindowSize) - } + if failedOnly && runID > 0 { + // Handle failed-only mode: get logs for all failed jobs in the workflow run + return handleFailedJobLogs(ctx, client, owner, repo, int64(runID), returnContent, tailLines, deps.ContentWindowSize) + } else if jobID > 0 { + // Handle single job mode + return handleSingleJobLogs(ctx, client, owner, repo, int64(jobID), returnContent, tailLines, deps.ContentWindowSize) + } - return utils.NewToolResultError("Either job_id must be provided for single job logs, or run_id with failed_only=true for failed job logs"), nil, nil - } + return utils.NewToolResultError("Either job_id must be provided for single job logs, or run_id with failed_only=true for failed job logs"), nil, nil + } + }, + ) } // handleFailedJobLogs gets logs for all failed jobs in a workflow run @@ -837,8 +866,9 @@ func downloadLogContent(ctx context.Context, logURL string, tailLines int, maxLi } // RerunWorkflowRun creates a tool to re-run an entire workflow run -func RerunWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func RerunWorkflowRun(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "rerun_workflow_run", Description: t("TOOL_RERUN_WORKFLOW_RUN_DESCRIPTION", "Re-run an entire workflow run"), Annotations: &mcp.ToolAnnotations{ @@ -864,51 +894,55 @@ func RerunWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFun Required: []string{"owner", "repo", "run_id"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runIDInt, err := RequiredInt(args, "run_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runID := int64(runIDInt) + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runIDInt, err := RequiredInt(args, "run_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runID := int64(runIDInt) - resp, err := client.Actions.RerunWorkflowByID(ctx, owner, repo, runID) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to rerun workflow run", resp, err), nil, nil - } - defer func() { _ = resp.Body.Close() }() + resp, err := client.Actions.RerunWorkflowByID(ctx, owner, repo, runID) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to rerun workflow run", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() - result := map[string]any{ - "message": "Workflow run has been queued for re-run", - "run_id": runID, - "status": resp.Status, - "status_code": resp.StatusCode, - } + result := map[string]any{ + "message": "Workflow run has been queued for re-run", + "run_id": runID, + "status": resp.Status, + "status_code": resp.StatusCode, + } - r, err := json.Marshal(result) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + r, err := json.Marshal(result) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } - return utils.NewToolResultText(string(r)), nil, nil - } + return utils.NewToolResultText(string(r)), nil, nil + } + }, + ) } // RerunFailedJobs creates a tool to re-run only the failed jobs in a workflow run -func RerunFailedJobs(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func RerunFailedJobs(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "rerun_failed_jobs", Description: t("TOOL_RERUN_FAILED_JOBS_DESCRIPTION", "Re-run only the failed jobs in a workflow run"), Annotations: &mcp.ToolAnnotations{ @@ -934,51 +968,55 @@ func RerunFailedJobs(getClient GetClientFn, t translations.TranslationHelperFunc Required: []string{"owner", "repo", "run_id"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runIDInt, err := RequiredInt(args, "run_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runID := int64(runIDInt) + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runIDInt, err := RequiredInt(args, "run_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runID := int64(runIDInt) - resp, err := client.Actions.RerunFailedJobsByID(ctx, owner, repo, runID) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to rerun failed jobs", resp, err), nil, nil - } - defer func() { _ = resp.Body.Close() }() + resp, err := client.Actions.RerunFailedJobsByID(ctx, owner, repo, runID) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to rerun failed jobs", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() - result := map[string]any{ - "message": "Failed jobs have been queued for re-run", - "run_id": runID, - "status": resp.Status, - "status_code": resp.StatusCode, - } + result := map[string]any{ + "message": "Failed jobs have been queued for re-run", + "run_id": runID, + "status": resp.Status, + "status_code": resp.StatusCode, + } - r, err := json.Marshal(result) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + r, err := json.Marshal(result) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } - return utils.NewToolResultText(string(r)), nil, nil - } + return utils.NewToolResultText(string(r)), nil, nil + } + }, + ) } // CancelWorkflowRun creates a tool to cancel a workflow run -func CancelWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func CancelWorkflowRun(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "cancel_workflow_run", Description: t("TOOL_CANCEL_WORKFLOW_RUN_DESCRIPTION", "Cancel a workflow run"), Annotations: &mcp.ToolAnnotations{ @@ -1004,53 +1042,57 @@ func CancelWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFu Required: []string{"owner", "repo", "run_id"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runIDInt, err := RequiredInt(args, "run_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runID := int64(runIDInt) + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runIDInt, err := RequiredInt(args, "run_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runID := int64(runIDInt) - resp, err := client.Actions.CancelWorkflowRunByID(ctx, owner, repo, runID) - if err != nil { - if _, ok := err.(*github.AcceptedError); !ok { - return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to cancel workflow run", resp, err), nil, nil + resp, err := client.Actions.CancelWorkflowRunByID(ctx, owner, repo, runID) + if err != nil { + if _, ok := err.(*github.AcceptedError); !ok { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to cancel workflow run", resp, err), nil, nil + } } - } - defer func() { _ = resp.Body.Close() }() + defer func() { _ = resp.Body.Close() }() - result := map[string]any{ - "message": "Workflow run has been cancelled", - "run_id": runID, - "status": resp.Status, - "status_code": resp.StatusCode, - } + result := map[string]any{ + "message": "Workflow run has been cancelled", + "run_id": runID, + "status": resp.Status, + "status_code": resp.StatusCode, + } - r, err := json.Marshal(result) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + r, err := json.Marshal(result) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } - return utils.NewToolResultText(string(r)), nil, nil - } + return utils.NewToolResultText(string(r)), nil, nil + } + }, + ) } // ListWorkflowRunArtifacts creates a tool to list artifacts for a workflow run -func ListWorkflowRunArtifacts(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ListWorkflowRunArtifacts(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "list_workflow_run_artifacts", Description: t("TOOL_LIST_WORKFLOW_RUN_ARTIFACTS_DESCRIPTION", "List artifacts for a workflow run"), Annotations: &mcp.ToolAnnotations{ @@ -1076,56 +1118,60 @@ func ListWorkflowRunArtifacts(getClient GetClientFn, t translations.TranslationH Required: []string{"owner", "repo", "run_id"}, }), }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runIDInt, err := RequiredInt(args, "run_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runID := int64(runIDInt) + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - // Get optional pagination parameters - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runIDInt, err := RequiredInt(args, "run_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runID := int64(runIDInt) - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + // Get optional pagination parameters + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Set up list options - opts := &github.ListOptions{ - PerPage: pagination.PerPage, - Page: pagination.Page, - } + // Set up list options + opts := &github.ListOptions{ + PerPage: pagination.PerPage, + Page: pagination.Page, + } - artifacts, resp, err := client.Actions.ListWorkflowRunArtifacts(ctx, owner, repo, runID, opts) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to list workflow run artifacts", resp, err), nil, nil - } - defer func() { _ = resp.Body.Close() }() + artifacts, resp, err := client.Actions.ListWorkflowRunArtifacts(ctx, owner, repo, runID, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to list workflow run artifacts", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(artifacts) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + r, err := json.Marshal(artifacts) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } - return utils.NewToolResultText(string(r)), nil, nil - } + return utils.NewToolResultText(string(r)), nil, nil + } + }, + ) } // DownloadWorkflowRunArtifact creates a tool to download a workflow run artifact -func DownloadWorkflowRunArtifact(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func DownloadWorkflowRunArtifact(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "download_workflow_run_artifact", Description: t("TOOL_DOWNLOAD_WORKFLOW_RUN_ARTIFACT_DESCRIPTION", "Get download URL for a workflow run artifact"), Annotations: &mcp.ToolAnnotations{ @@ -1151,53 +1197,57 @@ func DownloadWorkflowRunArtifact(getClient GetClientFn, t translations.Translati Required: []string{"owner", "repo", "artifact_id"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - artifactIDInt, err := RequiredInt(args, "artifact_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - artifactID := int64(artifactIDInt) + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + artifactIDInt, err := RequiredInt(args, "artifact_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + artifactID := int64(artifactIDInt) - // Get the download URL for the artifact - url, resp, err := client.Actions.DownloadArtifact(ctx, owner, repo, artifactID, 1) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get artifact download URL", resp, err), nil, nil - } - defer func() { _ = resp.Body.Close() }() - - // Create response with the download URL and information - result := map[string]any{ - "download_url": url.String(), - "message": "Artifact is available for download", - "note": "The download_url provides a download link for the artifact as a ZIP archive. The link is temporary and expires after a short time.", - "artifact_id": artifactID, - } + // Get the download URL for the artifact + url, resp, err := client.Actions.DownloadArtifact(ctx, owner, repo, artifactID, 1) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get artifact download URL", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + // Create response with the download URL and information + result := map[string]any{ + "download_url": url.String(), + "message": "Artifact is available for download", + "note": "The download_url provides a download link for the artifact as a ZIP archive. The link is temporary and expires after a short time.", + "artifact_id": artifactID, + } - r, err := json.Marshal(result) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + r, err := json.Marshal(result) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } - return utils.NewToolResultText(string(r)), nil, nil - } + return utils.NewToolResultText(string(r)), nil, nil + } + }, + ) } // DeleteWorkflowRunLogs creates a tool to delete logs for a workflow run -func DeleteWorkflowRunLogs(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func DeleteWorkflowRunLogs(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "delete_workflow_run_logs", Description: t("TOOL_DELETE_WORKFLOW_RUN_LOGS_DESCRIPTION", "Delete logs for a workflow run"), Annotations: &mcp.ToolAnnotations{ @@ -1224,51 +1274,55 @@ func DeleteWorkflowRunLogs(getClient GetClientFn, t translations.TranslationHelp Required: []string{"owner", "repo", "run_id"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runIDInt, err := RequiredInt(args, "run_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runID := int64(runIDInt) + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runIDInt, err := RequiredInt(args, "run_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runID := int64(runIDInt) - resp, err := client.Actions.DeleteWorkflowRunLogs(ctx, owner, repo, runID) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to delete workflow run logs", resp, err), nil, nil - } - defer func() { _ = resp.Body.Close() }() + resp, err := client.Actions.DeleteWorkflowRunLogs(ctx, owner, repo, runID) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to delete workflow run logs", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() - result := map[string]any{ - "message": "Workflow run logs have been deleted", - "run_id": runID, - "status": resp.Status, - "status_code": resp.StatusCode, - } + result := map[string]any{ + "message": "Workflow run logs have been deleted", + "run_id": runID, + "status": resp.Status, + "status_code": resp.StatusCode, + } - r, err := json.Marshal(result) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + r, err := json.Marshal(result) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } - return utils.NewToolResultText(string(r)), nil, nil - } + return utils.NewToolResultText(string(r)), nil, nil + } + }, + ) } // GetWorkflowRunUsage creates a tool to get usage metrics for a workflow run -func GetWorkflowRunUsage(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetWorkflowRunUsage(t translations.TranslationHelperFunc) toolsets.ServerTool { + return NewTool( + mcp.Tool{ Name: "get_workflow_run_usage", Description: t("TOOL_GET_WORKFLOW_RUN_USAGE_DESCRIPTION", "Get usage metrics for a workflow run"), Annotations: &mcp.ToolAnnotations{ @@ -1294,37 +1348,40 @@ func GetWorkflowRunUsage(getClient GetClientFn, t translations.TranslationHelper Required: []string{"owner", "repo", "run_id"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runIDInt, err := RequiredInt(args, "run_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - runID := int64(runIDInt) + func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runIDInt, err := RequiredInt(args, "run_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + runID := int64(runIDInt) - usage, resp, err := client.Actions.GetWorkflowRunUsageByID(ctx, owner, repo, runID) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get workflow run usage", resp, err), nil, nil - } - defer func() { _ = resp.Body.Close() }() + usage, resp, err := client.Actions.GetWorkflowRunUsageByID(ctx, owner, repo, runID) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get workflow run usage", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(usage) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + r, err := json.Marshal(usage) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } - return utils.NewToolResultText(string(r)), nil, nil - } + return utils.NewToolResultText(string(r)), nil, nil + } + }, + ) } diff --git a/pkg/github/actions_test.go b/pkg/github/actions_test.go index 6d9921f2e..09ab3b2cc 100644 --- a/pkg/github/actions_test.go +++ b/pkg/github/actions_test.go @@ -25,13 +25,12 @@ import ( func Test_ListWorkflows(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListWorkflows(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + toolDef := ListWorkflows(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "list_workflows", tool.Name) - assert.NotEmpty(t, tool.Description) - inputSchema := tool.InputSchema.(*jsonschema.Schema) + assert.Equal(t, "list_workflows", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + inputSchema := toolDef.Tool.InputSchema.(*jsonschema.Schema) assert.Contains(t, inputSchema.Properties, "owner") assert.Contains(t, inputSchema.Properties, "repo") assert.Contains(t, inputSchema.Properties, "perPage") @@ -106,13 +105,16 @@ func Test_ListWorkflows(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ListWorkflows(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) require.Equal(t, tc.expectError, result.IsError) @@ -138,18 +140,17 @@ func Test_ListWorkflows(t *testing.T) { func Test_RunWorkflow(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := RunWorkflow(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) - - assert.Equal(t, "run_workflow", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "owner") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "repo") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "workflow_id") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "ref") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "inputs") - assert.ElementsMatch(t, tool.InputSchema.(*jsonschema.Schema).Required, []string{"owner", "repo", "workflow_id", "ref"}) + toolDef := RunWorkflow(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) + + assert.Equal(t, "run_workflow", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "owner") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "repo") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "workflow_id") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "ref") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "inputs") + assert.ElementsMatch(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Required, []string{"owner", "repo", "workflow_id", "ref"}) tests := []struct { name string @@ -193,13 +194,16 @@ func Test_RunWorkflow(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := RunWorkflow(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) require.Equal(t, tc.expectError, result.IsError) @@ -224,6 +228,8 @@ func Test_RunWorkflow(t *testing.T) { func Test_RunWorkflow_WithFilename(t *testing.T) { // Test the unified RunWorkflow function with filenames + toolDef := RunWorkflow(translations.NullTranslationHelper) + tests := []struct { name string mockedClient *http.Client @@ -284,13 +290,16 @@ func Test_RunWorkflow_WithFilename(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := RunWorkflow(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) require.Equal(t, tc.expectError, result.IsError) @@ -315,16 +324,15 @@ func Test_RunWorkflow_WithFilename(t *testing.T) { func Test_CancelWorkflowRun(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := CancelWorkflowRun(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + toolDef := CancelWorkflowRun(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "cancel_workflow_run", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "owner") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "repo") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "run_id") - assert.ElementsMatch(t, tool.InputSchema.(*jsonschema.Schema).Required, []string{"owner", "repo", "run_id"}) + assert.Equal(t, "cancel_workflow_run", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "owner") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "repo") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "run_id") + assert.ElementsMatch(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Required, []string{"owner", "repo", "run_id"}) tests := []struct { name string @@ -390,13 +398,16 @@ func Test_CancelWorkflowRun(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := CancelWorkflowRun(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) require.Equal(t, tc.expectError, result.IsError) @@ -421,18 +432,17 @@ func Test_CancelWorkflowRun(t *testing.T) { func Test_ListWorkflowRunArtifacts(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListWorkflowRunArtifacts(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) - - assert.Equal(t, "list_workflow_run_artifacts", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "owner") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "repo") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "run_id") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "perPage") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "page") - assert.ElementsMatch(t, tool.InputSchema.(*jsonschema.Schema).Required, []string{"owner", "repo", "run_id"}) + toolDef := ListWorkflowRunArtifacts(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) + + assert.Equal(t, "list_workflow_run_artifacts", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "owner") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "repo") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "run_id") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "perPage") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "page") + assert.ElementsMatch(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Required, []string{"owner", "repo", "run_id"}) tests := []struct { name string @@ -518,13 +528,16 @@ func Test_ListWorkflowRunArtifacts(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ListWorkflowRunArtifacts(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) require.Equal(t, tc.expectError, result.IsError) @@ -550,16 +563,15 @@ func Test_ListWorkflowRunArtifacts(t *testing.T) { func Test_DownloadWorkflowRunArtifact(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := DownloadWorkflowRunArtifact(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + toolDef := DownloadWorkflowRunArtifact(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "download_workflow_run_artifact", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "owner") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "repo") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "artifact_id") - assert.ElementsMatch(t, tool.InputSchema.(*jsonschema.Schema).Required, []string{"owner", "repo", "artifact_id"}) + assert.Equal(t, "download_workflow_run_artifact", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "owner") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "repo") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "artifact_id") + assert.ElementsMatch(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Required, []string{"owner", "repo", "artifact_id"}) tests := []struct { name string @@ -606,13 +618,16 @@ func Test_DownloadWorkflowRunArtifact(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := DownloadWorkflowRunArtifact(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) require.Equal(t, tc.expectError, result.IsError) @@ -639,16 +654,15 @@ func Test_DownloadWorkflowRunArtifact(t *testing.T) { func Test_DeleteWorkflowRunLogs(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := DeleteWorkflowRunLogs(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + toolDef := DeleteWorkflowRunLogs(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "delete_workflow_run_logs", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "owner") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "repo") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "run_id") - assert.ElementsMatch(t, tool.InputSchema.(*jsonschema.Schema).Required, []string{"owner", "repo", "run_id"}) + assert.Equal(t, "delete_workflow_run_logs", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "owner") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "repo") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "run_id") + assert.ElementsMatch(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Required, []string{"owner", "repo", "run_id"}) tests := []struct { name string @@ -690,13 +704,16 @@ func Test_DeleteWorkflowRunLogs(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := DeleteWorkflowRunLogs(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) require.Equal(t, tc.expectError, result.IsError) @@ -721,16 +738,15 @@ func Test_DeleteWorkflowRunLogs(t *testing.T) { func Test_GetWorkflowRunUsage(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := GetWorkflowRunUsage(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + toolDef := GetWorkflowRunUsage(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "get_workflow_run_usage", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "owner") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "repo") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "run_id") - assert.ElementsMatch(t, tool.InputSchema.(*jsonschema.Schema).Required, []string{"owner", "repo", "run_id"}) + assert.Equal(t, "get_workflow_run_usage", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "owner") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "repo") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "run_id") + assert.ElementsMatch(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Required, []string{"owner", "repo", "run_id"}) tests := []struct { name string @@ -792,13 +808,16 @@ func Test_GetWorkflowRunUsage(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetWorkflowRunUsage(stubGetClientFn(client), translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + } + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) require.Equal(t, tc.expectError, result.IsError) @@ -823,19 +842,18 @@ func Test_GetWorkflowRunUsage(t *testing.T) { func Test_GetJobLogs(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := GetJobLogs(stubGetClientFn(mockClient), translations.NullTranslationHelper, 5000) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) - - assert.Equal(t, "get_job_logs", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "owner") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "repo") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "job_id") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "run_id") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "failed_only") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "return_content") - assert.ElementsMatch(t, tool.InputSchema.(*jsonschema.Schema).Required, []string{"owner", "repo"}) + toolDef := GetJobLogs(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) + + assert.Equal(t, "get_job_logs", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "owner") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "repo") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "job_id") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "run_id") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "failed_only") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "return_content") + assert.ElementsMatch(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Required, []string{"owner", "repo"}) tests := []struct { name string @@ -1054,13 +1072,17 @@ func Test_GetJobLogs(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetJobLogs(stubGetClientFn(client), translations.NullTranslationHelper, 5000) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + ContentWindowSize: 5000, + } + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(context.Background(), &request) require.NoError(t, err) require.Equal(t, tc.expectError, result.IsError) @@ -1113,7 +1135,12 @@ func Test_GetJobLogs_WithContentReturn(t *testing.T) { ) client := github.NewClient(mockedClient) - _, handler := GetJobLogs(stubGetClientFn(client), translations.NullTranslationHelper, 5000) + toolDef := GetJobLogs(translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + ContentWindowSize: 5000, + } + handler := toolDef.Handler(deps) request := createMCPRequest(map[string]any{ "owner": "owner", @@ -1121,14 +1148,8 @@ func Test_GetJobLogs_WithContentReturn(t *testing.T) { "job_id": float64(123), "return_content": true, }) - args := map[string]any{ - "owner": "owner", - "repo": "repo", - "job_id": float64(123), - "return_content": true, - } - result, _, err := handler(context.Background(), &request, args) + result, err := handler(context.Background(), &request) require.NoError(t, err) require.False(t, result.IsError) @@ -1166,7 +1187,12 @@ func Test_GetJobLogs_WithContentReturnAndTailLines(t *testing.T) { ) client := github.NewClient(mockedClient) - _, handler := GetJobLogs(stubGetClientFn(client), translations.NullTranslationHelper, 5000) + toolDef := GetJobLogs(translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + ContentWindowSize: 5000, + } + handler := toolDef.Handler(deps) request := createMCPRequest(map[string]any{ "owner": "owner", @@ -1175,15 +1201,8 @@ func Test_GetJobLogs_WithContentReturnAndTailLines(t *testing.T) { "return_content": true, "tail_lines": float64(1), // Requesting last 1 line }) - args := map[string]any{ - "owner": "owner", - "repo": "repo", - "job_id": float64(123), - "return_content": true, - "tail_lines": float64(1), - } - result, _, err := handler(context.Background(), &request, args) + result, err := handler(context.Background(), &request) require.NoError(t, err) require.False(t, result.IsError) @@ -1220,7 +1239,12 @@ func Test_GetJobLogs_WithContentReturnAndLargeTailLines(t *testing.T) { ) client := github.NewClient(mockedClient) - _, handler := GetJobLogs(stubGetClientFn(client), translations.NullTranslationHelper, 5000) + toolDef := GetJobLogs(translations.NullTranslationHelper) + deps := ToolDependencies{ + GetClient: stubGetClientFn(client), + ContentWindowSize: 5000, + } + handler := toolDef.Handler(deps) request := createMCPRequest(map[string]any{ "owner": "owner", @@ -1229,15 +1253,8 @@ func Test_GetJobLogs_WithContentReturnAndLargeTailLines(t *testing.T) { "return_content": true, "tail_lines": float64(100), }) - args := map[string]any{ - "owner": "owner", - "repo": "repo", - "job_id": float64(123), - "return_content": true, - "tail_lines": float64(100), - } - result, _, err := handler(context.Background(), &request, args) + result, err := handler(context.Background(), &request) require.NoError(t, err) require.False(t, result.IsError) @@ -1353,13 +1370,12 @@ func Test_MemoryUsage_SlidingWindow_vs_NoWindow(t *testing.T) { func Test_ListWorkflowRuns(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListWorkflowRuns(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + toolDef := ListWorkflowRuns(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "list_workflow_runs", tool.Name) - assert.NotEmpty(t, tool.Description) - inputSchema := tool.InputSchema.(*jsonschema.Schema) + assert.Equal(t, "list_workflow_runs", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + inputSchema := toolDef.Tool.InputSchema.(*jsonschema.Schema) assert.Contains(t, inputSchema.Properties, "owner") assert.Contains(t, inputSchema.Properties, "repo") assert.Contains(t, inputSchema.Properties, "workflow_id") @@ -1368,13 +1384,12 @@ func Test_ListWorkflowRuns(t *testing.T) { func Test_GetWorkflowRun(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := GetWorkflowRun(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + toolDef := GetWorkflowRun(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "get_workflow_run", tool.Name) - assert.NotEmpty(t, tool.Description) - inputSchema := tool.InputSchema.(*jsonschema.Schema) + assert.Equal(t, "get_workflow_run", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + inputSchema := toolDef.Tool.InputSchema.(*jsonschema.Schema) assert.Contains(t, inputSchema.Properties, "owner") assert.Contains(t, inputSchema.Properties, "repo") assert.Contains(t, inputSchema.Properties, "run_id") @@ -1383,13 +1398,12 @@ func Test_GetWorkflowRun(t *testing.T) { func Test_GetWorkflowRunLogs(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := GetWorkflowRunLogs(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + toolDef := GetWorkflowRunLogs(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "get_workflow_run_logs", tool.Name) - assert.NotEmpty(t, tool.Description) - inputSchema := tool.InputSchema.(*jsonschema.Schema) + assert.Equal(t, "get_workflow_run_logs", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + inputSchema := toolDef.Tool.InputSchema.(*jsonschema.Schema) assert.Contains(t, inputSchema.Properties, "owner") assert.Contains(t, inputSchema.Properties, "repo") assert.Contains(t, inputSchema.Properties, "run_id") @@ -1398,13 +1412,12 @@ func Test_GetWorkflowRunLogs(t *testing.T) { func Test_ListWorkflowJobs(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListWorkflowJobs(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + toolDef := ListWorkflowJobs(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "list_workflow_jobs", tool.Name) - assert.NotEmpty(t, tool.Description) - inputSchema := tool.InputSchema.(*jsonschema.Schema) + assert.Equal(t, "list_workflow_jobs", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + inputSchema := toolDef.Tool.InputSchema.(*jsonschema.Schema) assert.Contains(t, inputSchema.Properties, "owner") assert.Contains(t, inputSchema.Properties, "repo") assert.Contains(t, inputSchema.Properties, "run_id") @@ -1413,13 +1426,12 @@ func Test_ListWorkflowJobs(t *testing.T) { func Test_RerunWorkflowRun(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := RerunWorkflowRun(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + toolDef := RerunWorkflowRun(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "rerun_workflow_run", tool.Name) - assert.NotEmpty(t, tool.Description) - inputSchema := tool.InputSchema.(*jsonschema.Schema) + assert.Equal(t, "rerun_workflow_run", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + inputSchema := toolDef.Tool.InputSchema.(*jsonschema.Schema) assert.Contains(t, inputSchema.Properties, "owner") assert.Contains(t, inputSchema.Properties, "repo") assert.Contains(t, inputSchema.Properties, "run_id") @@ -1428,13 +1440,12 @@ func Test_RerunWorkflowRun(t *testing.T) { func Test_RerunFailedJobs(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := RerunFailedJobs(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + toolDef := RerunFailedJobs(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "rerun_failed_jobs", tool.Name) - assert.NotEmpty(t, tool.Description) - inputSchema := tool.InputSchema.(*jsonschema.Schema) + assert.Equal(t, "rerun_failed_jobs", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + inputSchema := toolDef.Tool.InputSchema.(*jsonschema.Schema) assert.Contains(t, inputSchema.Properties, "owner") assert.Contains(t, inputSchema.Properties, "repo") assert.Contains(t, inputSchema.Properties, "run_id") diff --git a/pkg/github/tools.go b/pkg/github/tools.go index ba53f22af..8e811c9bf 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -300,22 +300,22 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG actions := toolsets.NewToolset(ToolsetMetadataActions.ID, ToolsetMetadataActions.Description). SetDependencies(deps). AddReadTools( - toolsets.NewServerToolLegacy(ListWorkflows(getClient, t)), - toolsets.NewServerToolLegacy(ListWorkflowRuns(getClient, t)), - toolsets.NewServerToolLegacy(GetWorkflowRun(getClient, t)), - toolsets.NewServerToolLegacy(GetWorkflowRunLogs(getClient, t)), - toolsets.NewServerToolLegacy(ListWorkflowJobs(getClient, t)), - toolsets.NewServerToolLegacy(GetJobLogs(getClient, t, contentWindowSize)), - toolsets.NewServerToolLegacy(ListWorkflowRunArtifacts(getClient, t)), - toolsets.NewServerToolLegacy(DownloadWorkflowRunArtifact(getClient, t)), - toolsets.NewServerToolLegacy(GetWorkflowRunUsage(getClient, t)), + ListWorkflows(t), + ListWorkflowRuns(t), + GetWorkflowRun(t), + GetWorkflowRunLogs(t), + ListWorkflowJobs(t), + GetJobLogs(t), + ListWorkflowRunArtifacts(t), + DownloadWorkflowRunArtifact(t), + GetWorkflowRunUsage(t), ). AddWriteTools( - toolsets.NewServerToolLegacy(RunWorkflow(getClient, t)), - toolsets.NewServerToolLegacy(RerunWorkflowRun(getClient, t)), - toolsets.NewServerToolLegacy(RerunFailedJobs(getClient, t)), - toolsets.NewServerToolLegacy(CancelWorkflowRun(getClient, t)), - toolsets.NewServerToolLegacy(DeleteWorkflowRunLogs(getClient, t)), + RunWorkflow(t), + RerunWorkflowRun(t), + RerunFailedJobs(t), + CancelWorkflowRun(t), + DeleteWorkflowRunLogs(t), ) securityAdvisories := toolsets.NewToolset(ToolsetMetadataSecurityAdvisories.ID, ToolsetMetadataSecurityAdvisories.Description).