diff --git a/openhands-agent-server/openhands/agent_server/api.py b/openhands-agent-server/openhands/agent_server/api.py index abf4c24b35..ed52f94cde 100644 --- a/openhands-agent-server/openhands/agent_server/api.py +++ b/openhands-agent-server/openhands/agent_server/api.py @@ -23,7 +23,10 @@ from openhands.agent_server.event_router import event_router from openhands.agent_server.file_router import file_router from openhands.agent_server.git_router import git_router -from openhands.agent_server.middleware import LocalhostCORSMiddleware +from openhands.agent_server.middleware import ( + ActivityTrackingMiddleware, + LocalhostCORSMiddleware, +) from openhands.agent_server.server_details_router import ( get_server_info, server_details_router, @@ -322,6 +325,10 @@ def create_app(config: Config | None = None) -> FastAPI: _add_api_routes(app, config) _setup_static_files(app, config) app.add_middleware(LocalhostCORSMiddleware, allow_origins=config.allow_cors_origins) + # Add activity tracking middleware to update last activity time on every + # HTTP request. This ensures runtime-api can accurately detect idle vs + # active runtimes. + app.add_middleware(ActivityTrackingMiddleware) _add_exception_handlers(app) return app diff --git a/openhands-agent-server/openhands/agent_server/middleware.py b/openhands-agent-server/openhands/agent_server/middleware.py index 2de4309c05..2136dd4d0b 100644 --- a/openhands-agent-server/openhands/agent_server/middleware.py +++ b/openhands-agent-server/openhands/agent_server/middleware.py @@ -1,8 +1,13 @@ from urllib.parse import urlparse from fastapi.middleware.cors import CORSMiddleware +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response from starlette.types import ASGIApp +from openhands.agent_server.server_details_router import update_last_execution_time + class LocalhostCORSMiddleware(CORSMiddleware): """Custom CORS middleware that allows any request from localhost/127.0.0.1 domains, @@ -30,3 +35,21 @@ def is_allowed_origin(self, origin: str) -> bool: # For missing origin or other origins, use the parent class's logic result: bool = super().is_allowed_origin(origin) return result + + +class ActivityTrackingMiddleware(BaseHTTPMiddleware): + """Middleware that tracks HTTP request activity for idle detection. + + This middleware updates the last activity timestamp on every HTTP request, + ensuring that the runtime-api can accurately detect when the server is + actually idle vs actively serving requests. + + This fixes the issue where runtime-api would kill active runtimes because + it only tracked its own API calls, not the actual HTTP traffic to the pods. + """ + + async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[type-arg] + # Update activity timestamp before processing the request + update_last_execution_time() + response = await call_next(request) + return response diff --git a/openhands-agent-server/openhands/agent_server/sockets.py b/openhands-agent-server/openhands/agent_server/sockets.py index 5195ccf9c7..3e5feada53 100644 --- a/openhands-agent-server/openhands/agent_server/sockets.py +++ b/openhands-agent-server/openhands/agent_server/sockets.py @@ -78,13 +78,19 @@ async def events_socket( await event_service.send_message(message, True) except WebSocketDisconnect: logger.info(f"Event websocket disconnected: {conversation_id}") - # Exit the loop when websocket disconnects + # Exit the loop when websocket disconnects - this is normal behavior + # and should NOT trigger server shutdown return except Exception as e: logger.exception("error_in_subscription", stack_info=True) - # For critical errors that indicate the websocket is broken, exit + # For critical errors that indicate the websocket is broken, + # exit gracefully without re-raising to avoid server shutdown if isinstance(e, (RuntimeError, ConnectionError)): - raise + logger.warning( + f"WebSocket connection error for {conversation_id}, " + "closing connection gracefully" + ) + return # For other exceptions, continue the loop finally: await event_service.unsubscribe_from_events(subscriber_id) @@ -124,14 +130,19 @@ async def bash_events_socket( request = ExecuteBashRequest.model_validate(data) await bash_event_service.start_bash_command(request) except WebSocketDisconnect: - # Exit the loop when websocket disconnects + # Exit the loop when websocket disconnects - this is normal behavior + # and should NOT trigger server shutdown logger.info("Bash websocket disconnected") return except Exception as e: logger.exception("error_in_bash_event_subscription", stack_info=True) - # For critical errors that indicate the websocket is broken, exit + # For critical errors that indicate the websocket is broken, + # exit gracefully without re-raising to avoid server shutdown if isinstance(e, (RuntimeError, ConnectionError)): - raise + logger.warning( + "Bash WebSocket connection error, closing connection gracefully" + ) + return # For other exceptions, continue the loop finally: await bash_event_service.unsubscribe_from_events(subscriber_id) diff --git a/tests/agent_server/test_event_router_websocket.py b/tests/agent_server/test_event_router_websocket.py index eb24572dbe..04a9a26405 100644 --- a/tests/agent_server/test_event_router_websocket.py +++ b/tests/agent_server/test_event_router_websocket.py @@ -231,7 +231,11 @@ def side_effect(): async def test_websocket_unsubscribe_in_finally_when_no_disconnect( self, mock_websocket, mock_event_service, sample_conversation_id ): - """Test that unsubscription happens in finally block when no disconnect.""" + """Test that unsubscription happens in finally block when RuntimeError occurs. + + RuntimeError is now handled gracefully (no exception raised) to avoid + server shutdown, but cleanup should still happen. + """ # Simulate a different kind of exception that doesn't trigger disconnect handler mock_websocket.receive_json.side_effect = RuntimeError("Unexpected error") @@ -249,11 +253,11 @@ async def test_websocket_unsubscribe_in_finally_when_no_disconnect( from openhands.agent_server.sockets import events_socket - # This should raise the RuntimeError but still clean up - with pytest.raises(RuntimeError): - await events_socket( - sample_conversation_id, mock_websocket, session_api_key=None - ) + # RuntimeError is now handled gracefully (no exception raised) + # to avoid server shutdown when websocket connection errors occur + await events_socket( + sample_conversation_id, mock_websocket, session_api_key=None + ) # Should still unsubscribe in the finally block mock_event_service.unsubscribe_from_events.assert_called_once() diff --git a/tests/agent_server/test_middleware.py b/tests/agent_server/test_middleware.py new file mode 100644 index 0000000000..bdd03d0f8f --- /dev/null +++ b/tests/agent_server/test_middleware.py @@ -0,0 +1,121 @@ +"""Tests for the agent server middleware functionality.""" + +import time + +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from openhands.agent_server.middleware import ActivityTrackingMiddleware +from openhands.agent_server.server_details_router import ( + _last_event_time, + update_last_execution_time, +) + + +def test_activity_tracking_middleware_updates_last_activity_time(): + """Test that ActivityTrackingMiddleware updates last activity time on requests.""" + # Create a simple FastAPI app with the middleware + app = FastAPI() + app.add_middleware(ActivityTrackingMiddleware) + + @app.get("/test") + async def test_endpoint(): + return {"status": "ok"} + + client = TestClient(app) + + # Record the initial last event time + initial_time = _last_event_time + + # Wait a small amount to ensure time difference + time.sleep(0.01) + + # Make a request + response = client.get("/test") + assert response.status_code == 200 + + # Import the module-level variable again to get the updated value + from openhands.agent_server import server_details_router + + # The last event time should have been updated + assert server_details_router._last_event_time > initial_time + + +def test_activity_tracking_middleware_updates_on_every_request(): + """Test that ActivityTrackingMiddleware updates on every HTTP request.""" + # Create a simple FastAPI app with the middleware + app = FastAPI() + app.add_middleware(ActivityTrackingMiddleware) + + @app.get("/test1") + async def test_endpoint1(): + return {"status": "ok"} + + @app.get("/test2") + async def test_endpoint2(): + return {"status": "ok"} + + client = TestClient(app) + + # Make first request + response1 = client.get("/test1") + assert response1.status_code == 200 + + from openhands.agent_server import server_details_router + + time_after_first = server_details_router._last_event_time + + # Wait a small amount + time.sleep(0.01) + + # Make second request + response2 = client.get("/test2") + assert response2.status_code == 200 + + # The last event time should have been updated again + assert server_details_router._last_event_time > time_after_first + + +def test_activity_tracking_middleware_updates_on_error_responses(): + """Test that ActivityTrackingMiddleware updates even when endpoint returns error.""" + # Create a simple FastAPI app with the middleware + app = FastAPI() + app.add_middleware(ActivityTrackingMiddleware) + + @app.get("/error") + async def error_endpoint(): + from fastapi import HTTPException + + raise HTTPException(status_code=500, detail="Test error") + + client = TestClient(app, raise_server_exceptions=False) + + from openhands.agent_server import server_details_router + + initial_time = server_details_router._last_event_time + + # Wait a small amount + time.sleep(0.01) + + # Make a request that will return an error + response = client.get("/error") + assert response.status_code == 500 + + # The last event time should still have been updated + assert server_details_router._last_event_time > initial_time + + +def test_update_last_execution_time_function(): + """Test that update_last_execution_time function works correctly.""" + from openhands.agent_server import server_details_router + + initial_time = server_details_router._last_event_time + + # Wait a small amount + time.sleep(0.01) + + # Call the function + update_last_execution_time() + + # The time should have been updated + assert server_details_router._last_event_time > initial_time