From 97e8e1012061e746137b1c739c54a51549c4df7f Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Tue, 25 Nov 2025 17:11:29 +0100 Subject: [PATCH 1/7] image-artifactsaver integration --- .../artifactsavers/image_artifactsaver.py | 84 +++++++++++++++++++ src/pruna/evaluation/artifactsavers/utils.py | 49 +++++++++++ 2 files changed, 133 insertions(+) create mode 100644 src/pruna/evaluation/artifactsavers/image_artifactsaver.py create mode 100644 src/pruna/evaluation/artifactsavers/utils.py diff --git a/src/pruna/evaluation/artifactsavers/image_artifactsaver.py b/src/pruna/evaluation/artifactsavers/image_artifactsaver.py new file mode 100644 index 00000000..eb0039f7 --- /dev/null +++ b/src/pruna/evaluation/artifactsavers/image_artifactsaver.py @@ -0,0 +1,84 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import secrets +import time +from pathlib import Path +from typing import Any + +import numpy as np +import torch +from PIL import Image + +from pruna.evaluation.artifactsavers.artifactsaver import ArtifactSaver + +class ImageArtifactSaver(ArtifactSaver): + """ + Save image artifacts. + + Parameters + ---------- + root: Path | str | None = None + The root directory to save the artifacts. + export_format: str | None = "png" + The format to save the artifacts (e.g. "png", "jpg", "jpeg", "webp"). + """ + + export_format: str | None + root: Path | str | None + + def __init__(self, root: Path | str | None = None, export_format: str | None = "png") -> None: + self.root = Path(root) if root else Path.cwd() + (self.root / "canonical").mkdir(parents=True, exist_ok=True) + self.export_format = export_format if export_format else "png" + if self.export_format not in ["png", "jpg", "jpeg", "webp"]: + raise ValueError(f"Invalid format: {self.export_format}. Valid formats are: png, jpg, jpeg, webp.") + + def save_artifact(self, data: Any, saving_kwargs: dict = dict()) -> Path: + """ + Save the image artifact. + + Parameters + ---------- + data: Any + The data to save. + saving_kwargs: dict + The additional kwargs to pass to the saving utility function. + + Returns + ------- + Path + The path to the saved artifact. + """ + canonical_filename = f"{int(time.time())}_{secrets.token_hex(4)}.{self.export_format}" + canonical_path = Path(str(self.root)) / "canonical" / canonical_filename + + # We save the image as a PIL Image, so we need to convert the data to a PIL Image. + # Usually, the data is already a PIL.Image, so we don't need to convert it. + if isinstance(data, torch.Tensor): + data = np.transpose(data.cpu().numpy(), (1, 2, 0)) + data = np.clip(data * 255, 0, 255).astype(np.uint8) + if isinstance(data, np.ndarray): + data = Image.fromarray(data.astype(np.uint8)) + # Now data must be a PIL Image + if not isinstance(data, Image.Image): + raise ValueError("Model outputs must be torch.Tensor, numpy.ndarray, or PIL.Image.") + + # Save the image (export format is determined by the file extension) + data.save(canonical_path, **saving_kwargs.copy()) + + return canonical_path + diff --git a/src/pruna/evaluation/artifactsavers/utils.py b/src/pruna/evaluation/artifactsavers/utils.py new file mode 100644 index 00000000..796f6e82 --- /dev/null +++ b/src/pruna/evaluation/artifactsavers/utils.py @@ -0,0 +1,49 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from pathlib import Path + +from pruna.evaluation.artifactsavers.artifactsaver import ArtifactSaver +from pruna.evaluation.artifactsavers.video_artifactsaver import VideoArtifactSaver +from pruna.evaluation.artifactsavers.image_artifactsaver import ImageArtifactSaver + + +def assign_artifact_saver( + modality: str, root: Path | str | None = None, export_format: str | None = None +) -> ArtifactSaver: + """ + Assign the appropriate artifact saver based on the modality. + + Parameters + ---------- + modality: str + The modality of the data. + root: str + The root directory to save the artifacts. + export_format: str + The format to save the artifacts. + + Returns + ------- + ArtifactSaver + The appropriate artifact saver. + """ + if modality == "video": + return VideoArtifactSaver(root=root, export_format=export_format) + if modality == "image": + return ImageArtifactSaver(root=root, export_format=export_format) + else: + raise ValueError(f"Modality {modality} is not supported") From 8f34df237dd14f3bf12dfa12aee92e0b15d14bbb Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Mon, 1 Dec 2025 17:39:08 +0100 Subject: [PATCH 2/7] Tests for image artifactsaver added, all tests passed locally. --- tests/evaluation/test_artifactsaver.py | 180 +++++++++++++++++++++++++ 1 file changed, 180 insertions(+) create mode 100644 tests/evaluation/test_artifactsaver.py diff --git a/tests/evaluation/test_artifactsaver.py b/tests/evaluation/test_artifactsaver.py new file mode 100644 index 00000000..7a765237 --- /dev/null +++ b/tests/evaluation/test_artifactsaver.py @@ -0,0 +1,180 @@ +import pytest +import tempfile +from pathlib import Path + +import numpy as np +import torch + +from pruna.evaluation.artifactsavers.video_artifactsaver import VideoArtifactSaver +from pruna.evaluation.artifactsavers.image_artifactsaver import ImageArtifactSaver +from pruna.evaluation.artifactsavers.utils import assign_artifact_saver +from pruna.evaluation.metrics.vbench_utils import load_video +from PIL import Image + + +def test_create_alias(): + """ Test that we can create an alias for an existing video and image.""" + with tempfile.TemporaryDirectory() as tmp_path: + + # --- Test video artifact saver --- + # First, we create a random video and save it. + saver = VideoArtifactSaver(root=tmp_path, export_format="mp4") + dummy_video = np.random.randint(0, 255, (10, 16, 16, 3), dtype=np.uint8) + source_filename = saver.save_artifact(dummy_video, saving_kwargs={"fps": 5}) + # Then, we create an alias for the video. + alias = saver.create_alias(source_filename, "alias_filename") + # Finally, we reload the alias and check that it is the same as the original video. + reloaded_alias_video = load_video(str(alias), return_type = "np") + assert(reloaded_alias_video.shape == dummy_video.shape) + assert alias.exists() + assert alias.name.endswith(".mp4") + + # --- Test image artifact saver --- + saver = ImageArtifactSaver(root=tmp_path, export_format="png") + dummy_image = np.random.randint(0, 255, (16, 16, 3), dtype=np.uint8) + source_filename = saver.save_artifact(dummy_image, saving_kwargs={"quality": 95}) + # Then, we create an alias for the image. + alias = saver.create_alias(source_filename, "alias_filename") + # Finally, we reload the alias and check that it is the same as the original image. + reloaded_alias_image = np.array(Image.open(str(alias))) + assert(reloaded_alias_image.shape == dummy_image.shape) + assert alias.exists() + assert alias.name.endswith(".png") + +def test_assign_all_artifact_savers(tmp_path: Path): + """ Test each artifact saver is assigned correctly.""" + saver = assign_artifact_saver("video", root=tmp_path, export_format="mp4") + assert isinstance(saver, VideoArtifactSaver) + assert saver.export_format == "mp4" + saver = assign_artifact_saver("image", root=tmp_path, export_format="png") + assert isinstance(saver, ImageArtifactSaver) + assert saver.export_format == "png" + +def test_assign_artifact_saver_invalid(): + """ Test that we raise an error if the artifact saver is assigned incorrectly.""" + with pytest.raises(ValueError): + assign_artifact_saver("nonexistent_modality") + +@pytest.mark.parametrize( + "export_format, save_from_type, save_from_dtype", + [pytest.param("gif", "np", "uint8"), + pytest.param("gif", "np", "float32"), + # Numpy doesn't have half precision, so we do not test for float16 + pytest.param("gif", "pt", "float32"), + pytest.param("gif", "pt", "float16"), + pytest.param("gif", "pt", "uint8"), + # PIL doesnot support creating images from float numpy arrays, so we only test uint8. + pytest.param("gif", "pil", "uint8"), + pytest.param("mp4", "np", "uint8"), + pytest.param("mp4", "np", "float32"), + pytest.param("mp4", "pt", "float32"), + pytest.param("mp4", "pt", "float16"), + pytest.param("mp4", "pt", "uint8"), + # PIL doesnot support creating images from float numpy arrays, so we only test uint8. + pytest.param("mp4", "pil", "uint8"),] +) +def test_video_artifact_saver_tensor(export_format: str, save_from_type: str, save_from_dtype: str): + """ Test that we can save a video from numpy, torch and PIL in mp4 and gif formats. """ + with tempfile.TemporaryDirectory() as tmp_path: + saver = VideoArtifactSaver(root=tmp_path, export_format=export_format) + # create a fake video: + if save_from_type == "pt": + # Unfortunately, neither torch nor numpy have one random generator function that can support all dtypes. + # Therefore, we need to use different functions for int and float dtypes. + if save_from_dtype == "uint8": + dtype = getattr(torch, save_from_dtype) + dummy_video = torch.randint(0, 256, (2, 3, 16, 16), dtype=dtype) + else: + dtype = getattr(torch, save_from_dtype) + dummy_video = torch.randn(2, 3, 16, 16, dtype=dtype) + elif save_from_type == "np": + if save_from_dtype == "uint8": + dtype = getattr(np, save_from_dtype) + dummy_video = np.random.randint(0, 256, (2, 16, 16, 3), dtype=dtype) + else: + rng = np.random.default_rng() + dtype = getattr(np, save_from_dtype) + dummy_video = rng.random((2, 16, 16, 3), dtype=dtype) + elif save_from_type == "pil": + dtype = getattr(np, save_from_dtype) + dummy_video = np.random.randint(0, 256, (2, 16, 16, 3), dtype=dtype) + dummy_video = [Image.fromarray(frame.astype(np.uint8)) for frame in dummy_video] + path = saver.save_artifact(dummy_video) + assert path.exists() + assert path.suffix == f".{export_format}" + +@pytest.mark.parametrize( + "export_format, save_from_type, save_from_dtype", + [ + # --- Test png format --- + # numpy + pytest.param("png", "np", "uint8"), + pytest.param("png", "np", "float32"), + # torch + pytest.param("png", "pt", "float32"), + pytest.param("png", "pt", "float16"), + pytest.param("png", "pt", "uint8"), + # PIL + pytest.param("png", "pil", "uint8"), + # --- Test jpg format --- + # numpy + pytest.param("jpg", "np", "uint8"), + pytest.param("jpg", "np", "float32"), + # torch + pytest.param("jpg", "pt", "float32"), + pytest.param("jpg", "pt", "float16"), + pytest.param("jpg", "pt", "uint8"), + # PIL + pytest.param("jpg", "pil", "uint8"), + # --- Test webp format --- + # numpy + pytest.param("webp", "np", "uint8"), + pytest.param("webp", "np", "float32"), + # torch + pytest.param("webp", "pt", "float32"), + pytest.param("webp", "pt", "float16"), + pytest.param("webp", "pt", "uint8"), + # PIL + pytest.param("webp", "pil", "uint8"), + # --- Test jpeg format --- + # numpy + pytest.param("jpeg", "np", "uint8"), + pytest.param("jpeg", "np", "float32"), + # torch + pytest.param("jpeg", "pt", "float32"), + pytest.param("jpeg", "pt", "float16"), + pytest.param("jpeg", "pt", "uint8"), + # PIL + pytest.param("jpeg", "pil", "uint8"), + ] + ) +def test_image_artifact_saver_tensor(export_format: str, save_from_type: str, save_from_dtype: str): + """ Test that we can save an image from a tensor.""" + with tempfile.TemporaryDirectory() as tmp_path: + saver = ImageArtifactSaver(root=tmp_path, export_format=export_format) + # Create fake image: + if save_from_type == "pt": + # Note: torch convention is (C, H, W) + if save_from_dtype == "uint8": + dtype = getattr(torch, save_from_dtype) + dummy_image = torch.randint(0, 256, (3, 16, 16), dtype=dtype) + else: + dtype = getattr(torch, save_from_dtype) + dummy_image = torch.randn(3, 16, 16, dtype=dtype) + elif save_from_type == "np": + # Note: Numpy arrays as images follow the convention (H, W, C) + if save_from_dtype == "uint8": + dtype = getattr(np, save_from_dtype) + dummy_image = np.random.randint(0, 256, (16, 16, 3), dtype=dtype) + else: + rng = np.random.default_rng() + dtype = getattr(np, save_from_dtype) + dummy_image = rng.random((16, 16, 3), dtype=dtype) + elif save_from_type == "pil": + # Note: PIL images by default have shape (H, W, C) and are usually uint8 (standard for ".jpg", etc.) + dtype = getattr(np, save_from_dtype) + dummy_image = np.random.randint(0, 256, (16, 16, 3), dtype=dtype) + dummy_image = Image.fromarray(dummy_image.astype(np.uint8)) + path = saver.save_artifact(dummy_image) + assert path.exists() + assert path.suffix == f".{export_format}" \ No newline at end of file From 759436330f36a21893d6f46768ee274ef71abcf1 Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Tue, 2 Dec 2025 15:02:32 +0100 Subject: [PATCH 3/7] Change formatting for passing test --- src/pruna/evaluation/artifactsavers/image_artifactsaver.py | 5 ++--- src/pruna/evaluation/artifactsavers/utils.py | 1 + 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pruna/evaluation/artifactsavers/image_artifactsaver.py b/src/pruna/evaluation/artifactsavers/image_artifactsaver.py index eb0039f7..08385e96 100644 --- a/src/pruna/evaluation/artifactsavers/image_artifactsaver.py +++ b/src/pruna/evaluation/artifactsavers/image_artifactsaver.py @@ -22,9 +22,9 @@ import numpy as np import torch from PIL import Image - from pruna.evaluation.artifactsavers.artifactsaver import ArtifactSaver - + + class ImageArtifactSaver(ArtifactSaver): """ Save image artifacts. @@ -81,4 +81,3 @@ def save_artifact(self, data: Any, saving_kwargs: dict = dict()) -> Path: data.save(canonical_path, **saving_kwargs.copy()) return canonical_path - diff --git a/src/pruna/evaluation/artifactsavers/utils.py b/src/pruna/evaluation/artifactsavers/utils.py index 796f6e82..f30f886c 100644 --- a/src/pruna/evaluation/artifactsavers/utils.py +++ b/src/pruna/evaluation/artifactsavers/utils.py @@ -18,6 +18,7 @@ from pruna.evaluation.artifactsavers.artifactsaver import ArtifactSaver from pruna.evaluation.artifactsavers.video_artifactsaver import VideoArtifactSaver + from pruna.evaluation.artifactsavers.image_artifactsaver import ImageArtifactSaver From 07a3cae2d09a6c6e41863c849256e1042bb78ce7 Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Tue, 2 Dec 2025 15:16:05 +0100 Subject: [PATCH 4/7] Change file format for passing tests --- src/pruna/evaluation/artifactsavers/image_artifactsaver.py | 1 + src/pruna/evaluation/artifactsavers/utils.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pruna/evaluation/artifactsavers/image_artifactsaver.py b/src/pruna/evaluation/artifactsavers/image_artifactsaver.py index 08385e96..4d1c47f8 100644 --- a/src/pruna/evaluation/artifactsavers/image_artifactsaver.py +++ b/src/pruna/evaluation/artifactsavers/image_artifactsaver.py @@ -22,6 +22,7 @@ import numpy as np import torch from PIL import Image + from pruna.evaluation.artifactsavers.artifactsaver import ArtifactSaver diff --git a/src/pruna/evaluation/artifactsavers/utils.py b/src/pruna/evaluation/artifactsavers/utils.py index f30f886c..ef903b93 100644 --- a/src/pruna/evaluation/artifactsavers/utils.py +++ b/src/pruna/evaluation/artifactsavers/utils.py @@ -17,9 +17,8 @@ from pathlib import Path from pruna.evaluation.artifactsavers.artifactsaver import ArtifactSaver -from pruna.evaluation.artifactsavers.video_artifactsaver import VideoArtifactSaver - from pruna.evaluation.artifactsavers.image_artifactsaver import ImageArtifactSaver +from pruna.evaluation.artifactsavers.video_artifactsaver import VideoArtifactSaver def assign_artifact_saver( From 070108e84ab8b7cc9a8747c6d058b575216d98b2 Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Tue, 2 Dec 2025 15:17:07 +0100 Subject: [PATCH 5/7] Add artifactsaver and video_artifactsaver from vbench-integration branch as image_artifactsaver depends on these files --- .../artifactsavers/artifactsaver.py | 98 +++++++++++++++++++ .../artifactsavers/video_artifactsaver.py | 82 ++++++++++++++++ 2 files changed, 180 insertions(+) create mode 100644 src/pruna/evaluation/artifactsavers/artifactsaver.py create mode 100644 src/pruna/evaluation/artifactsavers/video_artifactsaver.py diff --git a/src/pruna/evaluation/artifactsavers/artifactsaver.py b/src/pruna/evaluation/artifactsavers/artifactsaver.py new file mode 100644 index 00000000..51bbb3d5 --- /dev/null +++ b/src/pruna/evaluation/artifactsavers/artifactsaver.py @@ -0,0 +1,98 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any + + +class ArtifactSaver(ABC): + """ + Abstract class for artifact savers. + + The artifact saver is responsible for saving the inference outputs during evaluation. + + There needs to be a subclass for each metric modality (e.g. video, image, text, etc.). + + Parameters + ---------- + export_format: str | None + The format to export the artifacts in. + root: Path | str | None + The root directory to save the artifacts in. + """ + + export_format: str | None = None + root: Path | str | None = None + + @abstractmethod + def save_artifact(self, data: Any) -> Path: + """ + Implement this method to save the artifact. + + Parameters + ---------- + data: Any + The data to save. + + Returns + ------- + Path + The full path to the saved artifact. + """ + pass + + def create_alias(self, source_path: Path | str, filename: str) -> Path: + """ + Create an alias for the artifact. + + The evaluation agent will save the inference outputs with a canonical file + formatting style that makes sense for the general case. + + If your metric requires a different file naming convention for evaluation, + you can use this method to create an alias for the artifact. + + This way we prevent duplicate artifacts from being saved and save storage space. + + By default, the alias will be created as a hardlink to the source artifact. + If the hardlink fails, a symlink will be created. + + Parameters + ---------- + source_path : Path | str + The path to the source artifact. + filename : str + The filename to create the alias for. + + Returns + ------- + Path + The full path to the alias. + """ + alias = Path(str(self.root)) / f"{filename}.{self.export_format}" + alias.parent.mkdir(parents=True, exist_ok=True) + try: + if alias.exists(): + alias.unlink() + alias.hardlink_to(source_path) + except Exception: + try: + if alias.exists(): + alias.unlink() + alias.symlink_to(source_path) + except Exception as e: + raise e + return alias diff --git a/src/pruna/evaluation/artifactsavers/video_artifactsaver.py b/src/pruna/evaluation/artifactsavers/video_artifactsaver.py new file mode 100644 index 00000000..4acae51c --- /dev/null +++ b/src/pruna/evaluation/artifactsavers/video_artifactsaver.py @@ -0,0 +1,82 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import secrets +import time +from pathlib import Path +from typing import Any + +import numpy as np +import torch +from diffusers.utils import export_to_gif, export_to_video +from PIL import Image + +from pruna.evaluation.artifactsavers.artifactsaver import ArtifactSaver + + +class VideoArtifactSaver(ArtifactSaver): + """ + Save video artifacts. + + Parameters + ---------- + root: Path | str | None = None + The root directory to save the artifacts. + export_format: str | None = "mp4" + The format to save the artifacts. + """ + + export_format: str | None + root: Path | str | None + + def __init__(self, root: Path | str | None = None, export_format: str | None = "mp4") -> None: + self.root = Path(root) if root else Path.cwd() + (self.root / "canonical").mkdir(parents=True, exist_ok=True) + self.export_format = export_format if export_format else "mp4" + + def save_artifact(self, data: Any, saving_kwargs: dict = dict()) -> Path: + """ + Save the video artifact. + + Parameters + ---------- + data: Any + The data to save. + saving_kwargs: dict + The additional kwargs to pass to the saving utility function. + + Returns + ------- + Path + The path to the saved artifact. + """ + canonical_filename = f"{int(time.time())}_{secrets.token_hex(4)}.{self.export_format}" + canonical_path = Path(str(self.root)) / "canonical" / canonical_filename + + # all diffusers saving utility functions accept a list of PIL.Images, so we convert to PIL to be safe. + if isinstance(data, torch.Tensor): + data = np.transpose(data.cpu().numpy(), (0, 2, 3, 1)) + data = np.clip(data * 255, 0, 255).astype(np.uint8) + if isinstance(data, np.ndarray): + data = [Image.fromarray(frame.astype(np.uint8)) for frame in data] + + if self.export_format == "mp4": + export_to_video(data, str(canonical_path), **saving_kwargs.copy()) + elif self.export_format == "gif": + export_to_gif(data, str(canonical_path), **saving_kwargs.copy()) + else: + raise ValueError(f"Invalid format: {self.export_format}") + return canonical_path From 14cae0ffa2fe5981c363ee9a4986669c4edec585 Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Tue, 2 Dec 2025 15:52:57 +0100 Subject: [PATCH 6/7] Change file format to pass test --- src/pruna/evaluation/metrics/vbench_utils.py | 511 +++++++++++++++++++ 1 file changed, 511 insertions(+) create mode 100644 src/pruna/evaluation/metrics/vbench_utils.py diff --git a/src/pruna/evaluation/metrics/vbench_utils.py b/src/pruna/evaluation/metrics/vbench_utils.py new file mode 100644 index 00000000..cf1ab085 --- /dev/null +++ b/src/pruna/evaluation/metrics/vbench_utils.py @@ -0,0 +1,511 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import hashlib +import re +from pathlib import Path +from typing import Any, Callable, Iterable, List + +import numpy as np +import torch +from diffusers.utils import export_to_gif, export_to_video +from diffusers.utils import load_video as diffusers_load_video +from PIL.Image import Image +from torchvision.transforms import ToTensor + +from pruna.data.pruna_datamodule import PrunaDataModule +from pruna.data.utils import define_sample_size_for_dataset, stratify_dataset +from pruna.engine.utils import safe_memory_cleanup, set_to_best_available_device +from pruna.evaluation.metrics.metric_stateful import StatefulMetric +from pruna.evaluation.metrics.result import MetricResult +from pruna.logging.logger import pruna_logger + + +class VBenchMixin: + """ + Mixin class for VBench metrics. + + Handles benchmark specific initilizations and artifact saving conventions. + """ + + def create_filename(self, prompt: str, idx: int, file_extension: str, special_str: str = "") -> str: + """ + Create filename according to VBench formatting conventions. + + Parameters + ---------- + prompt : str + The prompt to create the filename from. + idx : int + The index of the video. Vbench uses 5 seeds for each prompt. + file_extension : str + The file extension to use. Vbench supports mp4 and gif. + special_str : str + A special string to add to the filename if you wish to add a specific identifier. + + Returns + ------- + str + The filename. + """ + return create_vbench_file_name(sanitize_prompt(prompt), idx, special_str, file_extension) + + def validate_batch(self, batch: torch.Tensor) -> torch.Tensor: + """ + Make sure that the video tensor has correct dimensions. + + Parameters + ---------- + batch : torch.Tensor + The video tensor. + + Returns + ------- + torch.Tensor + The video tensor. + """ + if batch.ndim == 4: + return batch.unsqueeze(0) + elif batch.ndim != 5: + raise ValueError(f"Batch must be 4 or 5 dimensional video tensor with B,T,C,H,W, got {batch.ndim}") + return batch + + +def get_sample_seed(experiment_name: str, prompt: str, index: int) -> int: + """Get a sample seed for a given experiment name, prompt, and index.""" + key = f"{experiment_name}_{prompt}_{index}".encode('utf-8') + + return int(hashlib.sha256(key).hexdigest(), 16) % (2**32) + + +def is_file_exists(path: str | Path, filename: str) -> bool: + """Return True if the file with the given filename exists under the provided path.""" + folder = Path(path) + full_path = folder / filename + + return full_path.is_file() + + +def load_video(path: str | Path, return_type: str = "pt") -> List[Image] | np.ndarray | torch.Tensor: + """ + Load videos from a path. + + Parameters + ---------- + path : str | Path + The path to the videos. + return_type : str + The type to return the videos as. Can be "pt", "np", "pil". + + Returns + ------- + List[torch.Tensor] + The videos. + """ + video = diffusers_load_video(str(path)) + if return_type == "pt": + return torch.stack([ToTensor()(frame) for frame in video]) + elif return_type == "np": + return np.stack([np.array(frame) for frame in video]) + elif return_type == "pil": + return video + else: + raise ValueError(f"Invalid return_type: {return_type}. Use 'pt', 'np', or 'pil'.") + + +def load_videos_from_path(path: str | Path) -> torch.Tensor: + """ + Load entire directory of mp4 videos as a single tensor ready to be passed to evaluation. + + Parameters + ---------- + path : str | Path + The path to the directory of videos. + + Returns + ------- + torch.Tensor + The videos. + """ + path = Path(str(path)) + videos = torch.stack([load_video(p) for p in path.glob("*.mp4")]) + return videos + + +def sanitize_prompt(prompt: str) -> str: + """ + Return a filesystem-safe version of a prompt. + + Replaces characters illegal in filenames and collapses whitespace so that + generated files are portable across file systems. + + Parameters + ---------- + prompt : str + The prompt to sanitize. + + Returns + ------- + str + The sanitized prompt. + """ + prompt = re.sub(r"[\\/:*?\"<>|]", " ", prompt) # remove illegal chars + prompt = re.sub(r"\s+", " ", prompt) # collapse multiple spaces + prompt = prompt.strip() # remove leading and trailing whitespace + return prompt + + +def prepare_batch(batch: str | tuple[str | List[str], Any]) -> str: + """ + Prepare the batch to be used in the generate_videos function. + + Pruna datamodules are expected to yield tuples where the first element is + a sequence of inputs; this utility enforces batch_size == 1 for simplicity. + + Parameters + ---------- + batch : str | tuple[str | List[str], Any] + The batch to prepare. + + Returns + ------- + str + The prompt string. + """ + if isinstance(batch, str): + return batch + # for pruna datamodule. always returns a tuple where the first element is the input (list of prompts) to the model. + elif isinstance(batch, tuple): + if not hasattr(batch[0], "__len__"): + raise ValueError(f"Batch[0] is not a sequence (got {type(batch[0])})") + if len(batch[0]) != 1: + raise ValueError(f"Only batch size 1 is supported; got {len(batch[0])}") + return batch[0][0] + else: + raise ValueError(f"Invalid batch type: {type(batch)}") + + +def _normalize_save_format(save_format: str) -> tuple[str, Callable]: + """ + Normalize the save format to be used in the generate_videos function. + + Parameters + ---------- + save_format : str + The format to save the videos in. VBench supports mp4 and gif. + + Returns + ------- + tuple[str, Callable] + The normalized save format and the save function. + """ + save_format = save_format.lower().strip() + if save_format == "mp4": + return ".mp4", export_to_video + if save_format == "gif": + return ".gif", export_to_gif + raise ValueError(f"Invalid save_format: {save_format}. Use 'mp4' or 'gif'.") + + +def _normalize_prompts( + prompts: str | List[str] | PrunaDataModule, split: str = "test", batch_size: int = 1, num_samples: int | None = None, + fraction: float = 1.0, data_partition_strategy: str = "indexed", partition_index: int = 0, seed: int = 42 +) -> Iterable[str]: + """ + Normalize prompts to an iterable format to be used in the generate_videos function. + + Parameters + ---------- + prompts : str | List[str] | PrunaDataModule + The prompts to normalize. + split : str + The dataset split to sample from. + batch_size : int + The batch size to sample from. + num_samples : int | None + The number of samples to sample from. + fraction : float + The fraction of the dataset to sample from. + data_partition_strategy : str + The strategy to use for partitioning the dataset. Can be "indexed" or "random". + partition_index : int + The index to use for partitioning the dataset. + seed : int + The seed to use for partitioning the dataset. + + Returns + ------- + Iterable[str] + The normalized prompts. + """ + if isinstance(prompts, str): + return [prompts] + elif isinstance(prompts, PrunaDataModule): + target_dataset = getattr(prompts, f"{split}_dataset") + sample_size = define_sample_size_for_dataset(target_dataset, fraction, num_samples) + setattr(prompts, f"{split}_dataset", stratify_dataset(target_dataset, sample_size, seed, data_partition_strategy, + partition_index)) + return getattr(prompts, f"{split}_dataloader")(batch_size=batch_size) + else: # list of prompts, already iterable + return prompts + + +def _ensure_dir(p: Path) -> None: + """ + Ensure the directory exists. + + Parameters + ---------- + p : Path + The path to ensure the directory exists. + """ + p.mkdir(parents=True, exist_ok=True) + + +def create_vbench_file_name( + prompt: str, idx: int, special_str: str = "", save_format: str = ".mp4", max_filename_length: int = 255 +) -> str: + """ + Create a file name for the video in accordance with the VBench format. + + Parameters + ---------- + prompt : str + The prompt to create the file name from. + idx : int + The index of the video. Vbench uses 5 seeds for each prompt. + special_str : str + A special string to add to the file name if you wish to add a specific identifier. + save_format : str + The format of the video file. Vbench supports mp4 and gif. + max_filename_length : int + The maximum length allowed for the file name. + + Returns + ------- + str + The file name for the video. + """ + filename = f"{prompt}{special_str}-{str(idx)}{save_format}" + if len(filename) > max_filename_length: + pruna_logger.debug( + f"File name {filename} is too long. Maximum length is {max_filename_length} characters. Truncating filename." + ) + filename = filename[:max_filename_length] + return filename + + +def sample_video_from_pipelines(pipeline: Any, seeder: Any, prompt: str, **kwargs): + """ + Sample a video from diffusers pipeline. + + Parameters + ---------- + pipeline : Any + The pipeline to sample from. + seeder : Any + The seeding generator. + prompt : str + The prompt to sample from. + **kwargs : Any + Additional keyword arguments to pass to the pipeline. + + Returns + ------- + torch.Tensor + The video tensor. + """ + is_return_dict = kwargs.pop("return_dict", True) + with torch.inference_mode(): + if is_return_dict: + out = pipeline(prompt=prompt, generator=seeder, **kwargs).frames[0] + else: + # If return_dict is False, the pipeline returns a tuple of (frames, metadata). + out = pipeline(prompt=prompt, generator=seeder, **kwargs)[0] + + return out + + +def _wrap_sampler(model: Any, sampling_fn: Callable[..., Any]) -> Callable[..., Any]: + """ + Wrap a user-provided sampling function into a uniform callable. + + The returned callable has a keyword-only signature: + sampler(*, prompt: str, seeder: Any, device: str|torch.device, **kwargs) + + This wrapper always passes `model` as the first positional argument, so + custom functions can name their first parameter `model` or `pipeline`, etc. + + Parameters + ---------- + model : Any + The model to sample from. + sampling_fn : Callable[..., Any] + The sampling function to wrap. + + Returns + ------- + Callable[..., Any] + The wrapped sampling function. + """ + if sampling_fn != sample_video_from_pipelines: + pruna_logger.info( + "Using custom sampling function. Ensure it accepts (model, *, prompt, seeder, device, **kwargs)." + ) + + # The sampling function may expect the model as "pipeline" so we pass it as an arg and not a kwarg. + def sampler(*, prompt: str, seeder: Any, **kwargs: Any) -> Any: + return sampling_fn(model, prompt=prompt, seeder=seeder, **kwargs) + + return sampler + + +def generate_videos( + model: Any, + prompts: str | List[str] | PrunaDataModule, + num_samples: int | None = None, + samples_fraction: float = 1.0, + data_partition_strategy: str = "indexed", + partition_index: int = 0, + split: str = "test", + unique_sample_per_video_count: int = 1, + global_seed: int = 42, + sampling_fn: Callable[..., Any] = sample_video_from_pipelines, + fps: int = 16, + save_dir: str | Path = "./saved_videos", + save_format: str = "mp4", + filename_fn: Callable = create_vbench_file_name, + special_str: str = "", + device: str | torch.device = None, + experiment_name: str = "", + sampling_seed_fn: Callable[..., Any] = get_sample_seed, + **model_kwargs, +) -> None: + """ + Generate N samples per prompt and save them to disk with seed tracking. + + This function: + - Normalizes prompts (string, list, or datamodule). + - Uses an RNG seeded with `global_seed` for reproducibility across runs. + - Saves videos as MP4 or GIF. + + Parameters + ---------- + model : Any + The model to sample from. + prompts : str | List[str] | PrunaDataModule + The prompts to sample from. + split : str + The split to sample from. + Default is "test" since most benchmarking datamodules in Pruna are configured to use the test split. + unique_sample_per_video_count : int + The number of unique samples per video. Default is 5 by VBench requirements. + global_seed : int + The global seed to sample from. + sampling_fn : Callable[..., Any] + The sampling function to use. + fps : int + The frames per second of the video. + save_dir : str | Path + The directory to save the videos to. + save_format : str + The format to save the videos in. VBench supports mp4 and gif. + filename_fn : Callable + The function to create the file name. + special_str : str + A special string to add to the file name if you wish to add a specific identifier. + device : str | torch.device | None + The device to sample on. If None, the best available device will be used. + **model_kwargs : Any + Additional keyword arguments to pass to the sampling function. + """ + file_extension, save_fn = _normalize_save_format(save_format) + + device = set_to_best_available_device(device) + + prompt_iterable = _normalize_prompts( + prompts, + split, + batch_size=1, + num_samples=num_samples, + fraction=samples_fraction, + data_partition_strategy=data_partition_strategy, + partition_index=partition_index, + seed=global_seed, +) + + save_dir = Path(save_dir) + _ensure_dir(save_dir) + + # set a run-level seed (VBench suggests this) (important for reproducibility) + def _seed_rng(x: int) -> torch.Generator: + """Create a CPU torch.Generator seeded with the given integer.""" + return torch.Generator("cpu").manual_seed(x) + + sampler = _wrap_sampler(model=model, sampling_fn=sampling_fn) + + for batch in prompt_iterable: + prompt = prepare_batch(batch) + for idx in range(unique_sample_per_video_count): + file_name = filename_fn(sanitize_prompt(prompt), idx, special_str, file_extension) + out_path = save_dir / file_name + + if is_file_exists(save_dir, file_name): + continue + else: + seed = sampling_seed_fn(experiment_name, prompt, idx) + vid = sampler(prompt=prompt, seeder=_seed_rng(seed), **model_kwargs) + save_fn(vid, out_path, fps=fps) + + del vid + safe_memory_cleanup() + + +def evaluate_videos( + data: Any, metrics: StatefulMetric | List[StatefulMetric], prompts: Any | None = None +) -> List[MetricResult]: + """ + Evaluation loop helper. + + Parameters + ---------- + data : Any + The data to evaluate. + metrics : StatefulMetric | List[StatefulMetric] + The metrics to evaluate. + prompts : Any | None + The prompts to evaluate. + + Returns + ------- + List[MetricResult] + The results of the evaluation. + """ + results = [] + if isinstance(metrics, StatefulMetric): + metrics = [metrics] + if any(metric.call_type != "y" for metric in metrics) and prompts is None: + raise ValueError( + "You are trying to evaluate metrics that require more than the outputs, but didn't provide prompts." + ) + for metric in metrics: + for batch in data: + if prompts is None: + prompts = batch + metric.update(prompts, batch, batch) + prompts = None + results.append(metric.compute()) + return results From 737947e8c774a340a874840ff7abb9da0179f6f6 Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Wed, 3 Dec 2025 15:00:47 +0100 Subject: [PATCH 7/7] Added ignore to pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3aca3631..751b8715 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ no-matching-overload = "ignore" # mypy is more permissive with overloads unresolved-reference = "ignore" # mypy is more permissive with references possibly-unbound-import = "ignore" missing-argument = "ignore" - +possibly-unbound-attribute = "ignore" [tool.coverage.run] source = ["src/pruna"]