Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 34 additions & 29 deletions lucent/optvis/param/cppn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def forward(self, x):
return torch.cat([x/0.67, (x*x)/0.6], 1)


def cppn(size, num_output_channels=3, num_hidden_channels=24, num_layers=8,
def cppn(size, num_output_channels=3, batch=None, num_hidden_channels=24, num_layers=8,
activation_fn=CompositeActivation, normalize=False):

r = 3 ** 0.5
Expand All @@ -40,32 +40,37 @@ def cppn(size, num_output_channels=3, num_hidden_channels=24, num_layers=8,

input_tensor = torch.stack([x, y], dim=0).unsqueeze(0).to(device)

layers = []
kernel_size = 1
for i in range(num_layers):
out_c = num_hidden_channels
in_c = out_c * 2 # * 2 for composite activation
if i == 0:
in_c = 2
if i == num_layers - 1:
out_c = num_output_channels
layers.append(('conv{}'.format(i), torch.nn.Conv2d(in_c, out_c, kernel_size)))
if normalize:
layers.append(('norm{}'.format(i), torch.nn.InstanceNorm2d(out_c)))
if i < num_layers - 1:
layers.append(('actv{}'.format(i), activation_fn()))
else:
layers.append(('output', torch.nn.Sigmoid()))
# batching is handled via a network per example in the batch
batch = 1 if batch is None else batch
nets = torch.nn.ModuleList()
for bi in range(batch):
layers = []
kernel_size = 1
for i in range(num_layers):
out_c = num_hidden_channels
in_c = out_c * 2 # * 2 for composite activation
if i == 0:
in_c = 2
if i == num_layers - 1:
out_c = num_output_channels
layers.append(('conv{}'.format(i), torch.nn.Conv2d(in_c, out_c, kernel_size)))
if normalize:
layers.append(('norm{}'.format(i), torch.nn.InstanceNorm2d(out_c)))
if i < num_layers - 1:
layers.append(('actv{}'.format(i), activation_fn()))
else:
layers.append(('output', torch.nn.Sigmoid()))

# Initialize model
net = torch.nn.Sequential(OrderedDict(layers)).to(device)
# Initialize weights
def weights_init(module):
if isinstance(module, torch.nn.Conv2d):
torch.nn.init.normal_(module.weight, 0, np.sqrt(1/module.in_channels))
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
net.apply(weights_init)
# Set last conv2d layer's weights to 0
torch.nn.init.zeros_(dict(net.named_children())['conv{}'.format(num_layers - 1)].weight)
return net.parameters(), lambda: net(input_tensor)
# Initialize model
net = torch.nn.Sequential(OrderedDict(layers)).to(device)
# Initialize weights
def weights_init(module):
if isinstance(module, torch.nn.Conv2d):
torch.nn.init.normal_(module.weight, 0, np.sqrt(1/module.in_channels))
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
net.apply(weights_init)
# Set last conv2d layer's weights to 0
torch.nn.init.zeros_(dict(net.named_children())['conv{}'.format(num_layers - 1)].weight)
nets.append(net)
return nets.parameters(), lambda: torch.cat([net(input_tensor) for net in nets], dim=0)