diff --git a/acp_test.go b/acp_test.go index 915bb87..c5a7c37 100644 --- a/acp_test.go +++ b/acp_test.go @@ -746,3 +746,76 @@ func TestPromptWaitsForSessionUpdatesComplete(t *testing.T) { "returns.", completed, numUpdates) } } + +// TestRequestHandlerCanMakeNestedRequest verifies that a request handler can make nested +// requests without deadlocking (e.g., Prompt handler calling RequestPermission). +func TestRequestHandlerCanMakeNestedRequest(t *testing.T) { + c2aR, c2aW := io.Pipe() + a2cR, a2cW := io.Pipe() + + c := NewClientSideConnection(&clientFuncs{ + WriteTextFileFunc: func(context.Context, WriteTextFileRequest) (WriteTextFileResponse, error) { + return WriteTextFileResponse{}, nil + }, + ReadTextFileFunc: func(context.Context, ReadTextFileRequest) (ReadTextFileResponse, error) { + return ReadTextFileResponse{Content: "test"}, nil + }, + RequestPermissionFunc: func(context.Context, RequestPermissionRequest) (RequestPermissionResponse, error) { + return RequestPermissionResponse{Outcome: RequestPermissionOutcome{Selected: &RequestPermissionOutcomeSelected{OptionId: "allow"}}}, nil + }, + SessionUpdateFunc: func(context.Context, SessionNotification) error { + return nil + }, + }, c2aW, a2cR) + + var ag *AgentSideConnection + ag = NewAgentSideConnection(agentFuncs{ + InitializeFunc: func(context.Context, InitializeRequest) (InitializeResponse, error) { + return InitializeResponse{ProtocolVersion: ProtocolVersionNumber, AgentCapabilities: AgentCapabilities{LoadSession: false}, AuthMethods: []AuthMethod{}}, nil + }, + NewSessionFunc: func(context.Context, NewSessionRequest) (NewSessionResponse, error) { + return NewSessionResponse{SessionId: "test-session"}, nil + }, + LoadSessionFunc: func(context.Context, LoadSessionRequest) (LoadSessionResponse, error) { + return LoadSessionResponse{}, nil + }, + AuthenticateFunc: func(context.Context, AuthenticateRequest) (AuthenticateResponse, error) { + return AuthenticateResponse{}, nil + }, + PromptFunc: func(ctx context.Context, p PromptRequest) (PromptResponse, error) { + _, err := ag.RequestPermission(ctx, RequestPermissionRequest{ + SessionId: p.SessionId, + ToolCall: RequestPermissionToolCall{ + ToolCallId: "call_1", + Title: Ptr("Test permission"), + }, + Options: []PermissionOption{ + {Kind: PermissionOptionKindAllowOnce, Name: "Allow", OptionId: "allow"}, + }, + }) + if err != nil { + return PromptResponse{}, err + } + return PromptResponse{StopReason: "end_turn"}, nil + }, + CancelFunc: func(context.Context, CancelNotification) error { return nil }, + }, a2cW, c2aR) + + if _, err := c.Initialize(context.Background(), InitializeRequest{ProtocolVersion: ProtocolVersionNumber}); err != nil { + t.Fatalf("initialize: %v", err) + } + sess, err := c.NewSession(context.Background(), NewSessionRequest{Cwd: "/", McpServers: []McpServer{}}) + if err != nil { + t.Fatalf("newSession: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + if _, err := c.Prompt(ctx, PromptRequest{ + SessionId: sess.SessionId, + Prompt: []ContentBlock{TextBlock("test")}, + }); err != nil { + t.Fatalf("prompt failed: %v", err) + } +} diff --git a/connection.go b/connection.go index 5e4865b..fc114e5 100644 --- a/connection.go +++ b/connection.go @@ -98,11 +98,18 @@ func (c *Connection) receive() { case msg.ID != nil && msg.Method == "": c.handleResponse(&msg) case msg.Method != "": - c.notificationWg.Add(1) - go func(m *anyMessage) { - defer c.notificationWg.Done() + // Only track notifications (no ID) in the WaitGroup, not requests (with ID). + // This prevents deadlock when a request handler makes another request. + isNotification := msg.ID == nil + if isNotification { + c.notificationWg.Add(1) + } + go func(m *anyMessage, isNotif bool) { + if isNotif { + defer c.notificationWg.Done() + } c.handleInbound(m) - }(&msg) + }(&msg, isNotification) default: c.loggerOrDefault().Error("received message with neither id nor method", "raw", string(line)) }