From c447ac662db6a73d7532de10b7cb9bc4f4d119bf Mon Sep 17 00:00:00 2001 From: Felix Delattre Date: Thu, 11 Dec 2025 19:59:57 +0100 Subject: [PATCH 1/2] Fixed link rewriting for localhost:PORT. --- .../middleware/ProcessLinksMiddleware.py | 61 +++++- tests/test_process_links.py | 177 ++++++++++++++++++ 2 files changed, 230 insertions(+), 8 deletions(-) diff --git a/src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py b/src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py index 79b0f21f..2faf8710 100644 --- a/src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py +++ b/src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py @@ -17,6 +17,43 @@ logger = logging.getLogger(__name__) +def _extract_hostname(netloc: str) -> str: + """ + Extract hostname from netloc, ignoring port number. + + Args: + netloc: Network location string (e.g., "localhost:8080" or "example.com") + + Returns: + Hostname without port (e.g., "localhost" or "example.com") + + """ + if ":" in netloc: + if netloc.startswith("["): + # IPv6 with port: [::1]:8080 + end_bracket = netloc.rfind("]") + if end_bracket != -1: + return netloc[: end_bracket + 1] + # Regular hostname with port: localhost:8080 + return netloc.split(":", 1)[0] + return netloc + + +def _hostnames_match(hostname1: str, hostname2: str) -> bool: + """ + Check if two hostnames match, ignoring case and port. + + Args: + hostname1: First hostname (may include port) + hostname2: Second hostname (may include port) + + Returns: + True if hostnames match (case-insensitive, ignoring port) + + """ + return _extract_hostname(hostname1).lower() == _extract_hostname(hostname2).lower() + + @dataclass class ProcessLinksMiddleware(JsonResponseMiddleware): """ @@ -70,10 +107,14 @@ def _update_link( parsed_link = urlparse(link["href"]) - if parsed_link.netloc not in [ - request_url.netloc, - upstream_url.netloc, - ]: + link_hostname = _extract_hostname(parsed_link.netloc) + request_hostname = _extract_hostname(request_url.netloc) + upstream_hostname = _extract_hostname(upstream_url.netloc) + + if not ( + _hostnames_match(link_hostname, request_hostname) + or _hostnames_match(link_hostname, upstream_hostname) + ): logger.debug( "Ignoring link %s because it is not for an endpoint behind this proxy (%s or %s)", link["href"], @@ -94,10 +135,14 @@ def _update_link( return # Replace the upstream host with the client's host - if parsed_link.netloc == upstream_url.netloc: - parsed_link = parsed_link._replace(netloc=request_url.netloc)._replace( - scheme=request_url.scheme - ) + link_matches_upstream = _hostnames_match( + parsed_link.netloc, upstream_url.netloc + ) + parsed_link = parsed_link._replace(netloc=request_url.netloc) + if link_matches_upstream: + # Link hostname matches upstream: also replace scheme with request URL's scheme + parsed_link = parsed_link._replace(scheme=request_url.scheme) + # If link matches request hostname, scheme is preserved (handles https://localhost:443 -> http://localhost) # Remove the upstream prefix from the link path if upstream_url.path != "/" and parsed_link.path.startswith(upstream_url.path): diff --git a/tests/test_process_links.py b/tests/test_process_links.py index 02dce87c..748a82d6 100644 --- a/tests/test_process_links.py +++ b/tests/test_process_links.py @@ -597,3 +597,180 @@ def test_transform_with_forwarded_headers(headers, expected_base_url): # but not include the forwarded path in the response URLs assert transformed["links"][0]["href"] == f"{expected_base_url}/proxy/collections" assert transformed["links"][1]["href"] == f"{expected_base_url}/proxy" + + +@pytest.mark.parametrize( + "upstream_url,root_path,request_host,input_links,expected_links", + [ + # Basic localhost:PORT rewriting (common port 8080) + ( + "http://eoapi-stac:8080", + "/stac", + "localhost", + [ + {"rel": "data", "href": "http://localhost:8080/collections"}, + ], + [ + "http://localhost/stac/collections", + ], + ), + # Standard HTTP port + ( + "http://eoapi-stac:8080", + "/stac", + "localhost", + [ + {"rel": "self", "href": "http://localhost:80/collections"}, + ], + [ + "http://localhost/stac/collections", + ], + ), + # HTTPS port + ( + "http://eoapi-stac:8080", + "/stac", + "localhost", + [ + {"rel": "self", "href": "https://localhost:443/collections"}, + ], + [ + "https://localhost/stac/collections", + ], + ), + # Arbitrary port + ( + "http://eoapi-stac:8080", + "/stac", + "localhost", + [ + {"rel": "self", "href": "http://localhost:3000/collections"}, + ], + [ + "http://localhost/stac/collections", + ], + ), + # Multiple links with different ports + ( + "http://eoapi-stac:8080", + "/stac", + "localhost", + [ + {"rel": "self", "href": "http://localhost:8080/collections"}, + {"rel": "root", "href": "http://localhost:80/"}, + { + "rel": "items", + "href": "https://localhost:443/collections/test/items", + }, + ], + [ + "http://localhost/stac/collections", + "http://localhost/stac/", + "https://localhost/stac/collections/test/items", + ], + ), + # localhost:PORT with upstream path + ( + "http://eoapi-stac:8080/api", + "/stac", + "localhost", + [ + {"rel": "self", "href": "http://localhost:8080/api/collections"}, + ], + [ + "http://localhost/stac/collections", + ], + ), + # Request host with port should still work (port removed in rewrite) + ( + "http://eoapi-stac:8080", + "/stac", + "localhost:80", + [ + {"rel": "self", "href": "http://localhost:8080/collections"}, + ], + [ + "http://localhost:80/stac/collections", + ], + ), + ], +) +def test_transform_localhost_with_port( + upstream_url, root_path, request_host, input_links, expected_links +): + """Test transforming links with localhost:PORT (any port number).""" + middleware = ProcessLinksMiddleware( + app=None, upstream_url=upstream_url, root_path=root_path + ) + request_scope = { + "type": "http", + "path": "/test", + "headers": [ + (b"host", request_host.encode()), + (b"content-type", b"application/json"), + ], + } + + data = {"links": input_links} + transformed = middleware.transform_json(data, Request(request_scope)) + + for i, expected in enumerate(expected_links): + assert transformed["links"][i]["href"] == expected + + +def test_localhost_with_port_preserves_other_hostnames(): + """Test that links with other hostnames are not transformed.""" + middleware = ProcessLinksMiddleware( + app=None, + upstream_url="http://eoapi-stac:8080", + root_path="/stac", + ) + request_scope = { + "type": "http", + "path": "/test", + "headers": [ + (b"host", b"localhost"), + (b"content-type", b"application/json"), + ], + } + + data = { + "links": [ + {"rel": "external", "href": "http://example.com:8080/collections"}, + {"rel": "other", "href": "http://other-host:3000/collections"}, + ] + } + + transformed = middleware.transform_json(data, Request(request_scope)) + + # External hostnames should remain unchanged + assert transformed["links"][0]["href"] == "http://example.com:8080/collections" + assert transformed["links"][1]["href"] == "http://other-host:3000/collections" + + +def test_localhost_with_port_upstream_service_name_still_works(): + """Test that upstream service name matching still works.""" + middleware = ProcessLinksMiddleware( + app=None, + upstream_url="http://eoapi-stac:8080", + root_path="/stac", + ) + request_scope = { + "type": "http", + "path": "/test", + "headers": [ + (b"host", b"localhost"), + (b"content-type", b"application/json"), + ], + } + + data = { + "links": [ + {"rel": "self", "href": "http://eoapi-stac:8080/collections"}, + ] + } + + transformed = middleware.transform_json(data, Request(request_scope)) + + # Upstream service name should be rewritten to request hostname + assert transformed["links"][0]["href"] == "http://localhost/stac/collections" From d07a4d8dd18b70370177e6d932ebf572c9e1def9 Mon Sep 17 00:00:00 2001 From: Felix Delattre Date: Fri, 12 Dec 2025 00:04:03 +0100 Subject: [PATCH 2/2] Reduced logic to only fix ommited standard ports. --- .../middleware/ProcessLinksMiddleware.py | 71 +++++++++++-------- 1 file changed, 41 insertions(+), 30 deletions(-) diff --git a/src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py b/src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py index 2faf8710..adc9a274 100644 --- a/src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py +++ b/src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py @@ -18,40 +18,42 @@ def _extract_hostname(netloc: str) -> str: - """ - Extract hostname from netloc, ignoring port number. - - Args: - netloc: Network location string (e.g., "localhost:8080" or "example.com") - - Returns: - Hostname without port (e.g., "localhost" or "example.com") - - """ + """Extract hostname from netloc.""" if ":" in netloc: if netloc.startswith("["): # IPv6 with port: [::1]:8080 end_bracket = netloc.rfind("]") if end_bracket != -1: return netloc[: end_bracket + 1] - # Regular hostname with port: localhost:8080 return netloc.split(":", 1)[0] return netloc -def _hostnames_match(hostname1: str, hostname2: str) -> bool: +def _netlocs_match(netloc1: str, scheme1: str, netloc2: str, scheme2: str) -> bool: """ - Check if two hostnames match, ignoring case and port. - - Args: - hostname1: First hostname (may include port) - hostname2: Second hostname (may include port) - - Returns: - True if hostnames match (case-insensitive, ignoring port) - + Check if two netlocs match. Ports must match exactly, but missing ports + are assumed to be standard ports (80 for http, 443 for https). """ - return _extract_hostname(hostname1).lower() == _extract_hostname(hostname2).lower() + if _extract_hostname(netloc1).lower() != _extract_hostname(netloc2).lower(): + return False + + def _get_port(netloc: str, scheme: str) -> int: + if ":" in netloc: + if netloc.startswith("["): + end_bracket = netloc.rfind("]") + if end_bracket != -1 and end_bracket + 1 < len(netloc): + try: + return int(netloc[end_bracket + 2 :]) + except ValueError: + pass + else: + try: + return int(netloc.split(":", 1)[1]) + except ValueError: + pass + return 443 if scheme == "https" else 80 + + return _get_port(netloc1, scheme1) == _get_port(netloc2, scheme2) @dataclass @@ -107,13 +109,19 @@ def _update_link( parsed_link = urlparse(link["href"]) - link_hostname = _extract_hostname(parsed_link.netloc) - request_hostname = _extract_hostname(request_url.netloc) - upstream_hostname = _extract_hostname(upstream_url.netloc) - if not ( - _hostnames_match(link_hostname, request_hostname) - or _hostnames_match(link_hostname, upstream_hostname) + _netlocs_match( + parsed_link.netloc, + parsed_link.scheme, + request_url.netloc, + request_url.scheme, + ) + or _netlocs_match( + parsed_link.netloc, + parsed_link.scheme, + upstream_url.netloc, + upstream_url.scheme, + ) ): logger.debug( "Ignoring link %s because it is not for an endpoint behind this proxy (%s or %s)", @@ -135,8 +143,11 @@ def _update_link( return # Replace the upstream host with the client's host - link_matches_upstream = _hostnames_match( - parsed_link.netloc, upstream_url.netloc + link_matches_upstream = _netlocs_match( + parsed_link.netloc, + parsed_link.scheme, + upstream_url.netloc, + upstream_url.scheme, ) parsed_link = parsed_link._replace(netloc=request_url.netloc) if link_matches_upstream: