From 4d027a39e3ac18d313bda1f29b6d41d4dd7f2449 Mon Sep 17 00:00:00 2001 From: bpeng Date: Sat, 10 Jan 2026 12:48:26 +1300 Subject: [PATCH 1/2] feat: GetAllConcurrentlyWithContext function --- aws/s3/s3_concurrent.go | 110 +++++++++++++++++++++++++++++++++++ aws/s3/s3_concurrent_test.go | 56 ++++++++++++++++++ 2 files changed, 166 insertions(+) diff --git a/aws/s3/s3_concurrent.go b/aws/s3/s3_concurrent.go index b3b4c0a..7be8b4b 100644 --- a/aws/s3/s3_concurrent.go +++ b/aws/s3/s3_concurrent.go @@ -188,6 +188,89 @@ func (s *S3Concurrent) GetAllConcurrently(bucket, version string, objects []type return s.manager.Process(processFunc, objects) } +// GetAllConcurrently gets the objects with provided context, from specified bucket and writes the resulting HydratedFiles +// to the returned output channel. The closure of this channel is handled, however it's the caller's +// responsibility to purge the channel, and handle any errors present in the HydratedFiles. +// If the ConcurrencyManager is not initialised before calling GetAllConcurrently, an output channel +// containing a single HydratedFile with an error is returned. +// Version can be empty, but must be the same for all objects. +func (s *S3Concurrent) GetAllConcurrentlyWithContext( + ctx context.Context, + bucket, version string, + objects []types.Object, +) chan HydratedFile { + + output := make(chan HydratedFile, 1) + + // Early cancel check + select { + case <-ctx.Done(): + output <- HydratedFile{Error: ctx.Err()} + close(output) + return output + default: + } + + if s.manager == nil { + output <- HydratedFile{ + Error: errors.New("error getting files from S3, Concurrency Manager not initialised"), + } + close(output) + return output + } + + if s.manager.memoryTotalSize < s.manager.calculateRequiredMemoryFor(objects) { + output <- HydratedFile{ + Error: fmt.Errorf( + "error: bytes requested greater than max allowed by server (%v)", + s.manager.memoryTotalSize, + ), + } + close(output) + return output + } + // Secure memory for all objects upfront. + s.manager.secureMemory(objects) // 0. + + // IMPORTANT: ensure memory is released if context cancels before processing finishes + go func() { + <-ctx.Done() + // Best-effort cleanup: release all secured memory + for _, o := range objects { + s.manager.releaseMemory(aws.ToInt64(o.Size)) + } + }() + + processFunc := func(input types.Object) HydratedFile { + // Respect cancellation before starting work + select { + case <-ctx.Done(): + return HydratedFile{Error: ctx.Err()} + default: + } + + buf := bytes.NewBuffer(make([]byte, 0, int(*input.Size))) + key := aws.ToString(input.Key) + + // Prefer context-aware S3 call if available + _, err := s.GetWithContext(ctx, bucket, key, version, buf) + + // If context was cancelled during S3 read, surface that + if ctx.Err() != nil { + return HydratedFile{Error: ctx.Err()} + } + + return HydratedFile{ + Key: key, + Data: buf.Bytes(), + Error: err, + } + } + + // Process already accepts a context internally, so pass it through + return s.manager.ProcessWithContext(ctx, processFunc, objects) +} + // getWorker retrieves a number of workers from the manager's worker pool. func (cm *ConcurrencyManager) getWorkers(number int) []*worker { cm.workerPool.mutex.Lock() @@ -259,6 +342,33 @@ func (cm *ConcurrencyManager) Process(asyncProcessor FileProcessor, objects []ty return workerGroup.returnOutput() // 2. } +// Functions for providing a fan-out/fan-in operation with provided context. Workers are taken from the +// worker pool and added to a WorkerGroup. All workers are returned to the pool once +// the jobs have finished. +func (cm *ConcurrencyManager) ProcessWithContext( + ctx context.Context, + asyncProcessor FileProcessor, + objects []types.Object, +) chan HydratedFile { + + workerGroup := cm.newWorkerGroup(ctx, asyncProcessor, cm.maxWorkersPerRequest) + + go func() { + for _, obj := range objects { + select { + case <-ctx.Done(): + workerGroup.stopWork() + return + default: + workerGroup.addWork(obj) + } + } + workerGroup.stopWork() + }() + + return workerGroup.returnOutput() +} + // start begins a worker's process of making itself available for work, doing the work, // and repeat, until all work is done. func (w *worker) start(ctx context.Context, processor FileProcessor, roster chan *worker, wg *sync.WaitGroup) { diff --git a/aws/s3/s3_concurrent_test.go b/aws/s3/s3_concurrent_test.go index fa557f4..2298924 100644 --- a/aws/s3/s3_concurrent_test.go +++ b/aws/s3/s3_concurrent_test.go @@ -1,6 +1,7 @@ package s3 import ( + "context" "fmt" "testing" "time" @@ -121,3 +122,58 @@ func TestS3GetAllConcurrently(t *testing.T) { } } } + +// go test --run TestS3GetAllConcurrentlyWithContext_Cancel -v +func TestS3GetAllConcurrentlyWithContext_Cancel(t *testing.T) { + // ARRANGE + setup() + defer teardown() + + client, err := NewConcurrent(100, 10, 1000) + require.NoError(t, err) + + total := 20 + keys := make([]string, total) + for i := 0; i < total; i++ { + keys[i] = fmt.Sprintf("%s-%v", testObjectKey, i) + } + awsCmdPutKeys(keys) + objects, _ := client.ListAllObjects(testBucket, "") + + t.Run("early-cancel-before-start", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + out := client.GetAllConcurrentlyWithContext(ctx, testBucket, "", objects) + + var got []HydratedFile + for hf := range out { + got = append(got, hf) + } + require.Len(t, got, 1) + require.ErrorIs(t, got[0].Error, context.Canceled) + + time.Sleep(200 * time.Millisecond) + assert.Equal(t, 100, len(client.manager.workerPool.channel)) + assert.Equal(t, 100, len(client.manager.memoryPool.channel)) + }) + t.Run("cancel-during-processing", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + out := client.GetAllConcurrentlyWithContext(ctx, testBucket, "", objects) + + collected := make([]HydratedFile, 0, len(objects)) + cancelAfter := 3 + + for hf := range out { + collected = append(collected, hf) + if len(collected) == cancelAfter { + cancel() + } + } + // At least some work completed + require.GreaterOrEqual(t, len(collected), cancelAfter) + // But not all objects should be processed + require.Less(t, len(collected), len(objects)) + }) +} From 030fdc6860522dc4e9f366ac60fbc9e7e9fa65a8 Mon Sep 17 00:00:00 2001 From: bpeng Date: Sat, 10 Jan 2026 12:16:00 +1300 Subject: [PATCH 2/2] feat: update GetAllConcurrentlyWithContext to fix test failure --- aws/s3/s3_concurrent.go | 95 ++++++++++---------------- aws/s3/s3_concurrent_test.go | 122 +++++++++++++++++++++++++++++++++- aws/s3/s3_integration_test.go | 25 ++++--- 3 files changed, 169 insertions(+), 73 deletions(-) diff --git a/aws/s3/s3_concurrent.go b/aws/s3/s3_concurrent.go index 7be8b4b..38141c0 100644 --- a/aws/s3/s3_concurrent.go +++ b/aws/s3/s3_concurrent.go @@ -157,41 +157,13 @@ func newConcurrencyManager(maxWorkers, maxWorkersPerRequest, maxBytes int) *Conc // containing a single HydratedFile with an error is returned. // Version can be empty, but must be the same for all objects. func (s *S3Concurrent) GetAllConcurrently(bucket, version string, objects []types.Object) chan HydratedFile { - - if s.manager == nil { - output := make(chan HydratedFile, 1) - output <- HydratedFile{Error: errors.New("error getting files from S3, Concurrency Manager not initialised")} - close(output) - return output - } - - if s.manager.memoryTotalSize < s.manager.calculateRequiredMemoryFor(objects) { - output := make(chan HydratedFile, 1) - output <- HydratedFile{Error: fmt.Errorf("error: bytes requested greater than max allowed by server (%v)", s.manager.memoryTotalSize)} - close(output) - return output - } - // Secure memory for all objects upfront. - s.manager.secureMemory(objects) // 0. - - processFunc := func(input types.Object) HydratedFile { - buf := bytes.NewBuffer(make([]byte, 0, int(*input.Size))) - key := aws.ToString(input.Key) - err := s.Get(bucket, key, version, buf) - - return HydratedFile{ - Key: key, - Data: buf.Bytes(), - Error: err, - } - } - return s.manager.Process(processFunc, objects) + return s.GetAllConcurrentlyWithContext(context.Background(), bucket, version, objects) } -// GetAllConcurrently gets the objects with provided context, from specified bucket and writes the resulting HydratedFiles +// GetAllConcurrentlyWithContext gets the objects with provided context, from specified bucket and writes the resulting HydratedFiles // to the returned output channel. The closure of this channel is handled, however it's the caller's // responsibility to purge the channel, and handle any errors present in the HydratedFiles. -// If the ConcurrencyManager is not initialised before calling GetAllConcurrently, an output channel +// If the ConcurrencyManager is not initialised before calling GetAllConcurrentlyWithContext, an output channel // containing a single HydratedFile with an error is returned. // Version can be empty, but must be the same for all objects. func (s *S3Concurrent) GetAllConcurrentlyWithContext( @@ -229,13 +201,14 @@ func (s *S3Concurrent) GetAllConcurrentlyWithContext( close(output) return output } + // Secure memory for all objects upfront. s.manager.secureMemory(objects) // 0. - // IMPORTANT: ensure memory is released if context cancels before processing finishes + // ensure memory is released if context cancels before processing finishes go func() { <-ctx.Done() - // Best-effort cleanup: release all secured memory + // release all secured memory for _, o := range objects { s.manager.releaseMemory(aws.ToInt64(o.Size)) } @@ -252,7 +225,6 @@ func (s *S3Concurrent) GetAllConcurrentlyWithContext( buf := bytes.NewBuffer(make([]byte, 0, int(*input.Size))) key := aws.ToString(input.Key) - // Prefer context-aware S3 call if available _, err := s.GetWithContext(ctx, bucket, key, version, buf) // If context was cancelled during S3 read, surface that @@ -267,8 +239,8 @@ func (s *S3Concurrent) GetAllConcurrentlyWithContext( } } - // Process already accepts a context internally, so pass it through - return s.manager.ProcessWithContext(ctx, processFunc, objects) + // Process with a context + return s.manager.Process(ctx, processFunc, objects) } // getWorker retrieves a number of workers from the manager's worker pool. @@ -327,25 +299,10 @@ func (cm *ConcurrencyManager) releaseMemory(size int64) { } } -// Functions for providing a fan-out/fan-in operation. Workers are taken from the -// worker pool and added to a WorkerGroup. All workers are returned to the pool once -// the jobs have finished. -func (cm *ConcurrencyManager) Process(asyncProcessor FileProcessor, objects []types.Object) chan HydratedFile { - workerGroup := cm.newWorkerGroup(context.Background(), asyncProcessor, cm.maxWorkersPerRequest) // 1. - - go func() { - for _, obj := range objects { - workerGroup.addWork(obj) - } - workerGroup.stopWork() // 9. - }() - return workerGroup.returnOutput() // 2. -} - // Functions for providing a fan-out/fan-in operation with provided context. Workers are taken from the // worker pool and added to a WorkerGroup. All workers are returned to the pool once // the jobs have finished. -func (cm *ConcurrencyManager) ProcessWithContext( +func (cm *ConcurrencyManager) Process( ctx context.Context, asyncProcessor FileProcessor, objects []types.Object, @@ -354,16 +311,21 @@ func (cm *ConcurrencyManager) ProcessWithContext( workerGroup := cm.newWorkerGroup(ctx, asyncProcessor, cm.maxWorkersPerRequest) go func() { + defer func() { + close(workerGroup.reception) + workerGroup.stopWork() + }() + for _, obj := range objects { select { case <-ctx.Done(): - workerGroup.stopWork() return default: - workerGroup.addWork(obj) + if !workerGroup.addWork(ctx, obj) { + return + } } } - workerGroup.stopWork() }() return workerGroup.returnOutput() @@ -371,7 +333,12 @@ func (cm *ConcurrencyManager) ProcessWithContext( // start begins a worker's process of making itself available for work, doing the work, // and repeat, until all work is done. -func (w *worker) start(ctx context.Context, processor FileProcessor, roster chan *worker, wg *sync.WaitGroup) { +func (w *worker) start( + ctx context.Context, + processor FileProcessor, + roster chan *worker, + wg *sync.WaitGroup, +) { go func() { defer func() { wg.Done() @@ -451,7 +418,7 @@ func (wg *workerGroup) startOutput() { func (wg *workerGroup) cleanUp(ctx context.Context) { <-ctx.Done() wg.group.Wait() // 9. - close(wg.reception) + //close(wg.reception) close(wg.roster) } @@ -459,12 +426,18 @@ func (wg *workerGroup) cleanUp(ctx context.Context) { // roster, and gives it an S3 Object to download. The worker's output // channel is registered to the workerGroup's reception so that // order is retained. -func (wg *workerGroup) addWork(newWork types.Object) { // 4. +func (wg *workerGroup) addWork(ctx context.Context, newWork types.Object) bool { for w := range wg.roster { - w.input <- newWork - wg.reception <- w.output - break + select { + case <-ctx.Done(): + return false + default: + w.input <- newWork + wg.reception <- w.output + return true + } } + return false } // returnOutput returns the workerGroup's output channel. diff --git a/aws/s3/s3_concurrent_test.go b/aws/s3/s3_concurrent_test.go index 2298924..56c85b0 100644 --- a/aws/s3/s3_concurrent_test.go +++ b/aws/s3/s3_concurrent_test.go @@ -90,7 +90,7 @@ func TestS3GetAllConcurrently(t *testing.T) { } // ASSERT input and output order is the same. - require.Equal(t, len(outputKeys), total) + require.Equal(t, total, len(outputKeys)) for i := 0; i < total; i++ { assert.Equal(t, aws.ToString(objects[i].Key), outputKeys[i]) } @@ -123,6 +123,117 @@ func TestS3GetAllConcurrently(t *testing.T) { } } +// go test --run TestS3GetAllConcurrentlyWithContext -v +func TestS3GetAllConcurrentlyWithContext(t *testing.T) { + // ARRANGE + setup() + defer teardown() + + // ASSERT parameter errors. + _, err := NewConcurrent(0, 100, 1000) + assert.NotNil(t, err) + _, err = NewConcurrent(100, 0, 1000) + assert.NotNil(t, err) + _, err = NewConcurrent(100, 100, 0) + assert.NotNil(t, err) + _, err = NewConcurrent(100, 10, 99) + assert.NotNil(t, err) + _, err = NewConcurrent(100, 101, 1000) + assert.NotNil(t, err) + + client, err := NewConcurrent(100, 10, 1000) + require.Nil(t, err, fmt.Sprintf("error creating s3 client concurrency manager: %v", err)) + + // ASSERT computed fields. + assert.Equal(t, 100, len(client.manager.workerPool.channel)) + assert.Equal(t, 100, len(client.manager.memoryPool.channel)) + assert.Equal(t, int64(10), client.manager.memoryChunkSize) + assert.Equal(t, int64(10*100), client.manager.memoryTotalSize) + assert.Equal(t, 10, client.manager.maxWorkersPerRequest) + + // ASSERT memory chunk size is correct in memory pool. + chunk := <-client.manager.memoryPool.channel + assert.Equal(t, int64(10), chunk) + client.manager.memoryPool.channel <- chunk + + // ASSERT worker get/release methods work expectedly. + w := client.manager.getWorkers(1) + assert.Equal(t, 99, len(client.manager.workerPool.channel)) + client.manager.returnWorker(w[0]) + assert.Equal(t, 100, len(client.manager.workerPool.channel)) + + // ASSERT memory get/release methods work expectedly. + elevenByteFile := types.Object{Size: aws.Int64(11)} // requires 2 memory chunks. + client.manager.secureMemory([]types.Object{elevenByteFile}) + assert.Equal(t, 98, len(client.manager.memoryPool.channel)) + client.manager.releaseMemory(20) + assert.Equal(t, 100, len(client.manager.memoryPool.channel)) + + // ARRANGE bucket with test objects. + total := 20 + keys := make([]string, total) + for i := 0; i < total; i++ { + keys[i] = fmt.Sprintf("%s-%v", testObjectKey, i) + } + awsCmdPutKeys(keys) + + // ACTION + objects, _ := client.ListAllObjects(testBucket, "") + tooManyBytes := make([]types.Object, 10*len(objects)) + for _, o := range objects { + for i := 0; i < 10; i++ { + tooManyBytes = append(tooManyBytes, o) + } + } + output := client.GetAllConcurrentlyWithContext(context.Background(), testBucket, "", tooManyBytes) + + // ASSERT error returned + for hf := range output { + assert.NotNil(t, hf.Error) + } + + // ACTION + objects, _ = client.ListAllObjects(testBucket, "") + output = client.GetAllConcurrentlyWithContext(context.Background(), testBucket, "", objects) + outputKeys := make([]string, 0) + for hf := range output { + outputKeys = append(outputKeys, hf.Key) + } + + // ASSERT input and output order is the same. + require.Equal(t, total, len(outputKeys)) + for i := 0; i < total; i++ { + assert.Equal(t, aws.ToString(objects[i].Key), outputKeys[i]) + } + + // ASSERT all workers and memory returned to pools. + time.Sleep(2 * time.Second) + assert.Equal(t, 100, len(client.manager.workerPool.channel)) + assert.Equal(t, 100, len(client.manager.memoryPool.channel)) + + // ASSERT that process blocked when all memory secured. + tenByteFile := types.Object{Size: aws.Int64(10)} + oneThousandBytesOfFiles := make([]types.Object, 100) + for i := 0; i < 100; i++ { + oneThousandBytesOfFiles[i] = tenByteFile + } + client.manager.secureMemory(oneThousandBytesOfFiles) + ch := make(chan chan HydratedFile) + go func() { + ch <- client.GetAllConcurrentlyWithContext(context.Background(), testBucket, "", objects) + }() + + for { + select { + case <-ch: + t.Error("process was not blocked") + case <-time.After(time.Second): + // Timed out as expected + return + } + } +} + // go test --run TestS3GetAllConcurrentlyWithContext_Cancel -v func TestS3GetAllConcurrentlyWithContext_Cancel(t *testing.T) { // ARRANGE @@ -175,5 +286,14 @@ func TestS3GetAllConcurrentlyWithContext_Cancel(t *testing.T) { require.GreaterOrEqual(t, len(collected), cancelAfter) // But not all objects should be processed require.Less(t, len(collected), len(objects)) + // Pool recovery + require.Eventually(t, func() bool { + return len(client.manager.workerPool.channel) == 100 + }, 5*time.Second, 10*time.Millisecond, fmt.Sprintf("workers pool not recovered, expected %d actual %d", 100, len(client.manager.workerPool.channel))) + require.Eventually(t, func() bool { + return len(client.manager.memoryPool.channel) == 100 + }, 5*time.Second, 10*time.Millisecond, fmt.Sprintf("memory pool not recovered, expected %d actual %d", 100, len(client.manager.memoryPool.channel))) + }) + } diff --git a/aws/s3/s3_integration_test.go b/aws/s3/s3_integration_test.go index 0ba6e64..3239695 100644 --- a/aws/s3/s3_integration_test.go +++ b/aws/s3/s3_integration_test.go @@ -50,17 +50,20 @@ func setup() { // setup environment variable to run AWS CLI/SDK setAwsEnv() - // create bucket - cmd := exec.Command( //nolint:gosec - "aws", "s3api", - "create-bucket", - "--bucket", testBucket, - "--create-bucket-configuration", fmt.Sprintf( - "{\"LocationConstraint\": \"%v\"}", testRegion), - ) - if output, err := cmd.CombinedOutput(); err != nil { - fmt.Printf("Command failed: %v\nOutput: %s\n", err, string(output)) - panic(err) + // check if bucket exists before creating + if !awsCmdBucketExists(testBucket) { + // create bucket + cmd := exec.Command( //nolint:gosec + "aws", "s3api", + "create-bucket", + "--bucket", testBucket, + "--create-bucket-configuration", fmt.Sprintf( + "{\"LocationConstraint\": \"%v\"}", testRegion), + ) + if output, err := cmd.CombinedOutput(); err != nil { + fmt.Printf("Command failed: %v\nOutput: %s\n", err, string(output)) + panic(err) + } } }