diff --git a/train.py b/train.py index eec6d78..531273d 100755 --- a/train.py +++ b/train.py @@ -233,16 +233,19 @@ def main(): metrics_log = K.utils.CSVLogger(f'{args.name}_metrics.csv', ['step', 'fid', 'kid']) del train_iter + # sample demo noise once per training run + demo_n_per_proc = math.ceil(args.sample_n / accelerator.num_processes) + demo_x = torch.randn([demo_n_per_proc, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max + @torch.no_grad() @K.utils.eval_mode(model_ema) def demo(): if accelerator.is_main_process: tqdm.write('Sampling...') filename = f'{args.name}_demo_{step:08}.png' - n_per_proc = math.ceil(args.sample_n / accelerator.num_processes) - x = torch.randn([n_per_proc, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max + sigmas = K.sampling.get_sigmas_karras(50, sigma_min, sigma_max, rho=7., device=device) - x_0 = K.sampling.sample_lms(model_ema, x, sigmas, disable=not accelerator.is_main_process) + x_0 = K.sampling.sample_lms(model_ema, demo_x, sigmas, disable=not accelerator.is_main_process) x_0 = accelerator.gather(x_0)[:args.sample_n] if accelerator.is_main_process: grid = utils.make_grid(x_0, nrow=math.ceil(args.sample_n ** 0.5), padding=0)