diff --git a/aws/s3/s3.go b/aws/s3/s3.go index ca10fe0..f2005cb 100644 --- a/aws/s3/s3.go +++ b/aws/s3/s3.go @@ -128,22 +128,40 @@ func (s *S3) Client() *s3.Client { // Get gets the object referred to by key and version from bucket and writes it into b. // Version can be empty. func (s *S3) Get(bucket, key, version string, b *bytes.Buffer) error { + _, err := s.GetWithContext(context.Background(), bucket, key, version, b) + return err +} + +// Get gets the object referred to by key and version from bucket and writes it into b. +// with the provided context. +// Version can be empty. +func (s *S3) GetWithContext( + ctx context.Context, + bucket, key, version string, + w io.Writer, +) (int64, error) { + input := s3.GetObjectInput{ - Key: aws.String(key), Bucket: aws.String(bucket), + Key: aws.String(key), } if version != "" { input.VersionId = aws.String(version) } - result, err := s.client.GetObject(context.TODO(), &input) + + result, err := s.client.GetObject(ctx, &input) if err != nil { - return err + return 0, err } defer result.Body.Close() - _, err = b.ReadFrom(result.Body) + n, err := io.Copy(w, result.Body) - return err + // Distinguish cancellation from real errors + if ctx.Err() != nil { + return n, ctx.Err() + } + return n, err } // GetByteRange gets the specified byte range of an object referred to by key and version diff --git a/aws/s3/s3_integration_test.go b/aws/s3/s3_integration_test.go index b285404..0ba6e64 100644 --- a/aws/s3/s3_integration_test.go +++ b/aws/s3/s3_integration_test.go @@ -43,6 +43,7 @@ func setAwsEnv() { os.Setenv("AWS_SECRET_ACCESS_KEY", "test") os.Setenv("AWS_ACCESS_KEY_ID", "test") os.Setenv("AWS_ENDPOINT_URL", customAWSEndpoint) + os.Setenv("AWS_S3_DISABLE_CHECKSUM", "true") } func setup() { @@ -50,13 +51,15 @@ func setup() { setAwsEnv() // create bucket - if err := exec.Command( //nolint:gosec + cmd := exec.Command( //nolint:gosec "aws", "s3api", "create-bucket", "--bucket", testBucket, "--create-bucket-configuration", fmt.Sprintf( "{\"LocationConstraint\": \"%v\"}", testRegion), - ).Run(); err != nil { + ) + if output, err := cmd.CombinedOutput(); err != nil { + fmt.Printf("Command failed: %v\nOutput: %s\n", err, string(output)) panic(err) } } @@ -146,11 +149,12 @@ func awsCmdPutKeys(keys []string) { testFile.Close() } // sync to bucket - if err := exec.Command( + cmd := exec.Command( "aws", "s3", "sync", tmpDir, fmt.Sprintf("s3://%v", testBucket), - ).Run(); err != nil { - + ) + if output, err := cmd.CombinedOutput(); err != nil { + fmt.Printf("Command failed: %v\nOutput: %s\n", err, string(output)) panic(err) } } @@ -360,6 +364,141 @@ func TestS3Get(t *testing.T) { assert.Equal(t, testObjectData, dataObject.String()) } +func TestS3GetWithContext(t *testing.T) { + // ARRANGE + setup() + defer teardown() + + awsCmdPopulateBucket() + + client, err := New() + require.NoError(t, err, "error creating s3 client") + + t.Run("normal", func(t *testing.T) { + var buf bytes.Buffer + ctx := context.Background() + + // ACTION + written, err := client.GetWithContext( + ctx, + testBucket, + testObjectKey, + "", + &buf, + ) + + // ASSERT + require.NoError(t, err) + assert.Equal(t, int64(len(testObjectData)), written) + assert.Equal(t, testObjectData, buf.String()) + }) + + t.Run("cancelled", func(t *testing.T) { + var buf bytes.Buffer + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + // ACTION + written, err := client.GetWithContext( + ctx, + testBucket, + testObjectKey, + "", + &buf, + ) + + // ASSERT + require.Error(t, err) + assert.ErrorIs(t, err, context.Canceled) + assert.Equal(t, int64(0), written) + }) + + t.Run("cancel-during-processing", func(t *testing.T) { + // We’ll cancel after a portion of the object has been written to the buffer. + ctx, cancel := context.WithCancel(context.Background()) + var buf bytes.Buffer + + // Choose a threshold smaller than the total size so we cancel mid-stream. + sw := &cancelAfterNWriter{ + dst: &buf, + cancel: cancel, + limit: 4, // cancel after 4 bytes are written + sleep: 0 * time.Millisecond, // optional; set to >0 to slow per-write + } + + // ACTION + written, err := client.GetWithContext( + ctx, + testBucket, + testObjectKey, + "", + sw, + ) + t.Log("written bytes:", written) + // ASSERT: it should end early with a context error and partial bytes written + require.Error(t, err, "expected error due to mid-run cancellation") + assert.ErrorIs(t, err, context.Canceled) + assert.GreaterOrEqual(t, written, int64(1), "should write some bytes before cancel") + assert.Equal(t, written, int64(buf.Len()), "buffer length should match reported written") + assert.Less(t, written, int64(len(testObjectData)), "should not complete full object") + }) +} + +// cancelAfterNWriter writes at most limit bytes to dst. +// Once limit is reached, it cancels ctx and returns context.Canceled. +// If a single Write would exceed the limit, it performs a **partial write** +// and then returns context.Canceled so the copy loop stops immediately. +type cancelAfterNWriter struct { + dst io.Writer + cancel context.CancelFunc + limit int64 // total bytes allowed before we cancel & error + sleep time.Duration + wrote int64 +} + +func (w *cancelAfterNWriter) Write(p []byte) (int, error) { + if w.sleep > 0 { + time.Sleep(w.sleep) + } + + remaining := w.limit - w.wrote + if remaining <= 0 { + // Already reached the limit: cancel & error without writing. + if w.cancel != nil { + w.cancel() + w.cancel = nil + } + return 0, context.Canceled + } + + // If the incoming chunk exceeds the remaining budget, do a **partial write**. + if int64(len(p)) > remaining { + // write only `remaining` bytes + n, err := w.dst.Write(p[:remaining]) + if err != nil { + return n, err + } + w.wrote += int64(n) + // Now cancel & return error to abort the transfer + if w.cancel != nil { + w.cancel() + w.cancel = nil + } + // ignore underlying err to ensure we signal cancel; return context.Canceled with partial write + return n, context.Canceled + } + + // Normal path: whole chunk fits. + n, err := w.dst.Write(p) + w.wrote += int64(n) + // If we *exactly* hit the limit after this write, cancel & error on the next call. + if w.wrote >= w.limit && w.cancel != nil { + w.cancel() + w.cancel = nil + } + return n, err +} + func TestS3GetByteRange(t *testing.T) { // ARRANGE setup()