diff --git a/README.md b/README.md index e426523..1ad1c91 100644 --- a/README.md +++ b/README.md @@ -88,17 +88,16 @@ GET /api/v1/job//stderr Heimdall supports a growing set of pluggable command types: -| Plugin | Description | Execution Mode | -| ----------- | -------------------------------------- | -------------- | -| `ping` | [Basic plugin used for testing](https://github.com/patterninc/heimdall/blob/main/plugins/ping/README.md) | Sync or Async | -| `shell` | [Shell command execution](https://github.com/patterninc/heimdall/blob/main/plugins/shell/README.md) | Sync or Async | -| `glue` | [Pulling Iceberg table metadata](https://github.com/patterninc/heimdall/blob/main/plugins/glue/README.md) | Sync or Async | -| `dynamo` | [DynamoDB read operation](https://github.com/patterninc/heimdall/blob/main/plugins/dynamo/README.md) | Sync or Async | -| `snowflake` | [Query execution in Snowflake](https://github.com/patterninc/heimdall/blob/main/plugins/snowflake/README.md) | Async | -| `spark` | [SparkSQL query execution on EMR on EKS](https://github.com/patterninc/heimdall/blob/main/plugins/spark/README.md) | Async | -| `trino` | [Query execution in Trino](https://github.com/patterninc/heimdall/blob/main/plugins/trino/README.md) | Async | -| `clickhouse` | [Query execution in Clickhouse](https://github.com/patterninc/heimdall/blob/main/plugins/clickhouse/README.md) | Sync | -| `ecs fargate` | [Task Deployment in ECS Fargate](https://github.com/patterninc/heimdall/blob/main/plugins/ecs/README.md) | Async | +| Plugin | Description | Execution Mode | +| ----------- | -------------------------------------- | -------------- | +| `ping` | [Basic plugin used for testing](https://github.com/patterninc/heimdall/blob/main/plugins/ping/README.md) | Sync or Async | +| `shell` | [Shell command execution](https://github.com/patterninc/heimdall/blob/main/plugins/shell/README.md) | Sync or Async | +| `glue` | [Pulling Iceberg table metadata](https://github.com/patterninc/heimdall/blob/main/plugins/glue/README.md) | Sync or Async | +| `dynamo` | [DynamoDB read operation](https://github.com/patterninc/heimdall/blob/main/plugins/dynamo/README.md) | Sync or Async | +| `snowflake` | [Query execution in Snowflake](https://github.com/patterninc/heimdall/blob/main/plugins/snowflake/README.md) | Async | +| `spark` | [SparkSQL query execution on EMR on EKS](https://github.com/patterninc/heimdall/blob/main/plugins/spark/README.md) | Async | +| `trino` | [Query execution in Trino](https://github.com/patterninc/heimdall/blob/main/plugins/trino/README.md) | Async | +| `clickhouse`| [Query execution in Clickhouse](https://github.com/patterninc/heimdall/blob/main/plugins/clickhouse/README.md) | Sync | --- @@ -164,7 +163,6 @@ It centralizes execution logic, logging, and auditing—all accessible via API o | `POST /api/v1/job` | Submit a job | | `GET /api/v1/job/` | Get job details | | `GET /api/v1/job//status` | Check job status | -| `POST /api/v1/job//cancel` | Cancel an async job | | `GET /api/v1/job//stdout` | Get stdout for a completed job | | `GET /api/v1/job//stderr` | Get stderr for a completed job | | `GET /api/v1/job//result` | Get job's result | diff --git a/assets/databases/heimdall/data/job_statuses.sql b/assets/databases/heimdall/data/job_statuses.sql index b0e5420..0acc895 100644 --- a/assets/databases/heimdall/data/job_statuses.sql +++ b/assets/databases/heimdall/data/job_statuses.sql @@ -9,9 +9,7 @@ values (3, 'RUNNING'), (4, 'FAILED'), (5, 'KILLED'), - (6, 'SUCCEEDED'), - (7, 'CANCELLING'), - (8, 'CANCELLED') + (6, 'SUCCEEDED') on conflict (job_status_id) do update set job_status_name = excluded.job_status_name; diff --git a/assets/databases/heimdall/tables/jobs.sql b/assets/databases/heimdall/tables/jobs.sql index e40a6fb..11de46d 100644 --- a/assets/databases/heimdall/tables/jobs.sql +++ b/assets/databases/heimdall/tables/jobs.sql @@ -19,6 +19,4 @@ create table if not exists jobs constraint _jobs_job_id unique (job_id) ); -alter table jobs add column if not exists store_result_sync boolean not null default false; -alter table jobs add column if not exists cancelled_by varchar(64) null; -update jobs set cancelled_by = '' where cancelled_by is null; \ No newline at end of file +alter table jobs add column if not exists store_result_sync boolean not null default false; \ No newline at end of file diff --git a/internal/pkg/aws/cloudwatch.go b/internal/pkg/aws/cloudwatch.go index 543c9f9..d8467ba 100644 --- a/internal/pkg/aws/cloudwatch.go +++ b/internal/pkg/aws/cloudwatch.go @@ -1,7 +1,6 @@ package aws import ( - "context" "fmt" "os" "time" @@ -11,7 +10,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs" ) -func PullLogs(ctx context.Context, writer *os.File, logGroup, logStream string, chunkSize int, memoryLimit int64) error { +func PullLogs(writer *os.File, logGroup, logStream string, chunkSize int, memoryLimit int64) error { // initialize AWS session cfg, err := config.LoadDefaultConfig(ctx) diff --git a/internal/pkg/aws/glue.go b/internal/pkg/aws/glue.go index a540d94..94489a0 100644 --- a/internal/pkg/aws/glue.go +++ b/internal/pkg/aws/glue.go @@ -1,7 +1,6 @@ package aws import ( - "context" "fmt" "strings" @@ -18,7 +17,7 @@ var ( ErrMissingCatalogTableMetadata = fmt.Errorf(`missing table metadata in the glue catalog`) ) -func GetTableMetadata(ctx context.Context, catalogID, tableName string) ([]byte, error) { +func GetTableMetadata(catalogID, tableName string) ([]byte, error) { // split tableName to namespace and table names tableNameParts := strings.Split(tableName, `.`) @@ -28,18 +27,18 @@ func GetTableMetadata(ctx context.Context, catalogID, tableName string) ([]byte, } // let's get the latest metadata file location - location, err := getTableMetadataLocation(ctx, catalogID, tableNameParts[0], tableNameParts[1]) + location, err := getTableMetadataLocation(catalogID, tableNameParts[0], tableNameParts[1]) if err != nil { return nil, err } // let's pull the file content - return ReadFromS3(ctx, location) + return ReadFromS3(location) } // function that calls AWS glue catalog to get the snapshot ID for a given database, table and branch -func getTableMetadataLocation(ctx context.Context, catalogID, databaseName, tableName string) (string, error) { +func getTableMetadataLocation(catalogID, databaseName, tableName string) (string, error) { // Return an error if databaseName or tableName is empty if databaseName == `` || tableName == `` { diff --git a/internal/pkg/aws/s3.go b/internal/pkg/aws/s3.go index 51843e2..1be4514 100644 --- a/internal/pkg/aws/s3.go +++ b/internal/pkg/aws/s3.go @@ -13,11 +13,12 @@ import ( ) var ( + ctx = context.Background() rxS3Path = regexp.MustCompile(`^s3://([^/]+)/(.*)$`) ) // WriteToS3 writes a file to S3, providing the same interface as os.WriteFile function -func WriteToS3(ctx context.Context, name string, data []byte, _ os.FileMode) error { +func WriteToS3(name string, data []byte, _ os.FileMode) error { bucket, key, err := parseS3Path(name) if err != nil { @@ -46,7 +47,7 @@ func WriteToS3(ctx context.Context, name string, data []byte, _ os.FileMode) err } -func ReadFromS3(ctx context.Context, name string) ([]byte, error) { +func ReadFromS3(name string) ([]byte, error) { bucket, key, err := parseS3Path(name) if err != nil { diff --git a/internal/pkg/heimdall/cluster_dal.go b/internal/pkg/heimdall/cluster_dal.go index 02173cc..188aabf 100644 --- a/internal/pkg/heimdall/cluster_dal.go +++ b/internal/pkg/heimdall/cluster_dal.go @@ -1,7 +1,6 @@ package heimdall import ( - "context" "database/sql" _ "embed" "encoding/json" @@ -78,13 +77,13 @@ var ( getClustersMethod = telemetry.NewMethod("db_connection", "get_clusters") ) -func (h *Heimdall) submitCluster(ctx context.Context, c *cluster.Cluster) (any, error) { +func (h *Heimdall) submitCluster(c *cluster.Cluster) (any, error) { if err := h.clusterUpsert(c); err != nil { return nil, err } - return h.getCluster(ctx, &cluster.Cluster{Object: object.Object{ID: c.ID}}) + return h.getCluster(&cluster.Cluster{Object: object.Object{ID: c.ID}}) } @@ -135,7 +134,7 @@ func (h *Heimdall) clusterUpsert(c *cluster.Cluster) error { } -func (h *Heimdall) getCluster(ctx context.Context, c *cluster.Cluster) (any, error) { +func (h *Heimdall) getCluster(c *cluster.Cluster) (any, error) { // Track DB connection for get cluster operation defer getClusterMethod.RecordLatency(time.Now()) @@ -182,7 +181,7 @@ func (h *Heimdall) getCluster(ctx context.Context, c *cluster.Cluster) (any, err } -func (h *Heimdall) getClusterStatus(ctx context.Context, c *cluster.Cluster) (any, error) { +func (h *Heimdall) getClusterStatus(c *cluster.Cluster) (any, error) { // Track DB connection for cluster status operation defer getClusterStatusMethod.RecordLatency(time.Now()) @@ -217,7 +216,7 @@ func (h *Heimdall) getClusterStatus(ctx context.Context, c *cluster.Cluster) (an } -func (h *Heimdall) updateClusterStatus(ctx context.Context, c *cluster.Cluster) (any, error) { +func (h *Heimdall) updateClusterStatus(c *cluster.Cluster) (any, error) { defer updateClusterStatusMethod.RecordLatency(time.Now()) updateClusterStatusMethod.CountRequest() @@ -240,11 +239,11 @@ func (h *Heimdall) updateClusterStatus(ctx context.Context, c *cluster.Cluster) } updateClusterStatusMethod.CountSuccess() - return h.getClusterStatus(ctx, c) + return h.getClusterStatus(c) } -func (h *Heimdall) getClusters(ctx context.Context, f *database.Filter) (any, error) { +func (h *Heimdall) getClusters(f *database.Filter) (any, error) { // Track DB connection for clusters list operation defer getClustersMethod.RecordLatency(time.Now()) @@ -296,7 +295,7 @@ func (h *Heimdall) getClusters(ctx context.Context, f *database.Filter) (any, er } -func (h *Heimdall) getClusterStatuses(ctx context.Context, _ *database.Filter) (any, error) { +func (h *Heimdall) getClusterStatuses(_ *database.Filter) (any, error) { return database.GetSlice(h.Database, queryClusterStatusesSelect) diff --git a/internal/pkg/heimdall/command_dal.go b/internal/pkg/heimdall/command_dal.go index 2274801..ad9cd99 100644 --- a/internal/pkg/heimdall/command_dal.go +++ b/internal/pkg/heimdall/command_dal.go @@ -1,7 +1,6 @@ package heimdall import ( - "context" "database/sql" _ "embed" "encoding/json" @@ -90,13 +89,13 @@ var ( getCommandsMethod = telemetry.NewMethod("db_connection", "get_commands") ) -func (h *Heimdall) submitCommand(ctx context.Context, c *command.Command) (any, error) { +func (h *Heimdall) submitCommand(c *command.Command) (any, error) { if err := h.commandUpsert(c); err != nil { return nil, err } - return h.getCommand(ctx, &command.Command{Object: object.Object{ID: c.ID}}) + return h.getCommand(&command.Command{Object: object.Object{ID: c.ID}}) } @@ -164,7 +163,7 @@ func (h *Heimdall) commandUpsert(c *command.Command) error { } -func (h *Heimdall) getCommand(ctx context.Context, c *command.Command) (any, error) { +func (h *Heimdall) getCommand(c *command.Command) (any, error) { // Track DB connection for get command operation defer getCommandMethod.RecordLatency(time.Now()) @@ -211,7 +210,7 @@ func (h *Heimdall) getCommand(ctx context.Context, c *command.Command) (any, err } -func (h *Heimdall) getCommandStatus(ctx context.Context, c *command.Command) (any, error) { +func (h *Heimdall) getCommandStatus(c *command.Command) (any, error) { // Track DB connection for command status operation defer getCommandStatusMethod.RecordLatency(time.Now()) @@ -246,7 +245,7 @@ func (h *Heimdall) getCommandStatus(ctx context.Context, c *command.Command) (an } -func (h *Heimdall) updateCommandStatus(ctx context.Context, c *command.Command) (any, error) { +func (h *Heimdall) updateCommandStatus(c *command.Command) (any, error) { // Track DB connection for command status update operation defer updateCommandStatusMethod.RecordLatency(time.Now()) @@ -270,11 +269,11 @@ func (h *Heimdall) updateCommandStatus(ctx context.Context, c *command.Command) } updateCommandStatusMethod.CountSuccess() - return h.getCommandStatus(ctx, c) + return h.getCommandStatus(c) } -func (h *Heimdall) getCommands(ctx context.Context, f *database.Filter) (any, error) { +func (h *Heimdall) getCommands(f *database.Filter) (any, error) { // Track DB connection for commands list operation defer getCommandsMethod.RecordLatency(time.Now()) @@ -326,7 +325,7 @@ func (h *Heimdall) getCommands(ctx context.Context, f *database.Filter) (any, er } -func (h *Heimdall) getCommandStatuses(ctx context.Context, _ *database.Filter) (any, error) { +func (h *Heimdall) getCommandStatuses(_ *database.Filter) (any, error) { return database.GetSlice(h.Database, queryCommandStatusesSelect) diff --git a/internal/pkg/heimdall/handler.go b/internal/pkg/heimdall/handler.go index 55607ff..194169b 100644 --- a/internal/pkg/heimdall/handler.go +++ b/internal/pkg/heimdall/handler.go @@ -1,7 +1,6 @@ package heimdall import ( - "context" "encoding/json" "fmt" "io" @@ -49,7 +48,7 @@ func writeAPIError(w http.ResponseWriter, err error, obj any) { w.Write(responseJSON) } -func payloadHandler[T any](fn func(context.Context, *T) (any, error)) http.HandlerFunc { +func payloadHandler[T any](fn func(*T) (any, error)) http.HandlerFunc { // start latency timer defer payloadHandlerMethod.RecordLatency(time.Now()) @@ -82,7 +81,7 @@ func payloadHandler[T any](fn func(context.Context, *T) (any, error)) http.Handl } // execute request - result, err := fn(r.Context(), &payload) + result, err := fn(&payload) if err != nil { writeAPIError(w, err, result) return diff --git a/internal/pkg/heimdall/heimdall.go b/internal/pkg/heimdall/heimdall.go index 871546d..d8a5840 100644 --- a/internal/pkg/heimdall/heimdall.go +++ b/internal/pkg/heimdall/heimdall.go @@ -15,12 +15,12 @@ import ( "github.com/patterninc/heimdall/internal/pkg/database" "github.com/patterninc/heimdall/internal/pkg/janitor" "github.com/patterninc/heimdall/internal/pkg/pool" - "github.com/patterninc/heimdall/internal/pkg/rbac" "github.com/patterninc/heimdall/internal/pkg/server" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/command" "github.com/patterninc/heimdall/pkg/object/job" "github.com/patterninc/heimdall/pkg/plugin" + "github.com/patterninc/heimdall/internal/pkg/rbac" rbacI "github.com/patterninc/heimdall/pkg/rbac" ) @@ -173,7 +173,6 @@ func (h *Heimdall) Start() error { // job(s) endpoints... apiRouter.Methods(methodGET).PathPrefix(`/job/statuses`).HandlerFunc(payloadHandler(h.getJobStatuses)) apiRouter.Methods(methodGET).PathPrefix(`/job/{id}/status`).HandlerFunc(payloadHandler(h.getJobStatus)) - apiRouter.Methods(methodPOST).PathPrefix(`/job/{id}/cancel`).HandlerFunc(payloadHandler(h.cancelJob)) apiRouter.Methods(methodGET).PathPrefix(`/job/{id}/{file}`).HandlerFunc(h.getJobFile) apiRouter.Methods(methodGET).PathPrefix(`/job/{id}`).HandlerFunc(payloadHandler(h.getJob)) apiRouter.Methods(methodGET).PathPrefix(`/jobs`).HandlerFunc(payloadHandler(h.getJobs)) diff --git a/internal/pkg/heimdall/job.go b/internal/pkg/heimdall/job.go index 41151d1..b0cd043 100644 --- a/internal/pkg/heimdall/job.go +++ b/internal/pkg/heimdall/job.go @@ -1,9 +1,7 @@ package heimdall import ( - "context" "crypto/rand" - _ "embed" "encoding/json" "fmt" "math/big" @@ -38,20 +36,15 @@ const ( var ( ErrCommandClusterPairNotFound = fmt.Errorf(`command-cluster pair is not found`) - ErrJobCancelFailed = fmt.Errorf(`async job unrecognized or already in final state`) runJobMethod = telemetry.NewMethod("runJob", "heimdall") - cancelJobMethod = telemetry.NewMethod("db_connection", "cancel_job") ) -//go:embed queries/job/status_cancel_update.sql -var queryJobCancelUpdate string - type commandOnCluster struct { command *command.Command cluster *cluster.Cluster } -func (h *Heimdall) submitJob(ctx context.Context, j *job.Job) (any, error) { +func (h *Heimdall) submitJob(j *job.Job) (any, error) { // set / add job properties if err := j.Init(); err != nil { @@ -74,7 +67,7 @@ func (h *Heimdall) submitJob(ctx context.Context, j *job.Job) (any, error) { } // let's run the job - err = h.runJob(ctx, j, command, cluster) + err = h.runJob(j, command, cluster) // before we process the error, we'll make the best effort to record this job in the database go h.insertJob(j, cluster.ID, command.ID) @@ -83,16 +76,16 @@ func (h *Heimdall) submitJob(ctx context.Context, j *job.Job) (any, error) { } -func (h *Heimdall) runJob(ctx context.Context, j *job.Job, command *command.Command, cluster *cluster.Cluster) error { +func (h *Heimdall) runJob(job *job.Job, command *command.Command, cluster *cluster.Cluster) error { defer runJobMethod.RecordLatency(time.Now(), command.Name, cluster.Name) runJobMethod.CountRequest(command.Name, cluster.Name) // let's set environment runtime := &plugin.Runtime{ - WorkingDirectory: h.JobsDirectory + separator + j.ID, - ArchiveDirectory: h.ArchiveDirectory + separator + j.ID, - ResultDirectory: h.ResultDirectory + separator + j.ID, + WorkingDirectory: h.JobsDirectory + separator + job.ID, + ArchiveDirectory: h.ArchiveDirectory + separator + job.ID, + ResultDirectory: h.ResultDirectory + separator + job.ID, Version: h.Version, UserAgent: fmt.Sprintf(formatUserAgent, h.Version), } @@ -111,87 +104,41 @@ func (h *Heimdall) runJob(ctx context.Context, j *job.Job, command *command.Comm defer close(keepaliveActive) // ...and now we just start keepalive function for this job - go h.jobKeepalive(keepaliveActive, j.SystemID, h.agentName) - - // Create channels for coordination between plugin execution and cancellation monitoring - jobDone := make(chan error, 1) - cancelMonitorDone := make(chan struct{}) - - // Create cancellable context for the job - pluginCtx, cancel := context.WithCancel(ctx) - defer cancel() - - // Start plugin execution in goroutine - go func() { - defer close(cancelMonitorDone) // signal monitoring to stop - err := h.commandHandlers[command.ID](pluginCtx, runtime, j, cluster) - jobDone <- err - }() - - // Start cancellation monitoring for async jobs - if !j.IsSync { - go func() { - ticker := time.NewTicker(10 * time.Second) - defer ticker.Stop() - - for { - select { - // plugin finished, stop monitoring - case <-cancelMonitorDone: - return - case <-ticker.C: - // If job is in cancelling state, trigger context cancellation - result, err := h.getJobStatus(ctx, &jobRequest{ID: j.ID}) - if err == nil { - if job, ok := result.(*job.Job); ok && job.Status == jobStatus.Cancelling { - cancel() - return - } - } - } - } - }() - } + go h.jobKeepalive(keepaliveActive, job.SystemID, h.agentName) - // Wait for job execution to complete - jobErr := <-jobDone + // let's execute command + if err := h.commandHandlers[command.ID](runtime, job, cluster); err != nil { - // Check if context was cancelled and mark status appropriately - if pluginCtx.Err() != nil { - j.Status = jobStatus.Cancelling // janitor will update to cancelled when resources are cleaned up - runJobMethod.LogAndCountError(pluginCtx.Err(), command.Name, cluster.Name) - return nil - } + job.Status = jobStatus.Failed + job.Error = err.Error() + + runJobMethod.LogAndCountError(err, command.Name, cluster.Name) + + return err - // Handle plugin execution result (only if not cancelled) - if jobErr != nil { - j.Status = jobStatus.Failed - j.Error = jobErr.Error() - runJobMethod.LogAndCountError(jobErr, command.Name, cluster.Name) - return jobErr } - if j.StoreResultSync || !j.IsSync { - h.storeResults(runtime, j) + if job.StoreResultSync || !job.IsSync { + h.storeResults(runtime, job) } else { - go h.storeResults(runtime, j) + go h.storeResults(runtime, job) } - j.Status = jobStatus.Succeeded + job.Status = jobStatus.Succeeded runJobMethod.CountSuccess(command.Name, cluster.Name) return nil } -func (h *Heimdall) storeResults(runtime *plugin.Runtime, j *job.Job) error { +func (h *Heimdall) storeResults(runtime *plugin.Runtime, job *job.Job) error { // do we have result to be written? - if j.Result == nil { + if job.Result == nil { return nil } // prepare result - data, err := json.Marshal(j.Result) + data, err := json.Marshal(job.Result) if err != nil { return err @@ -200,9 +147,7 @@ func (h *Heimdall) storeResults(runtime *plugin.Runtime, j *job.Job) error { // write result writeFileFunc := os.WriteFile if strings.HasPrefix(runtime.ResultDirectory, s3Prefix) { - writeFileFunc = func(name string, data []byte, perm os.FileMode) error { - return aws.WriteToS3(context.Background(), name, data, perm) - } + writeFileFunc = aws.WriteToS3 } if err := writeFileFunc(runtime.ResultDirectory+separator+resultFilename, data, 0600); err != nil { @@ -212,34 +157,6 @@ func (h *Heimdall) storeResults(runtime *plugin.Runtime, j *job.Job) error { return nil } -func (h *Heimdall) cancelJob(ctx context.Context, req *jobRequest) (any, error) { - - defer cancelJobMethod.RecordLatency(time.Now()) - cancelJobMethod.CountRequest() - - sess, err := h.Database.NewSession(false) - if err != nil { - cancelJobMethod.LogAndCountError(err, "new_session") - return nil, err - } - defer sess.Close() - - // Attempt to cancel - rowsAffected, err := sess.Exec(queryJobCancelUpdate, req.ID, req.User) - if err != nil { - return nil, err - } - - if rowsAffected == 0 { - return nil, ErrJobCancelFailed - } - - // return job status - return &job.Job{ - Status: jobStatus.Cancelling, - }, nil -} - func (h *Heimdall) getJobFile(w http.ResponseWriter, r *http.Request) { // get vars @@ -256,7 +173,7 @@ func (h *Heimdall) getJobFile(w http.ResponseWriter, r *http.Request) { } // let's validate jobID we got - if _, err := h.getJobStatus(r.Context(), &jobRequest{ID: jobID}); err != nil { + if _, err := h.getJobStatus(&jobRequest{ID: jobID}); err != nil { writeAPIError(w, err, nil) return } @@ -275,9 +192,7 @@ func (h *Heimdall) getJobFile(w http.ResponseWriter, r *http.Request) { readFileFunc := os.ReadFile filenamePath := fmt.Sprintf(jobFileFormat, sourceDirectory, jobID, filename) if strings.HasPrefix(filenamePath, s3Prefix) { - readFileFunc = func(path string) ([]byte, error) { - return aws.ReadFromS3(r.Context(), path) - } + readFileFunc = aws.ReadFromS3 } // get file's content diff --git a/internal/pkg/heimdall/job_dal.go b/internal/pkg/heimdall/job_dal.go index a845be0..31f070a 100644 --- a/internal/pkg/heimdall/job_dal.go +++ b/internal/pkg/heimdall/job_dal.go @@ -1,7 +1,6 @@ package heimdall import ( - "context" "database/sql" _ "embed" "encoding/json" @@ -95,7 +94,6 @@ var ( type jobRequest struct { ID string `yaml:"id,omitempty" json:"id,omitempty"` File string `yaml:"file,omitempty" json:"file,omitempty"` - User string `yaml:"user,omitempty" json:"user,omitempty"` } func (h *Heimdall) insertJob(j *job.Job, clusterID, commandID string) (int64, error) { @@ -113,7 +111,7 @@ func (h *Heimdall) insertJob(j *job.Job, clusterID, commandID string) (int64, er defer sess.Close() // insert job row - jobID, err := sess.InsertRow(queryJobInsert, clusterID, commandID, j.Status, j.ID, j.Name, j.Version, j.Description, j.Context.String(), j.Error, j.User, j.IsSync, j.StoreResultSync, j.CancelledBy) + jobID, err := sess.InsertRow(queryJobInsert, clusterID, commandID, j.Status, j.ID, j.Name, j.Version, j.Description, j.Context.String(), j.Error, j.User, j.IsSync, j.StoreResultSync) if err != nil { return 0, err } @@ -171,7 +169,7 @@ func (h *Heimdall) insertJob(j *job.Job, clusterID, commandID string) (int64, er } -func (h *Heimdall) getJob(ctx context.Context, j *jobRequest) (any, error) { +func (h *Heimdall) getJob(j *jobRequest) (any, error) { // Track DB connection for job get operation defer getJobMethod.RecordLatency(time.Now()) @@ -200,7 +198,7 @@ func (h *Heimdall) getJob(ctx context.Context, j *jobRequest) (any, error) { var jobContext string if err := row.Scan(&r.SystemID, &r.Status, &r.Name, &r.Version, &r.Description, &jobContext, &r.Error, &r.User, &r.IsSync, - &r.CreatedAt, &r.UpdatedAt, &r.CommandID, &r.CommandName, &r.ClusterID, &r.ClusterName, &r.StoreResultSync, &r.CancelledBy); err != nil { + &r.CreatedAt, &r.UpdatedAt, &r.CommandID, &r.CommandName, &r.CluserID, &r.ClusterName, &r.StoreResultSync); err != nil { if err == sql.ErrNoRows { return nil, ErrUnknownJobID } else { @@ -218,7 +216,7 @@ func (h *Heimdall) getJob(ctx context.Context, j *jobRequest) (any, error) { } -func (h *Heimdall) getJobs(ctx context.Context, f *database.Filter) (any, error) { +func (h *Heimdall) getJobs(f *database.Filter) (any, error) { // Track DB connection for jobs list operation defer getJobsMethod.RecordLatency(time.Now()) @@ -253,7 +251,7 @@ func (h *Heimdall) getJobs(ctx context.Context, f *database.Filter) (any, error) r := &job.Job{} if err := rows.Scan(&r.SystemID, &r.ID, &r.Status, &r.Name, &r.Version, &r.Description, &jobContext, &r.Error, &r.User, &r.IsSync, - &r.CreatedAt, &r.UpdatedAt, &r.CommandID, &r.CommandName, &r.ClusterID, &r.ClusterName, &r.StoreResultSync, &r.CancelledBy); err != nil { + &r.CreatedAt, &r.UpdatedAt, &r.CommandID, &r.CommandName, &r.CluserID, &r.ClusterName, &r.StoreResultSync); err != nil { getJobsMethod.LogAndCountError(err, "scan") return nil, err } @@ -274,7 +272,7 @@ func (h *Heimdall) getJobs(ctx context.Context, f *database.Filter) (any, error) } -func (h *Heimdall) getJobStatus(ctx context.Context, j *jobRequest) (any, error) { +func (h *Heimdall) getJobStatus(j *jobRequest) (any, error) { // Track DB connection for job status operation defer getJobStatusMethod.RecordLatency(time.Now()) @@ -339,7 +337,7 @@ func jobParseContextAndTags(j *job.Job, jobContext string, sess *database.Sessio } -func (h *Heimdall) getJobStatuses(ctx context.Context, _ *database.Filter) (any, error) { +func (h *Heimdall) getJobStatuses(_ *database.Filter) (any, error) { return database.GetSlice(h.Database, queryJobStatusesSelect) diff --git a/internal/pkg/heimdall/jobs_async.go b/internal/pkg/heimdall/jobs_async.go index 0f2e82a..49a23f4 100644 --- a/internal/pkg/heimdall/jobs_async.go +++ b/internal/pkg/heimdall/jobs_async.go @@ -1,7 +1,6 @@ package heimdall import ( - "context" _ "embed" "fmt" "time" @@ -64,7 +63,7 @@ func (h *Heimdall) getAsyncJobs(limit int) ([]*job.Job, error) { jobContext, j := ``, &job.Job{} - if err := rows.Scan(&j.SystemID, &j.CommandID, &j.ClusterID, &j.Status, &j.ID, &j.Name, + if err := rows.Scan(&j.SystemID, &j.CommandID, &j.CluserID, &j.Status, &j.ID, &j.Name, &j.Version, &j.Description, &jobContext, &j.User, &j.IsSync, &j.CreatedAt, &j.UpdatedAt, &j.StoreResultSync); err != nil { return nil, err } @@ -104,7 +103,7 @@ func (h *Heimdall) getAsyncJobs(limit int) ([]*job.Job, error) { } -func (h *Heimdall) runAsyncJob(ctx context.Context, j *job.Job) error { +func (h *Heimdall) runAsyncJob(j *job.Job) error { // Track DB connection for async job execution defer runAsyncJobMethod.RecordLatency(time.Now()) @@ -129,13 +128,13 @@ func (h *Heimdall) runAsyncJob(ctx context.Context, j *job.Job) error { } // do we have hte cluster? - cluster, found := h.Clusters[j.ClusterID] + cluster, found := h.Clusters[j.CluserID] if !found { - return h.updateAsyncJobStatus(j, fmt.Errorf(formatErrUnknownCluster, j.ClusterID)) + return h.updateAsyncJobStatus(j, fmt.Errorf(formatErrUnknownCluster, j.CluserID)) } runAsyncJobMethod.CountSuccess() - return h.updateAsyncJobStatus(j, h.runJob(ctx, j, command, cluster)) + return h.updateAsyncJobStatus(j, h.runJob(j, command, cluster)) } @@ -145,6 +144,14 @@ func (h *Heimdall) updateAsyncJobStatus(j *job.Job, jobError error) error { defer updateAsyncJobStatusMethod.RecordLatency(time.Now()) updateAsyncJobStatusMethod.CountRequest() + // we updte the final job status based on presence of the error + if jobError == nil { + j.Status = status.Succeeded + } else { + j.Status = status.Failed + j.Error = jobError.Error() + } + // now we update that status in the database sess, err := h.Database.NewSession(true) if err != nil { diff --git a/internal/pkg/heimdall/queries/job/insert.sql b/internal/pkg/heimdall/queries/job/insert.sql index c641363..46813a6 100644 --- a/internal/pkg/heimdall/queries/job/insert.sql +++ b/internal/pkg/heimdall/queries/job/insert.sql @@ -11,8 +11,7 @@ insert into jobs job_error, username, is_sync, - store_result_sync, - cancelled_by + store_result_sync ) select cm.system_command_id, @@ -26,8 +25,7 @@ select $9, -- job_error $10, -- username $11, -- is_sync - $12, -- store_result_sync - $13 -- cancelled_by + $12 -- store_result_sync from clusters cl, commands cm diff --git a/internal/pkg/heimdall/queries/job/select.sql b/internal/pkg/heimdall/queries/job/select.sql index 716633f..bc84784 100644 --- a/internal/pkg/heimdall/queries/job/select.sql +++ b/internal/pkg/heimdall/queries/job/select.sql @@ -14,8 +14,7 @@ select cm.command_name, cl.cluster_id, cl.cluster_name, - j.store_result_sync, - j.cancelled_by + j.store_result_sync from jobs j left join commands cm on cm.system_command_id = j.job_command_id diff --git a/internal/pkg/heimdall/queries/job/select_jobs.sql b/internal/pkg/heimdall/queries/job/select_jobs.sql index 3a439fa..39b4d81 100644 --- a/internal/pkg/heimdall/queries/job/select_jobs.sql +++ b/internal/pkg/heimdall/queries/job/select_jobs.sql @@ -15,8 +15,7 @@ select cm.command_name, cl.cluster_id, cl.cluster_name, - j.store_result_sync, - j.cancelled_by + j.store_result_sync from jobs j join job_statuses js on js.job_status_id = j.job_status_id diff --git a/internal/pkg/heimdall/queries/job/status_cancel_update.sql b/internal/pkg/heimdall/queries/job/status_cancel_update.sql deleted file mode 100644 index 8597499..0000000 --- a/internal/pkg/heimdall/queries/job/status_cancel_update.sql +++ /dev/null @@ -1,8 +0,0 @@ -update jobs -set - job_status_id = 7, -- CANCELLING - cancelled_by = $2, - updated_at = extract(epoch from now())::int -where - job_id = $1 - and job_status_id not in (4, 5, 6, 8); -- Not in FAILED, KILLED, SUCCEEDED, CANCELLED diff --git a/internal/pkg/object/command/clickhouse/clickhouse.go b/internal/pkg/object/command/clickhouse/clickhouse.go index 5581667..819f83a 100644 --- a/internal/pkg/object/command/clickhouse/clickhouse.go +++ b/internal/pkg/object/command/clickhouse/clickhouse.go @@ -7,7 +7,7 @@ import ( "github.com/ClickHouse/clickhouse-go/v2" "github.com/ClickHouse/clickhouse-go/v2/lib/driver" "github.com/hladush/go-telemetry/pkg/telemetry" - heimdallContext "github.com/patterninc/heimdall/pkg/context" + hdctx "github.com/patterninc/heimdall/pkg/context" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" "github.com/patterninc/heimdall/pkg/object/job/status" @@ -45,11 +45,11 @@ var ( ) // New creates a new clickhouse plugin handler -func New(commandCtx *heimdallContext.Context) (plugin.Handler, error) { +func New(ctx *hdctx.Context) (plugin.Handler, error) { t := &commandContext{} - if commandCtx != nil { - if err := commandCtx.Unmarshal(t); err != nil { + if ctx != nil { + if err := ctx.Unmarshal(t); err != nil { return nil, err } } @@ -57,9 +57,10 @@ func New(commandCtx *heimdallContext.Context) (plugin.Handler, error) { return t.handler, nil } -func (cmd *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { +func (cmd *commandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { + ctx := context.Background() - jobContext, err := cmd.createJobContext(ctx, j, c) + jobContext, err := cmd.createJobContext(j, c) if err != nil { handleMethod.LogAndCountError(err, "create_job_context") return err @@ -81,7 +82,7 @@ func (cmd *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *jo return nil } -func (cmd *commandContext) createJobContext(ctx context.Context, j *job.Job, c *cluster.Cluster) (*jobContext, error) { +func (cmd *commandContext) createJobContext(j *job.Job, c *cluster.Cluster) (*jobContext, error) { // get cluster context clusterCtx := &clusterContext{} if c.Context != nil { diff --git a/internal/pkg/object/command/dynamo/dynamo.go b/internal/pkg/object/command/dynamo/dynamo.go index 283a6b8..f0a3334 100644 --- a/internal/pkg/object/command/dynamo/dynamo.go +++ b/internal/pkg/object/command/dynamo/dynamo.go @@ -1,7 +1,7 @@ package dynamo import ( - "context" + ct "context" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" @@ -9,7 +9,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/dynamodb" "github.com/aws/aws-sdk-go-v2/service/sts" - heimdallContext "github.com/patterninc/heimdall/pkg/context" + "github.com/patterninc/heimdall/pkg/context" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" "github.com/patterninc/heimdall/pkg/plugin" @@ -17,36 +17,37 @@ import ( ) // dynamoJobContext represents the context for a dynamo job -type jobContext struct { +type dynamoJobContext struct { Query string `yaml:"query,omitempty" json:"query,omitempty"` Limit int `yaml:"limit,omitempty" json:"limit,omitempty"` } // dynamoClusterContext represents the context for a dynamo endpoint -type clusterContext struct { +type dynamoClusterContext struct { RoleARN *string `yaml:"role_arn,omitempty" json:"role_arn,omitempty"` } // dynamoCommandContext represents the dynamo command context -type commandContext struct{} +type dynamoCommandContext struct{} var ( + ctx = ct.Background() assumeRoleSession = aws.String("AssumeRoleSession") ) // New creates a new dynamo plugin handler. -func New(_ *heimdallContext.Context) (plugin.Handler, error) { +func New(_ *context.Context) (plugin.Handler, error) { - s := &commandContext{} + s := &dynamoCommandContext{} return s.handler, nil } // Handler for the Spark job submission. -func (d *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) (err error) { +func (d *dynamoCommandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster.Cluster) (err error) { // let's unmarshal job context - jobContext := &jobContext{} + jobContext := &dynamoJobContext{} if j.Context != nil { if err := j.Context.Unmarshal(jobContext); err != nil { return err @@ -54,7 +55,7 @@ func (d *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job. } // let's unmarshal cluster context - clusterContext := &clusterContext{} + clusterContext := &dynamoClusterContext{} if c.Context != nil { if err := c.Context.Unmarshal(clusterContext); err != nil { return err diff --git a/internal/pkg/object/command/ecs/ecs.go b/internal/pkg/object/command/ecs/ecs.go index a15819f..21672cc 100644 --- a/internal/pkg/object/command/ecs/ecs.go +++ b/internal/pkg/object/command/ecs/ecs.go @@ -1,7 +1,7 @@ package ecs import ( - "context" + ct "context" "encoding/json" "fmt" "os" @@ -15,7 +15,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/hladush/go-telemetry/pkg/telemetry" heimdallAws "github.com/patterninc/heimdall/internal/pkg/aws" - heimdallContext "github.com/patterninc/heimdall/pkg/context" + "github.com/patterninc/heimdall/pkg/context" "github.com/patterninc/heimdall/pkg/duration" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" @@ -23,7 +23,7 @@ import ( ) // ECS command context structure -type commandContext struct { +type ecsCommandContext struct { TaskDefinitionTemplate string `yaml:"task_definition_template,omitempty" json:"task_definition_template,omitempty"` TaskCount int `yaml:"task_count,omitempty" json:"task_count,omitempty"` CPU int `yaml:"cpu,omitempty" json:"cpu,omitempty"` @@ -35,7 +35,7 @@ type commandContext struct { } // ECS cluster context structure -type clusterContext struct { +type ecsClusterContext struct { MaxCPU int `yaml:"max_cpu,omitempty" json:"max_cpu,omitempty"` MaxMemory int `yaml:"max_memory,omitempty" json:"max_memory,omitempty"` MaxTaskCount int `yaml:"max_task_count,omitempty" json:"max_task_count,omitempty"` @@ -86,7 +86,7 @@ type executionContext struct { Memory int `json:"memory"` TaskDefinitionWrapper *taskDefinitionWrapper `json:"task_definition_wrapper"` ContainerOverrides []types.ContainerOverride `json:"container_overrides"` - ClusterConfig *clusterContext `json:"cluster_config"` + ClusterConfig *ecsClusterContext `json:"cluster_config"` PollingInterval duration.Duration `json:"polling_interval"` Timeout duration.Duration `json:"timeout"` @@ -116,21 +116,22 @@ const ( ) var ( + ctx = ct.Background() errMissingTemplate = fmt.Errorf("task definition template is required") methodMetrics = telemetry.NewMethod("ecs", "ecs plugin") ) -func New(commandCtx *heimdallContext.Context) (plugin.Handler, error) { +func New(commandContext *context.Context) (plugin.Handler, error) { - e := &commandContext{ + e := &ecsCommandContext{ PollingInterval: defaultPollingInterval, Timeout: defaultTaskTimeout, MaxFailCount: defaultMaxFailCount, TaskCount: defaultTaskCount, } - if commandCtx != nil { - if err := commandCtx.Unmarshal(e); err != nil { + if commandContext != nil { + if err := commandContext.Unmarshal(e); err != nil { return nil, err } } @@ -140,31 +141,31 @@ func New(commandCtx *heimdallContext.Context) (plugin.Handler, error) { } // handler implements the main ECS plugin logic -func (e *commandContext) handler(ctx context.Context, r *plugin.Runtime, job *job.Job, cluster *cluster.Cluster) error { +func (e *ecsCommandContext) handler(r *plugin.Runtime, job *job.Job, cluster *cluster.Cluster) error { // Build execution context with resolved configuration and loaded template - execCtx, err := buildExecutionContext(ctx, e, job, cluster, r) + execCtx, err := buildExecutionContext(e, job, cluster, r) if err != nil { return err } // register task definition - if err := execCtx.registerTaskDefinition(ctx); err != nil { + if err := execCtx.registerTaskDefinition(); err != nil { return err } // Start tasks - if err := execCtx.startTasks(ctx, job.ID); err != nil { + if err := execCtx.startTasks(job.ID); err != nil { return err } // Poll for completion - if err := execCtx.pollForCompletion(ctx); err != nil { + if err := execCtx.pollForCompletion(); err != nil { return err } // Try to retrieve logs, but don't fail the job if it fails - if err := execCtx.retrieveLogs(ctx); err != nil { + if err := execCtx.retrieveLogs(); err != nil { execCtx.runtime.Stderr.WriteString(fmt.Sprintf("Failed to retrieve logs: %v\n", err)) } @@ -179,7 +180,7 @@ func (e *commandContext) handler(ctx context.Context, r *plugin.Runtime, job *jo } // prepare and register task definition with ECS -func (execCtx *executionContext) registerTaskDefinition(ctx context.Context) error { +func (execCtx *executionContext) registerTaskDefinition() error { registerInput := &ecs.RegisterTaskDefinitionInput{ Family: aws.String(aws.ToString(execCtx.TaskDefinitionWrapper.TaskDefinition.Family)), RequiresCompatibilities: []types.Compatibility{types.CompatibilityFargate}, @@ -203,10 +204,10 @@ func (execCtx *executionContext) registerTaskDefinition(ctx context.Context) err } // startTasks launches all tasks and returns a map of task trackers -func (execCtx *executionContext) startTasks(ctx context.Context, jobID string) error { +func (execCtx *executionContext) startTasks(jobID string) error { for i := 0; i < execCtx.TaskCount; i++ { - taskARN, err := runTask(ctx, execCtx, fmt.Sprintf("%s%s-%d", startedByPrefix, jobID, i), i) + taskARN, err := runTask(execCtx, fmt.Sprintf("%s%s-%d", startedByPrefix, jobID, i), i) if err != nil { return err } @@ -222,7 +223,7 @@ func (execCtx *executionContext) startTasks(ctx context.Context, jobID string) e } // monitor tasks until completion, faliure, or timeout -func (execCtx *executionContext) pollForCompletion(ctx context.Context) error { +func (execCtx *executionContext) pollForCompletion() error { startTime := time.Now() stopTime := startTime.Add(time.Duration(execCtx.Timeout)) @@ -286,7 +287,7 @@ func (execCtx *executionContext) pollForCompletion(ctx context.Context) error { // Stop all other running tasks reason := fmt.Sprintf(errMaxFailCount, tracker.ActiveARN, tracker.Retries, execCtx.MaxFailCount) - if err := stopAllTasks(ctx, execCtx, reason); err != nil { + if err := stopAllTasks(execCtx, reason); err != nil { return err } @@ -295,7 +296,7 @@ func (execCtx *executionContext) pollForCompletion(ctx context.Context) error { break } - newTaskARN, err := runTask(ctx, execCtx, tracker.Name, tracker.TaskNum) + newTaskARN, err := runTask(execCtx, tracker.Name, tracker.TaskNum) if err != nil { return err } @@ -329,7 +330,7 @@ func (execCtx *executionContext) pollForCompletion(ctx context.Context) error { // Stop all remaining tasks reason := fmt.Sprintf(errPollingTimeout, incompleteARNs, execCtx.Timeout) - if err := stopAllTasks(ctx, execCtx, reason); err != nil { + if err := stopAllTasks(execCtx, reason); err != nil { return err } @@ -337,7 +338,7 @@ func (execCtx *executionContext) pollForCompletion(ctx context.Context) error { break } - // sleep for polling interval + // Sleep until next poll time time.Sleep(time.Duration(execCtx.PollingInterval)) } @@ -346,7 +347,7 @@ func (execCtx *executionContext) pollForCompletion(ctx context.Context) error { } -func buildExecutionContext(ctx context.Context, commandCtx *commandContext, j *job.Job, c *cluster.Cluster, runtime *plugin.Runtime) (*executionContext, error) { +func buildExecutionContext(commandCtx *ecsCommandContext, j *job.Job, c *cluster.Cluster, runtime *plugin.Runtime) (*executionContext, error) { execCtx := &executionContext{ tasks: make(map[string]*taskTracker), @@ -354,7 +355,7 @@ func buildExecutionContext(ctx context.Context, commandCtx *commandContext, j *j } // Create a context from commandCtx and unmarshal onto execCtx (defaults) - commandContext := heimdallContext.New(commandCtx) + commandContext := context.New(commandCtx) if err := commandContext.Unmarshal(execCtx); err != nil { return nil, err } @@ -367,7 +368,7 @@ func buildExecutionContext(ctx context.Context, commandCtx *commandContext, j *j } // Add cluster config (no overlapping values) - clusterContext := &clusterContext{} + clusterContext := &ecsClusterContext{} if c.Context != nil { if err := c.Context.Unmarshal(clusterContext); err != nil { return nil, err @@ -463,7 +464,7 @@ func buildContainerOverrides(execCtx *executionContext) error { } // stopAllTasks stops all non-completed tasks with the given reason -func stopAllTasks(ctx context.Context, execCtx *executionContext, reason string) error { +func stopAllTasks(execCtx *executionContext, reason string) error { // AWS ECS has a 1024 character limit on the reason field if len(reason) > 1024 { reason = reason[:1021] + "..." @@ -536,7 +537,7 @@ func loadTaskDefinitionTemplate(templatePath string) (*taskDefinitionWrapper, er } // runTask runs a single task and returns the task ARN -func runTask(ctx context.Context, execCtx *executionContext, startedBy string, taskNum int) (string, error) { +func runTask(execCtx *executionContext, startedBy string, taskNum int) (string, error) { // Create a copy of the overrides and add TASK_NAME and TASK_NUM env variables finalOverrides := append([]types.ContainerOverride{}, execCtx.ContainerOverrides...) @@ -602,7 +603,7 @@ func isTaskSuccessful(task types.Task, execCtx *executionContext) bool { } // We pull logs from cloudwatch for all containers in a single task that represents the job outcome -func (execCtx *executionContext) retrieveLogs(ctx context.Context) error { +func (execCtx *executionContext) retrieveLogs() error { var selectedTask *taskTracker var writer *os.File @@ -654,7 +655,7 @@ func (execCtx *executionContext) retrieveLogs(ctx context.Context) error { case types.LogDriverAwslogs: logGroup := logInfo.options["awslogs-group"] logStream := fmt.Sprintf("%s/%s/%s", logInfo.options["awslogs-stream-prefix"], logInfo.containerName, taskID) - if err := heimdallAws.PullLogs(ctx, writer, logGroup, logStream, maxLogChunkSize, maxLogMemoryBytes); err != nil { + if err := heimdallAws.PullLogs(writer, logGroup, logStream, maxLogChunkSize, maxLogMemoryBytes); err != nil { return err } default: diff --git a/internal/pkg/object/command/glue/glue.go b/internal/pkg/object/command/glue/glue.go index 63503b6..63128a9 100644 --- a/internal/pkg/object/command/glue/glue.go +++ b/internal/pkg/object/command/glue/glue.go @@ -1,30 +1,28 @@ package glue import ( - "context" - "github.com/patterninc/heimdall/internal/pkg/aws" - heimdallContext "github.com/patterninc/heimdall/pkg/context" + "github.com/patterninc/heimdall/pkg/context" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" "github.com/patterninc/heimdall/pkg/plugin" "github.com/patterninc/heimdall/pkg/result" ) -type commandContext struct { +type glueCommandContext struct { CatalogID string `yaml:"catalog_id,omitempty" json:"catalog_id,omitempty"` } -type jobContext struct { +type glueJobContext struct { TableName string `yaml:"table_name,omitempty" json:"table_name,omitempty"` } -func New(commandCtx *heimdallContext.Context) (plugin.Handler, error) { +func New(commandContext *context.Context) (plugin.Handler, error) { - g := &commandContext{} + g := &glueCommandContext{} - if commandCtx != nil { - if err := commandCtx.Unmarshal(g); err != nil { + if commandContext != nil { + if err := commandContext.Unmarshal(g); err != nil { return nil, err } } @@ -33,10 +31,10 @@ func New(commandCtx *heimdallContext.Context) (plugin.Handler, error) { } -func (g *commandContext) handler(ctx context.Context, _ *plugin.Runtime, j *job.Job, _ *cluster.Cluster) (err error) { +func (g *glueCommandContext) handler(_ *plugin.Runtime, j *job.Job, _ *cluster.Cluster) (err error) { // let's unmarshal job context - jc := &jobContext{} + jc := &glueJobContext{} if j.Context != nil { if err = j.Context.Unmarshal(jc); err != nil { return @@ -44,7 +42,7 @@ func (g *commandContext) handler(ctx context.Context, _ *plugin.Runtime, j *job. } // let's get our metadata - metadata, err := aws.GetTableMetadata(ctx, g.CatalogID, jc.TableName) + metadata, err := aws.GetTableMetadata(g.CatalogID, jc.TableName) if err != nil { return } diff --git a/internal/pkg/object/command/ping/ping.go b/internal/pkg/object/command/ping/ping.go index b09a22c..64fd333 100644 --- a/internal/pkg/object/command/ping/ping.go +++ b/internal/pkg/object/command/ping/ping.go @@ -1,10 +1,9 @@ package ping import ( - "context" "fmt" - heimdallContext "github.com/patterninc/heimdall/pkg/context" + "github.com/patterninc/heimdall/pkg/context" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" "github.com/patterninc/heimdall/pkg/plugin" @@ -15,16 +14,16 @@ const ( messageFormat = `Hello, %s!` ) -type commandContext struct{} +type pingCommandContext struct{} -func New(_ *heimdallContext.Context) (plugin.Handler, error) { +func New(_ *context.Context) (plugin.Handler, error) { - p := &commandContext{} + p := &pingCommandContext{} return p.handler, nil } -func (p *commandContext) handler(ctx context.Context, _ *plugin.Runtime, j *job.Job, _ *cluster.Cluster) (err error) { +func (p *pingCommandContext) handler(_ *plugin.Runtime, j *job.Job, _ *cluster.Cluster) (err error) { j.Result, err = result.FromMessage(fmt.Sprintf(messageFormat, j.User)) return diff --git a/internal/pkg/object/command/shell/shell.go b/internal/pkg/object/command/shell/shell.go index aa34040..d16b20d 100644 --- a/internal/pkg/object/command/shell/shell.go +++ b/internal/pkg/object/command/shell/shell.go @@ -1,13 +1,12 @@ package shell import ( - "context" "encoding/json" "os" "os/exec" "path" - heimdallContext "github.com/patterninc/heimdall/pkg/context" + "github.com/patterninc/heimdall/pkg/context" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" "github.com/patterninc/heimdall/pkg/plugin" @@ -19,27 +18,27 @@ const ( contextFilename = `context.json` ) -type commandContext struct { +type shellCommandContext struct { Command []string `yaml:"command,omitempty" json:"command,omitempty"` } -type jobContext struct { +type shellJobContext struct { Arguments []string `yaml:"arguments,omitempty" json:"arguments,omitempty"` } type runtimeContext struct { - Job *job.Job `yaml:"job,omitempty" json:"job,omitempty"` - Command *commandContext `yaml:"command,omitempty" json:"command,omitempty"` - Cluster *cluster.Cluster `yaml:"cluster,omitempty" json:"cluster,omitempty"` - Runtime *plugin.Runtime `yaml:"runtime,omitempty" json:"runtime,omitempty"` + Job *job.Job `yaml:"job,omitempty" json:"job,omitempty"` + Command *shellCommandContext `yaml:"command,omitempty" json:"command,omitempty"` + Cluster *cluster.Cluster `yaml:"cluster,omitempty" json:"cluster,omitempty"` + Runtime *plugin.Runtime `yaml:"runtime,omitempty" json:"runtime,omitempty"` } -func New(commandCtx *heimdallContext.Context) (plugin.Handler, error) { +func New(commandContext *context.Context) (plugin.Handler, error) { - s := &commandContext{} + s := &shellCommandContext{} - if commandCtx != nil { - if err := commandCtx.Unmarshal(s); err != nil { + if commandContext != nil { + if err := commandContext.Unmarshal(s); err != nil { return nil, err } } @@ -48,10 +47,10 @@ func New(commandCtx *heimdallContext.Context) (plugin.Handler, error) { } -func (s *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { +func (s *shellCommandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { // let's unmarshal job context - jc := &jobContext{} + jc := &shellJobContext{} if j.Context != nil { if err := j.Context.Unmarshal(jc); err != nil { return err @@ -83,7 +82,7 @@ func (s *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job. commandWithArguments = append(commandWithArguments, jc.Arguments...) // configure command - cmd := exec.CommandContext(ctx, commandWithArguments[0], commandWithArguments[1:]...) + cmd := exec.Command(commandWithArguments[0], commandWithArguments[1:]...) cmd.Stdout = r.Stdout cmd.Stderr = r.Stderr diff --git a/internal/pkg/object/command/snowflake/snowflake.go b/internal/pkg/object/command/snowflake/snowflake.go index cf850db..be267a7 100644 --- a/internal/pkg/object/command/snowflake/snowflake.go +++ b/internal/pkg/object/command/snowflake/snowflake.go @@ -1,7 +1,6 @@ package snowflake import ( - "context" "crypto/rsa" "crypto/x509" "database/sql" @@ -11,7 +10,7 @@ import ( sf "github.com/snowflakedb/gosnowflake" - heimdallContext "github.com/patterninc/heimdall/pkg/context" + "github.com/patterninc/heimdall/pkg/context" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" "github.com/patterninc/heimdall/pkg/plugin" @@ -28,13 +27,13 @@ var ( ErrInvalidKeyType = fmt.Errorf(`invalida key type`) ) -type commandContext struct{} +type snowflakeCommandContext struct{} -type jobContext struct { +type snowflakeJobContext struct { Query string `yaml:"query,omitempty" json:"query,omitempty"` } -type clusterContext struct { +type snowflakeClusterContext struct { Account string `yaml:"account,omitempty" json:"account,omitempty"` User string `yaml:"user,omitempty" json:"user,omitempty"` Database string `yaml:"database,omitempty" json:"database,omitempty"` @@ -66,14 +65,14 @@ func parsePrivateKey(privateKeyBytes []byte) (*rsa.PrivateKey, error) { } -func New(_ *heimdallContext.Context) (plugin.Handler, error) { - s := &commandContext{} +func New(_ *context.Context) (plugin.Handler, error) { + s := &snowflakeCommandContext{} return s.handler, nil } -func (s *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { +func (s *snowflakeCommandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { - clusterContext := &clusterContext{} + clusterContext := &snowflakeClusterContext{} if c.Context != nil { if err := c.Context.Unmarshal(clusterContext); err != nil { return err @@ -81,7 +80,7 @@ func (s *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job. } // let's unmarshal job context - jobContext := &jobContext{} + jobContext := &snowflakeJobContext{} if err := j.Context.Unmarshal(jobContext); err != nil { return err } @@ -118,7 +117,7 @@ func (s *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job. } defer db.Close() - rows, err := db.QueryContext(ctx, jobContext.Query) + rows, err := db.Query(jobContext.Query) if err != nil { return err } diff --git a/internal/pkg/object/command/spark/spark.go b/internal/pkg/object/command/spark/spark.go index 0583a04..9d0f382 100644 --- a/internal/pkg/object/command/spark/spark.go +++ b/internal/pkg/object/command/spark/spark.go @@ -1,7 +1,7 @@ package spark import ( - "context" + ct "context" "encoding/json" "fmt" "os" @@ -19,7 +19,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/babourine/x/pkg/set" - heimdallContext "github.com/patterninc/heimdall/pkg/context" + "github.com/patterninc/heimdall/pkg/context" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" "github.com/patterninc/heimdall/pkg/plugin" @@ -32,7 +32,7 @@ type sparkSubmitParameters struct { } // spark represents the Spark command context -type commandContext struct { +type sparkCommandContext struct { QueriesURI string `yaml:"queries_uri,omitempty" json:"queries_uri,omitempty"` ResultsURI string `yaml:"results_uri,omitempty" json:"results_uri,omitempty"` LogsURI *string `yaml:"logs_uri,omitempty" json:"logs_uri,omitempty"` @@ -41,7 +41,7 @@ type commandContext struct { } // sparkJobContext represents the context for a spark job -type jobContext struct { +type sparkJobContext struct { Query string `yaml:"query,omitempty" json:"query,omitempty"` Arguments []string `yaml:"arguments,omitempty" json:"arguments,omitempty"` Parameters *sparkSubmitParameters `yaml:"parameters,omitempty" json:"parameters,omitempty"` @@ -49,7 +49,7 @@ type jobContext struct { } // sparkClusterContext represents the context for a spark cluster -type clusterContext struct { +type sparkClusterContext struct { ExecutionRoleArn *string `yaml:"execution_role_arn,omitempty" json:"execution_role_arn,omitempty"` EMRReleaseLabel *string `yaml:"emr_release_label,omitempty" json:"emr_release_label,omitempty"` RoleARN *string `yaml:"role_arn,omitempty" json:"role_arn,omitempty"` @@ -64,6 +64,7 @@ const ( ) var ( + ctx = ct.Background() sparkDefaults = aws.String(`spark-defaults`) assumeRoleSession = aws.String("AssumeRoleSession") runtimeStates = set.New([]types.JobRunState{types.JobRunStateCompleted, types.JobRunStateFailed, types.JobRunStateCancelled}) @@ -76,12 +77,12 @@ var ( ) // New creates a new Spark plugin handler. -func New(commandCtx *heimdallContext.Context) (plugin.Handler, error) { +func New(commandContext *context.Context) (plugin.Handler, error) { - s := &commandContext{} + s := &sparkCommandContext{} - if commandCtx != nil { - if err := commandCtx.Unmarshal(s); err != nil { + if commandContext != nil { + if err := commandContext.Unmarshal(s); err != nil { return nil, err } } @@ -91,10 +92,10 @@ func New(commandCtx *heimdallContext.Context) (plugin.Handler, error) { } // Handler for the Spark job submission. -func (s *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) (err error) { +func (s *sparkCommandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster.Cluster) (err error) { // let's unmarshal job context - jobContext := &jobContext{} + jobContext := &sparkJobContext{} if j.Context != nil { if err := j.Context.Unmarshal(jobContext); err != nil { return err @@ -102,7 +103,7 @@ func (s *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job. } // let's unmarshal cluster context - clusterContext := &clusterContext{} + clusterContext := &sparkClusterContext{} if c.Context != nil { if err := c.Context.Unmarshal(clusterContext); err != nil { return err @@ -164,7 +165,7 @@ func (s *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job. svc := emrcontainers.NewFromConfig(awsConfig, assumeRoleOptions) // let's get the cluster ID - clusterID, err := getClusterID(ctx, svc, c.Name) + clusterID, err := getClusterID(svc, c.Name) if err != nil { return err } @@ -174,7 +175,7 @@ func (s *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job. // upload query to s3 here... queryURI := fmt.Sprintf("%s/%s/query.sql", s.QueriesURI, j.ID) - if err := uploadFileToS3(ctx, queryURI, jobContext.Query); err != nil { + if err := uploadFileToS3(queryURI, jobContext.Query); err != nil { return err } @@ -261,7 +262,7 @@ timeoutLoop: } -func (s *commandContext) setJobDriver(jobContext *jobContext, jobDriver *types.JobDriver, queryURI string, resultURI string) { +func (s *sparkCommandContext) setJobDriver(jobContext *sparkJobContext, jobDriver *types.JobDriver, queryURI string, resultURI string) { jobParameters := getSparkSubmitParameters(jobContext) if jobContext.Arguments != nil { jobDriver.SparkSubmitJobDriver = &types.SparkSubmitJobDriver{ @@ -287,7 +288,7 @@ func (s *commandContext) setJobDriver(jobContext *jobContext, jobDriver *types.J } -func getClusterID(ctx context.Context, svc *emrcontainers.Client, clusterName string) (*string, error) { +func getClusterID(svc *emrcontainers.Client, clusterName string) (*string, error) { // let's get the cluster ID outputListClusters, err := svc.ListVirtualClusters(ctx, &emrcontainers.ListVirtualClustersInput{ @@ -308,7 +309,7 @@ func getClusterID(ctx context.Context, svc *emrcontainers.Client, clusterName st } -func getSparkSubmitParameters(context *jobContext) *string { +func getSparkSubmitParameters(context *sparkJobContext) *string { properties := context.Parameters.Properties conf := make([]string, 0, len(properties)) @@ -326,7 +327,7 @@ func printState(stdout *os.File, state types.JobRunState) { stdout.WriteString(fmt.Sprintf("%v - job is still running. latest status: %v\n", time.Now(), state)) } -func uploadFileToS3(ctx context.Context, fileURI, content string) error { +func uploadFileToS3(fileURI, content string) error { // get bucket name and prefix s3Parts := rxS3.FindAllStringSubmatch(fileURI, -1) diff --git a/internal/pkg/object/command/sparkeks/sparkeks.go b/internal/pkg/object/command/sparkeks/sparkeks.go index ac9e8a1..4750889 100644 --- a/internal/pkg/object/command/sparkeks/sparkeks.go +++ b/internal/pkg/object/command/sparkeks/sparkeks.go @@ -73,6 +73,7 @@ const ( ) var ( + ctx = context.Background() rxS3 = regexp.MustCompile(`^s3://([^/]+)/(.*)$`) runtimeStates = []v1beta2.ApplicationStateType{ v1beta2.ApplicationStateCompleted, @@ -90,24 +91,24 @@ var ( ErrSparkApplicationFile = fmt.Errorf("failed to read SparkApplication application template file: check file path and permissions") ) -type commandContext struct { +type sparkEksCommandContext struct { JobsURI string `yaml:"jobs_uri,omitempty" json:"jobs_uri,omitempty"` WrapperURI string `yaml:"wrapper_uri,omitempty" json:"wrapper_uri,omitempty"` Properties map[string]string `yaml:"properties,omitempty" json:"properties,omitempty"` KubeNamespace string `yaml:"kube_namespace,omitempty" json:"kube_namespace,omitempty"` } -type jobParameters struct { +type sparkEksJobParameters struct { Properties map[string]string `yaml:"properties,omitempty" json:"properties,omitempty"` } -type jobContext struct { - Query string `yaml:"query,omitempty" json:"query,omitempty"` - Parameters *jobParameters `yaml:"parameters,omitempty" json:"parameters,omitempty"` - ReturnResult bool `yaml:"return_result,omitempty" json:"return_result,omitempty"` +type sparkEksJobContext struct { + Query string `yaml:"query,omitempty" json:"query,omitempty"` + Parameters *sparkEksJobParameters `yaml:"parameters,omitempty" json:"parameters,omitempty"` + ReturnResult bool `yaml:"return_result,omitempty" json:"return_result,omitempty"` } -type clusterContext struct { +type sparkEksClusterContext struct { RoleARN *string `yaml:"role_arn,omitempty" json:"role_arn,omitempty"` Properties map[string]string `yaml:"properties,omitempty" json:"properties,omitempty"` Image *string `yaml:"image,omitempty" json:"image,omitempty"` @@ -121,9 +122,9 @@ type executionContext struct { runtime *plugin.Runtime job *job.Job cluster *cluster.Cluster - commandContext *commandContext - jobContext *jobContext - clusterContext *clusterContext + commandContext *sparkEksCommandContext + jobContext *sparkEksJobContext + clusterContext *sparkEksClusterContext sparkClient *sparkClientSet.Clientset kubeClient *kubernetes.Clientset @@ -139,13 +140,13 @@ type executionContext struct { } // New creates a new Spark EKS plugin handler. -func New(commandCtx *heimdallContext.Context) (plugin.Handler, error) { - s := &commandContext{ +func New(commandContext *heimdallContext.Context) (plugin.Handler, error) { + s := &sparkEksCommandContext{ KubeNamespace: defaultNamespace, } - if commandCtx != nil { - if err := commandCtx.Unmarshal(s); err != nil { + if commandContext != nil { + if err := commandContext.Unmarshal(s); err != nil { return nil, err } } @@ -154,24 +155,23 @@ func New(commandCtx *heimdallContext.Context) (plugin.Handler, error) { } // handler executes the Spark EKS job submission and execution. -func (s *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { - +func (s *sparkEksCommandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { // 1. Build execution context, create URIs, and upload query - execCtx, err := buildExecutionContextAndURI(ctx, r, j, c, s) + execCtx, err := buildExecutionContextAndURI(r, j, c, s) if err != nil { return fmt.Errorf("failed to build execution context: %w", err) } // 2. Submit the Spark Application to the cluster - if err := execCtx.submitSparkApp(ctx); err != nil { + if err := execCtx.submitSparkApp(); err != nil { return err } // 3. Monitor the job until completion and collect logs - monitorErr := execCtx.monitorJobAndCollectLogs(ctx) + monitorErr := execCtx.monitorJobAndCollectLogs() // 4. Cleanup any resources that are still pending - if err := execCtx.cleanupSparkApp(ctx); err != nil { + if err := execCtx.cleanupSparkApp(); err != nil { // Log cleanup error but don't override the main monitoring error execCtx.runtime.Stderr.WriteString(fmt.Sprintf("Warning: failed to cleanup application %s: %v\n", execCtx.submittedApp.Name, err)) } @@ -182,7 +182,7 @@ func (s *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job. } // 5. Get and store results if required - if err := execCtx.getAndStoreResults(ctx); err != nil { + if err := execCtx.getAndStoreResults(); err != nil { return err } @@ -190,7 +190,7 @@ func (s *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job. } // buildExecutionContextAndURI prepares the context, merges configurations, and uploads the query. -func buildExecutionContextAndURI(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster, s *commandContext) (*executionContext, error) { +func buildExecutionContextAndURI(r *plugin.Runtime, j *job.Job, c *cluster.Cluster, s *sparkEksCommandContext) (*executionContext, error) { execCtx := &executionContext{ runtime: r, job: j, @@ -199,7 +199,7 @@ func buildExecutionContextAndURI(ctx context.Context, r *plugin.Runtime, j *job. } // Parse job context - jobContext := &jobContext{} + jobContext := &sparkEksJobContext{} if j.Context != nil { if err := j.Context.Unmarshal(jobContext); err != nil { return nil, fmt.Errorf("failed to unmarshal job context: %w", err) @@ -208,7 +208,7 @@ func buildExecutionContextAndURI(ctx context.Context, r *plugin.Runtime, j *job. execCtx.jobContext = jobContext // Parse cluster context - clusterContext := &clusterContext{} + clusterContext := &sparkEksClusterContext{} if c.Context != nil { if err := c.Context.Unmarshal(clusterContext); err != nil { return nil, fmt.Errorf("failed to unmarshal cluster context: %w", err) @@ -218,7 +218,7 @@ func buildExecutionContextAndURI(ctx context.Context, r *plugin.Runtime, j *job. // Initialize and merge properties from command -> job if execCtx.jobContext.Parameters == nil { - execCtx.jobContext.Parameters = &jobParameters{ + execCtx.jobContext.Parameters = &sparkEksJobParameters{ Properties: make(map[string]string), } } @@ -248,12 +248,12 @@ func buildExecutionContextAndURI(ctx context.Context, r *plugin.Runtime, j *job. execCtx.logURI = fmt.Sprintf("%s/%s/%s", s.JobsURI, j.ID, logsPath) // Upload query to S3 - if err := uploadFileToS3(ctx, execCtx.awsConfig, execCtx.queryURI, execCtx.jobContext.Query); err != nil { + if err := uploadFileToS3(execCtx.awsConfig, execCtx.queryURI, execCtx.jobContext.Query); err != nil { return nil, fmt.Errorf("failed to upload query to S3: %w", err) } // create empty log s3 directory to avoid spark event log dir errors - if err := uploadFileToS3(ctx, execCtx.awsConfig, fmt.Sprintf("%s/.keepdir", execCtx.logURI), ""); err != nil { + if err := uploadFileToS3(execCtx.awsConfig, fmt.Sprintf("%s/.keepdir", execCtx.logURI), ""); err != nil { return nil, fmt.Errorf("failed to create log directory in S3: %w", err) } @@ -261,9 +261,9 @@ func buildExecutionContextAndURI(ctx context.Context, r *plugin.Runtime, j *job. } // submitSparkApp creates clients, generates the spec, and submits it to Kubernetes. -func (e *executionContext) submitSparkApp(ctx context.Context) error { +func (e *executionContext) submitSparkApp() error { // Create Kubernetes and Spark Operator clients - if err := createSparkClients(ctx, e); err != nil { + if err := createSparkClients(e); err != nil { return fmt.Errorf("failed to create Spark Operator client: %w", err) } @@ -297,7 +297,7 @@ func (e *executionContext) submitSparkApp(ctx context.Context) error { } // cleanupSparkApp removes the SparkApplication from the cluster if it still exists. -func (e *executionContext) cleanupSparkApp(ctx context.Context) error { +func (e *executionContext) cleanupSparkApp() error { if e.submittedApp == nil { return nil } @@ -316,12 +316,12 @@ func (e *executionContext) cleanupSparkApp(ctx context.Context) error { } // getAndStoreResults fetches the job output from S3 and stores it. -func (e *executionContext) getAndStoreResults(ctx context.Context) error { +func (e *executionContext) getAndStoreResults() error { if !e.jobContext.ReturnResult { return nil } - returnResultFileURI, err := getS3FileURI(ctx, e.awsConfig, e.resultURI, avroFileExtension) + returnResultFileURI, err := getS3FileURI(e.awsConfig, e.resultURI, avroFileExtension) if err != nil { e.runtime.Stdout.WriteString(fmt.Sprintf("failed to find .avro file in results directory %s: %s", e.resultURI, err)) return fmt.Errorf("failed to find .avro file in results directory %s: %w", e.resultURI, err) @@ -335,7 +335,7 @@ func (e *executionContext) getAndStoreResults(ctx context.Context) error { } // uploadFileToS3 uploads content to S3. -func uploadFileToS3(ctx context.Context, awsConfig aws.Config, fileURI, content string) error { +func uploadFileToS3(awsConfig aws.Config, fileURI, content string) error { s3Parts := rxS3.FindAllStringSubmatch(fileURI, -1) if len(s3Parts) == 0 || len(s3Parts[0]) < 3 { return fmt.Errorf("unexpected S3 URI format: %s", fileURI) @@ -358,7 +358,7 @@ func updateS3ToS3aURI(uri string) string { } // getS3FileURI finds a file in an S3 directory that matches the given extension. -func getS3FileURI(ctx context.Context, awsConfig aws.Config, directoryURI, matchingExtension string) (string, error) { +func getS3FileURI(awsConfig aws.Config, directoryURI, matchingExtension string) (string, error) { s3Parts := rxS3.FindAllStringSubmatch(directoryURI, -1) if len(s3Parts) == 0 || len(s3Parts[0]) < 3 { return "", fmt.Errorf("invalid S3 URI format: %s", directoryURI) @@ -390,7 +390,7 @@ func printState(writer io.Writer, state v1beta2.ApplicationStateType) { } // getSparkSubmitParameters returns Spark submit parameters as a string. -func getSparkSubmitParameters(context *jobContext) *string { +func getSparkSubmitParameters(context *sparkEksJobContext) *string { if context.Parameters == nil || len(context.Parameters.Properties) == 0 { return nil } @@ -403,7 +403,7 @@ func getSparkSubmitParameters(context *jobContext) *string { } // getSparkApplicationPods returns the list of pods associated with a Spark application. -func getSparkApplicationPods(ctx context.Context, kubeClient *kubernetes.Clientset, appName, namespace string) ([]corev1.Pod, error) { +func getSparkApplicationPods(kubeClient *kubernetes.Clientset, appName, namespace string) ([]corev1.Pod, error) { labelSelector := fmt.Sprintf("%s=%s", sparkAppLabelSelectorFormat, appName) podList, err := kubeClient.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{LabelSelector: labelSelector}) if err != nil { @@ -437,7 +437,7 @@ func writeDriverLogsToStderr(execCtx *executionContext, pod corev1.Pod, logConte } // getAndUploadPodContainerLogs fetches logs from a specific container in a pod and uploads them to S3. -func getAndUploadPodContainerLogs(ctx context.Context, execCtx *executionContext, pod corev1.Pod, container corev1.Container, previous bool, logType string, writeToStderr bool) { +func getAndUploadPodContainerLogs(execCtx *executionContext, pod corev1.Pod, container corev1.Container, previous bool, logType string, writeToStderr bool) { logOptions := &corev1.PodLogOptions{Container: container.Name, Previous: previous} req := execCtx.kubeClient.CoreV1().Pods(pod.Namespace).GetLogs(pod.Name, logOptions) logs, err := req.Stream(ctx) @@ -457,31 +457,31 @@ func getAndUploadPodContainerLogs(ctx context.Context, execCtx *executionContext } logURI := fmt.Sprintf("%s/%s-%s", execCtx.logURI, pod.Name, logType) - if err := uploadFileToS3(ctx, execCtx.awsConfig, logURI, string(logContent)); err != nil { + if err := uploadFileToS3(execCtx.awsConfig, logURI, string(logContent)); err != nil { execCtx.runtime.Stderr.WriteString(fmt.Sprintf("Pod %s, container %s: %s upload error: %v\n", pod.Name, container.Name, logType, err)) } } } // getSparkApplicationPodLogs fetches logs from pods and uploads them to S3. -func getSparkApplicationPodLogs(ctx context.Context, execCtx *executionContext, pods []corev1.Pod, writeToStderr bool) error { +func getSparkApplicationPodLogs(execCtx *executionContext, pods []corev1.Pod, writeToStderr bool) error { for _, pod := range pods { if !isPodInValidPhase(pod) { continue } for _, container := range pod.Spec.Containers { // Get current logs and upload - getAndUploadPodContainerLogs(ctx, execCtx, pod, container, false, stdoutLogSuffix, writeToStderr) + getAndUploadPodContainerLogs(execCtx, pod, container, false, stdoutLogSuffix, writeToStderr) // Get logs from previous (failed) runs and upload - getAndUploadPodContainerLogs(ctx, execCtx, pod, container, true, stderrLogSuffix, false) + getAndUploadPodContainerLogs(execCtx, pod, container, true, stderrLogSuffix, false) } } return nil } // createSparkClients creates Kubernetes and Spark clients for the EKS cluster. -func createSparkClients(ctx context.Context, execCtx *executionContext) error { - kubeconfigPath, err := updateKubeConfig(ctx, execCtx) +func createSparkClients(execCtx *executionContext) error { + kubeconfigPath, err := updateKubeConfig(execCtx) if err != nil { return fmt.Errorf("failed to update kubeconfig: %w", err) } @@ -513,7 +513,7 @@ func createSparkClients(ctx context.Context, execCtx *executionContext) error { return nil } -func updateKubeConfig(ctx context.Context, execCtx *executionContext) (string, error) { +func updateKubeConfig(execCtx *executionContext) (string, error) { region := os.Getenv(awsRegionEnvVar) if execCtx.clusterContext.Region != nil { region = *execCtx.clusterContext.Region @@ -762,7 +762,7 @@ func loadTemplate(execCtx *executionContext) (*v1beta2.SparkApplication, error) } // monitorJobAndCollectLogs monitors the Spark job until completion and collects logs. -func (e *executionContext) monitorJobAndCollectLogs(ctx context.Context) error { +func (e *executionContext) monitorJobAndCollectLogs() error { appName, namespace := e.submittedApp.Name, e.submittedApp.Namespace e.runtime.Stdout.WriteString(fmt.Sprintf("Monitoring Spark application: %s\n", appName)) @@ -774,7 +774,7 @@ func (e *executionContext) monitorJobAndCollectLogs(ctx context.Context) error { for { if monitorCtx.Err() != nil { if finalSparkApp != nil { - collectSparkApplicationLogs(ctx, e, finalSparkApp, true) + collectSparkApplicationLogs(e, finalSparkApp, true) } if monitorCtx.Err() == context.DeadlineExceeded { return fmt.Errorf("spark job timed out after %v", jobTimeout) @@ -788,7 +788,7 @@ func (e *executionContext) monitorJobAndCollectLogs(ctx context.Context) error { if err != nil { e.runtime.Stderr.WriteString(fmt.Sprintf("Spark application %s/%s not found or deleted externally: %v\n", namespace, appName, err)) if finalSparkApp != nil { - collectSparkApplicationLogs(ctx, e, finalSparkApp, true) + collectSparkApplicationLogs(e, finalSparkApp, true) } return fmt.Errorf("spark application %s/%s not found: %w", namespace, appName, err) } @@ -802,17 +802,17 @@ func (e *executionContext) monitorJobAndCollectLogs(ctx context.Context) error { } if state == v1beta2.ApplicationStateRunning { - collectSparkApplicationLogs(ctx, e, sparkApp, false) + collectSparkApplicationLogs(e, sparkApp, false) continue } switch state { case v1beta2.ApplicationStateCompleted: - collectSparkApplicationLogs(ctx, e, sparkApp, false) + collectSparkApplicationLogs(e, sparkApp, false) e.runtime.Stdout.WriteString("Spark job completed successfully\n") return nil case v1beta2.ApplicationStateFailed: - collectSparkApplicationLogs(ctx, e, sparkApp, true) + collectSparkApplicationLogs(e, sparkApp, true) errorMessage := sparkApp.Status.AppState.ErrorMessage if errorMessage == "" { errorMessage = unknownErrorMsg @@ -820,7 +820,7 @@ func (e *executionContext) monitorJobAndCollectLogs(ctx context.Context) error { e.runtime.Stderr.WriteString(fmt.Sprintf("Spark job failed: %s\n", errorMessage)) return fmt.Errorf("spark job failed: %s", errorMessage) case v1beta2.ApplicationStateFailedSubmission, v1beta2.ApplicationStateUnknown: - collectSparkApplicationLogs(ctx, e, finalSparkApp, true) + collectSparkApplicationLogs(e, finalSparkApp, true) msg := sparkJobSubmissionFailedMsg if state == v1beta2.ApplicationStateUnknown { msg = sparkAppUnknownStateMsg @@ -831,17 +831,17 @@ func (e *executionContext) monitorJobAndCollectLogs(ctx context.Context) error { } // collectSparkApplicationLogs collects logs from Spark application pods. -func collectSparkApplicationLogs(ctx context.Context, execCtx *executionContext, sparkApp *v1beta2.SparkApplication, writeToStderr bool) { +func collectSparkApplicationLogs(execCtx *executionContext, sparkApp *v1beta2.SparkApplication, writeToStderr bool) { if sparkApp == nil { return } - pods, err := getSparkApplicationPods(ctx, execCtx.kubeClient, sparkApp.Name, sparkApp.Namespace) + pods, err := getSparkApplicationPods(execCtx.kubeClient, sparkApp.Name, sparkApp.Namespace) if err != nil { execCtx.runtime.Stderr.WriteString(fmt.Sprintf("Warning: failed to get Spark application pods: %v\n", err)) return } - if err := getSparkApplicationPodLogs(ctx, execCtx, pods, writeToStderr); err != nil { + if err := getSparkApplicationPodLogs(execCtx, pods, writeToStderr); err != nil { execCtx.runtime.Stderr.WriteString(fmt.Sprintf("Warning: failed to collect pod logs: %v\n", err)) } } diff --git a/internal/pkg/object/command/sparkeks/sparkeks_test.go b/internal/pkg/object/command/sparkeks/sparkeks_test.go index dc823c4..ad56b69 100644 --- a/internal/pkg/object/command/sparkeks/sparkeks_test.go +++ b/internal/pkg/object/command/sparkeks/sparkeks_test.go @@ -1,7 +1,6 @@ package sparkeks import ( - "context" "os" "strings" "testing" @@ -25,8 +24,8 @@ func TestUpdateS3ToS3aURI(t *testing.T) { } func TestGetSparkSubmitParameters(t *testing.T) { - ctx := &jobContext{ - Parameters: &jobParameters{ + ctx := &sparkEksJobContext{ + Parameters: &sparkEksJobParameters{ Properties: map[string]string{ "spark.executor.memory": "4g", "spark.driver.cores": "2", @@ -40,8 +39,8 @@ func TestGetSparkSubmitParameters(t *testing.T) { } func TestGetSparkSubmitParameters_Empty(t *testing.T) { - ctx := &jobContext{ - Parameters: &jobParameters{ + ctx := &sparkEksJobContext{ + Parameters: &sparkEksJobParameters{ Properties: map[string]string{}, }, } @@ -52,7 +51,7 @@ func TestGetSparkSubmitParameters_Empty(t *testing.T) { } func TestGetSparkSubmitParameters_NilParameters(t *testing.T) { - ctx := &jobContext{ + ctx := &sparkEksJobContext{ Parameters: nil, } params := getSparkSubmitParameters(ctx) @@ -62,7 +61,7 @@ func TestGetSparkSubmitParameters_NilParameters(t *testing.T) { } func TestGetS3FileURI_InvalidFormat(t *testing.T) { - _, err := getS3FileURI(context.Background(), aws.Config{}, "invalid-uri", "avro") + _, err := getS3FileURI(aws.Config{}, "invalid-uri", "avro") if err == nil { t.Error("Expected error for invalid S3 URI format") } @@ -71,7 +70,7 @@ func TestGetS3FileURI_InvalidFormat(t *testing.T) { func TestGetS3FileURI_ValidFormat(t *testing.T) { // This test only checks parsing, not actual AWS interaction uri := "s3://bucket/path/" - _, err := getS3FileURI(context.Background(), aws.Config{}, uri, "avro") + _, err := getS3FileURI(aws.Config{}, uri, "avro") // Should not error on parsing, but will error on AWS call (which is fine for unit test context) if err == nil || !strings.Contains(err.Error(), "failed to list S3 objects") { t.Logf("Expected AWS list objects error, got: %v", err) diff --git a/internal/pkg/object/command/trino/client.go b/internal/pkg/object/command/trino/client.go index ce39bd5..793ddf9 100644 --- a/internal/pkg/object/command/trino/client.go +++ b/internal/pkg/object/command/trino/client.go @@ -2,7 +2,6 @@ package trino import ( "bytes" - "context" "encoding/json" "fmt" "io" @@ -46,7 +45,7 @@ type response struct { Error map[string]any `json:"error"` } -func newRequest(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster, jobCtx *jobContext) (*request, error) { +func newRequest(r *plugin.Runtime, j *job.Job, c *cluster.Cluster, jobCtx *jobContext) (*request, error) { // get cluster context clusterCtx := &clusterContext{} @@ -73,7 +72,7 @@ func newRequest(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.C } // submit query - if err := req.submit(ctx, jobCtx.Query); err != nil { + if err := req.submit(jobCtx.Query); err != nil { return nil, err } @@ -81,9 +80,9 @@ func newRequest(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.C } -func (r *request) submit(ctx context.Context, query string) error { +func (r *request) submit(query string) error { - req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/v1/statement", r.endpoint), bytes.NewBuffer([]byte(query))) + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/v1/statement", r.endpoint), bytes.NewBuffer([]byte(query))) if err != nil { return err } @@ -92,9 +91,9 @@ func (r *request) submit(ctx context.Context, query string) error { } -func (r *request) poll(ctx context.Context) error { +func (r *request) poll() error { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, r.nextUri, nil) + req, err := http.NewRequest(http.MethodGet, r.nextUri, nil) if err != nil { return err } @@ -175,4 +174,4 @@ func (r *request) api(req *http.Request) error { func normalizeTrinoQuery(query string) string { // Trino does not support semicolon at the end of the query, so we remove it if present return strings.TrimSuffix(query, ";") -} +} \ No newline at end of file diff --git a/internal/pkg/object/command/trino/trino.go b/internal/pkg/object/command/trino/trino.go index 8d246da..ce1891c 100644 --- a/internal/pkg/object/command/trino/trino.go +++ b/internal/pkg/object/command/trino/trino.go @@ -1,13 +1,12 @@ package trino import ( - "context" "fmt" "log" "time" "github.com/hladush/go-telemetry/pkg/telemetry" - heimdallContext "github.com/patterninc/heimdall/pkg/context" + "github.com/patterninc/heimdall/pkg/context" "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" "github.com/patterninc/heimdall/pkg/plugin" @@ -37,14 +36,14 @@ type jobContext struct { } // New creates a new trino plugin handler -func New(commandCtx *heimdallContext.Context) (plugin.Handler, error) { +func New(ctx *context.Context) (plugin.Handler, error) { t := &commandContext{ PollInterval: defaultPollInterval, } - if commandCtx != nil { - if err := commandCtx.Unmarshal(t); err != nil { + if ctx != nil { + if err := ctx.Unmarshal(t); err != nil { return nil, err } } @@ -53,7 +52,7 @@ func New(commandCtx *heimdallContext.Context) (plugin.Handler, error) { } -func (t *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { +func (t *commandContext) handler(r *plugin.Runtime, j *job.Job, c *cluster.Cluster) error { // get job context jobCtx := &jobContext{} @@ -69,7 +68,7 @@ func (t *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job. // this code will be enabled in prod after some testing } // let's submit our query to trino - req, err := newRequest(ctx, r, j, c, jobCtx) + req, err := newRequest(r, j, c, jobCtx) if err != nil { return err } @@ -77,7 +76,7 @@ func (t *commandContext) handler(ctx context.Context, r *plugin.Runtime, j *job. // now let's keep pooling until we get the full result... for req.nextUri != `` { time.Sleep(time.Duration(t.PollInterval) * time.Millisecond) - if err := req.poll(ctx); err != nil { + if err := req.poll(); err != nil { return err } } @@ -113,7 +112,7 @@ func canQueryBeExecuted(query, user, id string, c *cluster.Cluster) bool { return false } } - + canBeExecutedMethod.CountSuccess() return true } diff --git a/internal/pkg/pool/pool.go b/internal/pkg/pool/pool.go index dec05db..368a12d 100644 --- a/internal/pkg/pool/pool.go +++ b/internal/pkg/pool/pool.go @@ -1,7 +1,6 @@ package pool import ( - "context" "fmt" "time" ) @@ -17,7 +16,7 @@ type Pool[T any] struct { queue chan T } -func (p *Pool[T]) Start(worker func(context.Context, T) error, getWork func(int) ([]T, error)) error { +func (p *Pool[T]) Start(worker func(T) error, getWork func(int) ([]T, error)) error { // do we have the size set? if p.Size <= 0 { @@ -54,9 +53,7 @@ func (p *Pool[T]) Start(worker func(context.Context, T) error, getWork func(int) } // do the work.... - err := worker(context.Background(), w) - - if err != nil { + if err := worker(w); err != nil { // TODO: implement proper error logging fmt.Println(`worker:`, err) } @@ -109,4 +106,5 @@ func (p *Pool[T]) Start(worker func(context.Context, T) error, getWork func(int) }(tokens) return nil + } diff --git a/pkg/object/job/job.go b/pkg/object/job/job.go index 22b9c83..b0d6855 100644 --- a/pkg/object/job/job.go +++ b/pkg/object/job/job.go @@ -19,9 +19,8 @@ type Job struct { ClusterCriteria *set.Set[string] `yaml:"cluster_criteria,omitempty" json:"cluster_criteria,omitempty"` CommandID string `yaml:"command_id,omitempty" json:"command_id,omitempty"` CommandName string `yaml:"command_name,omitempty" json:"command_name,omitempty"` - ClusterID string `yaml:"cluster_id,omitempty" json:"cluster_id,omitempty"` + CluserID string `yaml:"cluster_id,omitempty" json:"cluster_id,omitempty"` ClusterName string `yaml:"cluster_name,omitempty" json:"cluster_name,omitempty"` - CancelledBy string `yaml:"cancelled_by,omitempty" json:"cancelled_by,omitempty"` Result *result.Result `yaml:"result,omitempty" json:"result,omitempty"` } diff --git a/pkg/object/job/status/status.go b/pkg/object/job/status/status.go index 375c540..efe449a 100644 --- a/pkg/object/job/status/status.go +++ b/pkg/object/job/status/status.go @@ -11,14 +11,12 @@ import ( type Status status.Status const ( - New Status = 1 - Accepted Status = 2 - Running Status = 3 - Failed Status = 4 - Killed Status = 5 - Succeeded Status = 6 - Cancelling Status = 7 - Cancelled Status = 8 + New Status = 1 + Accepted Status = 2 + Running Status = 3 + Failed Status = 4 + Killed Status = 5 + Succeeded Status = 6 ) const ( @@ -27,14 +25,12 @@ const ( var ( statusMapping = map[string]status.Status{ - `new`: status.Status(New), - `accepted`: status.Status(Accepted), - `running`: status.Status(Running), - `failed`: status.Status(Failed), - `killed`: status.Status(Killed), - `succeeded`: status.Status(Succeeded), - `cancelling`: status.Status(Cancelling), - `cancelled`: status.Status(Cancelled), + `new`: status.Status(New), + `accepted`: status.Status(Accepted), + `running`: status.Status(Running), + `failed`: status.Status(Failed), + `killed`: status.Status(Killed), + `succeeded`: status.Status(Succeeded), } ) diff --git a/pkg/plugin/plugin.go b/pkg/plugin/plugin.go index cfb6db6..3512cbf 100644 --- a/pkg/plugin/plugin.go +++ b/pkg/plugin/plugin.go @@ -1,10 +1,8 @@ package plugin import ( - "context" - "github.com/patterninc/heimdall/pkg/object/cluster" "github.com/patterninc/heimdall/pkg/object/job" ) -type Handler func(context.Context, *Runtime, *job.Job, *cluster.Cluster) error +type Handler func(*Runtime, *job.Job, *cluster.Cluster) error diff --git a/pkg/plugin/runtime.go b/pkg/plugin/runtime.go index 179fb81..6c0d616 100644 --- a/pkg/plugin/runtime.go +++ b/pkg/plugin/runtime.go @@ -1,7 +1,6 @@ package plugin import ( - "context" "fmt" "io/fs" "os" @@ -98,9 +97,7 @@ func copyDir(src, dst string) error { // if we have local filesystem, crete directory as appropriate writeFileFunc := os.WriteFile if strings.HasPrefix(dst, s3Prefix) { - writeFileFunc = func(name string, data []byte, perm os.FileMode) error { - return aws.WriteToS3(context.Background(), name, data, perm) - } + writeFileFunc = aws.WriteToS3 } else { if _, err := os.Stat(dst); os.IsNotExist(err) { os.MkdirAll(dst, jobDirectoryPermissions) diff --git a/web/src/modules/Jobs/Helper.tsx b/web/src/modules/Jobs/Helper.tsx index c47f205..50ce30d 100644 --- a/web/src/modules/Jobs/Helper.tsx +++ b/web/src/modules/Jobs/Helper.tsx @@ -36,7 +36,6 @@ export type JobType = { command_name: string cluster_id: string cluster_name: string - cancelled_by: string error?: string context?: { properties: {