Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
331 changes: 299 additions & 32 deletions aws/sqs/sqs.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,34 +127,40 @@ func (s *SQS) ReceiveWithContextAttributes(ctx context.Context, queueURL string,
WaitTimeSeconds: 20,
AttributeNames: attrs,
}
return s.receiveMessage(ctx, &input)
msgs, err := s.receiveMessages(ctx, &input)
if err != nil {
return Raw{}, err
}
return msgs[0], err
}

// receiveMessage is the common code used internally to receive an SQS message based
// receiveMessages is the common code used internally to receive an SQS messages based
// on the provided input.
func (s *SQS) receiveMessage(ctx context.Context, input *sqs.ReceiveMessageInput) (Raw, error) {
func (s *SQS) receiveMessages(ctx context.Context, input *sqs.ReceiveMessageInput) ([]Raw, error) {
r, err := s.client.ReceiveMessage(ctx, input)
if err != nil {
return Raw{}, err
return []Raw{}, err
}

switch {
case r == nil || len(r.Messages) == 0:
// no message received
return Raw{}, ErrNoMessages
return []Raw{}, ErrNoMessages

case len(r.Messages) == 1:
raw := r.Messages[0]
case len(r.Messages) >= 1:

m := Raw{
Body: aws.ToString(raw.Body),
ReceiptHandle: aws.ToString(raw.ReceiptHandle),
Attributes: raw.Attributes,
messages := make([]Raw, len(r.Messages))
for i := range r.Messages {
messages[i] = Raw{
Body: aws.ToString(r.Messages[i].Body),
ReceiptHandle: aws.ToString(r.Messages[i].ReceiptHandle),
Attributes: r.Messages[i].Attributes,
}
}
return m, nil
return messages, nil

default:
return Raw{}, fmt.Errorf("received unexpected messages: %d", len(r.Messages))
return []Raw{}, fmt.Errorf("received unexpected number of messages: %d", len(r.Messages)) // Probably an impossible case
}
}

Expand All @@ -169,7 +175,28 @@ func (s *SQS) ReceiveWithContext(ctx context.Context, queueURL string, visibilit
VisibilityTimeout: visibilityTimeout,
WaitTimeSeconds: 20,
}
return s.receiveMessage(ctx, &input)
msgs, err := s.receiveMessages(ctx, &input)
if err != nil {
return Raw{}, err
}
return msgs[0], err
}

// ReceiveBatch is similar to Receive, however it can return up to 10 messages.
func (s *SQS) ReceiveBatch(ctx context.Context, queueURL string, visibilityTimeout int32) ([]Raw, error) {

input := sqs.ReceiveMessageInput{
QueueUrl: aws.String(queueURL),
MaxNumberOfMessages: 10,
VisibilityTimeout: visibilityTimeout,
WaitTimeSeconds: 20,
}

msgs, err := s.receiveMessages(ctx, &input)
if err != nil {
return []Raw{}, err
}
return msgs, nil
}

// Delete deletes the message referred to by receiptHandle from the queue.
Expand Down Expand Up @@ -231,44 +258,284 @@ func (s *SQS) SendFifoMessage(queue, group, dedupe string, msg []byte) (string,
return "", nil
}

// Leverage the sendbatch api for uploading large numbers of messages
func (s *SQS) SendBatch(ctx context.Context, queueURL string, bodies []string) error {
if len(bodies) > 11 {
return errors.New("too many messages to batch")
type SendBatchError struct {
Err error
Info []SendBatchErrorEntry
}
type SendBatchErrorEntry struct {
Entry types.BatchResultErrorEntry
Index int
}

func (s *SendBatchError) Error() string {
return fmt.Sprintf("%v: %v messages failed to send", s.Err, len(s.Info))
}
func (s *SendBatchError) Unwrap() error {
return s.Err
}

type SendNBatchError struct {
Errors []error
Info []SendBatchErrorEntry
}

func (s *SendNBatchError) Error() string {
var allErrors string
for _, err := range s.Errors {
allErrors += fmt.Sprintf("%s,", err.Error())
}
allErrors = strings.TrimSuffix(allErrors, ",")
return fmt.Sprintf("%v error(s) sending batches: %s", len(s.Errors), allErrors)
}

// SendBatch sends up to 10 messages to a given SQS queue with one API call.
// If an error occurs on any or all messages, a SendBatchError is returned that lets
// the caller know the index of the message/s in bodies that failed.
func (s *SQS) SendBatch(ctx context.Context, queueURL string, bodies []string) error {

var err error
entries := make([]types.SendMessageBatchRequestEntry, len(bodies))
for j, body := range bodies {
entries[j] = types.SendMessageBatchRequestEntry{
Id: aws.String(fmt.Sprintf("gamitjob%d", j)),
Id: aws.String(fmt.Sprintf("message-%d", j)),
MessageBody: aws.String(body),
}
}
_, err = s.client.SendMessageBatch(ctx, &sqs.SendMessageBatchInput{
output, err := s.client.SendMessageBatch(ctx, &sqs.SendMessageBatchInput{
Entries: entries,
QueueUrl: &queueURL,
})
return err
if err != nil {
info := make([]SendBatchErrorEntry, len(entries))
for i := range entries {
info[i] = SendBatchErrorEntry{
Index: i,
}
}
return &SendBatchError{Err: err, Info: info}
}
if len(output.Failed) > 0 {
info := make([]SendBatchErrorEntry, len(output.Failed))
for i, entry := range output.Failed {
for j, msg := range entries {
if aws.ToString(msg.Id) == aws.ToString(entry.Id) {
info[i] = SendBatchErrorEntry{
Entry: entry,
Index: j,
}
break
}
}
}
return &SendBatchError{Err: errors.New("partial message failure"), Info: info}
}
return nil
}

func (s *SQS) SendNBatch(ctx context.Context, queueURL string, bodies []string) error {
// SendNBatch sends any number of messages to a given SQS queue via a series of SendBatch calls.
// If an error occurs on any or all messages, a SendNBatchError is returned that lets
// the caller know the index of the message/s in bodies that failed.
// Returns the number of API calls to SendBatch made.
func (s *SQS) SendNBatch(ctx context.Context, queueURL string, bodies []string) (int, error) {

const (
maxCount = 10
maxSize = 262144 // 256 KiB
)

allErrors := make([]error, 0)
allInfo := make([]SendBatchErrorEntry, 0)

batchesSent := 0

batch := make([]int, 0)
totalSize := 0

sendBatch := func() {
batchBodies := make([]string, len(batch))

for i, batchIndex := range batch {
batchBodies[i] = bodies[batchIndex]
}

err := s.SendBatch(ctx, queueURL, batchBodies)
var sbe *SendBatchError
if errors.As(err, &sbe) {
allErrors = append(allErrors, err)

// Update index so that index refers to the position in given bodies slice.
for i := range sbe.Info {
sbe.Info[i].Index = batch[sbe.Info[i].Index]
}

allInfo = append(allInfo, sbe.Info...)
}

batchesSent++
batch = batch[:0]
totalSize = 0
}

for i, body := range bodies {

// Check if any single message is too big
if len(body) > maxSize {
allErrors = append(allErrors, errors.New("message too big to send"))
allInfo = append(allInfo, SendBatchErrorEntry{
Index: i,
})
continue
}
// If adding the current message would exceed the batch max size or count, send the current batch.
if totalSize+len(body) > maxSize || len(batch) == maxCount {
sendBatch()
}
batch = append(batch, i)
totalSize += len(body)
}

if len(batch) > 0 {
sendBatch()
}

if len(allErrors) > 0 {
return batchesSent, &SendNBatchError{
Errors: allErrors,
Info: allInfo,
}
}

return batchesSent, nil
}

type DeleteBatchError struct {
Err error
Info []DeleteBatchErrorEntry
}

type DeleteBatchErrorEntry struct {
Entry types.BatchResultErrorEntry
Index int
}

func (d *DeleteBatchError) Error() string {
return fmt.Sprintf("%v: %v messages failed to delete", d.Err, len(d.Info))
}

func (d *DeleteBatchError) Unwrap() error {
return d.Err
}

type DeleteNBatchError struct {
Errors []error
Info []DeleteBatchErrorEntry
}

func (s *DeleteNBatchError) Error() string {
var allErrors string
for _, err := range s.Errors {
allErrors += fmt.Sprintf("%s,", err.Error())
}
allErrors = strings.TrimSuffix(allErrors, ",")
return fmt.Sprintf("%v error(s) deleting batches: %s", len(s.Errors), allErrors)
}

// DeleteBatch deletes up to 10 messages from an SQS queue in a single batch.
// If an error occurs on any or all messages, a DeleteBatchError is returned that lets
// the caller know the indice/s in receiptHandles that failed.
func (s *SQS) DeleteBatch(ctx context.Context, queueURL string, receiptHandles []string) error {
entries := make([]types.DeleteMessageBatchRequestEntry, len(receiptHandles))
for i, receipt := range receiptHandles {
entries[i] = types.DeleteMessageBatchRequestEntry{
Id: aws.String(fmt.Sprintf("delete-message-%d", i)),
ReceiptHandle: aws.String(receipt),
}
}

output, err := s.client.DeleteMessageBatch(ctx, &sqs.DeleteMessageBatchInput{
Entries: entries,
QueueUrl: &queueURL,
})
if err != nil {
info := make([]DeleteBatchErrorEntry, len(entries))
for i := range entries {
info[i] = DeleteBatchErrorEntry{
Index: i,
}
}
return &DeleteBatchError{Err: err, Info: info}
}
if len(output.Failed) > 0 {
info := make([]DeleteBatchErrorEntry, len(output.Failed))
for i, errorEntry := range output.Failed {
for j, requestEntry := range entries {
if aws.ToString(requestEntry.Id) == aws.ToString(errorEntry.Id) {
info[i] = DeleteBatchErrorEntry{
Entry: errorEntry,
Index: j,
}
break
}
}
}
return &DeleteBatchError{Info: info}
}
return nil
}

// DeleteNBatch deletes any number of messages from a given SQS queue via a series of DeleteBatch calls.
// If an error occurs on any or all messages, a DeleteNBatchError is returned that lets
// the caller know the receipt handles that failed.
// Returns the number of API calls to DeleteBatch made.
func (s *SQS) DeleteNBatch(ctx context.Context, queueURL string, receiptHandles []string) (int, error) {

var (
bodiesLen = len(bodies)
maxlen = 10
times = int(math.Ceil(float64(bodiesLen) / float64(maxlen)))
receiptCount = len(receiptHandles)
maxlen = 10
times = int(math.Ceil(float64(receiptCount) / float64(maxlen)))
)

allErrors := make([]error, 0)
allInfo := make([]DeleteBatchErrorEntry, 0)

batchesDeleted := 0

for i := 0; i < times; i++ {
batch_end := maxlen * (i + 1)
if maxlen*(i+1) > bodiesLen {
batch_end = bodiesLen
if maxlen*(i+1) > receiptCount {
batch_end = receiptCount
}
var bodies_batch = bodies[maxlen*i : batch_end]
err := s.SendBatch(ctx, queueURL, bodies_batch)
if err != nil {
return err
var receipt_batch = receiptHandles[maxlen*i : batch_end]

indexMap := make(map[int]int, 0)
count := 0
for j := maxlen * i; j < batch_end; j++ {
indexMap[count] = j
count++
}

err := s.DeleteBatch(ctx, queueURL, receipt_batch)
var dbe *DeleteBatchError
if errors.As(err, &dbe) {
allErrors = append(allErrors, err)

// Update index so that index refers to the position in given receiptHandles slice.
for i := range dbe.Info {
dbe.Info[i].Index = indexMap[dbe.Info[i].Index]
}

allInfo = append(allInfo, dbe.Info...)
}
batchesDeleted++
}
return nil

if len(allErrors) > 0 {
return batchesDeleted, &DeleteNBatchError{
Errors: allErrors,
Info: allInfo,
}
}
return batchesDeleted, nil
}

// GetQueueUrl returns an AWS SQS queue URL given its name.
Expand Down
Loading