From 799ee9a15060b75dc4497a5f0a22782eff7476c9 Mon Sep 17 00:00:00 2001 From: flexthink Date: Wed, 16 Apr 2025 23:36:48 -0400 Subject: [PATCH 1/2] DASB: Add a resynthesis script --- benchmarks/DASB/LibriTTS/libritts_prepare.py | 504 ++++++++++++++++++ .../DASB/LibriTTS/resynthesis/evaluate.py | 247 +++++++++ .../resynthesis/hparams/discrete_ssl.yaml | 127 +++++ .../LibriTTS/resynthesis/libritts_prepare.py | 1 + .../DASB/LibriTTS/resynthesis/metric.py | 369 +++++++++++++ 5 files changed, 1248 insertions(+) create mode 100644 benchmarks/DASB/LibriTTS/libritts_prepare.py create mode 100644 benchmarks/DASB/LibriTTS/resynthesis/evaluate.py create mode 100644 benchmarks/DASB/LibriTTS/resynthesis/hparams/discrete_ssl.yaml create mode 120000 benchmarks/DASB/LibriTTS/resynthesis/libritts_prepare.py create mode 100644 benchmarks/DASB/LibriTTS/resynthesis/metric.py diff --git a/benchmarks/DASB/LibriTTS/libritts_prepare.py b/benchmarks/DASB/LibriTTS/libritts_prepare.py new file mode 100644 index 000000000..52594eaf9 --- /dev/null +++ b/benchmarks/DASB/LibriTTS/libritts_prepare.py @@ -0,0 +1,504 @@ +""" +LibriTTS data preparation + +Authors + * Pradnya Kandarkar 2022 +""" + +import json +import os +import random + +import torch +import torchaudio +import re +from tqdm import tqdm + +from speechbrain.inference.text import GraphemeToPhoneme +from speechbrain.utils.data_utils import get_all_files +from speechbrain.utils.logger import get_logger +from speechbrain.utils.text_to_sequence import _g2p_keep_punctuations +from pathlib import Path + +logger = get_logger(__name__) +LIBRITTS_URL_PREFIX = "https://www.openslr.org/resources/60/" + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def prepare_libritts( + data_folder, + save_json_train, + save_json_valid, + save_json_test, + sample_rate, + split_ratio=[80, 10, 10], + libritts_subsets=None, + train_split=None, + valid_split=None, + test_split=None, + seed=1234, + model_name=None, + max_valid_size=500, + alignments_folder=None, + skip_prep=False, + skip_resample=False, +): + """ + Prepares the json files for the LibriTTS dataset. + Downloads the dataset if it is not found in the `data_folder` as expected. + + Arguments + --------- + data_folder : str + Path to the folder where the LibriTTS dataset is stored. + save_json_train : str + Path where the train data specification file will be saved. + save_json_valid : str + Path where the validation data specification file will be saved. + save_json_test : str + Path where the test data specification file will be saved. + sample_rate : int + The sample rate to be used for the dataset + split_ratio : list + List composed of three integers that sets split ratios for train, valid, + and test sets, respectively. For instance split_ratio=[80, 10, 10] will + assign 80% of the sentences to training, 10% for validation, and 10% + for test. + libritts_subsets: list + List of librispeech subsets to use (e.g., dev-clean, train-clean-100, ...) for the experiment. + This parameter will be ignored if explicit data splits are provided. + Explicit data splits parameters: "train_split", "valid_split", "test_split" + train_split : list + List of librispeech subsets to use (e.g.,train-clean-100, train-clean-360) for the experiment training stage. + valid_split : list + List of librispeech subsets to use (e.g., dev-clean) for the experiment validation stage. + test_split : list + List of librispeech subsets to use (e.g., test-clean) for the experiment testing stage. + seed : int + Seed value + model_name : str + Model name (used to prepare additional model specific data) + alignments_path : None + The path to alignments files + skip_prep: Bool + If True, skip preparation. + skip_resample: bool + If True, audio will not be resampled + + Returns + ------- + None + """ + + if skip_prep: + return + + # Setting the seed value + random.seed(seed) + + # Checks if this phase is already done (if so, skips it) + if skip(save_json_train, save_json_valid, save_json_test): + logger.info("Preparation completed in previous run, skipping.") + return + + logger.info( + f"Creating {save_json_train}, {save_json_valid}, and {save_json_test}" + ) + + # If specific splits are provided, creates data manifest files accordingly + if train_split: + wav_list = prepare_split(data_folder, train_split) + create_json(wav_list, save_json_train, sample_rate, data_folder, alignments_folder, model_name, skip_resample) + if valid_split: + wav_list = prepare_split(data_folder, valid_split) + # TODO add better way to speedup evaluation + if max_valid_size is not None and len(wav_list) > max_valid_size: + wav_list = random.sample(wav_list, max_valid_size) + create_json(wav_list, save_json_valid, sample_rate, data_folder, alignments_folder, model_name, skip_resample) + if test_split: + wav_list = prepare_split(data_folder, test_split) + create_json(wav_list, save_json_test, sample_rate, data_folder, alignments_folder, model_name, skip_resample) + + if skip(save_json_train, save_json_valid, save_json_test): + logger.info("Preparation completed.") + return + + # If specific splits are not provided, and a list of subsets if provided, creates train, valid, test splits + # Creates data manifest files according to the data splits + if libritts_subsets: + wav_list = prepare_split(data_folder, libritts_subsets) + # Random split the signal list into train, valid, and test sets. + data_split = split_sets(wav_list, split_ratio) + # Creating json files + create_json( + data_split["train"], save_json_train, sample_rate, alignments_folder, model_name, skip_resample + ) + create_json( + data_split["valid"], save_json_valid, sample_rate, alignments_folder, model_name, skip_resample + ) + create_json(data_split["test"], save_json_test, sample_rate, alignments_folder, model_name, skip_resample) + + +def prepare_split(data_folder, split_list): + """ + Processes the provided list of LibriTTS subsets and creates a list of all the .wav files present in the subsets. + Downloads the LibriTTS subsets as required. + + Arguments + --------- + data_folder : str + Path to the folder where the LibriTTS dataset is stored + split_list : list + List of librispeech subsets to process (e.g., dev-clean, train-clean-100, ...) + + Returns + ------- + wav_list : list + List of all .wav files to be processed + """ + extension = [".wav"] # The expected extension for audio files + wav_list = list() # Stores all audio file paths for the dataset + + # For every subset of the dataset, if it doesn't exist, downloads it + for subset_name in split_list: + subset_folder = os.path.join(data_folder, subset_name) + subset_archive = os.path.join(subset_folder, subset_name + ".tar.gz") + + if not check_folders(subset_folder): + logger.info( + f"No data found for {subset_name}. Checking for an archive file." + ) + if not os.path.isfile(subset_archive): + logger.info( + f"No archive file found for {subset_name}. Downloading and unpacking." + ) + quit() + # Collects all files matching the provided extension + wav_list.extend(get_all_files(subset_folder, match_and=extension)) + + return wav_list + + +def create_json(wav_list, json_file, sample_rate, data_folder, alignments_folder=None, model_name=None, skip_resample=False): + """ + Creates the json file given a list of wav files. + Arguments + --------- + wav_list : list of str + The list of wav files. + json_file : str + The path of the output json file + sample_rate : int + The sample rate to be used for the dataset + data_folder : str + The path to LibriTTS + alignments_folder : str + The path to LibriTTS alignments + model_name : str + Model name (used to prepare additional model specific data) + skip_resample : int + Skips resampling - useful when large temporary storage + is absent. + """ + + # Downloads and initializes the G2P model to compute the phonemes if data is being prepared for Tacotron2 experiments + if model_name == "Tacotron2": + logger.info( + "Computing phonemes for labels using SpeechBrain G2P. This may take a while." + ) + g2p = GraphemeToPhoneme.from_hparams( + "speechbrain/soundchoice-g2p", run_opts={"device": DEVICE} + ) + else: + g2p = None + + json_dict = {} + + # Processes all the wav files in the list + for wav_file in tqdm(wav_list): + # Reads the signal + signal, sig_sr = torchaudio.load(wav_file) + duration = signal.shape[1] / sig_sr + + # TODO add better way to filter short utterances + if duration < 1.0: + continue + + # Manipulates path to get relative path and uttid + path_parts = wav_file.split(os.path.sep) + uttid, _ = os.path.splitext(path_parts[-1]) + # relative_path = os.path.join("{data_root}", *path_parts[-4:]) + + # Gets the path for the text files and extracts the input text + normalized_text_path = os.path.join( + "/", *path_parts[:-1], uttid + ".normalized.txt" + ) + try: + with open(normalized_text_path, encoding="utf-8") as f: + normalized_text = f.read() + if normalized_text.__contains__("{"): + normalized_text = normalized_text.replace("{", "") + if normalized_text.__contains__("}"): + normalized_text = normalized_text.replace("}", "") + except FileNotFoundError: + print(f"Warning: The file {normalized_text_path} does not exist.") + continue + + # Resamples the audio file if required + if sig_sr != sample_rate and not skip_resample: + resampled_signal = torchaudio.functional.resample( + signal, sig_sr, sample_rate + ) + os.unlink(wav_file) + torchaudio.save(wav_file, resampled_signal, sample_rate=sample_rate) + + # Gets the speaker-id from the utterance-id + spk_id = uttid.split("_")[0] + + # Creates an entry for the utterance + json_dict[uttid] = { + "uttid": uttid, + "wav": wav_file, + "duration": duration, + "spk_id": spk_id, + "label": normalized_text, + "segment": True if "train" in json_file else False, + } + if alignments_folder is not None: + alignments_file_name = get_alignment_path(data_folder, alignments_folder, wav_file) + alignments = parse_alignments(alignments_file_name) + json_dict[uttid].update(alignments) + + # Characters are used for Tacotron2, phonemes may be needed for other models + if model_name not in ["Tacotron2", "HiFi-GAN"] and g2p is not None: + # Computes phoneme labels using SpeechBrain G2P and keeps the punctuations + phonemes = _g2p_keep_punctuations(g2p, normalized_text) + json_dict[uttid].update({"label_phoneme": phonemes}) + + # Writes the dictionary to the json file + with open(json_file, mode="w", encoding="utf-8") as json_f: + json.dump(json_dict, json_f, indent=2) + + logger.info(f"{json_file} successfully created!") + + +def get_alignment_path(data_folder, alignments_folder, file_name): + """Returns the path in the LibriSpeech-Alignments dataset + corresponding to the specified file path in LibriSpeech + + Arguments + --------- + data_folder: str + the path to LibriSpeech + alignments_folder: str + the path to LibriSpeech-Alignments + file_name: str + the file name within LibriSpeech + + Returns + ------- + file_name: str + the alignment file path + """ + file_name = Path(file_name) + data_folder = Path(data_folder) + if file_name.parts[0] == "{data_root}": + file_name_rel = file_name.relative_to("{data_root}") + else: + file_name_rel = file_name.relative_to(data_folder) + data_slice = file_name_rel.parts[0] + + textgrid_folder = file_name_rel.relative_to(Path(data_slice) / "LibriTTS" / data_slice).parent.parent + textgrid_file_name = f"{file_name_rel.stem}.TextGrid" + textgrid_path = Path(alignments_folder) / data_slice / textgrid_folder / textgrid_file_name + + return textgrid_path + + +def skip(*filenames): + """ + Detects if the data preparation has been already done. + If the preparation has been done, we can skip it. + + Arguments + --------- + *filenames : tuple + Set of filenames to check for existence. + + Returns + ------- + bool + if True, the preparation phase can be skipped. + if False, it must be done. + """ + for filename in filenames: + if isinstance(filename, list): + if any(not os.path.isfile(item) for item in filename): + return False + else: + if not os.path.isfile(filename): + return False + return True + + +def split_sets(wav_list, split_ratio): + """Randomly splits the wav list into training, validation, and test lists. + + Arguments + --------- + wav_list : list + list of all the signals in the dataset + split_ratio: list + List composed of three integers that sets split ratios for train, valid, + and test sets, respectively. For instance split_ratio=[80, 10, 10] will + assign 80% of the sentences to training, 10% for validation, and 10% + for test. + + Returns + ------- + dictionary containing train, valid, and test splits. + """ + # Random shuffles the list + random.shuffle(wav_list) + tot_split = sum(split_ratio) + tot_snts = len(wav_list) + data_split = {} + splits = ["train", "valid"] + + for i, split in enumerate(splits): + n_snts = int(tot_snts * split_ratio[i] / tot_split) + data_split[split] = wav_list[0:n_snts] + del wav_list[0:n_snts] + data_split["test"] = wav_list + + return data_split + + +def check_folders(*folders): + """Returns False if any passed folder does not exist.""" + for folder in folders: + if not os.path.exists(folder): + return False + return True + +def parse_alignments(file_name): + """Parses a given LibriSpeech-Alignments TextGrid file and + converts the results to the desired format (to be used in JSON + metadata) + + Arguments + --------- + file_name : path-like + the file name of the TextGrid file + + Returns + ------- + details: dict + the metadata details + """ + try: + import textgrids + except ImportError: + logger.error( + "Parsing LibriSpeech-alignments requires the" + "praat-textgrids package" + ) + raise + if not file_name.exists(): + return { + "has_alignments": False, + "phn": [], + "phn_stress": [], + "phn_start": [], + "phn_end": [], + "phn_count": 0, + "wrd": [], + "wrd_start": [], + "wrd_end": [], + "wrd_count": 0, + "unk_count": None + } + + text_grid = textgrids.TextGrid() + text_grid.read(file_name) + word_intervals = [ + {**word, "label": word["label"].upper()} + for word in text_grid.interval_tier_to_array("words") + ] + phn_intervals = text_grid.interval_tier_to_array("phones") + details = {} + details.update(intervals_to_dict(word_intervals, "wrd")) + phn = intervals_to_dict(phn_intervals, "phn") + phn_stress = phn["phn"] + phn_nostress = remove_stress_marks(phn_stress) + phn["phn"] = phn_nostress + phn["phn_stress"] = phn_stress + details.update(phn) + details["unk_count"] = sum(wrd == "" for wrd in details["wrd"]) + details["has_alignments"] = True + + return details + + +INTERVAL_MAP = [("label", ""), ("begin", "_start"), ("end", "_end")] +INTERVAL_EMPTY_LABELS = {"", "sil", "sp", "spn"} + + +def intervals_to_dict(intervals, prefix): + """ + Converts a parsed list of intervals from PRAAT TextGrid + to a learning-friendly array + + Arguments + --------- + intervals: list + A list of raw TextGrid intervals, as returned by + TextGrid.interval_tier_to_array + prefix: str + the prefix to add + + Returns + ------- + result: dict + A dictionary of the form + { + "{prefix}": , + "{prefix}_start": , + "{prefix}_end": , + "{prefix}_count: + } + + """ + # Remove meaningless labels + intervals_clean = [ + interval + for interval in intervals + if interval["label"] not in INTERVAL_EMPTY_LABELS + ] + result = { + f"{prefix}{suffix}": [interval[key] for interval in intervals_clean] + for key, suffix in INTERVAL_MAP + } + # This will map space labels to a single one + result[f"{prefix}_count"] = len(intervals_clean) + return result + + +RE_STRESS_MARK = re.compile(r"\d$") + + +def remove_stress_marks(phn): + """Removes stress marks from a phoneme annotation + + Arguments + --------- + phn: list + a list of phoneme annotations with or without stress marks + + Returns + ------- + result: list + a list of phoneme annotations without stress marks + """ + return [RE_STRESS_MARK.sub("", item) for item in phn] diff --git a/benchmarks/DASB/LibriTTS/resynthesis/evaluate.py b/benchmarks/DASB/LibriTTS/resynthesis/evaluate.py new file mode 100644 index 000000000..d70de8409 --- /dev/null +++ b/benchmarks/DASB/LibriTTS/resynthesis/evaluate.py @@ -0,0 +1,247 @@ +#!/usr/bin/env/python3 +"""Recipe for evaluating vocoders on resynthesis + +Authors + * Artem Ploujnikov 2024 +""" + + +import logging +import json +import csv +import sys +import torchaudio +import speechbrain as sb + +from types import SimpleNamespace +from tqdm.auto import tqdm +from pathlib import Path +from hyperpyyaml import load_hyperpyyaml +from speechbrain.dataio.batch import PaddedData +from speechbrain.utils.distributed import run_on_main +from torch import nn + + +logger = logging.getLogger(__name__) + + +class VocoderEvaluator: + """A standalone vocoder evaluator + + Arguments + --------- + hparams : dict + Hyperparameters + run_opts : dict + Run options + """ + def __init__(self, hparams, run_opts): + self.hparams = SimpleNamespace(**hparams, run_opts=None) + if run_opts is None: + run_opts = {} + self.device = run_opts.get("device", "cpu") + self.modules = nn.ModuleDict(self.hparams.modules).to(self.device) + + def on_evaluate_start(self): + """Invoked when evaluation starts""" + tokenizer = ( + self.modules.tokenizer.module + if hasattr(self.modules.tokenizer, "module") + else self.modules.tokenizer + ) + tokenizer.device = self.device + if hasattr(tokenizer, "codec_vocoder"): + tokenizer.codec_vocoder.to(self.device) + tokenizer.codec_vocoder.device = self.device + + if self.hparams.representation_mode == "continuous": + self.vocoder = self.hparams.vocoder( + run_opts={"device": self.device} + ) + if hasattr(self.vocoder, "device"): + self.vocoder.device = self.device + if hasattr(self.vocoder, "model"): + self.vocoder.model.device = self.device + self.metric = self.hparams.metric() + + def on_evaluate_end(self): + """Invoked when evaluation ends""" + summary = self.metric.summarize() + output_folder = Path(self.hparams.output_folder) + summary_file_name = output_folder / "vocoder" / "summary.json" + summary_file_name.parent.mkdir(parents=True, exist_ok=True) + with open(summary_file_name, "w") as summary_file: + json.dump(summary, summary_file, indent=4) + + mos_file_name = output_folder / "vocoder" / "mos.csv" + with open(mos_file_name, "w") as mos_file: + writer = csv.writer(mos_file) + writer.writerow(["id", "score"]) + for row in zip(self.metric.ids, self.metric.scores): + writer.writerow(row) + + def evaluate(self, dataset): + """Evaluates the vocoder on a dataset + + Arguments + --------- + dataset : DynamicItemDataset + a dataset + """ + self.on_evaluate_start() + dataloader = sb.dataio.dataloader.make_dataloader(dataset) + for batch in tqdm(dataloader): + self.evaluate_batch(batch) + self.on_evaluate_end() + + def evaluate_batch(self, batch): + """Evaluates a single batch + + Arguments + --------- + batch : PaddedBatch + a batch""" + batch = batch.to(self.device) + wav_rec = self.get_wav_rec(batch) + self.metric.append( + ids=batch.uttid, + wavs=wav_rec.squeeze(1), + length=batch.sig.lengths, + sample_rate=self.hparams.model_sample_rate + ) + + def get_wav_rec(self, batch): + """Retrieves audio features + + Arguments + --------- + batch : PaddedBatch + a batch + + Returns + ------- + audio: torch.Tensor + The audio representation + """ + if self.hparams.representation_mode == "discrete": + audio = self.modules.tokenizer.sig_to_tokens(batch.sig.data, batch.sig.lengths) + wav_rec = self.modules.tokenizer.tokens_to_sig(audio) + else: + audio = self.modules.ssl_model( + batch.sig.data, + batch.sig.lengths, + ) + audio = audio.permute(1, 2, 0, 3)[:, :, self.hparams.num_codebooks] + wav_rec = self.vocoder(audio) + return wav_rec + + +def dataio_prepare(hparams): + """This function prepares the datasets to be used in the brain class. + It also defines the data processing pipeline through user-defined functions. + + + Arguments + --------- + hparams : dict + This dictionary is loaded from the `train.yaml` file, and it includes + all the hyperparameters needed for dataset construction and loading. + + Returns + ------- + datasets : dict + Dictionary containing "train", "valid", and "test" keys that correspond + to the DynamicItemDataset objects. + """ + + # Define datasets from json data manifest file + # Define datasets sorted by ascending lengths for efficiency + datasets = {} + data_folder = hparams["data_folder"] + data_info = { + "train": hparams["train_json"], + "valid": hparams["valid_json"], + "test": hparams["test_json"], + } + + @sb.utils.data_pipeline.takes("wav") + @sb.utils.data_pipeline.provides("sig") + def sig_pipeline(wav): + sig = sb.dataio.dataio.read_audio(wav) + sig = torchaudio.functional.resample( + sig, + hparams["sample_rate"], + hparams["model_sample_rate"], + ) + return sig + + dynamic_items = [sig_pipeline] + output_keys = ["uttid", "sig"] + + for dataset in data_info: + dataset_dynamic_items = list(dynamic_items) + dataset_output_keys = list(output_keys) + + dynamic_dataset = sb.dataio.dataset.DynamicItemDataset.from_json( + json_path=data_info[dataset], + replacements={"data_root": data_folder}, + dynamic_items=dataset_dynamic_items, + output_keys=dataset_output_keys, + ) + datasets[dataset] = dynamic_dataset + + hparams["dataloader_opts"]["shuffle"] = False + return datasets + + +if __name__ == "__main__": + + # Reading command line arguments + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + + # Initialize ddp (useful only for multi-GPU DDP training) + sb.utils.distributed.ddp_init_group(run_opts) + + # Load hyperparameters file with command-line overrides + with open(hparams_file, encoding="utf-8") as fin: + yaml = fin.read() + + hparams = load_hyperpyyaml(yaml, overrides, overrides_must_match=True) + + # Create experiment directory + sb.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + from libritts_prepare import prepare_libritts + + # Data preparation, to be run on only one process. + if not hparams["skip_prep"]: + run_on_main( + prepare_libritts, + kwargs={ + "data_folder": hparams["data_folder"], + "train_split": hparams["train_splits"], + "valid_split": hparams["dev_splits"], + "test_split": hparams["test_splits"], + "save_json_train": hparams["train_json"], + "save_json_valid": hparams["valid_json"], + "save_json_test": hparams["test_json"], + "sample_rate": hparams["sample_rate"], + "skip_prep": hparams["skip_prep"], + "max_valid_size": None, + "skip_resample": hparams["skip_resample"], + }, + ) + # We can now directly create the datasets for training, valid, and test + datasets = dataio_prepare(hparams) + + # Evaluate + evaluator = VocoderEvaluator(hparams, run_opts) + eval_dataset_key = hparams["eval_dataset"] + eval_dataset = datasets[eval_dataset_key] + logger.info("Starting evaluation on %s", eval_dataset_key) + evaluator.evaluate(eval_dataset) + logger.info("Evaluation ended") diff --git a/benchmarks/DASB/LibriTTS/resynthesis/hparams/discrete_ssl.yaml b/benchmarks/DASB/LibriTTS/resynthesis/hparams/discrete_ssl.yaml new file mode 100644 index 000000000..e0e1ff065 --- /dev/null +++ b/benchmarks/DASB/LibriTTS/resynthesis/hparams/discrete_ssl.yaml @@ -0,0 +1,127 @@ +# ############################################################################ +# Auido Tokenizer: WavLM +# Extraction: Librispeech 960h +# Authors: Jarod Duret 2024 +# ############################################################################ +# Seed needs to be set at top of yaml, before objects with parameters are made + +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/wavlm +save_folder: !ref /save +train_log: !ref /extraction_log.txt + + +# Data files +data_folder: !PLACEHOLDER # e.g., /path/to/LibriSpeech +cached_data_folder: !ref # e.g., path/to/cache +pretrained_model_save_folder: !ref /pretrained_models +train_splits: ["train-clean-100"] #, "train-clean-360", "train-other-500" +dev_splits: ["dev-clean"] +test_splits: ["test-clean", "test-other"] +skip_prep: False +skip_resample: True +train_json: !ref /train.json +valid_json: !ref /dev-clean.json +test_json: !ref /test.json + +batch_size: 1 +num_workers: 8 +src_key: wav +id_key: id + +# Dataloader options +dataloader_opts: + batch_size: !ref + shuffle: True + num_workers: !ref + +### Configuration for discrete SSL model +# | SSL Model | HF Encoder | K-Means Dataset | K-Means Size | SSL Layers | Vocoder Model | +# |------------|----------------------------------------|-----------------|--------------|----------------------|---------------------------------------------| +# | WavLM | microsoft/wavlm-large | LibriSpeech960 | 1000 | 1, 3, 7, 12, 18, 23 | speechbrain/hifigan-wavlm-k1000-LibriTTS | +# | HuBERT | facebook/hubert-large-ll60k | LibriSpeech960 | 1000 | 1, 3, 7, 12, 18, 23 | speechbrain/hifigan-hubert-k1000-LibriTTS | +# | Wav2Vec2 | facebook/wav2vec2-large | LibriSpeech960 | 1000 | 1, 3, 7, 12, 18, 23 | speechbrain/hifigan-wav2vec2-k1000-LibriTTS | + + +# ssl_model_type: HuBERT, WavLM, Wav2Vec2 +# ssl_hub: facebook/hubert-large-ll60k, microsoft/wavlm-large, facebook/wav2vec2-large +ssl_model_type: wavlm +ssl_hub: microsoft/wavlm-large +ssl_folder: !ref /ssl_checkpoint +kmeans_cache_dir: !ref /kmeans_checkpoint +kmeans_dataset: LibriSpeech +vocoder_repo_id: speechbrain/hifigan-wavlm-k1000-LibriTTS +freeze_ssl: True +freeze_feature_extractor: True +vocab_size: 1000 +save_embedding: False +utmos_source: chaanks/wav2vec2-small +utmos_sample_rate: 16000 +utmos_model_url: https://huggingface.co/chaanks/UTMOS/resolve/main +eval_dataset: valid +representation_mode: discrete + +### Config for Tokenizer +# Layer number should be among the supported layers for discrete SSL models(kmenas model should be available for that layer) +num_codebooks: [1, 3, 7, 12, 18, 23] +deduplicate: [False, False, False, False, False, False] +bpe_tokenizer_path: [null, null, null, null, null, null] +sample_rate: 24000 +model_sample_rate: 16000 +encoder_dim: 1024 + +ssl_model: !apply:speechbrain.utils.hparams.choice + value: !ref + choices: + wavlm: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM + source: !ref + output_norm: False + freeze: !ref + freeze_feature_extractor: !ref + output_all_hiddens: True + save_path: !ref + hubert: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT + source: !ref + output_norm: False + freeze: !ref + freeze_feature_extractor: !ref + output_all_hiddens: True + save_path: !ref + wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2 + source: !ref + output_norm: False + freeze: !ref + freeze_feature_extractor: !ref + output_all_hiddens: True + save_path: !ref + + +vocoder_src: !apply:speechbrain.utils.hparams.choice + value: !ref + choices: + wavlm: !ref speechbrain/hifigan-wavlm-l1-3-7-12-18-23-continuous-LibriTTS + hubert: !ref chaanks/hifigan-hubert-l1-3-7-12-18-23-LibriTTS + + +vocoder: !name:speechbrain.inference.vocoders.HIFIGAN.from_hparams + source: !ref + savedir: !ref + + +tokenizer: !new:utils.tokenizer_interface.DiscreteSSLTokenizer + save_path: !ref + ssl_model: !ref + vocoder_repo_id: !ref + kmeans_dataset: !ref + num_clusters: !ref + +modules: + ssl_model: !ref + tokenizer: !ref + +metric: !name:metric.UTMOSMetric + source: !ref + save_path: !ref + sample_rate: !ref + model_url: !ref \ No newline at end of file diff --git a/benchmarks/DASB/LibriTTS/resynthesis/libritts_prepare.py b/benchmarks/DASB/LibriTTS/resynthesis/libritts_prepare.py new file mode 120000 index 000000000..39f1a78c2 --- /dev/null +++ b/benchmarks/DASB/LibriTTS/resynthesis/libritts_prepare.py @@ -0,0 +1 @@ +../libritts_prepare.py \ No newline at end of file diff --git a/benchmarks/DASB/LibriTTS/resynthesis/metric.py b/benchmarks/DASB/LibriTTS/resynthesis/metric.py new file mode 100644 index 000000000..182a4451e --- /dev/null +++ b/benchmarks/DASB/LibriTTS/resynthesis/metric.py @@ -0,0 +1,369 @@ +"""Resynthesis metircs + +Authors + * Artem Ploujnikov 2024 +""" + +import csv +import torch +import torchaudio +from pathlib import Path +from torch import nn + +from speechbrain.lobes.models.huggingface_transformers.wav2vec2 import Wav2Vec2 +from speechbrain.utils.fetching import fetch +from speechbrain.utils.metric_stats import MetricStats + + +UTMOS_SAMPLE_RATE = 16000 +UTMOS_DEFAULT_JUDGE_ID = 288 +UTMOS_DEFAULT_DOMAIN_ID = 0 +UTMOS_DEFAULT_MODEL_NAME = "utmos.ckpt" + + +class UTMOSModel(nn.Module): + """The UTMOS model wrapper + + Arguments + --------- + source : str + The WavLM source + save_path : str | path-like + The path where the model will be saved + features_dim : int, optional + The features dimension + num_domains : int, optional + The number of domains + domain_dim : int, optional + The dimension of each domain + num_judges : int, optional + The number of "judges" + judge_dim : int, optional + The dimension of each judge + decoder_hidden_size : int, optional + The size of the decoder hidden state + multiplier : float, optional + The number that the raw model output is multiplied by + to compute the score + offset : float, optional + The number that (raw output * multiplier) will be added + to in order to get the score + """ + + def __init__( + self, + source, + save_path, + features_dim=768, + num_domains=3, + domain_dim=128, + num_judges=3000, + judge_dim=128, + decoder_hidden_size=512, + multiplier=2.0, + offset=3.0, + ): + super().__init__() + + self.ssl_encoder = Wav2Vec2( + source, + save_path, + freeze=True, + output_norm=False, + freeze_feature_extractor=True, + output_all_hiddens=False, + ) + + self.domain_embedding = nn.Embedding(num_domains, domain_dim) + self.judge_embedding = nn.Embedding(num_judges, judge_dim) + + self.decoder = nn.LSTM( + input_size=features_dim + domain_dim + judge_dim, + hidden_size=decoder_hidden_size, + num_layers=1, + batch_first=True, + bidirectional=True, + ) + + self.classifier = nn.Sequential( + nn.Linear(decoder_hidden_size * 2, 2048), + torch.nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(2048, 1), + ) + self.multiplier = multiplier + self.offset = offset + + def forward(self, wav, domain_id=None, judge_id=None): + """Computes the forward pass + + Arguments + --------- + wav : torch.Tensor + The raw waveforms + domain_id : torch.Tensor + The domain identifiers + judge_id : torch.Tensor + The judge identifier + + Returns + ------- + result : torch.Tensor + The predicted rating(s) + """ + + if domain_id is None: + domain_id = torch.zeros( + len(wav), dtype=torch.int, device=wav.device + ) + if judge_id is None: + judge_id = ( + torch.ones(len(wav), dtype=torch.int, device=wav.device) + * UTMOS_DEFAULT_JUDGE_ID + ) + + ssl_features = self.ssl_encoder(wav) + domain_emb = self.domain_embedding(domain_id) + judge_emb = self.judge_embedding(judge_id) + + domain_emb = domain_emb.unsqueeze(1).expand( + -1, ssl_features.size(1), -1 + ) + judge_emb = judge_emb.unsqueeze(1).expand(-1, ssl_features.size(1), -1) + concatenated_feature = torch.cat( + [ssl_features, domain_emb, judge_emb], dim=2 + ) + + decoder_output, _ = self.decoder(concatenated_feature) + pred = self.classifier(decoder_output) + + return pred.mean(dim=1).squeeze(1) * self.multiplier + self.offset + + +class UTMOSMetric(MetricStats): + """A metric implementing UTMOS + + Arguments + --------- + sample_rate : int + The audio sample rate + source : str`, optional + The HuggingFace hube name for the encoder + save_path : str | path-like, optional + The path where the model will be saved + model_name : str, optional + The name of the model + model_url : str, optional + The download URL for the model + features_dim : int, optional + The features dimension + num_domains : int, optional + The number of domains + domain_dim : int, optional + The dimension of each domain + num_judges : int, optional + The number of "judges" + judge_dim : int, optional + The dimension of each judge + decoder_hidden_size : int, optional + The size of the decoder hidden state + domain_id : int, optional + The domain identifier + judge_id : int, optional + The judge identifier + run_opts : dict + Run options when instantiating the metric + """ + + def __init__( + self, + sample_rate, + source, + save_path, + model_name=None, + model_url=None, + features_dim=768, + num_domains=3, + domain_dim=128, + num_judges=3000, + judge_dim=128, + decoder_hidden_size=512, + domain_id=None, + judge_id=None, + run_opts=None, + ): + self.sample_rate = sample_rate + self.clear() + + if model_name is None: + model_name = UTMOS_DEFAULT_MODEL_NAME + if domain_id is None: + domain_id = UTMOS_DEFAULT_DOMAIN_ID + if judge_id is None: + judge_id = UTMOS_DEFAULT_JUDGE_ID + if sample_rate is None: + sample_rate = UTMOS_SAMPLE_RATE + + encoder_path = Path(save_path) + self.model = UTMOSModel( + source=source, + save_path=encoder_path.as_posix(), + features_dim=features_dim, + num_domains=num_domains, + domain_dim=domain_dim, + num_judges=num_judges, + judge_dim=judge_dim, + decoder_hidden_size=decoder_hidden_size, + ) + + # Download utmos model checkpoint + fetch(model_name, model_url, save_path) + model_path = Path(save_path) / model_name + assert model_path.exists() + + # Load weights + state_dict = torch.load(model_path) + self.model.load_state_dict(state_dict) + self.model.eval() + self.domain_id = domain_id + self.judge_id = judge_id + + if run_opts: + device = run_opts.get("device") + if device: + self.model.to(device) + + def append( + self, + ids, + wavs, + length=None, + sample_rate=None, + domain_ids=None, + judge_ids=None, + **kwargs, + ): + """Computes the UTMOS metric for the provided audio + + Arguments + --------- + ids : list + The list of item IDs + wavs : torch.Tensor + The audio prediction to be evaluated (e.g. TTS output) + length : torch.Tensor, optional + Relative lengths + sample_rate : int + The sample rate + domain_ids : torch.Tensor, optional + The domain IDs. The default will be used if not provided + judge_ids : torch.Tensor + The judge IDs. The default will be used if not provided + **kwargs: : dict + Other arguments (ignored) + """ + if wavs.dim() > 2: + wavs = wavs.squeeze() + + # Resample + hyp_audio = torchaudio.functional.resample( + wavs, sample_rate, self.sample_rate + ) + + self.model.device = hyp_audio.device + self.model.to(hyp_audio.device) + + if domain_ids is None: + domain_ids = torch.zeros( + len(hyp_audio), dtype=torch.int, device=hyp_audio.device + ) + if judge_ids is None: + judge_ids = ( + torch.ones( + len(hyp_audio), dtype=torch.int, device=hyp_audio.device + ) + * self.judge_id + ) + + output = self.model(hyp_audio, domain_ids, judge_ids) + self.scores += output.cpu().tolist() + + self.ids += ids + + def summarize(self, field=None): + """Returns a dict containing detailed UTMOS statistics. UTMOS + itself produces only one score per utterance - but the summary + will obtain full descriptive statistics (see `descriptive_statistics`) + + Arguments + --------- + field : str, optional + The field to return, if you are only interested in one of them. + If specified, a single `float` is returned, otherwise, a dict is. + + Returns + ------- + dict from str to float, if `field is None` + A dictionary of the fields documented above. + float, if `field is not None` + The single field selected by `field`. + """ + stats = descriptive_statistics(self.scores, result_key="utmos") + return stats[field] if field else stats + + def write_stats(self, filestream, verbose=False): + writer = csv.writer(filestream) + writer.writerow(["id", "utmos"]) + for uttid, row in zip(self.ids, self.scores): + writer.writerow([uttid, row]) + + +def descriptive_statistics(items, key=None, result_key=None): + """Computes descriptive statistics for the summary + + Arguments + --------- + items : list + a list of dictionaries with metric values for each item + key : str + The key of the metric for which the statistics will be computed + result_key : str + The key to use for results + + Returns + ------- + statistics : dict + The desccriptive statistics computed + _mean : the arithmetic mean + _std : the standard deviation + _min : the minimum value + _max : the maximum value + _median : the median value + _q1 : the first quartile + _q3 : the third quartile + _iqr : the interquartile ratio + """ + if not items: + return {} + if not result_key: + result_key = key + if key is None: + values = torch.tensor(items) + else: + values = torch.tensor([item[key] for item in items]) + quantiles = torch.tensor([0.25, 0.5, 0.75]) + q1, median, q3 = values.quantile(quantiles) + stats = { + "mean": values.mean(), + "std": values.std(), + "min": values.min(), + "max": values.max(), + "median": median, + "q1": q1, + "q3": q3, + "iqr": q3 - q1, + } + return { + f"{result_key}_{stat_key}": value.item() + for stat_key, value in stats.items() + } \ No newline at end of file From 2378ea1545685985911273550d1e4549fc8c02de Mon Sep 17 00:00:00 2001 From: flexthink Date: Thu, 17 Apr 2025 13:55:40 -0400 Subject: [PATCH 2/2] DASB: Resynthesis: Cosmetic changes --- benchmarks/DASB/LibriTTS/libritts_prepare.py | 81 ++++++++++++++++--- .../DASB/LibriTTS/resynthesis/evaluate.py | 17 ++-- .../resynthesis/hparams/discrete_ssl.yaml | 6 +- .../DASB/LibriTTS/resynthesis/metric.py | 2 +- 4 files changed, 81 insertions(+), 25 deletions(-) diff --git a/benchmarks/DASB/LibriTTS/libritts_prepare.py b/benchmarks/DASB/LibriTTS/libritts_prepare.py index 52594eaf9..dda10826d 100644 --- a/benchmarks/DASB/LibriTTS/libritts_prepare.py +++ b/benchmarks/DASB/LibriTTS/libritts_prepare.py @@ -109,16 +109,40 @@ def prepare_libritts( # If specific splits are provided, creates data manifest files accordingly if train_split: wav_list = prepare_split(data_folder, train_split) - create_json(wav_list, save_json_train, sample_rate, data_folder, alignments_folder, model_name, skip_resample) + create_json( + wav_list, + save_json_train, + sample_rate, + data_folder, + alignments_folder, + model_name, + skip_resample, + ) if valid_split: wav_list = prepare_split(data_folder, valid_split) # TODO add better way to speedup evaluation if max_valid_size is not None and len(wav_list) > max_valid_size: wav_list = random.sample(wav_list, max_valid_size) - create_json(wav_list, save_json_valid, sample_rate, data_folder, alignments_folder, model_name, skip_resample) + create_json( + wav_list, + save_json_valid, + sample_rate, + data_folder, + alignments_folder, + model_name, + skip_resample, + ) if test_split: wav_list = prepare_split(data_folder, test_split) - create_json(wav_list, save_json_test, sample_rate, data_folder, alignments_folder, model_name, skip_resample) + create_json( + wav_list, + save_json_test, + sample_rate, + data_folder, + alignments_folder, + model_name, + skip_resample, + ) if skip(save_json_train, save_json_valid, save_json_test): logger.info("Preparation completed.") @@ -132,12 +156,29 @@ def prepare_libritts( data_split = split_sets(wav_list, split_ratio) # Creating json files create_json( - data_split["train"], save_json_train, sample_rate, alignments_folder, model_name, skip_resample + data_split["train"], + save_json_train, + sample_rate, + alignments_folder, + model_name, + skip_resample, + ) + create_json( + data_split["valid"], + save_json_valid, + sample_rate, + alignments_folder, + model_name, + skip_resample, ) create_json( - data_split["valid"], save_json_valid, sample_rate, alignments_folder, model_name, skip_resample + data_split["test"], + save_json_test, + sample_rate, + alignments_folder, + model_name, + skip_resample, ) - create_json(data_split["test"], save_json_test, sample_rate, alignments_folder, model_name, skip_resample) def prepare_split(data_folder, split_list): @@ -180,7 +221,15 @@ def prepare_split(data_folder, split_list): return wav_list -def create_json(wav_list, json_file, sample_rate, data_folder, alignments_folder=None, model_name=None, skip_resample=False): +def create_json( + wav_list, + json_file, + sample_rate, + data_folder, + alignments_folder=None, + model_name=None, + skip_resample=False, +): """ Creates the json file given a list of wav files. Arguments @@ -266,7 +315,9 @@ def create_json(wav_list, json_file, sample_rate, data_folder, alignments_folder "segment": True if "train" in json_file else False, } if alignments_folder is not None: - alignments_file_name = get_alignment_path(data_folder, alignments_folder, wav_file) + alignments_file_name = get_alignment_path( + data_folder, alignments_folder, wav_file + ) alignments = parse_alignments(alignments_file_name) json_dict[uttid].update(alignments) @@ -309,9 +360,16 @@ def get_alignment_path(data_folder, alignments_folder, file_name): file_name_rel = file_name.relative_to(data_folder) data_slice = file_name_rel.parts[0] - textgrid_folder = file_name_rel.relative_to(Path(data_slice) / "LibriTTS" / data_slice).parent.parent + textgrid_folder = file_name_rel.relative_to( + Path(data_slice) / "LibriTTS" / data_slice + ).parent.parent textgrid_file_name = f"{file_name_rel.stem}.TextGrid" - textgrid_path = Path(alignments_folder) / data_slice / textgrid_folder / textgrid_file_name + textgrid_path = ( + Path(alignments_folder) + / data_slice + / textgrid_folder + / textgrid_file_name + ) return textgrid_path @@ -382,6 +440,7 @@ def check_folders(*folders): return False return True + def parse_alignments(file_name): """Parses a given LibriSpeech-Alignments TextGrid file and converts the results to the desired format (to be used in JSON @@ -417,7 +476,7 @@ def parse_alignments(file_name): "wrd_start": [], "wrd_end": [], "wrd_count": 0, - "unk_count": None + "unk_count": None, } text_grid = textgrids.TextGrid() diff --git a/benchmarks/DASB/LibriTTS/resynthesis/evaluate.py b/benchmarks/DASB/LibriTTS/resynthesis/evaluate.py index d70de8409..ba072994f 100644 --- a/benchmarks/DASB/LibriTTS/resynthesis/evaluate.py +++ b/benchmarks/DASB/LibriTTS/resynthesis/evaluate.py @@ -17,7 +17,6 @@ from tqdm.auto import tqdm from pathlib import Path from hyperpyyaml import load_hyperpyyaml -from speechbrain.dataio.batch import PaddedData from speechbrain.utils.distributed import run_on_main from torch import nn @@ -35,6 +34,7 @@ class VocoderEvaluator: run_opts : dict Run options """ + def __init__(self, hparams, run_opts): self.hparams = SimpleNamespace(**hparams, run_opts=None) if run_opts is None: @@ -107,7 +107,7 @@ def evaluate_batch(self, batch): ids=batch.uttid, wavs=wav_rec.squeeze(1), length=batch.sig.lengths, - sample_rate=self.hparams.model_sample_rate + sample_rate=self.hparams.model_sample_rate, ) def get_wav_rec(self, batch): @@ -124,13 +124,12 @@ def get_wav_rec(self, batch): The audio representation """ if self.hparams.representation_mode == "discrete": - audio = self.modules.tokenizer.sig_to_tokens(batch.sig.data, batch.sig.lengths) + audio = self.modules.tokenizer.sig_to_tokens( + batch.sig.data, batch.sig.lengths + ) wav_rec = self.modules.tokenizer.tokens_to_sig(audio) else: - audio = self.modules.ssl_model( - batch.sig.data, - batch.sig.lengths, - ) + audio = self.modules.ssl_model(batch.sig.data, batch.sig.lengths,) audio = audio.permute(1, 2, 0, 3)[:, :, self.hparams.num_codebooks] wav_rec = self.vocoder(audio) return wav_rec @@ -169,9 +168,7 @@ def dataio_prepare(hparams): def sig_pipeline(wav): sig = sb.dataio.dataio.read_audio(wav) sig = torchaudio.functional.resample( - sig, - hparams["sample_rate"], - hparams["model_sample_rate"], + sig, hparams["sample_rate"], hparams["model_sample_rate"], ) return sig diff --git a/benchmarks/DASB/LibriTTS/resynthesis/hparams/discrete_ssl.yaml b/benchmarks/DASB/LibriTTS/resynthesis/hparams/discrete_ssl.yaml index e0e1ff065..c1287670e 100644 --- a/benchmarks/DASB/LibriTTS/resynthesis/hparams/discrete_ssl.yaml +++ b/benchmarks/DASB/LibriTTS/resynthesis/hparams/discrete_ssl.yaml @@ -95,7 +95,7 @@ ssl_model: !apply:speechbrain.utils.hparams.choice freeze_feature_extractor: !ref output_all_hiddens: True save_path: !ref - + vocoder_src: !apply:speechbrain.utils.hparams.choice value: !ref @@ -103,7 +103,7 @@ vocoder_src: !apply:speechbrain.utils.hparams.choice wavlm: !ref speechbrain/hifigan-wavlm-l1-3-7-12-18-23-continuous-LibriTTS hubert: !ref chaanks/hifigan-hubert-l1-3-7-12-18-23-LibriTTS - + vocoder: !name:speechbrain.inference.vocoders.HIFIGAN.from_hparams source: !ref savedir: !ref @@ -124,4 +124,4 @@ metric: !name:metric.UTMOSMetric source: !ref save_path: !ref sample_rate: !ref - model_url: !ref \ No newline at end of file + model_url: !ref diff --git a/benchmarks/DASB/LibriTTS/resynthesis/metric.py b/benchmarks/DASB/LibriTTS/resynthesis/metric.py index 182a4451e..1c23a4968 100644 --- a/benchmarks/DASB/LibriTTS/resynthesis/metric.py +++ b/benchmarks/DASB/LibriTTS/resynthesis/metric.py @@ -366,4 +366,4 @@ def descriptive_statistics(items, key=None, result_key=None): return { f"{result_key}_{stat_key}": value.item() for stat_key, value in stats.items() - } \ No newline at end of file + }