diff --git a/acp_test.go b/acp_test.go index 915bb87..f31756c 100644 --- a/acp_test.go +++ b/acp_test.go @@ -2,6 +2,7 @@ package acp import ( "context" + "encoding/json" "io" "slices" "sync" @@ -467,6 +468,107 @@ func TestConnectionHandlesNotifications(t *testing.T) { } } +func TestConnection_DoesNotCancelInboundContextBeforeDrainingNotificationsOnDisconnect(t *testing.T) { + const n = 25 + + incomingR, incomingW := io.Pipe() + + var ( + wg sync.WaitGroup + canceledCount atomic.Int64 + ) + wg.Add(n) + + c := NewConnection(func(ctx context.Context, method string, _ json.RawMessage) (any, *RequestError) { + defer wg.Done() + // Slow down processing so some notifications are handled after the receive + // loop observes EOF and signals disconnect. + time.Sleep(10 * time.Millisecond) + if ctx.Err() != nil { + canceledCount.Add(1) + } + return nil, nil + }, io.Discard, incomingR) + + // Write notifications quickly and then close the stream to simulate a peer disconnect. + for i := 0; i < n; i++ { + if _, err := io.WriteString(incomingW, `{"jsonrpc":"2.0","method":"test/notify","params":{}}`+"\n"); err != nil { + t.Fatalf("write notification: %v", err) + } + } + _ = incomingW.Close() + + select { + case <-c.Done(): + // Expected: peer disconnect observed promptly. + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting for connection Done()") + } + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatalf("timeout waiting for notification handlers") + } + + if got := canceledCount.Load(); got != 0 { + t.Fatalf("inbound handler context was canceled for %d/%d notifications", got, n) + } +} + +func TestConnection_CancelsRequestHandlersOnDisconnectEvenWithNotificationBacklog(t *testing.T) { + const numNotifications = 200 + + incomingR, incomingW := io.Pipe() + + reqDone := make(chan struct{}) + + c := NewConnection(func(ctx context.Context, method string, _ json.RawMessage) (any, *RequestError) { + switch method { + case "test/notify": + // Slow down to create a backlog of queued notifications. + time.Sleep(5 * time.Millisecond) + return nil, nil + case "test/request": + // Requests should be canceled promptly on disconnect (uses c.ctx). + <-ctx.Done() + close(reqDone) + return nil, NewInternalError(map[string]any{"error": "canceled"}) + default: + return nil, nil + } + }, io.Discard, incomingR) + + for i := 0; i < numNotifications; i++ { + if _, err := io.WriteString(incomingW, `{"jsonrpc":"2.0","method":"test/notify","params":{}}`+"\n"); err != nil { + t.Fatalf("write notification: %v", err) + } + } + if _, err := io.WriteString(incomingW, `{"jsonrpc":"2.0","id":1,"method":"test/request","params":{}}`+"\n"); err != nil { + t.Fatalf("write request: %v", err) + } + _ = incomingW.Close() + + // Disconnect should be observed quickly. + select { + case <-c.Done(): + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting for connection Done()") + } + + // Even with a big notification backlog, the request handler should be canceled promptly. + select { + case <-reqDone: + case <-time.After(1 * time.Second): + t.Fatalf("timeout waiting for request handler cancellation") + } +} + // Test initialize method behavior func TestConnectionHandlesInitialize(t *testing.T) { c2aR, c2aW := io.Pipe() diff --git a/connection.go b/connection.go index 5e4865b..cae8b59 100644 --- a/connection.go +++ b/connection.go @@ -10,6 +10,7 @@ import ( "log/slog" "sync" "sync/atomic" + "time" ) type anyMessage struct { @@ -37,27 +38,45 @@ type Connection struct { nextID atomic.Uint64 pending map[string]*pendingResponse + // ctx/cancel govern connection lifetime and are used for Done() and for canceling + // callers waiting on responses when the peer disconnects. ctx context.Context cancel context.CancelCauseFunc + // inboundCtx/inboundCancel are used when invoking the inbound MethodHandler. + // This ctx is intentionally kept alive long enough to process notifications + // that were successfully received and queued just before a peer disconnect. + // Otherwise, handlers that respect context cancellation may drop end-of-connection + // messages that we already read off the wire. + inboundCtx context.Context + inboundCancel context.CancelCauseFunc + logger *slog.Logger // notificationWg tracks in-flight notification handlers. This ensures SendRequest waits // for all notifications received before the response to complete processing. notificationWg sync.WaitGroup + + // notificationQueue serializes notification processing to maintain order + notificationQueue chan *anyMessage } func NewConnection(handler MethodHandler, peerInput io.Writer, peerOutput io.Reader) *Connection { ctx, cancel := context.WithCancelCause(context.Background()) + inboundCtx, inboundCancel := context.WithCancelCause(context.Background()) c := &Connection{ - w: peerInput, - r: peerOutput, - handler: handler, - pending: make(map[string]*pendingResponse), - ctx: ctx, - cancel: cancel, + w: peerInput, + r: peerOutput, + handler: handler, + pending: make(map[string]*pendingResponse), + ctx: ctx, + cancel: cancel, + inboundCtx: inboundCtx, + inboundCancel: inboundCancel, + notificationQueue: make(chan *anyMessage, 1024), } go c.receive() + go c.processNotifications() return c } @@ -98,20 +117,74 @@ func (c *Connection) receive() { case msg.ID != nil && msg.Method == "": c.handleResponse(&msg) case msg.Method != "": + // Requests (method+id) must not be serialized behind notifications, otherwise + // a long-running request (e.g. session/prompt) can deadlock cancellation + // notifications (session/cancel) that are required to stop it. + if msg.ID != nil { + m := msg + go c.handleInbound(&m) + continue + } + c.notificationWg.Add(1) - go func(m *anyMessage) { - defer c.notificationWg.Done() - c.handleInbound(m) - }(&msg) + // Queue the notification for sequential processing. + // If the queue is full, fall back to concurrent processing (old behavior) + // to avoid blocking the receive loop and prevent requests from hanging. + m := msg + select { + case c.notificationQueue <- &m: + // Successfully queued for sequential processing + default: + // Queue is full - process concurrently to avoid blocking the receive loop. + // This maintains backward compatibility and prevents the receive loop from + // stalling, which could cause requests to hang waiting for responses. + go func(m *anyMessage) { + defer c.notificationWg.Done() + c.handleInbound(m) + }(&m) + } default: c.loggerOrDefault().Error("received message with neither id nor method", "raw", string(line)) } } - c.cancel(errors.New("peer connection closed")) + cause := errors.New("peer connection closed") + + // First, signal disconnect to callers waiting on responses. + c.cancel(cause) + + // Then close the notification queue so already-received messages can drain. + // IMPORTANT: Do not block this receive goroutine waiting for the drain to complete; + // notification handlers may legitimately block until their context is canceled. + close(c.notificationQueue) + + // Cancel inboundCtx after notifications finish, but ensure we don't leak forever if a + // handler blocks waiting for cancellation. + const drainTimeout = 5 * time.Second + go func() { + done := make(chan struct{}) + go func() { + c.notificationWg.Wait() + close(done) + }() + select { + case <-done: + case <-time.After(drainTimeout): + } + c.inboundCancel(cause) + }() + c.loggerOrDefault().Info("peer connection closed") } +// processNotifications processes notifications sequentially to maintain order +func (c *Connection) processNotifications() { + for msg := range c.notificationQueue { + c.handleInbound(msg) + c.notificationWg.Done() + } +} + func (c *Connection) handleResponse(msg *anyMessage) { idStr := string(*msg.ID) @@ -129,6 +202,15 @@ func (c *Connection) handleResponse(msg *anyMessage) { func (c *Connection) handleInbound(req *anyMessage) { res := anyMessage{JSONRPC: "2.0"} + + // Notifications are allowed a slightly longer-lived context during disconnect so we can + // process already-received end-of-connection messages. Requests, however, should be + // canceled promptly when the peer disconnects to avoid doing unnecessary work after + // the caller is gone. + ctx := c.ctx + if req.ID == nil { + ctx = c.inboundCtx + } // copy ID if present if req.ID != nil { res.ID = req.ID @@ -141,7 +223,7 @@ func (c *Connection) handleInbound(req *anyMessage) { return } - result, err := c.handler(c.ctx, req.Method, req.Params) + result, err := c.handler(ctx, req.Method, req.Params) if req.ID == nil { // Notification: no response is sent; log handler errors to surface decode failures. if err != nil {