diff --git a/example/natsmq/consumer/consumer.go b/example/natsmq/consumer/consumer.go index 12a5dee..3afcb0a 100644 --- a/example/natsmq/consumer/consumer.go +++ b/example/natsmq/consumer/consumer.go @@ -1,17 +1,19 @@ package main import ( + "context" "flag" - "github.com/nats-io/nats.go" - "github.com/nats-io/nats.go/jetstream" - "github.com/zeromicro/go-queue/natsmq/common" - "github.com/zeromicro/go-queue/natsmq/consumer" - "github.com/zeromicro/go-zero/core/conf" "log" "os" "os/signal" "syscall" "time" + + "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/jetstream" + "github.com/zeromicro/go-queue/natsmq/common" + "github.com/zeromicro/go-queue/natsmq/consumer" + "github.com/zeromicro/go-zero/core/conf" ) var configFile = flag.String("f", "config.yaml", "Specify the config file") @@ -28,7 +30,7 @@ type NatsConf struct { type MyConsumeHandler struct{} -func (h *MyConsumeHandler) Consume(msg jetstream.Msg) error { +func (h *MyConsumeHandler) Consume(ctx context.Context, msg jetstream.Msg) error { log.Printf("subject [%s] Received message: %s", msg.Subject(), string(msg.Data())) return nil } diff --git a/natsmq/consumer/config.go b/natsmq/consumer/config.go index a661066..5321a99 100644 --- a/natsmq/consumer/config.go +++ b/natsmq/consumer/config.go @@ -1,8 +1,10 @@ package consumer import ( - "github.com/nats-io/nats.go/jetstream" + "context" "time" + + "github.com/nats-io/nats.go/jetstream" ) // ConsumerConfig combines core consumer settings with advanced parameters. @@ -59,5 +61,5 @@ const ( // ConsumeHandler defines an interface for message processing. // Users need to implement the Consume method to handle individual messages. type ConsumeHandler interface { - Consume(msg jetstream.Msg) error + Consume(ctx context.Context, msg jetstream.Msg) error } diff --git a/natsmq/consumer/consumer.go b/natsmq/consumer/consumer.go index 701da1b..0c1334e 100644 --- a/natsmq/consumer/consumer.go +++ b/natsmq/consumer/consumer.go @@ -4,11 +4,15 @@ import ( "context" "errors" "fmt" + "log" + "github.com/nats-io/nats.go" "github.com/nats-io/nats.go/jetstream" "github.com/zeromicro/go-queue/natsmq/common" + "github.com/zeromicro/go-queue/natsmq/internal" + "github.com/zeromicro/go-zero/core/contextx" "github.com/zeromicro/go-zero/core/queue" - "log" + "go.opentelemetry.io/otel" ) // ConsumerManager manages consumer operations including NATS connection, JetStream stream initialization, @@ -205,7 +209,14 @@ func (cm *ConsumerManager) consumerSubscription(consumer jetstream.Consumer, cfg // cfg - pointer to ConsumerQueueConfig containing the message handler and acknowledgement settings // msg - the JetStream message to process func (cm *ConsumerManager) ackMessage(cfg *ConsumerQueueConfig, msg jetstream.Msg) { - if err := cfg.Handler.Consume(msg); err != nil { + headers := msg.Headers() + carrier := internal.NewHeaderCarrier(&headers) + // extract trace context from message + ctx := otel.GetTextMapPropagator().Extract(context.Background(), carrier) + // remove deadline and error control + ctx = contextx.ValueOnlyFrom(ctx) + + if err := cfg.Handler.Consume(ctx, msg); err != nil { log.Printf("message processing error: %v", err) return } diff --git a/natsmq/internal/trace.go b/natsmq/internal/trace.go new file mode 100644 index 0000000..320e6c2 --- /dev/null +++ b/natsmq/internal/trace.go @@ -0,0 +1,49 @@ +package internal + +import ( + "github.com/nats-io/nats.go" + "go.opentelemetry.io/otel/propagation" +) + +var _ propagation.TextMapCarrier = (*HeaderCarrier)(nil) + +// HeaderCarrier injects and extracts traces from NATS headers. +type HeaderCarrier struct { + headers *nats.Header +} + +// NewHeaderCarrier returns a new HeaderCarrier. +func NewHeaderCarrier(headers *nats.Header) HeaderCarrier { + return HeaderCarrier{headers: headers} +} + +// Get returns the value associated with the passed key. +func (h HeaderCarrier) Get(key string) string { + if h.headers == nil || *h.headers == nil { + return "" + } + return (*h.headers).Get(key) +} + +// Set stores the key-value pair. +func (h HeaderCarrier) Set(key string, value string) { + if h.headers == nil { + return + } + if *h.headers == nil { + *h.headers = nats.Header{} + } + (*h.headers).Set(key, value) +} + +// Keys lists the keys stored in this carrier. +func (h HeaderCarrier) Keys() []string { + if h.headers == nil || *h.headers == nil { + return []string{} + } + out := make([]string, 0, len(*h.headers)) + for key := range *h.headers { + out = append(out, key) + } + return out +} diff --git a/natsmq/internal/trace_test.go b/natsmq/internal/trace_test.go new file mode 100644 index 0000000..1426e02 --- /dev/null +++ b/natsmq/internal/trace_test.go @@ -0,0 +1,99 @@ +package internal + +import ( + "testing" + + "github.com/nats-io/nats.go" + "github.com/stretchr/testify/assert" +) + +func TestHeaderCarrierGet(t *testing.T) { + testCases := []struct { + name string + carrier HeaderCarrier + key string + expected string + }{ + { + name: "exists", + carrier: NewHeaderCarrier(&nats.Header{ + "foo": []string{"bar"}, + }), + key: "foo", + expected: "bar", + }, + { + name: "not exists", + carrier: NewHeaderCarrier(&nats.Header{}), + key: "foo", + expected: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := tc.carrier.Get(tc.key) + assert.Equal(t, tc.expected, result) + }) + } +} + +func TestHeaderCarrierWithNilPointer(t *testing.T) { + carrier := NewHeaderCarrier(nil) + + assert.Equal(t, "", carrier.Get("foo")) + + carrier.Set("foo", "bar") + + assert.Equal(t, []string{}, carrier.Keys()) +} + +func TestHeaderCarrierSet(t *testing.T) { + var headers nats.Header + carrier := NewHeaderCarrier(&headers) + + carrier.Set("foo", "bar") + carrier.Set("foo", "bar2") + carrier.Set("foo2", "bar3") + + assert.Equal(t, nats.Header{ + "foo": []string{"bar2"}, + "foo2": []string{"bar3"}, + }, headers) +} + +func TestHeaderCarrierKeys(t *testing.T) { + testCases := []struct { + name string + carrier HeaderCarrier + expected []string + }{ + { + name: "one", + carrier: NewHeaderCarrier(&nats.Header{ + "foo": []string{"bar"}, + }), + expected: []string{"foo"}, + }, + { + name: "none", + carrier: NewHeaderCarrier(&nats.Header{}), + expected: []string{}, + }, + { + name: "many", + carrier: NewHeaderCarrier(&nats.Header{ + "foo": []string{"bar"}, + "baz": []string{"quux"}, + }), + expected: []string{"foo", "baz"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := tc.carrier.Keys() + assert.ElementsMatch(t, tc.expected, result) + }) + } +} diff --git a/natsmq/publisher/publisher.go b/natsmq/publisher/publisher.go index d3fe201..613e0cb 100644 --- a/natsmq/publisher/publisher.go +++ b/natsmq/publisher/publisher.go @@ -3,10 +3,13 @@ package publisher import ( "context" "fmt" + "log" + "github.com/nats-io/nats.go" "github.com/nats-io/nats.go/jetstream" "github.com/zeromicro/go-queue/natsmq/common" - "log" + "github.com/zeromicro/go-queue/natsmq/internal" + "go.opentelemetry.io/otel" ) // JetStreamPublisher implements the Publisher interface by utilizing an internal JetStream context for message publishing. @@ -68,13 +71,46 @@ func (p *JetStreamPublisher) initJetStream() error { // Publish synchronously publishes a message to the specified subject and waits for a server acknowledgment. func (p *JetStreamPublisher) Publish(ctx context.Context, subject string, payload []byte) (*jetstream.PubAck, error) { - ack, err := p.js.Publish(ctx, subject, payload) + ack, err := p.PublishWithHeaders(ctx, subject, payload, nil) if err != nil { return nil, fmt.Errorf("failed to publish message on subject %s: %w", subject, err) } return ack, nil } +// PublishWithHeaders publishes a message with optional headers and waits for a server acknowledgment. +func (p *JetStreamPublisher) PublishWithHeaders(ctx context.Context, subject string, payload []byte, headers map[string]string) (*jetstream.PubAck, error) { + msg := &nats.Msg{ + Subject: subject, + Data: payload, + } + if len(headers) > 0 { + msg.Header = make(nats.Header, len(headers)) + for key, value := range headers { + msg.Header.Set(key, value) + } + } + + return p.PublishMsg(ctx, msg) +} + +// PublishMsg publishes a full NATS message and waits for a server acknowledgment. +func (p *JetStreamPublisher) PublishMsg(ctx context.Context, msg *nats.Msg) (*jetstream.PubAck, error) { + if msg == nil { + return nil, fmt.Errorf("message is nil") + } + + // inject trace context into message headers + mc := internal.NewHeaderCarrier(&msg.Header) + otel.GetTextMapPropagator().Inject(ctx, mc) + + ack, err := p.js.PublishMsg(ctx, msg) + if err != nil { + return nil, fmt.Errorf("failed to publish message on subject %s: %w", msg.Subject, err) + } + return ack, nil +} + // Close terminates the NATS connection used by the JetStreamPublisher and releases all associated resources. func (p *JetStreamPublisher) Close() { if p.conn != nil {