diff --git a/README.md b/README.md index 1ad1c91..e426523 100644 --- a/README.md +++ b/README.md @@ -88,16 +88,17 @@ GET /api/v1/job//stderr Heimdall supports a growing set of pluggable command types: -| Plugin | Description | Execution Mode | -| ----------- | -------------------------------------- | -------------- | -| `ping` | [Basic plugin used for testing](https://github.com/patterninc/heimdall/blob/main/plugins/ping/README.md) | Sync or Async | -| `shell` | [Shell command execution](https://github.com/patterninc/heimdall/blob/main/plugins/shell/README.md) | Sync or Async | -| `glue` | [Pulling Iceberg table metadata](https://github.com/patterninc/heimdall/blob/main/plugins/glue/README.md) | Sync or Async | -| `dynamo` | [DynamoDB read operation](https://github.com/patterninc/heimdall/blob/main/plugins/dynamo/README.md) | Sync or Async | -| `snowflake` | [Query execution in Snowflake](https://github.com/patterninc/heimdall/blob/main/plugins/snowflake/README.md) | Async | -| `spark` | [SparkSQL query execution on EMR on EKS](https://github.com/patterninc/heimdall/blob/main/plugins/spark/README.md) | Async | -| `trino` | [Query execution in Trino](https://github.com/patterninc/heimdall/blob/main/plugins/trino/README.md) | Async | -| `clickhouse`| [Query execution in Clickhouse](https://github.com/patterninc/heimdall/blob/main/plugins/clickhouse/README.md) | Sync | +| Plugin | Description | Execution Mode | +| ----------- | -------------------------------------- | -------------- | +| `ping` | [Basic plugin used for testing](https://github.com/patterninc/heimdall/blob/main/plugins/ping/README.md) | Sync or Async | +| `shell` | [Shell command execution](https://github.com/patterninc/heimdall/blob/main/plugins/shell/README.md) | Sync or Async | +| `glue` | [Pulling Iceberg table metadata](https://github.com/patterninc/heimdall/blob/main/plugins/glue/README.md) | Sync or Async | +| `dynamo` | [DynamoDB read operation](https://github.com/patterninc/heimdall/blob/main/plugins/dynamo/README.md) | Sync or Async | +| `snowflake` | [Query execution in Snowflake](https://github.com/patterninc/heimdall/blob/main/plugins/snowflake/README.md) | Async | +| `spark` | [SparkSQL query execution on EMR on EKS](https://github.com/patterninc/heimdall/blob/main/plugins/spark/README.md) | Async | +| `trino` | [Query execution in Trino](https://github.com/patterninc/heimdall/blob/main/plugins/trino/README.md) | Async | +| `clickhouse` | [Query execution in Clickhouse](https://github.com/patterninc/heimdall/blob/main/plugins/clickhouse/README.md) | Sync | +| `ecs fargate` | [Task Deployment in ECS Fargate](https://github.com/patterninc/heimdall/blob/main/plugins/ecs/README.md) | Async | --- @@ -163,6 +164,7 @@ It centralizes execution logic, logging, and auditing—all accessible via API o | `POST /api/v1/job` | Submit a job | | `GET /api/v1/job/` | Get job details | | `GET /api/v1/job//status` | Check job status | +| `POST /api/v1/job//cancel` | Cancel an async job | | `GET /api/v1/job//stdout` | Get stdout for a completed job | | `GET /api/v1/job//stderr` | Get stderr for a completed job | | `GET /api/v1/job//result` | Get job's result | diff --git a/assets/databases/heimdall/data/job_statuses.sql b/assets/databases/heimdall/data/job_statuses.sql index 0acc895..b0e5420 100644 --- a/assets/databases/heimdall/data/job_statuses.sql +++ b/assets/databases/heimdall/data/job_statuses.sql @@ -9,7 +9,9 @@ values (3, 'RUNNING'), (4, 'FAILED'), (5, 'KILLED'), - (6, 'SUCCEEDED') + (6, 'SUCCEEDED'), + (7, 'CANCELLING'), + (8, 'CANCELLED') on conflict (job_status_id) do update set job_status_name = excluded.job_status_name; diff --git a/assets/databases/heimdall/tables/jobs.sql b/assets/databases/heimdall/tables/jobs.sql index 11de46d..e40a6fb 100644 --- a/assets/databases/heimdall/tables/jobs.sql +++ b/assets/databases/heimdall/tables/jobs.sql @@ -19,4 +19,6 @@ create table if not exists jobs constraint _jobs_job_id unique (job_id) ); -alter table jobs add column if not exists store_result_sync boolean not null default false; \ No newline at end of file +alter table jobs add column if not exists store_result_sync boolean not null default false; +alter table jobs add column if not exists cancelled_by varchar(64) null; +update jobs set cancelled_by = '' where cancelled_by is null; \ No newline at end of file diff --git a/internal/pkg/aws/cloudwatch.go b/internal/pkg/aws/cloudwatch.go index d8467ba..543c9f9 100644 --- a/internal/pkg/aws/cloudwatch.go +++ b/internal/pkg/aws/cloudwatch.go @@ -1,6 +1,7 @@ package aws import ( + "context" "fmt" "os" "time" @@ -10,7 +11,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs" ) -func PullLogs(writer *os.File, logGroup, logStream string, chunkSize int, memoryLimit int64) error { +func PullLogs(ctx context.Context, writer *os.File, logGroup, logStream string, chunkSize int, memoryLimit int64) error { // initialize AWS session cfg, err := config.LoadDefaultConfig(ctx) diff --git a/internal/pkg/aws/glue.go b/internal/pkg/aws/glue.go index 94489a0..a540d94 100644 --- a/internal/pkg/aws/glue.go +++ b/internal/pkg/aws/glue.go @@ -1,6 +1,7 @@ package aws import ( + "context" "fmt" "strings" @@ -17,7 +18,7 @@ var ( ErrMissingCatalogTableMetadata = fmt.Errorf(`missing table metadata in the glue catalog`) ) -func GetTableMetadata(catalogID, tableName string) ([]byte, error) { +func GetTableMetadata(ctx context.Context, catalogID, tableName string) ([]byte, error) { // split tableName to namespace and table names tableNameParts := strings.Split(tableName, `.`) @@ -27,18 +28,18 @@ func GetTableMetadata(catalogID, tableName string) ([]byte, error) { } // let's get the latest metadata file location - location, err := getTableMetadataLocation(catalogID, tableNameParts[0], tableNameParts[1]) + location, err := getTableMetadataLocation(ctx, catalogID, tableNameParts[0], tableNameParts[1]) if err != nil { return nil, err } // let's pull the file content - return ReadFromS3(location) + return ReadFromS3(ctx, location) } // function that calls AWS glue catalog to get the snapshot ID for a given database, table and branch -func getTableMetadataLocation(catalogID, databaseName, tableName string) (string, error) { +func getTableMetadataLocation(ctx context.Context, catalogID, databaseName, tableName string) (string, error) { // Return an error if databaseName or tableName is empty if databaseName == `` || tableName == `` { diff --git a/internal/pkg/aws/s3.go b/internal/pkg/aws/s3.go index 1be4514..51843e2 100644 --- a/internal/pkg/aws/s3.go +++ b/internal/pkg/aws/s3.go @@ -13,12 +13,11 @@ import ( ) var ( - ctx = context.Background() rxS3Path = regexp.MustCompile(`^s3://([^/]+)/(.*)$`) ) // WriteToS3 writes a file to S3, providing the same interface as os.WriteFile function -func WriteToS3(name string, data []byte, _ os.FileMode) error { +func WriteToS3(ctx context.Context, name string, data []byte, _ os.FileMode) error { bucket, key, err := parseS3Path(name) if err != nil { @@ -47,7 +46,7 @@ func WriteToS3(name string, data []byte, _ os.FileMode) error { } -func ReadFromS3(name string) ([]byte, error) { +func ReadFromS3(ctx context.Context, name string) ([]byte, error) { bucket, key, err := parseS3Path(name) if err != nil { diff --git a/internal/pkg/heimdall/cluster_dal.go b/internal/pkg/heimdall/cluster_dal.go index 188aabf..02173cc 100644 --- a/internal/pkg/heimdall/cluster_dal.go +++ b/internal/pkg/heimdall/cluster_dal.go @@ -1,6 +1,7 @@ package heimdall import ( + "context" "database/sql" _ "embed" "encoding/json" @@ -77,13 +78,13 @@ var ( getClustersMethod = telemetry.NewMethod("db_connection", "get_clusters") ) -func (h *Heimdall) submitCluster(c *cluster.Cluster) (any, error) { +func (h *Heimdall) submitCluster(ctx context.Context, c *cluster.Cluster) (any, error) { if err := h.clusterUpsert(c); err != nil { return nil, err } - return h.getCluster(&cluster.Cluster{Object: object.Object{ID: c.ID}}) + return h.getCluster(ctx, &cluster.Cluster{Object: object.Object{ID: c.ID}}) } @@ -134,7 +135,7 @@ func (h *Heimdall) clusterUpsert(c *cluster.Cluster) error { } -func (h *Heimdall) getCluster(c *cluster.Cluster) (any, error) { +func (h *Heimdall) getCluster(ctx context.Context, c *cluster.Cluster) (any, error) { // Track DB connection for get cluster operation defer getClusterMethod.RecordLatency(time.Now()) @@ -181,7 +182,7 @@ func (h *Heimdall) getCluster(c *cluster.Cluster) (any, error) { } -func (h *Heimdall) getClusterStatus(c *cluster.Cluster) (any, error) { +func (h *Heimdall) getClusterStatus(ctx context.Context, c *cluster.Cluster) (any, error) { // Track DB connection for cluster status operation defer getClusterStatusMethod.RecordLatency(time.Now()) @@ -216,7 +217,7 @@ func (h *Heimdall) getClusterStatus(c *cluster.Cluster) (any, error) { } -func (h *Heimdall) updateClusterStatus(c *cluster.Cluster) (any, error) { +func (h *Heimdall) updateClusterStatus(ctx context.Context, c *cluster.Cluster) (any, error) { defer updateClusterStatusMethod.RecordLatency(time.Now()) updateClusterStatusMethod.CountRequest() @@ -239,11 +240,11 @@ func (h *Heimdall) updateClusterStatus(c *cluster.Cluster) (any, error) { } updateClusterStatusMethod.CountSuccess() - return h.getClusterStatus(c) + return h.getClusterStatus(ctx, c) } -func (h *Heimdall) getClusters(f *database.Filter) (any, error) { +func (h *Heimdall) getClusters(ctx context.Context, f *database.Filter) (any, error) { // Track DB connection for clusters list operation defer getClustersMethod.RecordLatency(time.Now()) @@ -295,7 +296,7 @@ func (h *Heimdall) getClusters(f *database.Filter) (any, error) { } -func (h *Heimdall) getClusterStatuses(_ *database.Filter) (any, error) { +func (h *Heimdall) getClusterStatuses(ctx context.Context, _ *database.Filter) (any, error) { return database.GetSlice(h.Database, queryClusterStatusesSelect) diff --git a/internal/pkg/heimdall/command_dal.go b/internal/pkg/heimdall/command_dal.go index ad9cd99..2274801 100644 --- a/internal/pkg/heimdall/command_dal.go +++ b/internal/pkg/heimdall/command_dal.go @@ -1,6 +1,7 @@ package heimdall import ( + "context" "database/sql" _ "embed" "encoding/json" @@ -89,13 +90,13 @@ var ( getCommandsMethod = telemetry.NewMethod("db_connection", "get_commands") ) -func (h *Heimdall) submitCommand(c *command.Command) (any, error) { +func (h *Heimdall) submitCommand(ctx context.Context, c *command.Command) (any, error) { if err := h.commandUpsert(c); err != nil { return nil, err } - return h.getCommand(&command.Command{Object: object.Object{ID: c.ID}}) + return h.getCommand(ctx, &command.Command{Object: object.Object{ID: c.ID}}) } @@ -163,7 +164,7 @@ func (h *Heimdall) commandUpsert(c *command.Command) error { } -func (h *Heimdall) getCommand(c *command.Command) (any, error) { +func (h *Heimdall) getCommand(ctx context.Context, c *command.Command) (any, error) { // Track DB connection for get command operation defer getCommandMethod.RecordLatency(time.Now()) @@ -210,7 +211,7 @@ func (h *Heimdall) getCommand(c *command.Command) (any, error) { } -func (h *Heimdall) getCommandStatus(c *command.Command) (any, error) { +func (h *Heimdall) getCommandStatus(ctx context.Context, c *command.Command) (any, error) { // Track DB connection for command status operation defer getCommandStatusMethod.RecordLatency(time.Now()) @@ -245,7 +246,7 @@ func (h *Heimdall) getCommandStatus(c *command.Command) (any, error) { } -func (h *Heimdall) updateCommandStatus(c *command.Command) (any, error) { +func (h *Heimdall) updateCommandStatus(ctx context.Context, c *command.Command) (any, error) { // Track DB connection for command status update operation defer updateCommandStatusMethod.RecordLatency(time.Now()) @@ -269,11 +270,11 @@ func (h *Heimdall) updateCommandStatus(c *command.Command) (any, error) { } updateCommandStatusMethod.CountSuccess() - return h.getCommandStatus(c) + return h.getCommandStatus(ctx, c) } -func (h *Heimdall) getCommands(f *database.Filter) (any, error) { +func (h *Heimdall) getCommands(ctx context.Context, f *database.Filter) (any, error) { // Track DB connection for commands list operation defer getCommandsMethod.RecordLatency(time.Now()) @@ -325,7 +326,7 @@ func (h *Heimdall) getCommands(f *database.Filter) (any, error) { } -func (h *Heimdall) getCommandStatuses(_ *database.Filter) (any, error) { +func (h *Heimdall) getCommandStatuses(ctx context.Context, _ *database.Filter) (any, error) { return database.GetSlice(h.Database, queryCommandStatusesSelect) diff --git a/internal/pkg/heimdall/handler.go b/internal/pkg/heimdall/handler.go index 194169b..55607ff 100644 --- a/internal/pkg/heimdall/handler.go +++ b/internal/pkg/heimdall/handler.go @@ -1,6 +1,7 @@ package heimdall import ( + "context" "encoding/json" "fmt" "io" @@ -48,7 +49,7 @@ func writeAPIError(w http.ResponseWriter, err error, obj any) { w.Write(responseJSON) } -func payloadHandler[T any](fn func(*T) (any, error)) http.HandlerFunc { +func payloadHandler[T any](fn func(context.Context, *T) (any, error)) http.HandlerFunc { // start latency timer defer payloadHandlerMethod.RecordLatency(time.Now()) @@ -81,7 +82,7 @@ func payloadHandler[T any](fn func(*T) (any, error)) http.HandlerFunc { } // execute request - result, err := fn(&payload) + result, err := fn(r.Context(), &payload) if err != nil { writeAPIError(w, err, result) return diff --git a/internal/pkg/heimdall/heimdall.go b/internal/pkg/heimdall/heimdall.go index d8a5840..871546d 100644 --- a/internal/pkg/heimdall/heimdall.go +++ b/internal/pkg/heimdall/heimdall.go @@ -15,12 +15,12 @@ import ( "github.com/patterninc/heimdall/internal/pkg/database" "github.com/patterninc/heimdall/internal/pkg/janitor" "github.com/patterninc/heimdall/internal/pkg/pool" + "github.com/patterninc/heimdall/internal/pkg/rbac" "github.com/patterninc/heimdall/internal/pkg/server" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/command" "github.com/patterninc/heimdall/pkg/object/job" "github.com/patterninc/heimdall/pkg/plugin" - "github.com/patterninc/heimdall/internal/pkg/rbac" rbacI "github.com/patterninc/heimdall/pkg/rbac" ) @@ -173,6 +173,7 @@ func (h *Heimdall) Start() error { // job(s) endpoints... apiRouter.Methods(methodGET).PathPrefix(`/job/statuses`).HandlerFunc(payloadHandler(h.getJobStatuses)) apiRouter.Methods(methodGET).PathPrefix(`/job/{id}/status`).HandlerFunc(payloadHandler(h.getJobStatus)) + apiRouter.Methods(methodPOST).PathPrefix(`/job/{id}/cancel`).HandlerFunc(payloadHandler(h.cancelJob)) apiRouter.Methods(methodGET).PathPrefix(`/job/{id}/{file}`).HandlerFunc(h.getJobFile) apiRouter.Methods(methodGET).PathPrefix(`/job/{id}`).HandlerFunc(payloadHandler(h.getJob)) apiRouter.Methods(methodGET).PathPrefix(`/jobs`).HandlerFunc(payloadHandler(h.getJobs)) diff --git a/internal/pkg/heimdall/job.go b/internal/pkg/heimdall/job.go index b0cd043..41151d1 100644 --- a/internal/pkg/heimdall/job.go +++ b/internal/pkg/heimdall/job.go @@ -1,7 +1,9 @@ package heimdall import ( + "context" "crypto/rand" + _ "embed" "encoding/json" "fmt" "math/big" @@ -36,15 +38,20 @@ const ( var ( ErrCommandClusterPairNotFound = fmt.Errorf(`command-cluster pair is not found`) + ErrJobCancelFailed = fmt.Errorf(`async job unrecognized or already in final state`) runJobMethod = telemetry.NewMethod("runJob", "heimdall") + cancelJobMethod = telemetry.NewMethod("db_connection", "cancel_job") ) +//go:embed queries/job/status_cancel_update.sql +var queryJobCancelUpdate string + type commandOnCluster struct { command *command.Command cluster *cluster.Cluster } -func (h *Heimdall) submitJob(j *job.Job) (any, error) { +func (h *Heimdall) submitJob(ctx context.Context, j *job.Job) (any, error) { // set / add job properties if err := j.Init(); err != nil { @@ -67,7 +74,7 @@ func (h *Heimdall) submitJob(j *job.Job) (any, error) { } // let's run the job - err = h.runJob(j, command, cluster) + err = h.runJob(ctx, j, command, cluster) // before we process the error, we'll make the best effort to record this job in the database go h.insertJob(j, cluster.ID, command.ID) @@ -76,16 +83,16 @@ func (h *Heimdall) submitJob(j *job.Job) (any, error) { } -func (h *Heimdall) runJob(job *job.Job, command *command.Command, cluster *cluster.Cluster) error { +func (h *Heimdall) runJob(ctx context.Context, j *job.Job, command *command.Command, cluster *cluster.Cluster) error { defer runJobMethod.RecordLatency(time.Now(), command.Name, cluster.Name) runJobMethod.CountRequest(command.Name, cluster.Name) // let's set environment runtime := &plugin.Runtime{ - WorkingDirectory: h.JobsDirectory + separator + job.ID, - ArchiveDirectory: h.ArchiveDirectory + separator + job.ID, - ResultDirectory: h.ResultDirectory + separator + job.ID, + WorkingDirectory: h.JobsDirectory + separator + j.ID, + ArchiveDirectory: h.ArchiveDirectory + separator + j.ID, + ResultDirectory: h.ResultDirectory + separator + j.ID, Version: h.Version, UserAgent: fmt.Sprintf(formatUserAgent, h.Version), } @@ -104,41 +111,87 @@ func (h *Heimdall) runJob(job *job.Job, command *command.Command, cluster *clust defer close(keepaliveActive) // ...and now we just start keepalive function for this job - go h.jobKeepalive(keepaliveActive, job.SystemID, h.agentName) - - // let's execute command - if err := h.commandHandlers[command.ID](runtime, job, cluster); err != nil { - - job.Status = jobStatus.Failed - job.Error = err.Error() + go h.jobKeepalive(keepaliveActive, j.SystemID, h.agentName) + + // Create channels for coordination between plugin execution and cancellation monitoring + jobDone := make(chan error, 1) + cancelMonitorDone := make(chan struct{}) + + // Create cancellable context for the job + pluginCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // Start plugin execution in goroutine + go func() { + defer close(cancelMonitorDone) // signal monitoring to stop + err := h.commandHandlers[command.ID](pluginCtx, runtime, j, cluster) + jobDone <- err + }() + + // Start cancellation monitoring for async jobs + if !j.IsSync { + go func() { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + for { + select { + // plugin finished, stop monitoring + case <-cancelMonitorDone: + return + case <-ticker.C: + // If job is in cancelling state, trigger context cancellation + result, err := h.getJobStatus(ctx, &jobRequest{ID: j.ID}) + if err == nil { + if job, ok := result.(*job.Job); ok && job.Status == jobStatus.Cancelling { + cancel() + return + } + } + } + } + }() + } - runJobMethod.LogAndCountError(err, command.Name, cluster.Name) + // Wait for job execution to complete + jobErr := <-jobDone - return err + // Check if context was cancelled and mark status appropriately + if pluginCtx.Err() != nil { + j.Status = jobStatus.Cancelling // janitor will update to cancelled when resources are cleaned up + runJobMethod.LogAndCountError(pluginCtx.Err(), command.Name, cluster.Name) + return nil + } + // Handle plugin execution result (only if not cancelled) + if jobErr != nil { + j.Status = jobStatus.Failed + j.Error = jobErr.Error() + runJobMethod.LogAndCountError(jobErr, command.Name, cluster.Name) + return jobErr } - if job.StoreResultSync || !job.IsSync { - h.storeResults(runtime, job) + if j.StoreResultSync || !j.IsSync { + h.storeResults(runtime, j) } else { - go h.storeResults(runtime, job) + go h.storeResults(runtime, j) } - job.Status = jobStatus.Succeeded + j.Status = jobStatus.Succeeded runJobMethod.CountSuccess(command.Name, cluster.Name) return nil } -func (h *Heimdall) storeResults(runtime *plugin.Runtime, job *job.Job) error { +func (h *Heimdall) storeResults(runtime *plugin.Runtime, j *job.Job) error { // do we have result to be written? - if job.Result == nil { + if j.Result == nil { return nil } // prepare result - data, err := json.Marshal(job.Result) + data, err := json.Marshal(j.Result) if err != nil { return err @@ -147,7 +200,9 @@ func (h *Heimdall) storeResults(runtime *plugin.Runtime, job *job.Job) error { // write result writeFileFunc := os.WriteFile if strings.HasPrefix(runtime.ResultDirectory, s3Prefix) { - writeFileFunc = aws.WriteToS3 + writeFileFunc = func(name string, data []byte, perm os.FileMode) error { + return aws.WriteToS3(context.Background(), name, data, perm) + } } if err := writeFileFunc(runtime.ResultDirectory+separator+resultFilename, data, 0600); err != nil { @@ -157,6 +212,34 @@ func (h *Heimdall) storeResults(runtime *plugin.Runtime, job *job.Job) error { return nil } +func (h *Heimdall) cancelJob(ctx context.Context, req *jobRequest) (any, error) { + + defer cancelJobMethod.RecordLatency(time.Now()) + cancelJobMethod.CountRequest() + + sess, err := h.Database.NewSession(false) + if err != nil { + cancelJobMethod.LogAndCountError(err, "new_session") + return nil, err + } + defer sess.Close() + + // Attempt to cancel + rowsAffected, err := sess.Exec(queryJobCancelUpdate, req.ID, req.User) + if err != nil { + return nil, err + } + + if rowsAffected == 0 { + return nil, ErrJobCancelFailed + } + + // return job status + return &job.Job{ + Status: jobStatus.Cancelling, + }, nil +} + func (h *Heimdall) getJobFile(w http.ResponseWriter, r *http.Request) { // get vars @@ -173,7 +256,7 @@ func (h *Heimdall) getJobFile(w http.ResponseWriter, r *http.Request) { } // let's validate jobID we got - if _, err := h.getJobStatus(&jobRequest{ID: jobID}); err != nil { + if _, err := h.getJobStatus(r.Context(), &jobRequest{ID: jobID}); err != nil { writeAPIError(w, err, nil) return } @@ -192,7 +275,9 @@ func (h *Heimdall) getJobFile(w http.ResponseWriter, r *http.Request) { readFileFunc := os.ReadFile filenamePath := fmt.Sprintf(jobFileFormat, sourceDirectory, jobID, filename) if strings.HasPrefix(filenamePath, s3Prefix) { - readFileFunc = aws.ReadFromS3 + readFileFunc = func(path string) ([]byte, error) { + return aws.ReadFromS3(r.Context(), path) + } } // get file's content diff --git a/internal/pkg/heimdall/job_dal.go b/internal/pkg/heimdall/job_dal.go index 31f070a..a845be0 100644 --- a/internal/pkg/heimdall/job_dal.go +++ b/internal/pkg/heimdall/job_dal.go @@ -1,6 +1,7 @@ package heimdall import ( + "context" "database/sql" _ "embed" "encoding/json" @@ -94,6 +95,7 @@ var ( type jobRequest struct { ID string `yaml:"id,omitempty" json:"id,omitempty"` File string `yaml:"file,omitempty" json:"file,omitempty"` + User string `yaml:"user,omitempty" json:"user,omitempty"` } func (h *Heimdall) insertJob(j *job.Job, clusterID, commandID string) (int64, error) { @@ -111,7 +113,7 @@ func (h *Heimdall) insertJob(j *job.Job, clusterID, commandID string) (int64, er defer sess.Close() // insert job row - jobID, err := sess.InsertRow(queryJobInsert, clusterID, commandID, j.Status, j.ID, j.Name, j.Version, j.Description, j.Context.String(), j.Error, j.User, j.IsSync, j.StoreResultSync) + jobID, err := sess.InsertRow(queryJobInsert, clusterID, commandID, j.Status, j.ID, j.Name, j.Version, j.Description, j.Context.String(), j.Error, j.User, j.IsSync, j.StoreResultSync, j.CancelledBy) if err != nil { return 0, err } @@ -169,7 +171,7 @@ func (h *Heimdall) insertJob(j *job.Job, clusterID, commandID string) (int64, er } -func (h *Heimdall) getJob(j *jobRequest) (any, error) { +func (h *Heimdall) getJob(ctx context.Context, j *jobRequest) (any, error) { // Track DB connection for job get operation defer getJobMethod.RecordLatency(time.Now()) @@ -198,7 +200,7 @@ func (h *Heimdall) getJob(j *jobRequest) (any, error) { var jobContext string if err := row.Scan(&r.SystemID, &r.Status, &r.Name, &r.Version, &r.Description, &jobContext, &r.Error, &r.User, &r.IsSync, - &r.CreatedAt, &r.UpdatedAt, &r.CommandID, &r.CommandName, &r.CluserID, &r.ClusterName, &r.StoreResultSync); err != nil { + &r.CreatedAt, &r.UpdatedAt, &r.CommandID, &r.CommandName, &r.ClusterID, &r.ClusterName, &r.StoreResultSync, &r.CancelledBy); err != nil { if err == sql.ErrNoRows { return nil, ErrUnknownJobID } else { @@ -216,7 +218,7 @@ func (h *Heimdall) getJob(j *jobRequest) (any, error) { } -func (h *Heimdall) getJobs(f *database.Filter) (any, error) { +func (h *Heimdall) getJobs(ctx context.Context, f *database.Filter) (any, error) { // Track DB connection for jobs list operation defer getJobsMethod.RecordLatency(time.Now()) @@ -251,7 +253,7 @@ func (h *Heimdall) getJobs(f *database.Filter) (any, error) { r := &job.Job{} if err := rows.Scan(&r.SystemID, &r.ID, &r.Status, &r.Name, &r.Version, &r.Description, &jobContext, &r.Error, &r.User, &r.IsSync, - &r.CreatedAt, &r.UpdatedAt, &r.CommandID, &r.CommandName, &r.CluserID, &r.ClusterName, &r.StoreResultSync); err != nil { + &r.CreatedAt, &r.UpdatedAt, &r.CommandID, &r.CommandName, &r.ClusterID, &r.ClusterName, &r.StoreResultSync, &r.CancelledBy); err != nil { getJobsMethod.LogAndCountError(err, "scan") return nil, err } @@ -272,7 +274,7 @@ func (h *Heimdall) getJobs(f *database.Filter) (any, error) { } -func (h *Heimdall) getJobStatus(j *jobRequest) (any, error) { +func (h *Heimdall) getJobStatus(ctx context.Context, j *jobRequest) (any, error) { // Track DB connection for job status operation defer getJobStatusMethod.RecordLatency(time.Now()) @@ -337,7 +339,7 @@ func jobParseContextAndTags(j *job.Job, jobContext string, sess *database.Sessio } -func (h *Heimdall) getJobStatuses(_ *database.Filter) (any, error) { +func (h *Heimdall) getJobStatuses(ctx context.Context, _ *database.Filter) (any, error) { return database.GetSlice(h.Database, queryJobStatusesSelect) diff --git a/internal/pkg/heimdall/jobs_async.go b/internal/pkg/heimdall/jobs_async.go index 49a23f4..0f2e82a 100644 --- a/internal/pkg/heimdall/jobs_async.go +++ b/internal/pkg/heimdall/jobs_async.go @@ -1,6 +1,7 @@ package heimdall import ( + "context" _ "embed" "fmt" "time" @@ -63,7 +64,7 @@ func (h *Heimdall) getAsyncJobs(limit int) ([]*job.Job, error) { jobContext, j := ``, &job.Job{} - if err := rows.Scan(&j.SystemID, &j.CommandID, &j.CluserID, &j.Status, &j.ID, &j.Name, + if err := rows.Scan(&j.SystemID, &j.CommandID, &j.ClusterID, &j.Status, &j.ID, &j.Name, &j.Version, &j.Description, &jobContext, &j.User, &j.IsSync, &j.CreatedAt, &j.UpdatedAt, &j.StoreResultSync); err != nil { return nil, err } @@ -103,7 +104,7 @@ func (h *Heimdall) getAsyncJobs(limit int) ([]*job.Job, error) { } -func (h *Heimdall) runAsyncJob(j *job.Job) error { +func (h *Heimdall) runAsyncJob(ctx context.Context, j *job.Job) error { // Track DB connection for async job execution defer runAsyncJobMethod.RecordLatency(time.Now()) @@ -128,13 +129,13 @@ func (h *Heimdall) runAsyncJob(j *job.Job) error { } // do we have hte cluster? - cluster, found := h.Clusters[j.CluserID] + cluster, found := h.Clusters[j.ClusterID] if !found { - return h.updateAsyncJobStatus(j, fmt.Errorf(formatErrUnknownCluster, j.CluserID)) + return h.updateAsyncJobStatus(j, fmt.Errorf(formatErrUnknownCluster, j.ClusterID)) } runAsyncJobMethod.CountSuccess() - return h.updateAsyncJobStatus(j, h.runJob(j, command, cluster)) + return h.updateAsyncJobStatus(j, h.runJob(ctx, j, command, cluster)) } @@ -144,14 +145,6 @@ func (h *Heimdall) updateAsyncJobStatus(j *job.Job, jobError error) error { defer updateAsyncJobStatusMethod.RecordLatency(time.Now()) updateAsyncJobStatusMethod.CountRequest() - // we updte the final job status based on presence of the error - if jobError == nil { - j.Status = status.Succeeded - } else { - j.Status = status.Failed - j.Error = jobError.Error() - } - // now we update that status in the database sess, err := h.Database.NewSession(true) if err != nil { diff --git a/internal/pkg/heimdall/queries/job/insert.sql b/internal/pkg/heimdall/queries/job/insert.sql index 46813a6..c641363 100644 --- a/internal/pkg/heimdall/queries/job/insert.sql +++ b/internal/pkg/heimdall/queries/job/insert.sql @@ -11,7 +11,8 @@ insert into jobs job_error, username, is_sync, - store_result_sync + store_result_sync, + cancelled_by ) select cm.system_command_id, @@ -25,7 +26,8 @@ select $9, -- job_error $10, -- username $11, -- is_sync - $12 -- store_result_sync + $12, -- store_result_sync + $13 -- cancelled_by from clusters cl, commands cm diff --git a/internal/pkg/heimdall/queries/job/select.sql b/internal/pkg/heimdall/queries/job/select.sql index bc84784..716633f 100644 --- a/internal/pkg/heimdall/queries/job/select.sql +++ b/internal/pkg/heimdall/queries/job/select.sql @@ -14,7 +14,8 @@ select cm.command_name, cl.cluster_id, cl.cluster_name, - j.store_result_sync + j.store_result_sync, + j.cancelled_by from jobs j left join commands cm on cm.system_command_id = j.job_command_id diff --git a/internal/pkg/heimdall/queries/job/select_jobs.sql b/internal/pkg/heimdall/queries/job/select_jobs.sql index 39b4d81..3a439fa 100644 --- a/internal/pkg/heimdall/queries/job/select_jobs.sql +++ b/internal/pkg/heimdall/queries/job/select_jobs.sql @@ -15,7 +15,8 @@ select cm.command_name, cl.cluster_id, cl.cluster_name, - j.store_result_sync + j.store_result_sync, + j.cancelled_by from jobs j join job_statuses js on js.job_status_id = j.job_status_id diff --git a/internal/pkg/heimdall/queries/job/status_cancel_update.sql b/internal/pkg/heimdall/queries/job/status_cancel_update.sql new file mode 100644 index 0000000..8597499 --- /dev/null +++ b/internal/pkg/heimdall/queries/job/status_cancel_update.sql @@ -0,0 +1,8 @@ +update jobs +set + job_status_id = 7, -- CANCELLING + cancelled_by = $2, + updated_at = extract(epoch from now())::int +where + job_id = $1 + and job_status_id not in (4, 5, 6, 8); -- Not in FAILED, KILLED, SUCCEEDED, CANCELLED diff --git a/internal/pkg/object/command/clickhouse/clickhouse.go b/internal/pkg/object/command/clickhouse/clickhouse.go index 819f83a..5581667 100644 --- a/internal/pkg/object/command/clickhouse/clickhouse.go +++ b/internal/pkg/object/command/clickhouse/clickhouse.go @@ -7,7 +7,7 @@ import ( "github.com/ClickHouse/clickhouse-go/v2" "github.com/ClickHouse/clickhouse-go/v2/lib/driver" "github.com/hladush/go-telemetry/pkg/telemetry" - hdctx "github.com/patterninc/heimdall/pkg/context" + heimdallContext "github.com/patterninc/heimdall/pkg/context" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" "github.com/patterninc/heimdall/pkg/object/job/status" @@ -45,11 +45,11 @@ var ( ) // New creates a new clickhouse plugin handler -func New(ctx *hdctx.Context) (plugin.Handler, error) { +func New(commandCtx *heimdallContext.Context) (plugin.Handler, error) { t := &commandContext{} - if ctx != nil { - if err := ctx.Unmarshal(t); err != nil { + if commandCtx != nil { + if err := commandCtx.Unmarshal(t); err != nil { return nil, err } } @@ -57,10 +57,9 @@ func New(ctx *hdctx.Context) (plugin.Handler, error) { return t.handler, nil } -func (cmd *commandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { - ctx := context.Background() +func (cmd *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { - jobContext, err := cmd.createJobContext(j, c) + jobContext, err := cmd.createJobContext(ctx, j, c) if err != nil { handleMethod.LogAndCountError(err, "create_job_context") return err @@ -82,7 +81,7 @@ func (cmd *commandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster.Clu return nil } -func (cmd *commandContext) createJobContext(j *job.Job, c *cluster.Cluster) (*jobContext, error) { +func (cmd *commandContext) createJobContext(ctx context.Context, j *job.Job, c *cluster.Cluster) (*jobContext, error) { // get cluster context clusterCtx := &clusterContext{} if c.Context != nil { diff --git a/internal/pkg/object/command/dynamo/dynamo.go b/internal/pkg/object/command/dynamo/dynamo.go index f0a3334..283a6b8 100644 --- a/internal/pkg/object/command/dynamo/dynamo.go +++ b/internal/pkg/object/command/dynamo/dynamo.go @@ -1,7 +1,7 @@ package dynamo import ( - ct "context" + "context" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" @@ -9,7 +9,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/dynamodb" "github.com/aws/aws-sdk-go-v2/service/sts" - "github.com/patterninc/heimdall/pkg/context" + heimdallContext "github.com/patterninc/heimdall/pkg/context" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" "github.com/patterninc/heimdall/pkg/plugin" @@ -17,37 +17,36 @@ import ( ) // dynamoJobContext represents the context for a dynamo job -type dynamoJobContext struct { +type jobContext struct { Query string `yaml:"query,omitempty" json:"query,omitempty"` Limit int `yaml:"limit,omitempty" json:"limit,omitempty"` } // dynamoClusterContext represents the context for a dynamo endpoint -type dynamoClusterContext struct { +type clusterContext struct { RoleARN *string `yaml:"role_arn,omitempty" json:"role_arn,omitempty"` } // dynamoCommandContext represents the dynamo command context -type dynamoCommandContext struct{} +type commandContext struct{} var ( - ctx = ct.Background() assumeRoleSession = aws.String("AssumeRoleSession") ) // New creates a new dynamo plugin handler. -func New(_ *context.Context) (plugin.Handler, error) { +func New(_ *heimdallContext.Context) (plugin.Handler, error) { - s := &dynamoCommandContext{} + s := &commandContext{} return s.handler, nil } // Handler for the Spark job submission. -func (d *dynamoCommandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster.Cluster) (err error) { +func (d *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) (err error) { // let's unmarshal job context - jobContext := &dynamoJobContext{} + jobContext := &jobContext{} if j.Context != nil { if err := j.Context.Unmarshal(jobContext); err != nil { return err @@ -55,7 +54,7 @@ func (d *dynamoCommandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster } // let's unmarshal cluster context - clusterContext := &dynamoClusterContext{} + clusterContext := &clusterContext{} if c.Context != nil { if err := c.Context.Unmarshal(clusterContext); err != nil { return err diff --git a/internal/pkg/object/command/ecs/ecs.go b/internal/pkg/object/command/ecs/ecs.go index 21672cc..a15819f 100644 --- a/internal/pkg/object/command/ecs/ecs.go +++ b/internal/pkg/object/command/ecs/ecs.go @@ -1,7 +1,7 @@ package ecs import ( - ct "context" + "context" "encoding/json" "fmt" "os" @@ -15,7 +15,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/hladush/go-telemetry/pkg/telemetry" heimdallAws "github.com/patterninc/heimdall/internal/pkg/aws" - "github.com/patterninc/heimdall/pkg/context" + heimdallContext "github.com/patterninc/heimdall/pkg/context" "github.com/patterninc/heimdall/pkg/duration" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" @@ -23,7 +23,7 @@ import ( ) // ECS command context structure -type ecsCommandContext struct { +type commandContext struct { TaskDefinitionTemplate string `yaml:"task_definition_template,omitempty" json:"task_definition_template,omitempty"` TaskCount int `yaml:"task_count,omitempty" json:"task_count,omitempty"` CPU int `yaml:"cpu,omitempty" json:"cpu,omitempty"` @@ -35,7 +35,7 @@ type ecsCommandContext struct { } // ECS cluster context structure -type ecsClusterContext struct { +type clusterContext struct { MaxCPU int `yaml:"max_cpu,omitempty" json:"max_cpu,omitempty"` MaxMemory int `yaml:"max_memory,omitempty" json:"max_memory,omitempty"` MaxTaskCount int `yaml:"max_task_count,omitempty" json:"max_task_count,omitempty"` @@ -86,7 +86,7 @@ type executionContext struct { Memory int `json:"memory"` TaskDefinitionWrapper *taskDefinitionWrapper `json:"task_definition_wrapper"` ContainerOverrides []types.ContainerOverride `json:"container_overrides"` - ClusterConfig *ecsClusterContext `json:"cluster_config"` + ClusterConfig *clusterContext `json:"cluster_config"` PollingInterval duration.Duration `json:"polling_interval"` Timeout duration.Duration `json:"timeout"` @@ -116,22 +116,21 @@ const ( ) var ( - ctx = ct.Background() errMissingTemplate = fmt.Errorf("task definition template is required") methodMetrics = telemetry.NewMethod("ecs", "ecs plugin") ) -func New(commandContext *context.Context) (plugin.Handler, error) { +func New(commandCtx *heimdallContext.Context) (plugin.Handler, error) { - e := &ecsCommandContext{ + e := &commandContext{ PollingInterval: defaultPollingInterval, Timeout: defaultTaskTimeout, MaxFailCount: defaultMaxFailCount, TaskCount: defaultTaskCount, } - if commandContext != nil { - if err := commandContext.Unmarshal(e); err != nil { + if commandCtx != nil { + if err := commandCtx.Unmarshal(e); err != nil { return nil, err } } @@ -141,31 +140,31 @@ func New(commandContext *context.Context) (plugin.Handler, error) { } // handler implements the main ECS plugin logic -func (e *ecsCommandContext) handler(r *plugin.Runtime, job *job.Job, cluster *cluster.Cluster) error { +func (e *commandContext) handler(ctx context.Context, r *plugin.Runtime, job *job.Job, cluster *cluster.Cluster) error { // Build execution context with resolved configuration and loaded template - execCtx, err := buildExecutionContext(e, job, cluster, r) + execCtx, err := buildExecutionContext(ctx, e, job, cluster, r) if err != nil { return err } // register task definition - if err := execCtx.registerTaskDefinition(); err != nil { + if err := execCtx.registerTaskDefinition(ctx); err != nil { return err } // Start tasks - if err := execCtx.startTasks(job.ID); err != nil { + if err := execCtx.startTasks(ctx, job.ID); err != nil { return err } // Poll for completion - if err := execCtx.pollForCompletion(); err != nil { + if err := execCtx.pollForCompletion(ctx); err != nil { return err } // Try to retrieve logs, but don't fail the job if it fails - if err := execCtx.retrieveLogs(); err != nil { + if err := execCtx.retrieveLogs(ctx); err != nil { execCtx.runtime.Stderr.WriteString(fmt.Sprintf("Failed to retrieve logs: %v\n", err)) } @@ -180,7 +179,7 @@ func (e *ecsCommandContext) handler(r *plugin.Runtime, job *job.Job, cluster *cl } // prepare and register task definition with ECS -func (execCtx *executionContext) registerTaskDefinition() error { +func (execCtx *executionContext) registerTaskDefinition(ctx context.Context) error { registerInput := &ecs.RegisterTaskDefinitionInput{ Family: aws.String(aws.ToString(execCtx.TaskDefinitionWrapper.TaskDefinition.Family)), RequiresCompatibilities: []types.Compatibility{types.CompatibilityFargate}, @@ -204,10 +203,10 @@ func (execCtx *executionContext) registerTaskDefinition() error { } // startTasks launches all tasks and returns a map of task trackers -func (execCtx *executionContext) startTasks(jobID string) error { +func (execCtx *executionContext) startTasks(ctx context.Context, jobID string) error { for i := 0; i < execCtx.TaskCount; i++ { - taskARN, err := runTask(execCtx, fmt.Sprintf("%s%s-%d", startedByPrefix, jobID, i), i) + taskARN, err := runTask(ctx, execCtx, fmt.Sprintf("%s%s-%d", startedByPrefix, jobID, i), i) if err != nil { return err } @@ -223,7 +222,7 @@ func (execCtx *executionContext) startTasks(jobID string) error { } // monitor tasks until completion, faliure, or timeout -func (execCtx *executionContext) pollForCompletion() error { +func (execCtx *executionContext) pollForCompletion(ctx context.Context) error { startTime := time.Now() stopTime := startTime.Add(time.Duration(execCtx.Timeout)) @@ -287,7 +286,7 @@ func (execCtx *executionContext) pollForCompletion() error { // Stop all other running tasks reason := fmt.Sprintf(errMaxFailCount, tracker.ActiveARN, tracker.Retries, execCtx.MaxFailCount) - if err := stopAllTasks(execCtx, reason); err != nil { + if err := stopAllTasks(ctx, execCtx, reason); err != nil { return err } @@ -296,7 +295,7 @@ func (execCtx *executionContext) pollForCompletion() error { break } - newTaskARN, err := runTask(execCtx, tracker.Name, tracker.TaskNum) + newTaskARN, err := runTask(ctx, execCtx, tracker.Name, tracker.TaskNum) if err != nil { return err } @@ -330,7 +329,7 @@ func (execCtx *executionContext) pollForCompletion() error { // Stop all remaining tasks reason := fmt.Sprintf(errPollingTimeout, incompleteARNs, execCtx.Timeout) - if err := stopAllTasks(execCtx, reason); err != nil { + if err := stopAllTasks(ctx, execCtx, reason); err != nil { return err } @@ -338,7 +337,7 @@ func (execCtx *executionContext) pollForCompletion() error { break } - // Sleep until next poll time + // sleep for polling interval time.Sleep(time.Duration(execCtx.PollingInterval)) } @@ -347,7 +346,7 @@ func (execCtx *executionContext) pollForCompletion() error { } -func buildExecutionContext(commandCtx *ecsCommandContext, j *job.Job, c *cluster.Cluster, runtime *plugin.Runtime) (*executionContext, error) { +func buildExecutionContext(ctx context.Context, commandCtx *commandContext, j *job.Job, c *cluster.Cluster, runtime *plugin.Runtime) (*executionContext, error) { execCtx := &executionContext{ tasks: make(map[string]*taskTracker), @@ -355,7 +354,7 @@ func buildExecutionContext(commandCtx *ecsCommandContext, j *job.Job, c *cluster } // Create a context from commandCtx and unmarshal onto execCtx (defaults) - commandContext := context.New(commandCtx) + commandContext := heimdallContext.New(commandCtx) if err := commandContext.Unmarshal(execCtx); err != nil { return nil, err } @@ -368,7 +367,7 @@ func buildExecutionContext(commandCtx *ecsCommandContext, j *job.Job, c *cluster } // Add cluster config (no overlapping values) - clusterContext := &ecsClusterContext{} + clusterContext := &clusterContext{} if c.Context != nil { if err := c.Context.Unmarshal(clusterContext); err != nil { return nil, err @@ -464,7 +463,7 @@ func buildContainerOverrides(execCtx *executionContext) error { } // stopAllTasks stops all non-completed tasks with the given reason -func stopAllTasks(execCtx *executionContext, reason string) error { +func stopAllTasks(ctx context.Context, execCtx *executionContext, reason string) error { // AWS ECS has a 1024 character limit on the reason field if len(reason) > 1024 { reason = reason[:1021] + "..." @@ -537,7 +536,7 @@ func loadTaskDefinitionTemplate(templatePath string) (*taskDefinitionWrapper, er } // runTask runs a single task and returns the task ARN -func runTask(execCtx *executionContext, startedBy string, taskNum int) (string, error) { +func runTask(ctx context.Context, execCtx *executionContext, startedBy string, taskNum int) (string, error) { // Create a copy of the overrides and add TASK_NAME and TASK_NUM env variables finalOverrides := append([]types.ContainerOverride{}, execCtx.ContainerOverrides...) @@ -603,7 +602,7 @@ func isTaskSuccessful(task types.Task, execCtx *executionContext) bool { } // We pull logs from cloudwatch for all containers in a single task that represents the job outcome -func (execCtx *executionContext) retrieveLogs() error { +func (execCtx *executionContext) retrieveLogs(ctx context.Context) error { var selectedTask *taskTracker var writer *os.File @@ -655,7 +654,7 @@ func (execCtx *executionContext) retrieveLogs() error { case types.LogDriverAwslogs: logGroup := logInfo.options["awslogs-group"] logStream := fmt.Sprintf("%s/%s/%s", logInfo.options["awslogs-stream-prefix"], logInfo.containerName, taskID) - if err := heimdallAws.PullLogs(writer, logGroup, logStream, maxLogChunkSize, maxLogMemoryBytes); err != nil { + if err := heimdallAws.PullLogs(ctx, writer, logGroup, logStream, maxLogChunkSize, maxLogMemoryBytes); err != nil { return err } default: diff --git a/internal/pkg/object/command/glue/glue.go b/internal/pkg/object/command/glue/glue.go index 63128a9..63503b6 100644 --- a/internal/pkg/object/command/glue/glue.go +++ b/internal/pkg/object/command/glue/glue.go @@ -1,28 +1,30 @@ package glue import ( + "context" + "github.com/patterninc/heimdall/internal/pkg/aws" - "github.com/patterninc/heimdall/pkg/context" + heimdallContext "github.com/patterninc/heimdall/pkg/context" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" "github.com/patterninc/heimdall/pkg/plugin" "github.com/patterninc/heimdall/pkg/result" ) -type glueCommandContext struct { +type commandContext struct { CatalogID string `yaml:"catalog_id,omitempty" json:"catalog_id,omitempty"` } -type glueJobContext struct { +type jobContext struct { TableName string `yaml:"table_name,omitempty" json:"table_name,omitempty"` } -func New(commandContext *context.Context) (plugin.Handler, error) { +func New(commandCtx *heimdallContext.Context) (plugin.Handler, error) { - g := &glueCommandContext{} + g := &commandContext{} - if commandContext != nil { - if err := commandContext.Unmarshal(g); err != nil { + if commandCtx != nil { + if err := commandCtx.Unmarshal(g); err != nil { return nil, err } } @@ -31,10 +33,10 @@ func New(commandContext *context.Context) (plugin.Handler, error) { } -func (g *glueCommandContext) handler(_ *plugin.Runtime, j *job.Job, _ *cluster.Cluster) (err error) { +func (g *commandContext) handler(ctx context.Context, _ *plugin.Runtime, j *job.Job, _ *cluster.Cluster) (err error) { // let's unmarshal job context - jc := &glueJobContext{} + jc := &jobContext{} if j.Context != nil { if err = j.Context.Unmarshal(jc); err != nil { return @@ -42,7 +44,7 @@ func (g *glueCommandContext) handler(_ *plugin.Runtime, j *job.Job, _ *cluster.C } // let's get our metadata - metadata, err := aws.GetTableMetadata(g.CatalogID, jc.TableName) + metadata, err := aws.GetTableMetadata(ctx, g.CatalogID, jc.TableName) if err != nil { return } diff --git a/internal/pkg/object/command/ping/ping.go b/internal/pkg/object/command/ping/ping.go index 64fd333..b09a22c 100644 --- a/internal/pkg/object/command/ping/ping.go +++ b/internal/pkg/object/command/ping/ping.go @@ -1,9 +1,10 @@ package ping import ( + "context" "fmt" - "github.com/patterninc/heimdall/pkg/context" + heimdallContext "github.com/patterninc/heimdall/pkg/context" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" "github.com/patterninc/heimdall/pkg/plugin" @@ -14,16 +15,16 @@ const ( messageFormat = `Hello, %s!` ) -type pingCommandContext struct{} +type commandContext struct{} -func New(_ *context.Context) (plugin.Handler, error) { +func New(_ *heimdallContext.Context) (plugin.Handler, error) { - p := &pingCommandContext{} + p := &commandContext{} return p.handler, nil } -func (p *pingCommandContext) handler(_ *plugin.Runtime, j *job.Job, _ *cluster.Cluster) (err error) { +func (p *commandContext) handler(ctx context.Context, _ *plugin.Runtime, j *job.Job, _ *cluster.Cluster) (err error) { j.Result, err = result.FromMessage(fmt.Sprintf(messageFormat, j.User)) return diff --git a/internal/pkg/object/command/shell/shell.go b/internal/pkg/object/command/shell/shell.go index d16b20d..aa34040 100644 --- a/internal/pkg/object/command/shell/shell.go +++ b/internal/pkg/object/command/shell/shell.go @@ -1,12 +1,13 @@ package shell import ( + "context" "encoding/json" "os" "os/exec" "path" - "github.com/patterninc/heimdall/pkg/context" + heimdallContext "github.com/patterninc/heimdall/pkg/context" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" "github.com/patterninc/heimdall/pkg/plugin" @@ -18,27 +19,27 @@ const ( contextFilename = `context.json` ) -type shellCommandContext struct { +type commandContext struct { Command []string `yaml:"command,omitempty" json:"command,omitempty"` } -type shellJobContext struct { +type jobContext struct { Arguments []string `yaml:"arguments,omitempty" json:"arguments,omitempty"` } type runtimeContext struct { - Job *job.Job `yaml:"job,omitempty" json:"job,omitempty"` - Command *shellCommandContext `yaml:"command,omitempty" json:"command,omitempty"` - Cluster *cluster.Cluster `yaml:"cluster,omitempty" json:"cluster,omitempty"` - Runtime *plugin.Runtime `yaml:"runtime,omitempty" json:"runtime,omitempty"` + Job *job.Job `yaml:"job,omitempty" json:"job,omitempty"` + Command *commandContext `yaml:"command,omitempty" json:"command,omitempty"` + Cluster *cluster.Cluster `yaml:"cluster,omitempty" json:"cluster,omitempty"` + Runtime *plugin.Runtime `yaml:"runtime,omitempty" json:"runtime,omitempty"` } -func New(commandContext *context.Context) (plugin.Handler, error) { +func New(commandCtx *heimdallContext.Context) (plugin.Handler, error) { - s := &shellCommandContext{} + s := &commandContext{} - if commandContext != nil { - if err := commandContext.Unmarshal(s); err != nil { + if commandCtx != nil { + if err := commandCtx.Unmarshal(s); err != nil { return nil, err } } @@ -47,10 +48,10 @@ func New(commandContext *context.Context) (plugin.Handler, error) { } -func (s *shellCommandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { +func (s *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { // let's unmarshal job context - jc := &shellJobContext{} + jc := &jobContext{} if j.Context != nil { if err := j.Context.Unmarshal(jc); err != nil { return err @@ -82,7 +83,7 @@ func (s *shellCommandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster. commandWithArguments = append(commandWithArguments, jc.Arguments...) // configure command - cmd := exec.Command(commandWithArguments[0], commandWithArguments[1:]...) + cmd := exec.CommandContext(ctx, commandWithArguments[0], commandWithArguments[1:]...) cmd.Stdout = r.Stdout cmd.Stderr = r.Stderr diff --git a/internal/pkg/object/command/snowflake/snowflake.go b/internal/pkg/object/command/snowflake/snowflake.go index be267a7..cf850db 100644 --- a/internal/pkg/object/command/snowflake/snowflake.go +++ b/internal/pkg/object/command/snowflake/snowflake.go @@ -1,6 +1,7 @@ package snowflake import ( + "context" "crypto/rsa" "crypto/x509" "database/sql" @@ -10,7 +11,7 @@ import ( sf "github.com/snowflakedb/gosnowflake" - "github.com/patterninc/heimdall/pkg/context" + heimdallContext "github.com/patterninc/heimdall/pkg/context" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" "github.com/patterninc/heimdall/pkg/plugin" @@ -27,13 +28,13 @@ var ( ErrInvalidKeyType = fmt.Errorf(`invalida key type`) ) -type snowflakeCommandContext struct{} +type commandContext struct{} -type snowflakeJobContext struct { +type jobContext struct { Query string `yaml:"query,omitempty" json:"query,omitempty"` } -type snowflakeClusterContext struct { +type clusterContext struct { Account string `yaml:"account,omitempty" json:"account,omitempty"` User string `yaml:"user,omitempty" json:"user,omitempty"` Database string `yaml:"database,omitempty" json:"database,omitempty"` @@ -65,14 +66,14 @@ func parsePrivateKey(privateKeyBytes []byte) (*rsa.PrivateKey, error) { } -func New(_ *context.Context) (plugin.Handler, error) { - s := &snowflakeCommandContext{} +func New(_ *heimdallContext.Context) (plugin.Handler, error) { + s := &commandContext{} return s.handler, nil } -func (s *snowflakeCommandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { +func (s *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { - clusterContext := &snowflakeClusterContext{} + clusterContext := &clusterContext{} if c.Context != nil { if err := c.Context.Unmarshal(clusterContext); err != nil { return err @@ -80,7 +81,7 @@ func (s *snowflakeCommandContext) handler(r *plugin.Runtime, j *job.Job, c *clus } // let's unmarshal job context - jobContext := &snowflakeJobContext{} + jobContext := &jobContext{} if err := j.Context.Unmarshal(jobContext); err != nil { return err } @@ -117,7 +118,7 @@ func (s *snowflakeCommandContext) handler(r *plugin.Runtime, j *job.Job, c *clus } defer db.Close() - rows, err := db.Query(jobContext.Query) + rows, err := db.QueryContext(ctx, jobContext.Query) if err != nil { return err } diff --git a/internal/pkg/object/command/spark/spark.go b/internal/pkg/object/command/spark/spark.go index 9d0f382..0583a04 100644 --- a/internal/pkg/object/command/spark/spark.go +++ b/internal/pkg/object/command/spark/spark.go @@ -1,7 +1,7 @@ package spark import ( - ct "context" + "context" "encoding/json" "fmt" "os" @@ -19,7 +19,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/babourine/x/pkg/set" - "github.com/patterninc/heimdall/pkg/context" + heimdallContext "github.com/patterninc/heimdall/pkg/context" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" "github.com/patterninc/heimdall/pkg/plugin" @@ -32,7 +32,7 @@ type sparkSubmitParameters struct { } // spark represents the Spark command context -type sparkCommandContext struct { +type commandContext struct { QueriesURI string `yaml:"queries_uri,omitempty" json:"queries_uri,omitempty"` ResultsURI string `yaml:"results_uri,omitempty" json:"results_uri,omitempty"` LogsURI *string `yaml:"logs_uri,omitempty" json:"logs_uri,omitempty"` @@ -41,7 +41,7 @@ type sparkCommandContext struct { } // sparkJobContext represents the context for a spark job -type sparkJobContext struct { +type jobContext struct { Query string `yaml:"query,omitempty" json:"query,omitempty"` Arguments []string `yaml:"arguments,omitempty" json:"arguments,omitempty"` Parameters *sparkSubmitParameters `yaml:"parameters,omitempty" json:"parameters,omitempty"` @@ -49,7 +49,7 @@ type sparkJobContext struct { } // sparkClusterContext represents the context for a spark cluster -type sparkClusterContext struct { +type clusterContext struct { ExecutionRoleArn *string `yaml:"execution_role_arn,omitempty" json:"execution_role_arn,omitempty"` EMRReleaseLabel *string `yaml:"emr_release_label,omitempty" json:"emr_release_label,omitempty"` RoleARN *string `yaml:"role_arn,omitempty" json:"role_arn,omitempty"` @@ -64,7 +64,6 @@ const ( ) var ( - ctx = ct.Background() sparkDefaults = aws.String(`spark-defaults`) assumeRoleSession = aws.String("AssumeRoleSession") runtimeStates = set.New([]types.JobRunState{types.JobRunStateCompleted, types.JobRunStateFailed, types.JobRunStateCancelled}) @@ -77,12 +76,12 @@ var ( ) // New creates a new Spark plugin handler. -func New(commandContext *context.Context) (plugin.Handler, error) { +func New(commandCtx *heimdallContext.Context) (plugin.Handler, error) { - s := &sparkCommandContext{} + s := &commandContext{} - if commandContext != nil { - if err := commandContext.Unmarshal(s); err != nil { + if commandCtx != nil { + if err := commandCtx.Unmarshal(s); err != nil { return nil, err } } @@ -92,10 +91,10 @@ func New(commandContext *context.Context) (plugin.Handler, error) { } // Handler for the Spark job submission. -func (s *sparkCommandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster.Cluster) (err error) { +func (s *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) (err error) { // let's unmarshal job context - jobContext := &sparkJobContext{} + jobContext := &jobContext{} if j.Context != nil { if err := j.Context.Unmarshal(jobContext); err != nil { return err @@ -103,7 +102,7 @@ func (s *sparkCommandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster. } // let's unmarshal cluster context - clusterContext := &sparkClusterContext{} + clusterContext := &clusterContext{} if c.Context != nil { if err := c.Context.Unmarshal(clusterContext); err != nil { return err @@ -165,7 +164,7 @@ func (s *sparkCommandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster. svc := emrcontainers.NewFromConfig(awsConfig, assumeRoleOptions) // let's get the cluster ID - clusterID, err := getClusterID(svc, c.Name) + clusterID, err := getClusterID(ctx, svc, c.Name) if err != nil { return err } @@ -175,7 +174,7 @@ func (s *sparkCommandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster. // upload query to s3 here... queryURI := fmt.Sprintf("%s/%s/query.sql", s.QueriesURI, j.ID) - if err := uploadFileToS3(queryURI, jobContext.Query); err != nil { + if err := uploadFileToS3(ctx, queryURI, jobContext.Query); err != nil { return err } @@ -262,7 +261,7 @@ timeoutLoop: } -func (s *sparkCommandContext) setJobDriver(jobContext *sparkJobContext, jobDriver *types.JobDriver, queryURI string, resultURI string) { +func (s *commandContext) setJobDriver(jobContext *jobContext, jobDriver *types.JobDriver, queryURI string, resultURI string) { jobParameters := getSparkSubmitParameters(jobContext) if jobContext.Arguments != nil { jobDriver.SparkSubmitJobDriver = &types.SparkSubmitJobDriver{ @@ -288,7 +287,7 @@ func (s *sparkCommandContext) setJobDriver(jobContext *sparkJobContext, jobDrive } -func getClusterID(svc *emrcontainers.Client, clusterName string) (*string, error) { +func getClusterID(ctx context.Context, svc *emrcontainers.Client, clusterName string) (*string, error) { // let's get the cluster ID outputListClusters, err := svc.ListVirtualClusters(ctx, &emrcontainers.ListVirtualClustersInput{ @@ -309,7 +308,7 @@ func getClusterID(svc *emrcontainers.Client, clusterName string) (*string, error } -func getSparkSubmitParameters(context *sparkJobContext) *string { +func getSparkSubmitParameters(context *jobContext) *string { properties := context.Parameters.Properties conf := make([]string, 0, len(properties)) @@ -327,7 +326,7 @@ func printState(stdout *os.File, state types.JobRunState) { stdout.WriteString(fmt.Sprintf("%v - job is still running. latest status: %v\n", time.Now(), state)) } -func uploadFileToS3(fileURI, content string) error { +func uploadFileToS3(ctx context.Context, fileURI, content string) error { // get bucket name and prefix s3Parts := rxS3.FindAllStringSubmatch(fileURI, -1) diff --git a/internal/pkg/object/command/sparkeks/sparkeks.go b/internal/pkg/object/command/sparkeks/sparkeks.go index 4750889..ac9e8a1 100644 --- a/internal/pkg/object/command/sparkeks/sparkeks.go +++ b/internal/pkg/object/command/sparkeks/sparkeks.go @@ -73,7 +73,6 @@ const ( ) var ( - ctx = context.Background() rxS3 = regexp.MustCompile(`^s3://([^/]+)/(.*)$`) runtimeStates = []v1beta2.ApplicationStateType{ v1beta2.ApplicationStateCompleted, @@ -91,24 +90,24 @@ var ( ErrSparkApplicationFile = fmt.Errorf("failed to read SparkApplication application template file: check file path and permissions") ) -type sparkEksCommandContext struct { +type commandContext struct { JobsURI string `yaml:"jobs_uri,omitempty" json:"jobs_uri,omitempty"` WrapperURI string `yaml:"wrapper_uri,omitempty" json:"wrapper_uri,omitempty"` Properties map[string]string `yaml:"properties,omitempty" json:"properties,omitempty"` KubeNamespace string `yaml:"kube_namespace,omitempty" json:"kube_namespace,omitempty"` } -type sparkEksJobParameters struct { +type jobParameters struct { Properties map[string]string `yaml:"properties,omitempty" json:"properties,omitempty"` } -type sparkEksJobContext struct { - Query string `yaml:"query,omitempty" json:"query,omitempty"` - Parameters *sparkEksJobParameters `yaml:"parameters,omitempty" json:"parameters,omitempty"` - ReturnResult bool `yaml:"return_result,omitempty" json:"return_result,omitempty"` +type jobContext struct { + Query string `yaml:"query,omitempty" json:"query,omitempty"` + Parameters *jobParameters `yaml:"parameters,omitempty" json:"parameters,omitempty"` + ReturnResult bool `yaml:"return_result,omitempty" json:"return_result,omitempty"` } -type sparkEksClusterContext struct { +type clusterContext struct { RoleARN *string `yaml:"role_arn,omitempty" json:"role_arn,omitempty"` Properties map[string]string `yaml:"properties,omitempty" json:"properties,omitempty"` Image *string `yaml:"image,omitempty" json:"image,omitempty"` @@ -122,9 +121,9 @@ type executionContext struct { runtime *plugin.Runtime job *job.Job cluster *cluster.Cluster - commandContext *sparkEksCommandContext - jobContext *sparkEksJobContext - clusterContext *sparkEksClusterContext + commandContext *commandContext + jobContext *jobContext + clusterContext *clusterContext sparkClient *sparkClientSet.Clientset kubeClient *kubernetes.Clientset @@ -140,13 +139,13 @@ type executionContext struct { } // New creates a new Spark EKS plugin handler. -func New(commandContext *heimdallContext.Context) (plugin.Handler, error) { - s := &sparkEksCommandContext{ +func New(commandCtx *heimdallContext.Context) (plugin.Handler, error) { + s := &commandContext{ KubeNamespace: defaultNamespace, } - if commandContext != nil { - if err := commandContext.Unmarshal(s); err != nil { + if commandCtx != nil { + if err := commandCtx.Unmarshal(s); err != nil { return nil, err } } @@ -155,23 +154,24 @@ func New(commandContext *heimdallContext.Context) (plugin.Handler, error) { } // handler executes the Spark EKS job submission and execution. -func (s *sparkEksCommandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { +func (s *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { + // 1. Build execution context, create URIs, and upload query - execCtx, err := buildExecutionContextAndURI(r, j, c, s) + execCtx, err := buildExecutionContextAndURI(ctx, r, j, c, s) if err != nil { return fmt.Errorf("failed to build execution context: %w", err) } // 2. Submit the Spark Application to the cluster - if err := execCtx.submitSparkApp(); err != nil { + if err := execCtx.submitSparkApp(ctx); err != nil { return err } // 3. Monitor the job until completion and collect logs - monitorErr := execCtx.monitorJobAndCollectLogs() + monitorErr := execCtx.monitorJobAndCollectLogs(ctx) // 4. Cleanup any resources that are still pending - if err := execCtx.cleanupSparkApp(); err != nil { + if err := execCtx.cleanupSparkApp(ctx); err != nil { // Log cleanup error but don't override the main monitoring error execCtx.runtime.Stderr.WriteString(fmt.Sprintf("Warning: failed to cleanup application %s: %v\n", execCtx.submittedApp.Name, err)) } @@ -182,7 +182,7 @@ func (s *sparkEksCommandContext) handler(r *plugin.Runtime, j *job.Job, c *clust } // 5. Get and store results if required - if err := execCtx.getAndStoreResults(); err != nil { + if err := execCtx.getAndStoreResults(ctx); err != nil { return err } @@ -190,7 +190,7 @@ func (s *sparkEksCommandContext) handler(r *plugin.Runtime, j *job.Job, c *clust } // buildExecutionContextAndURI prepares the context, merges configurations, and uploads the query. -func buildExecutionContextAndURI(r *plugin.Runtime, j *job.Job, c *cluster.Cluster, s *sparkEksCommandContext) (*executionContext, error) { +func buildExecutionContextAndURI(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster, s *commandContext) (*executionContext, error) { execCtx := &executionContext{ runtime: r, job: j, @@ -199,7 +199,7 @@ func buildExecutionContextAndURI(r *plugin.Runtime, j *job.Job, c *cluster.Clust } // Parse job context - jobContext := &sparkEksJobContext{} + jobContext := &jobContext{} if j.Context != nil { if err := j.Context.Unmarshal(jobContext); err != nil { return nil, fmt.Errorf("failed to unmarshal job context: %w", err) @@ -208,7 +208,7 @@ func buildExecutionContextAndURI(r *plugin.Runtime, j *job.Job, c *cluster.Clust execCtx.jobContext = jobContext // Parse cluster context - clusterContext := &sparkEksClusterContext{} + clusterContext := &clusterContext{} if c.Context != nil { if err := c.Context.Unmarshal(clusterContext); err != nil { return nil, fmt.Errorf("failed to unmarshal cluster context: %w", err) @@ -218,7 +218,7 @@ func buildExecutionContextAndURI(r *plugin.Runtime, j *job.Job, c *cluster.Clust // Initialize and merge properties from command -> job if execCtx.jobContext.Parameters == nil { - execCtx.jobContext.Parameters = &sparkEksJobParameters{ + execCtx.jobContext.Parameters = &jobParameters{ Properties: make(map[string]string), } } @@ -248,12 +248,12 @@ func buildExecutionContextAndURI(r *plugin.Runtime, j *job.Job, c *cluster.Clust execCtx.logURI = fmt.Sprintf("%s/%s/%s", s.JobsURI, j.ID, logsPath) // Upload query to S3 - if err := uploadFileToS3(execCtx.awsConfig, execCtx.queryURI, execCtx.jobContext.Query); err != nil { + if err := uploadFileToS3(ctx, execCtx.awsConfig, execCtx.queryURI, execCtx.jobContext.Query); err != nil { return nil, fmt.Errorf("failed to upload query to S3: %w", err) } // create empty log s3 directory to avoid spark event log dir errors - if err := uploadFileToS3(execCtx.awsConfig, fmt.Sprintf("%s/.keepdir", execCtx.logURI), ""); err != nil { + if err := uploadFileToS3(ctx, execCtx.awsConfig, fmt.Sprintf("%s/.keepdir", execCtx.logURI), ""); err != nil { return nil, fmt.Errorf("failed to create log directory in S3: %w", err) } @@ -261,9 +261,9 @@ func buildExecutionContextAndURI(r *plugin.Runtime, j *job.Job, c *cluster.Clust } // submitSparkApp creates clients, generates the spec, and submits it to Kubernetes. -func (e *executionContext) submitSparkApp() error { +func (e *executionContext) submitSparkApp(ctx context.Context) error { // Create Kubernetes and Spark Operator clients - if err := createSparkClients(e); err != nil { + if err := createSparkClients(ctx, e); err != nil { return fmt.Errorf("failed to create Spark Operator client: %w", err) } @@ -297,7 +297,7 @@ func (e *executionContext) submitSparkApp() error { } // cleanupSparkApp removes the SparkApplication from the cluster if it still exists. -func (e *executionContext) cleanupSparkApp() error { +func (e *executionContext) cleanupSparkApp(ctx context.Context) error { if e.submittedApp == nil { return nil } @@ -316,12 +316,12 @@ func (e *executionContext) cleanupSparkApp() error { } // getAndStoreResults fetches the job output from S3 and stores it. -func (e *executionContext) getAndStoreResults() error { +func (e *executionContext) getAndStoreResults(ctx context.Context) error { if !e.jobContext.ReturnResult { return nil } - returnResultFileURI, err := getS3FileURI(e.awsConfig, e.resultURI, avroFileExtension) + returnResultFileURI, err := getS3FileURI(ctx, e.awsConfig, e.resultURI, avroFileExtension) if err != nil { e.runtime.Stdout.WriteString(fmt.Sprintf("failed to find .avro file in results directory %s: %s", e.resultURI, err)) return fmt.Errorf("failed to find .avro file in results directory %s: %w", e.resultURI, err) @@ -335,7 +335,7 @@ func (e *executionContext) getAndStoreResults() error { } // uploadFileToS3 uploads content to S3. -func uploadFileToS3(awsConfig aws.Config, fileURI, content string) error { +func uploadFileToS3(ctx context.Context, awsConfig aws.Config, fileURI, content string) error { s3Parts := rxS3.FindAllStringSubmatch(fileURI, -1) if len(s3Parts) == 0 || len(s3Parts[0]) < 3 { return fmt.Errorf("unexpected S3 URI format: %s", fileURI) @@ -358,7 +358,7 @@ func updateS3ToS3aURI(uri string) string { } // getS3FileURI finds a file in an S3 directory that matches the given extension. -func getS3FileURI(awsConfig aws.Config, directoryURI, matchingExtension string) (string, error) { +func getS3FileURI(ctx context.Context, awsConfig aws.Config, directoryURI, matchingExtension string) (string, error) { s3Parts := rxS3.FindAllStringSubmatch(directoryURI, -1) if len(s3Parts) == 0 || len(s3Parts[0]) < 3 { return "", fmt.Errorf("invalid S3 URI format: %s", directoryURI) @@ -390,7 +390,7 @@ func printState(writer io.Writer, state v1beta2.ApplicationStateType) { } // getSparkSubmitParameters returns Spark submit parameters as a string. -func getSparkSubmitParameters(context *sparkEksJobContext) *string { +func getSparkSubmitParameters(context *jobContext) *string { if context.Parameters == nil || len(context.Parameters.Properties) == 0 { return nil } @@ -403,7 +403,7 @@ func getSparkSubmitParameters(context *sparkEksJobContext) *string { } // getSparkApplicationPods returns the list of pods associated with a Spark application. -func getSparkApplicationPods(kubeClient *kubernetes.Clientset, appName, namespace string) ([]corev1.Pod, error) { +func getSparkApplicationPods(ctx context.Context, kubeClient *kubernetes.Clientset, appName, namespace string) ([]corev1.Pod, error) { labelSelector := fmt.Sprintf("%s=%s", sparkAppLabelSelectorFormat, appName) podList, err := kubeClient.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{LabelSelector: labelSelector}) if err != nil { @@ -437,7 +437,7 @@ func writeDriverLogsToStderr(execCtx *executionContext, pod corev1.Pod, logConte } // getAndUploadPodContainerLogs fetches logs from a specific container in a pod and uploads them to S3. -func getAndUploadPodContainerLogs(execCtx *executionContext, pod corev1.Pod, container corev1.Container, previous bool, logType string, writeToStderr bool) { +func getAndUploadPodContainerLogs(ctx context.Context, execCtx *executionContext, pod corev1.Pod, container corev1.Container, previous bool, logType string, writeToStderr bool) { logOptions := &corev1.PodLogOptions{Container: container.Name, Previous: previous} req := execCtx.kubeClient.CoreV1().Pods(pod.Namespace).GetLogs(pod.Name, logOptions) logs, err := req.Stream(ctx) @@ -457,31 +457,31 @@ func getAndUploadPodContainerLogs(execCtx *executionContext, pod corev1.Pod, con } logURI := fmt.Sprintf("%s/%s-%s", execCtx.logURI, pod.Name, logType) - if err := uploadFileToS3(execCtx.awsConfig, logURI, string(logContent)); err != nil { + if err := uploadFileToS3(ctx, execCtx.awsConfig, logURI, string(logContent)); err != nil { execCtx.runtime.Stderr.WriteString(fmt.Sprintf("Pod %s, container %s: %s upload error: %v\n", pod.Name, container.Name, logType, err)) } } } // getSparkApplicationPodLogs fetches logs from pods and uploads them to S3. -func getSparkApplicationPodLogs(execCtx *executionContext, pods []corev1.Pod, writeToStderr bool) error { +func getSparkApplicationPodLogs(ctx context.Context, execCtx *executionContext, pods []corev1.Pod, writeToStderr bool) error { for _, pod := range pods { if !isPodInValidPhase(pod) { continue } for _, container := range pod.Spec.Containers { // Get current logs and upload - getAndUploadPodContainerLogs(execCtx, pod, container, false, stdoutLogSuffix, writeToStderr) + getAndUploadPodContainerLogs(ctx, execCtx, pod, container, false, stdoutLogSuffix, writeToStderr) // Get logs from previous (failed) runs and upload - getAndUploadPodContainerLogs(execCtx, pod, container, true, stderrLogSuffix, false) + getAndUploadPodContainerLogs(ctx, execCtx, pod, container, true, stderrLogSuffix, false) } } return nil } // createSparkClients creates Kubernetes and Spark clients for the EKS cluster. -func createSparkClients(execCtx *executionContext) error { - kubeconfigPath, err := updateKubeConfig(execCtx) +func createSparkClients(ctx context.Context, execCtx *executionContext) error { + kubeconfigPath, err := updateKubeConfig(ctx, execCtx) if err != nil { return fmt.Errorf("failed to update kubeconfig: %w", err) } @@ -513,7 +513,7 @@ func createSparkClients(execCtx *executionContext) error { return nil } -func updateKubeConfig(execCtx *executionContext) (string, error) { +func updateKubeConfig(ctx context.Context, execCtx *executionContext) (string, error) { region := os.Getenv(awsRegionEnvVar) if execCtx.clusterContext.Region != nil { region = *execCtx.clusterContext.Region @@ -762,7 +762,7 @@ func loadTemplate(execCtx *executionContext) (*v1beta2.SparkApplication, error) } // monitorJobAndCollectLogs monitors the Spark job until completion and collects logs. -func (e *executionContext) monitorJobAndCollectLogs() error { +func (e *executionContext) monitorJobAndCollectLogs(ctx context.Context) error { appName, namespace := e.submittedApp.Name, e.submittedApp.Namespace e.runtime.Stdout.WriteString(fmt.Sprintf("Monitoring Spark application: %s\n", appName)) @@ -774,7 +774,7 @@ func (e *executionContext) monitorJobAndCollectLogs() error { for { if monitorCtx.Err() != nil { if finalSparkApp != nil { - collectSparkApplicationLogs(e, finalSparkApp, true) + collectSparkApplicationLogs(ctx, e, finalSparkApp, true) } if monitorCtx.Err() == context.DeadlineExceeded { return fmt.Errorf("spark job timed out after %v", jobTimeout) @@ -788,7 +788,7 @@ func (e *executionContext) monitorJobAndCollectLogs() error { if err != nil { e.runtime.Stderr.WriteString(fmt.Sprintf("Spark application %s/%s not found or deleted externally: %v\n", namespace, appName, err)) if finalSparkApp != nil { - collectSparkApplicationLogs(e, finalSparkApp, true) + collectSparkApplicationLogs(ctx, e, finalSparkApp, true) } return fmt.Errorf("spark application %s/%s not found: %w", namespace, appName, err) } @@ -802,17 +802,17 @@ func (e *executionContext) monitorJobAndCollectLogs() error { } if state == v1beta2.ApplicationStateRunning { - collectSparkApplicationLogs(e, sparkApp, false) + collectSparkApplicationLogs(ctx, e, sparkApp, false) continue } switch state { case v1beta2.ApplicationStateCompleted: - collectSparkApplicationLogs(e, sparkApp, false) + collectSparkApplicationLogs(ctx, e, sparkApp, false) e.runtime.Stdout.WriteString("Spark job completed successfully\n") return nil case v1beta2.ApplicationStateFailed: - collectSparkApplicationLogs(e, sparkApp, true) + collectSparkApplicationLogs(ctx, e, sparkApp, true) errorMessage := sparkApp.Status.AppState.ErrorMessage if errorMessage == "" { errorMessage = unknownErrorMsg @@ -820,7 +820,7 @@ func (e *executionContext) monitorJobAndCollectLogs() error { e.runtime.Stderr.WriteString(fmt.Sprintf("Spark job failed: %s\n", errorMessage)) return fmt.Errorf("spark job failed: %s", errorMessage) case v1beta2.ApplicationStateFailedSubmission, v1beta2.ApplicationStateUnknown: - collectSparkApplicationLogs(e, finalSparkApp, true) + collectSparkApplicationLogs(ctx, e, finalSparkApp, true) msg := sparkJobSubmissionFailedMsg if state == v1beta2.ApplicationStateUnknown { msg = sparkAppUnknownStateMsg @@ -831,17 +831,17 @@ func (e *executionContext) monitorJobAndCollectLogs() error { } // collectSparkApplicationLogs collects logs from Spark application pods. -func collectSparkApplicationLogs(execCtx *executionContext, sparkApp *v1beta2.SparkApplication, writeToStderr bool) { +func collectSparkApplicationLogs(ctx context.Context, execCtx *executionContext, sparkApp *v1beta2.SparkApplication, writeToStderr bool) { if sparkApp == nil { return } - pods, err := getSparkApplicationPods(execCtx.kubeClient, sparkApp.Name, sparkApp.Namespace) + pods, err := getSparkApplicationPods(ctx, execCtx.kubeClient, sparkApp.Name, sparkApp.Namespace) if err != nil { execCtx.runtime.Stderr.WriteString(fmt.Sprintf("Warning: failed to get Spark application pods: %v\n", err)) return } - if err := getSparkApplicationPodLogs(execCtx, pods, writeToStderr); err != nil { + if err := getSparkApplicationPodLogs(ctx, execCtx, pods, writeToStderr); err != nil { execCtx.runtime.Stderr.WriteString(fmt.Sprintf("Warning: failed to collect pod logs: %v\n", err)) } } diff --git a/internal/pkg/object/command/sparkeks/sparkeks_test.go b/internal/pkg/object/command/sparkeks/sparkeks_test.go index ad56b69..dc823c4 100644 --- a/internal/pkg/object/command/sparkeks/sparkeks_test.go +++ b/internal/pkg/object/command/sparkeks/sparkeks_test.go @@ -1,6 +1,7 @@ package sparkeks import ( + "context" "os" "strings" "testing" @@ -24,8 +25,8 @@ func TestUpdateS3ToS3aURI(t *testing.T) { } func TestGetSparkSubmitParameters(t *testing.T) { - ctx := &sparkEksJobContext{ - Parameters: &sparkEksJobParameters{ + ctx := &jobContext{ + Parameters: &jobParameters{ Properties: map[string]string{ "spark.executor.memory": "4g", "spark.driver.cores": "2", @@ -39,8 +40,8 @@ func TestGetSparkSubmitParameters(t *testing.T) { } func TestGetSparkSubmitParameters_Empty(t *testing.T) { - ctx := &sparkEksJobContext{ - Parameters: &sparkEksJobParameters{ + ctx := &jobContext{ + Parameters: &jobParameters{ Properties: map[string]string{}, }, } @@ -51,7 +52,7 @@ func TestGetSparkSubmitParameters_Empty(t *testing.T) { } func TestGetSparkSubmitParameters_NilParameters(t *testing.T) { - ctx := &sparkEksJobContext{ + ctx := &jobContext{ Parameters: nil, } params := getSparkSubmitParameters(ctx) @@ -61,7 +62,7 @@ func TestGetSparkSubmitParameters_NilParameters(t *testing.T) { } func TestGetS3FileURI_InvalidFormat(t *testing.T) { - _, err := getS3FileURI(aws.Config{}, "invalid-uri", "avro") + _, err := getS3FileURI(context.Background(), aws.Config{}, "invalid-uri", "avro") if err == nil { t.Error("Expected error for invalid S3 URI format") } @@ -70,7 +71,7 @@ func TestGetS3FileURI_InvalidFormat(t *testing.T) { func TestGetS3FileURI_ValidFormat(t *testing.T) { // This test only checks parsing, not actual AWS interaction uri := "s3://bucket/path/" - _, err := getS3FileURI(aws.Config{}, uri, "avro") + _, err := getS3FileURI(context.Background(), aws.Config{}, uri, "avro") // Should not error on parsing, but will error on AWS call (which is fine for unit test context) if err == nil || !strings.Contains(err.Error(), "failed to list S3 objects") { t.Logf("Expected AWS list objects error, got: %v", err) diff --git a/internal/pkg/object/command/trino/client.go b/internal/pkg/object/command/trino/client.go index 793ddf9..ce39bd5 100644 --- a/internal/pkg/object/command/trino/client.go +++ b/internal/pkg/object/command/trino/client.go @@ -2,6 +2,7 @@ package trino import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -45,7 +46,7 @@ type response struct { Error map[string]any `json:"error"` } -func newRequest(r *plugin.Runtime, j *job.Job, c *cluster.Cluster, jobCtx *jobContext) (*request, error) { +func newRequest(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster, jobCtx *jobContext) (*request, error) { // get cluster context clusterCtx := &clusterContext{} @@ -72,7 +73,7 @@ func newRequest(r *plugin.Runtime, j *job.Job, c *cluster.Cluster, jobCtx *jobCo } // submit query - if err := req.submit(jobCtx.Query); err != nil { + if err := req.submit(ctx, jobCtx.Query); err != nil { return nil, err } @@ -80,9 +81,9 @@ func newRequest(r *plugin.Runtime, j *job.Job, c *cluster.Cluster, jobCtx *jobCo } -func (r *request) submit(query string) error { +func (r *request) submit(ctx context.Context, query string) error { - req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/v1/statement", r.endpoint), bytes.NewBuffer([]byte(query))) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/v1/statement", r.endpoint), bytes.NewBuffer([]byte(query))) if err != nil { return err } @@ -91,9 +92,9 @@ func (r *request) submit(query string) error { } -func (r *request) poll() error { +func (r *request) poll(ctx context.Context) error { - req, err := http.NewRequest(http.MethodGet, r.nextUri, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, r.nextUri, nil) if err != nil { return err } @@ -174,4 +175,4 @@ func (r *request) api(req *http.Request) error { func normalizeTrinoQuery(query string) string { // Trino does not support semicolon at the end of the query, so we remove it if present return strings.TrimSuffix(query, ";") -} \ No newline at end of file +} diff --git a/internal/pkg/object/command/trino/trino.go b/internal/pkg/object/command/trino/trino.go index ce1891c..8d246da 100644 --- a/internal/pkg/object/command/trino/trino.go +++ b/internal/pkg/object/command/trino/trino.go @@ -1,12 +1,13 @@ package trino import ( + "context" "fmt" "log" "time" "github.com/hladush/go-telemetry/pkg/telemetry" - "github.com/patterninc/heimdall/pkg/context" + heimdallContext "github.com/patterninc/heimdall/pkg/context" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" "github.com/patterninc/heimdall/pkg/plugin" @@ -36,14 +37,14 @@ type jobContext struct { } // New creates a new trino plugin handler -func New(ctx *context.Context) (plugin.Handler, error) { +func New(commandCtx *heimdallContext.Context) (plugin.Handler, error) { t := &commandContext{ PollInterval: defaultPollInterval, } - if ctx != nil { - if err := ctx.Unmarshal(t); err != nil { + if commandCtx != nil { + if err := commandCtx.Unmarshal(t); err != nil { return nil, err } } @@ -52,7 +53,7 @@ func New(ctx *context.Context) (plugin.Handler, error) { } -func (t *commandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { +func (t *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { // get job context jobCtx := &jobContext{} @@ -68,7 +69,7 @@ func (t *commandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster.Clust // this code will be enabled in prod after some testing } // let's submit our query to trino - req, err := newRequest(r, j, c, jobCtx) + req, err := newRequest(ctx, r, j, c, jobCtx) if err != nil { return err } @@ -76,7 +77,7 @@ func (t *commandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster.Clust // now let's keep pooling until we get the full result... for req.nextUri != `` { time.Sleep(time.Duration(t.PollInterval) * time.Millisecond) - if err := req.poll(); err != nil { + if err := req.poll(ctx); err != nil { return err } } @@ -112,7 +113,7 @@ func canQueryBeExecuted(query, user, id string, c *cluster.Cluster) bool { return false } } - + canBeExecutedMethod.CountSuccess() return true } diff --git a/internal/pkg/pool/pool.go b/internal/pkg/pool/pool.go index 368a12d..dec05db 100644 --- a/internal/pkg/pool/pool.go +++ b/internal/pkg/pool/pool.go @@ -1,6 +1,7 @@ package pool import ( + "context" "fmt" "time" ) @@ -16,7 +17,7 @@ type Pool[T any] struct { queue chan T } -func (p *Pool[T]) Start(worker func(T) error, getWork func(int) ([]T, error)) error { +func (p *Pool[T]) Start(worker func(context.Context, T) error, getWork func(int) ([]T, error)) error { // do we have the size set? if p.Size <= 0 { @@ -53,7 +54,9 @@ func (p *Pool[T]) Start(worker func(T) error, getWork func(int) ([]T, error)) er } // do the work.... - if err := worker(w); err != nil { + err := worker(context.Background(), w) + + if err != nil { // TODO: implement proper error logging fmt.Println(`worker:`, err) } @@ -106,5 +109,4 @@ func (p *Pool[T]) Start(worker func(T) error, getWork func(int) ([]T, error)) er }(tokens) return nil - } diff --git a/pkg/object/job/job.go b/pkg/object/job/job.go index b0d6855..22b9c83 100644 --- a/pkg/object/job/job.go +++ b/pkg/object/job/job.go @@ -19,8 +19,9 @@ type Job struct { ClusterCriteria *set.Set[string] `yaml:"cluster_criteria,omitempty" json:"cluster_criteria,omitempty"` CommandID string `yaml:"command_id,omitempty" json:"command_id,omitempty"` CommandName string `yaml:"command_name,omitempty" json:"command_name,omitempty"` - CluserID string `yaml:"cluster_id,omitempty" json:"cluster_id,omitempty"` + ClusterID string `yaml:"cluster_id,omitempty" json:"cluster_id,omitempty"` ClusterName string `yaml:"cluster_name,omitempty" json:"cluster_name,omitempty"` + CancelledBy string `yaml:"cancelled_by,omitempty" json:"cancelled_by,omitempty"` Result *result.Result `yaml:"result,omitempty" json:"result,omitempty"` } diff --git a/pkg/object/job/status/status.go b/pkg/object/job/status/status.go index efe449a..375c540 100644 --- a/pkg/object/job/status/status.go +++ b/pkg/object/job/status/status.go @@ -11,12 +11,14 @@ import ( type Status status.Status const ( - New Status = 1 - Accepted Status = 2 - Running Status = 3 - Failed Status = 4 - Killed Status = 5 - Succeeded Status = 6 + New Status = 1 + Accepted Status = 2 + Running Status = 3 + Failed Status = 4 + Killed Status = 5 + Succeeded Status = 6 + Cancelling Status = 7 + Cancelled Status = 8 ) const ( @@ -25,12 +27,14 @@ const ( var ( statusMapping = map[string]status.Status{ - `new`: status.Status(New), - `accepted`: status.Status(Accepted), - `running`: status.Status(Running), - `failed`: status.Status(Failed), - `killed`: status.Status(Killed), - `succeeded`: status.Status(Succeeded), + `new`: status.Status(New), + `accepted`: status.Status(Accepted), + `running`: status.Status(Running), + `failed`: status.Status(Failed), + `killed`: status.Status(Killed), + `succeeded`: status.Status(Succeeded), + `cancelling`: status.Status(Cancelling), + `cancelled`: status.Status(Cancelled), } ) diff --git a/pkg/plugin/plugin.go b/pkg/plugin/plugin.go index 3512cbf..cfb6db6 100644 --- a/pkg/plugin/plugin.go +++ b/pkg/plugin/plugin.go @@ -1,8 +1,10 @@ package plugin import ( + "context" + "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" ) -type Handler func(*Runtime, *job.Job, *cluster.Cluster) error +type Handler func(context.Context, *Runtime, *job.Job, *cluster.Cluster) error diff --git a/pkg/plugin/runtime.go b/pkg/plugin/runtime.go index 6c0d616..179fb81 100644 --- a/pkg/plugin/runtime.go +++ b/pkg/plugin/runtime.go @@ -1,6 +1,7 @@ package plugin import ( + "context" "fmt" "io/fs" "os" @@ -97,7 +98,9 @@ func copyDir(src, dst string) error { // if we have local filesystem, crete directory as appropriate writeFileFunc := os.WriteFile if strings.HasPrefix(dst, s3Prefix) { - writeFileFunc = aws.WriteToS3 + writeFileFunc = func(name string, data []byte, perm os.FileMode) error { + return aws.WriteToS3(context.Background(), name, data, perm) + } } else { if _, err := os.Stat(dst); os.IsNotExist(err) { os.MkdirAll(dst, jobDirectoryPermissions) diff --git a/web/src/modules/Jobs/Helper.tsx b/web/src/modules/Jobs/Helper.tsx index 50ce30d..c47f205 100644 --- a/web/src/modules/Jobs/Helper.tsx +++ b/web/src/modules/Jobs/Helper.tsx @@ -36,6 +36,7 @@ export type JobType = { command_name: string cluster_id: string cluster_name: string + cancelled_by: string error?: string context?: { properties: {