diff --git a/go/core/x/streaming/streaming.go b/go/core/x/streaming/streaming.go new file mode 100644 index 0000000000..26f0c72444 --- /dev/null +++ b/go/core/x/streaming/streaming.go @@ -0,0 +1,382 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// Package streaming provides experimental durable streaming APIs for Genkit. +// +// APIs in this package are under active development and may change in any +// minor version release. Use with caution in production environments. +// +// When these APIs stabilize, they will be moved to their parent packages +// (e.g., core and genkit) and these exports will be deprecated. +package streaming + +import ( + "context" + "encoding/json" + "sync" + "time" + + "github.com/firebase/genkit/go/core" +) + +// StreamEventType indicates the type of stream event. +type StreamEventType int + +const ( + StreamEventChunk StreamEventType = iota + StreamEventDone + StreamEventError +) + +// StreamEvent represents an event in a durable stream. +type StreamEvent struct { + Type StreamEventType + Chunk json.RawMessage // set when Type == StreamEventChunk + Output json.RawMessage // set when Type == StreamEventDone + Err error // set when Type == StreamEventError +} + +// StreamInput provides methods for writing to a durable stream. +type StreamInput interface { + // Write sends a chunk to the stream and notifies all subscribers. + Write(chunk json.RawMessage) error + // Done marks the stream as successfully completed with the given output. + Done(output json.RawMessage) error + // Error marks the stream as failed with the given error. + Error(err error) error + // Close releases resources without marking the stream as done or errored. + Close() error +} + +// StreamManager manages durable streams, allowing creation and subscription. +// Implementations can provide different storage backends (e.g., in-memory, database, cache). +type StreamManager interface { + // Open creates a new stream for writing. + // Returns an error if a stream with the given ID already exists. + Open(ctx context.Context, streamID string) (StreamInput, error) + // Subscribe subscribes to an existing stream. + // Returns a channel that receives stream events, an unsubscribe function, and an error. + // If the stream has already completed, all buffered events are sent before the done/error event. + // Returns NOT_FOUND error if the stream doesn't exist. + Subscribe(ctx context.Context, streamID string) (<-chan StreamEvent, func(), error) +} + +// inMemoryStreamBufferSize is the buffer size for subscriber event channels. +const inMemoryStreamBufferSize = 100 + +// streamStatus represents the current state of a stream. +type streamStatus int + +const ( + streamStatusOpen streamStatus = iota + streamStatusDone + streamStatusError +) + +// streamState holds the internal state of a single stream. +type streamState struct { + status streamStatus + chunks []json.RawMessage + output json.RawMessage + err error + subscribers []chan StreamEvent + lastTouched time.Time + mu sync.RWMutex +} + +// InMemoryStreamManager is an in-memory implementation of StreamManager. +// Useful for testing or single-instance deployments where persistence is not required. +// Call Close to stop the background cleanup goroutine when the manager is no longer needed. +type InMemoryStreamManager struct { + streams map[string]*streamState + mu sync.RWMutex + ttl time.Duration + stopCh chan struct{} + doneCh chan struct{} +} + +// StreamManagerOption configures an InMemoryStreamManager. +type StreamManagerOption interface { + applyInMemoryStreamManager(*streamManagerOptions) +} + +// streamManagerOptions holds configuration for InMemoryStreamManager. +type streamManagerOptions struct { + TTL time.Duration // Time-to-live for completed streams. +} + +func (o *streamManagerOptions) applyInMemoryStreamManager(opts *streamManagerOptions) { + if o.TTL > 0 { + opts.TTL = o.TTL + } +} + +// WithTTL sets the time-to-live for completed streams. +// Streams that have completed (done or error) will be cleaned up after this duration. +// Default is 5 minutes. +func WithTTL(ttl time.Duration) StreamManagerOption { + return &streamManagerOptions{TTL: ttl} +} + +// NewInMemoryStreamManager creates a new InMemoryStreamManager. +// A background goroutine is started to periodically clean up expired streams. +// Call Close to stop the goroutine when the manager is no longer needed. +func NewInMemoryStreamManager(opts ...StreamManagerOption) *InMemoryStreamManager { + options := &streamManagerOptions{ + TTL: 5 * time.Minute, + } + for _, opt := range opts { + opt.applyInMemoryStreamManager(options) + } + m := &InMemoryStreamManager{ + streams: make(map[string]*streamState), + ttl: options.TTL, + stopCh: make(chan struct{}), + doneCh: make(chan struct{}), + } + go m.cleanupLoop() + return m +} + +// cleanupLoop runs periodically to remove expired streams. +func (m *InMemoryStreamManager) cleanupLoop() { + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + defer close(m.doneCh) + + for { + select { + case <-m.stopCh: + return + case <-ticker.C: + m.cleanupExpiredStreams() + } + } +} + +// cleanupExpiredStreams removes streams that have completed and exceeded the TTL. +func (m *InMemoryStreamManager) cleanupExpiredStreams() { + now := time.Now() + m.mu.Lock() + defer m.mu.Unlock() + + for id, state := range m.streams { + state.mu.RLock() + shouldDelete := state.status != streamStatusOpen && now.Sub(state.lastTouched) > m.ttl + state.mu.RUnlock() + if shouldDelete { + delete(m.streams, id) + } + } +} + +// Close stops the background cleanup goroutine and releases resources. +// This method blocks until the cleanup goroutine has stopped. +func (m *InMemoryStreamManager) Close() { + close(m.stopCh) + <-m.doneCh +} + +// Open creates a new stream for writing. +func (m *InMemoryStreamManager) Open(ctx context.Context, streamID string) (StreamInput, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if _, exists := m.streams[streamID]; exists { + return nil, core.NewPublicError(core.ALREADY_EXISTS, "stream already exists", nil) + } + + state := &streamState{ + status: streamStatusOpen, + chunks: make([]json.RawMessage, 0), + subscribers: make([]chan StreamEvent, 0), + lastTouched: time.Now(), + } + m.streams[streamID] = state + + return &inMemoryStreamInput{ + manager: m, + streamID: streamID, + state: state, + }, nil +} + +// Subscribe subscribes to an existing stream. +func (m *InMemoryStreamManager) Subscribe(ctx context.Context, streamID string) (<-chan StreamEvent, func(), error) { + m.mu.RLock() + state, exists := m.streams[streamID] + m.mu.RUnlock() + + if !exists { + return nil, nil, core.NewPublicError(core.NOT_FOUND, "stream not found", nil) + } + + ch := make(chan StreamEvent, inMemoryStreamBufferSize) + + state.mu.Lock() + defer state.mu.Unlock() + + // Send all buffered chunks + for _, chunk := range state.chunks { + select { + case ch <- StreamEvent{Type: StreamEventChunk, Chunk: chunk}: + case <-ctx.Done(): + close(ch) + return nil, nil, ctx.Err() + } + } + + // Handle completed streams + switch state.status { + case streamStatusDone: + ch <- StreamEvent{Type: StreamEventDone, Output: state.output} + close(ch) + return ch, func() {}, nil + case streamStatusError: + ch <- StreamEvent{Type: StreamEventError, Err: state.err} + close(ch) + return ch, func() {}, nil + } + + // Stream is still open, add subscriber + state.subscribers = append(state.subscribers, ch) + + unsubscribe := func() { + state.mu.Lock() + defer state.mu.Unlock() + for i, sub := range state.subscribers { + if sub == ch { + state.subscribers = append(state.subscribers[:i], state.subscribers[i+1:]...) + close(ch) + break + } + } + } + + return ch, unsubscribe, nil +} + +// inMemoryStreamInput implements ActionStreamInput for the in-memory manager. +type inMemoryStreamInput struct { + manager *InMemoryStreamManager + streamID string + state *streamState + closed bool + mu sync.Mutex +} + +func (s *inMemoryStreamInput) Write(chunk json.RawMessage) error { + s.mu.Lock() + if s.closed { + s.mu.Unlock() + return core.NewPublicError(core.FAILED_PRECONDITION, "stream writer is closed", nil) + } + s.mu.Unlock() + + s.state.mu.Lock() + defer s.state.mu.Unlock() + + if s.state.status != streamStatusOpen { + return core.NewPublicError(core.FAILED_PRECONDITION, "stream has already completed", nil) + } + + s.state.chunks = append(s.state.chunks, chunk) + s.state.lastTouched = time.Now() + + event := StreamEvent{Type: StreamEventChunk, Chunk: chunk} + for _, ch := range s.state.subscribers { + select { + case ch <- event: + default: + // Channel full, skip (subscriber is slow) + } + } + + return nil +} + +func (s *inMemoryStreamInput) Done(output json.RawMessage) error { + s.mu.Lock() + if s.closed { + s.mu.Unlock() + return core.NewPublicError(core.FAILED_PRECONDITION, "stream writer is closed", nil) + } + s.closed = true + s.mu.Unlock() + + s.state.mu.Lock() + defer s.state.mu.Unlock() + + if s.state.status != streamStatusOpen { + return core.NewPublicError(core.FAILED_PRECONDITION, "stream has already completed", nil) + } + + s.state.status = streamStatusDone + s.state.output = output + s.state.lastTouched = time.Now() + + event := StreamEvent{Type: StreamEventDone, Output: output} + for _, ch := range s.state.subscribers { + select { + case ch <- event: + default: + } + close(ch) + } + s.state.subscribers = nil + + return nil +} + +func (s *inMemoryStreamInput) Error(err error) error { + s.mu.Lock() + if s.closed { + s.mu.Unlock() + return core.NewPublicError(core.FAILED_PRECONDITION, "stream writer is closed", nil) + } + s.closed = true + s.mu.Unlock() + + s.state.mu.Lock() + defer s.state.mu.Unlock() + + if s.state.status != streamStatusOpen { + return core.NewPublicError(core.FAILED_PRECONDITION, "stream has already completed", nil) + } + + s.state.status = streamStatusError + s.state.err = err + s.state.lastTouched = time.Now() + + event := StreamEvent{Type: StreamEventError, Err: err} + for _, ch := range s.state.subscribers { + select { + case ch <- event: + default: + } + close(ch) + } + s.state.subscribers = nil + + return nil +} + +func (s *inMemoryStreamInput) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + s.closed = true + return nil +} diff --git a/go/core/x/streaming/streaming_test.go b/go/core/x/streaming/streaming_test.go new file mode 100644 index 0000000000..884c56b9d6 --- /dev/null +++ b/go/core/x/streaming/streaming_test.go @@ -0,0 +1,789 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package streaming + +import ( + "context" + "encoding/json" + "errors" + "sync" + "testing" + "time" + + "github.com/firebase/genkit/go/core" +) + +func TestInMemoryStreamManager_OpenAndSubscribe(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-1" + + // Open a new stream + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + if writer == nil { + t.Fatal("Open returned nil writer") + } + + // Subscribe to the stream + events, unsubscribe, err := m.Subscribe(ctx, streamID) + if err != nil { + t.Fatalf("Subscribe failed: %v", err) + } + defer unsubscribe() + + if events == nil { + t.Fatal("Subscribe returned nil channel") + } +} + +func TestInMemoryStreamManager_OpenDuplicateFails(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-dup" + + // Open first stream + _, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("First Open failed: %v", err) + } + + // Try to open duplicate + _, err = m.Open(ctx, streamID) + if err == nil { + t.Fatal("Expected error when opening duplicate stream") + } + + var ufErr *core.UserFacingError + if !errors.As(err, &ufErr) { + t.Fatalf("Expected UserFacingError, got %T", err) + } + if ufErr.Status != core.ALREADY_EXISTS { + t.Errorf("Expected ALREADY_EXISTS status, got %v", ufErr.Status) + } +} + +func TestInMemoryStreamManager_SubscribeNonExistent(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + + _, _, err := m.Subscribe(ctx, "non-existent") + if err == nil { + t.Fatal("Expected error when subscribing to non-existent stream") + } + + var ufErr *core.UserFacingError + if !errors.As(err, &ufErr) { + t.Fatalf("Expected UserFacingError, got %T", err) + } + if ufErr.Status != core.NOT_FOUND { + t.Errorf("Expected NOT_FOUND status, got %v", ufErr.Status) + } +} + +func TestInMemoryStreamManager_WriteAndReceiveChunks(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-chunks" + + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + events, unsubscribe, err := m.Subscribe(ctx, streamID) + if err != nil { + t.Fatalf("Subscribe failed: %v", err) + } + defer unsubscribe() + + // Write chunks + chunks := []string{"chunk1", "chunk2", "chunk3"} + for _, chunk := range chunks { + if err := writer.Write(json.RawMessage(`"` + chunk + `"`)); err != nil { + t.Fatalf("Write failed: %v", err) + } + } + + // Read chunks + for i, expected := range chunks { + select { + case event := <-events: + if event.Type != StreamEventChunk { + t.Errorf("Expected chunk event, got %v", event.Type) + } + var got string + if err := json.Unmarshal(event.Chunk, &got); err != nil { + t.Fatalf("Failed to unmarshal chunk: %v", err) + } + if got != expected { + t.Errorf("Chunk %d: expected %q, got %q", i, expected, got) + } + case <-time.After(time.Second): + t.Fatalf("Timeout waiting for chunk %d", i) + } + } +} + +func TestInMemoryStreamManager_Done(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-done" + + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + events, unsubscribe, err := m.Subscribe(ctx, streamID) + if err != nil { + t.Fatalf("Subscribe failed: %v", err) + } + defer unsubscribe() + + // Write a chunk + if err := writer.Write(json.RawMessage(`"test-chunk"`)); err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Mark as done + output := json.RawMessage(`{"result": "success"}`) + if err := writer.Done(output); err != nil { + t.Fatalf("Done failed: %v", err) + } + + // Should receive chunk then done + select { + case event := <-events: + if event.Type != StreamEventChunk { + t.Errorf("Expected chunk event first, got %v", event.Type) + } + case <-time.After(time.Second): + t.Fatal("Timeout waiting for chunk") + } + + select { + case event := <-events: + if event.Type != StreamEventDone { + t.Errorf("Expected done event, got %v", event.Type) + } + if string(event.Output) != string(output) { + t.Errorf("Expected output %s, got %s", output, event.Output) + } + case <-time.After(time.Second): + t.Fatal("Timeout waiting for done event") + } +} + +func TestInMemoryStreamManager_Error(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-error" + + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + events, unsubscribe, err := m.Subscribe(ctx, streamID) + if err != nil { + t.Fatalf("Subscribe failed: %v", err) + } + defer unsubscribe() + + // Mark as error + streamErr := core.NewPublicError(core.INTERNAL, "test error", nil) + if err := writer.Error(streamErr); err != nil { + t.Fatalf("Error failed: %v", err) + } + + select { + case event := <-events: + if event.Type != StreamEventError { + t.Errorf("Expected error event, got %v", event.Type) + } + if event.Err == nil { + t.Error("Expected error to be set") + } + case <-time.After(time.Second): + t.Fatal("Timeout waiting for error event") + } +} + +func TestInMemoryStreamManager_WriteAfterDone(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-write-after-done" + + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + if err := writer.Done(json.RawMessage(`"done"`)); err != nil { + t.Fatalf("Done failed: %v", err) + } + + // Try to write after done + err = writer.Write(json.RawMessage(`"chunk"`)) + if err == nil { + t.Fatal("Expected error when writing after done") + } + + var ufErr *core.UserFacingError + if !errors.As(err, &ufErr) { + t.Fatalf("Expected UserFacingError, got %T", err) + } + if ufErr.Status != core.FAILED_PRECONDITION { + t.Errorf("Expected FAILED_PRECONDITION status, got %v", ufErr.Status) + } +} + +func TestInMemoryStreamManager_WriteAfterClose(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-write-after-close" + + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + if err := writer.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } + + // Try to write after close + err = writer.Write(json.RawMessage(`"chunk"`)) + if err == nil { + t.Fatal("Expected error when writing after close") + } + + var ufErr *core.UserFacingError + if !errors.As(err, &ufErr) { + t.Fatalf("Expected UserFacingError, got %T", err) + } + if ufErr.Status != core.FAILED_PRECONDITION { + t.Errorf("Expected FAILED_PRECONDITION status, got %v", ufErr.Status) + } +} + +func TestInMemoryStreamManager_DoneAfterError(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-done-after-error" + + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + if err := writer.Error(core.NewPublicError(core.INTERNAL, "test", nil)); err != nil { + t.Fatalf("Error failed: %v", err) + } + + // Try to mark done after error + err = writer.Done(json.RawMessage(`"done"`)) + if err == nil { + t.Fatal("Expected error when calling Done after Error") + } +} + +func TestInMemoryStreamManager_MultipleSubscribers(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-multi-sub" + + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + // Create multiple subscribers + events1, unsub1, err := m.Subscribe(ctx, streamID) + if err != nil { + t.Fatalf("Subscribe 1 failed: %v", err) + } + defer unsub1() + + events2, unsub2, err := m.Subscribe(ctx, streamID) + if err != nil { + t.Fatalf("Subscribe 2 failed: %v", err) + } + defer unsub2() + + // Write a chunk + chunk := json.RawMessage(`"shared-chunk"`) + if err := writer.Write(chunk); err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Both subscribers should receive the chunk + for i, events := range []<-chan StreamEvent{events1, events2} { + select { + case event := <-events: + if event.Type != StreamEventChunk { + t.Errorf("Subscriber %d: expected chunk event, got %v", i+1, event.Type) + } + if string(event.Chunk) != string(chunk) { + t.Errorf("Subscriber %d: expected chunk %s, got %s", i+1, chunk, event.Chunk) + } + case <-time.After(time.Second): + t.Fatalf("Subscriber %d: timeout waiting for chunk", i+1) + } + } +} + +func TestInMemoryStreamManager_LateSubscriberGetsBufferedChunks(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-late-sub" + + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + // Write chunks before any subscriber + chunks := []string{"early1", "early2"} + for _, chunk := range chunks { + if err := writer.Write(json.RawMessage(`"` + chunk + `"`)); err != nil { + t.Fatalf("Write failed: %v", err) + } + } + + // Late subscriber joins + events, unsubscribe, err := m.Subscribe(ctx, streamID) + if err != nil { + t.Fatalf("Subscribe failed: %v", err) + } + defer unsubscribe() + + // Should receive buffered chunks + for i, expected := range chunks { + select { + case event := <-events: + if event.Type != StreamEventChunk { + t.Errorf("Expected chunk event, got %v", event.Type) + } + var got string + if err := json.Unmarshal(event.Chunk, &got); err != nil { + t.Fatalf("Failed to unmarshal chunk: %v", err) + } + if got != expected { + t.Errorf("Chunk %d: expected %q, got %q", i, expected, got) + } + case <-time.After(time.Second): + t.Fatalf("Timeout waiting for buffered chunk %d", i) + } + } +} + +func TestInMemoryStreamManager_SubscribeToCompletedStream(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-completed" + + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + // Write and complete before subscribing + if err := writer.Write(json.RawMessage(`"chunk1"`)); err != nil { + t.Fatalf("Write failed: %v", err) + } + if err := writer.Write(json.RawMessage(`"chunk2"`)); err != nil { + t.Fatalf("Write failed: %v", err) + } + output := json.RawMessage(`{"final": true}`) + if err := writer.Done(output); err != nil { + t.Fatalf("Done failed: %v", err) + } + + // Subscribe after completion + events, unsubscribe, err := m.Subscribe(ctx, streamID) + if err != nil { + t.Fatalf("Subscribe failed: %v", err) + } + defer unsubscribe() + + // Should receive all buffered chunks + for i := 0; i < 2; i++ { + select { + case event := <-events: + if event.Type != StreamEventChunk { + t.Errorf("Expected chunk event %d, got %v", i, event.Type) + } + case <-time.After(time.Second): + t.Fatalf("Timeout waiting for chunk %d", i) + } + } + + // Should receive done event + select { + case event := <-events: + if event.Type != StreamEventDone { + t.Errorf("Expected done event, got %v", event.Type) + } + if string(event.Output) != string(output) { + t.Errorf("Expected output %s, got %s", output, event.Output) + } + case <-time.After(time.Second): + t.Fatal("Timeout waiting for done event") + } + + // Channel should be closed + select { + case _, ok := <-events: + if ok { + t.Error("Expected channel to be closed") + } + case <-time.After(100 * time.Millisecond): + t.Error("Channel not closed after done") + } +} + +func TestInMemoryStreamManager_SubscribeToErroredStream(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-errored" + + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + // Write and error before subscribing + if err := writer.Write(json.RawMessage(`"chunk1"`)); err != nil { + t.Fatalf("Write failed: %v", err) + } + streamErr := core.NewPublicError(core.INTERNAL, "test error", nil) + if err := writer.Error(streamErr); err != nil { + t.Fatalf("Error failed: %v", err) + } + + // Subscribe after error + events, unsubscribe, err := m.Subscribe(ctx, streamID) + if err != nil { + t.Fatalf("Subscribe failed: %v", err) + } + defer unsubscribe() + + // Should receive buffered chunk + select { + case event := <-events: + if event.Type != StreamEventChunk { + t.Errorf("Expected chunk event, got %v", event.Type) + } + case <-time.After(time.Second): + t.Fatal("Timeout waiting for chunk") + } + + // Should receive error event + select { + case event := <-events: + if event.Type != StreamEventError { + t.Errorf("Expected error event, got %v", event.Type) + } + case <-time.After(time.Second): + t.Fatal("Timeout waiting for error event") + } +} + +func TestInMemoryStreamManager_Unsubscribe(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-unsub" + + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + events, unsubscribe, err := m.Subscribe(ctx, streamID) + if err != nil { + t.Fatalf("Subscribe failed: %v", err) + } + + // Unsubscribe + unsubscribe() + + // Write a chunk - should not panic + if err := writer.Write(json.RawMessage(`"chunk"`)); err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Events channel should be closed + select { + case _, ok := <-events: + if ok { + t.Error("Expected channel to be closed after unsubscribe") + } + case <-time.After(100 * time.Millisecond): + t.Error("Channel not closed after unsubscribe") + } +} + +func TestInMemoryStreamManager_WithTTL(t *testing.T) { + m := NewInMemoryStreamManager(WithTTL(10 * time.Millisecond)) + defer m.Close() + + if m.ttl != 10*time.Millisecond { + t.Errorf("Expected TTL 10ms, got %v", m.ttl) + } +} + +func TestInMemoryStreamManager_ConcurrentOperations(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-concurrent" + + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + const numSubscribers = 5 + const numChunks = 10 + + var wg sync.WaitGroup + errors := make(chan error, numSubscribers*numChunks) + + // Start subscribers + for i := 0; i < numSubscribers; i++ { + wg.Add(1) + go func(subID int) { + defer wg.Done() + + events, unsubscribe, err := m.Subscribe(ctx, streamID) + if err != nil { + errors <- err + return + } + defer unsubscribe() + + received := 0 + for event := range events { + if event.Type == StreamEventChunk { + received++ + } else if event.Type == StreamEventDone { + break + } + } + + if received != numChunks { + errors <- core.NewPublicError(core.INTERNAL, "subscriber %d received %d chunks, expected %d", nil) + } + }(i) + } + + // Give subscribers time to set up + time.Sleep(50 * time.Millisecond) + + // Write chunks concurrently + for i := 0; i < numChunks; i++ { + if err := writer.Write(json.RawMessage(`"chunk"`)); err != nil { + t.Fatalf("Write failed: %v", err) + } + } + + // Complete the stream + if err := writer.Done(json.RawMessage(`"done"`)); err != nil { + t.Fatalf("Done failed: %v", err) + } + + wg.Wait() + close(errors) + + for err := range errors { + t.Errorf("Subscriber error: %v", err) + } +} + +func TestInMemoryStreamManager_Close(t *testing.T) { + m := NewInMemoryStreamManager() + + // Close should not block + done := make(chan struct{}) + go func() { + m.Close() + close(done) + }() + + select { + case <-done: + // Success + case <-time.After(time.Second): + t.Fatal("Close blocked") + } +} + +func TestInMemoryStreamManager_CleanupExpiredStreams(t *testing.T) { + m := NewInMemoryStreamManager(WithTTL(10 * time.Millisecond)) + defer m.Close() + + ctx := context.Background() + + // Create and complete a stream + writer, err := m.Open(ctx, "expired-stream") + if err != nil { + t.Fatalf("Open failed: %v", err) + } + if err := writer.Done(json.RawMessage(`"done"`)); err != nil { + t.Fatalf("Done failed: %v", err) + } + + // Wait for TTL to expire + time.Sleep(20 * time.Millisecond) + + // Trigger cleanup + m.cleanupExpiredStreams() + + // Stream should be gone + _, _, err = m.Subscribe(ctx, "expired-stream") + if err == nil { + t.Fatal("Expected error subscribing to expired stream") + } + + var ufErr *core.UserFacingError + if !errors.As(err, &ufErr) { + t.Fatalf("Expected UserFacingError, got %T", err) + } + if ufErr.Status != core.NOT_FOUND { + t.Errorf("Expected NOT_FOUND status, got %v", ufErr.Status) + } +} + +func TestInMemoryStreamManager_OpenStreamsNotCleanedUp(t *testing.T) { + m := NewInMemoryStreamManager(WithTTL(10 * time.Millisecond)) + defer m.Close() + + ctx := context.Background() + + // Create an open stream (not completed) + _, err := m.Open(ctx, "open-stream") + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + // Wait longer than TTL + time.Sleep(20 * time.Millisecond) + + // Trigger cleanup + m.cleanupExpiredStreams() + + // Stream should still exist + _, _, err = m.Subscribe(ctx, "open-stream") + if err != nil { + t.Fatalf("Subscribe failed: %v", err) + } +} + +func TestInMemoryStreamManager_ErrorAfterClose(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-error-after-close" + + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + if err := writer.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } + + // Try to error after close + err = writer.Error(core.NewPublicError(core.INTERNAL, "test", nil)) + if err == nil { + t.Fatal("Expected error when calling Error after Close") + } + + var ufErr *core.UserFacingError + if !errors.As(err, &ufErr) { + t.Fatalf("Expected UserFacingError, got %T", err) + } + if ufErr.Status != core.FAILED_PRECONDITION { + t.Errorf("Expected FAILED_PRECONDITION status, got %v", ufErr.Status) + } +} + +func TestInMemoryStreamManager_DoneAfterClose(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-done-after-close" + + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + if err := writer.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } + + // Try to done after close + err = writer.Done(json.RawMessage(`"done"`)) + if err == nil { + t.Fatal("Expected error when calling Done after Close") + } + + var ufErr *core.UserFacingError + if !errors.As(err, &ufErr) { + t.Fatalf("Expected UserFacingError, got %T", err) + } + if ufErr.Status != core.FAILED_PRECONDITION { + t.Errorf("Expected FAILED_PRECONDITION status, got %v", ufErr.Status) + } +} diff --git a/go/genkit/servers.go b/go/genkit/servers.go index d48c11ffd6..fe69b1d7ff 100644 --- a/go/genkit/servers.go +++ b/go/genkit/servers.go @@ -31,23 +31,37 @@ import ( "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/core/logger" + "github.com/firebase/genkit/go/core/x/streaming" + "github.com/google/uuid" ) +// HandlerOption configures a Handler. type HandlerOption interface { - apply(params *handlerParams) + applyHandler(*handlerOptions) error } -// handlerParams are the parameters for an action HTTP handler. -type handlerParams struct { - ContextProviders []core.ContextProvider // Providers for action context that may be used during runtime. +// handlerOptions are options for an action HTTP handler. +type handlerOptions struct { + ContextProviders []core.ContextProvider // Providers for action context that may be used during runtime. + StreamManager streaming.StreamManager // Optional manager for durable stream storage. } -// apply applies the options to the handler params. -func (p *handlerParams) apply(params *handlerParams) { - if params.ContextProviders != nil { - panic("genkit.WithContextProviders: cannot set ContextProviders more than once") +func (o *handlerOptions) applyHandler(opts *handlerOptions) error { + if o.ContextProviders != nil { + if opts.ContextProviders != nil { + return errors.New("cannot set ContextProviders more than once (WithContextProviders)") + } + opts.ContextProviders = o.ContextProviders + } + + if o.StreamManager != nil { + if opts.StreamManager != nil { + return errors.New("cannot set StreamManager more than once (WithStreamManager)") + } + opts.StreamManager = o.StreamManager } - params.ContextProviders = p.ContextProviders + + return nil } // requestID is a unique ID for each request. @@ -56,7 +70,16 @@ var requestID atomic.Int64 // WithContextProviders adds providers for action context that may be used during runtime. // They are called in the order added and may overwrite previous context. func WithContextProviders(ctxProviders ...core.ContextProvider) HandlerOption { - return &handlerParams{ContextProviders: ctxProviders} + return &handlerOptions{ContextProviders: ctxProviders} +} + +// WithStreamManager enables durable streaming with the provided StreamManager. +// When enabled, streaming responses include an x-genkit-stream-id header that clients +// can use to reconnect to in-progress or completed streams. +// +// EXPERIMENTAL: This API is subject to change. +func WithStreamManager(manager streaming.StreamManager) HandlerOption { + return &handlerOptions{StreamManager: manager} } // Handler returns an HTTP handler function that serves the action with the provided options. @@ -67,12 +90,14 @@ func WithContextProviders(ctxProviders ...core.ContextProvider) HandlerOption { // return api.ActionContext{"myKey": "myValue"}, nil // })) func Handler(a api.Action, opts ...HandlerOption) http.HandlerFunc { - params := &handlerParams{} + options := &handlerOptions{} for _, opt := range opts { - opt.apply(params) + if err := opt.applyHandler(options); err != nil { + panic(fmt.Errorf("genkit.Handler: error applying options: %w", err)) + } } - return wrapHandler(handler(a, params)) + return wrapHandler(handler(a, options)) } // wrapHandler wraps an HTTP handler function with common logging and error handling. @@ -101,8 +126,9 @@ func wrapHandler(h func(http.ResponseWriter, *http.Request) error) http.HandlerF } } -// handler returns an HTTP handler function that serves the action with the provided params. Responses are written in server-sent events (SSE) format. -func handler(a api.Action, params *handlerParams) func(http.ResponseWriter, *http.Request) error { +// handler returns an HTTP handler function that serves the action with the provided options. +// Streaming responses are written in server-sent events (SSE) format. +func handler(a api.Action, opts *handlerOptions) func(http.ResponseWriter, *http.Request) error { return func(w http.ResponseWriter, r *http.Request) error { if a == nil { return errors.New("action is nil; cannot serve") @@ -124,29 +150,9 @@ func handler(a api.Action, params *handlerParams) func(http.ResponseWriter, *htt } stream = stream || r.Header.Get("Accept") == "text/event-stream" - var callback streamingCallback[json.RawMessage] - if stream { - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Transfer-Encoding", "chunked") - callback = func(ctx context.Context, msg json.RawMessage) error { - _, err := fmt.Fprintf(w, "data: {\"message\": %s}\n\n", msg) - if err != nil { - return err - } - if f, ok := w.(http.Flusher); ok { - f.Flush() - } - return nil - } - } else { - w.Header().Set("Content-Type", "application/json") - } - ctx := r.Context() - if params.ContextProviders != nil { - for _, ctxProvider := range params.ContextProviders { + if opts.ContextProviders != nil { + for _, ctxProvider := range opts.ContextProviders { headers := make(map[string]string, len(r.Header)) for k, v := range r.Header { headers[strings.ToLower(k)] = strings.Join(v, " ") @@ -170,22 +176,214 @@ func handler(a api.Action, params *handlerParams) func(http.ResponseWriter, *htt } } - out, err := a.RunJSON(ctx, body.Data, callback) - if err != nil { - if stream { - _, err = fmt.Fprintf(w, "data: {\"error\": {\"status\": \"INTERNAL\", \"message\": \"stream flow error\", \"details\": \"%v\"}}\n\n", err) - return err + if stream { + streamID := r.Header.Get("X-Genkit-Stream-Id") + + if streamID != "" && opts.StreamManager != nil { + return subscribeToStream(ctx, w, opts.StreamManager, streamID) } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Transfer-Encoding", "chunked") + + if opts.StreamManager != nil { + return runWithDurableStreaming(ctx, w, a, opts.StreamManager, body.Data) + } + + return runWithStreaming(ctx, w, a, body.Data) + } + + w.Header().Set("Content-Type", "application/json") + out, err := a.RunJSON(ctx, body.Data, nil) + if err != nil { return err } - if stream { - _, err = fmt.Fprintf(w, "data: {\"result\": %s}\n\n", out) + return writeResultResponse(w, out) + } +} + +// runWithStreaming executes the action with standard HTTP streaming (no durability). +func runWithStreaming(ctx context.Context, w http.ResponseWriter, a api.Action, input json.RawMessage) error { + callback := func(ctx context.Context, msg json.RawMessage) error { + if err := writeSSEMessage(w, msg); err != nil { + return err + } + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + return nil + } + + out, err := a.RunJSON(ctx, input, callback) + if err != nil { + if werr := writeSSEError(w, err); werr != nil { + return werr + } + return nil + } + return writeSSEResult(w, out) +} + +// runWithDurableStreaming executes the action with durable streaming support. +// Chunks are written to both the HTTP response and the stream manager for later replay. +func runWithDurableStreaming(ctx context.Context, w http.ResponseWriter, a api.Action, sm streaming.StreamManager, input json.RawMessage) error { + streamID := uuid.New().String() + + durableStream, err := sm.Open(ctx, streamID) + if err != nil { + return err + } + defer durableStream.Close() + + w.Header().Set("X-Genkit-Stream-Id", streamID) + + callback := func(ctx context.Context, msg json.RawMessage) error { + durableStream.Write(msg) + if err := writeSSEMessage(w, msg); err != nil { return err } + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + return nil + } + + out, err := a.RunJSON(ctx, input, callback) + if err != nil { + durableStream.Error(err) + if werr := writeSSEError(w, err); werr != nil { + return werr + } + return nil + } + + durableStream.Done(out) + return writeSSEResult(w, out) +} + +// subscribeToStream subscribes to an existing durable stream and writes events to the HTTP response. +func subscribeToStream(ctx context.Context, w http.ResponseWriter, sm streaming.StreamManager, streamID string) error { + events, unsubscribe, err := sm.Subscribe(ctx, streamID) + if err != nil { + var ufErr *core.UserFacingError + if errors.As(err, &ufErr) && ufErr.Status == core.NOT_FOUND { + w.WriteHeader(http.StatusNoContent) + return nil + } + return err + } + defer unsubscribe() + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Transfer-Encoding", "chunked") + + for event := range events { + switch event.Type { + case streaming.StreamEventChunk: + if err := writeSSEMessage(w, event.Chunk); err != nil { + return err + } + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + case streaming.StreamEventDone: + if err := writeSSEResult(w, event.Output); err != nil { + return err + } + return nil + case streaming.StreamEventError: + streamErr := event.Err + if streamErr == nil { + streamErr = errors.New("unknown error") + } + if err := writeSSEError(w, streamErr); err != nil { + return err + } + return nil + } + } - _, err = fmt.Fprintf(w, "{\"result\": %s}\n", out) + return nil +} + +// flowResultResponse wraps a final action result for JSON serialization. +type flowResultResponse struct { + Result json.RawMessage `json:"result"` +} + +// flowMessageResponse wraps a streaming chunk for JSON serialization. +type flowMessageResponse struct { + Message json.RawMessage `json:"message"` +} + +// flowErrorResponse wraps an error for JSON serialization in streaming responses. +type flowErrorResponse struct { + Error *flowError `json:"error"` +} + +// flowError represents the error payload in a streaming error response. +type flowError struct { + Status core.StatusName `json:"status"` + Message string `json:"message"` + Details string `json:"details,omitempty"` +} + +// writeResultResponse writes a JSON result response for non-streaming requests. +func writeResultResponse(w http.ResponseWriter, result json.RawMessage) error { + resp := flowResultResponse{Result: result} + data, err := json.Marshal(resp) + if err != nil { + return err + } + _, err = w.Write(data) + if err != nil { + return err + } + _, err = w.Write([]byte("\n")) + return err +} + +// writeSSEResult writes a JSON result as a server-sent event for streaming requests. +func writeSSEResult(w http.ResponseWriter, result json.RawMessage) error { + resp := flowResultResponse{Result: result} + data, err := json.Marshal(resp) + if err != nil { + return err + } + _, err = fmt.Fprintf(w, "data: %s\n\n", data) + return err +} + +// writeSSEMessage writes a streaming chunk as a server-sent event. +func writeSSEMessage(w http.ResponseWriter, msg json.RawMessage) error { + resp := flowMessageResponse{Message: msg} + data, err := json.Marshal(resp) + if err != nil { + return err + } + _, err = fmt.Fprintf(w, "data: %s\n\n", data) + return err +} + +// writeSSEError writes an error as a server-sent event for streaming requests. +func writeSSEError(w http.ResponseWriter, flowErr error) error { + resp := flowErrorResponse{ + Error: &flowError{ + Status: core.INTERNAL, + Message: "stream flow error", + Details: flowErr.Error(), + }, + } + data, err := json.Marshal(resp) + if err != nil { return err } + _, err = fmt.Fprintf(w, "data: %s\n\n", data) + return err } func parseBoolQueryParam(r *http.Request, name string) (bool, error) { diff --git a/go/genkit/servers_test.go b/go/genkit/servers_test.go index a0a07cc21b..b5a69d17ec 100644 --- a/go/genkit/servers_test.go +++ b/go/genkit/servers_test.go @@ -27,6 +27,7 @@ import ( "testing" "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/core/x/streaming" ) func FakeContextProvider(ctx context.Context, req core.RequestData) (core.ActionContext, error) { @@ -222,17 +223,17 @@ func TestStreamingHandler(t *testing.T) { t.Errorf("want status code %d, got %d", http.StatusOK, resp.StatusCode) } - expected := `data: {"message": "h"} + expected := `data: {"message":"h"} -data: {"message": "e"} +data: {"message":"e"} -data: {"message": "l"} +data: {"message":"l"} -data: {"message": "l"} +data: {"message":"l"} -data: {"message": "o"} +data: {"message":"o"} -data: {"result": "hello-end"} +data: {"result":"hello-end"} ` if string(body) != expected { @@ -256,7 +257,7 @@ data: {"result": "hello-end"} t.Errorf("want status code %d, got %d", http.StatusOK, resp.StatusCode) } - expected := `data: {"error": {"status": "INTERNAL", "message": "stream flow error", "details": "streaming error"}} + expected := `data: {"error":{"status":"INTERNAL_SERVER_ERROR","message":"stream flow error","details":"streaming error"}} ` if string(body) != expected { @@ -264,3 +265,121 @@ data: {"result": "hello-end"} } }) } + +func TestDurableStreamingHandler(t *testing.T) { + g := Init(context.Background()) + + streamingFlow := DefineStreamingFlow(g, "durableStreaming", + func(ctx context.Context, input string, cb func(context.Context, string) error) (string, error) { + for _, c := range input { + if err := cb(ctx, string(c)); err != nil { + return "", err + } + } + return input + "-done", nil + }) + + t.Run("returns stream ID header", func(t *testing.T) { + sm := streaming.NewInMemoryStreamManager() + defer sm.Close() + handler := Handler(streamingFlow, WithStreamManager(sm)) + + req := httptest.NewRequest("POST", "/", strings.NewReader(`{"data":"hi"}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + w := httptest.NewRecorder() + + handler(w, req) + + resp := w.Result() + body, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("want status code %d, got %d", http.StatusOK, resp.StatusCode) + } + + streamID := resp.Header.Get("X-Genkit-Stream-Id") + if streamID == "" { + t.Error("want X-Genkit-Stream-Id header to be set") + } + + expected := `data: {"message":"h"} + +data: {"message":"i"} + +data: {"result":"hi-done"} + +` + if string(body) != expected { + t.Errorf("want streaming body:\n%q\n\nGot:\n%q", expected, string(body)) + } + }) + + t.Run("subscribe to completed stream", func(t *testing.T) { + sm := streaming.NewInMemoryStreamManager() + defer sm.Close() + handler := Handler(streamingFlow, WithStreamManager(sm)) + + // First request - run the stream to completion + req1 := httptest.NewRequest("POST", "/", strings.NewReader(`{"data":"ab"}`)) + req1.Header.Set("Content-Type", "application/json") + req1.Header.Set("Accept", "text/event-stream") + w1 := httptest.NewRecorder() + + handler(w1, req1) + + resp1 := w1.Result() + streamID := resp1.Header.Get("X-Genkit-Stream-Id") + if streamID == "" { + t.Fatal("want X-Genkit-Stream-Id header to be set") + } + + // Second request - subscribe to the completed stream + req2 := httptest.NewRequest("POST", "/", strings.NewReader(`{"data":"ignored"}`)) + req2.Header.Set("Content-Type", "application/json") + req2.Header.Set("Accept", "text/event-stream") + req2.Header.Set("X-Genkit-Stream-Id", streamID) + w2 := httptest.NewRecorder() + + handler(w2, req2) + + resp2 := w2.Result() + body2, _ := io.ReadAll(resp2.Body) + + if resp2.StatusCode != http.StatusOK { + t.Errorf("want status code %d, got %d", http.StatusOK, resp2.StatusCode) + } + + // Should replay all chunks and the final result + expected := `data: {"message":"a"} + +data: {"message":"b"} + +data: {"result":"ab-done"} + +` + if string(body2) != expected { + t.Errorf("want replayed body:\n%q\n\nGot:\n%q", expected, string(body2)) + } + }) + + t.Run("subscribe to non-existent stream returns 204", func(t *testing.T) { + sm := streaming.NewInMemoryStreamManager() + defer sm.Close() + handler := Handler(streamingFlow, WithStreamManager(sm)) + + req := httptest.NewRequest("POST", "/", strings.NewReader(`{"data":"test"}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("X-Genkit-Stream-Id", "non-existent-stream-id") + w := httptest.NewRecorder() + + handler(w, req) + + resp := w.Result() + + if resp.StatusCode != http.StatusNoContent { + t.Errorf("want status code %d, got %d", http.StatusNoContent, resp.StatusCode) + } + }) +} diff --git a/go/samples/durable-streaming/main.go b/go/samples/durable-streaming/main.go new file mode 100644 index 0000000000..e326c17a27 --- /dev/null +++ b/go/samples/durable-streaming/main.go @@ -0,0 +1,100 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// This sample demonstrates durable streaming, which allows clients to reconnect +// to in-progress or completed streams using a stream ID. +// +// Start the server: +// +// go run . +// +// Test streaming (get a stream ID back in X-Genkit-Stream-Id header): +// +// curl -N -i -H "Accept: text/event-stream" \ +// -d '{"data": 5}' \ +// http://localhost:8080/countdown +// +// Subscribe to an existing stream using the stream ID from the previous response: +// +// curl -N -H "Accept: text/event-stream" \ +// -H "X-Genkit-Stream-Id: " \ +// -d '{"data": 5}' \ +// http://localhost:8080/countdown +// +// The subscription will replay any buffered chunks and then continue with live updates. +// If the stream has already completed, all chunks plus the final result are returned. + +package main + +import ( + "context" + "fmt" + "log" + "net/http" + "time" + + "github.com/firebase/genkit/go/core/x/streaming" + "github.com/firebase/genkit/go/genkit" +) + +func main() { + ctx := context.Background() + g := genkit.Init(ctx) + + type CountdownChunk struct { + Count int `json:"count"` + Message string `json:"message"` + Timestamp string `json:"timestamp"` + } + + // Define a streaming flow that counts down with delays. + countdown := genkit.DefineStreamingFlow(g, "countdown", + func(ctx context.Context, count int, cb func(context.Context, CountdownChunk) error) (string, error) { + if count <= 0 { + count = 5 + } + + for i := count; i > 0; i-- { + select { + case <-ctx.Done(): + return "", ctx.Err() + case <-time.After(1 * time.Second): + } + + chunk := CountdownChunk{ + Count: i, + Message: fmt.Sprintf("T-%d...", i), + Timestamp: time.Now().Format(time.RFC3339), + } + + if cb != nil { + if err := cb(ctx, chunk); err != nil { + return "", err + } + } + } + + return "Liftoff!", nil + }) + + // Set up HTTP server with durable streaming enabled. + // Completed streams are kept for 10 minutes before cleanup (while server is running). + mux := http.NewServeMux() + mux.HandleFunc("POST /countdown", genkit.Handler(countdown, + genkit.WithStreamManager(streaming.NewInMemoryStreamManager(streaming.WithTTL(10*time.Minute))), + )) + log.Fatal(http.ListenAndServe("127.0.0.1:8080", mux)) +}