Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions acp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package acp

import (
"context"
"encoding/json"
"io"
"slices"
"sync"
Expand Down Expand Up @@ -467,6 +468,107 @@ func TestConnectionHandlesNotifications(t *testing.T) {
}
}

func TestConnection_DoesNotCancelInboundContextBeforeDrainingNotificationsOnDisconnect(t *testing.T) {
const n = 25

incomingR, incomingW := io.Pipe()

var (
wg sync.WaitGroup
canceledCount atomic.Int64
)
wg.Add(n)

c := NewConnection(func(ctx context.Context, method string, _ json.RawMessage) (any, *RequestError) {
defer wg.Done()
// Slow down processing so some notifications are handled after the receive
// loop observes EOF and signals disconnect.
time.Sleep(10 * time.Millisecond)
if ctx.Err() != nil {
canceledCount.Add(1)
}
return nil, nil
}, io.Discard, incomingR)

// Write notifications quickly and then close the stream to simulate a peer disconnect.
for i := 0; i < n; i++ {
if _, err := io.WriteString(incomingW, `{"jsonrpc":"2.0","method":"test/notify","params":{}}`+"\n"); err != nil {
t.Fatalf("write notification: %v", err)
}
}
_ = incomingW.Close()

select {
case <-c.Done():
// Expected: peer disconnect observed promptly.
case <-time.After(2 * time.Second):
t.Fatalf("timeout waiting for connection Done()")
}

done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(3 * time.Second):
t.Fatalf("timeout waiting for notification handlers")
}

if got := canceledCount.Load(); got != 0 {
t.Fatalf("inbound handler context was canceled for %d/%d notifications", got, n)
}
}

func TestConnection_CancelsRequestHandlersOnDisconnectEvenWithNotificationBacklog(t *testing.T) {
const numNotifications = 200

incomingR, incomingW := io.Pipe()

reqDone := make(chan struct{})

c := NewConnection(func(ctx context.Context, method string, _ json.RawMessage) (any, *RequestError) {
switch method {
case "test/notify":
// Slow down to create a backlog of queued notifications.
time.Sleep(5 * time.Millisecond)
return nil, nil
case "test/request":
// Requests should be canceled promptly on disconnect (uses c.ctx).
<-ctx.Done()
close(reqDone)
return nil, NewInternalError(map[string]any{"error": "canceled"})
default:
return nil, nil
}
}, io.Discard, incomingR)

for i := 0; i < numNotifications; i++ {
if _, err := io.WriteString(incomingW, `{"jsonrpc":"2.0","method":"test/notify","params":{}}`+"\n"); err != nil {
t.Fatalf("write notification: %v", err)
}
}
if _, err := io.WriteString(incomingW, `{"jsonrpc":"2.0","id":1,"method":"test/request","params":{}}`+"\n"); err != nil {
t.Fatalf("write request: %v", err)
}
_ = incomingW.Close()

// Disconnect should be observed quickly.
select {
case <-c.Done():
case <-time.After(2 * time.Second):
t.Fatalf("timeout waiting for connection Done()")
}

// Even with a big notification backlog, the request handler should be canceled promptly.
select {
case <-reqDone:
case <-time.After(1 * time.Second):
t.Fatalf("timeout waiting for request handler cancellation")
}
}

// Test initialize method behavior
func TestConnectionHandlesInitialize(t *testing.T) {
c2aR, c2aW := io.Pipe()
Expand Down
106 changes: 94 additions & 12 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"log/slog"
"sync"
"sync/atomic"
"time"
)

type anyMessage struct {
Expand Down Expand Up @@ -37,27 +38,45 @@ type Connection struct {
nextID atomic.Uint64
pending map[string]*pendingResponse

// ctx/cancel govern connection lifetime and are used for Done() and for canceling
// callers waiting on responses when the peer disconnects.
ctx context.Context
cancel context.CancelCauseFunc

// inboundCtx/inboundCancel are used when invoking the inbound MethodHandler.
// This ctx is intentionally kept alive long enough to process notifications
// that were successfully received and queued just before a peer disconnect.
// Otherwise, handlers that respect context cancellation may drop end-of-connection
// messages that we already read off the wire.
inboundCtx context.Context
inboundCancel context.CancelCauseFunc

logger *slog.Logger

// notificationWg tracks in-flight notification handlers. This ensures SendRequest waits
// for all notifications received before the response to complete processing.
notificationWg sync.WaitGroup

// notificationQueue serializes notification processing to maintain order
notificationQueue chan *anyMessage
}

func NewConnection(handler MethodHandler, peerInput io.Writer, peerOutput io.Reader) *Connection {
ctx, cancel := context.WithCancelCause(context.Background())
inboundCtx, inboundCancel := context.WithCancelCause(context.Background())
c := &Connection{
w: peerInput,
r: peerOutput,
handler: handler,
pending: make(map[string]*pendingResponse),
ctx: ctx,
cancel: cancel,
w: peerInput,
r: peerOutput,
handler: handler,
pending: make(map[string]*pendingResponse),
ctx: ctx,
cancel: cancel,
inboundCtx: inboundCtx,
inboundCancel: inboundCancel,
notificationQueue: make(chan *anyMessage, 1024),
}
go c.receive()
go c.processNotifications()
return c
}

Expand Down Expand Up @@ -98,20 +117,74 @@ func (c *Connection) receive() {
case msg.ID != nil && msg.Method == "":
c.handleResponse(&msg)
case msg.Method != "":
// Requests (method+id) must not be serialized behind notifications, otherwise
// a long-running request (e.g. session/prompt) can deadlock cancellation
// notifications (session/cancel) that are required to stop it.
if msg.ID != nil {
m := msg
go c.handleInbound(&m)
continue
}

c.notificationWg.Add(1)
go func(m *anyMessage) {
defer c.notificationWg.Done()
c.handleInbound(m)
}(&msg)
// Queue the notification for sequential processing.
// If the queue is full, fall back to concurrent processing (old behavior)
// to avoid blocking the receive loop and prevent requests from hanging.
m := msg
select {
case c.notificationQueue <- &m:
// Successfully queued for sequential processing
default:
// Queue is full - process concurrently to avoid blocking the receive loop.
// This maintains backward compatibility and prevents the receive loop from
// stalling, which could cause requests to hang waiting for responses.
go func(m *anyMessage) {
defer c.notificationWg.Done()
c.handleInbound(m)
}(&m)
}
default:
c.loggerOrDefault().Error("received message with neither id nor method", "raw", string(line))
}
}

c.cancel(errors.New("peer connection closed"))
cause := errors.New("peer connection closed")

// First, signal disconnect to callers waiting on responses.
c.cancel(cause)

// Then close the notification queue so already-received messages can drain.
// IMPORTANT: Do not block this receive goroutine waiting for the drain to complete;
// notification handlers may legitimately block until their context is canceled.
close(c.notificationQueue)

// Cancel inboundCtx after notifications finish, but ensure we don't leak forever if a
// handler blocks waiting for cancellation.
const drainTimeout = 5 * time.Second
go func() {
done := make(chan struct{})
go func() {
c.notificationWg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(drainTimeout):
}
c.inboundCancel(cause)
}()

c.loggerOrDefault().Info("peer connection closed")
}

// processNotifications processes notifications sequentially to maintain order
func (c *Connection) processNotifications() {
for msg := range c.notificationQueue {
c.handleInbound(msg)
c.notificationWg.Done()
}
}

func (c *Connection) handleResponse(msg *anyMessage) {
idStr := string(*msg.ID)

Expand All @@ -129,6 +202,15 @@ func (c *Connection) handleResponse(msg *anyMessage) {

func (c *Connection) handleInbound(req *anyMessage) {
res := anyMessage{JSONRPC: "2.0"}

// Notifications are allowed a slightly longer-lived context during disconnect so we can
// process already-received end-of-connection messages. Requests, however, should be
// canceled promptly when the peer disconnects to avoid doing unnecessary work after
// the caller is gone.
ctx := c.ctx
if req.ID == nil {
ctx = c.inboundCtx
}
// copy ID if present
if req.ID != nil {
res.ID = req.ID
Expand All @@ -141,7 +223,7 @@ func (c *Connection) handleInbound(req *anyMessage) {
return
}

result, err := c.handler(c.ctx, req.Method, req.Params)
result, err := c.handler(ctx, req.Method, req.Params)
if req.ID == nil {
// Notification: no response is sent; log handler errors to surface decode failures.
if err != nil {
Expand Down