From 559ffe91d48acdc6664c1ae18979e5d510892312 Mon Sep 17 00:00:00 2001 From: Adam Szokalski Date: Thu, 4 Dec 2025 14:40:11 +0100 Subject: [PATCH 1/2] Added `lucent.util.set_default_device(backend_name)` function to allow working with a `mps` backend. --- demo.py | 4 +++- lucent/optvis/param/cppn.py | 5 ++--- lucent/optvis/param/images.py | 4 ++-- lucent/optvis/param/lowres.py | 4 ++-- lucent/optvis/param/spatial.py | 6 +++--- lucent/util.py | 3 +++ tests/optvis/param/test_gan.py | 28 ++++++++++++------------- tests/optvis/test_integration.py | 4 ++-- tests/optvis/test_objectives.py | 4 ++-- tests/optvis/test_render.py | 4 ++-- tests/optvis/test_transform.py | 36 +++++++++++++++++--------------- 11 files changed, 54 insertions(+), 48 deletions(-) diff --git a/demo.py b/demo.py index f7b2aa7..039dd2a 100644 --- a/demo.py +++ b/demo.py @@ -2,10 +2,12 @@ from lucent.optvis import render, param from lucent.modelzoo import inceptionv1 +import lucent.util as ut def main(): - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + # ut.set_default_device("cuda:0") # or mps or cpu + device = ut.DEFAULT_DEVICE model = inceptionv1(pretrained=True) model.to(device).eval() diff --git a/lucent/optvis/param/cppn.py b/lucent/optvis/param/cppn.py index 82622ad..a5f23bc 100644 --- a/lucent/optvis/param/cppn.py +++ b/lucent/optvis/param/cppn.py @@ -18,8 +18,7 @@ from collections import OrderedDict import torch import numpy as np - -from lucent.util import DEFAULT_DEVICE +import lucent.util as ut class CompositeActivation(torch.nn.Module): @@ -31,7 +30,7 @@ def forward(self, x): def cppn(size, num_output_channels=3, num_hidden_channels=24, num_layers=8, activation_fn=CompositeActivation, normalize=False, - device=DEFAULT_DEVICE): + device=ut.DEFAULT_DEVICE): r = 3 ** 0.5 diff --git a/lucent/optvis/param/images.py b/lucent/optvis/param/images.py index 5385b2a..0ff5a39 100644 --- a/lucent/optvis/param/images.py +++ b/lucent/optvis/param/images.py @@ -19,11 +19,11 @@ from lucent.optvis.param.spatial import pixel_image, fft_image from lucent.optvis.param.color import to_valid_rgb -from lucent.util import DEFAULT_DEVICE +import lucent.util as ut def image(w, h=None, sd=None, batch=None, decorrelate=True, - fft=True, channels=None, device=DEFAULT_DEVICE): + fft=True, channels=None, device=ut.DEFAULT_DEVICE): h = h or w batch = batch or 1 ch = channels or 3 diff --git a/lucent/optvis/param/lowres.py b/lucent/optvis/param/lowres.py index 04aed64..aad90da 100644 --- a/lucent/optvis/param/lowres.py +++ b/lucent/optvis/param/lowres.py @@ -22,11 +22,11 @@ import torch.nn.functional as F from lucent.optvis.param.resize_bilinear_nd import resize_bilinear_nd -from lucent.util import DEFAULT_DEVICE +import lucent.util as ut def lowres_tensor(shape, underlying_shape, offset=None, sd=0.01, - device=DEFAULT_DEVICE): + device=ut.DEFAULT_DEVICE): """Produces a tensor paramaterized by a interpolated lower resolution tensor. This is like what is done in a laplacian pyramid, but a bit more general. It can be a powerful way to describe images. diff --git a/lucent/optvis/param/spatial.py b/lucent/optvis/param/spatial.py index 414c9c2..a47473c 100644 --- a/lucent/optvis/param/spatial.py +++ b/lucent/optvis/param/spatial.py @@ -18,13 +18,13 @@ import torch import numpy as np -from lucent.util import DEFAULT_DEVICE +import lucent.util as ut TORCH_VERSION = torch.__version__ -def pixel_image(shape, sd=None, device=DEFAULT_DEVICE): +def pixel_image(shape, sd=None, device=ut.DEFAULT_DEVICE): sd = sd or 0.01 tensor = (torch.randn(*shape) * sd).to(device).requires_grad_(True) return [tensor], lambda: tensor @@ -43,7 +43,7 @@ def rfft2d_freqs(h, w): return np.sqrt(fx * fx + fy * fy) -def fft_image(shape, sd=None, decay_power=1, device=DEFAULT_DEVICE): +def fft_image(shape, sd=None, decay_power=1, device=ut.DEFAULT_DEVICE): batch, channels, h, w = shape freqs = rfft2d_freqs(h, w) init_val_size = (batch, channels) + freqs.shape + (2,) # 2 for imaginary and real components diff --git a/lucent/util.py b/lucent/util.py index ac0f40b..f80132d 100644 --- a/lucent/util.py +++ b/lucent/util.py @@ -24,6 +24,9 @@ DEFAULT_DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +def set_default_device(backend_name: str): + global DEFAULT_DEVICE + DEFAULT_DEVICE = torch.device(backend_name) def set_seed(seed): # Set global seeds to for reproducibility diff --git a/tests/optvis/param/test_gan.py b/tests/optvis/param/test_gan.py index af0e17a..9443163 100644 --- a/tests/optvis/param/test_gan.py +++ b/tests/optvis/param/test_gan.py @@ -4,7 +4,7 @@ from lucent.optvis import render, objectives from lucent.optvis.param.gan import upconvGAN from lucent.modelzoo import inceptionv1 -from lucent.util import DEFAULT_DEVICE +import lucent.util as ut NUM_STEPS = 5 @@ -12,17 +12,17 @@ @pytest.fixture def inceptionv1_model(): - model = inceptionv1().to(DEFAULT_DEVICE).eval() + model = inceptionv1().to(ut.DEFAULT_DEVICE).eval() return model def test_fc6gan_load(): """ Test if gan could be downloaded and loaded It will download the model and store it locally """ - G = upconvGAN("fc6").to(DEFAULT_DEVICE) + G = upconvGAN("fc6").to(ut.DEFAULT_DEVICE) def GANparam(batch=1, sd=1): - code = (torch.randn((batch, G.codelen)) * sd).to(DEFAULT_DEVICE).requires_grad_(True) + code = (torch.randn((batch, G.codelen)) * sd).to(ut.DEFAULT_DEVICE).requires_grad_(True) imagef = lambda: G.visualize(code) return [code], imagef code, imagef = GANparam(batch=2, sd=1) @@ -30,10 +30,10 @@ def GANparam(batch=1, sd=1): assert img.shape == (2, 3, 256, 256), "Cannot forward fc6 GAN, shape incorrect." def test_fc7gan_load(): - G = upconvGAN("fc7").to(DEFAULT_DEVICE) + G = upconvGAN("fc7").to(ut.DEFAULT_DEVICE) def GANparam(batch=1, sd=1): - code = (torch.randn((batch, G.codelen)) * sd).to(DEFAULT_DEVICE).requires_grad_(True) + code = (torch.randn((batch, G.codelen)) * sd).to(ut.DEFAULT_DEVICE).requires_grad_(True) imagef = lambda: G.visualize(code) return [code], imagef code, imagef = GANparam(batch=2, sd=1) @@ -41,10 +41,10 @@ def GANparam(batch=1, sd=1): assert img.shape == (2, 3, 256, 256), "Cannot forward fc7 GAN, shape incorrect." def test_fc8gan_load(): - G = upconvGAN("fc8").to(DEFAULT_DEVICE) + G = upconvGAN("fc8").to(ut.DEFAULT_DEVICE) def GANparam(batch=1, sd=1): - code = (torch.randn((batch, G.codelen)) * sd).to(DEFAULT_DEVICE).requires_grad_(True) + code = (torch.randn((batch, G.codelen)) * sd).to(ut.DEFAULT_DEVICE).requires_grad_(True) imagef = lambda: G.visualize(code) return [code], imagef code, imagef = GANparam(batch=2, sd=1) @@ -52,10 +52,10 @@ def GANparam(batch=1, sd=1): assert img.shape == (2, 3, 256, 256), "Cannot forward fc8 GAN, shape incorrect." def test_pool5gan_load(): - G = upconvGAN("pool5").to(DEFAULT_DEVICE) + G = upconvGAN("pool5").to(ut.DEFAULT_DEVICE) def GANparam(batch=1, sd=1): - code = (torch.randn((batch, G.codelen, 6, 6)) * sd).to(DEFAULT_DEVICE).requires_grad_(True) + code = (torch.randn((batch, G.codelen, 6, 6)) * sd).to(ut.DEFAULT_DEVICE).requires_grad_(True) imagef = lambda: G.visualize(code) return [code], imagef code, imagef = GANparam(batch=2, sd=1) @@ -82,10 +82,10 @@ def assert_gan_gradient_descent(GANparam, objective, model): def test_gan_img_optim(inceptionv1_model): """ Test if GAN generated image could be optimized """ - G = upconvGAN("fc6").to(DEFAULT_DEVICE) + G = upconvGAN("fc6").to(ut.DEFAULT_DEVICE) def GANparam(batch=1, sd=1): - code = (torch.randn((batch, G.codelen)) * sd).to(DEFAULT_DEVICE).requires_grad_(True) + code = (torch.randn((batch, G.codelen)) * sd).to(ut.DEFAULT_DEVICE).requires_grad_(True) imagef = lambda: G.visualize(code) return [code], imagef objective = objectives.neuron("input", 0) @@ -94,10 +94,10 @@ def GANparam(batch=1, sd=1): def test_gan_deep_optim(inceptionv1_model): """ Test if GAN generated image could be optimized """ - G = upconvGAN("fc6").to(DEFAULT_DEVICE) + G = upconvGAN("fc6").to(ut.DEFAULT_DEVICE) def GANparam(batch=1, sd=1): - code = (torch.randn((batch, G.codelen)) * sd).to(DEFAULT_DEVICE).requires_grad_(True) + code = (torch.randn((batch, G.codelen)) * sd).to(ut.DEFAULT_DEVICE).requires_grad_(True) imagef = lambda: G.visualize(code) return [code], imagef diff --git a/tests/optvis/test_integration.py b/tests/optvis/test_integration.py index 63c24ce..1690449 100644 --- a/tests/optvis/test_integration.py +++ b/tests/optvis/test_integration.py @@ -20,12 +20,12 @@ import torch from lucent.optvis import param, render from lucent.modelzoo import inceptionv1 -from lucent.util import DEFAULT_DEVICE +import lucent.util as ut @pytest.fixture def inceptionv1_model(): - device = torch.device(DEFAULT_DEVICE) + device = torch.device(ut.DEFAULT_DEVICE) model = inceptionv1().to(device).eval() return model diff --git a/tests/optvis/test_objectives.py b/tests/optvis/test_objectives.py index 45d69ee..d2e6a91 100644 --- a/tests/optvis/test_objectives.py +++ b/tests/optvis/test_objectives.py @@ -21,14 +21,14 @@ from lucent.util import set_seed from lucent.optvis import objectives, param, render from lucent.modelzoo import inceptionv1 -from lucent.util import DEFAULT_DEVICE +import lucent.util as ut set_seed(137) NUM_STEPS = 5 -device = torch.device(DEFAULT_DEVICE) +device = torch.device(ut.DEFAULT_DEVICE) @pytest.fixture def inceptionv1_model(): diff --git a/tests/optvis/test_render.py b/tests/optvis/test_render.py index a9a7d3f..9e68599 100644 --- a/tests/optvis/test_render.py +++ b/tests/optvis/test_render.py @@ -19,12 +19,12 @@ from lucent.optvis import render, param from lucent.modelzoo import inceptionv1 -from lucent.util import DEFAULT_DEVICE +import lucent.util as ut @pytest.fixture def inceptionv1_model(): - model = inceptionv1().to(DEFAULT_DEVICE).eval() + model = inceptionv1().to(ut.DEFAULT_DEVICE).eval() return model diff --git a/tests/optvis/test_transform.py b/tests/optvis/test_transform.py index 8741711..1b20f78 100644 --- a/tests/optvis/test_transform.py +++ b/tests/optvis/test_transform.py @@ -18,39 +18,39 @@ import torch import numpy as np from lucent.optvis import transform -from lucent.util import DEFAULT_DEVICE +import lucent.util as ut def test_pad_reflect(): pad = transform.pad(1) - tensor = torch.ones(1, 3, 2, 2).to(DEFAULT_DEVICE) - assert torch.all(pad(tensor).eq(torch.ones(1, 3, 4, 4).to(DEFAULT_DEVICE))) + tensor = torch.ones(1, 3, 2, 2).to(ut.DEFAULT_DEVICE) + assert torch.all(pad(tensor).eq(torch.ones(1, 3, 4, 4).to(ut.DEFAULT_DEVICE))) def test_pad_constant(): pad = transform.pad(1, mode="constant") - tensor = torch.ones(1, 3, 2, 2).to(DEFAULT_DEVICE) + tensor = torch.ones(1, 3, 2, 2).to(ut.DEFAULT_DEVICE) assert torch.all(pad(tensor).eq(torch.tensor([[ [[0.5, 0.5, 0.5, 0.5], [0.5, 1, 1, 0.5], [0.5, 1, 1, 0.5], [0.5, 0.5, 0.5, 0.5]], [[0.5, 0.5, 0.5, 0.5], [0.5, 1, 1, 0.5], [0.5, 1, 1, 0.5], [0.5, 0.5, 0.5, 0.5]], [[0.5, 0.5, 0.5, 0.5], [0.5, 1, 1, 0.5], [0.5, 1, 1, 0.5], [0.5, 0.5, 0.5, 0.5]], - ]]).to(DEFAULT_DEVICE))) + ]]).to(ut.DEFAULT_DEVICE))) def test_random_scale_down(): scale = transform.random_scale([0.33]) - tensor = torch.ones(1, 3, 3, 3).to(DEFAULT_DEVICE) + tensor = torch.ones(1, 3, 3, 3).to(ut.DEFAULT_DEVICE) assert torch.all(scale(tensor).eq(torch.tensor([[ [[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 0, 0], [0, 1, 0], [0, 0, 0]] - ]]).to(DEFAULT_DEVICE))) + ]]).to(ut.DEFAULT_DEVICE))) def test_random_scale_up(): scale = transform.random_scale([2]) - tensor = torch.ones(1, 3, 1, 1).to(DEFAULT_DEVICE) - assert torch.all(scale(tensor).eq(torch.ones(1, 3, 2, 2).to(DEFAULT_DEVICE))) + tensor = torch.ones(1, 3, 1, 1).to(ut.DEFAULT_DEVICE) + assert torch.all(scale(tensor).eq(torch.ones(1, 3, 2, 2).to(ut.DEFAULT_DEVICE))) def test_random_rotate_even_size(): @@ -59,12 +59,15 @@ def test_random_rotate_even_size(): [[0, 1], [0, 1]], [[0, 1], [0, 1]], [[0, 1], [0, 1]], - ]]).to(DEFAULT_DEVICE) - assert torch.all(rotate(tensor).eq(torch.tensor([[ + ]]).float().to(ut.DEFAULT_DEVICE) + + target = torch.tensor([[ [[1, 1], [0, 0]], [[1, 1], [0, 0]], [[1, 1], [0, 0]], - ]]).to(DEFAULT_DEVICE))) + ]]).float().to(ut.DEFAULT_DEVICE) + # Looks like on some backends there is some numerical instability here, so we allow a small tolerance + assert torch.allclose(rotate(tensor), target, atol=1e-5) def test_random_rotate_odd_size(): @@ -73,20 +76,19 @@ def test_random_rotate_odd_size(): [[0, 0, 1], [0, 0, 1], [0, 0, 1]], [[0, 0, 1], [0, 0, 1], [0, 0, 1]], [[0, 0, 1], [0, 0, 1], [0, 0, 1]] - ]]).to(DEFAULT_DEVICE) + ]]).to(ut.DEFAULT_DEVICE) assert torch.all(rotate(tensor).eq(torch.tensor([[ [[1, 1, 1], [0, 0, 0], [0, 0, 0]], [[1, 1, 1], [0, 0, 0], [0, 0, 0]], [[1, 1, 1], [0, 0, 0], [0, 0, 0]] - ]]).to(DEFAULT_DEVICE))) + ]]).to(ut.DEFAULT_DEVICE))) def test_normalize(): normalize = transform.normalize() - tensor = torch.zeros(1, 3, 1, 1).to(DEFAULT_DEVICE) - print(normalize(tensor)) + tensor = torch.zeros(1, 3, 1, 1).to(ut.DEFAULT_DEVICE) assert torch.allclose(normalize(tensor), torch.tensor([[ [[-0.485/0.229]], [[-0.456/0.224]], [[-0.406/0.225]] - ]]).to(DEFAULT_DEVICE)) + ]]).to(ut.DEFAULT_DEVICE)) From 0e12be68d1cb8c8cfd64b482b911e4990eb23b99 Mon Sep 17 00:00:00 2001 From: Adam Szokalski Date: Thu, 4 Dec 2025 14:57:52 +0100 Subject: [PATCH 2/2] Made set_default_device return the device --- lucent/util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lucent/util.py b/lucent/util.py index f80132d..0f248aa 100644 --- a/lucent/util.py +++ b/lucent/util.py @@ -24,9 +24,10 @@ DEFAULT_DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -def set_default_device(backend_name: str): +def set_default_device(backend_name: str) -> torch.device: global DEFAULT_DEVICE DEFAULT_DEVICE = torch.device(backend_name) + return DEFAULT_DEVICE def set_seed(seed): # Set global seeds to for reproducibility