diff --git a/extract_feature_print.py b/extract_feature_print.py index cfc6e75f6..cf9cf1c86 100644 --- a/extract_feature_print.py +++ b/extract_feature_print.py @@ -1,4 +1,5 @@ import os, sys, traceback +import tqdm # device=sys.argv[1] n_part = int(sys.argv[2]) @@ -87,37 +88,40 @@ def readwave(wav_path, normalize=False): printt("no-feature-todo") else: printt("all-feature-%s" % len(todo)) - for idx, file in enumerate(todo): - try: - if file.endswith(".wav"): - wav_path = "%s/%s" % (wavPath, file) - out_path = "%s/%s" % (outPath, file.replace("wav", "npy")) - - if os.path.exists(out_path): - continue - - feats = readwave(wav_path, normalize=saved_cfg.task.normalize) - padding_mask = torch.BoolTensor(feats.shape).fill_(False) - inputs = { - "source": feats.half().to(device) - if device not in ["mps", "cpu"] - else feats.to(device), - "padding_mask": padding_mask.to(device), - "output_layer": 9 if version == "v1" else 12, # layer 9 - } - with torch.no_grad(): - logits = model.extract_features(**inputs) - feats = ( - model.final_proj(logits[0]) if version == "v1" else logits[0] - ) - - feats = feats.squeeze(0).float().cpu().numpy() - if np.isnan(feats).sum() == 0: - np.save(out_path, feats, allow_pickle=False) - else: - printt("%s-contains nan" % file) - if idx % n == 0: - printt("now-%s,all-%s,%s,%s" % (idx, len(todo), file, feats.shape)) - except: - printt(traceback.format_exc()) + with tqdm.tqdm(total=len(todo)) as pbar: + for idx, file in enumerate(todo): + try: + if file.endswith(".wav"): + wav_path = "%s/%s" % (wavPath, file) + out_path = "%s/%s" % (outPath, file.replace("wav", "npy")) + + if os.path.exists(out_path): + continue + + feats = readwave(wav_path, normalize=saved_cfg.task.normalize) + padding_mask = torch.BoolTensor(feats.shape).fill_(False) + inputs = { + "source": feats.half().to(device) + if device not in ["mps", "cpu"] + else feats.to(device), + "padding_mask": padding_mask.to(device), + "output_layer": 9 if version == "v1" else 12, # layer 9 + } + with torch.no_grad(): + logits = model.extract_features(**inputs) + feats = ( + model.final_proj(logits[0]) if version == "v1" else logits[0] + ) + + feats = feats.squeeze(0).float().cpu().numpy() + if np.isnan(feats).sum() == 0: + np.save(out_path, feats, allow_pickle=False) + else: + printt("%s-contains nan" % file) + # if idx % n == 0: + # printt("now-%s,all-%s,%s,%s" % (idx, len(todo), file, feats.shape)) + pbar.set_description("file %s, shape %s" % (file, feats.shape)) + except: + printt(traceback.format_exc()) + pbar.update(1) printt("all-feature-done") diff --git a/train/utils.py b/train/utils.py index 8884e43e4..9644ecd6a 100644 --- a/train/utils.py +++ b/train/utils.py @@ -31,10 +31,7 @@ def go(model, bkey): try: new_state_dict[k] = saved_state_dict[k] if saved_state_dict[k].shape != state_dict[k].shape: - print( - "shape-%s-mismatch|need-%s|get-%s" - % (k, state_dict[k].shape, saved_state_dict[k].shape) - ) # + print("shape-%s-mismatch|need-%s|get-%s" % (k, state_dict[k].shape, saved_state_dict[k].shape)) # raise KeyError except: # logger.info(traceback.format_exc()) @@ -52,9 +49,7 @@ def go(model, bkey): iteration = checkpoint_dict["iteration"] learning_rate = checkpoint_dict["learning_rate"] - if ( - optimizer is not None and load_opt == 1 - ): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch + if optimizer is not None and load_opt == 1: ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch # try: optimizer.load_state_dict(checkpoint_dict["optimizer"]) # except: @@ -106,10 +101,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1): try: new_state_dict[k] = saved_state_dict[k] if saved_state_dict[k].shape != state_dict[k].shape: - print( - "shape-%s-mismatch|need-%s|get-%s" - % (k, state_dict[k].shape, saved_state_dict[k].shape) - ) # + print("shape-%s-mismatch|need-%s|get-%s" % (k, state_dict[k].shape, saved_state_dict[k].shape)) # raise KeyError except: # logger.info(traceback.format_exc()) @@ -123,9 +115,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1): iteration = checkpoint_dict["iteration"] learning_rate = checkpoint_dict["learning_rate"] - if ( - optimizer is not None and load_opt == 1 - ): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch + if optimizer is not None and load_opt == 1: ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch # try: optimizer.load_state_dict(checkpoint_dict["optimizer"]) # except: @@ -134,16 +124,19 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1): return model, optimizer, learning_rate, iteration -def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): - logger.info( - "Saving model and optimizer state at epoch {} to {}".format( - iteration, checkpoint_path - ) - ) +def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path, checkpoint_type, delete_old=False): + # logger.info( + # "Saving model and optimizer state at epoch {} to {}".format( + # iteration, checkpoint_path + # ) + # ) if hasattr(model, "module"): state_dict = model.module.state_dict() else: state_dict = model.state_dict() + if delete_old: + latest_checkpoint = latest_checkpoint_path(checkpoint_path, regex=("G_*.pth" if checkpoint_type.startswith("G") else "D_*.pth")) + torch.save( { "model": state_dict, @@ -151,16 +144,19 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path) "optimizer": optimizer.state_dict(), "learning_rate": learning_rate, }, - checkpoint_path, + os.path.join(checkpoint_path, checkpoint_type), ) + # delete after saving new checkpoint to avoid loss if save fails + if delete_old and latest_checkpoint is not None: + os.remove(latest_checkpoint) def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoint_path): - logger.info( - "Saving model and optimizer state at epoch {} to {}".format( - iteration, checkpoint_path - ) - ) + # logger.info( + # "Saving model and optimizer state at epoch {} to {}".format( + # iteration, checkpoint_path + # ) + if hasattr(combd, "module"): state_dict_combd = combd.module.state_dict() else: @@ -203,8 +199,10 @@ def summarize( def latest_checkpoint_path(dir_path, regex="G_*.pth"): f_list = glob.glob(os.path.join(dir_path, regex)) f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) + if len(f_list) == 0: + return None x = f_list[-1] - print(x) + # print(x) return x @@ -247,9 +245,7 @@ def plot_alignment_to_numpy(alignment, info=None): import numpy as np fig, ax = plt.subplots(figsize=(6, 4)) - im = ax.imshow( - alignment.transpose(), aspect="auto", origin="lower", interpolation="none" - ) + im = ax.imshow(alignment.transpose(), aspect="auto", origin="lower", interpolation="none") fig.colorbar(im, ax=ax) xlabel = "Decoder timestep" if info is not None: @@ -302,25 +298,13 @@ def get_hparams(init=True): required=True, help="checkpoint save frequency (epoch)", ) - parser.add_argument( - "-te", "--total_epoch", type=int, required=True, help="total_epoch" - ) - parser.add_argument( - "-pg", "--pretrainG", type=str, default="", help="Pretrained Discriminator path" - ) - parser.add_argument( - "-pd", "--pretrainD", type=str, default="", help="Pretrained Generator path" - ) + parser.add_argument("-te", "--total_epoch", type=int, required=True, help="total_epoch") + parser.add_argument("-pg", "--pretrainG", type=str, default="", help="Pretrained Discriminator path") + parser.add_argument("-pd", "--pretrainD", type=str, default="", help="Pretrained Generator path") parser.add_argument("-g", "--gpus", type=str, default="0", help="split by -") - parser.add_argument( - "-bs", "--batch_size", type=int, required=True, help="batch size" - ) - parser.add_argument( - "-e", "--experiment_dir", type=str, required=True, help="experiment dir" - ) # -m - parser.add_argument( - "-sr", "--sample_rate", type=str, required=True, help="sample rate, 32k/40k/48k" - ) + parser.add_argument("-bs", "--batch_size", type=int, required=True, help="batch size") + parser.add_argument("-e", "--experiment_dir", type=str, required=True, help="experiment dir") # -m + parser.add_argument("-sr", "--sample_rate", type=str, required=True, help="sample rate, 32k/40k/48k") parser.add_argument( "-sw", "--save_every_weights", @@ -328,9 +312,7 @@ def get_hparams(init=True): default="0", help="save the extracted model in weights directory when saving checkpoints", ) - parser.add_argument( - "-v", "--version", type=str, required=True, help="model version" - ) + parser.add_argument("-v", "--version", type=str, required=True, help="model version") parser.add_argument( "-f0", "--if_f0", @@ -417,11 +399,7 @@ def get_hparams_from_file(config_path): def check_git_hash(model_dir): source_dir = os.path.dirname(os.path.realpath(__file__)) if not os.path.exists(os.path.join(source_dir, ".git")): - logger.warn( - "{} is not a git repository, therefore hash value comparison will be ignored.".format( - source_dir - ) - ) + logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(source_dir)) return cur_hash = subprocess.getoutput("git rev-parse HEAD") @@ -430,11 +408,7 @@ def check_git_hash(model_dir): if os.path.exists(path): saved_hash = open(path).read() if saved_hash != cur_hash: - logger.warn( - "git hash values are different. {}(saved) != {}(current)".format( - saved_hash[:8], cur_hash[:8] - ) - ) + logger.warn("git hash values are different. {}(saved) != {}(current)".format(saved_hash[:8], cur_hash[:8])) else: open(path, "w").write(cur_hash) diff --git a/train_nsf_sim_cache_sid_load_pretrain.py b/train_nsf_sim_cache_sid_load_pretrain.py index 2949bc478..d91084a43 100644 --- a/train_nsf_sim_cache_sid_load_pretrain.py +++ b/train_nsf_sim_cache_sid_load_pretrain.py @@ -32,6 +32,7 @@ TextAudioCollate, DistributedBucketSampler, ) +import tqdm if hps.version == "v1": from infer_pack.models import ( @@ -97,9 +98,7 @@ def run(rank, n_gpus, hps): writer = SummaryWriter(log_dir=hps.model_dir) writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval")) - dist.init_process_group( - backend="gloo", init_method="env://", world_size=n_gpus, rank=rank - ) + dist.init_process_group(backend="gloo", init_method="env://", world_size=n_gpus, rank=rank) torch.manual_seed(hps.train.seed) if torch.cuda.is_available(): torch.cuda.set_device(rank) @@ -175,84 +174,62 @@ def run(rank, n_gpus, hps): net_d = DDP(net_d) try: # 如果能加载自动resume - _, _, _, epoch_str = utils.load_checkpoint( - utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d - ) # D多半加载没事 + _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d) # D多半加载没事 if rank == 0: logger.info("loaded D") # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0) - _, _, _, epoch_str = utils.load_checkpoint( - utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g - ) - global_step = (epoch_str - 1) * len(train_loader) + _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) + # epoch_str = 1 # global_step = 0 except: # 如果首次不能加载,加载pretrain # traceback.print_exc() - epoch_str = 1 + epoch_str = 0 global_step = 0 if hps.pretrainG != "": if rank == 0: logger.info("loaded pretrained %s" % (hps.pretrainG)) - print( - net_g.module.load_state_dict( - torch.load(hps.pretrainG, map_location="cpu")["model"] - ) - ) ##测试不加载优化器 + print(net_g.module.load_state_dict(torch.load(hps.pretrainG, map_location="cpu")["model"])) ##测试不加载优化器 if hps.pretrainD != "": if rank == 0: logger.info("loaded pretrained %s" % (hps.pretrainD)) - print( - net_d.module.load_state_dict( - torch.load(hps.pretrainD, map_location="cpu")["model"] - ) - ) - - scheduler_g = torch.optim.lr_scheduler.ExponentialLR( - optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2 - ) - scheduler_d = torch.optim.lr_scheduler.ExponentialLR( - optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2 - ) + print(net_d.module.load_state_dict(torch.load(hps.pretrainD, map_location="cpu")["model"])) + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 1) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 1) + global_step = epoch_str * len(train_loader) scaler = GradScaler(enabled=hps.train.fp16_run) cache = [] - for epoch in range(epoch_str, hps.train.epochs + 1): - if rank == 0: - train_and_evaluate( - rank, - epoch, - hps, - [net_g, net_d], - [optim_g, optim_d], - [scheduler_g, scheduler_d], - scaler, - [train_loader, None], - logger, - [writer, writer_eval], - cache, - ) - else: - train_and_evaluate( - rank, - epoch, - hps, - [net_g, net_d], - [optim_g, optim_d], - [scheduler_g, scheduler_d], - scaler, - [train_loader, None], - None, - None, - cache, - ) - scheduler_g.step() - scheduler_d.step() + # for epoch in tqdm.tqdm(range(epoch_str, hps.train.epochs + 1), desc="Training progress", position=1, leave=True): + + # disable if rank is not 0 to avoid duplicate main progress bars on each process + with tqdm.tqdm( + total=(hps.total_epoch), + desc=f"Training progress, last ckpt saved at epoch: {epoch_str}", + position=0, + leave=True, + initial=0 if global_step == 0 else epoch_str, + dynamic_ncols=True, + disable=(rank != 0), + ) as pbar: + for epoch in range(epoch_str + 1, hps.train.epochs + 1): + if rank == 0: + train_and_evaluate( + rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, None], logger, [writer, writer_eval], cache, pbar + ) + else: + train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, None], None, None, cache, pbar) + pbar.update(1) + scheduler_g.step() + scheduler_d.step() def train_and_evaluate( - rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers, cache + rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers, cache, pbar ): net_g, net_d = nets optim_g, optim_d = optims @@ -350,227 +327,224 @@ def train_and_evaluate( # Run steps epoch_recorder = EpochRecorder() - for batch_idx, info in data_iterator: - # Data - ## Unpack - if hps.if_f0 == 1: - ( - phone, - phone_lengths, - pitch, - pitchf, - spec, - spec_lengths, - wave, - wave_lengths, - sid, - ) = info - else: - phone, phone_lengths, spec, spec_lengths, wave, wave_lengths, sid = info - ## Load on CUDA - if (hps.if_cache_data_in_gpu == False) and torch.cuda.is_available(): - phone = phone.cuda(rank, non_blocking=True) - phone_lengths = phone_lengths.cuda(rank, non_blocking=True) - if hps.if_f0 == 1: - pitch = pitch.cuda(rank, non_blocking=True) - pitchf = pitchf.cuda(rank, non_blocking=True) - sid = sid.cuda(rank, non_blocking=True) - spec = spec.cuda(rank, non_blocking=True) - spec_lengths = spec_lengths.cuda(rank, non_blocking=True) - wave = wave.cuda(rank, non_blocking=True) - # wave_lengths = wave_lengths.cuda(rank, non_blocking=True) - - # Calculate - with autocast(enabled=hps.train.fp16_run): + # for batch_idx, info in data_iterator: + with tqdm.tqdm(total=len(train_loader), desc="Epoch progress", position=rank + 1, leave=True, dynamic_ncols=True) as piterbar: + for batch_idx, info in data_iterator: + # Data + ## Unpack if hps.if_f0 == 1: ( - y_hat, - ids_slice, - x_mask, - z_mask, - (z, z_p, m_p, logs_p, m_q, logs_q), - ) = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid) + phone, + phone_lengths, + pitch, + pitchf, + spec, + spec_lengths, + wave, + wave_lengths, + sid, + ) = info else: - ( - y_hat, - ids_slice, - x_mask, - z_mask, - (z, z_p, m_p, logs_p, m_q, logs_q), - ) = net_g(phone, phone_lengths, spec, spec_lengths, sid) - mel = spec_to_mel_torch( - spec, - hps.data.filter_length, - hps.data.n_mel_channels, - hps.data.sampling_rate, - hps.data.mel_fmin, - hps.data.mel_fmax, - ) - y_mel = commons.slice_segments( - mel, ids_slice, hps.train.segment_size // hps.data.hop_length - ) - with autocast(enabled=False): - y_hat_mel = mel_spectrogram_torch( - y_hat.float().squeeze(1), + phone, phone_lengths, spec, spec_lengths, wave, wave_lengths, sid = info + ## Load on CUDA + if (hps.if_cache_data_in_gpu == False) and torch.cuda.is_available(): + phone = phone.cuda(rank, non_blocking=True) + phone_lengths = phone_lengths.cuda(rank, non_blocking=True) + if hps.if_f0 == 1: + pitch = pitch.cuda(rank, non_blocking=True) + pitchf = pitchf.cuda(rank, non_blocking=True) + sid = sid.cuda(rank, non_blocking=True) + spec = spec.cuda(rank, non_blocking=True) + spec_lengths = spec_lengths.cuda(rank, non_blocking=True) + wave = wave.cuda(rank, non_blocking=True) + # wave_lengths = wave_lengths.cuda(rank, non_blocking=True) + + # Calculate + with autocast(enabled=hps.train.fp16_run): + if hps.if_f0 == 1: + ( + y_hat, + ids_slice, + x_mask, + z_mask, + (z, z_p, m_p, logs_p, m_q, logs_q), + ) = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid) + else: + ( + y_hat, + ids_slice, + x_mask, + z_mask, + (z, z_p, m_p, logs_p, m_q, logs_q), + ) = net_g(phone, phone_lengths, spec, spec_lengths, sid) + mel = spec_to_mel_torch( + spec, hps.data.filter_length, hps.data.n_mel_channels, hps.data.sampling_rate, - hps.data.hop_length, - hps.data.win_length, hps.data.mel_fmin, hps.data.mel_fmax, ) - if hps.train.fp16_run == True: - y_hat_mel = y_hat_mel.half() - wave = commons.slice_segments( - wave, ids_slice * hps.data.hop_length, hps.train.segment_size - ) # slice - - # Discriminator - y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach()) - with autocast(enabled=False): - loss_disc, losses_disc_r, losses_disc_g = discriminator_loss( - y_d_hat_r, y_d_hat_g + y_mel = commons.slice_segments( + mel, ids_slice, hps.train.segment_size // hps.data.hop_length ) - optim_d.zero_grad() - scaler.scale(loss_disc).backward() - scaler.unscale_(optim_d) - grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) - scaler.step(optim_d) - - with autocast(enabled=hps.train.fp16_run): - # Generator - y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat) - with autocast(enabled=False): - loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel - loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl - loss_fm = feature_loss(fmap_r, fmap_g) - loss_gen, losses_gen = generator_loss(y_d_hat_g) - loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl - optim_g.zero_grad() - scaler.scale(loss_gen_all).backward() - scaler.unscale_(optim_g) - grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) - scaler.step(optim_g) - scaler.update() + with autocast(enabled=False): + y_hat_mel = mel_spectrogram_torch( + y_hat.float().squeeze(1), + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + hps.data.mel_fmin, + hps.data.mel_fmax, + ) + if hps.train.fp16_run == True: + y_hat_mel = y_hat_mel.half() + wave = commons.slice_segments( + wave, ids_slice * hps.data.hop_length, hps.train.segment_size + ) # slice + + # Discriminator + y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach()) + with autocast(enabled=False): + loss_disc, losses_disc_r, losses_disc_g = discriminator_loss( + y_d_hat_r, y_d_hat_g + ) + optim_d.zero_grad() + scaler.scale(loss_disc).backward() + scaler.unscale_(optim_d) + grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) + scaler.step(optim_d) + + with autocast(enabled=hps.train.fp16_run): + # Generator + y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat) + with autocast(enabled=False): + loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel + loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl + loss_fm = feature_loss(fmap_r, fmap_g) + loss_gen, losses_gen = generator_loss(y_d_hat_g) + loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl + optim_g.zero_grad() + scaler.scale(loss_gen_all).backward() + scaler.unscale_(optim_g) + grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) + scaler.step(optim_g) + scaler.update() - if rank == 0: - if global_step % hps.train.log_interval == 0: + if rank == 0: + pbar.refresh() + piterbar.update(1) lr = optim_g.param_groups[0]["lr"] - logger.info( - "Train Epoch: {} [{:.0f}%]".format( - epoch, 100.0 * batch_idx / len(train_loader) - ) - ) - # Amor For Tensorboard display - if loss_mel > 75: - loss_mel = 75 - if loss_kl > 9: - loss_kl = 9 - - logger.info([global_step, lr]) - logger.info( - f"loss_disc={loss_disc:.3f}, loss_gen={loss_gen:.3f}, loss_fm={loss_fm:.3f},loss_mel={loss_mel:.3f}, loss_kl={loss_kl:.3f}" - ) - scalar_dict = { - "loss/g/total": loss_gen_all, - "loss/d/total": loss_disc, - "learning_rate": lr, - "grad_norm_d": grad_norm_d, - "grad_norm_g": grad_norm_g, - } - scalar_dict.update( - { - "loss/g/fm": loss_fm, - "loss/g/mel": loss_mel, - "loss/g/kl": loss_kl, + piterbar.set_description(f"GPU{rank} l_disc={loss_disc:.3f}, l_gen={loss_gen:.3f}, l_fm={loss_fm:.3f}, l_mel={loss_mel:.3f}, l_kl={loss_kl:.3f}, iter={global_step}, lr{'{:.3e}'.format(lr)}") + if global_step % hps.train.log_interval == 0: + # logger.info( + # "Train Epoch: {} [{:.0f}%]".format( + # epoch, 100.0 * batch_idx / len(train_loader) + # ) + # ) + + # Amor For Tensorboard display + if loss_mel > 75: + loss_mel = 75 + if loss_kl > 9: + loss_kl = 9 + + # logger.info([global_step, lr]) + # logger.info( + # f"loss_disc={loss_disc:.3f}, loss_gen={loss_gen:.3f}, loss_fm={loss_fm:.3f},loss_mel={loss_mel:.3f}, loss_kl={loss_kl:.3f}" + # ) + scalar_dict = { + "loss/g/total": loss_gen_all, + "loss/d/total": loss_disc, + "learning_rate": lr, + "grad_norm_d": grad_norm_d, + "grad_norm_g": grad_norm_g, } - ) + scalar_dict.update( + { + "loss/g/fm": loss_fm, + "loss/g/mel": loss_mel, + "loss/g/kl": loss_kl, + } + ) - scalar_dict.update( - {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)} - ) - scalar_dict.update( - {"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)} - ) - scalar_dict.update( - {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)} - ) - image_dict = { - "slice/mel_org": utils.plot_spectrogram_to_numpy( - y_mel[0].data.cpu().numpy() - ), - "slice/mel_gen": utils.plot_spectrogram_to_numpy( - y_hat_mel[0].data.cpu().numpy() - ), - "all/mel": utils.plot_spectrogram_to_numpy( - mel[0].data.cpu().numpy() - ), - } - utils.summarize( - writer=writer, - global_step=global_step, - images=image_dict, - scalars=scalar_dict, - ) - global_step += 1 - # /Run steps + scalar_dict.update( + {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)} + ) + scalar_dict.update( + {"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)} + ) + scalar_dict.update( + {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)} + ) + image_dict = { + "slice/mel_org": utils.plot_spectrogram_to_numpy( + y_mel[0].data.cpu().numpy() + ), + "slice/mel_gen": utils.plot_spectrogram_to_numpy( + y_hat_mel[0].data.cpu().numpy() + ), + "all/mel": utils.plot_spectrogram_to_numpy( + mel[0].data.cpu().numpy() + ), + } + utils.summarize( + writer=writer, + global_step=global_step, + images=image_dict, + scalars=scalar_dict, + ) + global_step += 1 + # /Run steps if epoch % hps.save_every_epoch == 0 and rank == 0: - if hps.if_latest == 0: - utils.save_checkpoint( - net_g, - optim_g, - hps.train.learning_rate, - epoch, - os.path.join(hps.model_dir, "G_{}.pth".format(global_step)), - ) - utils.save_checkpoint( - net_d, - optim_d, - hps.train.learning_rate, - epoch, - os.path.join(hps.model_dir, "D_{}.pth".format(global_step)), - ) - else: - utils.save_checkpoint( - net_g, - optim_g, - hps.train.learning_rate, - epoch, - os.path.join(hps.model_dir, "G_{}.pth".format(2333333)), - ) - utils.save_checkpoint( - net_d, - optim_d, - hps.train.learning_rate, - epoch, - os.path.join(hps.model_dir, "D_{}.pth".format(2333333)), - ) + utils.save_checkpoint( + net_g, + optim_g, + hps.train.learning_rate, + epoch, + hps.model_dir, + "G_{}.pth".format(global_step), + True if hps.if_latest == 1 else False, + ) + utils.save_checkpoint( + net_d, + optim_d, + hps.train.learning_rate, + epoch, + hps.model_dir, + "D_{}.pth".format(global_step), + True if hps.if_latest == 1 else False, + ) + + pbar.set_description(f"Training progress, last ckpt saved at epoch {epoch}") if rank == 0 and hps.save_every_weights == "1": if hasattr(net_g, "module"): ckpt = net_g.module.state_dict() else: ckpt = net_g.state_dict() - logger.info( - "saving ckpt %s_e%s:%s" - % ( - hps.name, - epoch, - savee( - ckpt, - hps.sample_rate, - hps.if_f0, - hps.name + "_e%s_s%s" % (epoch, global_step), - epoch, - hps.version, - hps, - ), - ) + s = savee( + ckpt, + hps.sample_rate, + hps.if_f0, + hps.name + "_e%s_s%s" % (epoch, global_step), + epoch, + hps.version, + hps, ) - - if rank == 0: - logger.info("====> Epoch: {} {}".format(epoch, epoch_recorder.record())) + # logger.info( + # "saving ckpt %s_e%s:%s" + # % ( + # hps.name, + # epoch, + # s, + # ) + # ) + + # if rank == 0: + # logger.info("====> Epoch: {} {}".format(epoch, epoch_recorder.record())) if epoch >= hps.total_epoch and rank == 0: logger.info("Training is done. The program is closed.") @@ -578,14 +552,7 @@ def train_and_evaluate( ckpt = net_g.module.state_dict() else: ckpt = net_g.state_dict() - logger.info( - "saving final ckpt:%s" - % ( - savee( - ckpt, hps.sample_rate, hps.if_f0, hps.name, epoch, hps.version, hps - ) - ) - ) + logger.info("saving final ckpt:%s" % (savee(ckpt, hps.sample_rate, hps.if_f0, hps.name, epoch, hps.version, hps))) sleep(1) os._exit(2333333)