diff --git a/requirements.txt b/requirements.txt index ced32b9b34b..bcc7a984bac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,4 +29,4 @@ pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off -DMSGPACK_BUILD_EXAMPLES=Off -DCMAKE_POLICY_VERSION_MINIMUM=3.5 sqlite3@3.50.4 -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/composable_kernel@b7775add2d28251674d81e220cd4a857b90b997a -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/rocMLIR@16ccb523bb3d8af67796c12ab9c15c9b69d69b58 -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off +ROCm/rocMLIR@5878fd6a5f4ec92101f9cd3d279e6240086cb3b4 -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index 8912b186cdb..d403e4154cf 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -137,6 +138,34 @@ inline auto pointwise_inputs() }; } +// Helper function to extract the two gemm operations from an attention submodule +// Returns {gemm1, gemm2} where gemm1 is Q@K and gemm2 is P@V +inline std::pair get_attention_gemms(module_ref submod) +{ + std::vector gemms; + for(auto it = submod->begin(); it != submod->end(); ++it) + { + if(it->name() == "dot") + gemms.push_back(it); + } + assert(gemms.size() == 2 and "Expected exactly 2 gemm operations in attention submodule"); + + // gemms[0] is Q@K, gemms[1] is P@V + // gemms are in order since we iterate from begin to end + return {gemms[0], gemms[1]}; +} + +// Helper function to map submodule parameters to main module inputs +inline std::unordered_map +map_submod_params_to_inputs(module_ref submod, const std::vector& group_inputs) +{ + auto map_param_to_main = submod->get_ins_param_map(group_inputs, true); + // verify the mapping is correct + auto expected_inputs = submod->get_inputs(map_param_to_main); + assert(expected_inputs == group_inputs and "Mapped inputs don't match group inputs"); + return map_param_to_main; +} + struct find_attention { std::size_t* counter; @@ -288,6 +317,733 @@ struct find_attention } }; +struct find_gqa_flash_decoding +{ + std::size_t groups; + + // Struct to hold all attention dimensions + struct attention_dims + { + std::size_t batch_size; + std::size_t num_heads; // Q heads + std::size_t kv_heads; // K and V heads + std::size_t concat_heads; // total heads in QKV tensor + std::size_t sequence_length; + std::size_t max_seq_length; + std::size_t head_dim; + std::size_t seq_length_per_group; // sequence length per group after splitting max sequence length + + // constructor from parameters + attention_dims(instruction_ref q_param, instruction_ref k_param, std::size_t num_groups) + { + auto q_shape = q_param->get_shape(); + auto k_shape = k_param->get_shape(); + + batch_size = q_shape.lens()[0]; + concat_heads = q_shape.lens()[1]; + sequence_length = q_shape.lens()[2]; + head_dim = q_shape.lens()[3]; + + kv_heads = k_shape.lens()[1]; + max_seq_length = k_shape.lens()[2]; + + // calculate Q heads from concat_heads = num_heads + 2 * kv_heads + num_heads = concat_heads - 2 * kv_heads; + + // calculate sequence length per group + if(max_seq_length % num_groups != 0) { + std::cout << "Max sequence length " << max_seq_length + << " not divisible by " << num_groups << " groups" << std::endl; + // TODO: add autosplitting (padding won't be needed) + seq_length_per_group = 0; + } else { + seq_length_per_group = max_seq_length / num_groups; + } + } + }; + + // Helper function: adjust transpose permutation when inserting a group dimension + // insert_pos_from_end: position where group dimension is inserted in ORIGINAL tensor + // (negative values count from end, e.g., -2 means second-to-last) + // original_perm: the original permutation vector + // Returns: adjusted permutation vector + std::vector adjust_transpose_perm(const std::vector& original_perm, + int insert_pos_from_end) const + { + // Convert negative position to actual position in the ORIGINAL tensor + // For a 4D tensor, insert_pos_from_end=-2 means position 2 (0-indexed) + size_t original_rank = original_perm.size(); + int actual_insert_pos = insert_pos_from_end; + if(insert_pos_from_end < 0) { + actual_insert_pos = static_cast(original_rank) + insert_pos_from_end; + } + + // Adjust permutation values: any dimension >= insert_pos shifts up by 1 + std::vector new_perm; + for(auto idx : original_perm) { + if(idx >= actual_insert_pos) { + new_perm.push_back(idx + 1); + } else { + new_perm.push_back(idx); + } + } + + // Insert the group dimension at its natural position in the output + // The group dimension itself appears at position actual_insert_pos + new_perm.insert(new_perm.begin() + actual_insert_pos, actual_insert_pos); + + return new_perm; + } + + // Helper function: adjust axes when a group dimension is inserted + // axes: original axes vector + // group_dim_pos: position where group dimension is inserted (in ORIGINAL tensor coordinates) + // Returns: adjusted axes + std::vector adjust_axes(const std::vector& axes, int group_dim_pos) const + { + std::vector adjusted; + for(auto axis : axes) { + // If axis >= group_dim_pos, shift it by 1 + if(axis >= group_dim_pos) { + adjusted.push_back(axis + 1); + } else { + adjusted.push_back(axis); + } + } + return adjusted; + } + + auto matcher() const + { + return match::name("group")(match::has_op_value("tag", "kv_cache_attention")).bind("group"); + } + + // Helper to extract Q, K, V parameters from the attention submodule's gemm inputs + struct qkv_params { + instruction_ref q_param; // Parameter containing Q (full QKV tensor) + instruction_ref k_param; // Parameter for K (concat_past_present output) + instruction_ref v_param; // Parameter for V (concat_past_present output) + + // factory method to extract Q, K, V parameters from gemm operations + static std::optional from_gemms(instruction_ref gemm1, instruction_ref gemm2) + { + auto trace_back_to_param = [](instruction_ref ins) -> std::optional { + instruction_ref current = ins; + while(current->name() != "@param") { + if(current->inputs().empty()) { + return std::nullopt; + } + current = current->inputs()[0]; + } + return current; + }; + + auto q_input = gemm1->inputs()[0]; + auto k_input = gemm1->inputs()[1]; + auto v_input = gemm2->inputs()[1]; + + // trace back Q, K, V to find the parameters they originate from + auto q_param_opt = trace_back_to_param(q_input); + auto k_param_opt = trace_back_to_param(k_input); + auto v_param_opt = trace_back_to_param(v_input); + + if(not q_param_opt or not k_param_opt or not v_param_opt) return std::nullopt; + return qkv_params{*q_param_opt, *k_param_opt, *v_param_opt}; + } + }; + + void rebuild_gqa_attention(module& target_mod, + const module& source_mod, + const std::unordered_map& param_map, + instruction_ref gemm2, + const attention_dims& dims, + std::size_t num_groups) const + { + + std::cout << "Rebuilding GQA attention with inserter..." << std::endl; + std::cout << "Second gemm (will stop after): " << gemm2->name() << std::endl; + + // Calculate the group dimension position in the original tensor (4D) + // Group is inserted at position -2, which is index 2 in a 4D tensor + int group_dim_pos = 2; + + // Define the BNGSM shape for broadcasts (needed in the inserter) + std::vector bngsm{dims.batch_size, dims.num_heads, num_groups, dims.sequence_length, dims.seq_length_per_group}; + std::vector bnsm{dims.batch_size, dims.num_heads, dims.sequence_length, dims.max_seq_length}; + + // Track reduce operations for LSE calculation + instruction_ref reduce_max_ref; + instruction_ref reduce_sum_ref; + bool found_reduce_max = false; + bool found_reduce_sum = false; + + // Will get the second dot result after copying + instruction_ref second_dot_result; + + // Create the inserter function that transforms operations + auto inserter = [&](module& m, + instruction_ref ins, + const operation& op, + const std::vector& inputs, + const std::vector& mod_args) -> instruction_ref { + + auto op_name = op.name(); + + // Helper to print shape + auto print_shape = [](const std::vector& lens) { + std::cout << "{"; + for(size_t i = 0; i < lens.size(); ++i) { + std::cout << lens[i]; + if(i < lens.size() - 1) std::cout << ","; + } + std::cout << "}"; + }; + + auto print_output = [&print_shape](instruction_ref result) { + std::cout << " Output shape: "; + print_shape(result->get_shape().lens()); + std::cout << std::endl; + }; + + // Helper to print operation attributes + auto print_op_attrs = [](const operation& o) { + try { + auto val = o.to_value(); + if(!val.empty()) { + std::stringstream ss; + ss << val; + auto str = ss.str(); + if(!str.empty()) { + std::cout << " " << str; + } + } + } catch(...) {} + }; + + // Debug: print operation being processed + std::cout << "\n>>> Processing op: " << op_name; + print_op_attrs(op); + std::cout << std::endl; + std::cout << " Input shapes: "; + for(size_t i = 0; i < inputs.size(); ++i) { + print_shape(inputs[i]->get_shape().lens()); + if(i < inputs.size() - 1) std::cout << ", "; + } + std::cout << std::endl; + + // Transpose: adjust permutation + if(op_name == "transpose") { + auto perm = op.to_value()["permutation"].to_vector(); + auto new_perm = adjust_transpose_perm(perm, -2); + + std::cout << " Adjusted perm: ["; + for(size_t i = 0; i < perm.size(); ++i) + std::cout << perm[i] << (i < perm.size()-1 ? "," : ""); + std::cout << "] -> ["; + for(size_t i = 0; i < new_perm.size(); ++i) + std::cout << new_perm[i] << (i < new_perm.size()-1 ? "," : ""); + std::cout << "]" << std::endl; + + auto new_op = make_op("transpose", {{"permutation", new_perm}}); + std::cout << " Creating: transpose"; + print_op_attrs(new_op); + std::cout << std::endl; + auto result = m.insert_instruction(ins, new_op, inputs, mod_args); + print_output(result); + return result; + } + + // Reduce operations: adjust axes + if(op_name == "reduce_max" || op_name == "reduce_sum") { + auto axes = op.to_value()["axes"].to_vector(); + auto new_axes = adjust_axes(axes, group_dim_pos); + + std::cout << " Adjusted axes: [" << axes[0] + << "] -> [" << new_axes[0] << "]" << std::endl; + + auto new_op = make_op(op_name, {{"axes", new_axes}}); + std::cout << " Creating: " << op_name; + print_op_attrs(new_op); + std::cout << std::endl; + auto result = m.insert_instruction(ins, new_op, inputs, mod_args); + + // Track these for LSE calculation + if(op_name == "reduce_max") { + reduce_max_ref = result; + found_reduce_max = true; + } + if(op_name == "reduce_sum") { + reduce_sum_ref = result; + found_reduce_sum = true; + } + + print_output(result); + return result; + } + + // Broadcast: if output is 2D {batch, max_seq_length}, add reshape to 3D + if(op_name == "broadcast") { + auto result = m.insert_instruction(ins, op, inputs, mod_args); + auto result_lens = result->get_shape().lens(); + + // Check if result is {batch, max_seq_length} + if(result_lens.size() == 2 && + result_lens[0] == dims.batch_size && + result_lens[1] == dims.max_seq_length) { + std::cout << " Creating: broadcast (unchanged)"; + print_op_attrs(op); + std::cout << std::endl; + print_output(result); + + // Add reshape to split sequence dimension into groups + std::vector reshape_dims = {dims.batch_size, num_groups, dims.seq_length_per_group}; + std::cout << "\n>>> Auto-inserting reshape after 2D broadcast" << std::endl; + std::cout << " Reshape dims: "; + print_shape(reshape_dims); + std::cout << std::endl; + + auto reshape_result = m.insert_instruction(ins, make_op("reshape", {{"dims", reshape_dims}}), result); + std::cout << " Creating: reshape"; + auto reshape_op = make_op("reshape", {{"dims", reshape_dims}}); + print_op_attrs(reshape_op); + std::cout << std::endl; + print_output(reshape_result); + return reshape_result; + } + + std::cout << " Creating: broadcast (unchanged)"; + print_op_attrs(op); + std::cout << std::endl; + print_output(result); + return result; + } + + // Multibroadcast: adjust output shape if it matches BNSM pattern + if(op_name == "multibroadcast") { + auto out_lens = op.to_value()["out_lens"].to_vector(); + + // Check if output is 2D {batch, max_seq_length} - needs reshape + if(out_lens.size() == 2 && + out_lens[0] == dims.batch_size && + out_lens[1] == dims.max_seq_length) { + std::cout << " 2D multibroadcast - will add reshape" << std::endl; + auto result = m.insert_instruction(ins, op, inputs, mod_args); + std::cout << " Creating: multibroadcast (unchanged)"; + print_op_attrs(op); + std::cout << std::endl; + print_output(result); + + // Add reshape to split sequence dimension into groups + std::vector reshape_dims = {dims.batch_size, num_groups, dims.seq_length_per_group}; + std::cout << "\n>>> Auto-inserting reshape after 2D multibroadcast" << std::endl; + std::cout << " Reshape dims: "; + print_shape(reshape_dims); + std::cout << std::endl; + + auto reshape_result = m.insert_instruction(ins, make_op("reshape", {{"dims", reshape_dims}}), result); + std::cout << " Creating: reshape"; + auto reshape_op = make_op("reshape", {{"dims", reshape_dims}}); + print_op_attrs(reshape_op); + std::cout << std::endl; + print_output(reshape_result); + return reshape_result; + } + + // Check if this is a 4D multibroadcast with max_seq_length at last dimension + // Insert group dimension at position -2 and split last dimension + if(out_lens.size() == 4 && out_lens[3] == dims.max_seq_length) { + std::cout << " 4D multibroadcast with max_seq_length - inserting group dim" << std::endl; + + // Build new shape: insert group at -2, change last dim to seq_per_group + std::vector new_lens = {out_lens[0], out_lens[1], num_groups, out_lens[2], dims.seq_length_per_group}; + + std::cout << " Adjusted: "; + print_shape(out_lens); + std::cout << " -> "; + print_shape(new_lens); + std::cout << std::endl; + + auto new_op = make_op("multibroadcast", {{"out_lens", new_lens}}); + std::cout << " Creating: multibroadcast"; + print_op_attrs(new_op); + std::cout << std::endl; + auto result = m.insert_instruction(ins, new_op, inputs, mod_args); + print_output(result); + return result; + } + } + + // Unsqueeze: adjust axes (need to account for group dimension) + if(op_name == "unsqueeze") { + auto axes = op.to_value()["axes"].to_vector(); + // For unsqueeze, we need to adjust axes that are >= group_dim_pos + // But unsqueeze operates on the tensor BEFORE the dimensions are added + // So we need different logic + std::vector new_axes; + for(auto axis : axes) { + // If inserting group at position 2 in a 3D tensor (after greater/convert) + // and we want to unsqueeze at {1, 2}, it becomes {1, 3} (skip the group position) + if(axis >= group_dim_pos) { + new_axes.push_back(axis + 1); + } else { + new_axes.push_back(axis); + } + } + + std::cout << " Adjusted axes: ["; + for(size_t i = 0; i < axes.size(); ++i) + std::cout << axes[i] << (i < axes.size()-1 ? "," : ""); + std::cout << "] -> ["; + for(size_t i = 0; i < new_axes.size(); ++i) + std::cout << new_axes[i] << (i < new_axes.size()-1 ? "," : ""); + std::cout << "]" << std::endl; + + auto new_op = make_op("unsqueeze", {{"axes", new_axes}}); + std::cout << " Creating: unsqueeze"; + print_op_attrs(new_op); + std::cout << std::endl; + auto result = m.insert_instruction(ins, new_op, inputs, mod_args); + print_output(result); + return result; + } + + // Reshape: need to adjust dims if they span the group dimension + // This is tricky - for now, check if it's the mask reshape + // TODO + if(op_name == "reshape") { + auto dims_vec = op.to_value()["dims"].to_vector(); + + // Check if this is a mask reshape that needs to split the sequence dimension + // e.g., {batch, max_seq_length} -> {batch, num_groups, seq_per_group} + if(dims_vec.size() == 2 && dims_vec[1] == dims.max_seq_length) { + std::vector new_dims = {dims.batch_size, num_groups, dims.seq_length_per_group}; + std::cout << " Adjusted dims: "; + print_shape(dims_vec); + std::cout << " -> "; + print_shape(new_dims); + std::cout << std::endl; + + auto new_op = make_op("reshape", {{"dims", new_dims}}); + std::cout << " Creating: reshape"; + print_op_attrs(new_op); + std::cout << std::endl; + auto result = m.insert_instruction(ins, new_op, inputs, mod_args); + print_output(result); + return result; + } + } + + // Default: copy operation as-is + std::cout << " Creating: " << op_name << " (unchanged)"; + print_op_attrs(op); + std::cout << std::endl; + auto result = m.insert_instruction(ins, op, inputs, mod_args); + print_output(result); + return result; + }; + + // Find the instruction after gemm2 to use as the 'last' parameter + instruction_ref stop_point = std::next(gemm2); + + std::cout << "Copying instructions from source module up to (not including) instruction after gemm2..." << std::endl; + + // Use add_instructions with range [begin, stop_point) to copy up to and including gemm2 + std::unordered_map map_old_to_new = param_map; + target_mod.add_instructions(source_mod.begin(), stop_point, &map_old_to_new, inserter); + + // Get the transformed gemm2 + if(!contains(map_old_to_new, gemm2)) { + std::cout << "ERROR: gemm2 not found in map!" << std::endl; + return; + } + second_dot_result = map_old_to_new.at(gemm2); + + std::cout << "\n=== Instructions copied and transformed ===" << std::endl; + + std::cout << "Second dot shape: " << second_dot_result->get_shape() << std::endl; + + std::cout << "\n=== Adding final transpose and reshape for flash decoding ===" << std::endl; + + // Add correct flash decoding transpose: {B, N, G, S, D} -> {B, G, S, N, D} + std::cout << "Adding transpose with permutation {0, 2, 3, 1, 4}" << std::endl; + auto transpose_out = target_mod.add_instruction( + make_op("transpose", {{"permutation", {0, 2, 3, 1, 4}}}), + second_dot_result); + std::cout << "Transpose output shape: " << transpose_out->get_shape() << std::endl; + + // Add correct flash decoding reshape: {B, G, S, N, D} -> {B, G, S, N*D} + std::vector final_shape = {dims.batch_size, num_groups, dims.sequence_length, dims.num_heads * dims.head_dim}; + std::cout << "Adding reshape with dims {"; + for(size_t i = 0; i < final_shape.size(); ++i) { + std::cout << final_shape[i]; + if(i < final_shape.size() - 1) std::cout << ","; + } + std::cout << "}" << std::endl; + auto reshape_out = target_mod.add_instruction( + make_op("reshape", {{"dims", final_shape}}), + transpose_out); + std::cout << "Reshape output shape: " << reshape_out->get_shape() << std::endl; + + // Calculate LSE (log-sum-exp) from the tracked reduce operations + // LSE = log(sum_exp) + max + if(found_reduce_max && found_reduce_sum) { + auto log_sum = target_mod.add_instruction(make_op("log"), reduce_sum_ref); + auto lse = target_mod.add_instruction(make_op("add"), reduce_max_ref, log_sum); + std::cout << "LSE shape: " << lse->get_shape() << std::endl; + + target_mod.add_return({reshape_out, lse}); + } else { + std::cout << "WARNING: Could not find reduce_max or reduce_sum for LSE calculation" << std::endl; + target_mod.add_return({reshape_out}); + } + + // print the complete submodule + std::cout << "\n=== Complete GQA Flash Decoding Submodule ===" << std::endl; + std::cout << target_mod << std::endl; + std::cout << "=== End Submodule ===" << std::endl; + } + + void apply(module_pass_manager& mpm, const match::matcher_result& r) const + { + auto& mm = mpm.get_module(); + auto attn_group_ins = r.instructions["group"]; + auto* submod = attn_group_ins->module_inputs().front(); + + std::cout << "GQA flash decoding detected, here is the submodule: " << std::endl; + submod->debug_print(); + + // extract Q, K, V parameters from gemm inputs + auto [gemm1, gemm2] = get_attention_gemms(submod); + auto qkv_opt = qkv_params::from_gemms(gemm1, gemm2); + if(not qkv_opt) return; + auto [q_param, k_param, v_param] = *qkv_opt; + + // derive attention dims from Q, K, V parameters + attention_dims dims(q_param, k_param, groups); + + if(groups <= 1 or dims.seq_length_per_group == 0) { + return; + } + + // map submodule params to main module inputs + auto group_inputs = attn_group_ins->inputs(); + auto map_param_to_main = map_submod_params_to_inputs(submod, group_inputs); + + // get actual Q, K, V instructions from main module + auto q = map_param_to_main.at(q_param); // maps to the QKV tensor + auto k = map_param_to_main.at(k_param); // maps to K concat_past_present output + auto v = map_param_to_main.at(v_param); // maps to V concat_past_present output + + // GQA flash decoding: + // - Q (QKV tensor): add new group dim and broadcast + // - K: split sequence dimension into groups + // - V: split sequence dimension into groups + auto q_type = q->get_shape().type(); + auto k_type = k->get_shape().type(); + auto v_type = v->get_shape().type(); + + // insert group dimension at position -2 for all tensors + // K and V: [B, kv_heads, N, D] -> [B, kv_heads, G, N/G, D] + // build transformed shapes + std::vector q_transformed_shape; + std::vector k_transformed_shape; + std::vector v_transformed_shape; + + q_transformed_shape = {dims.batch_size, dims.concat_heads, groups, dims.sequence_length, dims.head_dim}; + k_transformed_shape = {dims.batch_size, dims.kv_heads, groups, dims.seq_length_per_group, dims.head_dim}; + v_transformed_shape = {dims.batch_size, dims.kv_heads, groups, dims.seq_length_per_group, dims.head_dim}; + + // insert reshape operations before the attention group + // [B, concat_heads, seq, head_dim] -> [B, concat_heads, 1, seq, head_dim] -> [B, concat_heads, G, seq, head_dim] + auto q_unsqueezed = mm.insert_instruction( + attn_group_ins, + make_op("unsqueeze", {{"axes", {-2}}}), + q); + + auto q_reshaped = mm.insert_instruction( + attn_group_ins, + make_op("multibroadcast", {{"out_lens", q_transformed_shape}}), + q_unsqueezed); + + // [B, kv_heads, N, D] -> [B, kv_heads, G, N/G, D] + auto k_reshaped = mm.insert_instruction( + attn_group_ins, + make_op("reshape", {{"dims", k_transformed_shape}}), + k); + + // [B, kv_heads, N, D] -> [B, kv_heads, G, N/G, D] + auto v_reshaped = mm.insert_instruction( + attn_group_ins, + make_op("reshape", {{"dims", v_transformed_shape}}), + v); + + // No need to reshape additional inputs + // We'll adjust broadcast patterns inside for masking + + // create new input list, starting with replacing Q, K, V with reshaped versions + std::vector new_group_inputs = group_inputs; + for(size_t i = 0; i < group_inputs.size(); ++i) { + if(group_inputs[i] == q) { + new_group_inputs[i] = q_reshaped; + } else if(group_inputs[i] == k) { + new_group_inputs[i] = k_reshaped; + } else if(group_inputs[i] == v) { + new_group_inputs[i] = v_reshaped; + } + } + + module m_flash_decode; + m_flash_decode.set_bypass(); + + // get parameter names from original submodule + auto get_param_name = [](instruction_ref param) -> std::string { + assert(param->name() == "@param"); + return param->get_operator().to_value()["parameter"].to(); + }; + + auto q_name = get_param_name(q_param); + auto k_name = get_param_name(k_param); + auto v_name = get_param_name(v_param); + + // Add parameters to new submodule with transformed shapes + auto new_q_param = m_flash_decode.add_parameter( + q_name, shape{q_type, q_transformed_shape}); + auto new_k_param = m_flash_decode.add_parameter( + k_name, shape{k_type, k_transformed_shape}); + auto new_v_param = m_flash_decode.add_parameter( + v_name, shape{v_type, v_transformed_shape}); + + // Build mapping from old params to new params + std::unordered_map map_old_params_to_new; + map_old_params_to_new[q_param] = new_q_param; + map_old_params_to_new[k_param] = new_k_param; + map_old_params_to_new[v_param] = new_v_param; + + // add the rest of the parameters + for(auto param : iterator_for(*submod)) { + if(param->name() == "@param") { + if(not contains(map_old_params_to_new, param)) { + auto param_name = get_param_name(param); + auto param_shape = param->get_shape(); + auto new_param = m_flash_decode.add_parameter(param_name, param_shape); + map_old_params_to_new[param] = new_param; + } + } + } + + // rebuild the attention operations in the flash decode submodule + rebuild_gqa_attention(m_flash_decode, *submod, map_old_params_to_new, + gemm2, dims, groups); + + // create the module in the module pass manager and insert the new group operation + auto orig_name = attn_group_ins->module_inputs().front()->name(); + std::string new_mod_name = orig_name + "_gqa_flash_decoding"; + + module_ref mpm_flash_mod = mpm.create_module(new_mod_name, std::move(m_flash_decode)); + mpm_flash_mod->set_bypass(); + + auto new_group_ins = mm.insert_instruction( + attn_group_ins, + make_op("group", {{"tag", "attention"}}), + new_group_inputs, + {mpm_flash_mod}); + + std::cout << "Created GQA flash decoding group" << std::endl; + std::cout << "Group output shape: " << new_group_ins->get_shape() << std::endl; + + // unpack O' and LSE + auto partial_output_o_prime = mm.insert_instruction( + attn_group_ins, make_op("get_tuple_elem", {{"index", 0}}), new_group_ins); + auto lse = mm.insert_instruction( + attn_group_ins, make_op("get_tuple_elem", {{"index", 1}}), new_group_ins); + + // LSE-weighted combination + std::cout << "\n=== Kernel 2: LSE-weighted combination ===" << std::endl; + std::cout << "Input LSE shape: " << lse->get_shape() << std::endl; // [B, N, G, S, 1] = [2, 2, 2, 1, 1] + std::cout << "Input O' shape: " << partial_output_o_prime->get_shape() << std::endl; // [B, G, S, N*D] = [2, 2, 1, 4] + + // align LSE with O' for proper weighting + // LSE is [B, N, G, S, 1], match group dimension of O' [B, G, S, N*D] + + // [B, N, G, S, 1] -> [B, G, N, S, 1] + auto lse_transposed = mm.insert_instruction( + attn_group_ins, make_op("transpose", {{"permutation", {0, 2, 1, 3, 4}}}), lse); + + // average across heads since all heads in a group share the same weight + // [B, G, N, S, 1] -> [B, G, 1, S, 1] + auto lse_avg = mm.insert_instruction( + attn_group_ins, make_op("reduce_mean", {{"axes", {2}}}), lse_transposed); + + // [B, G, 1, S, 1] -> [B, G, S] + auto lse_squeezed = mm.insert_instruction( + attn_group_ins, make_op("squeeze", {{"axes", {2, 4}}}), lse_avg); + + // softmax across groups for LSE weights + // find max across groups for numerical stability + // [B, G, S] -> [B, 1, S] + auto lse_max = mm.insert_instruction( + attn_group_ins, make_op("reduce_max", {{"axes", {1}}}), lse_squeezed); + + // broadcast max back to original shape + // [B, 1, S] -> [B, G, S] + auto lse_max_bcast = mm.insert_instruction( + attn_group_ins, + make_op("multibroadcast", {{"out_lens", lse_squeezed->get_shape().lens()}}), + lse_max); + + // exp(LSE - max_LSE) + auto lse_sub = mm.insert_instruction(attn_group_ins, make_op("sub"), lse_squeezed, lse_max_bcast); + auto lse_exp = mm.insert_instruction(attn_group_ins, make_op("exp"), lse_sub); + + // sum exp across groups + // [B, G, S] -> [B, 1, S] + auto lse_sum = mm.insert_instruction( + attn_group_ins, make_op("reduce_sum", {{"axes", {1}}}), lse_exp); + + // [B, 1, S] -> [B, G, S] + auto lse_sum_bcast = mm.insert_instruction( + attn_group_ins, + make_op("multibroadcast", {{"out_lens", lse_exp->get_shape().lens()}}), + lse_sum); + + // [B, G, S] -> [B, G, S] + auto weights = mm.insert_instruction(attn_group_ins, make_op("div"), lse_exp, lse_sum_bcast); + + // weights is [B, G, S], O' is [B, G, S, N*D] + // [B, G, S] -> [B, G, S, 1] + auto weights_unsqueezed = mm.insert_instruction( + attn_group_ins, make_op("unsqueeze", {{"axes", {3}}}), weights); + + // broadcast to match O' shape + // [B, G, S, 1] -> [B, G, S, N*D] + auto weights_bcast = mm.insert_instruction( + attn_group_ins, + make_op("multibroadcast", {{"out_lens", partial_output_o_prime->get_shape().lens()}}), + weights_unsqueezed); + + // convert weights to match O' type + auto output_type = partial_output_o_prime->get_shape().type(); + auto weights_converted = mm.insert_instruction( + attn_group_ins, make_op("convert", {{"target_type", output_type}}), weights_bcast); + + // multiply O' by weights + auto weighted_output = mm.insert_instruction( + attn_group_ins, make_op("mul"), partial_output_o_prime, weights_converted); + + // sum across groups to get final output + // [B, G, S, N*D] -> [B, 1, S, N*D] + auto final_output = mm.insert_instruction( + attn_group_ins, make_op("reduce_sum", {{"axes", {1}}}), weighted_output); + + // squeeze the reduced group dimension + // [B, 1, S, N*D] -> [B, S, N*D] + auto final_squeezed = mm.insert_instruction( + attn_group_ins, make_op("squeeze", {{"axes", {1}}}), final_output); + + mm.replace_instruction(attn_group_ins, final_squeezed); + } +}; + struct find_flash_decoding { // configuration from fuse_attention pass config @@ -301,21 +1057,6 @@ struct find_flash_decoding return match::name("group")(match::has_op_value("tag", "attention")).bind("group"); } - std::pair get_gemms(module_ref submod) const - { - std::vector gemms; - for(auto it = submod->begin(); it != submod->end(); ++it) - { - if(it->name() == "dot") - gemms.push_back(it); - } - assert(gemms.size() == 2 and "Expected exactly 2 gemm operations in attention submodule"); - - // gemms[0] is Q@K, gemms[1] is P@V - // gemms are in order since we iterate from begin to end - return {gemms[0], gemms[1]}; - } - std::vector get_qkv_shapes(instruction_ref q, instruction_ref k, instruction_ref v) const { std::vector qkv_shapes; @@ -408,17 +1149,6 @@ struct find_flash_decoding return result; } - std::unordered_map - map_submod_params_to_inputs(module_ref submod, - const std::vector& group_inputs) const - { - auto map_param_to_main = submod->get_ins_param_map(group_inputs, true); - // verify the mapping is correct - auto expected_inputs = submod->get_inputs(map_param_to_main); - assert(expected_inputs == group_inputs and "Mapped inputs don't match group inputs"); - return map_param_to_main; - } - void rebuild_attention_submodule( module& target_mod, const module& source_mod, @@ -517,7 +1247,7 @@ struct find_flash_decoding return; // get gemm1 and gemm2 - auto [gemm1, gemm2] = get_gemms(submod); + auto [gemm1, gemm2] = get_attention_gemms(submod); // TODO: for this first pass of flash decoding, assuming no input fusion / not supporting auto q_param = gemm1->inputs()[0]; @@ -930,11 +1660,17 @@ void fuse_attention::apply(module_pass_manager& mpm) const { std::size_t counter = 0; + std::cout << "original module before fusion: " << std::endl; + mpm.get_module().debug_print(); + // Fuse kv-cache attention by default match::find_matches(mpm, find_kv_cache_attention{.counter = &counter}); mpm.get_module().sort(); mpm.run_pass(dead_code_elimination{}); + std::cout << "module after kv-cache attention fusion: " << std::endl; + mpm.get_module().debug_print(); + // Only fuse plain attention when requested if(attn_enabled) { @@ -947,12 +1683,17 @@ void fuse_attention::apply(module_pass_manager& mpm) const std::size_t configured_splits = get_num_splits(flash_decoding_num_splits); if(configured_splits > 0 or flash_decoding_enabled) { + // flash decoding for regular attention, single & multi-headed match::find_matches( mpm, - find_flash_decoding{.configured_splits = flash_decoding_num_splits, + find_flash_decoding{.configured_splits = configured_splits, .configured_threshold = flash_decoding_threshold, .configured_max_splits = flash_decoding_max_splits, .configured_min_chunk_size = flash_decoding_min_chunk_size}); + + // flash decoding for GQA attention + match::find_matches( + mpm, find_gqa_flash_decoding{.groups = configured_splits}); mpm.run_pass(dead_code_elimination{}); } }