diff --git a/lucent/optvis/param/cppn.py b/lucent/optvis/param/cppn.py index 65b08cd..df7bdff 100644 --- a/lucent/optvis/param/cppn.py +++ b/lucent/optvis/param/cppn.py @@ -27,7 +27,7 @@ def forward(self, x): return torch.cat([x/0.67, (x*x)/0.6], 1) -def cppn(size, num_output_channels=3, num_hidden_channels=24, num_layers=8, +def cppn(size, num_output_channels=3, batch=None, num_hidden_channels=24, num_layers=8, activation_fn=CompositeActivation, normalize=False): r = 3 ** 0.5 @@ -40,32 +40,37 @@ def cppn(size, num_output_channels=3, num_hidden_channels=24, num_layers=8, input_tensor = torch.stack([x, y], dim=0).unsqueeze(0).to(device) - layers = [] - kernel_size = 1 - for i in range(num_layers): - out_c = num_hidden_channels - in_c = out_c * 2 # * 2 for composite activation - if i == 0: - in_c = 2 - if i == num_layers - 1: - out_c = num_output_channels - layers.append(('conv{}'.format(i), torch.nn.Conv2d(in_c, out_c, kernel_size))) - if normalize: - layers.append(('norm{}'.format(i), torch.nn.InstanceNorm2d(out_c))) - if i < num_layers - 1: - layers.append(('actv{}'.format(i), activation_fn())) - else: - layers.append(('output', torch.nn.Sigmoid())) + # batching is handled via a network per example in the batch + batch = 1 if batch is None else batch + nets = torch.nn.ModuleList() + for bi in range(batch): + layers = [] + kernel_size = 1 + for i in range(num_layers): + out_c = num_hidden_channels + in_c = out_c * 2 # * 2 for composite activation + if i == 0: + in_c = 2 + if i == num_layers - 1: + out_c = num_output_channels + layers.append(('conv{}'.format(i), torch.nn.Conv2d(in_c, out_c, kernel_size))) + if normalize: + layers.append(('norm{}'.format(i), torch.nn.InstanceNorm2d(out_c))) + if i < num_layers - 1: + layers.append(('actv{}'.format(i), activation_fn())) + else: + layers.append(('output', torch.nn.Sigmoid())) - # Initialize model - net = torch.nn.Sequential(OrderedDict(layers)).to(device) - # Initialize weights - def weights_init(module): - if isinstance(module, torch.nn.Conv2d): - torch.nn.init.normal_(module.weight, 0, np.sqrt(1/module.in_channels)) - if module.bias is not None: - torch.nn.init.zeros_(module.bias) - net.apply(weights_init) - # Set last conv2d layer's weights to 0 - torch.nn.init.zeros_(dict(net.named_children())['conv{}'.format(num_layers - 1)].weight) - return net.parameters(), lambda: net(input_tensor) + # Initialize model + net = torch.nn.Sequential(OrderedDict(layers)).to(device) + # Initialize weights + def weights_init(module): + if isinstance(module, torch.nn.Conv2d): + torch.nn.init.normal_(module.weight, 0, np.sqrt(1/module.in_channels)) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + net.apply(weights_init) + # Set last conv2d layer's weights to 0 + torch.nn.init.zeros_(dict(net.named_children())['conv{}'.format(num_layers - 1)].weight) + nets.append(net) + return nets.parameters(), lambda: torch.cat([net(input_tensor) for net in nets], dim=0)