From 132533f2f6a4c3d490286099e95f88722d4c3c8d Mon Sep 17 00:00:00 2001 From: B1ueber2y Date: Thu, 22 Aug 2019 19:22:36 -0400 Subject: [PATCH] fix imbalanced gpu memory usage. --- deep_sdf/utils.py | 2 +- networks/deep_sdf_decoder.py | 31 +++++++++++++++++++++++-- reconstruct.py | 4 ++-- train_deep_sdf.py | 45 ++++++++---------------------------- 4 files changed, 42 insertions(+), 40 deletions(-) diff --git a/deep_sdf/utils.py b/deep_sdf/utils.py index 22837d32..eb5faf61 100644 --- a/deep_sdf/utils.py +++ b/deep_sdf/utils.py @@ -70,6 +70,6 @@ def decode_sdf(decoder, latent_vector, queries): latent_repeat = latent_vector.expand(num_samples, -1) inputs = torch.cat([latent_repeat, queries], 1) - sdf = decoder(inputs) + sdf = decoder.inference(inputs) return sdf diff --git a/networks/deep_sdf_decoder.py b/networks/deep_sdf_decoder.py index 71a0b026..1a9b262f 100644 --- a/networks/deep_sdf_decoder.py +++ b/networks/deep_sdf_decoder.py @@ -4,7 +4,7 @@ import torch.nn as nn import torch import torch.nn.functional as F - +import deep_sdf class Decoder(nn.Module): def __init__( @@ -70,8 +70,33 @@ def make_sequence(): self.dropout = dropout self.th = nn.Tanh() + def latent_size_regul(self, lat_vecs): + latent_loss = lat_vecs.pow(2).mean(1) + return latent_loss + + def forward(self, sdf_data, lat_vecs_idx, min_vec, max_vec, enforce_minmax=True): + num_samp_per_scene = sdf_data.shape[1] + sdf_data = sdf_data.reshape(-1, 4) + xyz = sdf_data[:, 0:3] + sdf_gt = sdf_data[:, 3].unsqueeze(1) + + latent_dim = lat_vecs_idx.shape[1] + latent_inputs = lat_vecs_idx.repeat(1, num_samp_per_scene).view(-1, latent_dim) + + inputs = torch.cat([latent_inputs, xyz], 1) + pred_sdf = self.inference(inputs) + + if enforce_minmax: + sdf_gt = deep_sdf.utils.threshold_min_max(sdf_gt, min_vec, max_vec) + pred_sdf = deep_sdf.utils.threshold_min_max( + pred_sdf, min_vec, max_vec + ) + loss_l1 = torch.abs(pred_sdf - sdf_gt).squeeze(1) + loss_l2_size = self.latent_size_regul(lat_vecs_idx) + return pred_sdf, loss_l1, loss_l2_size + # input: N x (L+3) - def forward(self, input): + def inference(self, input): xyz = input[:, -3:] if input.shape[1] > 3 and self.latent_dropout: @@ -107,3 +132,5 @@ def forward(self, input): x = self.th(x) return x + + diff --git a/reconstruct.py b/reconstruct.py index 19dd5814..f632f290 100755 --- a/reconstruct.py +++ b/reconstruct.py @@ -65,11 +65,11 @@ def adjust_learning_rate( inputs = torch.cat([latent_inputs, xyz], 1).cuda() - pred_sdf = decoder(inputs) + pred_sdf = decoder.inference(inputs) # TODO: why is this needed? if e == 0: - pred_sdf = decoder(inputs) + pred_sdf = decoder.inference(inputs) pred_sdf = torch.clamp(pred_sdf, -clamp_dist, clamp_dist) diff --git a/train_deep_sdf.py b/train_deep_sdf.py index b8b86727..2c37325b 100755 --- a/train_deep_sdf.py +++ b/train_deep_sdf.py @@ -223,7 +223,7 @@ def main_function(experiment_directory, continue_from, batch_split): specs = ws.load_experiment_specifications(experiment_directory) - logging.info("Experiment description: \n" + specs["Description"]) + logging.info("Experiment description: \n" + ''.join(specs["Description"])) data_source = specs["DataSource"] train_split_file = specs["TrainSplit"] @@ -313,8 +313,8 @@ def empirical_stat(latent_vecs, indices): logging.info("training with {} GPU(s)".format(torch.cuda.device_count())) - # if torch.cuda.device_count() > 1: - decoder = torch.nn.DataParallel(decoder) + if torch.cuda.device_count() > 1: + decoder = torch.nn.DataParallel(decoder) num_epochs = specs["NumEpochs"] log_frequency = get_spec_with_default(specs, "LogFrequency", 10) @@ -437,45 +437,20 @@ def empirical_stat(latent_vecs, indices): optimizer_all.zero_grad() for _subbatch in range(batch_split): - - # Process the input datag - latent_inputs = torch.zeros(0).cuda() + # Process the input data sdf_data.requires_grad = False - - sdf_data = (sdf_data.cuda()).reshape( - num_samp_per_scene * scene_per_subbatch, 4 - ) - xyz = sdf_data[:, 0:3] - sdf_gt = sdf_data[:, 3].unsqueeze(1) - for ind in indices.numpy(): - latent_ind = lat_vecs[ind] - latent_repeat = latent_ind.expand(num_samp_per_scene, -1) - latent_inputs = torch.cat([latent_inputs, latent_repeat], 0) - inputs = torch.cat([latent_inputs, xyz], 1) - - if enforce_minmax: - sdf_gt = deep_sdf.utils.threshold_min_max(sdf_gt, min_vec, max_vec) - - if latent_size == 0: - inputs = xyz + sdf_data = sdf_data.cuda() + lat_vecs_tensor = torch.cat(lat_vecs) + lat_vecs_idx = lat_vecs_tensor[indices.numpy()] # NN optimization - - pred_sdf = decoder(inputs) - - if enforce_minmax: - pred_sdf = deep_sdf.utils.threshold_min_max( - pred_sdf, min_vec, max_vec - ) - - loss = loss_l1(pred_sdf, sdf_gt) + pred_sdf, loss_l1, l2_size_loss = decoder(sdf_data, lat_vecs_idx, min_vec, max_vec, enforce_minmax=enforce_minmax) + loss = loss_l1.mean() if do_code_regularization: - l2_size_loss = latent_size_regul(lat_vecs, indices.numpy()) - loss += code_reg_lambda * min(1, epoch / 100) * l2_size_loss + loss += code_reg_lambda * min(1, epoch / 100) * l2_size_loss.mean() loss.backward() - batch_loss += loss.item() loss_log.append(batch_loss)