From 2542ab88b2b1023a5aba3c56762ceb3b4b75f0d1 Mon Sep 17 00:00:00 2001 From: quasar529 Date: Wed, 10 Jul 2024 20:23:09 +0900 Subject: [PATCH 01/13] train setting - logging and sampling --- .gitignore | 4 ++ improved_diffusion/train_util.py | 73 +++++++++++--------------------- scripts/image_sample.py | 27 +++++------- scripts/image_train.py | 6 +-- 4 files changed, 41 insertions(+), 69 deletions(-) diff --git a/.gitignore b/.gitignore index 2d97991d7e..110f2eccfa 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ .DS_Store __pycache__/ +ckpt/ +*.ipynb +datasets/cifar_test +datasets/cifar_train \ No newline at end of file diff --git a/improved_diffusion/train_util.py b/improved_diffusion/train_util.py index 1867604145..e8e5a82fd2 100644 --- a/improved_diffusion/train_util.py +++ b/improved_diffusion/train_util.py @@ -20,6 +20,8 @@ from .nn import update_ema from .resample import LossAwareSampler, UniformSampler +import wandb + # For ImageNet experiments, this was a good default value. # We found that the lg_loss_scale quickly climbed to # 20-21 within the first ~1K steps of training. @@ -52,11 +54,7 @@ def __init__( self.batch_size = batch_size self.microbatch = microbatch if microbatch > 0 else batch_size self.lr = lr - self.ema_rate = ( - [ema_rate] - if isinstance(ema_rate, float) - else [float(x) for x in ema_rate.split(",")] - ) + self.ema_rate = [ema_rate] if isinstance(ema_rate, float) else [float(x) for x in ema_rate.split(",")] self.log_interval = log_interval self.save_interval = save_interval self.resume_checkpoint = resume_checkpoint @@ -84,13 +82,9 @@ def __init__( self._load_optimizer_state() # Model was resumed, either due to a restart or a checkpoint # being specified at the command line. - self.ema_params = [ - self._load_ema_parameters(rate) for rate in self.ema_rate - ] + self.ema_params = [self._load_ema_parameters(rate) for rate in self.ema_rate] else: - self.ema_params = [ - copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate)) - ] + self.ema_params = [copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))] if th.cuda.is_available(): self.use_ddp = True @@ -104,13 +98,12 @@ def __init__( ) else: if dist.get_world_size() > 1: - logger.warn( - "Distributed training requires CUDA. " - "Gradients will not be synchronized properly!" - ) + logger.warn("Distributed training requires CUDA. " "Gradients will not be synchronized properly!") self.use_ddp = False self.ddp_model = self.model + wandb.init(project="imporved_diffusion", entity="quasar529") + def _load_and_sync_parameters(self): resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint @@ -118,11 +111,7 @@ def _load_and_sync_parameters(self): self.resume_step = parse_resume_step_from_filename(resume_checkpoint) if dist.get_rank() == 0: logger.log(f"loading model from checkpoint: {resume_checkpoint}...") - self.model.load_state_dict( - dist_util.load_state_dict( - resume_checkpoint, map_location=dist_util.dev() - ) - ) + self.model.load_state_dict(dist_util.load_state_dict(resume_checkpoint, map_location=dist_util.dev())) dist_util.sync_params(self.model.parameters()) @@ -134,9 +123,7 @@ def _load_ema_parameters(self, rate): if ema_checkpoint: if dist.get_rank() == 0: logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") - state_dict = dist_util.load_state_dict( - ema_checkpoint, map_location=dist_util.dev() - ) + state_dict = dist_util.load_state_dict(ema_checkpoint, map_location=dist_util.dev()) ema_params = self._state_dict_to_master_params(state_dict) dist_util.sync_params(ema_params) @@ -144,14 +131,10 @@ def _load_ema_parameters(self, rate): def _load_optimizer_state(self): main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint - opt_checkpoint = bf.join( - bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" - ) + opt_checkpoint = bf.join(bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt") if bf.exists(opt_checkpoint): logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") - state_dict = dist_util.load_state_dict( - opt_checkpoint, map_location=dist_util.dev() - ) + state_dict = dist_util.load_state_dict(opt_checkpoint, map_location=dist_util.dev()) self.opt.load_state_dict(state_dict) def _setup_fp16(self): @@ -159,10 +142,7 @@ def _setup_fp16(self): self.model.convert_to_fp16() def run_loop(self): - while ( - not self.lr_anneal_steps - or self.step + self.resume_step < self.lr_anneal_steps - ): + while not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps: batch, cond = next(self.data) self.run_step(batch, cond) if self.step % self.log_interval == 0: @@ -189,10 +169,7 @@ def forward_backward(self, batch, cond): zero_grad(self.model_params) for i in range(0, batch.shape[0], self.microbatch): micro = batch[i : i + self.microbatch].to(dist_util.dev()) - micro_cond = { - k: v[i : i + self.microbatch].to(dist_util.dev()) - for k, v in cond.items() - } + micro_cond = {k: v[i : i + self.microbatch].to(dist_util.dev()) for k, v in cond.items()} last_batch = (i + self.microbatch) >= batch.shape[0] t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) @@ -211,16 +188,13 @@ def forward_backward(self, batch, cond): losses = compute_losses() if isinstance(self.schedule_sampler, LossAwareSampler): - self.schedule_sampler.update_with_local_losses( - t, losses["loss"].detach() - ) + self.schedule_sampler.update_with_local_losses(t, losses["loss"].detach()) loss = (losses["loss"] * weights).mean() - log_loss_dict( - self.diffusion, t, {k: v * weights for k, v in losses.items()} - ) + wandb.log({"train/loss": loss.item()}) + log_loss_dict(self.diffusion, t, {k: v * weights for k, v in losses.items()}) if self.use_fp16: - loss_scale = 2 ** self.lg_loss_scale + loss_scale = 2**self.lg_loss_scale (loss * loss_scale).backward() else: loss.backward() @@ -232,7 +206,7 @@ def optimize_fp16(self): return model_grads_to_master_grads(self.model_params, self.master_params) - self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) + self.master_params[0].grad.mul_(1.0 / (2**self.lg_loss_scale)) self._log_grad_norm() self._anneal_lr() self.opt.step() @@ -251,7 +225,8 @@ def optimize_normal(self): def _log_grad_norm(self): sqsum = 0.0 for p in self.master_params: - sqsum += (p.grad ** 2).sum().item() + sqsum += (p.grad**2).sum().item() + wandb.log({"train/grad_norm": np.sqrt(sqsum)}) logger.logkv_mean("grad_norm", np.sqrt(sqsum)) def _anneal_lr(self): @@ -261,10 +236,12 @@ def _anneal_lr(self): lr = self.lr * (1 - frac_done) for param_group in self.opt.param_groups: param_group["lr"] = lr + wandb.log({"train/lr": lr}) def log_step(self): logger.logkv("step", self.step + self.resume_step) logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) + wandb.log({"train/samples": (self.step + self.resume_step + 1) * self.global_batch}) if self.use_fp16: logger.logkv("lg_loss_scale", self.lg_loss_scale) @@ -295,9 +272,7 @@ def save_checkpoint(rate, params): def _master_params_to_state_dict(self, master_params): if self.use_fp16: - master_params = unflatten_master_params( - self.model.parameters(), master_params - ) + master_params = unflatten_master_params(self.model.parameters(), master_params) state_dict = self.model.state_dict() for i, (name, _value) in enumerate(self.model.named_parameters()): assert name in state_dict diff --git a/scripts/image_sample.py b/scripts/image_sample.py index 66326f6908..2e410be327 100644 --- a/scripts/image_sample.py +++ b/scripts/image_sample.py @@ -9,7 +9,10 @@ import numpy as np import torch as th import torch.distributed as dist +import sys +print(sys.path) +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from improved_diffusion import dist_util, logger from improved_diffusion.script_util import ( NUM_CLASSES, @@ -27,12 +30,8 @@ def main(): logger.configure() logger.log("creating model and diffusion...") - model, diffusion = create_model_and_diffusion( - **args_to_dict(args, model_and_diffusion_defaults().keys()) - ) - model.load_state_dict( - dist_util.load_state_dict(args.model_path, map_location="cpu") - ) + model, diffusion = create_model_and_diffusion(**args_to_dict(args, model_and_diffusion_defaults().keys())) + model.load_state_dict(dist_util.load_state_dict(args.model_path, map_location="cpu")) model.to(dist_util.dev()) model.eval() @@ -42,13 +41,9 @@ def main(): while len(all_images) * args.batch_size < args.num_samples: model_kwargs = {} if args.class_cond: - classes = th.randint( - low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev() - ) + classes = th.randint(low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()) model_kwargs["y"] = classes - sample_fn = ( - diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop - ) + sample_fn = diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop sample = sample_fn( model, (args.batch_size, 3, args.image_size, args.image_size), @@ -63,9 +58,7 @@ def main(): dist.all_gather(gathered_samples, sample) # gather not supported with NCCL all_images.extend([sample.cpu().numpy() for sample in gathered_samples]) if args.class_cond: - gathered_labels = [ - th.zeros_like(classes) for _ in range(dist.get_world_size()) - ] + gathered_labels = [th.zeros_like(classes) for _ in range(dist.get_world_size())] dist.all_gather(gathered_labels, classes) all_labels.extend([labels.cpu().numpy() for labels in gathered_labels]) logger.log(f"created {len(all_images) * args.batch_size} samples") @@ -91,9 +84,9 @@ def main(): def create_argparser(): defaults = dict( clip_denoised=True, - num_samples=10000, + num_samples=1000, batch_size=16, - use_ddim=False, + use_ddim=True, model_path="", ) defaults.update(model_and_diffusion_defaults()) diff --git a/scripts/image_train.py b/scripts/image_train.py index 76c0632608..339590a0ad 100644 --- a/scripts/image_train.py +++ b/scripts/image_train.py @@ -23,9 +23,8 @@ def main(): logger.configure() logger.log("creating model and diffusion...") - model, diffusion = create_model_and_diffusion( - **args_to_dict(args, model_and_diffusion_defaults().keys()) - ) + model, diffusion = create_model_and_diffusion(**args_to_dict(args, model_and_diffusion_defaults().keys())) + model.load_state_dict(dist_util.load_state_dict(args.model_path, map_location="cpu")) model.to(dist_util.dev()) schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) @@ -72,6 +71,7 @@ def create_argparser(): resume_checkpoint="", use_fp16=False, fp16_scale_growth=1e-3, + model_path="", ) defaults.update(model_and_diffusion_defaults()) parser = argparse.ArgumentParser() From d04c2af47dd4fe54c7342d8a56024b4b81a3f15e Mon Sep 17 00:00:00 2001 From: quasar529 Date: Thu, 11 Jul 2024 19:36:50 +0900 Subject: [PATCH 02/13] logging --- .gitignore | 6 +++++- improved_diffusion/train_util.py | 3 ++- scripts/image_train.py | 1 + 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 110f2eccfa..0dc057dc6d 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,8 @@ __pycache__/ ckpt/ *.ipynb datasets/cifar_test -datasets/cifar_train \ No newline at end of file +datasets/cifar_train +.venv/ +anaconda.sh +improved_diffusion.egg-info/ +wandb/ diff --git a/improved_diffusion/train_util.py b/improved_diffusion/train_util.py index e8e5a82fd2..ece949c1da 100644 --- a/improved_diffusion/train_util.py +++ b/improved_diffusion/train_util.py @@ -241,7 +241,8 @@ def _anneal_lr(self): def log_step(self): logger.logkv("step", self.step + self.resume_step) logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) - wandb.log({"train/samples": (self.step + self.resume_step + 1) * self.global_batch}) + wandb.log({"train/step": self.step + self.resume_step}) + wandb.log({"train/samples": (self.step + self.resume_step + 1) * self.global_batch}) if self.use_fp16: logger.logkv("lg_loss_scale", self.lg_loss_scale) diff --git a/scripts/image_train.py b/scripts/image_train.py index 339590a0ad..c2e1065068 100644 --- a/scripts/image_train.py +++ b/scripts/image_train.py @@ -26,6 +26,7 @@ def main(): model, diffusion = create_model_and_diffusion(**args_to_dict(args, model_and_diffusion_defaults().keys())) model.load_state_dict(dist_util.load_state_dict(args.model_path, map_location="cpu")) model.to(dist_util.dev()) + print(dist_util.dev()) schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) logger.log("creating data loader...") From a33642475a961cc426d9489281a4af04fece0762 Mon Sep 17 00:00:00 2001 From: quasar529 Date: Thu, 11 Jul 2024 20:04:18 +0900 Subject: [PATCH 03/13] add: tqdm --- improved_diffusion/train_util.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/improved_diffusion/train_util.py b/improved_diffusion/train_util.py index ece949c1da..51e656b4f5 100644 --- a/improved_diffusion/train_util.py +++ b/improved_diffusion/train_util.py @@ -21,6 +21,7 @@ from .resample import LossAwareSampler, UniformSampler import wandb +from tqdm import tqdm # For ImageNet experiments, this was a good default value. # We found that the lg_loss_scale quickly climbed to @@ -167,7 +168,8 @@ def run_step(self, batch, cond): def forward_backward(self, batch, cond): zero_grad(self.model_params) - for i in range(0, batch.shape[0], self.microbatch): + progress_bar = tqdm(range(0, batch.shape[0], self.microbatch), desc="Processing Batches") + for i in progress_bar: micro = batch[i : i + self.microbatch].to(dist_util.dev()) micro_cond = {k: v[i : i + self.microbatch].to(dist_util.dev()) for k, v in cond.items()} last_batch = (i + self.microbatch) >= batch.shape[0] @@ -198,6 +200,9 @@ def forward_backward(self, batch, cond): (loss * loss_scale).backward() else: loss.backward() + progress_bar.set_postfix(loss=loss.item()) + + tqdm.write(f"Batch {i // self.microbatch + 1}/{len(progress_bar)} - Loss: {loss.item()}") def optimize_fp16(self): if any(not th.isfinite(p.grad).all() for p in self.model_params): @@ -242,7 +247,7 @@ def log_step(self): logger.logkv("step", self.step + self.resume_step) logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) wandb.log({"train/step": self.step + self.resume_step}) - wandb.log({"train/samples": (self.step + self.resume_step + 1) * self.global_batch}) + wandb.log({"train/samples": (self.step + self.resume_step + 1) * self.global_batch}) if self.use_fp16: logger.logkv("lg_loss_scale", self.lg_loss_scale) From d548d36a1b661a29bdaacb83c5b82ed6ce3a7945 Mon Sep 17 00:00:00 2001 From: quasar529 Date: Thu, 11 Jul 2024 20:06:50 +0900 Subject: [PATCH 04/13] wandb config --- improved_diffusion/train_util.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/improved_diffusion/train_util.py b/improved_diffusion/train_util.py index 51e656b4f5..ab409db350 100644 --- a/improved_diffusion/train_util.py +++ b/improved_diffusion/train_util.py @@ -103,7 +103,23 @@ def __init__( self.use_ddp = False self.ddp_model = self.model - wandb.init(project="imporved_diffusion", entity="quasar529") + wandb.init(project="imporved_diffusion", entity="quasar529") + wandb.config.update( + { + "batch_size": batch_size, + "microbatch": microbatch, + "lr": lr, + "ema_rate": ema_rate, + "log_interval": log_interval, + "save_interval": save_interval, + "resume_checkpoint": resume_checkpoint, + "use_fp16": use_fp16, + "fp16_scale_growth": fp16_scale_growth, + "schedule_sampler": schedule_sampler, + "weight_decay": weight_decay, + "lr_anneal_steps": lr_anneal_steps, + } + ) def _load_and_sync_parameters(self): resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint From 9cfa8ddbd8ca09a0de08c8611ee5cb9fe0ee7d9a Mon Sep 17 00:00:00 2001 From: quasar529 Date: Thu, 11 Jul 2024 20:25:20 +0900 Subject: [PATCH 05/13] train checked --- .gitignore | 1 + scripts/image_sample.py | 2 +- scripts/image_train.py | 6 +++--- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 0dc057dc6d..b1bc2bdf00 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ datasets/cifar_train anaconda.sh improved_diffusion.egg-info/ wandb/ +*.pt diff --git a/scripts/image_sample.py b/scripts/image_sample.py index 2e410be327..e6210b120f 100644 --- a/scripts/image_sample.py +++ b/scripts/image_sample.py @@ -84,7 +84,7 @@ def main(): def create_argparser(): defaults = dict( clip_denoised=True, - num_samples=1000, + num_samples=10, batch_size=16, use_ddim=True, model_path="", diff --git a/scripts/image_train.py b/scripts/image_train.py index c2e1065068..8b9157b711 100644 --- a/scripts/image_train.py +++ b/scripts/image_train.py @@ -24,7 +24,7 @@ def main(): logger.log("creating model and diffusion...") model, diffusion = create_model_and_diffusion(**args_to_dict(args, model_and_diffusion_defaults().keys())) - model.load_state_dict(dist_util.load_state_dict(args.model_path, map_location="cpu")) + #model.load_state_dict(dist_util.load_state_dict(args.model_path, map_location="cpu")) model.to(dist_util.dev()) print(dist_util.dev()) schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) @@ -63,14 +63,14 @@ def create_argparser(): schedule_sampler="uniform", lr=1e-4, weight_decay=0.0, - lr_anneal_steps=0, + lr_anneal_steps=10000, batch_size=1, microbatch=-1, # -1 disables microbatches ema_rate="0.9999", # comma-separated list of EMA values log_interval=10, save_interval=10000, resume_checkpoint="", - use_fp16=False, + use_fp16=True, fp16_scale_growth=1e-3, model_path="", ) From 780fbc1621ec2ff9908405b778fec23ed4bf4591 Mon Sep 17 00:00:00 2001 From: quasar529 Date: Thu, 11 Jul 2024 20:29:33 +0900 Subject: [PATCH 06/13] fix: tqdm --- improved_diffusion/train_util.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/improved_diffusion/train_util.py b/improved_diffusion/train_util.py index ab409db350..82a1f1fcf3 100644 --- a/improved_diffusion/train_util.py +++ b/improved_diffusion/train_util.py @@ -120,6 +120,7 @@ def __init__( "lr_anneal_steps": lr_anneal_steps, } ) + wandb.run.name = f"diffusion_step_resume_checkpoint_{self.resume_checkpoint}" def _load_and_sync_parameters(self): resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint @@ -159,17 +160,21 @@ def _setup_fp16(self): self.model.convert_to_fp16() def run_loop(self): - while not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps: - batch, cond = next(self.data) - self.run_step(batch, cond) - if self.step % self.log_interval == 0: - logger.dumpkvs() - if self.step % self.save_interval == 0: - self.save() - # Run for a finite amount of time in integration tests. - if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: - return - self.step += 1 + total_steps = self.lr_anneal_steps if self.lr_anneal_steps else float("inf") + with tqdm(total=total_steps, desc="Training Progress") as pbar: + while not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps: + batch, cond = next(self.data) + self.run_step(batch, cond) + if self.step % self.log_interval == 0: + logger.dumpkvs() + if self.step % self.save_interval == 0: + self.save() + # Run for a finite amount of time in integration tests. + if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: + return + self.step += 1 + pbar.update(1) # Update the progress bar + # Save the last checkpoint if it wasn't already saved. if (self.step - 1) % self.save_interval != 0: self.save() @@ -185,7 +190,7 @@ def run_step(self, batch, cond): def forward_backward(self, batch, cond): zero_grad(self.model_params) progress_bar = tqdm(range(0, batch.shape[0], self.microbatch), desc="Processing Batches") - for i in progress_bar: + for i in range(0, batch.shape[0], self.microbatch): micro = batch[i : i + self.microbatch].to(dist_util.dev()) micro_cond = {k: v[i : i + self.microbatch].to(dist_util.dev()) for k, v in cond.items()} last_batch = (i + self.microbatch) >= batch.shape[0] @@ -217,6 +222,7 @@ def forward_backward(self, batch, cond): else: loss.backward() progress_bar.set_postfix(loss=loss.item()) + progress_bar.update(1) # 배치 진행 상황 업데이트 tqdm.write(f"Batch {i // self.microbatch + 1}/{len(progress_bar)} - Loss: {loss.item()}") From d374cdcf34c3704b65ba421ea26684d73aae389b Mon Sep 17 00:00:00 2001 From: quasar529 Date: Thu, 11 Jul 2024 20:31:11 +0900 Subject: [PATCH 07/13] fix:tqdm bar --- improved_diffusion/train_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/improved_diffusion/train_util.py b/improved_diffusion/train_util.py index 82a1f1fcf3..6db8d952a4 100644 --- a/improved_diffusion/train_util.py +++ b/improved_diffusion/train_util.py @@ -225,6 +225,7 @@ def forward_backward(self, batch, cond): progress_bar.update(1) # 배치 진행 상황 업데이트 tqdm.write(f"Batch {i // self.microbatch + 1}/{len(progress_bar)} - Loss: {loss.item()}") + progress_bar.close() # 배치 진행 바 종료 def optimize_fp16(self): if any(not th.isfinite(p.grad).all() for p in self.model_params): From 0775a3b2015e2bc38388b6b2029ff877191acc36 Mon Sep 17 00:00:00 2001 From: quasar529 Date: Thu, 11 Jul 2024 20:38:59 +0900 Subject: [PATCH 08/13] fix: save and log --- improved_diffusion/train_util.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/improved_diffusion/train_util.py b/improved_diffusion/train_util.py index 6db8d952a4..7ad1cd1c84 100644 --- a/improved_diffusion/train_util.py +++ b/improved_diffusion/train_util.py @@ -189,7 +189,9 @@ def run_step(self, batch, cond): def forward_backward(self, batch, cond): zero_grad(self.model_params) - progress_bar = tqdm(range(0, batch.shape[0], self.microbatch), desc="Processing Batches") + progress_bar = tqdm( + range(0, batch.shape[0], self.microbatch), desc="Processing Batches", leave=False, dynamic_ncols=True + ) for i in range(0, batch.shape[0], self.microbatch): micro = batch[i : i + self.microbatch].to(dist_util.dev()) micro_cond = {k: v[i : i + self.microbatch].to(dist_util.dev()) for k, v in cond.items()} @@ -274,7 +276,7 @@ def log_step(self): if self.use_fp16: logger.logkv("lg_loss_scale", self.lg_loss_scale) - def save(self): + def save(self, save_dir="/home/jun/improved-diffusion/results"): def save_checkpoint(rate, params): state_dict = self._master_params_to_state_dict(params) if dist.get_rank() == 0: @@ -283,7 +285,8 @@ def save_checkpoint(rate, params): filename = f"model{(self.step+self.resume_step):06d}.pt" else: filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" - with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: + save_path = bf.join(save_dir, filename) + with bf.BlobFile(save_path, "wb") as f: th.save(state_dict, f) save_checkpoint(0, self.master_params) @@ -291,10 +294,9 @@ def save_checkpoint(rate, params): save_checkpoint(rate, params) if dist.get_rank() == 0: - with bf.BlobFile( - bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), - "wb", - ) as f: + opt_filename = f"opt{(self.step+self.resume_step):06d}.pt" + opt_save_path = bf.join(save_dir, opt_filename) + with bf.BlobFile(opt_save_path, "wb") as f: th.save(self.opt.state_dict(), f) dist.barrier() From eae9d4a872f1b96d4d5788bc209b00798d15d0da Mon Sep 17 00:00:00 2001 From: quasar529 Date: Thu, 11 Jul 2024 20:43:05 +0900 Subject: [PATCH 09/13] update gitignore --- .gitignore | 1 + scripts/image_train.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index b1bc2bdf00..8f6aef9c7e 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ anaconda.sh improved_diffusion.egg-info/ wandb/ *.pt +results/ diff --git a/scripts/image_train.py b/scripts/image_train.py index 8b9157b711..e3c0024714 100644 --- a/scripts/image_train.py +++ b/scripts/image_train.py @@ -68,7 +68,7 @@ def create_argparser(): microbatch=-1, # -1 disables microbatches ema_rate="0.9999", # comma-separated list of EMA values log_interval=10, - save_interval=10000, + save_interval=5000, resume_checkpoint="", use_fp16=True, fp16_scale_growth=1e-3, From 59dd508506bda553df9985d5c2cacdcd04dc6468 Mon Sep 17 00:00:00 2001 From: quasar529 Date: Thu, 11 Jul 2024 20:46:18 +0900 Subject: [PATCH 10/13] fix: tqdm --- improved_diffusion/train_util.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/improved_diffusion/train_util.py b/improved_diffusion/train_util.py index 7ad1cd1c84..e4ca57017f 100644 --- a/improved_diffusion/train_util.py +++ b/improved_diffusion/train_util.py @@ -161,7 +161,7 @@ def _setup_fp16(self): def run_loop(self): total_steps = self.lr_anneal_steps if self.lr_anneal_steps else float("inf") - with tqdm(total=total_steps, desc="Training Progress") as pbar: + with tqdm(total=total_steps, desc="Training Progress", dynamic_ncols=True) as pbar: while not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps: batch, cond = next(self.data) self.run_step(batch, cond) @@ -189,9 +189,9 @@ def run_step(self, batch, cond): def forward_backward(self, batch, cond): zero_grad(self.model_params) - progress_bar = tqdm( - range(0, batch.shape[0], self.microbatch), desc="Processing Batches", leave=False, dynamic_ncols=True - ) + total_batches = (batch.shape[0] + self.microbatch - 1) // self.microbatch # 전체 배치 수 계산 + progress_bar = tqdm(total=total_batches, desc="Processing Batches", leave=False, dynamic_ncols=True) + for i in range(0, batch.shape[0], self.microbatch): micro = batch[i : i + self.microbatch].to(dist_util.dev()) micro_cond = {k: v[i : i + self.microbatch].to(dist_util.dev()) for k, v in cond.items()} @@ -218,6 +218,7 @@ def forward_backward(self, batch, cond): loss = (losses["loss"] * weights).mean() wandb.log({"train/loss": loss.item()}) log_loss_dict(self.diffusion, t, {k: v * weights for k, v in losses.items()}) + if self.use_fp16: loss_scale = 2**self.lg_loss_scale (loss * loss_scale).backward() @@ -226,7 +227,6 @@ def forward_backward(self, batch, cond): progress_bar.set_postfix(loss=loss.item()) progress_bar.update(1) # 배치 진행 상황 업데이트 - tqdm.write(f"Batch {i // self.microbatch + 1}/{len(progress_bar)} - Loss: {loss.item()}") progress_bar.close() # 배치 진행 바 종료 def optimize_fp16(self): From e9cacc9cdf55b2e3533a47ab982d696e1e59a609 Mon Sep 17 00:00:00 2001 From: quasar529 Date: Thu, 11 Jul 2024 20:47:47 +0900 Subject: [PATCH 11/13] fix: tqdm --- improved_diffusion/train_util.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/improved_diffusion/train_util.py b/improved_diffusion/train_util.py index e4ca57017f..f3ead243c9 100644 --- a/improved_diffusion/train_util.py +++ b/improved_diffusion/train_util.py @@ -193,6 +193,11 @@ def forward_backward(self, batch, cond): progress_bar = tqdm(total=total_batches, desc="Processing Batches", leave=False, dynamic_ncols=True) for i in range(0, batch.shape[0], self.microbatch): + current_batch = i // self.microbatch + 1 # 현재 배치 번호 계산 + progress_bar.set_description( + f"Processing Batches {current_batch}/{total_batches}" + ) # 진행 바 설명 업데이트 + micro = batch[i : i + self.microbatch].to(dist_util.dev()) micro_cond = {k: v[i : i + self.microbatch].to(dist_util.dev()) for k, v in cond.items()} last_batch = (i + self.microbatch) >= batch.shape[0] From 8b3bc07faa6c0b24623645dc74628e6f35800112 Mon Sep 17 00:00:00 2001 From: quasar529 Date: Thu, 11 Jul 2024 20:48:58 +0900 Subject: [PATCH 12/13] fix: wandb name --- improved_diffusion/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/improved_diffusion/train_util.py b/improved_diffusion/train_util.py index f3ead243c9..9bd2990614 100644 --- a/improved_diffusion/train_util.py +++ b/improved_diffusion/train_util.py @@ -120,7 +120,7 @@ def __init__( "lr_anneal_steps": lr_anneal_steps, } ) - wandb.run.name = f"diffusion_step_resume_checkpoint_{self.resume_checkpoint}" + wandb.run.name = f"diffusion_batch_{batch_size}_steps{lr_anneal_steps}_{self.resume_checkpoint}" def _load_and_sync_parameters(self): resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint From 6f14bb9c57fc3f94cf08b5483d3f610be3004a8c Mon Sep 17 00:00:00 2001 From: quasar529 Date: Thu, 18 Jul 2024 11:43:08 +0900 Subject: [PATCH 13/13] feat: calculate fid score for validation --- improved_diffusion/fid_evaluation.py | 137 +++++++++++++++++++++++++++ improved_diffusion/train_util.py | 78 ++++++++++++++- scripts/image_train.py | 24 ++++- 3 files changed, 236 insertions(+), 3 deletions(-) create mode 100644 improved_diffusion/fid_evaluation.py diff --git a/improved_diffusion/fid_evaluation.py b/improved_diffusion/fid_evaluation.py new file mode 100644 index 0000000000..4b4a8a4efe --- /dev/null +++ b/improved_diffusion/fid_evaluation.py @@ -0,0 +1,137 @@ +""" +Code from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/fid_evaluation.py +""" + +import math +import os + +import numpy as np +import torch +from einops import rearrange, repeat +from pytorch_fid.fid_score import calculate_frechet_distance +from pytorch_fid.inception import InceptionV3 +from torch.nn.functional import adaptive_avg_pool2d +from tqdm.auto import tqdm + +from datetime import datetime + + +def num_to_groups(num, divisor): + groups = num // divisor + remainder = num % divisor + arr = [divisor] * groups + if remainder > 0: + arr.append(remainder) + return arr + + +class FIDEvaluation: + def __init__( + self, + batch_size, + dl, + sampler, + channels=3, + accelerator=None, + stats_dir="./results", + device="cuda", + num_fid_samples=10000, + inception_block_idx=2048, + ): + self.batch_size = batch_size + self.n_samples = num_fid_samples + self.device = device + self.channels = channels + self.dl = dl + self.sampler = sampler + self.stats_dir = stats_dir + self.print_fn = print if accelerator is None else accelerator.print + self.samples_path = None # 샘플 경로를 저장할 새 속성 + + assert inception_block_idx in InceptionV3.BLOCK_INDEX_BY_DIM + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[inception_block_idx] + self.inception_v3 = InceptionV3([block_idx]).to(device) + self.dataset_stats_loaded = False + + def set_samples_path(self, path): + """생성된 샘플의 경로를 설정하는 메서드""" + self.samples_path = path + + def calculate_inception_features(self, samples): + if self.channels == 1: + samples = repeat(samples, "b 1 ... -> b c ...", c=3) + + self.inception_v3.eval() + features = self.inception_v3(samples)[0] + + if features.size(2) != 1 or features.size(3) != 1: + features = adaptive_avg_pool2d(features, output_size=(1, 1)) + features = rearrange(features, "... 1 1 -> ...") + return features + + def load_or_precalc_dataset_stats(self): + path = os.path.join(self.stats_dir, f"{datetime.now().strftime('%Y%m%d_%H%M')}_dataset_stats") + try: + ckpt = np.load(path + ".npz") + self.m2, self.s2 = ckpt["m2"], ckpt["s2"] + self.print_fn("Dataset stats loaded from disk.") + ckpt.close() + except OSError: + num_batches = int(math.ceil(self.n_samples / self.batch_size)) + stacked_real_features = [] + self.print_fn(f"Stacking Inception features for {self.n_samples} samples from the real dataset.") + for _ in tqdm(range(num_batches)): + try: + real_samples = next(self.dl) + except StopIteration: + break + if isinstance(real_samples, (list, tuple)): + real_samples = real_samples[0] + real_samples = real_samples.to(self.device) + real_features = self.calculate_inception_features(real_samples) + stacked_real_features.append(real_features) + stacked_real_features = torch.cat(stacked_real_features, dim=0).cpu().numpy() + m2 = np.mean(stacked_real_features, axis=0) + s2 = np.cov(stacked_real_features, rowvar=False) + np.savez_compressed(path, m2=m2, s2=s2) + self.print_fn(f"Dataset stats cached to {path}.npz for future use.") + self.m2, self.s2 = m2, s2 + self.dataset_stats_loaded = True + + @torch.inference_mode() + def fid_score(self): + if not self.dataset_stats_loaded: + self.load_or_precalc_dataset_stats() + # self.sampler.eval() + # batches = num_to_groups(self.n_samples, self.batch_size) + if self.samples_path is None: + raise ValueError("Samples path is not set. Call set_samples_path() before calculating FID score.") + + # 저장된 샘플 로드 + samples_data = np.load(self.samples_path) + if samples_data.files[0] == "arr_0": + samples = samples_data["arr_0"] + else: + samples = samples_data[samples_data.files[0]] + + if len(samples) < self.n_samples: + raise ValueError(f"Not enough samples. Found {len(samples)}, need {self.n_samples}") + + samples = samples[: self.n_samples] # 필요한 만큼만 사용 + + stacked_fake_features = [] + self.print_fn(f"Stacking Inception features for {self.n_samples} generated samples.") + # for batch in tqdm(batches): + for i in tqdm(range(0, self.n_samples, self.batch_size)): + batch = samples[i : i + self.batch_size] + batch = torch.from_numpy(batch).float().to(self.device) / 127.5 - 1 # 정규화 + batch = batch.permute(0, 3, 1, 2) # NHWC to NCHW + fake_features = self.calculate_inception_features(batch) + # fake_samples = self.sampler.sample(batch_size=batch) + # fake_features = self.calculate_inception_features(fake_samples) + stacked_fake_features.append(fake_features) + stacked_fake_features = torch.cat(stacked_fake_features, dim=0).cpu().numpy() + m1 = np.mean(stacked_fake_features, axis=0) + s1 = np.cov(stacked_fake_features, rowvar=False) + + return calculate_frechet_distance(m1, s1, self.m2, self.s2) diff --git a/improved_diffusion/train_util.py b/improved_diffusion/train_util.py index 9bd2990614..cb1bc2a4d6 100644 --- a/improved_diffusion/train_util.py +++ b/improved_diffusion/train_util.py @@ -22,6 +22,7 @@ import wandb from tqdm import tqdm +from datetime import datetime # For ImageNet experiments, this was a good default value. # We found that the lg_loss_scale quickly climbed to @@ -48,6 +49,8 @@ def __init__( schedule_sampler=None, weight_decay=0.0, lr_anneal_steps=0, + sample_interval=2500, + fid_evaluator=None, ): self.model = model self.diffusion = diffusion @@ -64,6 +67,8 @@ def __init__( self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) self.weight_decay = weight_decay self.lr_anneal_steps = lr_anneal_steps + self.sample_interval = sample_interval + self.fid_evaluator = fid_evaluator self.step = 0 self.resume_step = 0 @@ -159,6 +164,66 @@ def _setup_fp16(self): self.master_params = make_master_params(self.model_params) self.model.convert_to_fp16() + def generate_samples(self): + all_images = [] + all_labels = [] + num_samples = self.fid_evaluator.n_samples # 또는 원하는 샘플 수 + batch_size = self.batch_size # 또는 원하는 배치 크기 + + with th.no_grad(): + while len(all_images) * batch_size < num_samples: + model_kwargs = {} + # if self.args.class_cond: + # classes = th.randint(low=0, high=self.num_classes, size=(batch_size,), device=self.device) + # model_kwargs["y"] = classes + + sample_fn = ( + self.diffusion.ddim_sample_loop + # self.diffusion.p_sample_loop if not self.diffusion.use_ddim else self.diffusion.ddim_sample_loop + ) + sample = sample_fn( + self.model, + (batch_size, 3, 64, 64), + # (batch_size, 3, self.diffusion.image_size, self.diffusion.image_size), + clip_denoised=True, + # clip_denoised=self.args.clip_denoised, + model_kwargs=model_kwargs, + ) + sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) + sample = sample.permute(0, 2, 3, 1) + sample = sample.contiguous() + + all_images.extend([sample.cpu().numpy()]) + # if self.args.class_cond: + # all_labels.extend([classes.cpu().numpy()]) + + logger.log(f"created {len(all_images) * batch_size} samples") + + arr = np.concatenate(all_images, axis=0) + arr = arr[:num_samples] + # if self.args.class_cond: + # label_arr = np.concatenate(all_labels, axis=0) + # label_arr = label_arr[:num_samples] + + # 샘플 저장 + shape_str = "x".join([str(x) for x in arr.shape]) + # out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz") + if not os.path.exists(self.fid_evaluator.stats_dir): + os.makedirs(self.fid_evaluator.stats_dir) + print(f"{self.fid_evaluator.stats_dir} created") + else: + print(f"{self.fid_evaluator.stats_dir} already exists") + out_path = os.path.join(self.fid_evaluator.stats_dir, f"{self.step}_samples_{shape_str}.npz") + + logger.log(f"saving samples to {out_path}") + # if self.args.class_cond: + # np.savez(out_path, arr, label_arr) + # else: + # np.savez(out_path, arr) + np.savez(out_path, arr) + logger.log("sampling complete") + return out_path # 샘플이 저장된 경로 + def run_loop(self): total_steps = self.lr_anneal_steps if self.lr_anneal_steps else float("inf") with tqdm(total=total_steps, desc="Training Progress", dynamic_ncols=True) as pbar: @@ -172,6 +237,17 @@ def run_loop(self): # Run for a finite amount of time in integration tests. if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: return + + if self.step % self.sample_interval == 0: + logger.log(f"Generating samples at step {self.step}...") + samples_path = self.generate_samples() + self.fid_evaluator.set_samples_path(samples_path) + + logger.log(f"Calculating FID score...") + fid_score = self.fid_evaluator.fid_score() + + logger.log(f"FID score at step {self.step}: {fid_score}") + wandb.log({"valid/fid_score": fid_score}) self.step += 1 pbar.update(1) # Update the progress bar @@ -281,7 +357,7 @@ def log_step(self): if self.use_fp16: logger.logkv("lg_loss_scale", self.lg_loss_scale) - def save(self, save_dir="/home/jun/improved-diffusion/results"): + def save(self, save_dir=f"/home/jun/improved-diffusion/results/{datetime.now().strftime('%Y%m%d_%H%M')}"): def save_checkpoint(rate, params): state_dict = self._master_params_to_state_dict(params) if dist.get_rank() == 0: diff --git a/scripts/image_train.py b/scripts/image_train.py index e3c0024714..da9bdba941 100644 --- a/scripts/image_train.py +++ b/scripts/image_train.py @@ -14,6 +14,9 @@ add_dict_to_argparser, ) from improved_diffusion.train_util import TrainLoop +from improved_diffusion.fid_evaluation import FIDEvaluation + +from datetime import datetime def main(): @@ -24,7 +27,7 @@ def main(): logger.log("creating model and diffusion...") model, diffusion = create_model_and_diffusion(**args_to_dict(args, model_and_diffusion_defaults().keys())) - #model.load_state_dict(dist_util.load_state_dict(args.model_path, map_location="cpu")) + # model.load_state_dict(dist_util.load_state_dict(args.model_path, map_location="cpu")) model.to(dist_util.dev()) print(dist_util.dev()) schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) @@ -37,6 +40,18 @@ def main(): class_cond=args.class_cond, ) + logger.log("creating FID evaluator...") + + fid_evaluator = FIDEvaluation( + batch_size=args.batch_size, + dl=data, + sampler=diffusion, + channels=3, + device=dist_util.dev(), + num_fid_samples=args.num_fid_samples, + stats_dir=args.stats_dir, + ) + logger.log("training...") TrainLoop( model=model, @@ -54,6 +69,8 @@ def main(): schedule_sampler=schedule_sampler, weight_decay=args.weight_decay, lr_anneal_steps=args.lr_anneal_steps, + sample_interval=args.sample_interval, # Validation을 위한 sample interval + fid_evaluator=fid_evaluator, ).run_loop() @@ -63,7 +80,7 @@ def create_argparser(): schedule_sampler="uniform", lr=1e-4, weight_decay=0.0, - lr_anneal_steps=10000, + lr_anneal_steps=30000, batch_size=1, microbatch=-1, # -1 disables microbatches ema_rate="0.9999", # comma-separated list of EMA values @@ -73,6 +90,9 @@ def create_argparser(): use_fp16=True, fp16_scale_growth=1e-3, model_path="", + num_fid_samples=32, + sample_interval=2500, + stats_dir=f"./results/{datetime.now().strftime('%Y%m%d_%H%M')}/stats_dir", ) defaults.update(model_and_diffusion_defaults()) parser = argparse.ArgumentParser()