From 74edf5c8527dd213d5aef2d3c6592f7a1cf1482a Mon Sep 17 00:00:00 2001 From: Callum Morris Date: Mon, 17 Mar 2025 15:43:13 +1300 Subject: [PATCH 1/8] fix: remove reference to gamit in batch id This is a shared library --- aws/sqs/sqs.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aws/sqs/sqs.go b/aws/sqs/sqs.go index 9b36026..65e4d45 100644 --- a/aws/sqs/sqs.go +++ b/aws/sqs/sqs.go @@ -240,7 +240,7 @@ func (s *SQS) SendBatch(ctx context.Context, queueURL string, bodies []string) e entries := make([]types.SendMessageBatchRequestEntry, len(bodies)) for j, body := range bodies { entries[j] = types.SendMessageBatchRequestEntry{ - Id: aws.String(fmt.Sprintf("gamitjob%d", j)), + Id: aws.String(fmt.Sprintf("message-%d", j)), MessageBody: aws.String(body), } } From 24da0d5f1c785bb199d92b2882af17fac5e4cc47 Mon Sep 17 00:00:00 2001 From: Callum Morris Date: Fri, 11 Apr 2025 12:55:12 +1200 Subject: [PATCH 2/8] feat: add inflight message count check Added to test that test delete function --- aws/sqs/sqs_integration_test.go | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/aws/sqs/sqs_integration_test.go b/aws/sqs/sqs_integration_test.go index 9f2a0f7..526dde0 100644 --- a/aws/sqs/sqs_integration_test.go +++ b/aws/sqs/sqs_integration_test.go @@ -128,6 +128,25 @@ func awsCmdQueueCount() int { } } +func awsCmdQueueInFlightCount() int { + out, err := exec.Command( + "aws", "sqs", + "get-queue-attributes", + "--queue-url", awsCmdQueueURL(), + "--attribute-name", "ApproximateNumberOfMessagesNotVisible", + "--region", awsRegion).CombinedOutput() + + if err != nil { + panic(err) + } + + var payload map[string]map[string]string + json.Unmarshal(out, &payload) + + rvalue, _ := strconv.Atoi(payload["Attributes"]["ApproximateNumberOfMessagesNotVisible"]) + return rvalue +} + func awsCmdGetQueueArn(url string) string { arn, err := exec.Command( "aws", "sqs", @@ -323,8 +342,9 @@ func TestSQSDelete(t *testing.T) { client, err := New() require.Nil(t, err, fmt.Sprintf("Error creating sqs client: %v", err)) - receivedMessage, err := client.Receive(awsCmdQueueURL(), 1) + receivedMessage, err := client.Receive(awsCmdQueueURL(), 30) require.Nil(t, err, fmt.Sprintf("Error receiving test message: %v", err)) + require.Equal(t, 1, awsCmdQueueInFlightCount()) // ACTION err = client.Delete(awsCmdQueueURL(), receivedMessage.ReceiptHandle) @@ -332,6 +352,7 @@ func TestSQSDelete(t *testing.T) { // ASSERT assert.Nil(t, err) assert.Equal(t, 0, awsCmdQueueCount()) + assert.Equal(t, 0, awsCmdQueueInFlightCount()) } func TestSQSSend(t *testing.T) { From 0b19e42767d54a78c8b27fb9f03c0b965ebad7f3 Mon Sep 17 00:00:00 2001 From: Callum Morris Date: Fri, 11 Apr 2025 12:56:09 +1200 Subject: [PATCH 3/8] feat: add receiveMessages cmd to test suite --- aws/sqs/sqs_integration_test.go | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/aws/sqs/sqs_integration_test.go b/aws/sqs/sqs_integration_test.go index 526dde0..8b2d89b 100644 --- a/aws/sqs/sqs_integration_test.go +++ b/aws/sqs/sqs_integration_test.go @@ -111,6 +111,31 @@ func awsCmdReceiveMessage() string { } } +func awsCmdReceiveMessages() []string { + if out, err := exec.Command( //nolint:gosec + "aws", "sqs", + "receive-message", + "--queue-url", awsCmdQueueURL(), + "--attribute-names", "body", + "--region", awsRegion, + "--max-number-of-messages", "10", // AWS SQS allows up to 10 messages at a time + ).CombinedOutput(); err != nil { + + panic(err) + } else { + var payload map[string][]map[string]string + _ = json.Unmarshal(out, &payload) + + var bodies []string + for _, msg := range payload["Messages"] { + if body, ok := msg["Body"]; ok { + bodies = append(bodies, body) + } + } + return bodies + } +} + func awsCmdQueueCount() int { if out, err := exec.Command( "aws", "sqs", From c26183f0aedd20893a229de884e37b14a55aeabf Mon Sep 17 00:00:00 2001 From: Callum Morris Date: Fri, 11 Apr 2025 13:03:46 +1200 Subject: [PATCH 4/8] feat: update SendBatch to fix deficiencies in error handling Includes adding test for this method --- aws/sqs/sqs.go | 52 +++++++++++++++++++--- aws/sqs/sqs_integration_test.go | 78 +++++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+), 6 deletions(-) diff --git a/aws/sqs/sqs.go b/aws/sqs/sqs.go index 65e4d45..5e554f4 100644 --- a/aws/sqs/sqs.go +++ b/aws/sqs/sqs.go @@ -231,11 +231,27 @@ func (s *SQS) SendFifoMessage(queue, group, dedupe string, msg []byte) (string, return "", nil } -// Leverage the sendbatch api for uploading large numbers of messages +type SendBatchError struct { + Err error + Info []SendBatchErrorEntry +} +type SendBatchErrorEntry struct { + Entry types.BatchResultErrorEntry + Index int +} + +func (s *SendBatchError) Error() string { + return fmt.Sprintf("%v: %v messages failed to send", s.Err, len(s.Info)) +} +func (s *SendBatchError) Unwrap() error { + return s.Err +} + +// SendBatch sends up to 10 messages to a given SQS queue with one API call. +// If an error occurs on any or all messages, a SendBatchError is returned that lets +// the caller know the index of the message/s in bodies that failed. func (s *SQS) SendBatch(ctx context.Context, queueURL string, bodies []string) error { - if len(bodies) > 11 { - return errors.New("too many messages to batch") - } + var err error entries := make([]types.SendMessageBatchRequestEntry, len(bodies)) for j, body := range bodies { @@ -244,11 +260,35 @@ func (s *SQS) SendBatch(ctx context.Context, queueURL string, bodies []string) e MessageBody: aws.String(body), } } - _, err = s.client.SendMessageBatch(ctx, &sqs.SendMessageBatchInput{ + output, err := s.client.SendMessageBatch(ctx, &sqs.SendMessageBatchInput{ Entries: entries, QueueUrl: &queueURL, }) - return err + if err != nil { + info := make([]SendBatchErrorEntry, len(entries)) + for i := range entries { + info[i] = SendBatchErrorEntry{ + Index: i, + } + } + return &SendBatchError{Err: err, Info: info} + } + if len(output.Failed) > 0 { + info := make([]SendBatchErrorEntry, len(output.Failed)) + for i, entry := range output.Failed { + for j, msg := range entries { + if aws.ToString(msg.Id) == aws.ToString(entry.Id) { + info[i] = SendBatchErrorEntry{ + Entry: entry, + Index: j, + } + break + } + } + } + return &SendBatchError{Err: errors.New("partial message failure"), Info: info} + } + return nil } func (s *SQS) SendNBatch(ctx context.Context, queueURL string, bodies []string) error { diff --git a/aws/sqs/sqs_integration_test.go b/aws/sqs/sqs_integration_test.go index 8b2d89b..e5cfa06 100644 --- a/aws/sqs/sqs_integration_test.go +++ b/aws/sqs/sqs_integration_test.go @@ -4,7 +4,9 @@ package sqs import ( + "context" "encoding/json" + "errors" "fmt" "os" "os/exec" @@ -423,6 +425,82 @@ func TestSQSSendWithDelay(t *testing.T) { assert.True(t, timeElapsed < timeout) } +func TestSendBatch(t *testing.T) { + // ARRANGE + setup() + defer teardown() + + client, err := New() + require.Nil(t, err, fmt.Sprintf("error creating sqs client: %v", err)) + + // ACTION + var maxBytes int = 262144 + maxSizeSingleMessage := "" + for range maxBytes { + maxSizeSingleMessage += "a" + } + err = client.SendBatch(context.TODO(), awsCmdQueueURL(), []string{maxSizeSingleMessage}) + + // ASSERT + assert.Nil(t, err) + assert.Equal(t, maxSizeSingleMessage, awsCmdReceiveMessage()) + + // ACTION + tooLargeSingleMessage := maxSizeSingleMessage + "a" + err = client.SendBatch(context.TODO(), awsCmdQueueURL(), []string{tooLargeSingleMessage}) + + // ASSERT + assert.NotNil(t, err) + + // ACTION + var maxHalfBytes int = 131072 + maxHalfSizeMessage := "" + for range maxHalfBytes { + maxHalfSizeMessage += "a" + } + err = client.SendBatch(context.TODO(), awsCmdQueueURL(), []string{maxHalfSizeMessage, maxHalfSizeMessage}) + + // ASSERT + assert.Nil(t, err) + assert.Equal(t, maxHalfSizeMessage, awsCmdReceiveMessage()) + assert.Equal(t, maxHalfSizeMessage, awsCmdReceiveMessage()) + + // ACTION + tooLargeHalfSizeMessage := maxHalfSizeMessage + "a" + err = client.SendBatch(context.TODO(), awsCmdQueueURL(), []string{maxHalfSizeMessage, tooLargeHalfSizeMessage}) + + // ASSERT + assert.NotNil(t, err) + + var sbe *SendBatchError + if errors.As(err, &sbe) { + assert.Equal(t, 2, len(sbe.Info)) + assert.Equal(t, 0, sbe.Info[0].Index) + assert.Equal(t, 1, sbe.Info[1].Index) + } else { + t.Error("unexpected error type") + } + + // ACTION + validMessage := "test" + invalidMessage := "\u0000" + + err = client.SendBatch(context.TODO(), awsCmdQueueURL(), []string{validMessage, invalidMessage}) + + // ASSERT + assert.NotNil(t, err) + + sbe = nil + if errors.As(err, &sbe) { + assert.Equal(t, 1, len(sbe.Info)) + assert.Equal(t, 1, sbe.Info[0].Index) + } else { + t.Error("unexpected error type") + } + + assert.Equal(t, validMessage, awsCmdReceiveMessage()) +} + func TestGetQueueUrl(t *testing.T) { // ARRANGE setup() From a64ca1b47d87dd2c40618b27e3d77b7d39f0f0b8 Mon Sep 17 00:00:00 2001 From: Callum Morris Date: Fri, 11 Apr 2025 13:09:16 +1200 Subject: [PATCH 5/8] feat!: update SendNBatch to fix deficiencies in error handling Includes adding a test for this method BREAKING CHANGE: now returns the number of API calls to SendBatch made --- aws/sqs/sqs.go | 99 ++++++++++++++++++++++++----- aws/sqs/sqs_integration_test.go | 109 ++++++++++++++++++++++++++++++++ 2 files changed, 193 insertions(+), 15 deletions(-) diff --git a/aws/sqs/sqs.go b/aws/sqs/sqs.go index 5e554f4..38a8b93 100644 --- a/aws/sqs/sqs.go +++ b/aws/sqs/sqs.go @@ -5,7 +5,6 @@ import ( "context" "errors" "fmt" - "math" "os" "strings" @@ -247,6 +246,20 @@ func (s *SendBatchError) Unwrap() error { return s.Err } +type SendNBatchError struct { + Errors []error + Info []SendBatchErrorEntry +} + +func (s *SendNBatchError) Error() string { + var allErrors string + for _, err := range s.Errors { + allErrors += fmt.Sprintf("%s,", err.Error()) + } + allErrors = strings.TrimSuffix(allErrors, ",") + return fmt.Sprintf("%v error(s) sending batches: %s", len(s.Errors), allErrors) +} + // SendBatch sends up to 10 messages to a given SQS queue with one API call. // If an error occurs on any or all messages, a SendBatchError is returned that lets // the caller know the index of the message/s in bodies that failed. @@ -291,24 +304,80 @@ func (s *SQS) SendBatch(ctx context.Context, queueURL string, bodies []string) e return nil } -func (s *SQS) SendNBatch(ctx context.Context, queueURL string, bodies []string) error { - var ( - bodiesLen = len(bodies) - maxlen = 10 - times = int(math.Ceil(float64(bodiesLen) / float64(maxlen))) +// SendNBatch sends any number of messages to a given SQS queue via a series of SendBatch calls. +// If an error occurs on any or all messages, a SendNBatchError is returned that lets +// the caller know the index of the message/s in bodies that failed. +// Returns the number of API calls to SendBatch made. +func (s *SQS) SendNBatch(ctx context.Context, queueURL string, bodies []string) (int, error) { + + const ( + maxCount = 10 + maxSize = 262144 // 256 KiB ) - for i := 0; i < times; i++ { - batch_end := maxlen * (i + 1) - if maxlen*(i+1) > bodiesLen { - batch_end = bodiesLen + + allErrors := make([]error, 0) + allInfo := make([]SendBatchErrorEntry, 0) + + batchesSent := 0 + + batch := make([]int, 0) + totalSize := 0 + + sendBatch := func() { + batchBodies := make([]string, len(batch)) + + for i, batchIndex := range batch { + batchBodies[i] = bodies[batchIndex] + } + + err := s.SendBatch(ctx, queueURL, batchBodies) + var sbe *SendBatchError + if errors.As(err, &sbe) { + allErrors = append(allErrors, err) + + // Update index so that index refers to the position in given bodies slice. + for i := range sbe.Info { + sbe.Info[i].Index = batch[sbe.Info[i].Index] + } + + allInfo = append(allInfo, sbe.Info...) + } + + batchesSent++ + batch = batch[:0] + totalSize = 0 + } + + for i, body := range bodies { + + // Check if any single message is too big + if len(body) > maxSize { + allErrors = append(allErrors, errors.New("message too big to send")) + allInfo = append(allInfo, SendBatchErrorEntry{ + Index: i, + }) + continue } - var bodies_batch = bodies[maxlen*i : batch_end] - err := s.SendBatch(ctx, queueURL, bodies_batch) - if err != nil { - return err + // If adding the current message would exceed the batch max size or count, send the current batch. + if totalSize+len(body) > maxSize || len(batch) == maxCount { + sendBatch() } + batch = append(batch, i) + totalSize += len(body) } - return nil + + if len(batch) > 0 { + sendBatch() + } + + if len(allErrors) > 0 { + return batchesSent, &SendNBatchError{ + Errors: allErrors, + Info: allInfo, + } + } + + return batchesSent, nil } // GetQueueUrl returns an AWS SQS queue URL given its name. diff --git a/aws/sqs/sqs_integration_test.go b/aws/sqs/sqs_integration_test.go index e5cfa06..c616b0b 100644 --- a/aws/sqs/sqs_integration_test.go +++ b/aws/sqs/sqs_integration_test.go @@ -501,6 +501,115 @@ func TestSendBatch(t *testing.T) { assert.Equal(t, validMessage, awsCmdReceiveMessage()) } +func TestSendNBatch(t *testing.T) { + // ARRANGE + setup() + defer teardown() + + client, err := New() + require.Nil(t, err, fmt.Sprintf("error creating sqs client: %v", err)) + + // ACTION + var maxBytes int = 262144 + maxSizeSingleMessage := "" + for range maxBytes { + maxSizeSingleMessage += "a" + } + batchesSent, err := client.SendNBatch(context.TODO(), awsCmdQueueURL(), []string{maxSizeSingleMessage, maxSizeSingleMessage}) + + // ASSERT + assert.Nil(t, err) + assert.Equal(t, 2, batchesSent) + assert.Equal(t, maxSizeSingleMessage, awsCmdReceiveMessage()) + assert.Equal(t, maxSizeSingleMessage, awsCmdReceiveMessage()) + + assert.Equal(t, 0, awsCmdQueueCount()) + + // ACTION + tooLargeSingleMessage := maxSizeSingleMessage + "a" + batchesSent, err = client.SendNBatch(context.TODO(), awsCmdQueueURL(), []string{maxSizeSingleMessage, tooLargeSingleMessage}) + + // ASSERT + assert.NotNil(t, err) + assert.Equal(t, 1, batchesSent) + + var sbe *SendNBatchError + if errors.As(err, &sbe) { + assert.Equal(t, 1, len(sbe.Info)) + assert.True(t, indexIsPresent(sbe.Info, 1)) + } else { + t.Error("unexpected error type") + } + + assert.Equal(t, maxSizeSingleMessage, awsCmdReceiveMessage()) + assert.Equal(t, 0, awsCmdQueueCount()) + + // ACTION + smallMessageText := "small" + smallMessageCount := 21 + smallMessages := make([]string, smallMessageCount) + for i := range smallMessageCount { + smallMessages[i] = smallMessageText + } + batchesSent, err = client.SendNBatch(context.TODO(), awsCmdQueueURL(), smallMessages) + + // ASSERT + assert.Nil(t, err) + assert.Equal(t, 3, batchesSent) + + receiveCount := 0 + for range batchesSent { + messages := awsCmdReceiveMessages() + for _, m := range messages { + if m == smallMessageText { + receiveCount++ + } + } + } + assert.Equal(t, smallMessageCount, receiveCount) + + // ACTION + invalidMessage := "\u0000" + + batchesSent, err = client.SendNBatch(context.TODO(), awsCmdQueueURL(), []string{ + tooLargeSingleMessage, + maxSizeSingleMessage, + smallMessageText, + maxSizeSingleMessage, + invalidMessage, + smallMessageText, + smallMessageText, + invalidMessage, + tooLargeSingleMessage, + }) + + // ASSERT + assert.NotNil(t, err) + assert.Equal(t, 4, batchesSent) + + sbe = nil + if errors.As(err, &sbe) { + assert.Equal(t, 4, len(sbe.Info)) + assert.True(t, indexIsPresent(sbe.Info, 0)) + assert.True(t, indexIsPresent(sbe.Info, 4)) + assert.True(t, indexIsPresent(sbe.Info, 7)) + assert.True(t, indexIsPresent(sbe.Info, 8)) + } else { + t.Error("unexpected error type") + } + + assert.Equal(t, 5, len(awsCmdReceiveMessages())) +} + +func indexIsPresent(info []SendBatchErrorEntry, index int) bool { + for _, entry := range info { + if entry.Index == index { + return true + } + } + return false +} + func TestGetQueueUrl(t *testing.T) { // ARRANGE setup() From d9bb5e4ece38be5b320af89acaa4e600dffb202e Mon Sep 17 00:00:00 2001 From: Callum Morris Date: Fri, 11 Apr 2025 13:31:44 +1200 Subject: [PATCH 6/8] feat: add ReceiveBatch wrapper function Changes the underlying receiveMethod function to return all messages rather than just the first --- aws/sqs/sqs.go | 55 ++++++++++++++++++++++++--------- aws/sqs/sqs_integration_test.go | 21 +++++++++++++ 2 files changed, 62 insertions(+), 14 deletions(-) diff --git a/aws/sqs/sqs.go b/aws/sqs/sqs.go index 38a8b93..66fd5c9 100644 --- a/aws/sqs/sqs.go +++ b/aws/sqs/sqs.go @@ -126,34 +126,40 @@ func (s *SQS) ReceiveWithContextAttributes(ctx context.Context, queueURL string, WaitTimeSeconds: 20, AttributeNames: attrs, } - return s.receiveMessage(ctx, &input) + msgs, err := s.receiveMessages(ctx, &input) + if err != nil { + return Raw{}, err + } + return msgs[0], err } -// receiveMessage is the common code used internally to receive an SQS message based +// receiveMessages is the common code used internally to receive an SQS messages based // on the provided input. -func (s *SQS) receiveMessage(ctx context.Context, input *sqs.ReceiveMessageInput) (Raw, error) { +func (s *SQS) receiveMessages(ctx context.Context, input *sqs.ReceiveMessageInput) ([]Raw, error) { r, err := s.client.ReceiveMessage(ctx, input) if err != nil { - return Raw{}, err + return []Raw{}, err } switch { case r == nil || len(r.Messages) == 0: // no message received - return Raw{}, ErrNoMessages + return []Raw{}, ErrNoMessages - case len(r.Messages) == 1: - raw := r.Messages[0] + case len(r.Messages) >= 1: - m := Raw{ - Body: aws.ToString(raw.Body), - ReceiptHandle: aws.ToString(raw.ReceiptHandle), - Attributes: raw.Attributes, + messages := make([]Raw, len(r.Messages)) + for i := range r.Messages { + messages[i] = Raw{ + Body: aws.ToString(r.Messages[i].Body), + ReceiptHandle: aws.ToString(r.Messages[i].ReceiptHandle), + Attributes: r.Messages[i].Attributes, + } } - return m, nil + return messages, nil default: - return Raw{}, fmt.Errorf("received unexpected messages: %d", len(r.Messages)) + return []Raw{}, fmt.Errorf("received unexpected number of messages: %d", len(r.Messages)) // Probably an impossible case } } @@ -168,7 +174,28 @@ func (s *SQS) ReceiveWithContext(ctx context.Context, queueURL string, visibilit VisibilityTimeout: visibilityTimeout, WaitTimeSeconds: 20, } - return s.receiveMessage(ctx, &input) + msgs, err := s.receiveMessages(ctx, &input) + if err != nil { + return Raw{}, err + } + return msgs[0], err +} + +// ReceiveBatch is similar to Receive, however it can return up to 10 messages. +func (s *SQS) ReceiveBatch(ctx context.Context, queueURL string, visibilityTimeout int32) ([]Raw, error) { + + input := sqs.ReceiveMessageInput{ + QueueUrl: aws.String(queueURL), + MaxNumberOfMessages: 10, + VisibilityTimeout: visibilityTimeout, + WaitTimeSeconds: 20, + } + + msgs, err := s.receiveMessages(ctx, &input) + if err != nil { + return []Raw{}, err + } + return msgs, nil } // Delete deletes the message referred to by receiptHandle from the queue. diff --git a/aws/sqs/sqs_integration_test.go b/aws/sqs/sqs_integration_test.go index c616b0b..929fc8d 100644 --- a/aws/sqs/sqs_integration_test.go +++ b/aws/sqs/sqs_integration_test.go @@ -358,6 +358,27 @@ func TestSQSReceiveWithAttributes(t *testing.T) { assert.True(t, len(receivedMessage.Attributes) > 0) } +func TestSQSReceiveBatch(t *testing.T) { + // ARRANGE + setup() + defer teardown() + + awsCmdSendMessage() + awsCmdSendMessage() + + client, err := New() + require.Nil(t, err, fmt.Sprintf("error creating sqs client: %v", err)) + + // ACTION + receivedMessages, err := client.ReceiveBatch(context.TODO(), awsCmdQueueURL(), 30) + + // ASSERT + assert.Nil(t, err) + for _, receivedMessage := range receivedMessages { + assert.Equal(t, testMessage, receivedMessage.Body) + } +} + func TestSQSDelete(t *testing.T) { // ARRANGE setup() From 81a579cfee2e1af14fddc123a1194598d4dc598e Mon Sep 17 00:00:00 2001 From: Callum Morris Date: Fri, 11 Apr 2025 13:12:08 +1200 Subject: [PATCH 7/8] feat: add DeleteBatch function --- aws/sqs/sqs.go | 61 +++++++++++++++++++++++++++++++++ aws/sqs/sqs_integration_test.go | 48 ++++++++++++++++++++++++++ 2 files changed, 109 insertions(+) diff --git a/aws/sqs/sqs.go b/aws/sqs/sqs.go index 66fd5c9..949ff28 100644 --- a/aws/sqs/sqs.go +++ b/aws/sqs/sqs.go @@ -407,6 +407,67 @@ func (s *SQS) SendNBatch(ctx context.Context, queueURL string, bodies []string) return batchesSent, nil } +type DeleteBatchError struct { + Err error + Info []DeleteBatchErrorEntry +} + +type DeleteBatchErrorEntry struct { + Entry types.BatchResultErrorEntry + Index int +} + +func (d *DeleteBatchError) Error() string { + return fmt.Sprintf("%v: %v messages failed to delete", d.Err, len(d.Info)) +} + +func (d *DeleteBatchError) Unwrap() error { + return d.Err +} + +// DeleteBatch deletes up to 10 messages from an SQS queue in a single batch. +// If an error occurs on any or all messages, a DeleteBatchError is returned that lets +// the caller know the indice/s in receiptHandles that failed. +func (s *SQS) DeleteBatch(ctx context.Context, queueURL string, receiptHandles []string) error { + entries := make([]types.DeleteMessageBatchRequestEntry, len(receiptHandles)) + for i, receipt := range receiptHandles { + entries[i] = types.DeleteMessageBatchRequestEntry{ + Id: aws.String(fmt.Sprintf("delete-message-%d", i)), + ReceiptHandle: aws.String(receipt), + } + } + + output, err := s.client.DeleteMessageBatch(ctx, &sqs.DeleteMessageBatchInput{ + Entries: entries, + QueueUrl: &queueURL, + }) + if err != nil { + info := make([]DeleteBatchErrorEntry, len(entries)) + for i := range entries { + info[i] = DeleteBatchErrorEntry{ + Index: i, + } + } + return &DeleteBatchError{Err: err, Info: info} + } + if len(output.Failed) > 0 { + info := make([]DeleteBatchErrorEntry, len(output.Failed)) + for i, errorEntry := range output.Failed { + for j, requestEntry := range entries { + if aws.ToString(requestEntry.Id) == aws.ToString(errorEntry.Id) { + info[i] = DeleteBatchErrorEntry{ + Entry: errorEntry, + Index: j, + } + break + } + } + } + return &DeleteBatchError{Info: info} + } + return nil +} + // GetQueueUrl returns an AWS SQS queue URL given its name. func (s *SQS) GetQueueUrl(name string) (string, error) { params := sqs.GetQueueUrlInput{ diff --git a/aws/sqs/sqs_integration_test.go b/aws/sqs/sqs_integration_test.go index 929fc8d..8c4d4d8 100644 --- a/aws/sqs/sqs_integration_test.go +++ b/aws/sqs/sqs_integration_test.go @@ -631,6 +631,54 @@ func indexIsPresent(info []SendBatchErrorEntry, index int) bool { return false } +func TestDeleteBatch(t *testing.T) { + // ARRANGE + setup() + defer teardown() + + client, err := New() + require.Nil(t, err, fmt.Sprintf("error creating sqs client: %v", err)) + + // Send and receive messages to get receipt handles + messages := []string{"message1", "message2", "message3"} + err = client.SendBatch(context.TODO(), awsCmdQueueURL(), messages) + require.Nil(t, err) + require.Equal(t, 3, awsCmdQueueCount()) + require.Equal(t, 0, awsCmdQueueInFlightCount()) + + receivedMessages, err := client.ReceiveBatch(context.TODO(), awsCmdQueueURL(), 30) + require.Nil(t, err) + require.Equal(t, 3, len(receivedMessages)) + require.Equal(t, 3, awsCmdQueueInFlightCount()) + receiptHandles := make([]string, 0) + for _, rm := range receivedMessages { + receiptHandles = append(receiptHandles, rm.ReceiptHandle) + } + + // ACTION + err = client.DeleteBatch(context.TODO(), awsCmdQueueURL(), receiptHandles) + + // ASSERT + assert.Nil(t, err) + assert.Equal(t, 0, awsCmdQueueCount()) + assert.Equal(t, 0, awsCmdQueueInFlightCount()) + + // ACTION + invalidReceiptHandle := "invalid-receipt-handle" + err = client.DeleteBatch(context.TODO(), awsCmdQueueURL(), []string{invalidReceiptHandle}) + + // ASSERT + assert.NotNil(t, err) + + var dbe *DeleteBatchError + if errors.As(err, &dbe) { + assert.Equal(t, 1, len(dbe.Info)) + assert.Equal(t, 0, dbe.Info[0].Index) + } else { + t.Error("unexpected error type") + } +} + func TestGetQueueUrl(t *testing.T) { // ARRANGE setup() From d2a8033c157987e44b3725abd32474004d954161 Mon Sep 17 00:00:00 2001 From: Callum Morris Date: Fri, 11 Apr 2025 13:16:21 +1200 Subject: [PATCH 8/8] feat: add DeleteNBatch function --- aws/sqs/sqs.go | 70 +++++++++++++++++++++++++++ aws/sqs/sqs_integration_test.go | 86 +++++++++++++++++++++++++++++++++ 2 files changed, 156 insertions(+) diff --git a/aws/sqs/sqs.go b/aws/sqs/sqs.go index 949ff28..e7c963b 100644 --- a/aws/sqs/sqs.go +++ b/aws/sqs/sqs.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "math" "os" "strings" @@ -425,6 +426,20 @@ func (d *DeleteBatchError) Unwrap() error { return d.Err } +type DeleteNBatchError struct { + Errors []error + Info []DeleteBatchErrorEntry +} + +func (s *DeleteNBatchError) Error() string { + var allErrors string + for _, err := range s.Errors { + allErrors += fmt.Sprintf("%s,", err.Error()) + } + allErrors = strings.TrimSuffix(allErrors, ",") + return fmt.Sprintf("%v error(s) deleting batches: %s", len(s.Errors), allErrors) +} + // DeleteBatch deletes up to 10 messages from an SQS queue in a single batch. // If an error occurs on any or all messages, a DeleteBatchError is returned that lets // the caller know the indice/s in receiptHandles that failed. @@ -468,6 +483,61 @@ func (s *SQS) DeleteBatch(ctx context.Context, queueURL string, receiptHandles [ return nil } +// DeleteNBatch deletes any number of messages from a given SQS queue via a series of DeleteBatch calls. +// If an error occurs on any or all messages, a DeleteNBatchError is returned that lets +// the caller know the receipt handles that failed. +// Returns the number of API calls to DeleteBatch made. +func (s *SQS) DeleteNBatch(ctx context.Context, queueURL string, receiptHandles []string) (int, error) { + + var ( + receiptCount = len(receiptHandles) + maxlen = 10 + times = int(math.Ceil(float64(receiptCount) / float64(maxlen))) + ) + + allErrors := make([]error, 0) + allInfo := make([]DeleteBatchErrorEntry, 0) + + batchesDeleted := 0 + + for i := 0; i < times; i++ { + batch_end := maxlen * (i + 1) + if maxlen*(i+1) > receiptCount { + batch_end = receiptCount + } + var receipt_batch = receiptHandles[maxlen*i : batch_end] + + indexMap := make(map[int]int, 0) + count := 0 + for j := maxlen * i; j < batch_end; j++ { + indexMap[count] = j + count++ + } + + err := s.DeleteBatch(ctx, queueURL, receipt_batch) + var dbe *DeleteBatchError + if errors.As(err, &dbe) { + allErrors = append(allErrors, err) + + // Update index so that index refers to the position in given receiptHandles slice. + for i := range dbe.Info { + dbe.Info[i].Index = indexMap[dbe.Info[i].Index] + } + + allInfo = append(allInfo, dbe.Info...) + } + batchesDeleted++ + } + + if len(allErrors) > 0 { + return batchesDeleted, &DeleteNBatchError{ + Errors: allErrors, + Info: allInfo, + } + } + return batchesDeleted, nil +} + // GetQueueUrl returns an AWS SQS queue URL given its name. func (s *SQS) GetQueueUrl(name string) (string, error) { params := sqs.GetQueueUrlInput{ diff --git a/aws/sqs/sqs_integration_test.go b/aws/sqs/sqs_integration_test.go index 8c4d4d8..320a115 100644 --- a/aws/sqs/sqs_integration_test.go +++ b/aws/sqs/sqs_integration_test.go @@ -679,6 +679,92 @@ func TestDeleteBatch(t *testing.T) { } } +func TestDeleteNBatch(t *testing.T) { + // ARRANGE + setup() + defer teardown() + + client, err := New() + require.Nil(t, err, fmt.Sprintf("error creating sqs client: %v", err)) + + // Send and receive messages to get receipt handles + messages := []string{"msg1", "msg2", "msg3", "msg4", "msg5", "msg6", "msg7", "msg8", "msg9", "msg10", "msg11"} + batchesSent, err := client.SendNBatch(context.TODO(), awsCmdQueueURL(), messages) + + require.Nil(t, err) + require.Equal(t, 2, batchesSent) + + receivedMessages1, err := client.ReceiveBatch(context.TODO(), awsCmdQueueURL(), 30) + require.Nil(t, err) + require.Equal(t, 10, len(receivedMessages1)) + receivedMessages2, err := client.ReceiveBatch(context.TODO(), awsCmdQueueURL(), 30) + require.Nil(t, err) + require.Equal(t, 1, len(receivedMessages2)) + + require.Equal(t, 11, awsCmdQueueInFlightCount()) + + receiptHandles := make([]string, 0) + for _, rm := range receivedMessages1 { + receiptHandles = append(receiptHandles, rm.ReceiptHandle) + } + for _, rm := range receivedMessages2 { + receiptHandles = append(receiptHandles, rm.ReceiptHandle) + } + + // ACTION + batchesDeleted, err := client.DeleteNBatch(context.TODO(), awsCmdQueueURL(), receiptHandles) + + // ASSERT + assert.Nil(t, err) + assert.Equal(t, 2, batchesDeleted) + assert.Equal(t, 0, awsCmdQueueCount()) + assert.Equal(t, 0, awsCmdQueueInFlightCount()) + + // ARRANGE + + // Send and receive messages to get receipt handles + messages = []string{"msg1", "msg2", "msg3", "msg4", "msg5", "msg6", "msg7", "msg8", "msg9", "msg10", "msg11", "msg12"} + batchesSent, err = client.SendNBatch(context.TODO(), awsCmdQueueURL(), messages) + + require.Nil(t, err) + require.Equal(t, 2, batchesSent) + + receivedMessages1, err = client.ReceiveBatch(context.TODO(), awsCmdQueueURL(), 30) + require.Nil(t, err) + require.Equal(t, 10, len(receivedMessages1)) + receivedMessages2, err = client.ReceiveBatch(context.TODO(), awsCmdQueueURL(), 30) + require.Nil(t, err) + require.Equal(t, 2, len(receivedMessages2)) + require.Equal(t, 12, awsCmdQueueInFlightCount()) + + receiptHandles = make([]string, 0) + for _, rm := range receivedMessages1 { + receiptHandles = append(receiptHandles, rm.ReceiptHandle) + } + for _, rm := range receivedMessages2 { + receiptHandles = append(receiptHandles, rm.ReceiptHandle) + } + invalidReceiptHandle := "invalid-receipt-handle" + receiptHandles[0] = invalidReceiptHandle // Replace a valid receipt handle with an invalid one. + receiptHandles = append(receiptHandles, invalidReceiptHandle) // Append an invalid receipt handle (index 12) + + // ACTION + batchesDeleted, err = client.DeleteNBatch(context.TODO(), awsCmdQueueURL(), receiptHandles) + assert.NotNil(t, err) + + var dbe *DeleteNBatchError + if errors.As(err, &dbe) { + assert.Equal(t, 2, len(dbe.Info)) + assert.Equal(t, 0, dbe.Info[0].Index) + assert.Equal(t, 12, dbe.Info[1].Index) + } else { + t.Error("unexpected error type") + } + assert.Equal(t, 2, batchesDeleted) + assert.Equal(t, 0, awsCmdQueueCount()) + assert.Equal(t, 1, awsCmdQueueInFlightCount()) +} + func TestGetQueueUrl(t *testing.T) { // ARRANGE setup()