Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ no-matching-overload = "ignore" # mypy is more permissive with overloads
unresolved-reference = "ignore" # mypy is more permissive with references
possibly-unbound-import = "ignore"
missing-argument = "ignore"

possibly-unbound-attribute = "ignore"
[tool.coverage.run]
source = ["src/pruna"]

Expand Down
98 changes: 98 additions & 0 deletions src/pruna/evaluation/artifactsavers/artifactsaver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any


class ArtifactSaver(ABC):
"""
Abstract class for artifact savers.

The artifact saver is responsible for saving the inference outputs during evaluation.

There needs to be a subclass for each metric modality (e.g. video, image, text, etc.).

Parameters
----------
export_format: str | None
The format to export the artifacts in.
root: Path | str | None
The root directory to save the artifacts in.
"""

export_format: str | None = None
root: Path | str | None = None

@abstractmethod
def save_artifact(self, data: Any) -> Path:
"""
Implement this method to save the artifact.

Parameters
----------
data: Any
The data to save.

Returns
-------
Path
The full path to the saved artifact.
"""
pass

def create_alias(self, source_path: Path | str, filename: str) -> Path:
"""
Create an alias for the artifact.

The evaluation agent will save the inference outputs with a canonical file
formatting style that makes sense for the general case.

If your metric requires a different file naming convention for evaluation,
you can use this method to create an alias for the artifact.

This way we prevent duplicate artifacts from being saved and save storage space.

By default, the alias will be created as a hardlink to the source artifact.
If the hardlink fails, a symlink will be created.

Parameters
----------
source_path : Path | str
The path to the source artifact.
filename : str
The filename to create the alias for.

Returns
-------
Path
The full path to the alias.
"""
alias = Path(str(self.root)) / f"{filename}.{self.export_format}"
alias.parent.mkdir(parents=True, exist_ok=True)
try:
if alias.exists():
alias.unlink()
alias.hardlink_to(source_path)
except Exception:
try:
if alias.exists():
alias.unlink()
alias.symlink_to(source_path)
except Exception as e:
raise e
return alias
84 changes: 84 additions & 0 deletions src/pruna/evaluation/artifactsavers/image_artifactsaver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import secrets
import time
from pathlib import Path
from typing import Any

import numpy as np
import torch
from PIL import Image

from pruna.evaluation.artifactsavers.artifactsaver import ArtifactSaver


class ImageArtifactSaver(ArtifactSaver):
"""
Save image artifacts.

Parameters
----------
root: Path | str | None = None
The root directory to save the artifacts.
export_format: str | None = "png"
The format to save the artifacts (e.g. "png", "jpg", "jpeg", "webp").
"""

export_format: str | None
root: Path | str | None

def __init__(self, root: Path | str | None = None, export_format: str | None = "png") -> None:
self.root = Path(root) if root else Path.cwd()
(self.root / "canonical").mkdir(parents=True, exist_ok=True)
self.export_format = export_format if export_format else "png"
if self.export_format not in ["png", "jpg", "jpeg", "webp"]:
raise ValueError(f"Invalid format: {self.export_format}. Valid formats are: png, jpg, jpeg, webp.")

def save_artifact(self, data: Any, saving_kwargs: dict = dict()) -> Path:
"""
Save the image artifact.

Parameters
----------
data: Any
The data to save.
saving_kwargs: dict
The additional kwargs to pass to the saving utility function.

Returns
-------
Path
The path to the saved artifact.
"""
canonical_filename = f"{int(time.time())}_{secrets.token_hex(4)}.{self.export_format}"
canonical_path = Path(str(self.root)) / "canonical" / canonical_filename

# We save the image as a PIL Image, so we need to convert the data to a PIL Image.
# Usually, the data is already a PIL.Image, so we don't need to convert it.
if isinstance(data, torch.Tensor):
data = np.transpose(data.cpu().numpy(), (1, 2, 0))
data = np.clip(data * 255, 0, 255).astype(np.uint8)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Uint8 torch tensors incorrectly multiplied by 255

The torch tensor conversion unconditionally multiplies values by 255, assuming float data normalized to [0, 1]. When a uint8 tensor with values in [0, 255] is passed, multiplying by 255 produces values up to 65025, which after clipping means any pixel value ≥2 becomes 255. This corrupts the image, making it nearly all white/saturated. The code needs to check the tensor's dtype before scaling.

Fix in Cursor Fix in Web

if isinstance(data, np.ndarray):
data = Image.fromarray(data.astype(np.uint8))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Float numpy arrays not scaled before uint8 conversion

When a numpy.ndarray with float dtype (values in [0.0, 1.0]) is passed, it's directly cast to uint8 without first multiplying by 255. All float values less than 1.0 truncate to 0, resulting in an all-black image. Unlike the torch tensor path, there's no scaling applied here before the conversion to uint8.

Fix in Cursor Fix in Web

# Now data must be a PIL Image
if not isinstance(data, Image.Image):
raise ValueError("Model outputs must be torch.Tensor, numpy.ndarray, or PIL.Image.")

# Save the image (export format is determined by the file extension)
data.save(canonical_path, **saving_kwargs.copy())

return canonical_path
49 changes: 49 additions & 0 deletions src/pruna/evaluation/artifactsavers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from pathlib import Path

from pruna.evaluation.artifactsavers.artifactsaver import ArtifactSaver
from pruna.evaluation.artifactsavers.image_artifactsaver import ImageArtifactSaver
from pruna.evaluation.artifactsavers.video_artifactsaver import VideoArtifactSaver


def assign_artifact_saver(
modality: str, root: Path | str | None = None, export_format: str | None = None
) -> ArtifactSaver:
"""
Assign the appropriate artifact saver based on the modality.

Parameters
----------
modality: str
The modality of the data.
root: str
The root directory to save the artifacts.
export_format: str
The format to save the artifacts.

Returns
-------
ArtifactSaver
The appropriate artifact saver.
"""
if modality == "video":
return VideoArtifactSaver(root=root, export_format=export_format)
if modality == "image":
return ImageArtifactSaver(root=root, export_format=export_format)
else:
raise ValueError(f"Modality {modality} is not supported")
82 changes: 82 additions & 0 deletions src/pruna/evaluation/artifactsavers/video_artifactsaver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import secrets
import time
from pathlib import Path
from typing import Any

import numpy as np
import torch
from diffusers.utils import export_to_gif, export_to_video
from PIL import Image

from pruna.evaluation.artifactsavers.artifactsaver import ArtifactSaver


class VideoArtifactSaver(ArtifactSaver):
"""
Save video artifacts.

Parameters
----------
root: Path | str | None = None
The root directory to save the artifacts.
export_format: str | None = "mp4"
The format to save the artifacts.
"""

export_format: str | None
root: Path | str | None

def __init__(self, root: Path | str | None = None, export_format: str | None = "mp4") -> None:
self.root = Path(root) if root else Path.cwd()
(self.root / "canonical").mkdir(parents=True, exist_ok=True)
self.export_format = export_format if export_format else "mp4"

def save_artifact(self, data: Any, saving_kwargs: dict = dict()) -> Path:
"""
Save the video artifact.

Parameters
----------
data: Any
The data to save.
saving_kwargs: dict
The additional kwargs to pass to the saving utility function.

Returns
-------
Path
The path to the saved artifact.
"""
canonical_filename = f"{int(time.time())}_{secrets.token_hex(4)}.{self.export_format}"
canonical_path = Path(str(self.root)) / "canonical" / canonical_filename

# all diffusers saving utility functions accept a list of PIL.Images, so we convert to PIL to be safe.
if isinstance(data, torch.Tensor):
data = np.transpose(data.cpu().numpy(), (0, 2, 3, 1))
data = np.clip(data * 255, 0, 255).astype(np.uint8)
if isinstance(data, np.ndarray):
data = [Image.fromarray(frame.astype(np.uint8)) for frame in data]

if self.export_format == "mp4":
export_to_video(data, str(canonical_path), **saving_kwargs.copy())
elif self.export_format == "gif":
export_to_gif(data, str(canonical_path), **saving_kwargs.copy())
else:
raise ValueError(f"Invalid format: {self.export_format}")
return canonical_path
Loading
Loading