Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9114d33
Update README.md
unnir May 26, 2021
6123386
Merge pull request #27 from unnir/patch-1
greentfrapp Aug 1, 2021
66b7332
reshape one-channel images for PIL
Jul 13, 2022
8a5ead4
Set numpy seed
dozed Sep 21, 2023
9bf8d0d
handle single channel images in tensor_to_img
alik-git Nov 24, 2023
ab8e855
Convert elements of scale_shape from numpy.int64 to int to ensure com…
Dec 14, 2023
4f25829
updated imports from Kornia
Tal-Golan Jan 11, 2024
558b10b
pytorch related update - make sure size arg to nn.Upsample is int
Tal-Golan Jan 11, 2024
3ae8c59
emulate Chrome request to bypass onedrive download link block
Tal-Golan Jan 11, 2024
c02a9c1
Merge pull request #46 from dozed/fixed-seeding
greentfrapp May 19, 2024
33f3b36
Merge pull request #48 from smarginatura/dev2
greentfrapp May 19, 2024
4da5e46
Merge pull request #39 from neoglez/dev
greentfrapp May 19, 2024
ecb5f38
Merge pull request #47 from alik-git/handle_1_channel_images
greentfrapp May 19, 2024
d82a93c
Remove conversion to uint8
greentfrapp May 19, 2024
ec474e9
Merge branch 'handle-1-channel-images' into dev
greentfrapp May 19, 2024
c4e4fd8
Merge pull request #49 from brainsandmachines/dev
greentfrapp May 19, 2024
9904e02
Fix #45 - Clear module hooks
greentfrapp May 19, 2024
fee0326
Make hook_model backward-compatible
greentfrapp May 19, 2024
fa5bfdc
Merge pull request #52 from greentfrapp/clear-module-hooks
greentfrapp May 19, 2024
d43479b
Refactor device and detect device where possible
greentfrapp Jun 8, 2024
a4fdb64
Detect model device in render_vis function
greentfrapp Jun 8, 2024
f101aea
Fix broken tests and model.device bug
greentfrapp Jun 8, 2024
f1c5550
Merge branch 'refactor-device' into dev
greentfrapp Jun 8, 2024
b17edc6
Add coverage CI
greentfrapp Mar 21, 2025
b666556
Add dep installation to coverage workflow
greentfrapp Mar 21, 2025
cb145e0
Add test result upload to workflow
greentfrapp Mar 21, 2025
2ccc434
Fix workflow config
greentfrapp Mar 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions .github/workflows/coverage.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: Run tests and upload coverage

on:
[workflow_dispatch, pull_request]

jobs:
test:
name: Run tests and collect coverage
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 2

- name: Set up Python
uses: actions/setup-python@v4

- name: Install repo dependencies
run: pip install .

- name: Install test dependencies
run: pip install pytest pytest-cov

- name: Test with pytest
run: |
pytest --cov --junitxml=junit.xml -o junit_family=legacy

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}

- name: Upload test results to Codecov
if: ${{ !cancelled() }}
uses: codecov/test-results-action@v1
with:
token: ${{ secrets.CODECOV_TOKEN }}
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ You can also clone this repository and run the notebooks locally with [Jupyter](

## Quickstart

```
```python
import torch

from lucent.optvis import render
Expand Down
4 changes: 2 additions & 2 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ def main():

if CPPN:
# CPPN parameterization
param_f = lambda: param.cppn(224)
param_f = lambda: param.cppn(224, device=device)
opt = lambda params: torch.optim.Adam(params, 5e-3)
# Some objectives work better with CPPN than others
obj = "mixed4d_3x3_bottleneck_pre_relu_conv:139"
else:
param_f = lambda: param.image(224, fft=SPATIAL_DECORRELATION, decorrelate=CHANNEL_DECORRELATION)
param_f = lambda: param.image(224, fft=SPATIAL_DECORRELATION, decorrelate=CHANNEL_DECORRELATION, device=device)
opt = lambda params: torch.optim.Adam(params, 5e-2)
obj = "mixed4a:476"

Expand Down
3 changes: 1 addition & 2 deletions lucent/optvis/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,14 +231,13 @@ def inner(model):


def _torch_blur(tensor, out_c=3):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
depth = tensor.shape[1]
weight = np.zeros([depth, depth, out_c, out_c])
for ch in range(depth):
weight_ch = weight[ch, ch, :, :]
weight_ch[ : , : ] = 0.5
weight_ch[1:-1, 1:-1] = 1.0
weight_t = torch.tensor(weight).float().to(device)
weight_t = torch.tensor(weight).float().to(tensor.device)
conv_f = lambda t: F.conv2d(t, weight_t, None, 1, 1)
return conv_f(tensor) / conv_f(torch.ones_like(tensor))

Expand Down
3 changes: 1 addition & 2 deletions lucent/optvis/param/color.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@


def _linear_decorrelate_color(tensor):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
t_permute = tensor.permute(0, 2, 3, 1)
t_permute = torch.matmul(t_permute, torch.tensor(color_correlation_normalized.T).to(device))
t_permute = torch.matmul(t_permute, torch.tensor(color_correlation_normalized.T).to(t_permute.device))
tensor = t_permute.permute(0, 3, 1, 2)
return tensor

Expand Down
7 changes: 5 additions & 2 deletions lucent/optvis/param/cppn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import torch
import numpy as np

from lucent.util import DEFAULT_DEVICE


class CompositeActivation(torch.nn.Module):

Expand All @@ -28,15 +30,16 @@ def forward(self, x):


def cppn(size, num_output_channels=3, num_hidden_channels=24, num_layers=8,
activation_fn=CompositeActivation, normalize=False):
activation_fn=CompositeActivation, normalize=False,
device=DEFAULT_DEVICE):

r = 3 ** 0.5

coord_range = torch.linspace(-r, r, size)
x = coord_range.view(-1, 1).repeat(1, coord_range.size(0))
y = coord_range.view(1, -1).repeat(coord_range.size(0), 1)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device(device)

input_tensor = torch.stack([x, y], dim=0).unsqueeze(0).to(device)

Expand Down
26 changes: 24 additions & 2 deletions lucent/optvis/param/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,37 @@
"fc7": "https://onedrive.live.com/download?cid=9CFFF6BCB39F6829&resid=9CFFF6BCB39F6829%2145338&authkey=AJ0R-daUAVYjQIw",
"fc8": "https://onedrive.live.com/download?cid=9CFFF6BCB39F6829&resid=9CFFF6BCB39F6829%2145340&authkey=AKIfNk7s5MGrRkU"}

def download_url_to_file_fake_request(url, dst):
"""
Download object at the given URL to a local path, using browser-like HTTP GET request.
"""

import requests
from tqdm import tqdm

# Imitate Chrome browser
headers = {"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_14_6) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/76.0.3809.132 Safari/537.36"}

with requests.get(url, headers=headers, stream=True) as r:
r.raise_for_status()
with open(dst, 'wb') as f:
for chunk in tqdm(r.iter_content(chunk_size=8192)):
f.write(chunk)

def load_statedict_from_online(name="fc6"):
torchhome = torch.hub._get_torch_home()
ckpthome = join(torchhome, "checkpoints")
os.makedirs(ckpthome, exist_ok=True)
filepath = join(ckpthome, "upconvGAN_%s.pt"%name)
if not os.path.exists(filepath):
torch.hub.download_url_to_file(model_urls[name], filepath, hash_prefix=None,
progress=True)
print("Downloading %s"%model_urls[name])
download_url_to_file_fake_request(model_urls[name], filepath)

# this is blocked by onedrive
#torch.hub.download_url_to_file(model_urls[name], filepath, hash_prefix=None,
# progress=True)
SD = torch.load(filepath)
return SD

Expand Down
5 changes: 3 additions & 2 deletions lucent/optvis/param/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@

from lucent.optvis.param.spatial import pixel_image, fft_image
from lucent.optvis.param.color import to_valid_rgb
from lucent.util import DEFAULT_DEVICE


def image(w, h=None, sd=None, batch=None, decorrelate=True,
fft=True, channels=None):
fft=True, channels=None, device=DEFAULT_DEVICE):
h = h or w
batch = batch or 1
ch = channels or 3
shape = [batch, ch, h, w]
param_f = fft_image if fft else pixel_image
params, image_f = param_f(shape, sd=sd)
params, image_f = param_f(shape, sd=sd, device=device)
if channels:
output = to_valid_rgb(image_f, decorrelate=False)
else:
Expand Down
6 changes: 4 additions & 2 deletions lucent/optvis/param/lowres.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
import torch.nn.functional as F

from lucent.optvis.param.resize_bilinear_nd import resize_bilinear_nd
from lucent.util import DEFAULT_DEVICE


def lowres_tensor(shape, underlying_shape, offset=None, sd=0.01):
def lowres_tensor(shape, underlying_shape, offset=None, sd=0.01,
device=DEFAULT_DEVICE):
"""Produces a tensor paramaterized by a interpolated lower resolution tensor.
This is like what is done in a laplacian pyramid, but a bit more general. It
can be a powerful way to describe images.
Expand All @@ -41,7 +43,7 @@ def lowres_tensor(shape, underlying_shape, offset=None, sd=0.01):
Returns:
A tensor paramaterized by a lower resolution tensorflow variable.
"""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device(device)
underlying_t = (torch.randn(*underlying_shape) * sd).to(device).requires_grad_(True)
if offset is not None:
# Deal with non-list offset
Expand Down
7 changes: 4 additions & 3 deletions lucent/optvis/param/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
import torch
import numpy as np

from lucent.util import DEFAULT_DEVICE


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
TORCH_VERSION = torch.__version__


def pixel_image(shape, sd=None):
def pixel_image(shape, sd=None, device=DEFAULT_DEVICE):
sd = sd or 0.01
tensor = (torch.randn(*shape) * sd).to(device).requires_grad_(True)
return [tensor], lambda: tensor
Expand All @@ -42,7 +43,7 @@ def rfft2d_freqs(h, w):
return np.sqrt(fx * fx + fy * fy)


def fft_image(shape, sd=None, decay_power=1):
def fft_image(shape, sd=None, decay_power=1, device=DEFAULT_DEVICE):
batch, channels, h, w = shape
freqs = rfft2d_freqs(h, w)
init_val_size = (batch, channels) + freqs.shape + (2,) # 2 for imaginary and real components
Expand Down
18 changes: 15 additions & 3 deletions lucent/optvis/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def render_vis(
fixed_image_size=None,
):
if param_f is None:
param_f = lambda: param.image(128)
param_f = lambda: param.image(128, device=next(model.parameters()).device)
# param_f is a function that should return two things
# params - parameters to update, which we pass to the optimizer
# image_f - a function that returns an image as a tensor
Expand Down Expand Up @@ -81,7 +81,7 @@ def render_vis(

transform_f = transform.compose(transforms)

hook = hook_model(model, image_f)
hook, features = hook_model(model, image_f, return_hooks=True)
objective_f = objectives.as_objective(objective_f)

if verbose:
Expand Down Expand Up @@ -124,6 +124,10 @@ def closure():
print("Loss at step {}: {:.3f}".format(i, objective_f(hook)))
images.append(tensor_to_img_array(image_f()))

# Clear hooks
for module_hook in features.values():
del module_hook.module._forward_hooks[module_hook.hook.id]

if save_image:
export(image_f(), image_name)
if show_inline:
Expand All @@ -136,6 +140,9 @@ def closure():
def tensor_to_img_array(tensor):
image = tensor.cpu().detach().numpy()
image = np.transpose(image, [0, 2, 3, 1])
# Check if the image is single channel and convert to 3-channel
if len(image.shape) == 4 and image.shape[3] == 1: # Single channel image
image = np.repeat(image, 3, axis=3)
return image


Expand All @@ -149,6 +156,8 @@ def view(tensor):
image = (image * 255).astype(np.uint8)
if len(image.shape) == 4:
image = np.concatenate(image, axis=1)
if len(image.shape) == 3 and image.shape[2] == 1:
image = image.squeeze(2)
Image.fromarray(image).show()


Expand Down Expand Up @@ -177,10 +186,11 @@ def hook_fn(self, module, input, output):
self.features = output

def close(self):
# This doesn't actually do anything
self.hook.remove()


def hook_model(model, image_f):
def hook_model(model, image_f, return_hooks=False):
features = OrderedDict()

# recursive hooking function
Expand All @@ -206,4 +216,6 @@ def hook(layer):
assert out is not None, "There are no saved feature maps. Make sure to put the model in eval mode, like so: `model.to(device).eval()`. See README for example."
return out

if return_hooks:
return hook, features
return hook
15 changes: 10 additions & 5 deletions lucent/optvis/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,13 @@
import kornia
from kornia.geometry.transform import translate

try:
from kornia import warp_affine, get_rotation_matrix2d
except ImportError:
from kornia.geometry.transform import warp_affine, get_rotation_matrix2d



device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
KORNIA_VERSION = kornia.__version__


Expand All @@ -33,7 +38,7 @@ def jitter(d):
def inner(image_t):
dx = np.random.choice(d)
dy = np.random.choice(d)
return translate(image_t, torch.tensor([[dx, dy]]).float().to(device))
return translate(image_t, torch.tensor([[dx, dy]]).float().to(image_t.device))

return inner

Expand All @@ -52,7 +57,7 @@ def random_scale(scales):
def inner(image_t):
scale = np.random.choice(scales)
shp = image_t.shape[2:]
scale_shape = [_roundup(scale * d) for d in shp]
scale_shape = [int(_roundup(scale * d)) for d in shp]
pad_x = max(0, _roundup((shp[1] - scale_shape[1]) / 2))
pad_y = max(0, _roundup((shp[0] - scale_shape[0]) / 2))
upsample = torch.nn.Upsample(
Expand All @@ -76,8 +81,8 @@ def inner(image_t):
center = torch.ones(b, 2)
center[..., 0] = (image_t.shape[3] - 1) / 2
center[..., 1] = (image_t.shape[2] - 1) / 2
M = kornia.get_rotation_matrix2d(center, angle, scale).to(device)
rotated_image = kornia.warp_affine(image_t.float(), M, dsize=(h, w))
M = get_rotation_matrix2d(center, angle, scale).to(image_t.device)
rotated_image = warp_affine(image_t.float(), M, dsize=(h, w))
return rotated_image

return inner
Expand Down
5 changes: 5 additions & 0 deletions lucent/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@

from __future__ import absolute_import, division, print_function

import numpy as np
import torch
import random


DEFAULT_DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def set_seed(seed):
# Set global seeds to for reproducibility
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic=True
random.seed(seed)
np.random.seed(seed)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
install_requires=[
"torch>=1.5.0",
"torchvision",
"kornia<=0.4.1",
"kornia>=0.4.1",
"tqdm",
"numpy",
"ipython",
Expand Down
5 changes: 1 addition & 4 deletions tests/optvis/param/test_cppn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,8 @@

from __future__ import absolute_import, division, print_function

import pytest

import torch
import numpy as np
from lucent.optvis import param, render, objectives
from lucent.optvis import param, objectives


def xor_loss(T):
Expand Down
Loading
Loading