Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions src/stac_auth_proxy/handlers/reverse_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@ def __post_init__(self):
)

def _prepare_headers(self, request: Request) -> MutableHeaders:
"""Prepare headers for the proxied request."""
"""
Prepare headers for the proxied request. Construct a Forwarded header to inform
the upstream API about the original request context, which will allow it to
properly construct URLs in responses (namely, in the Links). If there are
existing X-Forwarded-*/Forwarded headers (typically, in situations where the
STAC Auth Proxy is behind a proxy like Traefik or NGINX), we use those values.
"""
headers = MutableHeaders(request.headers)
headers.setdefault("Via", f"1.1 {self.proxy_name}")

Expand All @@ -44,18 +50,29 @@ def _prepare_headers(self, request: Request) -> MutableHeaders:
)
proxy_proto = headers.get("X-Forwarded-Proto", request.url.scheme)
proxy_host = headers.get("X-Forwarded-Host", request.url.netloc)
proxy_port = str(headers.get("X-Forwarded-Port", request.url.port))
proxy_path = headers.get("X-Forwarded-Path", request.base_url.path)

# NOTE: If we don't include a port, it's possible that the upstream server may
# mistakenly use the port from the Host header (which may be the internal port
# of the upstream server) when constructing URLs.
forwarded_host = proxy_host
if proxy_port:
forwarded_host = f"{forwarded_host}:{proxy_port}"

headers.setdefault(
"Forwarded",
f"for={proxy_client};host={proxy_host};proto={proxy_proto};path={proxy_path}",
f"for={proxy_client};host={forwarded_host};proto={proxy_proto};path={proxy_path}",
)

# NOTE: This is useful if the upstream API does not support the Forwarded header
# and there were no existing X-Forwarded-* headers on the incoming request.
if self.legacy_forwarded_headers:
headers.setdefault("X-Forwarded-For", proxy_client)
headers.setdefault("X-Forwarded-Host", proxy_host)
headers.setdefault("X-Forwarded-Path", proxy_path)
headers.setdefault("X-Forwarded-Proto", proxy_proto)
headers.setdefault("X-Forwarded-Port", proxy_port)

# Set host to the upstream host
if self.override_host:
Expand Down
28 changes: 28 additions & 0 deletions tests/test_reverse_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,31 @@ async def test_nginx_headers_behavior(scope_overrides, headers, expected_forward
assert f"{key}={expected_value}" in forwarded, (
f"Expected {key}={expected_value} in {forwarded}"
)


@pytest.mark.parametrize("legacy_headers", [False, True])
@pytest.mark.asyncio
async def test_x_forwarded_port_in_forwarded_header(legacy_headers):
"""Test that x-forwarded-port is included in the Forwarded header."""
headers = [
(b"host", b"localhost:8000"),
(b"user-agent", b"test-agent"),
(b"x-forwarded-port", b"443"),
(b"x-forwarded-proto", b"https"),
(b"x-forwarded-host", b"api.example.com"),
]
request = create_request(headers=headers)
handler = ReverseProxyHandler(
upstream="http://upstream-api.com", legacy_forwarded_headers=legacy_headers
)
result_headers = handler._prepare_headers(request)

# Check that the Forwarded header includes the port
forwarded = result_headers["Forwarded"]
assert "host=api.example.com:443" in forwarded, (
f"Expected host=api.example.com:443 in {forwarded}"
)
assert "proto=https" in forwarded

# Check that the x-forwarded-port header is preserved
assert result_headers["X-Forwarded-Port"] == "443"
Loading