-
Notifications
You must be signed in to change notification settings - Fork 75
feat: Image artifactsaver #449
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
97e8e10
8f34df2
7594363
07a3cae
070108e
14cae0f
737947e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| # Copyright 2025 - Pruna AI GmbH. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from abc import ABC, abstractmethod | ||
| from pathlib import Path | ||
| from typing import Any | ||
|
|
||
|
|
||
| class ArtifactSaver(ABC): | ||
| """ | ||
| Abstract class for artifact savers. | ||
|
|
||
| The artifact saver is responsible for saving the inference outputs during evaluation. | ||
|
|
||
| There needs to be a subclass for each metric modality (e.g. video, image, text, etc.). | ||
|
|
||
| Parameters | ||
| ---------- | ||
| export_format: str | None | ||
| The format to export the artifacts in. | ||
| root: Path | str | None | ||
| The root directory to save the artifacts in. | ||
| """ | ||
|
|
||
| export_format: str | None = None | ||
| root: Path | str | None = None | ||
|
|
||
| @abstractmethod | ||
| def save_artifact(self, data: Any) -> Path: | ||
| """ | ||
| Implement this method to save the artifact. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| data: Any | ||
| The data to save. | ||
|
|
||
| Returns | ||
| ------- | ||
| Path | ||
| The full path to the saved artifact. | ||
| """ | ||
| pass | ||
|
|
||
| def create_alias(self, source_path: Path | str, filename: str) -> Path: | ||
| """ | ||
| Create an alias for the artifact. | ||
|
|
||
| The evaluation agent will save the inference outputs with a canonical file | ||
| formatting style that makes sense for the general case. | ||
|
|
||
| If your metric requires a different file naming convention for evaluation, | ||
| you can use this method to create an alias for the artifact. | ||
|
|
||
| This way we prevent duplicate artifacts from being saved and save storage space. | ||
|
|
||
| By default, the alias will be created as a hardlink to the source artifact. | ||
| If the hardlink fails, a symlink will be created. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| source_path : Path | str | ||
| The path to the source artifact. | ||
| filename : str | ||
| The filename to create the alias for. | ||
|
|
||
| Returns | ||
| ------- | ||
| Path | ||
| The full path to the alias. | ||
| """ | ||
| alias = Path(str(self.root)) / f"{filename}.{self.export_format}" | ||
| alias.parent.mkdir(parents=True, exist_ok=True) | ||
| try: | ||
| if alias.exists(): | ||
| alias.unlink() | ||
| alias.hardlink_to(source_path) | ||
| except Exception: | ||
| try: | ||
| if alias.exists(): | ||
| alias.unlink() | ||
| alias.symlink_to(source_path) | ||
| except Exception as e: | ||
| raise e | ||
| return alias |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| # Copyright 2025 - Pruna AI GmbH. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import secrets | ||
| import time | ||
| from pathlib import Path | ||
| from typing import Any | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from PIL import Image | ||
|
|
||
| from pruna.evaluation.artifactsavers.artifactsaver import ArtifactSaver | ||
|
|
||
|
|
||
| class ImageArtifactSaver(ArtifactSaver): | ||
| """ | ||
| Save image artifacts. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| root: Path | str | None = None | ||
| The root directory to save the artifacts. | ||
| export_format: str | None = "png" | ||
| The format to save the artifacts (e.g. "png", "jpg", "jpeg", "webp"). | ||
| """ | ||
|
|
||
| export_format: str | None | ||
| root: Path | str | None | ||
|
|
||
| def __init__(self, root: Path | str | None = None, export_format: str | None = "png") -> None: | ||
| self.root = Path(root) if root else Path.cwd() | ||
| (self.root / "canonical").mkdir(parents=True, exist_ok=True) | ||
| self.export_format = export_format if export_format else "png" | ||
| if self.export_format not in ["png", "jpg", "jpeg", "webp"]: | ||
| raise ValueError(f"Invalid format: {self.export_format}. Valid formats are: png, jpg, jpeg, webp.") | ||
|
|
||
| def save_artifact(self, data: Any, saving_kwargs: dict = dict()) -> Path: | ||
| """ | ||
| Save the image artifact. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| data: Any | ||
| The data to save. | ||
| saving_kwargs: dict | ||
| The additional kwargs to pass to the saving utility function. | ||
|
|
||
| Returns | ||
| ------- | ||
| Path | ||
| The path to the saved artifact. | ||
| """ | ||
| canonical_filename = f"{int(time.time())}_{secrets.token_hex(4)}.{self.export_format}" | ||
| canonical_path = Path(str(self.root)) / "canonical" / canonical_filename | ||
|
|
||
| # We save the image as a PIL Image, so we need to convert the data to a PIL Image. | ||
| # Usually, the data is already a PIL.Image, so we don't need to convert it. | ||
| if isinstance(data, torch.Tensor): | ||
| data = np.transpose(data.cpu().numpy(), (1, 2, 0)) | ||
| data = np.clip(data * 255, 0, 255).astype(np.uint8) | ||
| if isinstance(data, np.ndarray): | ||
| data = Image.fromarray(data.astype(np.uint8)) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Float numpy arrays not scaled before uint8 conversionWhen a |
||
| # Now data must be a PIL Image | ||
| if not isinstance(data, Image.Image): | ||
| raise ValueError("Model outputs must be torch.Tensor, numpy.ndarray, or PIL.Image.") | ||
|
|
||
| # Save the image (export format is determined by the file extension) | ||
| data.save(canonical_path, **saving_kwargs.copy()) | ||
|
|
||
| return canonical_path | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| # Copyright 2025 - Pruna AI GmbH. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from pathlib import Path | ||
|
|
||
| from pruna.evaluation.artifactsavers.artifactsaver import ArtifactSaver | ||
| from pruna.evaluation.artifactsavers.image_artifactsaver import ImageArtifactSaver | ||
| from pruna.evaluation.artifactsavers.video_artifactsaver import VideoArtifactSaver | ||
|
|
||
|
|
||
| def assign_artifact_saver( | ||
| modality: str, root: Path | str | None = None, export_format: str | None = None | ||
| ) -> ArtifactSaver: | ||
| """ | ||
| Assign the appropriate artifact saver based on the modality. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| modality: str | ||
| The modality of the data. | ||
| root: str | ||
| The root directory to save the artifacts. | ||
| export_format: str | ||
| The format to save the artifacts. | ||
|
|
||
| Returns | ||
| ------- | ||
| ArtifactSaver | ||
| The appropriate artifact saver. | ||
| """ | ||
| if modality == "video": | ||
| return VideoArtifactSaver(root=root, export_format=export_format) | ||
| if modality == "image": | ||
| return ImageArtifactSaver(root=root, export_format=export_format) | ||
| else: | ||
| raise ValueError(f"Modality {modality} is not supported") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| # Copyright 2025 - Pruna AI GmbH. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import secrets | ||
| import time | ||
| from pathlib import Path | ||
| from typing import Any | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from diffusers.utils import export_to_gif, export_to_video | ||
| from PIL import Image | ||
|
|
||
| from pruna.evaluation.artifactsavers.artifactsaver import ArtifactSaver | ||
|
|
||
|
|
||
| class VideoArtifactSaver(ArtifactSaver): | ||
| """ | ||
| Save video artifacts. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| root: Path | str | None = None | ||
| The root directory to save the artifacts. | ||
| export_format: str | None = "mp4" | ||
| The format to save the artifacts. | ||
| """ | ||
|
|
||
| export_format: str | None | ||
| root: Path | str | None | ||
|
|
||
| def __init__(self, root: Path | str | None = None, export_format: str | None = "mp4") -> None: | ||
| self.root = Path(root) if root else Path.cwd() | ||
| (self.root / "canonical").mkdir(parents=True, exist_ok=True) | ||
| self.export_format = export_format if export_format else "mp4" | ||
|
|
||
| def save_artifact(self, data: Any, saving_kwargs: dict = dict()) -> Path: | ||
| """ | ||
| Save the video artifact. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| data: Any | ||
| The data to save. | ||
| saving_kwargs: dict | ||
| The additional kwargs to pass to the saving utility function. | ||
|
|
||
| Returns | ||
| ------- | ||
| Path | ||
| The path to the saved artifact. | ||
| """ | ||
| canonical_filename = f"{int(time.time())}_{secrets.token_hex(4)}.{self.export_format}" | ||
| canonical_path = Path(str(self.root)) / "canonical" / canonical_filename | ||
|
|
||
| # all diffusers saving utility functions accept a list of PIL.Images, so we convert to PIL to be safe. | ||
| if isinstance(data, torch.Tensor): | ||
| data = np.transpose(data.cpu().numpy(), (0, 2, 3, 1)) | ||
| data = np.clip(data * 255, 0, 255).astype(np.uint8) | ||
| if isinstance(data, np.ndarray): | ||
| data = [Image.fromarray(frame.astype(np.uint8)) for frame in data] | ||
|
|
||
| if self.export_format == "mp4": | ||
| export_to_video(data, str(canonical_path), **saving_kwargs.copy()) | ||
| elif self.export_format == "gif": | ||
| export_to_gif(data, str(canonical_path), **saving_kwargs.copy()) | ||
| else: | ||
| raise ValueError(f"Invalid format: {self.export_format}") | ||
| return canonical_path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Uint8 torch tensors incorrectly multiplied by 255
The torch tensor conversion unconditionally multiplies values by 255, assuming float data normalized to [0, 1]. When a
uint8tensor with values in [0, 255] is passed, multiplying by 255 produces values up to 65025, which after clipping means any pixel value ≥2 becomes 255. This corrupts the image, making it nearly all white/saturated. The code needs to check the tensor's dtype before scaling.