Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion openhands-agent-server/openhands/agent_server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from openhands.agent_server.event_router import event_router
from openhands.agent_server.file_router import file_router
from openhands.agent_server.git_router import git_router
from openhands.agent_server.middleware import LocalhostCORSMiddleware
from openhands.agent_server.middleware import (
ActivityTrackingMiddleware,
LocalhostCORSMiddleware,
)
from openhands.agent_server.server_details_router import (
get_server_info,
server_details_router,
Expand Down Expand Up @@ -322,6 +325,10 @@ def create_app(config: Config | None = None) -> FastAPI:
_add_api_routes(app, config)
_setup_static_files(app, config)
app.add_middleware(LocalhostCORSMiddleware, allow_origins=config.allow_cors_origins)
# Add activity tracking middleware to update last activity time on every
# HTTP request. This ensures runtime-api can accurately detect idle vs
# active runtimes.
app.add_middleware(ActivityTrackingMiddleware)
_add_exception_handlers(app)

return app
Expand Down
23 changes: 23 additions & 0 deletions openhands-agent-server/openhands/agent_server/middleware.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from urllib.parse import urlparse

from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import ASGIApp

from openhands.agent_server.server_details_router import update_last_execution_time


class LocalhostCORSMiddleware(CORSMiddleware):
"""Custom CORS middleware that allows any request from localhost/127.0.0.1 domains,
Expand Down Expand Up @@ -30,3 +35,21 @@ def is_allowed_origin(self, origin: str) -> bool:
# For missing origin or other origins, use the parent class's logic
result: bool = super().is_allowed_origin(origin)
return result


class ActivityTrackingMiddleware(BaseHTTPMiddleware):
"""Middleware that tracks HTTP request activity for idle detection.

This middleware updates the last activity timestamp on every HTTP request,
ensuring that the runtime-api can accurately detect when the server is
actually idle vs actively serving requests.

This fixes the issue where runtime-api would kill active runtimes because
it only tracked its own API calls, not the actual HTTP traffic to the pods.
"""

async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[type-arg]
# Update activity timestamp before processing the request
update_last_execution_time()
response = await call_next(request)
return response
23 changes: 17 additions & 6 deletions openhands-agent-server/openhands/agent_server/sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,19 @@ async def events_socket(
await event_service.send_message(message, True)
except WebSocketDisconnect:
logger.info(f"Event websocket disconnected: {conversation_id}")
# Exit the loop when websocket disconnects
# Exit the loop when websocket disconnects - this is normal behavior
# and should NOT trigger server shutdown
return
except Exception as e:
logger.exception("error_in_subscription", stack_info=True)
# For critical errors that indicate the websocket is broken, exit
# For critical errors that indicate the websocket is broken,
# exit gracefully without re-raising to avoid server shutdown
if isinstance(e, (RuntimeError, ConnectionError)):
raise
logger.warning(
f"WebSocket connection error for {conversation_id}, "
"closing connection gracefully"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure about not raising on RuntimeError? 🤔

Tiny nit: we could refactor these as except (exceptions) as e: like the WebSocketDisconnect code above, rather than isinstance

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Closed the PR and make 2 smaller PRs
#1636
#1635

)
return
# For other exceptions, continue the loop
finally:
await event_service.unsubscribe_from_events(subscriber_id)
Expand Down Expand Up @@ -124,14 +130,19 @@ async def bash_events_socket(
request = ExecuteBashRequest.model_validate(data)
await bash_event_service.start_bash_command(request)
except WebSocketDisconnect:
# Exit the loop when websocket disconnects
# Exit the loop when websocket disconnects - this is normal behavior
# and should NOT trigger server shutdown
logger.info("Bash websocket disconnected")
return
except Exception as e:
logger.exception("error_in_bash_event_subscription", stack_info=True)
# For critical errors that indicate the websocket is broken, exit
# For critical errors that indicate the websocket is broken,
# exit gracefully without re-raising to avoid server shutdown
if isinstance(e, (RuntimeError, ConnectionError)):
raise
logger.warning(
"Bash WebSocket connection error, closing connection gracefully"
)
return
# For other exceptions, continue the loop
finally:
await bash_event_service.unsubscribe_from_events(subscriber_id)
Expand Down
16 changes: 10 additions & 6 deletions tests/agent_server/test_event_router_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,11 @@ def side_effect():
async def test_websocket_unsubscribe_in_finally_when_no_disconnect(
self, mock_websocket, mock_event_service, sample_conversation_id
):
"""Test that unsubscription happens in finally block when no disconnect."""
"""Test that unsubscription happens in finally block when RuntimeError occurs.

RuntimeError is now handled gracefully (no exception raised) to avoid
server shutdown, but cleanup should still happen.
"""
# Simulate a different kind of exception that doesn't trigger disconnect handler
mock_websocket.receive_json.side_effect = RuntimeError("Unexpected error")

Expand All @@ -249,11 +253,11 @@ async def test_websocket_unsubscribe_in_finally_when_no_disconnect(

from openhands.agent_server.sockets import events_socket

# This should raise the RuntimeError but still clean up
with pytest.raises(RuntimeError):
await events_socket(
sample_conversation_id, mock_websocket, session_api_key=None
)
# RuntimeError is now handled gracefully (no exception raised)
# to avoid server shutdown when websocket connection errors occur
await events_socket(
sample_conversation_id, mock_websocket, session_api_key=None
)

# Should still unsubscribe in the finally block
mock_event_service.unsubscribe_from_events.assert_called_once()
Expand Down
121 changes: 121 additions & 0 deletions tests/agent_server/test_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""Tests for the agent server middleware functionality."""

import time

from fastapi import FastAPI
from fastapi.testclient import TestClient

from openhands.agent_server.middleware import ActivityTrackingMiddleware
from openhands.agent_server.server_details_router import (
_last_event_time,
update_last_execution_time,
)


def test_activity_tracking_middleware_updates_last_activity_time():
"""Test that ActivityTrackingMiddleware updates last activity time on requests."""
# Create a simple FastAPI app with the middleware
app = FastAPI()
app.add_middleware(ActivityTrackingMiddleware)

@app.get("/test")
async def test_endpoint():
return {"status": "ok"}

client = TestClient(app)

# Record the initial last event time
initial_time = _last_event_time

# Wait a small amount to ensure time difference
time.sleep(0.01)

# Make a request
response = client.get("/test")
assert response.status_code == 200

# Import the module-level variable again to get the updated value
from openhands.agent_server import server_details_router

# The last event time should have been updated
assert server_details_router._last_event_time > initial_time


def test_activity_tracking_middleware_updates_on_every_request():
"""Test that ActivityTrackingMiddleware updates on every HTTP request."""
# Create a simple FastAPI app with the middleware
app = FastAPI()
app.add_middleware(ActivityTrackingMiddleware)

@app.get("/test1")
async def test_endpoint1():
return {"status": "ok"}

@app.get("/test2")
async def test_endpoint2():
return {"status": "ok"}

client = TestClient(app)

# Make first request
response1 = client.get("/test1")
assert response1.status_code == 200

from openhands.agent_server import server_details_router

time_after_first = server_details_router._last_event_time

# Wait a small amount
time.sleep(0.01)

# Make second request
response2 = client.get("/test2")
assert response2.status_code == 200

# The last event time should have been updated again
assert server_details_router._last_event_time > time_after_first


def test_activity_tracking_middleware_updates_on_error_responses():
"""Test that ActivityTrackingMiddleware updates even when endpoint returns error."""
# Create a simple FastAPI app with the middleware
app = FastAPI()
app.add_middleware(ActivityTrackingMiddleware)

@app.get("/error")
async def error_endpoint():
from fastapi import HTTPException

raise HTTPException(status_code=500, detail="Test error")

client = TestClient(app, raise_server_exceptions=False)

from openhands.agent_server import server_details_router

initial_time = server_details_router._last_event_time

# Wait a small amount
time.sleep(0.01)

# Make a request that will return an error
response = client.get("/error")
assert response.status_code == 500

# The last event time should still have been updated
assert server_details_router._last_event_time > initial_time


def test_update_last_execution_time_function():
"""Test that update_last_execution_time function works correctly."""
from openhands.agent_server import server_details_router

initial_time = server_details_router._last_event_time

# Wait a small amount
time.sleep(0.01)

# Call the function
update_last_execution_time()

# The time should have been updated
assert server_details_router._last_event_time > initial_time
Loading