diff --git a/aws/sqs/sqs.go b/aws/sqs/sqs.go index 9b36026..e7c963b 100644 --- a/aws/sqs/sqs.go +++ b/aws/sqs/sqs.go @@ -127,34 +127,40 @@ func (s *SQS) ReceiveWithContextAttributes(ctx context.Context, queueURL string, WaitTimeSeconds: 20, AttributeNames: attrs, } - return s.receiveMessage(ctx, &input) + msgs, err := s.receiveMessages(ctx, &input) + if err != nil { + return Raw{}, err + } + return msgs[0], err } -// receiveMessage is the common code used internally to receive an SQS message based +// receiveMessages is the common code used internally to receive an SQS messages based // on the provided input. -func (s *SQS) receiveMessage(ctx context.Context, input *sqs.ReceiveMessageInput) (Raw, error) { +func (s *SQS) receiveMessages(ctx context.Context, input *sqs.ReceiveMessageInput) ([]Raw, error) { r, err := s.client.ReceiveMessage(ctx, input) if err != nil { - return Raw{}, err + return []Raw{}, err } switch { case r == nil || len(r.Messages) == 0: // no message received - return Raw{}, ErrNoMessages + return []Raw{}, ErrNoMessages - case len(r.Messages) == 1: - raw := r.Messages[0] + case len(r.Messages) >= 1: - m := Raw{ - Body: aws.ToString(raw.Body), - ReceiptHandle: aws.ToString(raw.ReceiptHandle), - Attributes: raw.Attributes, + messages := make([]Raw, len(r.Messages)) + for i := range r.Messages { + messages[i] = Raw{ + Body: aws.ToString(r.Messages[i].Body), + ReceiptHandle: aws.ToString(r.Messages[i].ReceiptHandle), + Attributes: r.Messages[i].Attributes, + } } - return m, nil + return messages, nil default: - return Raw{}, fmt.Errorf("received unexpected messages: %d", len(r.Messages)) + return []Raw{}, fmt.Errorf("received unexpected number of messages: %d", len(r.Messages)) // Probably an impossible case } } @@ -169,7 +175,28 @@ func (s *SQS) ReceiveWithContext(ctx context.Context, queueURL string, visibilit VisibilityTimeout: visibilityTimeout, WaitTimeSeconds: 20, } - return s.receiveMessage(ctx, &input) + msgs, err := s.receiveMessages(ctx, &input) + if err != nil { + return Raw{}, err + } + return msgs[0], err +} + +// ReceiveBatch is similar to Receive, however it can return up to 10 messages. +func (s *SQS) ReceiveBatch(ctx context.Context, queueURL string, visibilityTimeout int32) ([]Raw, error) { + + input := sqs.ReceiveMessageInput{ + QueueUrl: aws.String(queueURL), + MaxNumberOfMessages: 10, + VisibilityTimeout: visibilityTimeout, + WaitTimeSeconds: 20, + } + + msgs, err := s.receiveMessages(ctx, &input) + if err != nil { + return []Raw{}, err + } + return msgs, nil } // Delete deletes the message referred to by receiptHandle from the queue. @@ -231,44 +258,284 @@ func (s *SQS) SendFifoMessage(queue, group, dedupe string, msg []byte) (string, return "", nil } -// Leverage the sendbatch api for uploading large numbers of messages -func (s *SQS) SendBatch(ctx context.Context, queueURL string, bodies []string) error { - if len(bodies) > 11 { - return errors.New("too many messages to batch") +type SendBatchError struct { + Err error + Info []SendBatchErrorEntry +} +type SendBatchErrorEntry struct { + Entry types.BatchResultErrorEntry + Index int +} + +func (s *SendBatchError) Error() string { + return fmt.Sprintf("%v: %v messages failed to send", s.Err, len(s.Info)) +} +func (s *SendBatchError) Unwrap() error { + return s.Err +} + +type SendNBatchError struct { + Errors []error + Info []SendBatchErrorEntry +} + +func (s *SendNBatchError) Error() string { + var allErrors string + for _, err := range s.Errors { + allErrors += fmt.Sprintf("%s,", err.Error()) } + allErrors = strings.TrimSuffix(allErrors, ",") + return fmt.Sprintf("%v error(s) sending batches: %s", len(s.Errors), allErrors) +} + +// SendBatch sends up to 10 messages to a given SQS queue with one API call. +// If an error occurs on any or all messages, a SendBatchError is returned that lets +// the caller know the index of the message/s in bodies that failed. +func (s *SQS) SendBatch(ctx context.Context, queueURL string, bodies []string) error { + var err error entries := make([]types.SendMessageBatchRequestEntry, len(bodies)) for j, body := range bodies { entries[j] = types.SendMessageBatchRequestEntry{ - Id: aws.String(fmt.Sprintf("gamitjob%d", j)), + Id: aws.String(fmt.Sprintf("message-%d", j)), MessageBody: aws.String(body), } } - _, err = s.client.SendMessageBatch(ctx, &sqs.SendMessageBatchInput{ + output, err := s.client.SendMessageBatch(ctx, &sqs.SendMessageBatchInput{ Entries: entries, QueueUrl: &queueURL, }) - return err + if err != nil { + info := make([]SendBatchErrorEntry, len(entries)) + for i := range entries { + info[i] = SendBatchErrorEntry{ + Index: i, + } + } + return &SendBatchError{Err: err, Info: info} + } + if len(output.Failed) > 0 { + info := make([]SendBatchErrorEntry, len(output.Failed)) + for i, entry := range output.Failed { + for j, msg := range entries { + if aws.ToString(msg.Id) == aws.ToString(entry.Id) { + info[i] = SendBatchErrorEntry{ + Entry: entry, + Index: j, + } + break + } + } + } + return &SendBatchError{Err: errors.New("partial message failure"), Info: info} + } + return nil } -func (s *SQS) SendNBatch(ctx context.Context, queueURL string, bodies []string) error { +// SendNBatch sends any number of messages to a given SQS queue via a series of SendBatch calls. +// If an error occurs on any or all messages, a SendNBatchError is returned that lets +// the caller know the index of the message/s in bodies that failed. +// Returns the number of API calls to SendBatch made. +func (s *SQS) SendNBatch(ctx context.Context, queueURL string, bodies []string) (int, error) { + + const ( + maxCount = 10 + maxSize = 262144 // 256 KiB + ) + + allErrors := make([]error, 0) + allInfo := make([]SendBatchErrorEntry, 0) + + batchesSent := 0 + + batch := make([]int, 0) + totalSize := 0 + + sendBatch := func() { + batchBodies := make([]string, len(batch)) + + for i, batchIndex := range batch { + batchBodies[i] = bodies[batchIndex] + } + + err := s.SendBatch(ctx, queueURL, batchBodies) + var sbe *SendBatchError + if errors.As(err, &sbe) { + allErrors = append(allErrors, err) + + // Update index so that index refers to the position in given bodies slice. + for i := range sbe.Info { + sbe.Info[i].Index = batch[sbe.Info[i].Index] + } + + allInfo = append(allInfo, sbe.Info...) + } + + batchesSent++ + batch = batch[:0] + totalSize = 0 + } + + for i, body := range bodies { + + // Check if any single message is too big + if len(body) > maxSize { + allErrors = append(allErrors, errors.New("message too big to send")) + allInfo = append(allInfo, SendBatchErrorEntry{ + Index: i, + }) + continue + } + // If adding the current message would exceed the batch max size or count, send the current batch. + if totalSize+len(body) > maxSize || len(batch) == maxCount { + sendBatch() + } + batch = append(batch, i) + totalSize += len(body) + } + + if len(batch) > 0 { + sendBatch() + } + + if len(allErrors) > 0 { + return batchesSent, &SendNBatchError{ + Errors: allErrors, + Info: allInfo, + } + } + + return batchesSent, nil +} + +type DeleteBatchError struct { + Err error + Info []DeleteBatchErrorEntry +} + +type DeleteBatchErrorEntry struct { + Entry types.BatchResultErrorEntry + Index int +} + +func (d *DeleteBatchError) Error() string { + return fmt.Sprintf("%v: %v messages failed to delete", d.Err, len(d.Info)) +} + +func (d *DeleteBatchError) Unwrap() error { + return d.Err +} + +type DeleteNBatchError struct { + Errors []error + Info []DeleteBatchErrorEntry +} + +func (s *DeleteNBatchError) Error() string { + var allErrors string + for _, err := range s.Errors { + allErrors += fmt.Sprintf("%s,", err.Error()) + } + allErrors = strings.TrimSuffix(allErrors, ",") + return fmt.Sprintf("%v error(s) deleting batches: %s", len(s.Errors), allErrors) +} + +// DeleteBatch deletes up to 10 messages from an SQS queue in a single batch. +// If an error occurs on any or all messages, a DeleteBatchError is returned that lets +// the caller know the indice/s in receiptHandles that failed. +func (s *SQS) DeleteBatch(ctx context.Context, queueURL string, receiptHandles []string) error { + entries := make([]types.DeleteMessageBatchRequestEntry, len(receiptHandles)) + for i, receipt := range receiptHandles { + entries[i] = types.DeleteMessageBatchRequestEntry{ + Id: aws.String(fmt.Sprintf("delete-message-%d", i)), + ReceiptHandle: aws.String(receipt), + } + } + + output, err := s.client.DeleteMessageBatch(ctx, &sqs.DeleteMessageBatchInput{ + Entries: entries, + QueueUrl: &queueURL, + }) + if err != nil { + info := make([]DeleteBatchErrorEntry, len(entries)) + for i := range entries { + info[i] = DeleteBatchErrorEntry{ + Index: i, + } + } + return &DeleteBatchError{Err: err, Info: info} + } + if len(output.Failed) > 0 { + info := make([]DeleteBatchErrorEntry, len(output.Failed)) + for i, errorEntry := range output.Failed { + for j, requestEntry := range entries { + if aws.ToString(requestEntry.Id) == aws.ToString(errorEntry.Id) { + info[i] = DeleteBatchErrorEntry{ + Entry: errorEntry, + Index: j, + } + break + } + } + } + return &DeleteBatchError{Info: info} + } + return nil +} + +// DeleteNBatch deletes any number of messages from a given SQS queue via a series of DeleteBatch calls. +// If an error occurs on any or all messages, a DeleteNBatchError is returned that lets +// the caller know the receipt handles that failed. +// Returns the number of API calls to DeleteBatch made. +func (s *SQS) DeleteNBatch(ctx context.Context, queueURL string, receiptHandles []string) (int, error) { + var ( - bodiesLen = len(bodies) - maxlen = 10 - times = int(math.Ceil(float64(bodiesLen) / float64(maxlen))) + receiptCount = len(receiptHandles) + maxlen = 10 + times = int(math.Ceil(float64(receiptCount) / float64(maxlen))) ) + + allErrors := make([]error, 0) + allInfo := make([]DeleteBatchErrorEntry, 0) + + batchesDeleted := 0 + for i := 0; i < times; i++ { batch_end := maxlen * (i + 1) - if maxlen*(i+1) > bodiesLen { - batch_end = bodiesLen + if maxlen*(i+1) > receiptCount { + batch_end = receiptCount } - var bodies_batch = bodies[maxlen*i : batch_end] - err := s.SendBatch(ctx, queueURL, bodies_batch) - if err != nil { - return err + var receipt_batch = receiptHandles[maxlen*i : batch_end] + + indexMap := make(map[int]int, 0) + count := 0 + for j := maxlen * i; j < batch_end; j++ { + indexMap[count] = j + count++ } + + err := s.DeleteBatch(ctx, queueURL, receipt_batch) + var dbe *DeleteBatchError + if errors.As(err, &dbe) { + allErrors = append(allErrors, err) + + // Update index so that index refers to the position in given receiptHandles slice. + for i := range dbe.Info { + dbe.Info[i].Index = indexMap[dbe.Info[i].Index] + } + + allInfo = append(allInfo, dbe.Info...) + } + batchesDeleted++ } - return nil + + if len(allErrors) > 0 { + return batchesDeleted, &DeleteNBatchError{ + Errors: allErrors, + Info: allInfo, + } + } + return batchesDeleted, nil } // GetQueueUrl returns an AWS SQS queue URL given its name. diff --git a/aws/sqs/sqs_integration_test.go b/aws/sqs/sqs_integration_test.go index 9f2a0f7..320a115 100644 --- a/aws/sqs/sqs_integration_test.go +++ b/aws/sqs/sqs_integration_test.go @@ -4,7 +4,9 @@ package sqs import ( + "context" "encoding/json" + "errors" "fmt" "os" "os/exec" @@ -111,6 +113,31 @@ func awsCmdReceiveMessage() string { } } +func awsCmdReceiveMessages() []string { + if out, err := exec.Command( //nolint:gosec + "aws", "sqs", + "receive-message", + "--queue-url", awsCmdQueueURL(), + "--attribute-names", "body", + "--region", awsRegion, + "--max-number-of-messages", "10", // AWS SQS allows up to 10 messages at a time + ).CombinedOutput(); err != nil { + + panic(err) + } else { + var payload map[string][]map[string]string + _ = json.Unmarshal(out, &payload) + + var bodies []string + for _, msg := range payload["Messages"] { + if body, ok := msg["Body"]; ok { + bodies = append(bodies, body) + } + } + return bodies + } +} + func awsCmdQueueCount() int { if out, err := exec.Command( "aws", "sqs", @@ -128,6 +155,25 @@ func awsCmdQueueCount() int { } } +func awsCmdQueueInFlightCount() int { + out, err := exec.Command( + "aws", "sqs", + "get-queue-attributes", + "--queue-url", awsCmdQueueURL(), + "--attribute-name", "ApproximateNumberOfMessagesNotVisible", + "--region", awsRegion).CombinedOutput() + + if err != nil { + panic(err) + } + + var payload map[string]map[string]string + json.Unmarshal(out, &payload) + + rvalue, _ := strconv.Atoi(payload["Attributes"]["ApproximateNumberOfMessagesNotVisible"]) + return rvalue +} + func awsCmdGetQueueArn(url string) string { arn, err := exec.Command( "aws", "sqs", @@ -312,6 +358,27 @@ func TestSQSReceiveWithAttributes(t *testing.T) { assert.True(t, len(receivedMessage.Attributes) > 0) } +func TestSQSReceiveBatch(t *testing.T) { + // ARRANGE + setup() + defer teardown() + + awsCmdSendMessage() + awsCmdSendMessage() + + client, err := New() + require.Nil(t, err, fmt.Sprintf("error creating sqs client: %v", err)) + + // ACTION + receivedMessages, err := client.ReceiveBatch(context.TODO(), awsCmdQueueURL(), 30) + + // ASSERT + assert.Nil(t, err) + for _, receivedMessage := range receivedMessages { + assert.Equal(t, testMessage, receivedMessage.Body) + } +} + func TestSQSDelete(t *testing.T) { // ARRANGE setup() @@ -323,8 +390,9 @@ func TestSQSDelete(t *testing.T) { client, err := New() require.Nil(t, err, fmt.Sprintf("Error creating sqs client: %v", err)) - receivedMessage, err := client.Receive(awsCmdQueueURL(), 1) + receivedMessage, err := client.Receive(awsCmdQueueURL(), 30) require.Nil(t, err, fmt.Sprintf("Error receiving test message: %v", err)) + require.Equal(t, 1, awsCmdQueueInFlightCount()) // ACTION err = client.Delete(awsCmdQueueURL(), receivedMessage.ReceiptHandle) @@ -332,6 +400,7 @@ func TestSQSDelete(t *testing.T) { // ASSERT assert.Nil(t, err) assert.Equal(t, 0, awsCmdQueueCount()) + assert.Equal(t, 0, awsCmdQueueInFlightCount()) } func TestSQSSend(t *testing.T) { @@ -377,6 +446,325 @@ func TestSQSSendWithDelay(t *testing.T) { assert.True(t, timeElapsed < timeout) } +func TestSendBatch(t *testing.T) { + // ARRANGE + setup() + defer teardown() + + client, err := New() + require.Nil(t, err, fmt.Sprintf("error creating sqs client: %v", err)) + + // ACTION + var maxBytes int = 262144 + maxSizeSingleMessage := "" + for range maxBytes { + maxSizeSingleMessage += "a" + } + err = client.SendBatch(context.TODO(), awsCmdQueueURL(), []string{maxSizeSingleMessage}) + + // ASSERT + assert.Nil(t, err) + assert.Equal(t, maxSizeSingleMessage, awsCmdReceiveMessage()) + + // ACTION + tooLargeSingleMessage := maxSizeSingleMessage + "a" + err = client.SendBatch(context.TODO(), awsCmdQueueURL(), []string{tooLargeSingleMessage}) + + // ASSERT + assert.NotNil(t, err) + + // ACTION + var maxHalfBytes int = 131072 + maxHalfSizeMessage := "" + for range maxHalfBytes { + maxHalfSizeMessage += "a" + } + err = client.SendBatch(context.TODO(), awsCmdQueueURL(), []string{maxHalfSizeMessage, maxHalfSizeMessage}) + + // ASSERT + assert.Nil(t, err) + assert.Equal(t, maxHalfSizeMessage, awsCmdReceiveMessage()) + assert.Equal(t, maxHalfSizeMessage, awsCmdReceiveMessage()) + + // ACTION + tooLargeHalfSizeMessage := maxHalfSizeMessage + "a" + err = client.SendBatch(context.TODO(), awsCmdQueueURL(), []string{maxHalfSizeMessage, tooLargeHalfSizeMessage}) + + // ASSERT + assert.NotNil(t, err) + + var sbe *SendBatchError + if errors.As(err, &sbe) { + assert.Equal(t, 2, len(sbe.Info)) + assert.Equal(t, 0, sbe.Info[0].Index) + assert.Equal(t, 1, sbe.Info[1].Index) + } else { + t.Error("unexpected error type") + } + + // ACTION + validMessage := "test" + invalidMessage := "\u0000" + + err = client.SendBatch(context.TODO(), awsCmdQueueURL(), []string{validMessage, invalidMessage}) + + // ASSERT + assert.NotNil(t, err) + + sbe = nil + if errors.As(err, &sbe) { + assert.Equal(t, 1, len(sbe.Info)) + assert.Equal(t, 1, sbe.Info[0].Index) + } else { + t.Error("unexpected error type") + } + + assert.Equal(t, validMessage, awsCmdReceiveMessage()) +} + +func TestSendNBatch(t *testing.T) { + // ARRANGE + setup() + defer teardown() + + client, err := New() + require.Nil(t, err, fmt.Sprintf("error creating sqs client: %v", err)) + + // ACTION + var maxBytes int = 262144 + maxSizeSingleMessage := "" + for range maxBytes { + maxSizeSingleMessage += "a" + } + batchesSent, err := client.SendNBatch(context.TODO(), awsCmdQueueURL(), []string{maxSizeSingleMessage, maxSizeSingleMessage}) + + // ASSERT + assert.Nil(t, err) + assert.Equal(t, 2, batchesSent) + assert.Equal(t, maxSizeSingleMessage, awsCmdReceiveMessage()) + assert.Equal(t, maxSizeSingleMessage, awsCmdReceiveMessage()) + + assert.Equal(t, 0, awsCmdQueueCount()) + + // ACTION + tooLargeSingleMessage := maxSizeSingleMessage + "a" + batchesSent, err = client.SendNBatch(context.TODO(), awsCmdQueueURL(), []string{maxSizeSingleMessage, tooLargeSingleMessage}) + + // ASSERT + assert.NotNil(t, err) + assert.Equal(t, 1, batchesSent) + + var sbe *SendNBatchError + if errors.As(err, &sbe) { + assert.Equal(t, 1, len(sbe.Info)) + assert.True(t, indexIsPresent(sbe.Info, 1)) + } else { + t.Error("unexpected error type") + } + + assert.Equal(t, maxSizeSingleMessage, awsCmdReceiveMessage()) + assert.Equal(t, 0, awsCmdQueueCount()) + + // ACTION + smallMessageText := "small" + smallMessageCount := 21 + smallMessages := make([]string, smallMessageCount) + for i := range smallMessageCount { + smallMessages[i] = smallMessageText + } + batchesSent, err = client.SendNBatch(context.TODO(), awsCmdQueueURL(), smallMessages) + + // ASSERT + assert.Nil(t, err) + assert.Equal(t, 3, batchesSent) + + receiveCount := 0 + for range batchesSent { + messages := awsCmdReceiveMessages() + for _, m := range messages { + if m == smallMessageText { + receiveCount++ + } + } + } + assert.Equal(t, smallMessageCount, receiveCount) + + // ACTION + invalidMessage := "\u0000" + + batchesSent, err = client.SendNBatch(context.TODO(), awsCmdQueueURL(), []string{ + tooLargeSingleMessage, + maxSizeSingleMessage, + smallMessageText, + maxSizeSingleMessage, + invalidMessage, + smallMessageText, + smallMessageText, + invalidMessage, + tooLargeSingleMessage, + }) + + // ASSERT + assert.NotNil(t, err) + assert.Equal(t, 4, batchesSent) + + sbe = nil + if errors.As(err, &sbe) { + assert.Equal(t, 4, len(sbe.Info)) + assert.True(t, indexIsPresent(sbe.Info, 0)) + assert.True(t, indexIsPresent(sbe.Info, 4)) + assert.True(t, indexIsPresent(sbe.Info, 7)) + assert.True(t, indexIsPresent(sbe.Info, 8)) + } else { + t.Error("unexpected error type") + } + + assert.Equal(t, 5, len(awsCmdReceiveMessages())) +} + +func indexIsPresent(info []SendBatchErrorEntry, index int) bool { + for _, entry := range info { + if entry.Index == index { + return true + } + } + return false +} + +func TestDeleteBatch(t *testing.T) { + // ARRANGE + setup() + defer teardown() + + client, err := New() + require.Nil(t, err, fmt.Sprintf("error creating sqs client: %v", err)) + + // Send and receive messages to get receipt handles + messages := []string{"message1", "message2", "message3"} + err = client.SendBatch(context.TODO(), awsCmdQueueURL(), messages) + require.Nil(t, err) + require.Equal(t, 3, awsCmdQueueCount()) + require.Equal(t, 0, awsCmdQueueInFlightCount()) + + receivedMessages, err := client.ReceiveBatch(context.TODO(), awsCmdQueueURL(), 30) + require.Nil(t, err) + require.Equal(t, 3, len(receivedMessages)) + require.Equal(t, 3, awsCmdQueueInFlightCount()) + receiptHandles := make([]string, 0) + for _, rm := range receivedMessages { + receiptHandles = append(receiptHandles, rm.ReceiptHandle) + } + + // ACTION + err = client.DeleteBatch(context.TODO(), awsCmdQueueURL(), receiptHandles) + + // ASSERT + assert.Nil(t, err) + assert.Equal(t, 0, awsCmdQueueCount()) + assert.Equal(t, 0, awsCmdQueueInFlightCount()) + + // ACTION + invalidReceiptHandle := "invalid-receipt-handle" + err = client.DeleteBatch(context.TODO(), awsCmdQueueURL(), []string{invalidReceiptHandle}) + + // ASSERT + assert.NotNil(t, err) + + var dbe *DeleteBatchError + if errors.As(err, &dbe) { + assert.Equal(t, 1, len(dbe.Info)) + assert.Equal(t, 0, dbe.Info[0].Index) + } else { + t.Error("unexpected error type") + } +} + +func TestDeleteNBatch(t *testing.T) { + // ARRANGE + setup() + defer teardown() + + client, err := New() + require.Nil(t, err, fmt.Sprintf("error creating sqs client: %v", err)) + + // Send and receive messages to get receipt handles + messages := []string{"msg1", "msg2", "msg3", "msg4", "msg5", "msg6", "msg7", "msg8", "msg9", "msg10", "msg11"} + batchesSent, err := client.SendNBatch(context.TODO(), awsCmdQueueURL(), messages) + + require.Nil(t, err) + require.Equal(t, 2, batchesSent) + + receivedMessages1, err := client.ReceiveBatch(context.TODO(), awsCmdQueueURL(), 30) + require.Nil(t, err) + require.Equal(t, 10, len(receivedMessages1)) + receivedMessages2, err := client.ReceiveBatch(context.TODO(), awsCmdQueueURL(), 30) + require.Nil(t, err) + require.Equal(t, 1, len(receivedMessages2)) + + require.Equal(t, 11, awsCmdQueueInFlightCount()) + + receiptHandles := make([]string, 0) + for _, rm := range receivedMessages1 { + receiptHandles = append(receiptHandles, rm.ReceiptHandle) + } + for _, rm := range receivedMessages2 { + receiptHandles = append(receiptHandles, rm.ReceiptHandle) + } + + // ACTION + batchesDeleted, err := client.DeleteNBatch(context.TODO(), awsCmdQueueURL(), receiptHandles) + + // ASSERT + assert.Nil(t, err) + assert.Equal(t, 2, batchesDeleted) + assert.Equal(t, 0, awsCmdQueueCount()) + assert.Equal(t, 0, awsCmdQueueInFlightCount()) + + // ARRANGE + + // Send and receive messages to get receipt handles + messages = []string{"msg1", "msg2", "msg3", "msg4", "msg5", "msg6", "msg7", "msg8", "msg9", "msg10", "msg11", "msg12"} + batchesSent, err = client.SendNBatch(context.TODO(), awsCmdQueueURL(), messages) + + require.Nil(t, err) + require.Equal(t, 2, batchesSent) + + receivedMessages1, err = client.ReceiveBatch(context.TODO(), awsCmdQueueURL(), 30) + require.Nil(t, err) + require.Equal(t, 10, len(receivedMessages1)) + receivedMessages2, err = client.ReceiveBatch(context.TODO(), awsCmdQueueURL(), 30) + require.Nil(t, err) + require.Equal(t, 2, len(receivedMessages2)) + require.Equal(t, 12, awsCmdQueueInFlightCount()) + + receiptHandles = make([]string, 0) + for _, rm := range receivedMessages1 { + receiptHandles = append(receiptHandles, rm.ReceiptHandle) + } + for _, rm := range receivedMessages2 { + receiptHandles = append(receiptHandles, rm.ReceiptHandle) + } + invalidReceiptHandle := "invalid-receipt-handle" + receiptHandles[0] = invalidReceiptHandle // Replace a valid receipt handle with an invalid one. + receiptHandles = append(receiptHandles, invalidReceiptHandle) // Append an invalid receipt handle (index 12) + + // ACTION + batchesDeleted, err = client.DeleteNBatch(context.TODO(), awsCmdQueueURL(), receiptHandles) + assert.NotNil(t, err) + + var dbe *DeleteNBatchError + if errors.As(err, &dbe) { + assert.Equal(t, 2, len(dbe.Info)) + assert.Equal(t, 0, dbe.Info[0].Index) + assert.Equal(t, 12, dbe.Info[1].Index) + } else { + t.Error("unexpected error type") + } + assert.Equal(t, 2, batchesDeleted) + assert.Equal(t, 0, awsCmdQueueCount()) + assert.Equal(t, 1, awsCmdQueueInFlightCount()) +} + func TestGetQueueUrl(t *testing.T) { // ARRANGE setup()