diff --git a/src/stac_auth_proxy/handlers/reverse_proxy.py b/src/stac_auth_proxy/handlers/reverse_proxy.py index 3cae7137..203d103d 100644 --- a/src/stac_auth_proxy/handlers/reverse_proxy.py +++ b/src/stac_auth_proxy/handlers/reverse_proxy.py @@ -35,7 +35,13 @@ def __post_init__(self): ) def _prepare_headers(self, request: Request) -> MutableHeaders: - """Prepare headers for the proxied request.""" + """ + Prepare headers for the proxied request. Construct a Forwarded header to inform + the upstream API about the original request context, which will allow it to + properly construct URLs in responses (namely, in the Links). If there are + existing X-Forwarded-*/Forwarded headers (typically, in situations where the + STAC Auth Proxy is behind a proxy like Traefik or NGINX), we use those values. + """ headers = MutableHeaders(request.headers) headers.setdefault("Via", f"1.1 {self.proxy_name}") @@ -44,18 +50,29 @@ def _prepare_headers(self, request: Request) -> MutableHeaders: ) proxy_proto = headers.get("X-Forwarded-Proto", request.url.scheme) proxy_host = headers.get("X-Forwarded-Host", request.url.netloc) + proxy_port = str(headers.get("X-Forwarded-Port", request.url.port)) proxy_path = headers.get("X-Forwarded-Path", request.base_url.path) + + # NOTE: If we don't include a port, it's possible that the upstream server may + # mistakenly use the port from the Host header (which may be the internal port + # of the upstream server) when constructing URLs. + forwarded_host = proxy_host + if proxy_port: + forwarded_host = f"{forwarded_host}:{proxy_port}" + headers.setdefault( "Forwarded", - f"for={proxy_client};host={proxy_host};proto={proxy_proto};path={proxy_path}", + f"for={proxy_client};host={forwarded_host};proto={proxy_proto};path={proxy_path}", ) # NOTE: This is useful if the upstream API does not support the Forwarded header + # and there were no existing X-Forwarded-* headers on the incoming request. if self.legacy_forwarded_headers: headers.setdefault("X-Forwarded-For", proxy_client) headers.setdefault("X-Forwarded-Host", proxy_host) headers.setdefault("X-Forwarded-Path", proxy_path) headers.setdefault("X-Forwarded-Proto", proxy_proto) + headers.setdefault("X-Forwarded-Port", proxy_port) # Set host to the upstream host if self.override_host: diff --git a/tests/test_reverse_proxy.py b/tests/test_reverse_proxy.py index bb840268..6dd525a3 100644 --- a/tests/test_reverse_proxy.py +++ b/tests/test_reverse_proxy.py @@ -282,3 +282,31 @@ async def test_nginx_headers_behavior(scope_overrides, headers, expected_forward assert f"{key}={expected_value}" in forwarded, ( f"Expected {key}={expected_value} in {forwarded}" ) + + +@pytest.mark.parametrize("legacy_headers", [False, True]) +@pytest.mark.asyncio +async def test_x_forwarded_port_in_forwarded_header(legacy_headers): + """Test that x-forwarded-port is included in the Forwarded header.""" + headers = [ + (b"host", b"localhost:8000"), + (b"user-agent", b"test-agent"), + (b"x-forwarded-port", b"443"), + (b"x-forwarded-proto", b"https"), + (b"x-forwarded-host", b"api.example.com"), + ] + request = create_request(headers=headers) + handler = ReverseProxyHandler( + upstream="http://upstream-api.com", legacy_forwarded_headers=legacy_headers + ) + result_headers = handler._prepare_headers(request) + + # Check that the Forwarded header includes the port + forwarded = result_headers["Forwarded"] + assert "host=api.example.com:443" in forwarded, ( + f"Expected host=api.example.com:443 in {forwarded}" + ) + assert "proto=https" in forwarded + + # Check that the x-forwarded-port header is preserved + assert result_headers["X-Forwarded-Port"] == "443"