diff --git a/examples/common/common.hpp b/examples/common/common.hpp index bf38379d2..b9ac7edc1 100644 --- a/examples/common/common.hpp +++ b/examples/common/common.hpp @@ -863,7 +863,7 @@ static bool is_absolute_path(const std::string& p) { struct SDGenerationParams { std::string prompt; - std::string prompt_with_lora; // for metadata record only + std::string prompt_with_lora; // for metadata record only std::string negative_prompt; int clip_skip = -1; // <= 0 represents unspecified int width = 512; diff --git a/flux.hpp b/flux.hpp index 1df2874ae..ff364ece4 100644 --- a/flux.hpp +++ b/flux.hpp @@ -233,14 +233,17 @@ namespace Flux { __STATIC_INLINE__ struct ggml_tensor* modulate(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* shift, - struct ggml_tensor* scale) { + struct ggml_tensor* scale, + bool skip_reshape = false) { // x: [N, L, C] // scale: [N, C] // shift: [N, C] - scale = ggml_reshape_3d(ctx, scale, scale->ne[0], 1, scale->ne[1]); // [N, 1, C] - shift = ggml_reshape_3d(ctx, shift, shift->ne[0], 1, shift->ne[1]); // [N, 1, C] - x = ggml_add(ctx, x, ggml_mul(ctx, x, scale)); - x = ggml_add(ctx, x, shift); + if (!skip_reshape) { + scale = ggml_reshape_3d(ctx, scale, scale->ne[0], 1, scale->ne[1]); // [N, 1, C] + shift = ggml_reshape_3d(ctx, shift, shift->ne[0], 1, shift->ne[1]); // [N, 1, C] + } + x = ggml_add(ctx, x, ggml_mul(ctx, x, scale)); + x = ggml_add(ctx, x, shift); return x; } diff --git a/qwen_image.hpp b/qwen_image.hpp index eeb823d50..4952e5438 100644 --- a/qwen_image.hpp +++ b/qwen_image.hpp @@ -191,11 +191,16 @@ namespace Qwen { }; class QwenImageTransformerBlock : public GGMLBlock { + protected: + bool zero_cond_t; + public: QwenImageTransformerBlock(int64_t dim, int64_t num_attention_heads, int64_t attention_head_dim, - float eps = 1e-6) { + float eps = 1e-6, + bool zero_cond_t = false) + : zero_cond_t(zero_cond_t) { // img_mod.0 is nn.SiLU() blocks["img_mod.1"] = std::shared_ptr(new Linear(dim, 6 * dim, true)); @@ -220,11 +225,36 @@ namespace Qwen { eps)); } + std::vector get_mod_params_vec(ggml_context* ctx, ggml_tensor* mod_params, ggml_tensor* index = nullptr) { + // index: [N, n_img_token] + // mod_params: [N, hidden_size * 12] + if (index == nullptr) { + return ggml_ext_chunk(ctx, mod_params, 6, 0); + } + auto mod_params_vec = ggml_ext_chunk(ctx, mod_params, 12, 0); + index = ggml_reshape_3d(ctx, index, 1, index->ne[0], index->ne[1]); // [N, n_img_token, 1] + index = ggml_repeat_4d(ctx, index, mod_params[0].ne[0], index->ne[0], index->ne[1], index->ne[2]); // [N, n_img_token, hidden_size] + std::vector mod_results; + for (int i = 0; i < 6; i++) { + auto mod_0 = mod_params_vec[2 * i]; + auto mod_1 = mod_params_vec[2 * i + 1]; + + // mod_result = torch.where(index == 0, mod_0, mod_1) + // mod_result = (1 - index)*mod_0 + index*mod_1 + mod_0 = ggml_sub(ctx, ggml_repeat(ctx, mod_0, index), ggml_mul(ctx, index, mod_0)); // [N, n_img_token, hidden_size] + mod_1 = ggml_mul(ctx, index, mod_1); // [N, n_img_token, hidden_size] + auto mod_result = ggml_add(ctx, mod_0, mod_1); + mod_results.push_back(mod_result); + } + return mod_results; + } + virtual std::pair forward(GGMLRunnerContext* ctx, struct ggml_tensor* img, struct ggml_tensor* txt, struct ggml_tensor* t_emb, - struct ggml_tensor* pe) { + struct ggml_tensor* pe, + struct ggml_tensor* modulate_index = nullptr) { // img: [N, n_img_token, hidden_size] // txt: [N, n_txt_token, hidden_size] // pe: [n_img_token + n_txt_token, d_head/2, 2, 2] @@ -244,14 +274,18 @@ namespace Qwen { auto img_mod_params = ggml_silu(ctx->ggml_ctx, t_emb); img_mod_params = img_mod_1->forward(ctx, img_mod_params); - auto img_mod_param_vec = ggml_ext_chunk(ctx->ggml_ctx, img_mod_params, 6, 0); + auto img_mod_param_vec = get_mod_params_vec(ctx->ggml_ctx, img_mod_params, modulate_index); + + if (zero_cond_t) { + t_emb = ggml_ext_chunk(ctx->ggml_ctx, t_emb, 2, 0)[0]; + } auto txt_mod_params = ggml_silu(ctx->ggml_ctx, t_emb); txt_mod_params = txt_mod_1->forward(ctx, txt_mod_params); - auto txt_mod_param_vec = ggml_ext_chunk(ctx->ggml_ctx, txt_mod_params, 6, 0); + auto txt_mod_param_vec = get_mod_params_vec(ctx->ggml_ctx, txt_mod_params); auto img_normed = img_norm1->forward(ctx, img); - auto img_modulated = Flux::modulate(ctx->ggml_ctx, img_normed, img_mod_param_vec[0], img_mod_param_vec[1]); + auto img_modulated = Flux::modulate(ctx->ggml_ctx, img_normed, img_mod_param_vec[0], img_mod_param_vec[1], modulate_index != nullptr); auto img_gate1 = img_mod_param_vec[2]; auto txt_normed = txt_norm1->forward(ctx, txt); @@ -264,7 +298,7 @@ namespace Qwen { txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_attn_output, txt_gate1)); auto img_normed2 = img_norm2->forward(ctx, img); - auto img_modulated2 = Flux::modulate(ctx->ggml_ctx, img_normed2, img_mod_param_vec[3], img_mod_param_vec[4]); + auto img_modulated2 = Flux::modulate(ctx->ggml_ctx, img_normed2, img_mod_param_vec[3], img_mod_param_vec[4], modulate_index != nullptr); auto img_gate2 = img_mod_param_vec[5]; auto txt_normed2 = txt_norm2->forward(ctx, txt); @@ -325,6 +359,7 @@ namespace Qwen { float theta = 10000; std::vector axes_dim = {16, 56, 56}; int64_t axes_dim_sum = 128; + bool zero_cond_t = false; }; class QwenImageModel : public GGMLBlock { @@ -346,7 +381,8 @@ namespace Qwen { auto block = std::shared_ptr(new QwenImageTransformerBlock(inner_dim, params.num_attention_heads, params.attention_head_dim, - 1e-6f)); + 1e-6f, + params.zero_cond_t)); blocks["transformer_blocks." + std::to_string(i)] = block; } @@ -421,7 +457,8 @@ namespace Qwen { struct ggml_tensor* x, struct ggml_tensor* timestep, struct ggml_tensor* context, - struct ggml_tensor* pe) { + struct ggml_tensor* pe, + struct ggml_tensor* modulate_index = nullptr) { auto time_text_embed = std::dynamic_pointer_cast(blocks["time_text_embed"]); auto txt_norm = std::dynamic_pointer_cast(blocks["txt_norm"]); auto img_in = std::dynamic_pointer_cast(blocks["img_in"]); @@ -429,6 +466,10 @@ namespace Qwen { auto norm_out = std::dynamic_pointer_cast(blocks["norm_out"]); auto proj_out = std::dynamic_pointer_cast(blocks["proj_out"]); + if (params.zero_cond_t) { + timestep = ggml_concat(ctx->ggml_ctx, timestep, ggml_scale(ctx->ggml_ctx, timestep, 0.f), 0); + } + auto t_emb = time_text_embed->forward(ctx, timestep); auto img = img_in->forward(ctx, x); auto txt = txt_norm->forward(ctx, context); @@ -437,7 +478,7 @@ namespace Qwen { for (int i = 0; i < params.num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["transformer_blocks." + std::to_string(i)]); - auto result = block->forward(ctx, img, txt, t_emb, pe); + auto result = block->forward(ctx, img, txt, t_emb, pe, modulate_index); img = result.first; txt = result.second; } @@ -453,7 +494,8 @@ namespace Qwen { struct ggml_tensor* timestep, struct ggml_tensor* context, struct ggml_tensor* pe, - std::vector ref_latents = {}) { + std::vector ref_latents = {}, + struct ggml_tensor* modulate_index = nullptr) { // Forward pass of DiT. // x: [N, C, H, W] // timestep: [N,] @@ -479,7 +521,7 @@ namespace Qwen { int64_t h_len = ((H + (params.patch_size / 2)) / params.patch_size); int64_t w_len = ((W + (params.patch_size / 2)) / params.patch_size); - auto out = forward_orig(ctx, img, timestep, context, pe); // [N, h_len*w_len, ph*pw*C] + auto out = forward_orig(ctx, img, timestep, context, pe, modulate_index); // [N, h_len*w_len, ph*pw*C] if (out->ne[1] > img_tokens) { out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size] @@ -502,6 +544,7 @@ namespace Qwen { QwenImageParams qwen_image_params; QwenImageModel qwen_image; std::vector pe_vec; + std::vector modulate_index_vec; SDVersion version; QwenImageRunner(ggml_backend_t backend, @@ -574,6 +617,31 @@ namespace Qwen { // pe->data = nullptr; set_backend_tensor_data(pe, pe_vec.data()); + ggml_tensor* modulate_index = nullptr; + if (qwen_image_params.zero_cond_t) { + modulate_index_vec.clear(); + + int64_t h_len = ((x->ne[1] + (qwen_image_params.patch_size / 2)) / qwen_image_params.patch_size); + int64_t w_len = ((x->ne[0] + (qwen_image_params.patch_size / 2)) / qwen_image_params.patch_size); + int64_t num_img_tokens = h_len * w_len; + + modulate_index_vec.insert(modulate_index_vec.end(), num_img_tokens, 0.f); + int64_t num_ref_img_tokens = 0; + for (ggml_tensor* ref : ref_latents) { + int64_t h_len = ((ref->ne[1] + (qwen_image_params.patch_size / 2)) / qwen_image_params.patch_size); + int64_t w_len = ((ref->ne[0] + (qwen_image_params.patch_size / 2)) / qwen_image_params.patch_size); + + num_ref_img_tokens += h_len * w_len; + } + + if (num_ref_img_tokens > 0) { + modulate_index_vec.insert(modulate_index_vec.end(), num_ref_img_tokens, 1.f); + } + + modulate_index = vector_to_ggml_tensor(compute_ctx, modulate_index_vec); + modulate_index = to_backend(modulate_index); + } + auto runner_ctx = get_context(); struct ggml_tensor* out = qwen_image.forward(&runner_ctx,