From c7f5fd48119c5eefa80c634cae380363433518bd Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Sat, 29 Jun 2024 16:13:04 +0200 Subject: [PATCH 1/7] doc: fix `Bark` -> `Encodec` --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index cd7be33..534cd28 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ https://github.com/PABannier/encodec.cpp/assets/12958149/d11561be-98e9-4504-bba7 ## Usage -Here are the steps for the bark model. +Here are the steps for the Encodec model. ### Get the code From 12058a39d911995b6444147a4b7eb4f74c0388b3 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Sat, 29 Jun 2024 16:13:28 +0200 Subject: [PATCH 2/7] python: convert Vocos weights to GGML format --- examples/vocos/convert.py | 159 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 159 insertions(+) create mode 100644 examples/vocos/convert.py diff --git a/examples/vocos/convert.py b/examples/vocos/convert.py new file mode 100644 index 0000000..cb88630 --- /dev/null +++ b/examples/vocos/convert.py @@ -0,0 +1,159 @@ +"""Convert Vocos model checkpoint into the GGML format. + +The bytes are packed in a binary file in the following order: + - Magic (`ggml` in binary format) + - Tensors + +For each tensor, the bytes are packed as follows: + - Number of dimensions (int) + - Name length (int) + - Dimensions (int[n_dims]) + - Name (char[name_length]) + - Data (float[n_dims]) + +Usage +----- + +```bash + python convert.py \ + --dir-model \ + --out-dir ./ \ + --use-f16 +``` +""" +import argparse +from pathlib import Path +import struct +import re +import yaml + +import numpy as np +import torch + +parser = argparse.ArgumentParser() +parser.add_argument("--dir-model", type=str, required=True) +parser.add_argument("--out-dir", type=str, required=True) +parser.add_argument("--use-f16", action="store_true") + + +def parse_vocos_weights(checkpoint, fout, use_f16): + """Dump Vocos model checkpoint.""" + def _clean_name(name): + module_name = name.split(".")[0] + if module_name not in ["feature_extractor", "backbone", "head"]: + raise Exception("Unknown module name") + + # Backbone + if re.match(r"backbone\.convnext\.\d+\.dwconv\.(weight|bias)", name): + i = re.findall(r"\d+", name)[0] + ttype = re.findall(r"(weight|bias)", name)[0][0] + return f"{module_name}/convnext/{i}/dwconv/{ttype}" + elif re.match(r"backbone\.convnext\.\d+\.gamma", name): + i = re.findall(r"\d+", name)[0] + return f"{module_name}/convnext/{i}/gamma" + elif re.match(r"backbone\.convnext\.\d+\.norm\.(scale|shift)\.weight", name): + i = re.findall(r"\d+", name)[0] + scale_or_shift = re.findall(r"(scale|shift)", name)[0] + return f"{module_name}/convnext/{i}/norm/{scale_or_shift}" + elif re.match(r"backbone\.convnext\.\d+\.pwconv(\d+)\.(weight|bias)", name): + matches = re.findall(r"\d+", name) + i, j = matches[0], matches[1] + ttype = re.findall(r"(weight|bias)", name)[0][0] + return f"{module_name}/convnext/{i}/pwconv/{j}/{ttype}" + elif re.match(r"backbone\.(embed|final_layer_norm)\.(weight|bias)", name): + ltype = re.findall(r"(embed|final_layer_norm)", name)[0] + ttype = re.findall(r"(weight|bias)", name)[0][0] + return f"{module_name}/{ltype}/{ttype}" + elif re.match(r"backbone\.norm\.(scale|shift)\.weight", name): + ltype = re.findall(r"(scale|shift)", name)[0] + return f"{module_name}/norm/{ltype}/w" + # Feature extractor + elif name == "feature_extractor.codebook_weights": + return f"{module_name}/codebook_weights" + # Head + elif name == "head.istft.window": + return f"{module_name}/istft/window" + elif re.match(r"head\.out\.(weight|bias)", name): + ttype = re.findall(r"(weight|bias)", name)[0][0] + return f"{module_name}/out/{ttype}" + # Unknown + else: + raise Exception(f"Unknown variable name: {name}") + + n_f16, n_f32 = 0, 0 + + for name in checkpoint.keys(): + var_data = checkpoint[name].cpu().numpy() + clean_name = _clean_name(name) + + print(f"{name : <40} -> {clean_name}") + print(f" {var_data.shape}") + + if use_f16: + if clean_name.endswith("/w"): + # Only weight matrices are cast to float16 + var_data = var_data.astype(np.float16) + ftype_cur = 1 + n_f16 += 1 + else: + var_data = var_data.astype(np.float32) + ftype_cur = 0 + n_f32 += 1 + else: + var_data = var_data.astype(np.float32) + ftype_cur = 0 + n_f32 += 1 + + n_dims = len(var_data.shape) + encoded_name = clean_name.encode("utf-8") + fout.write(struct.pack("iii", n_dims, len(encoded_name), ftype_cur)) + + for i in range(n_dims): + fout.write(struct.pack("i", var_data.shape[n_dims - 1 - i])) + fout.write(encoded_name) + + var_data.tofile(fout) + + print("\n") + print(f"n_f16: {n_f16} ({n_f16 / (n_f16 + n_f32) * 100:.0f}%)") + print(f"n_f32: {n_f32} ({n_f32 / (n_f16 + n_f32) * 100:.0f}%)") + + +def parse_hparams(fout, config, use_f16): + # Backbone + bb_config = config["backbone"]["init_args"] + fout.write(struct.pack("i", bb_config["input_channels"])) + fout.write(struct.pack("i", bb_config["dim"])) + fout.write(struct.pack("i", bb_config["intermediate_dim"])) + fout.write(struct.pack("i", bb_config["num_layers"])) + fout.write(struct.pack("i", bb_config["adanorm_num_embeddings"])) + # Head (padding is assumed to be `same`) + head_config = config["head"]["init_args"] + fout.write(struct.pack("i", head_config["dim"])) + fout.write(struct.pack("i", head_config["n_fft"])) + fout.write(struct.pack("i", head_config["hop_length"])) + # General + fout.write(struct.pack("i", int(use_f16))) + + +if __name__ == "__main__": + args = parser.parse_args() + + dir_model = Path(args.dir_model) + + out_dir = Path(args.out_dir) + out_dir.mkdir(exist_ok=True, parents=True) + + outfile = Path(out_dir / "ggml_weights.bin") + fout = open(outfile, "wb") + fout.write(struct.pack("i", 0x67676d6c)) + + with open(dir_model / "config.yaml", "rb") as f: + config = yaml.safe_load(f) + parse_hparams(fout, config, args.use_f16) + + checkpoint = torch.load(dir_model / "pytorch_model.bin", map_location="cpu") + parse_vocos_weights(checkpoint, fout, args.use_f16) + + fout.close() + print("Done.") From 5f99e3fab5df94078b31623a64b541d22068fee0 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Sat, 29 Jun 2024 16:13:42 +0200 Subject: [PATCH 3/7] example: add Vocos executable --- examples/CMakeLists.txt | 1 + examples/vocos/CMakeLists.txt | 4 + examples/vocos/vocos.cpp | 504 ++++++++++++++++++++++++++++++++++ 3 files changed, 509 insertions(+) create mode 100644 examples/vocos/CMakeLists.txt create mode 100644 examples/vocos/vocos.cpp diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 0e7ed19..c3694ca 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -5,3 +5,4 @@ target_compile_features(common PRIVATE cxx_std_11) add_subdirectory(main) add_subdirectory(compress) add_subdirectory(decompress) +add_subdirectory(vocos) \ No newline at end of file diff --git a/examples/vocos/CMakeLists.txt b/examples/vocos/CMakeLists.txt new file mode 100644 index 0000000..ac6d444 --- /dev/null +++ b/examples/vocos/CMakeLists.txt @@ -0,0 +1,4 @@ +set(TARGET vocos) +add_executable(${TARGET} vocos.cpp) +target_link_libraries(${TARGET} PRIVATE encodec common ggml) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/vocos/vocos.cpp b/examples/vocos/vocos.cpp new file mode 100644 index 0000000..2ffa105 --- /dev/null +++ b/examples/vocos/vocos.cpp @@ -0,0 +1,504 @@ +/* This demonstrates how to use the Vocos encoder with Encodec code features +to reconstruct an audio. + +Author: Pierre-Antoine Bannier +*/ +#include "ggml-alloc.h" +#include "ggml-backend.h" +#include "ggml.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "encodec.h" + +#define VOCOS_FILE_MAGIC 'ggml' + +static const size_t MB = 1024 * 1024; + +struct vocos_hparams { + // Number of input channels in backbone + int32_t input_channels; + // Inner dimension in backbone + int32_t dim; + // Intermediate dimension in backbone + int32_t dim_intermediate; + // Number of layers in backbone + int32_t n_layers; + // Number of codes + int32_t adanorm_num_embeddings; + // Dimension in head + int32_t head_dim; + // Number of FFT bins + int32_t n_fft; + // Hop length + int32_t hop_length; + + // File type of model weights + int32_t ftype; +}; + +struct vocos_backbone_layer { + struct ggml_tensor * dwconv_w; + struct ggml_tensor * dwconv_b; + + struct ggml_tensor * gamma; + + struct ggml_tensor * norm_scale; + struct ggml_tensor * norm_shift; + + struct ggml_tensor * pwconv1_w; + struct ggml_tensor * pwconv1_b; + + struct ggml_tensor * pwconv2_w; + struct ggml_tensor * pwconv2_b; +}; + +struct vocos_backbone { + struct ggml_tensor * embed_w; + struct ggml_tensor * embed_b; + + struct ggml_tensor * norm_scale; + struct ggml_tensor * norm_shift; + + struct ggml_tensor * final_ln_w; + struct ggml_tensor * final_ln_b; + + std::vector layers; +}; + +struct vocos_feature_extractor { + struct ggml_tensor *codebook_weights; +}; + +struct vocos_head { + struct ggml_tensor *istft_window; + struct ggml_tensor *proj_out_w; + struct ggml_tensor *proj_out_b; +}; + +struct vocos_model { + struct vocos_hparams hparams; + + struct vocos_backbone backbone; + struct vocos_feature_extractor feature_extractor; + struct vocos_head head; + + // context + struct ggml_context * ctx; + int n_loaded; + + ggml_backend_t backend = NULL; + ggml_backend_buffer_t buffer_w; + + std::map tensors; +}; + +struct vocos_statistics { + // The time taken to load the model. + int64_t t_load_us; + // The time taken to compute the model. + int64_t t_compute_us; +}; + +struct vocos_context { + struct vocos_model model; + + // buffer for model evaluation + ggml_backend_buffer_t buf_compute; + + // custom allocator + struct ggml_allocr * allocr = NULL; + + // intermediate steps + struct ggml_tensor * features = NULL; + struct ggml_tensor * codes = NULL; + struct ggml_tensor * decoded = NULL; + + std::vector out_codes; + std::vector out_audio; + + // statistics + struct vocos_statistics stats; +}; + +typedef enum { + // Run the end-to-end encoder-decoder pipeline + full = 0, + // Encode an audio + encode = 1, + // Decode an audio from codes + decode = 2, +} vocos_run_mode; + +template +static void read_safe(std::ifstream &fin, T &dest) { + fin.read((char *)&dest, sizeof(T)); +} + +bool vocos_load_model_weights(std::ifstream &fin, struct vocos_model &model) { + // verify magic + { + uint32_t magic; + read_safe(fin, magic); + if (magic != VOCOS_FILE_MAGIC) { + std::cerr << "Invalid file magic" << std::endl; + return false; + } + } + + // load hparams + { + auto &hparams = model.hparams; + + read_safe(fin, hparams.input_channels); + read_safe(fin, hparams.dim); + read_safe(fin, hparams.dim_intermediate); + read_safe(fin, hparams.n_layers); + read_safe(fin, hparams.adanorm_num_embeddings); + read_safe(fin, hparams.head_dim); + read_safe(fin, hparams.n_fft); + read_safe(fin, hparams.hop_length); + read_safe(fin, hparams.ftype); + + printf("%s: input_channels = %d\n", __func__, hparams.input_channels); + printf("%s: dim = %d\n", __func__, hparams.dim); + printf("%s: dim_intermediate = %d\n", __func__, hparams.dim_intermediate); + printf("%s: n_layers = %d\n", __func__, hparams.n_layers); + printf("%s: adanorm_num_embeddings = %d\n", __func__, hparams.adanorm_num_embeddings); + printf("%s: head_dim = %d\n", __func__, hparams.head_dim); + printf("%s: n_fft = %d\n", __func__, hparams.n_fft); + printf("%s: hop_length = %d\n", __func__, hparams.hop_length); + printf("%s: ftype = %d\n", __func__, hparams.ftype); + } + + // for the big tensors, we have the option to store the data in 16-bit floats or quantized + // in order to save memory and also to speed up the computation + ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype)(model.hparams.ftype)); + if (wtype == GGML_TYPE_COUNT) { + std::cerr << "Invalid model file (bad ftype value " << model.hparams.ftype << ")" << std::endl; + return false; + } + + auto &ctx = model.ctx; + + size_t buffer_size = 0; + size_t n_tensors = 0; + + // Evaluating context size + { + const auto &hparams = model.hparams; + + const int input_channels = hparams.input_channels; + const int dim = hparams.dim; + const int dim_intermediate = hparams.dim_intermediate; + const int n_layers = hparams.n_layers; + const int adanorm_num_embeddings = hparams.adanorm_num_embeddings; + const int head_dim = hparams.head_dim; + const int n_fft = hparams.n_fft; + const int hop_length = hparams.hop_length; + + // backbone + buffer_size += input_channels * dim * 7 * ggml_type_size(wtype); // embed_w + buffer_size += dim * ggml_type_size(GGML_TYPE_F32); // embed_b + + buffer_size += 2 * dim * ggml_type_size(GGML_TYPE_F32); // final_layer_norm + buffer_size += 2 * dim * adanorm_num_embeddings * ggml_type_size(GGML_TYPE_F32); // norm_scale and norm_shift + + buffer_size += n_layers * dim * dim * 7 * ggml_type_size(wtype); // dwconv_w + buffer_size += n_layers * dim * ggml_type_size(GGML_TYPE_F32); // dwconv_b + buffer_size += n_layers * 2 * dim * ggml_type_size(GGML_TYPE_F32); // gamma + buffer_size += n_layers * dim * adanorm_num_embeddings * ggml_type_size(wtype); // norm_scale and norm_shift + buffer_size += n_layers * 2 * dim * dim_intermediate * ggml_type_size(wtype); // pwconv1_w and pwconv2_w + buffer_size += n_layers * dim * ggml_type_size(GGML_TYPE_F32); // pwconv1_b + buffer_size += n_layers * dim_intermediate * ggml_type_size(GGML_TYPE_F32); // pwconv2_b + + n_tensors += 6 + n_layers * 9; + + // feature extactor + buffer_size += 16384 * input_channels * ggml_type_size(GGML_TYPE_F32); // TODO(PAB): hardcoded value! + n_tensors++; + + // head + buffer_size += dim * (n_fft + 2) * ggml_type_size(wtype); // proj_out_w + buffer_size += (n_fft + 2) * ggml_type_size(GGML_TYPE_F32); // proj_out_b + buffer_size += n_fft * ggml_type_size(GGML_TYPE_F32); // istft_window + + n_tensors += 3; + + buffer_size += 10ull * MB; // object overhead + + printf("%s: ggml tensor size = %d bytes\n", __func__, (int)sizeof(ggml_tensor)); + printf("%s: backend buffer size = %6.2f MB\n", __func__, buffer_size / (1024.0 * 1024.0)); + } + + // create the ggml context + { + struct ggml_init_params params = { + /* .mem_size = */ ggml_tensor_overhead() * n_tensors, + /* .mem_buffer = */ NULL, + /* .no_alloc = */ true, + }; + + model.ctx = ggml_init(params); + if (!model.ctx) { + std::cerr << __func__ << ": ggml_init() failed" << std::endl; + return false; + } + } + + if (!model.backend) { + // fallback to CPU backend + std::cerr << __func__ << ": using CPU backend" << std::endl; + model.backend = ggml_backend_cpu_init(); + } + + if (!model.backend) { + std::cerr << __func__ << ": ggml_backend_cpu_init() failed" << std::endl; + return false; + } + + // allocate weights buffer + model.buffer_w = ggml_backend_alloc_buffer(model.backend, buffer_size); + + // prepare memory for the weights + { + const auto &hparams = model.hparams; + + const int input_channels = hparams.input_channels; + const int dim = hparams.dim; + const int dim_intermediate = hparams.dim_intermediate; + const int n_layers = hparams.n_layers; + const int adanorm_num_embeddings = hparams.adanorm_num_embeddings; + const int head_dim = hparams.head_dim; + const int n_fft = hparams.n_fft; + const int hop_length = hparams.hop_length; + + // backbone + { + model.backbone.layers.resize(n_layers); + + model.backbone.embed_w = ggml_new_tensor_3d(ctx, wtype, 7, input_channels, dim); + model.backbone.embed_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim); + + model.tensors["backbone/embed/w"] = model.backbone.embed_w; + model.tensors["backbone/embed/b"] = model.backbone.embed_b; + + model.backbone.norm_scale = ggml_new_tensor_2d(ctx, wtype, dim, adanorm_num_embeddings); + model.backbone.norm_shift = ggml_new_tensor_2d(ctx, wtype, dim, adanorm_num_embeddings); + + model.tensors["backbone/norm/scale/w"] = model.backbone.norm_scale; + model.tensors["backbone/norm/shift/w"] = model.backbone.norm_shift; + + model.backbone.final_ln_w = ggml_new_tensor_1d(ctx, wtype, dim); + model.backbone.final_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim); + + model.tensors["backbone/final_layer_norm/w"] = model.backbone.final_ln_w; + model.tensors["backbone/final_layer_norm/b"] = model.backbone.final_ln_b; + + for (int i = 0; i < n_layers; i++) { + auto &layer = model.backbone.layers[i]; + + layer.dwconv_w = ggml_new_tensor_3d(ctx, wtype, 7, 1, 384); + layer.dwconv_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim); + + model.tensors["backbone/convnext/" + std::to_string(i) + "/dwconv/w"] = layer.dwconv_w; + model.tensors["backbone/convnext/" + std::to_string(i) + "/dwconv/b"] = layer.dwconv_b; + + layer.gamma = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim); + + model.tensors["backbone/convnext/" + std::to_string(i) + "/gamma"] = layer.gamma; + + layer.norm_scale = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, adanorm_num_embeddings); + layer.norm_shift = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, adanorm_num_embeddings); + + model.tensors["backbone/convnext/" + std::to_string(i) + "/norm/scale"] = layer.norm_scale; + model.tensors["backbone/convnext/" + std::to_string(i) + "/norm/shift"] = layer.norm_shift; + + layer.pwconv1_w = ggml_new_tensor_2d(ctx, wtype, dim, dim_intermediate); + layer.pwconv1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim_intermediate); + + model.tensors["backbone/convnext/" + std::to_string(i) + "/pwconv/1/w"] = layer.pwconv1_w; + model.tensors["backbone/convnext/" + std::to_string(i) + "/pwconv/1/b"] = layer.pwconv1_b; + + layer.pwconv2_w = ggml_new_tensor_2d(ctx, wtype, dim_intermediate, dim); + layer.pwconv2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim); + + model.tensors["backbone/convnext/" + std::to_string(i) + "/pwconv/2/w"] = layer.pwconv2_w; + model.tensors["backbone/convnext/" + std::to_string(i) + "/pwconv/2/b"] = layer.pwconv2_b; + } + } + + // feature extractor + { + // TODO (PAB): careful with hardcoded + model.feature_extractor.codebook_weights = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, input_channels, 16384); + model.tensors["feature_extractor/codebook_weights"] = model.feature_extractor.codebook_weights; + } + + // head + { + model.head.istft_window = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_fft); + model.tensors["head/istft/window"] = model.head.istft_window; + + model.head.proj_out_w = ggml_new_tensor_2d(ctx, wtype, dim, n_fft + 2); + model.head.proj_out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_fft + 2); + + model.tensors["head/out/w"] = model.head.proj_out_w; + model.tensors["head/out/b"] = model.head.proj_out_b; + } + } + + // load weights + { + ggml_allocr *alloc = ggml_allocr_new_from_buffer(model.buffer_w); + + size_t total_size = 0; + model.n_loaded = 0; + + std::vector read_buf; + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ftype; + + read_safe(fin, n_dims); + read_safe(fin, length); + read_safe(fin, ftype); + + if (fin.eof()) { + break; + } + + int32_t nelements = 1; + int32_t ne[3] = { 1, 1, 1 }; + for (int i = 0; i < n_dims; i++) { + read_safe(fin, ne[i]); + nelements *= ne[i]; + } + + std::string name; + std::vector buf(length); + fin.read(&buf[0], buf.size()); + name.assign(&buf[0], buf.size()); + + if (model.tensors.find(name.data()) == model.tensors.end()) { + std::cerr << "Unknown tensor name: " << name << std::endl; + return false; + } + + auto tensor = model.tensors[name.data()]; + ggml_set_name(tensor, name.c_str()); + if (ggml_nelements(tensor) != nelements) { + std::cerr << "Invalid number of elements for tensor " << name << " (" << ggml_nelements(tensor) << " != " << nelements << ")" << std::endl; + return false; + } + + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) { + fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%lld, %lld, %lld], expected [%d, %d, %d]\n", + __func__, name.data(), tensor->ne[0], tensor->ne[1], tensor->ne[2], ne[0], ne[1], ne[2]); + return false; + } + + const size_t bpe = ggml_type_size(ggml_type(ftype)); + + if ((nelements * bpe) / ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { + fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", + __func__, name.data(), ggml_nbytes(tensor), nelements * bpe); + return false; + } + + ggml_allocr_alloc(alloc, tensor); + + if (ggml_backend_is_cpu(model.backend)) { + // for the CPU and Metal backends, we can read directly into the device memory + fin.read(reinterpret_cast(tensor->data), ggml_nbytes(tensor)); + } else { + // read into a temporary buffer first, then copy to device memory + read_buf.resize(ggml_nbytes(tensor)); + fin.read(read_buf.data(), ggml_nbytes(tensor)); + ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor)); + } + + total_size += ggml_nbytes(tensor); + model.n_loaded++; + } + + ggml_allocr_free(alloc); + printf("%s: model size = %8.2f MB\n", __func__, total_size / 1024.0 / 1024.0); + } + + fin.close(); + + return true; +} + +struct vocos_context *vocos_load_model(const std::string model_path) { + int64_t t_start_load_us = ggml_time_us(); + + auto fin = std::ifstream(model_path, std::ios::binary); + if (!fin) { + std::cerr << "Failed to open model file" << std::endl; + return nullptr; + } + + struct vocos_context *vctx = new vocos_context(); + + vctx->model = vocos_model(); + if (!vocos_load_model_weights(fin, vctx->model)) { + std::cerr << "Failed to load model weights" << std::endl; + delete vctx; + return nullptr; + } + + vctx->stats.t_load_us = ggml_time_us() - t_start_load_us; + + return vctx; +} + +void vocos_free(struct vocos_context *vctx) { + if (!vctx) { + return; + } + + if (vctx->model.ctx) { + ggml_free(vctx->model.ctx); + } + + if (vctx->buf_compute) { + ggml_backend_buffer_free(vctx->buf_compute); + } + + ggml_backend_buffer_free(vctx->model.buffer_w); + ggml_backend_free(vctx->model.backend); + + delete vctx; +} + +int main(int argc, char **argv) { + if (argc < 2) { + std::cerr << "Usage: " << argv[0] << " " << std::endl; + return 1; + } + + std::string model_path = argv[1]; + + struct vocos_context *vctx = vocos_load_model(model_path); + if (!vctx) { + std::cerr << "Failed to load model" << std::endl; + return 1; + } + + vocos_free(vctx); + + return 0; +} \ No newline at end of file From 25800ee312d84849ab0d7c8514e6df2fb76248e1 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Sat, 29 Jun 2024 16:13:54 +0200 Subject: [PATCH 4/7] doc: add Vocos README.md with instructions --- examples/vocos/README.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 examples/vocos/README.md diff --git a/examples/vocos/README.md b/examples/vocos/README.md new file mode 100644 index 0000000..ceb577c --- /dev/null +++ b/examples/vocos/README.md @@ -0,0 +1,20 @@ +# vocos.cpp + +High-performance inference of [Vocos]() vocoder using Encodec codes: + +- Plain C/C++ implementation without dependencies using [ggml](https://github.com/ggerganov/ggml) + +# Download Vocos weights from HuggingFace Hub + +```python + +``` + +## Build + +```bash +mkdir build +cd build +cmake .. +cmake --build . --config Release +``` From c9b851e3fb8b25cef43a8bde883d6bd41f5ee559 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Sun, 30 Jun 2024 13:57:11 +0200 Subject: [PATCH 5/7] vocos: forward pass --- examples/vocos/vocos.cpp | 443 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 436 insertions(+), 7 deletions(-) diff --git a/examples/vocos/vocos.cpp b/examples/vocos/vocos.cpp index 2ffa105..18e9adf 100644 --- a/examples/vocos/vocos.cpp +++ b/examples/vocos/vocos.cpp @@ -19,11 +19,32 @@ Author: Pierre-Antoine Bannier #include #include "encodec.h" +#include "common.h" #define VOCOS_FILE_MAGIC 'ggml' static const size_t MB = 1024 * 1024; +struct vocos_params { + // Number of threads used for inference + int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + + // Target bandwidth + int32_t bandwidth_id = 2; + + // Input location + std::string input_path = "input.wav"; + + // Vocos weights location + std::string vocos_model_path = "./vocos/ggml-model.bin"; + + // Encodec weights location + std::string encodec_model_path = "./encodec/ggml-model.bin"; + + // Output location + std::string output_path = "output.wav"; +}; + struct vocos_hparams { // Number of input channels in backbone int32_t input_channels; @@ -42,6 +63,9 @@ struct vocos_hparams { // Hop length int32_t hop_length; + // Bandwidth identifier + int32_t bandwidth_id; + // File type of model weights int32_t ftype; }; @@ -119,8 +143,7 @@ struct vocos_context { struct ggml_allocr * allocr = NULL; // intermediate steps - struct ggml_tensor * features = NULL; - struct ggml_tensor * codes = NULL; + struct ggml_tensor * encoded = NULL; struct ggml_tensor * decoded = NULL; std::vector out_codes; @@ -128,6 +151,10 @@ struct vocos_context { // statistics struct vocos_statistics stats; + + // parameters + int32_t n_threads; + std::string encodec_path; }; typedef enum { @@ -144,6 +171,14 @@ static void read_safe(std::ifstream &fin, T &dest) { fin.read((char *)&dest, sizeof(T)); } +const struct vocos_statistics* vocos_get_statistics(struct vocos_context *vctx) { + if (!vctx) { + fprintf(stderr, "%s: null context\n", __func__); + return nullptr; + } + return &vctx->stats; +} + bool vocos_load_model_weights(std::ifstream &fin, struct vocos_model &model) { // verify magic { @@ -465,6 +500,318 @@ struct vocos_context *vocos_load_model(const std::string model_path) { return vctx; } +struct ggml_tensor *vocos_ada_layer_norm( + struct ggml_context *ctx0, + struct ggml_tensor *inp, + struct ggml_tensor *scale_w, + struct ggml_tensor *shift_w, + struct ggml_tensor *cond_embedding_id) { + + struct ggml_tensor * scale = ggml_get_rows(ctx0, scale_w, cond_embedding_id); + struct ggml_tensor * shift = ggml_get_rows(ctx0, shift_w, cond_embedding_id); + + struct ggml_tensor * norm = ggml_norm(ctx0, inp, 1e-5 /* eps */); + struct ggml_tensor * out = ggml_add(ctx0, ggml_mul(ctx0, norm, scale), shift); + + return out; +} + +struct ggml_tensor *vocos_forward_encoder( + struct vocos_context *vctx, + struct ggml_context *ctx0, + struct ggml_tensor *inp) { + if (!inp) { + fprintf(stderr, "%s: invalid input tensor\n", __func__); + return nullptr; + } + + const int T = inp->ne[0]; + const int n_q = inp->ne[1]; + + const int n_bins = 1024; // TODO (PAB): hardcoded + + const auto &model = vctx->model.feature_extractor; + const auto &allocr = vctx->allocr; + + // offsets: [n_q] + struct ggml_tensor *offsets = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_q); + if (!ggml_allocr_is_measure(allocr)) { + for (int32_t i = 0; i < n_q; i++) { + int32_t v = i * n_bins; + ggml_backend_tensor_set(offsets, &v, i * sizeof(int32_t), sizeof(i)); + } + } + + // inp: [n_bins, n_q] + // embeddings_idxs: [n_q, n_bins] + struct ggml_tensor *embeddings_idxs = ggml_add(ctx0, inp, offsets); + // [n_q, n_bins, dim] + struct ggml_tensor *features = ggml_get_rows(ctx0, model.codebook_weights, embeddings_idxs); + // [n_bins, dim] + features = ggml_sum_rows(ctx0, features); + + return features; +} + +struct ggml_tensor *vocos_forward_decoder( + struct vocos_context *vctx, + struct ggml_context *ctx0, + struct ggml_tensor *encoded, + struct ggml_tensor *bandwidth_id) { + if (!encoded) { + fprintf(stderr, "%s: invalid input tensor\n", __func__); + return nullptr; + } + + const auto &model = vctx->model; + const auto &backbone = model.backbone; + const auto &head = model.head; + + const auto &hparams = model.hparams; + const int n_layers = hparams.n_layers; + + // backbone + + struct ggml_tensor *emb = ggml_conv_1d( + ctx0, backbone.embed_w, encoded, 1 /* s0 */, 3 /* p0 */, 1 /* d0 */); + emb = ggml_add(ctx0, emb, backbone.embed_b); + + emb = vocos_ada_layer_norm(ctx0, emb, backbone.norm_scale, backbone.norm_shift, bandwidth_id); + + struct ggml_tensor *res = emb; + + for (int i = 0; i < n_layers; i++) { + auto &layer = backbone.layers[i]; + + // TODO (PAB): depth wise (groups=dim) + struct ggml_tensor *dwconv = ggml_conv_1d( + ctx0, layer.dwconv_w, res, 1 /* s0 */, 3 /* p0 */, 1 /* d0 */); + dwconv = ggml_add(ctx0, dwconv, layer.dwconv_b); + + dwconv = vocos_ada_layer_norm(ctx0, dwconv, layer.norm_scale, layer.norm_shift, bandwidth_id); + + struct ggml_tensor *pwconv1 = ggml_conv_1d( + ctx0, layer.pwconv1_w, dwconv, 1 /* s0 */, 0 /* p0 */, 1 /* d0 */); + pwconv1 = ggml_add(ctx0, pwconv1, layer.pwconv1_b); + + pwconv1 = ggml_gelu(ctx0, pwconv1); + + struct ggml_tensor *pwconv2 = ggml_conv_1d( + ctx0, layer.pwconv2_w, pwconv1, 1 /* s0 */, 0 /* p0 */, 1 /* d0 */); + pwconv2 = ggml_add(ctx0, pwconv2, layer.pwconv2_b); + + pwconv2 = ggml_mul(ctx0, pwconv2, layer.gamma); + + res = ggml_add(ctx0, res, pwconv2); + } + + struct ggml_tensor * out = ggml_norm(ctx0, res, 1e-5 /* eps */); + out = ggml_mul(ctx0, out, backbone.final_ln_w); + out = ggml_add(ctx0, out, backbone.final_ln_b); + + // head + // out = istft_head_forward(ctx0, out); + + return out; +} + +struct ggml_cgraph *vocos_build_graph( + struct vocos_context *vctx, + const std::vector codes, + const vocos_run_mode mode) { + assert(mode == vocos_run_mode::full || mode == vocos_run_mode::encode); + + const auto &model = vctx->model; + const auto &hparams = model.hparams; + const auto &allocr = vctx->allocr; + + const int n_q = 8; // TODO (PAB): hardcoded + const int T = codes.size() / n_q; + + // since we are using ggml-alloc, this buffer only needs enough space to hold the + // ggml_tensor and ggml_cgraph structs, but not the tensor data + static size_t buf_size = ggml_tensor_overhead() * GGML_MAX_NODES + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params ggml_params = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, // skip allocating as we use ggml_alloc to allocate exact memory requirements + }; + + struct ggml_context *ctx0 = ggml_init(ggml_params); + + struct ggml_cgraph *gf = ggml_new_graph(ctx0); + + struct ggml_tensor *inp = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, T, n_q); + ggml_allocr_alloc(allocr, inp); + + // avoid writing to tensors if we are only measuring the memory usage + if (!ggml_allocr_is_measure(allocr)) { + ggml_backend_tensor_set(inp, codes.data(), 0, codes.size() * ggml_element_size(inp)); + } + + struct ggml_tensor *encoded = vocos_forward_encoder(vctx, ctx0, inp); + // struct ggml_tensor *decoded = vocos_forward_decoder(vctx, ctx0, encoded); + + switch (mode) { + case vocos_run_mode::full: { + // ggml_build_forward_expand(gf, decoded); + } break; + case vocos_run_mode::encode: { + ggml_build_forward_expand(gf, encoded); + } break; + case vocos_run_mode::decode: { + return NULL; + } break; + default: { + fprintf(stderr, "%s: unknown run mode\n", __func__); + return NULL; + } break; + } + + ggml_free(ctx0); + + vctx->encoded = encoded; + // vctx->decoded = decoded; + + return gf; +} + +bool vocos_eval_internal( + struct vocos_context *vctx, + const std::vector codes, + const int n_threads, + const vocos_run_mode mode) { + auto &model = vctx->model; + auto &allocr = vctx->allocr; + + // reset the allocator to free all the memory allocated during the previous inference + ggml_allocr_reset(allocr); + + struct ggml_cgraph *gf = vocos_build_graph(vctx, codes, mode); + + // allocate tensors + ggml_allocr_alloc_graph(allocr, gf); + + // run the computation + if (ggml_backend_is_cpu(model.backend)) { + ggml_backend_cpu_set_n_threads(model.backend, n_threads); + } +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(model.backend)) { + ggml_backend_metal_set_n_cb(model.backend, n_threads); + } +#endif + ggml_backend_graph_compute(model.backend, gf); + + return true; +} + +std::vector get_encodec_codes(struct vocos_context *vctx, const float *raw_audio, int n_samples) { + struct encodec_context * ectx = encodec_load_model(vctx->encodec_path.c_str(), 0, 0); + if (!ectx) { + printf("%s: failed to load encodec model\n", __func__); + return std::vector(); + } + + const auto & hparams = vctx->model.hparams; + if (hparams.bandwidth_id < 0 || hparams.bandwidth_id > 4) { + printf("%s: invalid bandwidth id\n", __func__); + return std::vector(); + } + + // const float bandwidths[4] = { 1.5, 3.0, 6.0, 12.0 }; + // encodec_set_target_bandwidth(ectx, bandwidths[hparams.bandwidth_id]); + + encodec_set_target_bandwidth(ectx, 6); + + if (!encodec_compress_audio(ectx, raw_audio, n_samples, vctx->n_threads)) { + printf("%s: failed to compress audio\n", __func__); + return std::vector(); + } + + int32_t * codes_data = encodec_get_codes(ectx); + int n_codes = encodec_get_codes_size(ectx); + std::vector codes_arr(codes_data, codes_data + n_codes); + + return codes_arr; +} + +bool vocos_eval( + struct vocos_context *vctx, + const float *raw_audio, + const int n_samples, + const int n_threads, + const vocos_run_mode mode) { + const int64_t t_start_us = ggml_time_us(); + + // Encodec forward pass, shape [n_q, T] + // n_q depends on the bandwidth and the sample rate + std::vector codes = get_encodec_codes(vctx, raw_audio, n_samples); + + // allocate the compute buffer + { + // alignment required by the backend + size_t align = ggml_backend_get_alignment(vctx->model.backend); + vctx->allocr = ggml_allocr_new_measure(align); + + // create the graph for memory usage estimation + struct ggml_cgraph *gf = vocos_build_graph(vctx, codes, mode); + + // compute the required memory + size_t mem_size = ggml_allocr_alloc_graph(vctx->allocr, gf); + + // recreate the allocator with the required memory + ggml_allocr_free(vctx->allocr); + vctx->buf_compute = ggml_backend_alloc_buffer(vctx->model.backend, mem_size); + vctx->allocr = ggml_allocr_new_from_buffer(vctx->buf_compute); + + fprintf(stderr, "%s: compute buffer size: %.2f MB\n\n", __func__, mem_size / 1024.0 / 1024.0); + } + + // encodec eval + if (!vocos_eval_internal(vctx, codes, n_threads, mode)) { + fprintf(stderr, "%s: failed to run encodec eval\n", __func__); + return false; + } + + vctx->stats.t_compute_us = ggml_time_us() - t_start_us; + + return true; +} + +bool vocos_reconstruct_audio( + struct vocos_context *vctx, + const float *raw_audio, + const int n_samples, + int n_threads) { + if (raw_audio == nullptr) { + std::cerr << "Invalid raw audio buffer" << std::endl; + return false; + } + + if (!vocos_eval(vctx, raw_audio, n_samples, n_threads, vocos_run_mode::full)) { + std::cerr << "Failed to evaluate model" << std::endl; + return false; + } + + if (!vctx->decoded) { + std::cerr << "Failed to reconstruct audio" << std::endl; + return false; + } + + struct ggml_tensor *decoded = vctx->decoded; + auto &out_audio = vctx->out_audio; + + int out_length = decoded->ne[0]; + out_audio.resize(out_length); + + ggml_backend_tensor_get(decoded, out_audio.data(), 0, out_length * ggml_type_size(decoded->type)); + + return true; +} + void vocos_free(struct vocos_context *vctx) { if (!vctx) { return; @@ -484,19 +831,101 @@ void vocos_free(struct vocos_context *vctx) { delete vctx; } +void vocos_print_usage(char ** argv, const vocos_params ¶ms) { + fprintf(stderr, "usage: %s [options]\n", argv[0]); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); + fprintf(stderr, " -b N, --bandwidth_id N Target bandwidth identifier (default: %d)\n", params.bandwidth_id); + fprintf(stderr, " -vm FNAME, --vocos_model FNAME\n"); + fprintf(stderr, " Vocos model path (default: %s)\n", params.vocos_model_path.c_str()); + fprintf(stderr, " -em FNAME, --encodec_model FNAME\n"); + fprintf(stderr, " Encodec model path (default: %s)\n", params.encodec_model_path.c_str()); + fprintf(stderr, " -i FNAME, --input FNAME\n"); + fprintf(stderr, " original audio wav (default: %s)\n", params.input_path.c_str()); + fprintf(stderr, " -o FNAME, --outwav FNAME\n"); + fprintf(stderr, " output generated wav (default: %s)\n", params.output_path.c_str()); + fprintf(stderr, "\n"); +} + +int vocos_params_parse(int argc, char ** argv, vocos_params ¶ms) { + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + + if (arg == "-t" || arg == "--threads") { + params.n_threads = std::stoi(argv[++i]); + } else if (arg == "-b" || arg == "--bandwidth_id") { + params.bandwidth_id = std::stoi(argv[++i]); + } else if (arg == "-vm" || arg == "--vocos_model") { + params.vocos_model_path = argv[++i]; + } else if (arg == "-em" || arg == "--encodec_model") { + params.encodec_model_path = argv[++i]; + } else if (arg == "-o" || arg == "--outwav") { + params.output_path = argv[++i]; + } else if (arg == "-i" || arg == "--input") { + params.input_path = argv[++i]; + } else if (arg == "-h" || arg == "--help") { + vocos_print_usage(argv, params); + exit(0); + } else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + vocos_print_usage(argv, params); + exit(0); + } + } + + return 0; +} + int main(int argc, char **argv) { - if (argc < 2) { - std::cerr << "Usage: " << argv[0] << " " << std::endl; + ggml_time_init(); + const int64_t t_main_start_us = ggml_time_us(); + + vocos_params params; + + if (vocos_params_parse(argc, argv, params) > 0) { + fprintf(stderr, "%s: Could not parse arguments\n", __func__); return 1; } - std::string model_path = argv[1]; - - struct vocos_context *vctx = vocos_load_model(model_path); + struct vocos_context *vctx = vocos_load_model(params.vocos_model_path); if (!vctx) { std::cerr << "Failed to load model" << std::endl; return 1; } + vctx->encodec_path = params.encodec_model_path; + vctx->model.hparams.bandwidth_id = params.bandwidth_id; + + // read audio from disk + std::vector original_audio_arr; + if (!read_wav_from_disk(params.input_path, original_audio_arr)) { + std::cerr << "Failed to read audio from disk" << std::endl; + return 1; + } + + // reconstruct audio + if (!vocos_reconstruct_audio(vctx, original_audio_arr.data(), original_audio_arr.size(), params.n_threads)) { + std::cerr << "Failed to reconstruct audio" << std::endl; + return 1; + } + + // write reconstructed audio on disk + float * audio_data = vctx->out_audio.data(); + std::vector audio_arr(audio_data, audio_data + vctx->out_audio.size()); + audio_arr.resize(original_audio_arr.size()); + write_wav_on_disk(audio_arr, params.output_path); + + // report timing + { + const int64_t t_main_end_us = ggml_time_us(); + const vocos_statistics * stats = vocos_get_statistics(vctx); + + printf("\n\n"); + printf("%s: load time = %8.2f ms\n", __func__, stats->t_load_us/1000.0f); + printf("%s: eval time = %8.2f ms\n", __func__, stats->t_compute_us/1000.0f); + printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); + } vocos_free(vctx); From 55358ea70f26932170f103363c8e57643e62a3b6 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Tue, 2 Jul 2024 17:37:04 +0200 Subject: [PATCH 6/7] forward pass --- examples/vocos/vocos.cpp | 152 +++++++++++++++++++-------------------- 1 file changed, 75 insertions(+), 77 deletions(-) diff --git a/examples/vocos/vocos.cpp b/examples/vocos/vocos.cpp index 18e9adf..27dec82 100644 --- a/examples/vocos/vocos.cpp +++ b/examples/vocos/vocos.cpp @@ -25,6 +25,10 @@ Author: Pierre-Antoine Bannier static const size_t MB = 1024 * 1024; +static void print_tensor(struct ggml_tensor *t) { + printf("tensor %s: %lld %lld %lld\n", t->name, t->ne[0], t->ne[1], t->ne[2]); +} + struct vocos_params { // Number of threads used for inference int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); @@ -146,7 +150,7 @@ struct vocos_context { struct ggml_tensor * encoded = NULL; struct ggml_tensor * decoded = NULL; - std::vector out_codes; + std::vector features; std::vector out_audio; // statistics @@ -517,37 +521,43 @@ struct ggml_tensor *vocos_ada_layer_norm( } struct ggml_tensor *vocos_forward_encoder( - struct vocos_context *vctx, - struct ggml_context *ctx0, - struct ggml_tensor *inp) { - if (!inp) { - fprintf(stderr, "%s: invalid input tensor\n", __func__); + struct vocos_context * vctx, + struct ggml_context * ctx0, + struct ggml_tensor * codes) { + + if (!codes) { + fprintf(stderr, "%s: invalid codes tensor\n", __func__); return nullptr; } - const int T = inp->ne[0]; - const int n_q = inp->ne[1]; + const auto & model = vctx->model.feature_extractor; + const auto & allocr = vctx->allocr; - const int n_bins = 1024; // TODO (PAB): hardcoded + const int seq_length = codes->ne[0]; + const int n_q = codes->ne[1]; + const int dim = model.codebook_weights->ne[0]; - const auto &model = vctx->model.feature_extractor; - const auto &allocr = vctx->allocr; + // codes: [seq_length, n_q] -> [n_q, seq_length] + codes = ggml_transpose(ctx0, codes); - // offsets: [n_q] - struct ggml_tensor *offsets = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_q); - if (!ggml_allocr_is_measure(allocr)) { - for (int32_t i = 0; i < n_q; i++) { - int32_t v = i * n_bins; - ggml_backend_tensor_set(offsets, &v, i * sizeof(int32_t), sizeof(i)); - } + struct ggml_tensor *features = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, dim, n_q, seq_length); + ggml_allocr_alloc(allocr, features); + + for (int t = 0; t < seq_length; t++) { + // [n_q] + size_t offset = t * codes->nb[1]; + struct ggml_tensor *idxs = ggml_view_1d(ctx0, codes, n_q, offset); + + // [dim, n_q] + struct ggml_tensor *f_t = ggml_get_rows(ctx0, model.codebook_weights, idxs); + + features = ggml_set_2d(ctx0, features, f_t, features->nb[1], t*features->nb[2]); } - // inp: [n_bins, n_q] - // embeddings_idxs: [n_q, n_bins] - struct ggml_tensor *embeddings_idxs = ggml_add(ctx0, inp, offsets); - // [n_q, n_bins, dim] - struct ggml_tensor *features = ggml_get_rows(ctx0, model.codebook_weights, embeddings_idxs); - // [n_bins, dim] + // [dim, n_q, seq_length] -> [n_q, dim, seq_length] + features = ggml_cont(ctx0, ggml_permute(ctx0, features, 1, 0, 2, 3)); + + // [1, dim, seq_length] features = ggml_sum_rows(ctx0, features); return features; @@ -621,12 +631,13 @@ struct ggml_cgraph *vocos_build_graph( const vocos_run_mode mode) { assert(mode == vocos_run_mode::full || mode == vocos_run_mode::encode); - const auto &model = vctx->model; - const auto &hparams = model.hparams; - const auto &allocr = vctx->allocr; + const auto & model = vctx->model; + const auto & hparams = model.hparams; + const auto & allocr = vctx->allocr; - const int n_q = 8; // TODO (PAB): hardcoded - const int T = codes.size() / n_q; + const int n_q = 8; // TODO (PAB): hardcoded + const int n_bins = 1024; // TODO (PAB): hardcoded + const int seq_length = codes.size() / n_q; // since we are using ggml-alloc, this buffer only needs enough space to hold the // ggml_tensor and ggml_cgraph structs, but not the tensor data @@ -636,19 +647,25 @@ struct ggml_cgraph *vocos_build_graph( struct ggml_init_params ggml_params = { /*.mem_size =*/ buf_size, /*.mem_buffer =*/ buf.data(), - /*.no_alloc =*/ true, // skip allocating as we use ggml_alloc to allocate exact memory requirements + /*.no_alloc =*/ true, }; struct ggml_context *ctx0 = ggml_init(ggml_params); struct ggml_cgraph *gf = ggml_new_graph(ctx0); - struct ggml_tensor *inp = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, T, n_q); + struct ggml_tensor *inp = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, seq_length, n_q); ggml_allocr_alloc(allocr, inp); - // avoid writing to tensors if we are only measuring the memory usage if (!ggml_allocr_is_measure(allocr)) { - ggml_backend_tensor_set(inp, codes.data(), 0, codes.size() * ggml_element_size(inp)); + ggml_backend_tensor_set(inp, codes.data(), 0, codes.size() * sizeof(int32_t)); + + // add offsets of shape [n_q] broadcasted to inp + // inp + offsets + for (int i = 0; i < seq_length*n_q; i++) { + int32_t v = (i / seq_length) * n_bins; + ggml_backend_tensor_set(inp, &v, i * sizeof(int32_t), sizeof(int32_t)); + } } struct ggml_tensor *encoded = vocos_forward_encoder(vctx, ctx0, inp); @@ -698,17 +715,13 @@ bool vocos_eval_internal( if (ggml_backend_is_cpu(model.backend)) { ggml_backend_cpu_set_n_threads(model.backend, n_threads); } -#ifdef GGML_USE_METAL - if (ggml_backend_is_metal(model.backend)) { - ggml_backend_metal_set_n_cb(model.backend, n_threads); - } -#endif + ggml_backend_graph_compute(model.backend, gf); return true; } -std::vector get_encodec_codes(struct vocos_context *vctx, const float *raw_audio, int n_samples) { +std::vector get_encodec_codes(struct vocos_context *vctx, const std::vector raw_audio) { struct encodec_context * ectx = encodec_load_model(vctx->encodec_path.c_str(), 0, 0); if (!ectx) { printf("%s: failed to load encodec model\n", __func__); @@ -721,12 +734,9 @@ std::vector get_encodec_codes(struct vocos_context *vctx, const float * return std::vector(); } - // const float bandwidths[4] = { 1.5, 3.0, 6.0, 12.0 }; - // encodec_set_target_bandwidth(ectx, bandwidths[hparams.bandwidth_id]); - encodec_set_target_bandwidth(ectx, 6); - if (!encodec_compress_audio(ectx, raw_audio, n_samples, vctx->n_threads)) { + if (!encodec_compress_audio(ectx, raw_audio.data(), raw_audio.size(), vctx->n_threads)) { printf("%s: failed to compress audio\n", __func__); return std::vector(); } @@ -740,15 +750,14 @@ std::vector get_encodec_codes(struct vocos_context *vctx, const float * bool vocos_eval( struct vocos_context *vctx, - const float *raw_audio, - const int n_samples, + const std::vector raw_audio, const int n_threads, const vocos_run_mode mode) { const int64_t t_start_us = ggml_time_us(); // Encodec forward pass, shape [n_q, T] // n_q depends on the bandwidth and the sample rate - std::vector codes = get_encodec_codes(vctx, raw_audio, n_samples); + std::vector codes = get_encodec_codes(vctx, raw_audio); // allocate the compute buffer { @@ -770,45 +779,38 @@ bool vocos_eval( fprintf(stderr, "%s: compute buffer size: %.2f MB\n\n", __func__, mem_size / 1024.0 / 1024.0); } - // encodec eval if (!vocos_eval_internal(vctx, codes, n_threads, mode)) { fprintf(stderr, "%s: failed to run encodec eval\n", __func__); return false; } + auto &features = vctx->features; + + int out_length = ggml_nelements(vctx->encoded); + features.resize(out_length); + + ggml_backend_tensor_get(vctx->encoded, features.data(), 0, out_length * ggml_type_size(vctx->encoded->type)); + + float sum = 0.0f; + for (int i = 0; i < out_length; i++) { + sum += features[i]; + } + printf("%s: sum = %f\n", __func__, sum); + vctx->stats.t_compute_us = ggml_time_us() - t_start_us; return true; } bool vocos_reconstruct_audio( - struct vocos_context *vctx, - const float *raw_audio, - const int n_samples, - int n_threads) { - if (raw_audio == nullptr) { - std::cerr << "Invalid raw audio buffer" << std::endl; - return false; - } - - if (!vocos_eval(vctx, raw_audio, n_samples, n_threads, vocos_run_mode::full)) { + struct vocos_context *vctx, + const std::vector raw_audio, + int n_threads) { + if (!vocos_eval(vctx, raw_audio, n_threads, vocos_run_mode::encode)) { std::cerr << "Failed to evaluate model" << std::endl; return false; } - if (!vctx->decoded) { - std::cerr << "Failed to reconstruct audio" << std::endl; - return false; - } - - struct ggml_tensor *decoded = vctx->decoded; - auto &out_audio = vctx->out_audio; - - int out_length = decoded->ne[0]; - out_audio.resize(out_length); - - ggml_backend_tensor_get(decoded, out_audio.data(), 0, out_length * ggml_type_size(decoded->type)); - return true; } @@ -904,18 +906,14 @@ int main(int argc, char **argv) { return 1; } + original_audio_arr.resize(50000); + // reconstruct audio - if (!vocos_reconstruct_audio(vctx, original_audio_arr.data(), original_audio_arr.size(), params.n_threads)) { + if (!vocos_reconstruct_audio(vctx, original_audio_arr, params.n_threads)) { std::cerr << "Failed to reconstruct audio" << std::endl; return 1; } - // write reconstructed audio on disk - float * audio_data = vctx->out_audio.data(); - std::vector audio_arr(audio_data, audio_data + vctx->out_audio.size()); - audio_arr.resize(original_audio_arr.size()); - write_wav_on_disk(audio_arr, params.output_path); - // report timing { const int64_t t_main_end_us = ggml_time_us(); From 9165b9d25df4bcabbdffaf28eb9df942ee907e16 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Sun, 13 Oct 2024 21:14:29 +0200 Subject: [PATCH 7/7] wip --- examples/vocos/vocos.cpp | 61 ++++++++++++++++++++++------------------ 1 file changed, 33 insertions(+), 28 deletions(-) diff --git a/examples/vocos/vocos.cpp b/examples/vocos/vocos.cpp index 27dec82..fa0ef29 100644 --- a/examples/vocos/vocos.cpp +++ b/examples/vocos/vocos.cpp @@ -147,11 +147,11 @@ struct vocos_context { struct ggml_allocr * allocr = NULL; // intermediate steps - struct ggml_tensor * encoded = NULL; - struct ggml_tensor * decoded = NULL; + struct ggml_tensor * features_t = NULL; + struct ggml_tensor * out_audio_t = NULL; - std::vector features; - std::vector out_audio; + std::vector features ; + std::vector out_audio; // statistics struct vocos_statistics stats; @@ -582,10 +582,14 @@ struct ggml_tensor *vocos_forward_decoder( // backbone + // [dim, seq_length] struct ggml_tensor *emb = ggml_conv_1d( ctx0, backbone.embed_w, encoded, 1 /* s0 */, 3 /* p0 */, 1 /* d0 */); + print_tensor(emb); + print_tensor(backbone.embed_b); emb = ggml_add(ctx0, emb, backbone.embed_b); + // [dim, seq_length] emb = vocos_ada_layer_norm(ctx0, emb, backbone.norm_scale, backbone.norm_shift, bandwidth_id); struct ggml_tensor *res = emb; @@ -593,6 +597,7 @@ struct ggml_tensor *vocos_forward_decoder( for (int i = 0; i < n_layers; i++) { auto &layer = backbone.layers[i]; + // [dim, seq_length] // TODO (PAB): depth wise (groups=dim) struct ggml_tensor *dwconv = ggml_conv_1d( ctx0, layer.dwconv_w, res, 1 /* s0 */, 3 /* p0 */, 1 /* d0 */); @@ -600,18 +605,20 @@ struct ggml_tensor *vocos_forward_decoder( dwconv = vocos_ada_layer_norm(ctx0, dwconv, layer.norm_scale, layer.norm_shift, bandwidth_id); - struct ggml_tensor *pwconv1 = ggml_conv_1d( - ctx0, layer.pwconv1_w, dwconv, 1 /* s0 */, 0 /* p0 */, 1 /* d0 */); + // [intermediate_dim, seq_length] + struct ggml_tensor * pwconv1 = ggml_mul_mat(ctx0, layer.pwconv1_w, dwconv); pwconv1 = ggml_add(ctx0, pwconv1, layer.pwconv1_b); pwconv1 = ggml_gelu(ctx0, pwconv1); - struct ggml_tensor *pwconv2 = ggml_conv_1d( - ctx0, layer.pwconv2_w, pwconv1, 1 /* s0 */, 0 /* p0 */, 1 /* d0 */); + // [dim, seq_length] + struct ggml_tensor *pwconv2 = ggml_mul_mat(ctx0, layer.pwconv2_w, pwconv1); pwconv2 = ggml_add(ctx0, pwconv2, layer.pwconv2_b); + // [dim, seq_length] pwconv2 = ggml_mul(ctx0, pwconv2, layer.gamma); + // [dim, seq_length], residual connection res = ggml_add(ctx0, res, pwconv2); } @@ -635,8 +642,8 @@ struct ggml_cgraph *vocos_build_graph( const auto & hparams = model.hparams; const auto & allocr = vctx->allocr; - const int n_q = 8; // TODO (PAB): hardcoded - const int n_bins = 1024; // TODO (PAB): hardcoded + const int n_q = 8; // TODO (PAB): hardcoded + const int n_bins = 1024; // TODO (PAB): hardcoded const int seq_length = codes.size() / n_q; // since we are using ggml-alloc, this buffer only needs enough space to hold the @@ -657,29 +664,35 @@ struct ggml_cgraph *vocos_build_graph( struct ggml_tensor *inp = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, seq_length, n_q); ggml_allocr_alloc(allocr, inp); + struct ggml_tensor *bandwidth_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); + ggml_allocr_alloc(allocr, bandwidth_id); + if (!ggml_allocr_is_measure(allocr)) { ggml_backend_tensor_set(inp, codes.data(), 0, codes.size() * sizeof(int32_t)); // add offsets of shape [n_q] broadcasted to inp // inp + offsets + // TODO: can we ensure i / seq_length is floored? for (int i = 0; i < seq_length*n_q; i++) { int32_t v = (i / seq_length) * n_bins; ggml_backend_tensor_set(inp, &v, i * sizeof(int32_t), sizeof(int32_t)); } + + ggml_backend_tensor_set(bandwidth_id, &hparams.bandwidth_id, 0, sizeof(int32_t)); } - struct ggml_tensor *encoded = vocos_forward_encoder(vctx, ctx0, inp); - // struct ggml_tensor *decoded = vocos_forward_decoder(vctx, ctx0, encoded); + struct ggml_tensor * encoded = vocos_forward_encoder(vctx, ctx0, inp); + struct ggml_tensor * decoded = vocos_forward_decoder(vctx, ctx0, encoded, bandwidth_id); switch (mode) { case vocos_run_mode::full: { - // ggml_build_forward_expand(gf, decoded); + ggml_build_forward_expand(gf, decoded); } break; case vocos_run_mode::encode: { ggml_build_forward_expand(gf, encoded); } break; case vocos_run_mode::decode: { - return NULL; + ggml_build_forward_expand(gf, decoded); } break; default: { fprintf(stderr, "%s: unknown run mode\n", __func__); @@ -689,8 +702,8 @@ struct ggml_cgraph *vocos_build_graph( ggml_free(ctx0); - vctx->encoded = encoded; - // vctx->decoded = decoded; + vctx->features_t = encoded; + vctx->out_audio_t = decoded; return gf; } @@ -784,18 +797,10 @@ bool vocos_eval( return false; } - auto &features = vctx->features; + int32_t n_features = ggml_nelements(vctx->features_t); - int out_length = ggml_nelements(vctx->encoded); - features.resize(out_length); - - ggml_backend_tensor_get(vctx->encoded, features.data(), 0, out_length * ggml_type_size(vctx->encoded->type)); - - float sum = 0.0f; - for (int i = 0; i < out_length; i++) { - sum += features[i]; - } - printf("%s: sum = %f\n", __func__, sum); + vctx->features.resize(n_features); + ggml_backend_tensor_get(vctx->features_t, vctx->features.data(), 0, n_features * sizeof(float)); vctx->stats.t_compute_us = ggml_time_us() - t_start_us; @@ -806,7 +811,7 @@ bool vocos_reconstruct_audio( struct vocos_context *vctx, const std::vector raw_audio, int n_threads) { - if (!vocos_eval(vctx, raw_audio, n_threads, vocos_run_mode::encode)) { + if (!vocos_eval(vctx, raw_audio, n_threads, vocos_run_mode::full)) { std::cerr << "Failed to evaluate model" << std::endl; return false; }