diff --git a/.gitignore b/.gitignore index 2d97991d7e..8f6aef9c7e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,13 @@ .DS_Store __pycache__/ +ckpt/ +*.ipynb +datasets/cifar_test +datasets/cifar_train +.venv/ +anaconda.sh +improved_diffusion.egg-info/ +wandb/ +*.pt +results/ diff --git a/improved_diffusion/train_util.py b/improved_diffusion/train_util.py index 1867604145..9bd2990614 100644 --- a/improved_diffusion/train_util.py +++ b/improved_diffusion/train_util.py @@ -20,6 +20,9 @@ from .nn import update_ema from .resample import LossAwareSampler, UniformSampler +import wandb +from tqdm import tqdm + # For ImageNet experiments, this was a good default value. # We found that the lg_loss_scale quickly climbed to # 20-21 within the first ~1K steps of training. @@ -52,11 +55,7 @@ def __init__( self.batch_size = batch_size self.microbatch = microbatch if microbatch > 0 else batch_size self.lr = lr - self.ema_rate = ( - [ema_rate] - if isinstance(ema_rate, float) - else [float(x) for x in ema_rate.split(",")] - ) + self.ema_rate = [ema_rate] if isinstance(ema_rate, float) else [float(x) for x in ema_rate.split(",")] self.log_interval = log_interval self.save_interval = save_interval self.resume_checkpoint = resume_checkpoint @@ -84,13 +83,9 @@ def __init__( self._load_optimizer_state() # Model was resumed, either due to a restart or a checkpoint # being specified at the command line. - self.ema_params = [ - self._load_ema_parameters(rate) for rate in self.ema_rate - ] + self.ema_params = [self._load_ema_parameters(rate) for rate in self.ema_rate] else: - self.ema_params = [ - copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate)) - ] + self.ema_params = [copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))] if th.cuda.is_available(): self.use_ddp = True @@ -104,13 +99,29 @@ def __init__( ) else: if dist.get_world_size() > 1: - logger.warn( - "Distributed training requires CUDA. " - "Gradients will not be synchronized properly!" - ) + logger.warn("Distributed training requires CUDA. " "Gradients will not be synchronized properly!") self.use_ddp = False self.ddp_model = self.model + wandb.init(project="imporved_diffusion", entity="quasar529") + wandb.config.update( + { + "batch_size": batch_size, + "microbatch": microbatch, + "lr": lr, + "ema_rate": ema_rate, + "log_interval": log_interval, + "save_interval": save_interval, + "resume_checkpoint": resume_checkpoint, + "use_fp16": use_fp16, + "fp16_scale_growth": fp16_scale_growth, + "schedule_sampler": schedule_sampler, + "weight_decay": weight_decay, + "lr_anneal_steps": lr_anneal_steps, + } + ) + wandb.run.name = f"diffusion_batch_{batch_size}_steps{lr_anneal_steps}_{self.resume_checkpoint}" + def _load_and_sync_parameters(self): resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint @@ -118,11 +129,7 @@ def _load_and_sync_parameters(self): self.resume_step = parse_resume_step_from_filename(resume_checkpoint) if dist.get_rank() == 0: logger.log(f"loading model from checkpoint: {resume_checkpoint}...") - self.model.load_state_dict( - dist_util.load_state_dict( - resume_checkpoint, map_location=dist_util.dev() - ) - ) + self.model.load_state_dict(dist_util.load_state_dict(resume_checkpoint, map_location=dist_util.dev())) dist_util.sync_params(self.model.parameters()) @@ -134,9 +141,7 @@ def _load_ema_parameters(self, rate): if ema_checkpoint: if dist.get_rank() == 0: logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") - state_dict = dist_util.load_state_dict( - ema_checkpoint, map_location=dist_util.dev() - ) + state_dict = dist_util.load_state_dict(ema_checkpoint, map_location=dist_util.dev()) ema_params = self._state_dict_to_master_params(state_dict) dist_util.sync_params(ema_params) @@ -144,14 +149,10 @@ def _load_ema_parameters(self, rate): def _load_optimizer_state(self): main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint - opt_checkpoint = bf.join( - bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" - ) + opt_checkpoint = bf.join(bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt") if bf.exists(opt_checkpoint): logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") - state_dict = dist_util.load_state_dict( - opt_checkpoint, map_location=dist_util.dev() - ) + state_dict = dist_util.load_state_dict(opt_checkpoint, map_location=dist_util.dev()) self.opt.load_state_dict(state_dict) def _setup_fp16(self): @@ -159,20 +160,21 @@ def _setup_fp16(self): self.model.convert_to_fp16() def run_loop(self): - while ( - not self.lr_anneal_steps - or self.step + self.resume_step < self.lr_anneal_steps - ): - batch, cond = next(self.data) - self.run_step(batch, cond) - if self.step % self.log_interval == 0: - logger.dumpkvs() - if self.step % self.save_interval == 0: - self.save() - # Run for a finite amount of time in integration tests. - if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: - return - self.step += 1 + total_steps = self.lr_anneal_steps if self.lr_anneal_steps else float("inf") + with tqdm(total=total_steps, desc="Training Progress", dynamic_ncols=True) as pbar: + while not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps: + batch, cond = next(self.data) + self.run_step(batch, cond) + if self.step % self.log_interval == 0: + logger.dumpkvs() + if self.step % self.save_interval == 0: + self.save() + # Run for a finite amount of time in integration tests. + if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: + return + self.step += 1 + pbar.update(1) # Update the progress bar + # Save the last checkpoint if it wasn't already saved. if (self.step - 1) % self.save_interval != 0: self.save() @@ -187,12 +189,17 @@ def run_step(self, batch, cond): def forward_backward(self, batch, cond): zero_grad(self.model_params) + total_batches = (batch.shape[0] + self.microbatch - 1) // self.microbatch # 전체 배치 수 계산 + progress_bar = tqdm(total=total_batches, desc="Processing Batches", leave=False, dynamic_ncols=True) + for i in range(0, batch.shape[0], self.microbatch): + current_batch = i // self.microbatch + 1 # 현재 배치 번호 계산 + progress_bar.set_description( + f"Processing Batches {current_batch}/{total_batches}" + ) # 진행 바 설명 업데이트 + micro = batch[i : i + self.microbatch].to(dist_util.dev()) - micro_cond = { - k: v[i : i + self.microbatch].to(dist_util.dev()) - for k, v in cond.items() - } + micro_cond = {k: v[i : i + self.microbatch].to(dist_util.dev()) for k, v in cond.items()} last_batch = (i + self.microbatch) >= batch.shape[0] t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) @@ -211,19 +218,21 @@ def forward_backward(self, batch, cond): losses = compute_losses() if isinstance(self.schedule_sampler, LossAwareSampler): - self.schedule_sampler.update_with_local_losses( - t, losses["loss"].detach() - ) + self.schedule_sampler.update_with_local_losses(t, losses["loss"].detach()) loss = (losses["loss"] * weights).mean() - log_loss_dict( - self.diffusion, t, {k: v * weights for k, v in losses.items()} - ) + wandb.log({"train/loss": loss.item()}) + log_loss_dict(self.diffusion, t, {k: v * weights for k, v in losses.items()}) + if self.use_fp16: - loss_scale = 2 ** self.lg_loss_scale + loss_scale = 2**self.lg_loss_scale (loss * loss_scale).backward() else: loss.backward() + progress_bar.set_postfix(loss=loss.item()) + progress_bar.update(1) # 배치 진행 상황 업데이트 + + progress_bar.close() # 배치 진행 바 종료 def optimize_fp16(self): if any(not th.isfinite(p.grad).all() for p in self.model_params): @@ -232,7 +241,7 @@ def optimize_fp16(self): return model_grads_to_master_grads(self.model_params, self.master_params) - self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) + self.master_params[0].grad.mul_(1.0 / (2**self.lg_loss_scale)) self._log_grad_norm() self._anneal_lr() self.opt.step() @@ -251,7 +260,8 @@ def optimize_normal(self): def _log_grad_norm(self): sqsum = 0.0 for p in self.master_params: - sqsum += (p.grad ** 2).sum().item() + sqsum += (p.grad**2).sum().item() + wandb.log({"train/grad_norm": np.sqrt(sqsum)}) logger.logkv_mean("grad_norm", np.sqrt(sqsum)) def _anneal_lr(self): @@ -261,14 +271,17 @@ def _anneal_lr(self): lr = self.lr * (1 - frac_done) for param_group in self.opt.param_groups: param_group["lr"] = lr + wandb.log({"train/lr": lr}) def log_step(self): logger.logkv("step", self.step + self.resume_step) logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) + wandb.log({"train/step": self.step + self.resume_step}) + wandb.log({"train/samples": (self.step + self.resume_step + 1) * self.global_batch}) if self.use_fp16: logger.logkv("lg_loss_scale", self.lg_loss_scale) - def save(self): + def save(self, save_dir="/home/jun/improved-diffusion/results"): def save_checkpoint(rate, params): state_dict = self._master_params_to_state_dict(params) if dist.get_rank() == 0: @@ -277,7 +290,8 @@ def save_checkpoint(rate, params): filename = f"model{(self.step+self.resume_step):06d}.pt" else: filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" - with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: + save_path = bf.join(save_dir, filename) + with bf.BlobFile(save_path, "wb") as f: th.save(state_dict, f) save_checkpoint(0, self.master_params) @@ -285,19 +299,16 @@ def save_checkpoint(rate, params): save_checkpoint(rate, params) if dist.get_rank() == 0: - with bf.BlobFile( - bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), - "wb", - ) as f: + opt_filename = f"opt{(self.step+self.resume_step):06d}.pt" + opt_save_path = bf.join(save_dir, opt_filename) + with bf.BlobFile(opt_save_path, "wb") as f: th.save(self.opt.state_dict(), f) dist.barrier() def _master_params_to_state_dict(self, master_params): if self.use_fp16: - master_params = unflatten_master_params( - self.model.parameters(), master_params - ) + master_params = unflatten_master_params(self.model.parameters(), master_params) state_dict = self.model.state_dict() for i, (name, _value) in enumerate(self.model.named_parameters()): assert name in state_dict diff --git a/scripts/image_sample.py b/scripts/image_sample.py index 66326f6908..e6210b120f 100644 --- a/scripts/image_sample.py +++ b/scripts/image_sample.py @@ -9,7 +9,10 @@ import numpy as np import torch as th import torch.distributed as dist +import sys +print(sys.path) +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from improved_diffusion import dist_util, logger from improved_diffusion.script_util import ( NUM_CLASSES, @@ -27,12 +30,8 @@ def main(): logger.configure() logger.log("creating model and diffusion...") - model, diffusion = create_model_and_diffusion( - **args_to_dict(args, model_and_diffusion_defaults().keys()) - ) - model.load_state_dict( - dist_util.load_state_dict(args.model_path, map_location="cpu") - ) + model, diffusion = create_model_and_diffusion(**args_to_dict(args, model_and_diffusion_defaults().keys())) + model.load_state_dict(dist_util.load_state_dict(args.model_path, map_location="cpu")) model.to(dist_util.dev()) model.eval() @@ -42,13 +41,9 @@ def main(): while len(all_images) * args.batch_size < args.num_samples: model_kwargs = {} if args.class_cond: - classes = th.randint( - low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev() - ) + classes = th.randint(low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()) model_kwargs["y"] = classes - sample_fn = ( - diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop - ) + sample_fn = diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop sample = sample_fn( model, (args.batch_size, 3, args.image_size, args.image_size), @@ -63,9 +58,7 @@ def main(): dist.all_gather(gathered_samples, sample) # gather not supported with NCCL all_images.extend([sample.cpu().numpy() for sample in gathered_samples]) if args.class_cond: - gathered_labels = [ - th.zeros_like(classes) for _ in range(dist.get_world_size()) - ] + gathered_labels = [th.zeros_like(classes) for _ in range(dist.get_world_size())] dist.all_gather(gathered_labels, classes) all_labels.extend([labels.cpu().numpy() for labels in gathered_labels]) logger.log(f"created {len(all_images) * args.batch_size} samples") @@ -91,9 +84,9 @@ def main(): def create_argparser(): defaults = dict( clip_denoised=True, - num_samples=10000, + num_samples=10, batch_size=16, - use_ddim=False, + use_ddim=True, model_path="", ) defaults.update(model_and_diffusion_defaults()) diff --git a/scripts/image_train.py b/scripts/image_train.py index 76c0632608..e3c0024714 100644 --- a/scripts/image_train.py +++ b/scripts/image_train.py @@ -23,10 +23,10 @@ def main(): logger.configure() logger.log("creating model and diffusion...") - model, diffusion = create_model_and_diffusion( - **args_to_dict(args, model_and_diffusion_defaults().keys()) - ) + model, diffusion = create_model_and_diffusion(**args_to_dict(args, model_and_diffusion_defaults().keys())) + #model.load_state_dict(dist_util.load_state_dict(args.model_path, map_location="cpu")) model.to(dist_util.dev()) + print(dist_util.dev()) schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) logger.log("creating data loader...") @@ -63,15 +63,16 @@ def create_argparser(): schedule_sampler="uniform", lr=1e-4, weight_decay=0.0, - lr_anneal_steps=0, + lr_anneal_steps=10000, batch_size=1, microbatch=-1, # -1 disables microbatches ema_rate="0.9999", # comma-separated list of EMA values log_interval=10, - save_interval=10000, + save_interval=5000, resume_checkpoint="", - use_fp16=False, + use_fp16=True, fp16_scale_growth=1e-3, + model_path="", ) defaults.update(model_and_diffusion_defaults()) parser = argparse.ArgumentParser()