From 2542ab88b2b1023a5aba3c56762ceb3b4b75f0d1 Mon Sep 17 00:00:00 2001 From: quasar529 Date: Wed, 10 Jul 2024 20:23:09 +0900 Subject: [PATCH 01/12] train setting - logging and sampling --- .gitignore | 4 ++ improved_diffusion/train_util.py | 73 +++++++++++--------------------- scripts/image_sample.py | 27 +++++------- scripts/image_train.py | 6 +-- 4 files changed, 41 insertions(+), 69 deletions(-) diff --git a/.gitignore b/.gitignore index 2d97991d7e..110f2eccfa 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ .DS_Store __pycache__/ +ckpt/ +*.ipynb +datasets/cifar_test +datasets/cifar_train \ No newline at end of file diff --git a/improved_diffusion/train_util.py b/improved_diffusion/train_util.py index 1867604145..e8e5a82fd2 100644 --- a/improved_diffusion/train_util.py +++ b/improved_diffusion/train_util.py @@ -20,6 +20,8 @@ from .nn import update_ema from .resample import LossAwareSampler, UniformSampler +import wandb + # For ImageNet experiments, this was a good default value. # We found that the lg_loss_scale quickly climbed to # 20-21 within the first ~1K steps of training. @@ -52,11 +54,7 @@ def __init__( self.batch_size = batch_size self.microbatch = microbatch if microbatch > 0 else batch_size self.lr = lr - self.ema_rate = ( - [ema_rate] - if isinstance(ema_rate, float) - else [float(x) for x in ema_rate.split(",")] - ) + self.ema_rate = [ema_rate] if isinstance(ema_rate, float) else [float(x) for x in ema_rate.split(",")] self.log_interval = log_interval self.save_interval = save_interval self.resume_checkpoint = resume_checkpoint @@ -84,13 +82,9 @@ def __init__( self._load_optimizer_state() # Model was resumed, either due to a restart or a checkpoint # being specified at the command line. - self.ema_params = [ - self._load_ema_parameters(rate) for rate in self.ema_rate - ] + self.ema_params = [self._load_ema_parameters(rate) for rate in self.ema_rate] else: - self.ema_params = [ - copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate)) - ] + self.ema_params = [copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))] if th.cuda.is_available(): self.use_ddp = True @@ -104,13 +98,12 @@ def __init__( ) else: if dist.get_world_size() > 1: - logger.warn( - "Distributed training requires CUDA. " - "Gradients will not be synchronized properly!" - ) + logger.warn("Distributed training requires CUDA. " "Gradients will not be synchronized properly!") self.use_ddp = False self.ddp_model = self.model + wandb.init(project="imporved_diffusion", entity="quasar529") + def _load_and_sync_parameters(self): resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint @@ -118,11 +111,7 @@ def _load_and_sync_parameters(self): self.resume_step = parse_resume_step_from_filename(resume_checkpoint) if dist.get_rank() == 0: logger.log(f"loading model from checkpoint: {resume_checkpoint}...") - self.model.load_state_dict( - dist_util.load_state_dict( - resume_checkpoint, map_location=dist_util.dev() - ) - ) + self.model.load_state_dict(dist_util.load_state_dict(resume_checkpoint, map_location=dist_util.dev())) dist_util.sync_params(self.model.parameters()) @@ -134,9 +123,7 @@ def _load_ema_parameters(self, rate): if ema_checkpoint: if dist.get_rank() == 0: logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") - state_dict = dist_util.load_state_dict( - ema_checkpoint, map_location=dist_util.dev() - ) + state_dict = dist_util.load_state_dict(ema_checkpoint, map_location=dist_util.dev()) ema_params = self._state_dict_to_master_params(state_dict) dist_util.sync_params(ema_params) @@ -144,14 +131,10 @@ def _load_ema_parameters(self, rate): def _load_optimizer_state(self): main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint - opt_checkpoint = bf.join( - bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" - ) + opt_checkpoint = bf.join(bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt") if bf.exists(opt_checkpoint): logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") - state_dict = dist_util.load_state_dict( - opt_checkpoint, map_location=dist_util.dev() - ) + state_dict = dist_util.load_state_dict(opt_checkpoint, map_location=dist_util.dev()) self.opt.load_state_dict(state_dict) def _setup_fp16(self): @@ -159,10 +142,7 @@ def _setup_fp16(self): self.model.convert_to_fp16() def run_loop(self): - while ( - not self.lr_anneal_steps - or self.step + self.resume_step < self.lr_anneal_steps - ): + while not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps: batch, cond = next(self.data) self.run_step(batch, cond) if self.step % self.log_interval == 0: @@ -189,10 +169,7 @@ def forward_backward(self, batch, cond): zero_grad(self.model_params) for i in range(0, batch.shape[0], self.microbatch): micro = batch[i : i + self.microbatch].to(dist_util.dev()) - micro_cond = { - k: v[i : i + self.microbatch].to(dist_util.dev()) - for k, v in cond.items() - } + micro_cond = {k: v[i : i + self.microbatch].to(dist_util.dev()) for k, v in cond.items()} last_batch = (i + self.microbatch) >= batch.shape[0] t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) @@ -211,16 +188,13 @@ def forward_backward(self, batch, cond): losses = compute_losses() if isinstance(self.schedule_sampler, LossAwareSampler): - self.schedule_sampler.update_with_local_losses( - t, losses["loss"].detach() - ) + self.schedule_sampler.update_with_local_losses(t, losses["loss"].detach()) loss = (losses["loss"] * weights).mean() - log_loss_dict( - self.diffusion, t, {k: v * weights for k, v in losses.items()} - ) + wandb.log({"train/loss": loss.item()}) + log_loss_dict(self.diffusion, t, {k: v * weights for k, v in losses.items()}) if self.use_fp16: - loss_scale = 2 ** self.lg_loss_scale + loss_scale = 2**self.lg_loss_scale (loss * loss_scale).backward() else: loss.backward() @@ -232,7 +206,7 @@ def optimize_fp16(self): return model_grads_to_master_grads(self.model_params, self.master_params) - self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) + self.master_params[0].grad.mul_(1.0 / (2**self.lg_loss_scale)) self._log_grad_norm() self._anneal_lr() self.opt.step() @@ -251,7 +225,8 @@ def optimize_normal(self): def _log_grad_norm(self): sqsum = 0.0 for p in self.master_params: - sqsum += (p.grad ** 2).sum().item() + sqsum += (p.grad**2).sum().item() + wandb.log({"train/grad_norm": np.sqrt(sqsum)}) logger.logkv_mean("grad_norm", np.sqrt(sqsum)) def _anneal_lr(self): @@ -261,10 +236,12 @@ def _anneal_lr(self): lr = self.lr * (1 - frac_done) for param_group in self.opt.param_groups: param_group["lr"] = lr + wandb.log({"train/lr": lr}) def log_step(self): logger.logkv("step", self.step + self.resume_step) logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) + wandb.log({"train/samples": (self.step + self.resume_step + 1) * self.global_batch}) if self.use_fp16: logger.logkv("lg_loss_scale", self.lg_loss_scale) @@ -295,9 +272,7 @@ def save_checkpoint(rate, params): def _master_params_to_state_dict(self, master_params): if self.use_fp16: - master_params = unflatten_master_params( - self.model.parameters(), master_params - ) + master_params = unflatten_master_params(self.model.parameters(), master_params) state_dict = self.model.state_dict() for i, (name, _value) in enumerate(self.model.named_parameters()): assert name in state_dict diff --git a/scripts/image_sample.py b/scripts/image_sample.py index 66326f6908..2e410be327 100644 --- a/scripts/image_sample.py +++ b/scripts/image_sample.py @@ -9,7 +9,10 @@ import numpy as np import torch as th import torch.distributed as dist +import sys +print(sys.path) +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from improved_diffusion import dist_util, logger from improved_diffusion.script_util import ( NUM_CLASSES, @@ -27,12 +30,8 @@ def main(): logger.configure() logger.log("creating model and diffusion...") - model, diffusion = create_model_and_diffusion( - **args_to_dict(args, model_and_diffusion_defaults().keys()) - ) - model.load_state_dict( - dist_util.load_state_dict(args.model_path, map_location="cpu") - ) + model, diffusion = create_model_and_diffusion(**args_to_dict(args, model_and_diffusion_defaults().keys())) + model.load_state_dict(dist_util.load_state_dict(args.model_path, map_location="cpu")) model.to(dist_util.dev()) model.eval() @@ -42,13 +41,9 @@ def main(): while len(all_images) * args.batch_size < args.num_samples: model_kwargs = {} if args.class_cond: - classes = th.randint( - low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev() - ) + classes = th.randint(low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()) model_kwargs["y"] = classes - sample_fn = ( - diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop - ) + sample_fn = diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop sample = sample_fn( model, (args.batch_size, 3, args.image_size, args.image_size), @@ -63,9 +58,7 @@ def main(): dist.all_gather(gathered_samples, sample) # gather not supported with NCCL all_images.extend([sample.cpu().numpy() for sample in gathered_samples]) if args.class_cond: - gathered_labels = [ - th.zeros_like(classes) for _ in range(dist.get_world_size()) - ] + gathered_labels = [th.zeros_like(classes) for _ in range(dist.get_world_size())] dist.all_gather(gathered_labels, classes) all_labels.extend([labels.cpu().numpy() for labels in gathered_labels]) logger.log(f"created {len(all_images) * args.batch_size} samples") @@ -91,9 +84,9 @@ def main(): def create_argparser(): defaults = dict( clip_denoised=True, - num_samples=10000, + num_samples=1000, batch_size=16, - use_ddim=False, + use_ddim=True, model_path="", ) defaults.update(model_and_diffusion_defaults()) diff --git a/scripts/image_train.py b/scripts/image_train.py index 76c0632608..339590a0ad 100644 --- a/scripts/image_train.py +++ b/scripts/image_train.py @@ -23,9 +23,8 @@ def main(): logger.configure() logger.log("creating model and diffusion...") - model, diffusion = create_model_and_diffusion( - **args_to_dict(args, model_and_diffusion_defaults().keys()) - ) + model, diffusion = create_model_and_diffusion(**args_to_dict(args, model_and_diffusion_defaults().keys())) + model.load_state_dict(dist_util.load_state_dict(args.model_path, map_location="cpu")) model.to(dist_util.dev()) schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) @@ -72,6 +71,7 @@ def create_argparser(): resume_checkpoint="", use_fp16=False, fp16_scale_growth=1e-3, + model_path="", ) defaults.update(model_and_diffusion_defaults()) parser = argparse.ArgumentParser() From d04c2af47dd4fe54c7342d8a56024b4b81a3f15e Mon Sep 17 00:00:00 2001 From: quasar529 Date: Thu, 11 Jul 2024 19:36:50 +0900 Subject: [PATCH 02/12] logging --- .gitignore | 6 +++++- improved_diffusion/train_util.py | 3 ++- scripts/image_train.py | 1 + 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 110f2eccfa..0dc057dc6d 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,8 @@ __pycache__/ ckpt/ *.ipynb datasets/cifar_test -datasets/cifar_train \ No newline at end of file +datasets/cifar_train +.venv/ +anaconda.sh +improved_diffusion.egg-info/ +wandb/ diff --git a/improved_diffusion/train_util.py b/improved_diffusion/train_util.py index e8e5a82fd2..ece949c1da 100644 --- a/improved_diffusion/train_util.py +++ b/improved_diffusion/train_util.py @@ -241,7 +241,8 @@ def _anneal_lr(self): def log_step(self): logger.logkv("step", self.step + self.resume_step) logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) - wandb.log({"train/samples": (self.step + self.resume_step + 1) * self.global_batch}) + wandb.log({"train/step": self.step + self.resume_step}) + wandb.log({"train/samples": (self.step + self.resume_step + 1) * self.global_batch}) if self.use_fp16: logger.logkv("lg_loss_scale", self.lg_loss_scale) diff --git a/scripts/image_train.py b/scripts/image_train.py index 339590a0ad..c2e1065068 100644 --- a/scripts/image_train.py +++ b/scripts/image_train.py @@ -26,6 +26,7 @@ def main(): model, diffusion = create_model_and_diffusion(**args_to_dict(args, model_and_diffusion_defaults().keys())) model.load_state_dict(dist_util.load_state_dict(args.model_path, map_location="cpu")) model.to(dist_util.dev()) + print(dist_util.dev()) schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) logger.log("creating data loader...") From a33642475a961cc426d9489281a4af04fece0762 Mon Sep 17 00:00:00 2001 From: quasar529 Date: Thu, 11 Jul 2024 20:04:18 +0900 Subject: [PATCH 03/12] add: tqdm --- improved_diffusion/train_util.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/improved_diffusion/train_util.py b/improved_diffusion/train_util.py index ece949c1da..51e656b4f5 100644 --- a/improved_diffusion/train_util.py +++ b/improved_diffusion/train_util.py @@ -21,6 +21,7 @@ from .resample import LossAwareSampler, UniformSampler import wandb +from tqdm import tqdm # For ImageNet experiments, this was a good default value. # We found that the lg_loss_scale quickly climbed to @@ -167,7 +168,8 @@ def run_step(self, batch, cond): def forward_backward(self, batch, cond): zero_grad(self.model_params) - for i in range(0, batch.shape[0], self.microbatch): + progress_bar = tqdm(range(0, batch.shape[0], self.microbatch), desc="Processing Batches") + for i in progress_bar: micro = batch[i : i + self.microbatch].to(dist_util.dev()) micro_cond = {k: v[i : i + self.microbatch].to(dist_util.dev()) for k, v in cond.items()} last_batch = (i + self.microbatch) >= batch.shape[0] @@ -198,6 +200,9 @@ def forward_backward(self, batch, cond): (loss * loss_scale).backward() else: loss.backward() + progress_bar.set_postfix(loss=loss.item()) + + tqdm.write(f"Batch {i // self.microbatch + 1}/{len(progress_bar)} - Loss: {loss.item()}") def optimize_fp16(self): if any(not th.isfinite(p.grad).all() for p in self.model_params): @@ -242,7 +247,7 @@ def log_step(self): logger.logkv("step", self.step + self.resume_step) logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) wandb.log({"train/step": self.step + self.resume_step}) - wandb.log({"train/samples": (self.step + self.resume_step + 1) * self.global_batch}) + wandb.log({"train/samples": (self.step + self.resume_step + 1) * self.global_batch}) if self.use_fp16: logger.logkv("lg_loss_scale", self.lg_loss_scale) From d548d36a1b661a29bdaacb83c5b82ed6ce3a7945 Mon Sep 17 00:00:00 2001 From: quasar529 Date: Thu, 11 Jul 2024 20:06:50 +0900 Subject: [PATCH 04/12] wandb config --- improved_diffusion/train_util.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/improved_diffusion/train_util.py b/improved_diffusion/train_util.py index 51e656b4f5..ab409db350 100644 --- a/improved_diffusion/train_util.py +++ b/improved_diffusion/train_util.py @@ -103,7 +103,23 @@ def __init__( self.use_ddp = False self.ddp_model = self.model - wandb.init(project="imporved_diffusion", entity="quasar529") + wandb.init(project="imporved_diffusion", entity="quasar529") + wandb.config.update( + { + "batch_size": batch_size, + "microbatch": microbatch, + "lr": lr, + "ema_rate": ema_rate, + "log_interval": log_interval, + "save_interval": save_interval, + "resume_checkpoint": resume_checkpoint, + "use_fp16": use_fp16, + "fp16_scale_growth": fp16_scale_growth, + "schedule_sampler": schedule_sampler, + "weight_decay": weight_decay, + "lr_anneal_steps": lr_anneal_steps, + } + ) def _load_and_sync_parameters(self): resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint From 9cfa8ddbd8ca09a0de08c8611ee5cb9fe0ee7d9a Mon Sep 17 00:00:00 2001 From: quasar529 Date: Thu, 11 Jul 2024 20:25:20 +0900 Subject: [PATCH 05/12] train checked --- .gitignore | 1 + scripts/image_sample.py | 2 +- scripts/image_train.py | 6 +++--- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 0dc057dc6d..b1bc2bdf00 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ datasets/cifar_train anaconda.sh improved_diffusion.egg-info/ wandb/ +*.pt diff --git a/scripts/image_sample.py b/scripts/image_sample.py index 2e410be327..e6210b120f 100644 --- a/scripts/image_sample.py +++ b/scripts/image_sample.py @@ -84,7 +84,7 @@ def main(): def create_argparser(): defaults = dict( clip_denoised=True, - num_samples=1000, + num_samples=10, batch_size=16, use_ddim=True, model_path="", diff --git a/scripts/image_train.py b/scripts/image_train.py index c2e1065068..8b9157b711 100644 --- a/scripts/image_train.py +++ b/scripts/image_train.py @@ -24,7 +24,7 @@ def main(): logger.log("creating model and diffusion...") model, diffusion = create_model_and_diffusion(**args_to_dict(args, model_and_diffusion_defaults().keys())) - model.load_state_dict(dist_util.load_state_dict(args.model_path, map_location="cpu")) + #model.load_state_dict(dist_util.load_state_dict(args.model_path, map_location="cpu")) model.to(dist_util.dev()) print(dist_util.dev()) schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) @@ -63,14 +63,14 @@ def create_argparser(): schedule_sampler="uniform", lr=1e-4, weight_decay=0.0, - lr_anneal_steps=0, + lr_anneal_steps=10000, batch_size=1, microbatch=-1, # -1 disables microbatches ema_rate="0.9999", # comma-separated list of EMA values log_interval=10, save_interval=10000, resume_checkpoint="", - use_fp16=False, + use_fp16=True, fp16_scale_growth=1e-3, model_path="", ) From 780fbc1621ec2ff9908405b778fec23ed4bf4591 Mon Sep 17 00:00:00 2001 From: quasar529 Date: Thu, 11 Jul 2024 20:29:33 +0900 Subject: [PATCH 06/12] fix: tqdm --- improved_diffusion/train_util.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/improved_diffusion/train_util.py b/improved_diffusion/train_util.py index ab409db350..82a1f1fcf3 100644 --- a/improved_diffusion/train_util.py +++ b/improved_diffusion/train_util.py @@ -120,6 +120,7 @@ def __init__( "lr_anneal_steps": lr_anneal_steps, } ) + wandb.run.name = f"diffusion_step_resume_checkpoint_{self.resume_checkpoint}" def _load_and_sync_parameters(self): resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint @@ -159,17 +160,21 @@ def _setup_fp16(self): self.model.convert_to_fp16() def run_loop(self): - while not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps: - batch, cond = next(self.data) - self.run_step(batch, cond) - if self.step % self.log_interval == 0: - logger.dumpkvs() - if self.step % self.save_interval == 0: - self.save() - # Run for a finite amount of time in integration tests. - if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: - return - self.step += 1 + total_steps = self.lr_anneal_steps if self.lr_anneal_steps else float("inf") + with tqdm(total=total_steps, desc="Training Progress") as pbar: + while not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps: + batch, cond = next(self.data) + self.run_step(batch, cond) + if self.step % self.log_interval == 0: + logger.dumpkvs() + if self.step % self.save_interval == 0: + self.save() + # Run for a finite amount of time in integration tests. + if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: + return + self.step += 1 + pbar.update(1) # Update the progress bar + # Save the last checkpoint if it wasn't already saved. if (self.step - 1) % self.save_interval != 0: self.save() @@ -185,7 +190,7 @@ def run_step(self, batch, cond): def forward_backward(self, batch, cond): zero_grad(self.model_params) progress_bar = tqdm(range(0, batch.shape[0], self.microbatch), desc="Processing Batches") - for i in progress_bar: + for i in range(0, batch.shape[0], self.microbatch): micro = batch[i : i + self.microbatch].to(dist_util.dev()) micro_cond = {k: v[i : i + self.microbatch].to(dist_util.dev()) for k, v in cond.items()} last_batch = (i + self.microbatch) >= batch.shape[0] @@ -217,6 +222,7 @@ def forward_backward(self, batch, cond): else: loss.backward() progress_bar.set_postfix(loss=loss.item()) + progress_bar.update(1) # 배치 진행 상황 업데이트 tqdm.write(f"Batch {i // self.microbatch + 1}/{len(progress_bar)} - Loss: {loss.item()}") From d374cdcf34c3704b65ba421ea26684d73aae389b Mon Sep 17 00:00:00 2001 From: quasar529 Date: Thu, 11 Jul 2024 20:31:11 +0900 Subject: [PATCH 07/12] fix:tqdm bar --- improved_diffusion/train_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/improved_diffusion/train_util.py b/improved_diffusion/train_util.py index 82a1f1fcf3..6db8d952a4 100644 --- a/improved_diffusion/train_util.py +++ b/improved_diffusion/train_util.py @@ -225,6 +225,7 @@ def forward_backward(self, batch, cond): progress_bar.update(1) # 배치 진행 상황 업데이트 tqdm.write(f"Batch {i // self.microbatch + 1}/{len(progress_bar)} - Loss: {loss.item()}") + progress_bar.close() # 배치 진행 바 종료 def optimize_fp16(self): if any(not th.isfinite(p.grad).all() for p in self.model_params): From 0775a3b2015e2bc38388b6b2029ff877191acc36 Mon Sep 17 00:00:00 2001 From: quasar529 Date: Thu, 11 Jul 2024 20:38:59 +0900 Subject: [PATCH 08/12] fix: save and log --- improved_diffusion/train_util.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/improved_diffusion/train_util.py b/improved_diffusion/train_util.py index 6db8d952a4..7ad1cd1c84 100644 --- a/improved_diffusion/train_util.py +++ b/improved_diffusion/train_util.py @@ -189,7 +189,9 @@ def run_step(self, batch, cond): def forward_backward(self, batch, cond): zero_grad(self.model_params) - progress_bar = tqdm(range(0, batch.shape[0], self.microbatch), desc="Processing Batches") + progress_bar = tqdm( + range(0, batch.shape[0], self.microbatch), desc="Processing Batches", leave=False, dynamic_ncols=True + ) for i in range(0, batch.shape[0], self.microbatch): micro = batch[i : i + self.microbatch].to(dist_util.dev()) micro_cond = {k: v[i : i + self.microbatch].to(dist_util.dev()) for k, v in cond.items()} @@ -274,7 +276,7 @@ def log_step(self): if self.use_fp16: logger.logkv("lg_loss_scale", self.lg_loss_scale) - def save(self): + def save(self, save_dir="/home/jun/improved-diffusion/results"): def save_checkpoint(rate, params): state_dict = self._master_params_to_state_dict(params) if dist.get_rank() == 0: @@ -283,7 +285,8 @@ def save_checkpoint(rate, params): filename = f"model{(self.step+self.resume_step):06d}.pt" else: filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" - with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: + save_path = bf.join(save_dir, filename) + with bf.BlobFile(save_path, "wb") as f: th.save(state_dict, f) save_checkpoint(0, self.master_params) @@ -291,10 +294,9 @@ def save_checkpoint(rate, params): save_checkpoint(rate, params) if dist.get_rank() == 0: - with bf.BlobFile( - bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), - "wb", - ) as f: + opt_filename = f"opt{(self.step+self.resume_step):06d}.pt" + opt_save_path = bf.join(save_dir, opt_filename) + with bf.BlobFile(opt_save_path, "wb") as f: th.save(self.opt.state_dict(), f) dist.barrier() From eae9d4a872f1b96d4d5788bc209b00798d15d0da Mon Sep 17 00:00:00 2001 From: quasar529 Date: Thu, 11 Jul 2024 20:43:05 +0900 Subject: [PATCH 09/12] update gitignore --- .gitignore | 1 + scripts/image_train.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index b1bc2bdf00..8f6aef9c7e 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ anaconda.sh improved_diffusion.egg-info/ wandb/ *.pt +results/ diff --git a/scripts/image_train.py b/scripts/image_train.py index 8b9157b711..e3c0024714 100644 --- a/scripts/image_train.py +++ b/scripts/image_train.py @@ -68,7 +68,7 @@ def create_argparser(): microbatch=-1, # -1 disables microbatches ema_rate="0.9999", # comma-separated list of EMA values log_interval=10, - save_interval=10000, + save_interval=5000, resume_checkpoint="", use_fp16=True, fp16_scale_growth=1e-3, From 59dd508506bda553df9985d5c2cacdcd04dc6468 Mon Sep 17 00:00:00 2001 From: quasar529 Date: Thu, 11 Jul 2024 20:46:18 +0900 Subject: [PATCH 10/12] fix: tqdm --- improved_diffusion/train_util.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/improved_diffusion/train_util.py b/improved_diffusion/train_util.py index 7ad1cd1c84..e4ca57017f 100644 --- a/improved_diffusion/train_util.py +++ b/improved_diffusion/train_util.py @@ -161,7 +161,7 @@ def _setup_fp16(self): def run_loop(self): total_steps = self.lr_anneal_steps if self.lr_anneal_steps else float("inf") - with tqdm(total=total_steps, desc="Training Progress") as pbar: + with tqdm(total=total_steps, desc="Training Progress", dynamic_ncols=True) as pbar: while not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps: batch, cond = next(self.data) self.run_step(batch, cond) @@ -189,9 +189,9 @@ def run_step(self, batch, cond): def forward_backward(self, batch, cond): zero_grad(self.model_params) - progress_bar = tqdm( - range(0, batch.shape[0], self.microbatch), desc="Processing Batches", leave=False, dynamic_ncols=True - ) + total_batches = (batch.shape[0] + self.microbatch - 1) // self.microbatch # 전체 배치 수 계산 + progress_bar = tqdm(total=total_batches, desc="Processing Batches", leave=False, dynamic_ncols=True) + for i in range(0, batch.shape[0], self.microbatch): micro = batch[i : i + self.microbatch].to(dist_util.dev()) micro_cond = {k: v[i : i + self.microbatch].to(dist_util.dev()) for k, v in cond.items()} @@ -218,6 +218,7 @@ def forward_backward(self, batch, cond): loss = (losses["loss"] * weights).mean() wandb.log({"train/loss": loss.item()}) log_loss_dict(self.diffusion, t, {k: v * weights for k, v in losses.items()}) + if self.use_fp16: loss_scale = 2**self.lg_loss_scale (loss * loss_scale).backward() @@ -226,7 +227,6 @@ def forward_backward(self, batch, cond): progress_bar.set_postfix(loss=loss.item()) progress_bar.update(1) # 배치 진행 상황 업데이트 - tqdm.write(f"Batch {i // self.microbatch + 1}/{len(progress_bar)} - Loss: {loss.item()}") progress_bar.close() # 배치 진행 바 종료 def optimize_fp16(self): From e9cacc9cdf55b2e3533a47ab982d696e1e59a609 Mon Sep 17 00:00:00 2001 From: quasar529 Date: Thu, 11 Jul 2024 20:47:47 +0900 Subject: [PATCH 11/12] fix: tqdm --- improved_diffusion/train_util.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/improved_diffusion/train_util.py b/improved_diffusion/train_util.py index e4ca57017f..f3ead243c9 100644 --- a/improved_diffusion/train_util.py +++ b/improved_diffusion/train_util.py @@ -193,6 +193,11 @@ def forward_backward(self, batch, cond): progress_bar = tqdm(total=total_batches, desc="Processing Batches", leave=False, dynamic_ncols=True) for i in range(0, batch.shape[0], self.microbatch): + current_batch = i // self.microbatch + 1 # 현재 배치 번호 계산 + progress_bar.set_description( + f"Processing Batches {current_batch}/{total_batches}" + ) # 진행 바 설명 업데이트 + micro = batch[i : i + self.microbatch].to(dist_util.dev()) micro_cond = {k: v[i : i + self.microbatch].to(dist_util.dev()) for k, v in cond.items()} last_batch = (i + self.microbatch) >= batch.shape[0] From 8b3bc07faa6c0b24623645dc74628e6f35800112 Mon Sep 17 00:00:00 2001 From: quasar529 Date: Thu, 11 Jul 2024 20:48:58 +0900 Subject: [PATCH 12/12] fix: wandb name --- improved_diffusion/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/improved_diffusion/train_util.py b/improved_diffusion/train_util.py index f3ead243c9..9bd2990614 100644 --- a/improved_diffusion/train_util.py +++ b/improved_diffusion/train_util.py @@ -120,7 +120,7 @@ def __init__( "lr_anneal_steps": lr_anneal_steps, } ) - wandb.run.name = f"diffusion_step_resume_checkpoint_{self.resume_checkpoint}" + wandb.run.name = f"diffusion_batch_{batch_size}_steps{lr_anneal_steps}_{self.resume_checkpoint}" def _load_and_sync_parameters(self): resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint