diff --git a/mapillary_tools/api_v4.py b/mapillary_tools/api_v4.py index d2258a110..ae38f2809 100644 --- a/mapillary_tools/api_v4.py +++ b/mapillary_tools/api_v4.py @@ -21,6 +21,17 @@ USE_SYSTEM_CERTS: bool = False +class HTTPContentError(Exception): + """ + Raised when the HTTP response is ok (200) but the content is not as expected + e.g. not JSON or not a valid response. + """ + + def __init__(self, message: str, response: requests.Response): + self.response = response + super().__init__(message) + + class ClusterFileType(enum.Enum): ZIP = "zip" BLACKVUE = "mly_blackvue_video" @@ -58,24 +69,25 @@ def cert_verify(self, *args, **kwargs): @T.overload -def _truncate(s: bytes, limit: int = 512) -> bytes: ... +def _truncate(s: bytes, limit: int = 256) -> bytes | str: ... @T.overload -def _truncate(s: str, limit: int = 512) -> str: ... +def _truncate(s: str, limit: int = 256) -> str: ... -def _truncate(s, limit=512): +def _truncate(s, limit=256): if limit < len(s): + if isinstance(s, bytes): + try: + s = s.decode("utf-8") + except UnicodeDecodeError: + pass remaining = len(s) - limit if isinstance(s, bytes): - return ( - s[:limit] - + b"..." - + f"({remaining} more bytes truncated)".encode("utf-8") - ) + return s[:limit] + f"...({remaining} bytes truncated)".encode("utf-8") else: - return str(s[:limit]) + f"...({remaining} more chars truncated)" + return str(s[:limit]) + f"...({remaining} chars truncated)" else: return s @@ -95,7 +107,10 @@ def _sanitize(headers: T.Mapping[T.Any, T.Any]) -> T.Mapping[T.Any, T.Any]: ]: new_headers[k] = "[REDACTED]" else: - new_headers[k] = _truncate(v) + if isinstance(v, (str, bytes)): + new_headers[k] = T.cast(T.Any, _truncate(v)) + else: + new_headers[k] = v return new_headers @@ -106,7 +121,6 @@ def _log_debug_request( json: dict | None = None, params: dict | None = None, headers: dict | None = None, - timeout: T.Any = None, ): if logging.getLogger().getEffectiveLevel() <= logging.DEBUG: return @@ -126,8 +140,7 @@ def _log_debug_request( if headers: msg += f" HEADERS={_sanitize(headers)}" - if timeout is not None: - msg += f" TIMEOUT={timeout}" + msg = msg.replace("\n", "\\n") LOG.debug(msg) @@ -136,26 +149,41 @@ def _log_debug_response(resp: requests.Response): if logging.getLogger().getEffectiveLevel() <= logging.DEBUG: return - data: str | bytes + elapsed = resp.elapsed.total_seconds() * 1000 # Convert to milliseconds + msg = f"HTTP {resp.status_code} {resp.reason} ({elapsed:.0f} ms): {str(_truncate_response_content(resp))}" + + LOG.debug(msg) + + +def _truncate_response_content(resp: requests.Response) -> str | bytes: try: - data = _truncate(dumps(_sanitize(resp.json()))) - except Exception: - data = _truncate(resp.content) + json_data = resp.json() + except requests.JSONDecodeError: + if resp.content is not None: + data = _truncate(resp.content) + else: + data = "" + else: + if isinstance(json_data, dict): + data = _truncate(dumps(_sanitize(json_data))) + else: + data = _truncate(str(json_data)) + + if isinstance(data, bytes): + return data.replace(b"\n", b"\\n") - LOG.debug(f"HTTP {resp.status_code} ({resp.reason}): %s", data) + elif isinstance(data, str): + return data.replace("\n", "\\n") + + return data def readable_http_error(ex: requests.HTTPError) -> str: - req = ex.request - resp = ex.response + return readable_http_response(ex.response) - data: str | bytes - try: - data = _truncate(dumps(_sanitize(resp.json()))) - except Exception: - data = _truncate(resp.content) - return f"{req.method} {resp.url} => {resp.status_code} ({resp.reason}): {str(data)}" +def readable_http_response(resp: requests.Response) -> str: + return f"{resp.request.method} {resp.url} => {resp.status_code} {resp.reason}: {str(_truncate_response_content(resp))}" def request_post( @@ -174,7 +202,6 @@ def request_post( json=json, params=kwargs.get("params"), headers=kwargs.get("headers"), - timeout=kwargs.get("timeout"), ) if USE_SYSTEM_CERTS: @@ -208,11 +235,7 @@ def request_get( if not disable_debug: _log_debug_request( - "GET", - url, - params=kwargs.get("params"), - headers=kwargs.get("headers"), - timeout=kwargs.get("timeout"), + "GET", url, params=kwargs.get("params"), headers=kwargs.get("headers") ) if USE_SYSTEM_CERTS: @@ -335,10 +358,7 @@ def fetch_user_or_me( def log_event(action_type: ActionType, properties: dict) -> requests.Response: resp = request_post( f"{MAPILLARY_GRAPH_API_ENDPOINT}/logging", - json={ - "action_type": action_type, - "properties": properties, - }, + json={"action_type": action_type, "properties": properties}, headers={ "Authorization": f"OAuth {MAPILLARY_CLIENT_TOKEN}", }, @@ -374,3 +394,13 @@ def finish_upload( resp.raise_for_status() return resp + + +def jsonify_response(resp: requests.Response) -> T.Any: + """ + Convert the response to JSON, raising HTTPContentError if the response is not JSON. + """ + try: + return resp.json() + except requests.JSONDecodeError as ex: + raise HTTPContentError("Invalid JSON response", resp) from ex diff --git a/mapillary_tools/authenticate.py b/mapillary_tools/authenticate.py index 35bafc3e7..c2f9b8f6c 100644 --- a/mapillary_tools/authenticate.py +++ b/mapillary_tools/authenticate.py @@ -1,7 +1,6 @@ from __future__ import annotations import getpass -import json import logging import re import sys @@ -131,15 +130,18 @@ def fetch_user_items( user_items = _verify_user_auth(_validate_profile(user_items)) LOG.info( - 'Uploading to profile "%s": %s', profile_name, api_v4._sanitize(user_items) + f'Uploading to profile "{profile_name}": {user_items.get("MAPSettingsUsername")} (ID: {user_items.get("MAPSettingsUserKey")})' ) if organization_key is not None: resp = api_v4.fetch_organization( user_items["user_upload_token"], organization_key ) - LOG.info("Uploading to Mapillary organization: %s", json.dumps(resp.json())) - user_items["MAPOrganizationKey"] = organization_key + data = api_v4.jsonify_response(resp) + LOG.info( + f"Uploading to organization: {data.get('name')} (ID: {data.get('id')})" + ) + user_items["MAPOrganizationKey"] = data.get("id") return user_items @@ -182,12 +184,12 @@ def _verify_user_auth(user_items: config.UserItem) -> config.UserItem: else: raise ex - user_json = resp.json() + data = api_v4.jsonify_response(resp) return { **user_items, - "MAPSettingsUsername": user_json.get("username"), - "MAPSettingsUserKey": user_json.get("id"), + "MAPSettingsUsername": data.get("username"), + "MAPSettingsUserKey": data.get("id"), } @@ -285,7 +287,7 @@ def _prompt_login( raise ex - data = resp.json() + data = api_v4.jsonify_response(resp) user_items: config.UserItem = { "user_upload_token": str(data["access_token"]), diff --git a/mapillary_tools/commands/__main__.py b/mapillary_tools/commands/__main__.py index e50ac2fc7..10fab9224 100644 --- a/mapillary_tools/commands/__main__.py +++ b/mapillary_tools/commands/__main__.py @@ -7,6 +7,7 @@ import requests from .. import api_v4, constants, exceptions, VERSION +from ..upload import log_exception from . import ( authenticate, process, @@ -162,14 +163,16 @@ def main(): try: args.func(argvars) except requests.HTTPError as ex: - LOG.error("%s: %s", ex.__class__.__name__, api_v4.readable_http_error(ex)) + log_exception(ex) # TODO: standardize exit codes as exceptions.MapillaryUserError sys.exit(16) + except api_v4.HTTPContentError as ex: + log_exception(ex) + sys.exit(17) + except exceptions.MapillaryUserError as ex: - LOG.error( - "%s: %s", ex.__class__.__name__, ex, exc_info=log_level == logging.DEBUG - ) + log_exception(ex) sys.exit(ex.exit_code) diff --git a/mapillary_tools/constants.py b/mapillary_tools/constants.py index ab6ce003d..09396a8ed 100644 --- a/mapillary_tools/constants.py +++ b/mapillary_tools/constants.py @@ -2,6 +2,7 @@ import functools import os +import tempfile import appdirs @@ -146,6 +147,10 @@ def _parse_scaled_integers( MAPILLARY_UPLOAD_HISTORY_PATH: str = os.getenv( "MAPILLARY_UPLOAD_HISTORY_PATH", os.path.join(USER_DATA_DIR, "upload_history") ) +UPLOAD_CACHE_DIR: str = os.getenv( + _ENV_PREFIX + "UPLOAD_CACHE_DIR", + os.path.join(tempfile.gettempdir(), "mapillary_tools", "upload_cache"), +) MAX_IMAGE_UPLOAD_WORKERS: int = int( os.getenv(_ENV_PREFIX + "MAX_IMAGE_UPLOAD_WORKERS", 64) ) diff --git a/mapillary_tools/history.py b/mapillary_tools/history.py index 9ec188956..78d6a22ac 100644 --- a/mapillary_tools/history.py +++ b/mapillary_tools/history.py @@ -1,8 +1,12 @@ from __future__ import annotations +import contextlib +import dbm import json import logging import string +import threading +import time import typing as T from pathlib import Path @@ -71,3 +75,100 @@ def write_history( ] with open(path, "w") as fp: fp.write(json.dumps(history)) + + +class PersistentCache: + _lock: contextlib.nullcontext | threading.Lock + + def __init__(self, file: str): + # SQLite3 backend supports concurrent access without a lock + if dbm.whichdb(file) == "dbm.sqlite3": + self._lock = contextlib.nullcontext() + else: + self._lock = threading.Lock() + self._file = file + + def get(self, key: str) -> str | None: + s = time.perf_counter() + + with self._lock: + with dbm.open(self._file, flag="c") as db: + value: bytes | None = db.get(key) + + if value is None: + return None + + payload = self._decode(value) + + if self._is_expired(payload): + return None + + file_handle = payload.get("file_handle") + + LOG.debug( + f"Found file handle for {key} in cache ({(time.perf_counter() - s) * 1000:.0f} ms)" + ) + + return T.cast(str, file_handle) + + def set(self, key: str, file_handle: str, expires_in: int = 3600 * 24 * 2) -> None: + s = time.perf_counter() + + payload = { + "expires_at": time.time() + expires_in, + "file_handle": file_handle, + } + + value: bytes = json.dumps(payload).encode("utf-8") + + with self._lock: + with dbm.open(self._file, flag="c") as db: + db[key] = value + + LOG.debug( + f"Cached file handle for {key} ({(time.perf_counter() - s) * 1000:.0f} ms)" + ) + + def clear_expired(self) -> list[str]: + s = time.perf_counter() + + expired_keys: list[str] = [] + + with self._lock: + with dbm.open(self._file, flag="c") as db: + if hasattr(db, "items"): + items: T.Iterable[tuple[str | bytes, bytes]] = db.items() + else: + items = ((key, db[key]) for key in db.keys()) + + for key, value in items: + payload = self._decode(value) + if self._is_expired(payload): + del db[key] + expired_keys.append(T.cast(str, key)) + + if expired_keys: + LOG.debug( + f"Cleared {len(expired_keys)} expired entries from the cache ({(time.perf_counter() - s) * 1000:.0f} ms)" + ) + + return expired_keys + + def _is_expired(self, payload: JSONDict) -> bool: + expires_at = payload.get("expires_at") + if isinstance(expires_at, (int, float)): + return expires_at is None or expires_at <= time.time() + return False + + def _decode(self, value: bytes) -> JSONDict: + try: + payload = json.loads(value.decode("utf-8")) + except json.JSONDecodeError as ex: + LOG.warning(f"Failed to decode cache value: {ex}") + return {} + + if not isinstance(payload, dict): + LOG.warning(f"Invalid cache value format: {payload}") + return {} + + return payload diff --git a/mapillary_tools/upload.py b/mapillary_tools/upload.py index 73feaadc5..9303b813c 100644 --- a/mapillary_tools/upload.py +++ b/mapillary_tools/upload.py @@ -33,7 +33,7 @@ LOG = logging.getLogger(__name__) -class UploadedAlreadyError(uploader.SequenceError): +class UploadedAlready(uploader.SequenceError): pass @@ -96,23 +96,26 @@ def upload( upload_successes = 0 upload_errors: list[Exception] = [] - # The real upload happens sequentially here + # The real uploading happens sequentially here try: for _, result in results: if result.error is not None: - upload_errors.append(_continue_or_fail(result.error)) + upload_error = _continue_or_fail(result.error) + log_exception(upload_error) + upload_errors.append(upload_error) else: upload_successes += 1 except Exception as ex: # Fatal error: log and raise - if not dry_run: - _api_logging_failed(_summarize(stats), ex) + _api_logging_failed(_summarize(stats), ex, dry_run=dry_run) raise ex + except KeyboardInterrupt: + LOG.info("Upload interrupted by user...") + else: - if not dry_run: - _api_logging_finished(_summarize(stats)) + _api_logging_finished(_summarize(stats), dry_run=dry_run) finally: # We collected stats after every upload is finished @@ -141,6 +144,27 @@ def zip_images(import_path: Path, zip_dir: Path, desc_path: str | None = None): uploader.ZipUploader.zip_images(image_metadatas, zip_dir) +def log_exception(ex: Exception) -> None: + if LOG.getEffectiveLevel() <= logging.DEBUG: + exc_info = ex + else: + exc_info = None + + exc_name = ex.__class__.__name__ + + if isinstance(ex, UploadedAlready): + LOG.info(f"{exc_name}: {ex}") + elif isinstance(ex, requests.HTTPError): + LOG.error(f"{exc_name}: {api_v4.readable_http_error(ex)}", exc_info=exc_info) + elif isinstance(ex, api_v4.HTTPContentError): + LOG.error( + f"{exc_name}: {ex}: {api_v4.readable_http_response(ex.response)}", + exc_info=exc_info, + ) + else: + LOG.error(f"{exc_name}: {ex}", exc_info=exc_info) + + def _is_history_disabled(dry_run: bool) -> bool: # There is no way to read/write history if the path is not set if not constants.MAPILLARY_UPLOAD_HISTORY_PATH: @@ -195,14 +219,10 @@ def check_duplication(payload: uploader.Progress): ) else: if uploaded_at is not None: - LOG.info( - f"Skipping {name} (already uploaded at {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(uploaded_at))})" - ) + msg = f"Skipping {name} (previously uploaded at {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(uploaded_at))})" else: - LOG.info( - f"Skipping {name} (already uploaded, see {history_desc_path})" - ) - raise UploadedAlreadyError() + msg = f"Skipping {name} (already uploaded, see {history_desc_path})" + raise UploadedAlready(msg) @emitter.on("upload_finished") def write_history(payload: uploader.Progress): @@ -267,10 +287,20 @@ def upload_fetch_offset(payload: uploader.Progress) -> None: assert upload_pbar is not None, ( "progress_bar must be initialized in upload_start" ) - offset = payload.get("offset", 0) - if offset > 0: + begin_offset = payload.get("begin_offset", 0) + if begin_offset is not None and begin_offset > 0: + if upload_pbar.total is not None: + progress_percent = (begin_offset / upload_pbar.total) * 100 + upload_pbar.write( + f"Resuming upload at {begin_offset=} ({progress_percent:3.0f}%)", + file=sys.stderr, + ) + else: + upload_pbar.write( + f"Resuming upload at {begin_offset=}", file=sys.stderr + ) upload_pbar.reset() - upload_pbar.update(offset) + upload_pbar.update(begin_offset) upload_pbar.refresh() @emitter.on("upload_progress") @@ -282,6 +312,7 @@ def upload_progress(payload: uploader.Progress) -> None: upload_pbar.refresh() @emitter.on("upload_end") + @emitter.on("upload_failed") def upload_end(_: uploader.Progress) -> None: nonlocal upload_pbar if upload_pbar: @@ -429,7 +460,7 @@ def _show_upload_summary(stats: T.Sequence[_APIStats], errors: T.Sequence[Except errors_by_type.setdefault(error.__class__.__name__, []).append(error) for error_type, error_list in errors_by_type.items(): - if error_type == UploadedAlreadyError.__name__: + if error_type == UploadedAlready.__name__: LOG.info( "Skipped %d already uploaded sequences (use --reupload to force re-upload)", len(error_list), @@ -456,7 +487,10 @@ def _show_upload_summary(stats: T.Sequence[_APIStats], errors: T.Sequence[Except LOG.info("Nothing uploaded. Bye.") -def _api_logging_finished(summary: dict): +def _api_logging_finished(summary: dict, dry_run: bool = False): + if dry_run: + return + if constants.MAPILLARY_DISABLE_API_LOGGING: return @@ -465,15 +499,16 @@ def _api_logging_finished(summary: dict): api_v4.log_event(action, summary) except requests.HTTPError as exc: LOG.warning( - "HTTPError from API Logging for action %s: %s", - action, - api_v4.readable_http_error(exc), + f"HTTPError from logging action {action}: {api_v4.readable_http_error(exc)}" ) except Exception: - LOG.warning("Error from API Logging for action %s", action, exc_info=True) + LOG.warning(f"Error from logging action {action}", exc_info=True) -def _api_logging_failed(payload: dict, exc: Exception): +def _api_logging_failed(payload: dict, exc: Exception, dry_run: bool = False): + if dry_run: + return + if constants.MAPILLARY_DISABLE_API_LOGGING: return @@ -483,12 +518,10 @@ def _api_logging_failed(payload: dict, exc: Exception): api_v4.log_event(action, payload_with_reason) except requests.HTTPError as exc: LOG.warning( - "HTTPError from API Logging for action %s: %s", - action, - api_v4.readable_http_error(exc), + f"HTTPError from logging action {action}: {api_v4.readable_http_error(exc)}" ) except Exception: - LOG.warning("Error from API Logging for action %s", action, exc_info=True) + LOG.warning(f"Error from logging action {action}", exc_info=True) _M = T.TypeVar("_M", bound=types.Metadata) @@ -512,7 +545,9 @@ def _gen_upload_everything( (m for m in metadatas if isinstance(m, types.ImageMetadata)), utils.find_images(import_paths, skip_subfolders=skip_subfolders), ) - yield from uploader.ImageUploader.upload_images(mly_uploader, image_metadatas) + yield from uploader.ImageSequenceUploader.upload_images( + mly_uploader, image_metadatas + ) # Upload videos video_metadatas = _find_metadata_with_filename_existed_in( @@ -556,7 +591,7 @@ def _continue_or_fail(ex: Exception) -> Exception: return ex # Certain files not found or no permission - if isinstance(ex, OSError): + if isinstance(ex, (FileNotFoundError, PermissionError)): return ex # Certain metadatas are not valid diff --git a/mapillary_tools/upload_api_v4.py b/mapillary_tools/upload_api_v4.py index b8624d21a..b8c2e27bf 100644 --- a/mapillary_tools/upload_api_v4.py +++ b/mapillary_tools/upload_api_v4.py @@ -17,18 +17,17 @@ import requests -from .api_v4 import request_get, request_post, REQUESTS_TIMEOUT +from .api_v4 import ( + HTTPContentError, + jsonify_response, + request_get, + request_post, + REQUESTS_TIMEOUT, +) MAPILLARY_UPLOAD_ENDPOINT = os.getenv( "MAPILLARY_UPLOAD_ENDPOINT", "https://rupload.facebook.com/mapillary_public_uploads" ) -# According to the docs, UPLOAD_REQUESTS_TIMEOUT can be a tuple of -# (connection_timeout, read_timeout): https://requests.readthedocs.io/en/latest/user/advanced/#timeouts -# In my test, however, the connection_timeout rules both connection timeout and read timeout. -# i.e. if your the server does not respond within this timeout, it will throw: -# ConnectionError: ('Connection aborted.', timeout('The write operation timed out')) -# So let us make sure the largest possible chunks can be uploaded before this timeout for now, -UPLOAD_REQUESTS_TIMEOUT = (30 * 60, 30 * 60) # 30 minutes class UploadService: @@ -49,9 +48,15 @@ def fetch_offset(self) -> int: } url = f"{MAPILLARY_UPLOAD_ENDPOINT}/{self.session_key}" resp = request_get(url, headers=headers, timeout=REQUESTS_TIMEOUT) + resp.raise_for_status() - data = resp.json() - return data["offset"] + + data = jsonify_response(resp) + + try: + return data["offset"] + except KeyError: + raise HTTPContentError("Offset not found in the response", resp) @classmethod def chunkize_byte_stream( @@ -148,19 +153,23 @@ def upload_shifted_chunks( "X-Entity-Name": self.session_key, } url = f"{MAPILLARY_UPLOAD_ENDPOINT}/{self.session_key}" + # TODO: Estimate read timeout based on the data size + read_timeout = None resp = request_post( - url, headers=headers, data=shifted_chunks, timeout=UPLOAD_REQUESTS_TIMEOUT + url, + headers=headers, + data=shifted_chunks, + timeout=(REQUESTS_TIMEOUT, read_timeout), ) resp.raise_for_status() - payload = resp.json() + data = jsonify_response(resp) + try: - return payload["h"] + return data["h"] except KeyError: - raise RuntimeError( - f"Upload server error: File handle not found in the upload response {resp.text}" - ) + raise HTTPContentError("File handle not found in the response", resp) # A mock class for testing only @@ -174,9 +183,9 @@ class FakeUploadService(UploadService): def __init__( self, + *args, upload_path: Path | None = None, transient_error_ratio: float = 0.0, - *args, **kwargs, ): super().__init__(*args, **kwargs) @@ -187,10 +196,6 @@ def __init__( self._upload_path = upload_path self._transient_error_ratio = transient_error_ratio - @property - def upload_path(self) -> Path: - return self._upload_path - @override def upload_shifted_chunks( self, shifted_chunks: T.Iterable[bytes], offset: int @@ -205,15 +210,9 @@ def upload_shifted_chunks( filename = self._upload_path.joinpath(self.session_key) with filename.open("ab") as fp: for chunk in shifted_chunks: - if random.random() <= self._transient_error_ratio: - raise requests.ConnectionError( - f"TEST ONLY: Failed to upload with error ratio {self._transient_error_ratio}" - ) + self._randomly_raise_transient_error() fp.write(chunk) - if random.random() <= self._transient_error_ratio: - raise requests.ConnectionError( - f"TEST ONLY: Partially uploaded with error ratio {self._transient_error_ratio}" - ) + self._randomly_raise_transient_error() file_handle_dir = self._upload_path.joinpath(self.FILE_HANDLE_DIR) file_handle_path = file_handle_dir.joinpath(self.session_key) @@ -226,13 +225,24 @@ def upload_shifted_chunks( @override def fetch_offset(self) -> int: - if random.random() <= self._transient_error_ratio: - raise requests.ConnectionError( - f"TEST ONLY: Partially uploaded with error ratio {self._transient_error_ratio}" - ) + self._randomly_raise_transient_error() filename = self._upload_path.joinpath(self.session_key) if not filename.exists(): return 0 with open(filename, "rb") as fp: fp.seek(0, io.SEEK_END) return fp.tell() + + @property + def upload_path(self) -> Path: + return self._upload_path + + def _randomly_raise_transient_error(self): + """ + Randomly raise a transient error based on the configured error ratio. + This is for testing purposes only. + """ + if random.random() <= self._transient_error_ratio: + raise requests.ConnectionError( + f"[TEST ONLY]: Transient error with ratio {self._transient_error_ratio}" + ) diff --git a/mapillary_tools/uploader.py b/mapillary_tools/uploader.py index fe7a58971..e71860861 100644 --- a/mapillary_tools/uploader.py +++ b/mapillary_tools/uploader.py @@ -31,6 +31,7 @@ constants, exif_write, geo, + history, telemetry, types, upload_api_v4, @@ -137,9 +138,10 @@ class InvalidMapillaryZipFileError(SequenceError): "upload_start", "upload_fetch_offset", "upload_progress", + "upload_interrupted", "upload_end", + "upload_failed", "upload_finished", - "upload_interrupted", ] @@ -172,7 +174,11 @@ class VideoUploader: def upload_videos( cls, mly_uploader: Uploader, video_metadatas: T.Sequence[types.VideoMetadata] ) -> T.Generator[tuple[types.VideoMetadata, UploadResult], None, None]: - for idx, video_metadata in enumerate(video_metadatas): + # If upload in a random order, then interrupted uploads has a higher chance to expire. + # Therefore sort videos to make sure interrupted uploads are resumed as early as possible + sorted_video_metadatas = sorted(video_metadatas, key=lambda m: m.filename) + + for idx, video_metadata in enumerate(sorted_video_metadatas): try: video_metadata.update_md5sum() except Exception as ex: @@ -182,7 +188,7 @@ def upload_videos( assert isinstance(video_metadata.md5sum, str), "md5sum should be updated" progress: SequenceProgress = { - "total_sequence_count": len(video_metadatas), + "total_sequence_count": len(sorted_video_metadatas), "sequence_idx": idx, "file_type": video_metadata.filetype.value, "import_path": str(video_metadata.filename), @@ -264,9 +270,13 @@ class ZipUploader: def upload_zipfiles( cls, mly_uploader: Uploader, zip_paths: T.Sequence[Path] ) -> T.Generator[tuple[Path, UploadResult], None, None]: - for idx, zip_path in enumerate(zip_paths): + # If upload in a random order, then interrupted uploads has a higher chance to expire. + # Therefore sort zipfiles to make sure interrupted uploads are resumed as early as possible + sorted_zip_paths = sorted(zip_paths) + + for idx, zip_path in enumerate(sorted_zip_paths): progress: SequenceProgress = { - "total_sequence_count": len(zip_paths), + "total_sequence_count": len(sorted_zip_paths), "sequence_idx": idx, "import_path": str(zip_path), "file_type": types.FileType.ZIP.value, @@ -409,7 +419,7 @@ def _zip_sequence_fp( # Arcname should be unique, the name does not matter arcname = f"{idx}.jpg" zipinfo = zipfile.ZipInfo(arcname, date_time=(1980, 1, 1, 0, 0, 0)) - zipf.writestr(zipinfo, ImageUploader.dump_image_bytes(metadata)) + zipf.writestr(zipinfo, SingleImageUploader.dump_image_bytes(metadata)) assert len(sequence) == len(set(zipf.namelist())) zipf.comment = json.dumps( {"sequence_md5sum": sequence_md5sum}, @@ -471,7 +481,7 @@ def _wip_file_context(cls, wip_path: Path): pass -class ImageUploader: +class ImageSequenceUploader: @classmethod def upload_images( cls, uploader: Uploader, image_metadatas: T.Sequence[types.ImageMetadata] @@ -501,27 +511,6 @@ def upload_images( else: yield sequence_uuid, UploadResult(result=cluster_id) - @classmethod - def dump_image_bytes(cls, metadata: types.ImageMetadata) -> bytes: - try: - edit = exif_write.ExifEdit(metadata.filename) - except struct.error as ex: - raise ExifError(f"Failed to load EXIF: {ex}", metadata.filename) from ex - - # The cast is to fix the type checker error - edit.add_image_description( - T.cast( - T.Dict, desc_file_to_exif(DescriptionJSONSerializer.as_desc(metadata)) - ) - ) - - try: - return edit.dump_image_bytes() - except struct.error as ex: - raise ExifError( - f"Failed to dump EXIF bytes: {ex}", metadata.filename - ) from ex - @classmethod def _upload_sequence( cls, @@ -535,39 +524,43 @@ def _upload_sequence( # FIXME: This is a hack to disable the event emitter inside the uploader uploader_without_emitter = uploader.copy_uploader_without_emitter() - lock = threading.Lock() - - def _upload_image(image_metadata: types.ImageMetadata) -> str: - mutable_progress = { - **(progress or {}), - "filename": str(image_metadata.filename), - } - - bytes = cls.dump_image_bytes(image_metadata) - file_handle = uploader_without_emitter.upload_stream( - io.BytesIO(bytes), progress=mutable_progress - ) - - mutable_progress["chunk_size"] = image_metadata.filesize - - with lock: - uploader.emitter.emit("upload_progress", mutable_progress) - - return file_handle - _validate_metadatas(sequence) progress["entity_size"] = sum(m.filesize or 0 for m in sequence) # TODO: assert sequence is sorted + single_image_uploader = SingleImageUploader( + uploader, uploader_without_emitter, progress=progress + ) + uploader.emitter.emit("upload_start", progress) with concurrent.futures.ThreadPoolExecutor( max_workers=constants.MAX_IMAGE_UPLOAD_WORKERS ) as executor: - image_file_handles = list(executor.map(_upload_image, sequence)) + image_file_handles = list( + executor.map(single_image_uploader.upload, sequence) + ) + + manifest_file_handle = cls._upload_manifest( + uploader_without_emitter, image_file_handles + ) + + uploader.emitter.emit("upload_end", progress) + + cluster_id = uploader.finish_upload( + manifest_file_handle, + api_v4.ClusterFileType.MLY_BUNDLE_MANIFEST, + progress=progress, + ) + return cluster_id + + @classmethod + def _upload_manifest( + cls, uploader_without_emitter: Uploader, image_file_handles: T.Sequence[str] + ) -> str: manifest = { "version": "1", "upload_type": "images", @@ -581,19 +574,118 @@ def _upload_image(image_metadata: types.ImageMetadata) -> str: ) ) manifest_fp.seek(0, io.SEEK_SET) - manifest_file_handle = uploader_without_emitter.upload_stream( - manifest_fp, session_key=f"uuid_{uuid.uuid4().hex}.json" + return uploader_without_emitter.upload_stream( + manifest_fp, session_key=f"{_prefixed_uuid4()}.json" ) - uploader.emitter.emit("upload_end", progress) - cluster_id = uploader.finish_upload( - manifest_file_handle, - api_v4.ClusterFileType.MLY_BUNDLE_MANIFEST, - progress=progress, +class SingleImageUploader: + def __init__( + self, + uploader: Uploader, + uploader_without_emitter: Uploader, + progress: dict[str, T.Any] | None = None, + ): + self.uploader = uploader + self.uploader_without_emitter = uploader_without_emitter + self.progress = progress or {} + self.lock = threading.Lock() + self.cache = self._maybe_create_persistent_cache_instance(uploader.user_items) + + def upload(self, image_metadata: types.ImageMetadata) -> str: + mutable_progress = { + **(self.progress or {}), + "filename": str(image_metadata.filename), + } + + image_bytes = self.dump_image_bytes(image_metadata) + + session_key = self.uploader_without_emitter._gen_session_key( + io.BytesIO(image_bytes), mutable_progress ) - return cluster_id + file_handle = self._file_handle_cache_get(session_key) + + if file_handle is None: + file_handle = self.uploader_without_emitter.upload_stream( + io.BytesIO(image_bytes), + session_key=session_key, + progress=mutable_progress, + ) + self._file_handle_cache_set(session_key, file_handle) + + # Override chunk_size with the actual filesize + mutable_progress["chunk_size"] = image_metadata.filesize + + with self.lock: + self.uploader.emitter.emit("upload_progress", mutable_progress) + + return file_handle + + @classmethod + def dump_image_bytes(cls, metadata: types.ImageMetadata) -> bytes: + try: + edit = exif_write.ExifEdit(metadata.filename) + except struct.error as ex: + raise ExifError(f"Failed to load EXIF: {ex}", metadata.filename) from ex + + # The cast is to fix the type checker error + edit.add_image_description( + T.cast( + T.Dict, desc_file_to_exif(DescriptionJSONSerializer.as_desc(metadata)) + ) + ) + + try: + return edit.dump_image_bytes() + except struct.error as ex: + raise ExifError( + f"Failed to dump EXIF bytes: {ex}", metadata.filename + ) from ex + + @classmethod + def _maybe_create_persistent_cache_instance( + cls, user_items: config.UserItem + ) -> history.PersistentCache | None: + if not constants.UPLOAD_CACHE_DIR: + LOG.debug( + "Upload cache directory is set empty, skipping caching upload file handles" + ) + return None + + cache_path_dir = ( + Path(constants.UPLOAD_CACHE_DIR) + .joinpath(api_v4.MAPILLARY_CLIENT_TOKEN.replace("|", "_")) + .joinpath( + user_items.get("MAPSettingsUserKey", user_items["user_upload_token"]) + ) + ) + cache_path_dir.mkdir(parents=True, exist_ok=True) + cache_path = cache_path_dir.joinpath("cached_file_handles") + LOG.debug(f"File handle cache path: {cache_path}") + + cache = history.PersistentCache(str(cache_path.resolve())) + cache.clear_expired() + + return cache + + def _file_handle_cache_get(self, key: str) -> str | None: + if self.cache is None: + return None + + if _is_uuid(key): + return None + + return self.cache.get(key) + + def _file_handle_cache_set(self, key: str, value: str) -> None: + if self.cache is None: + return + + if _is_uuid(key): + return + + self.cache.set(key, value) class Uploader: @@ -627,23 +719,11 @@ def upload_stream( progress = {} if session_key is None: - if self.noresume: - # Generate a unique UUID for session_key when noresume is True - # to prevent resuming from previous uploads - session_key = f"uuid_{uuid.uuid4().hex}" - else: - fp.seek(0, io.SEEK_SET) - session_key = utils.md5sum_fp(fp).hexdigest() - - filetype = progress.get("file_type") - if filetype is not None: - session_key = _session_key(session_key, types.FileType(filetype)) + session_key = self._gen_session_key(fp, progress) fp.seek(0, io.SEEK_END) entity_size = fp.tell() - upload_service = self._create_upload_service(session_key) - progress["entity_size"] = entity_size progress["chunk_size"] = self.chunk_size progress["retries"] = 0 @@ -651,6 +731,8 @@ def upload_stream( self.emitter.emit("upload_start", progress) + upload_service = self._create_upload_service(session_key) + while True: try: file_handle = self._upload_stream_retryable( @@ -658,6 +740,9 @@ def upload_stream( ) except Exception as ex: self._handle_upload_exception(ex, T.cast(UploaderProgress, progress)) + except BaseException as ex: + self.emitter.emit("upload_failed", progress) + raise ex else: break @@ -687,10 +772,9 @@ def finish_upload( organization_id=self.user_items.get("MAPOrganizationKey"), ) - data = resp.json() - cluster_id = data.get("cluster_id") - - # TODO: validate cluster_id + body = api_v4.jsonify_response(resp) + # TODO: Validate cluster_id + cluster_id = body.get("cluster_id") progress["cluster_id"] = cluster_id self.emitter.emit("upload_finished", progress) @@ -732,35 +816,28 @@ def _create_upload_service(self, session_key: str) -> upload_api_v4.UploadServic def _handle_upload_exception( self, ex: Exception, progress: UploaderProgress ) -> None: - retries = progress["retries"] + retries = progress.get("retries", 0) begin_offset = progress.get("begin_offset") - chunk_size = progress["chunk_size"] + offset = progress.get("offset") if retries <= constants.MAX_UPLOAD_RETRIES and _is_retriable_exception(ex): self.emitter.emit("upload_interrupted", progress) LOG.warning( - # use %s instead of %d because offset could be None - "Error uploading chunk_size %d at begin_offset %s: %s: %s", - chunk_size, - begin_offset, - ex.__class__.__name__, - str(ex), + f"Error uploading at {offset=} since {begin_offset=}: {ex.__class__.__name__}: {ex}" ) # Keep things immutable here. Will increment retries in the caller retries += 1 - if _is_immediate_retry(ex): + if _is_immediate_retriable_exception(ex): sleep_for = 0 else: sleep_for = min(2**retries, 16) LOG.info( - "Retrying in %d seconds (%d/%d)", - sleep_for, - retries, - constants.MAX_UPLOAD_RETRIES, + f"Retrying in {sleep_for} seconds ({retries}/{constants.MAX_UPLOAD_RETRIES})" ) if sleep_for: time.sleep(sleep_for) else: + self.emitter.emit("upload_failed", progress) raise ex def _chunk_with_progress_emitted( @@ -801,6 +878,21 @@ def _upload_stream_retryable( return upload_service.upload_shifted_chunks(shifted_chunks, begin_offset) + def _gen_session_key(self, fp: T.IO[bytes], progress: dict[str, T.Any]) -> str: + if self.noresume: + # Generate a unique UUID for session_key when noresume is True + # to prevent resuming from previous uploads + session_key = f"{_prefixed_uuid4()}" + else: + fp.seek(0, io.SEEK_SET) + session_key = utils.md5sum_fp(fp).hexdigest() + + filetype = progress.get("file_type") + if filetype is not None: + session_key = _session_key(session_key, types.FileType(filetype)) + + return session_key + def _validate_metadatas(metadatas: T.Sequence[types.ImageMetadata]): for metadata in metadatas: @@ -809,7 +901,7 @@ def _validate_metadatas(metadatas: T.Sequence[types.ImageMetadata]): raise FileNotFoundError(f"No such file {metadata.filename}") -def _is_immediate_retry(ex: Exception): +def _is_immediate_retriable_exception(ex: Exception) -> bool: if ( isinstance(ex, requests.HTTPError) and isinstance(ex.response, requests.Response) @@ -822,8 +914,10 @@ def _is_immediate_retry(ex: Exception): # resp: {"debug_info":{"retriable":true,"type":"OffsetInvalidError","message":"Request starting offset is invalid"}} return resp.get("debug_info", {}).get("retriable", False) + return False + -def _is_retriable_exception(ex: Exception): +def _is_retriable_exception(ex: Exception) -> bool: if isinstance(ex, (requests.ConnectionError, requests.Timeout)): return True @@ -858,3 +952,13 @@ def _session_key( } return f"mly_tools_{upload_md5sum}{_SUFFIX_MAP[filetype]}" + + +def _prefixed_uuid4(): + prefixed = f"uuid_{uuid.uuid4().hex}" + assert _is_uuid(prefixed) + return prefixed + + +def _is_uuid(session_key: str) -> bool: + return session_key.startswith("uuid_") diff --git a/tests/cli/upload_api_v4.py b/tests/cli/upload_api_v4.py index 24d33c88f..3835bb062 100644 --- a/tests/cli/upload_api_v4.py +++ b/tests/cli/upload_api_v4.py @@ -65,7 +65,7 @@ def main(): user_access_token = user_items.get("user_upload_token", "") if parsed.dry_run: - service = FakeUploadService(user_access_token, session_key) + service = FakeUploadService(user_access_token="", session_key=session_key) else: service = UploadService(user_access_token, session_key) @@ -79,6 +79,9 @@ def main(): LOG.info("Entity size: %d", entity_size) LOG.info("Chunk size: %s MB", chunk_size / (1024 * 1024)) + if isinstance(service, FakeUploadService): + LOG.info(f"Uploading to {service.upload_path}") + def _update_pbar(chunks, pbar): for chunk in chunks: yield chunk diff --git a/tests/integration/fixtures.py b/tests/integration/fixtures.py index 5c86f9554..d11586021 100644 --- a/tests/integration/fixtures.py +++ b/tests/integration/fixtures.py @@ -66,6 +66,7 @@ def setup_upload(tmpdir: py.path.local): os.environ["MAPILLARY_TOOLS__AUTH_VERIFICATION_DISABLED"] = "YES" os.environ["MAPILLARY_TOOLS_PROMPT_DISABLED"] = "YES" os.environ["MAPILLARY__ENABLE_UPLOAD_HISTORY_FOR_DRY_RUN"] = "YES" + os.environ["MAPILLARY_TOOLS_UPLOAD_CACHE_DIR"] = str(tmpdir.mkdir("upload_cache")) history_path = tmpdir.join("history") os.environ["MAPILLARY_UPLOAD_HISTORY_PATH"] = str(history_path) yield upload_dir @@ -386,7 +387,9 @@ def assert_descs_exact_equal(left: list[dict], right: list[dict]): def run_command(params: list[str], command: str, **kwargs): - subprocess.run([*shlex.split(EXECUTABLE), command, *params], check=True, **kwargs) + subprocess.run( + [*shlex.split(EXECUTABLE), "--verbose", command, *params], check=True, **kwargs + ) def run_process_for_descs(params: list[str], command: str = "process", **kwargs): diff --git a/tests/unit/test_persistent_cache.py b/tests/unit/test_persistent_cache.py new file mode 100644 index 000000000..32fef006f --- /dev/null +++ b/tests/unit/test_persistent_cache.py @@ -0,0 +1,293 @@ +import dbm +import os +import threading +import time + +import pytest + +from mapillary_tools.history import PersistentCache + + +# DBM backends to test with +DBM_BACKENDS = ["dbm.sqlite3", "dbm.gnu", "dbm.ndbm", "dbm.dumb"] + + +@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) +def test_basic_operations_with_backend(tmpdir, dbm_backend): + """Test basic operations with different DBM backends. + + Note: This is a demonstration of pytest's parametrize feature. + The actual PersistentCache class might not support specifying backends. + """ + cache_file = os.path.join(tmpdir, dbm_backend) + # Here you would use the backend if the cache implementation supported it + cache = PersistentCache(cache_file) + + # Perform basic operations + cache.set("test_key", "test_value") + assert cache.get("test_key") == "test_value" + + # Add specific test logic for different backends if needed + # This is just a placeholder to demonstrate pytest's parametrization + + +@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) +def test_get_set(tmpdir, dbm_backend): + """Test basic get and set operations.""" + cache_file = os.path.join(tmpdir, f"cache_get_set_{dbm_backend}") + cache = PersistentCache(cache_file) + cache.set("key1", "value1") + assert cache.get("key1") == "value1" + assert cache.get("nonexistent_key") is None + + +@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) +def test_expiration(tmpdir, dbm_backend): + """Test that entries expire correctly.""" + cache_file = os.path.join(tmpdir, f"cache_expiration_{dbm_backend}") + cache = PersistentCache(cache_file) + + # Set with short expiration + cache.set("short_lived", "value", expires_in=1) + assert cache.get("short_lived") == "value" + + # Wait for expiration + time.sleep(1.1) + assert cache.get("short_lived") is None + + # Set with longer expiration + cache.set("long_lived", "value", expires_in=10) + assert cache.get("long_lived") == "value" + + # Should still be valid + time.sleep(1) + assert cache.get("long_lived") == "value" + + +@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) +@pytest.mark.parametrize( + "expire_time,sleep_time,should_exist", + [ + (1, 0.5, True), # Should not expire yet + (1, 1.5, False), # Should expire + (5, 2, True), # Should not expire yet + ], +) +def test_parametrized_expiration( + tmpdir, dbm_backend, expire_time, sleep_time, should_exist +): + """Test expiration with different timing combinations.""" + cache_file = os.path.join( + tmpdir, f"cache_param_exp_{dbm_backend}_{expire_time}_{sleep_time}" + ) + cache = PersistentCache(cache_file) + + key = f"key_expires_in_{expire_time}_sleeps_{sleep_time}" + cache.set(key, "test_value", expires_in=expire_time) + + time.sleep(sleep_time) + + if should_exist: + assert cache.get(key) == "test_value" + else: + assert cache.get(key) is None + + +@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) +def test_clear_expired(tmpdir, dbm_backend): + """Test clearing expired entries.""" + cache_file = os.path.join(tmpdir, f"cache_clear_expired_{dbm_backend}") + cache = PersistentCache(cache_file) + + # Test 1: Single expired key + cache.set("expired", "value1", expires_in=1) + cache.set("not_expired", "value2", expires_in=10) + + # Wait for first entry to expire + time.sleep(1.1) + + # Clear expired entries + expired_keys = cache.clear_expired() + + # Check that only the expired key was cleared + assert len(expired_keys) == 1 + assert expired_keys[0] == b"expired" + assert cache.get("expired") is None + assert cache.get("not_expired") == "value2" + + +@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) +def test_clear_expired_multiple(tmpdir, dbm_backend): + """Test clearing multiple expired entries.""" + cache_file = os.path.join(tmpdir, f"cache_clear_multiple_{dbm_backend}") + cache = PersistentCache(cache_file) + + # Test 2: Multiple expired keys + cache.set("expired1", "value1", expires_in=1) + cache.set("expired2", "value2", expires_in=1) + cache.set("not_expired", "value3", expires_in=10) + + # Wait for entries to expire + time.sleep(1.1) + + # Clear expired entries + expired_keys = cache.clear_expired() + + # Check that only expired keys were cleared + assert len(expired_keys) == 2 + assert b"expired1" in expired_keys + assert b"expired2" in expired_keys + assert cache.get("expired1") is None + assert cache.get("expired2") is None + assert cache.get("not_expired") == "value3" + + +@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) +def test_clear_expired_all(tmpdir, dbm_backend): + """Test clearing all expired entries.""" + cache_file = os.path.join(tmpdir, f"cache_clear_all_{dbm_backend}") + cache = PersistentCache(cache_file) + + # Test 3: All entries expired + cache.set("key1", "value1", expires_in=1) + cache.set("key2", "value2", expires_in=1) + + # Wait for entries to expire + time.sleep(1.1) + + # Clear expired entries + expired_keys = cache.clear_expired() + + # Check that all keys were cleared + assert len(expired_keys) == 2 + assert b"key1" in expired_keys + assert b"key2" in expired_keys + + +@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) +def test_clear_expired_none(tmpdir, dbm_backend): + """Test clearing when no entries are expired.""" + cache_file = os.path.join(tmpdir, f"cache_clear_none_{dbm_backend}") + cache = PersistentCache(cache_file) + + # Test 4: No entries expired + cache.set("key1", "value1", expires_in=10) + cache.set("key2", "value2", expires_in=10) + + # Clear expired entries + expired_keys = cache.clear_expired() + + # Check that no keys were cleared + assert len(expired_keys) == 0 + assert cache.get("key1") == "value1" + assert cache.get("key2") == "value2" + + +@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) +def test_clear_expired_empty(tmpdir, dbm_backend): + """Test clearing expired entries on an empty cache.""" + cache_file = os.path.join(tmpdir, f"cache_clear_empty_{dbm_backend}") + cache = PersistentCache(cache_file) + + # Test 5: Empty cache + expired_keys = cache.clear_expired() + + # Check that no keys were cleared + assert len(expired_keys) == 0 + + +@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) +def test_corrupted_data(tmpdir, dbm_backend): + """Test handling of corrupted data.""" + cache_file = os.path.join(tmpdir, f"cache_corrupted_{dbm_backend}") + cache = PersistentCache(cache_file) + + # Set valid entry + cache.set("key1", "value1") + + # Corrupt the data by directly writing invalid JSON + with dbm.open(cache_file, "c") as db: + db["corrupted"] = b"not valid json" + db["corrupted_dict"] = b'"not a dict"' + + # Check that corrupted entries are handled gracefully + assert cache.get("corrupted") is None + assert cache.get("corrupted_dict") is None + + # Valid entries should still work + assert cache.get("key1") == "value1" + + # Clear expired should not crash on corrupted entries + cache.clear_expired() + + +@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) +def test_concurrency(tmpdir, dbm_backend): + """Test concurrent access to the cache.""" + cache_file = os.path.join(tmpdir, f"cache_concurrency_{dbm_backend}") + cache = PersistentCache(cache_file) + num_threads = 10 + num_operations = 50 + + results = [] # Store assertion failures for pytest to check after threads complete + + def worker(thread_id): + for i in range(num_operations): + key = f"key_{thread_id}_{i}" + value = f"value_{thread_id}_{i}" + cache.set(key, value) + # Occasionally read a previously written value + if i > 0 and i % 5 == 0: + prev_key = f"key_{thread_id}_{i - 1}" + prev_value = cache.get(prev_key) + if prev_value != f"value_{thread_id}_{i - 1}": + results.append( + f"Expected {prev_key} to be value_{thread_id}_{i - 1}, got {prev_value}" + ) + + threads = [] + for i in range(num_threads): + t = threading.Thread(target=worker, args=(i,)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + # Check for any failures in threads + assert not results, f"Thread assertions failed: {results}" + + # Verify all values were written correctly + for i in range(num_threads): + for j in range(num_operations): + key = f"key_{i}_{j}" + expected_value = f"value_{i}_{j}" + assert cache.get(key) == expected_value + + +@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) +def test_decode_invalid_data(tmpdir, dbm_backend): + """Test _decode method with invalid data.""" + cache_file = os.path.join(tmpdir, f"cache_decode_invalid_{dbm_backend}") + cache = PersistentCache(cache_file) + + # Test with various invalid inputs + result = cache._decode(b"not valid json") + assert result == {} + + result = cache._decode(b'"string instead of dict"') + assert result == {} + + +@pytest.mark.parametrize("dbm_backend", DBM_BACKENDS) +def test_is_expired(tmpdir, dbm_backend): + """Test _is_expired method.""" + cache_file = os.path.join(tmpdir, f"cache_is_expired_{dbm_backend}") + cache = PersistentCache(cache_file) + + # Test with various payloads + assert cache._is_expired({"expires_at": time.time() - 10}) is True + assert cache._is_expired({"expires_at": time.time() + 10}) is False + assert cache._is_expired({}) is False + assert cache._is_expired({"expires_at": "not a number"}) is False + assert cache._is_expired({"expires_at": None}) is False