From db313416e49bfe4d86a0cde7ea06a1b65f63453d Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Wed, 10 Dec 2025 04:49:40 -0600 Subject: [PATCH 1/4] blargh --- src/fuse_attention.cpp | 975 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 975 insertions(+) diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index 8912b186cdb..b086bbf25c3 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -288,6 +288,976 @@ struct find_attention } }; +struct find_gqa_flash_decoding +{ + std::size_t groups; + + // Struct to hold all attention dimensions + struct attention_dims + { + std::size_t batch_size; + std::size_t num_heads; // Q heads + std::size_t kv_heads; // K and V heads + std::size_t concat_heads; // total heads in QKV tensor + std::size_t sequence_length; + std::size_t max_seq_length; + std::size_t head_dim; + std::size_t seq_length_per_group; // sequence length per group after splitting max sequence length + + // constructor from parameters + attention_dims(instruction_ref q_param, instruction_ref k_param, std::size_t num_groups) + { + auto q_shape = q_param->get_shape(); + auto k_shape = k_param->get_shape(); + + batch_size = q_shape.lens()[0]; + concat_heads = q_shape.lens()[1]; + sequence_length = q_shape.lens()[2]; + head_dim = q_shape.lens()[3]; + + kv_heads = k_shape.lens()[1]; + max_seq_length = k_shape.lens()[2]; + + // calculate Q heads from concat_heads = num_heads + 2 * kv_heads + num_heads = concat_heads - 2 * kv_heads; + + // calculate sequence length per group + if(max_seq_length % num_groups != 0) { + std::cout << "Max sequence length " << max_seq_length + << " not divisible by " << num_groups << " groups" << std::endl; + // TODO: Add padding support + seq_length_per_group = 0; // Set to 0 to indicate error + return; + } + seq_length_per_group = max_seq_length / num_groups; + } + }; + + auto matcher() const + { + return match::name("group")(match::has_op_value("tag", "kv_cache_attention")).bind("group"); + } + + std::pair get_gemms(module_ref submod) const + { + std::vector gemms; + for(auto it = submod->begin(); it != submod->end(); ++it) + { + if(it->name() == "dot") + gemms.push_back(it); + } + assert(gemms.size() == 2 and "Expected exactly 2 gemm operations in attention submodule"); + + // gemms[0] is Q@K, gemms[1] is P@V + // gemms are in order since we iterate from begin to end + return {gemms[0], gemms[1]}; + } + + // Helper to extract Q, K, V parameters from the attention submodule's gemm inputs + struct qkv_params { + instruction_ref q_param; // Parameter containing Q (full QKV tensor) + instruction_ref k_param; // Parameter for K (concat_past_present output) + instruction_ref v_param; // Parameter for V (concat_past_present output) + }; + + std::unordered_map + map_submod_params_to_inputs(module_ref submod, + const std::vector& group_inputs) const + { + auto map_param_to_main = submod->get_ins_param_map(group_inputs, true); + // verify the mapping is correct + auto expected_inputs = submod->get_inputs(map_param_to_main); + assert(expected_inputs == group_inputs and "Mapped inputs don't match group inputs"); + return map_param_to_main; + } + + // rebuild GQA attention operations in flash decoding submodule + // Helper to find early exit masking operations + struct early_exit_mask_ops { + instruction_ref pos_literal; // Literal with position indices {0,1,2,3...} + instruction_ref pos_broadcast; // Broadcast of position literal + instruction_ref seq_len_param; // Sequence length parameter + instruction_ref seq_multicast; // Multibroadcast of seq_len + instruction_ref greater_op; // Greater comparison + instruction_ref convert_op; // Convert to bool + instruction_ref unsqueeze_op; // Unsqueeze mask + instruction_ref mask_broadcast; // Final multibroadcast of mask + instruction_ref ninf_literal; // -inf literal for masking + instruction_ref ninf_broadcast; // Multibroadcast of -inf + instruction_ref where_op; // Where operation applying mask + + // Flags to track which operations were found + bool found = false; + bool has_pos_literal = false; + bool has_pos_broadcast = false; + bool has_seq_len_param = false; + bool has_seq_multicast = false; + bool has_greater = false; + bool has_convert = false; + bool has_unsqueeze = false; + bool has_mask_broadcast = false; + bool has_ninf_literal = false; + bool has_ninf_broadcast = false; + }; + + early_exit_mask_ops find_early_exit_masking_ops( + const module& source_mod, + instruction_ref scaled_scores, + const std::unordered_map& map_old_to_new) const + { + early_exit_mask_ops mask_ops; + + // Find the where operation that uses our scaled scores + for(auto ins : iterator_for(source_mod)) { + if(ins->name() == "where") { + // Check if one of its inputs is our scaled scores (through the mapping) + for(auto input : ins->inputs()) { + if(contains(map_old_to_new, input) && map_old_to_new.at(input) == scaled_scores) { + mask_ops.where_op = ins; + mask_ops.found = true; + break; + } + } + if(mask_ops.found) break; + } + } + + if(!mask_ops.found) { + return mask_ops; + } + + // Get the three inputs to where: mask, true_value (-inf), false_value (scores) + auto mask_input = mask_ops.where_op->inputs()[0]; + mask_ops.ninf_broadcast = mask_ops.where_op->inputs()[1]; + + // Trace back the mask to find multibroadcast -> unsqueeze -> convert -> greater + instruction_ref current = mask_input; + + // Should be multibroadcast + if(current->name() == "multibroadcast") { + mask_ops.mask_broadcast = current; + mask_ops.has_mask_broadcast = true; + current = current->inputs()[0]; + } + + // Should be unsqueeze + if(current->name() == "unsqueeze") { + mask_ops.unsqueeze_op = current; + mask_ops.has_unsqueeze = true; + current = current->inputs()[0]; + } + + // Should be convert + if(current->name() == "convert") { + mask_ops.convert_op = current; + mask_ops.has_convert = true; + current = current->inputs()[0]; + } + + // Should be greater + if(current->name() == "greater") { + mask_ops.greater_op = current; + mask_ops.has_greater = true; + + // Get inputs to greater + auto pos_input = mask_ops.greater_op->inputs()[0]; + auto seq_input = mask_ops.greater_op->inputs()[1]; + + // Position side: broadcast -> literal + if(pos_input->name() == "broadcast") { + mask_ops.pos_broadcast = pos_input; + mask_ops.has_pos_broadcast = true; + mask_ops.pos_literal = pos_input->inputs()[0]; + mask_ops.has_pos_literal = true; + } + + // Sequence length side: multibroadcast -> param + if(seq_input->name() == "multibroadcast") { + mask_ops.seq_multicast = seq_input; + mask_ops.has_seq_multicast = true; + mask_ops.seq_len_param = seq_input->inputs()[0]; + mask_ops.has_seq_len_param = true; + } + } + + // Find the -inf literal source + if(mask_ops.ninf_broadcast->name() == "multibroadcast") { + mask_ops.has_ninf_broadcast = true; + mask_ops.ninf_literal = mask_ops.ninf_broadcast->inputs()[0]; + mask_ops.has_ninf_literal = true; + } + + return mask_ops; + } + + void rebuild_gqa_attention(module& target_mod, + const module& source_mod, + const std::unordered_map& param_map, + instruction_ref q_param, + instruction_ref k_param, + instruction_ref v_param, + const attention_dims& dims, + std::size_t num_groups) const + { + // map from instructions in old module to new module + std::unordered_map map_old_to_new = param_map; + + // TODO can do this better, and also make it better for other flash decoding case + // track softmax components for LSE calculation + std::unordered_map softmax_parts; + + assert(contains(param_map, q_param) && "Q parameter must be mapped"); + assert(contains(param_map, k_param) && "K parameter must be mapped"); + assert(contains(param_map, v_param) && "V parameter must be mapped"); + (void)v_param; // Will be used later for V operations + + // handle Q extraction + // since we slice on axis 1 (concat_heads) and groups are at axis 2, no change needed + for(auto ins : iterator_for(source_mod)) { + if(ins->name() == "slice" && ins->inputs()[0] == q_param) { + auto op = ins->get_operator(); + auto new_q = map_old_to_new.at(q_param); + auto sliced_q = target_mod.add_instruction(op, new_q); + map_old_to_new[ins] = sliced_q; + std::cout << " Q slice created, shape: " << sliced_q->get_shape() << std::endl; + break; + } + } + + // handle K transpose + instruction_ref transposed_k; + for(auto ins : iterator_for(source_mod)) { + if(ins->name() == "transpose") { + auto transpose_input = ins->inputs()[0]; + if(transpose_input == k_param) { + auto op = ins->get_operator(); + auto perm = op.to_value()["permutation"].to_vector(); + + // dims.batch_size, dims.kv_heads, groups, dims.seq_length_per_group, dims.head_dim} + // perm is now [0, 1, 2, 4, 3] for [B, H, G, D, S] + std::vector new_perm = {0, 1, 2, 4, 3}; + auto new_transpose_op = make_op("transpose", {{"permutation", new_perm}}); + auto new_k = map_old_to_new.at(k_param); + transposed_k = target_mod.add_instruction(new_transpose_op, new_k); + map_old_to_new[ins] = transposed_k; + + break; + } + } + } + + // ninf is of shape + // {batch_size, num_heads, sequence_length, max_seq_len} + + + // handle literal constants and their broadcasts + for(auto ins : iterator_for(source_mod)) { + if(ins->name() == "@literal") { + // copy literals directly + auto lit_val = ins->get_literal(); + auto new_lit = target_mod.add_literal(lit_val); + map_old_to_new[ins] = new_lit; + std::cout << " Added literal with shape: " << new_lit->get_shape() << std::endl; + } + } + + // TODO handle when kv_heads != num_heads + // define expected broadcast shapes for literals + std::vector bnsm{dims.batch_size, dims.num_heads, dims.sequence_length, dims.max_seq_length}; + std::vector bngsm{dims.batch_size, dims.num_heads, num_groups, dims.sequence_length, dims.seq_length_per_group}; + + // update multibroadcast shapes for literals + // if broadcasting to [B, N, S, M] shape, change to [B, N, G, S, M/G] + // Keep track of specific broadcasts for -inf and scale + instruction_ref ninf_broadcast; + instruction_ref scale_broadcast; + + for(auto ins : iterator_for(source_mod)) { + if(ins->name() == "multibroadcast" && + !contains(map_old_to_new, ins)) { + auto input = ins->inputs()[0]; + + // check if input is a literal and is already mapped + if(contains(map_old_to_new, input) && input->name() == "@literal") { + auto op = ins->get_operator(); + auto out_lens = op.to_value()["out_lens"].to_vector(); + + // check if the shape matches [B, N, S, M] pattern for attention scores + if(out_lens == bnsm) { + // use the pre-defined bngsm shape + auto new_op = make_op("multibroadcast", {{"out_lens", bngsm}}); + auto new_input = map_old_to_new.at(input); + auto new_broadcast = target_mod.add_instruction(new_op, new_input); + map_old_to_new[ins] = new_broadcast; + + // Check the literal value to identify -inf or scale + auto lit = input->get_literal(); + if(lit.get_shape().type() == migraphx::shape::half_type) { + // Get the literal value as a string for comparison + auto lit_str = lit.to_string(); + if(lit_str == "-inf") { + ninf_broadcast = new_broadcast; + std::cout << "adjusted -inf multibroadcast from BNSM to BNGSM: " + << new_broadcast->get_shape() << std::endl; + } else { + // Assume it's the scale value + scale_broadcast = new_broadcast; + std::cout << "adjusted scale multibroadcast from BNSM to BNGSM: " + << new_broadcast->get_shape() << " (value: " << lit_str << ")" << std::endl; + } + } + } else { + // for other shapes, just copy the multibroadcast as-is + auto new_input = map_old_to_new.at(input); + auto new_broadcast = target_mod.add_instruction(op, new_input); + map_old_to_new[ins] = new_broadcast; + } + } + } + } + + // Q slice [batch, num heads, groups, sl, max sl] + // check if we found the specific broadcasts + bool has_scale = false; + bool has_ninf = false; + for(auto ins : iterator_for(target_mod)) { + if(ins == scale_broadcast) has_scale = true; + if(ins == ninf_broadcast) has_ninf = true; + } + + if(has_scale) { + std::cout << " found scale broadcast for attention scaling" << std::endl; + } + if(has_ninf) { + std::cout << " found -inf broadcast for masking" << std::endl; + } + + // handle first dot product (Q @ K^T) + std::cout << "rebuilding first dot product (Q @ K^T)..." << std::endl; + instruction_ref dot1; + for(auto ins : iterator_for(source_mod)) { + if(ins->name() == "dot") { + // check if this is the first dot (Q @ K^T) + // it should have Q (or sliced Q) as first input and K transpose as second + auto input0 = ins->inputs()[0]; + auto input1 = ins->inputs()[1]; + + // check if we've already mapped these inputs (Q slice and K transpose) + if(contains(map_old_to_new, input0) && contains(map_old_to_new, input1)) { + auto new_q = map_old_to_new.at(input0); + auto new_kt = map_old_to_new.at(input1); + + // create the dot product with transformed inputs + dot1 = target_mod.add_instruction(make_op("dot"), new_q, new_kt); + map_old_to_new[ins] = dot1; + + std::cout << " created dot1 (Q @ K^T) with shape: " << dot1->get_shape() << std::endl; + break; // assume first dot is Q @ K^T + } + } + } + + // handle scaling (mul with scale factor) + std::cout << "finding and rebuilding scale multiplication..." << std::endl; + instruction_ref scaled_scores; + + // check if we have both dot1 and scale_broadcast + bool has_dot1 = false; + bool has_scale_bc = false; + for(auto ins : iterator_for(target_mod)) { + if(ins == dot1) has_dot1 = true; + if(ins == scale_broadcast) has_scale_bc = true; + } + + if(has_dot1 && has_scale_bc) { + scaled_scores = target_mod.add_instruction(make_op("mul"), dot1, scale_broadcast); + std::cout << " created scaled scores with shape: " << scaled_scores->get_shape() << std::endl; + } else if(has_dot1) { + // ff we don't have scale_broadcast, try to find the mul in the original + for(auto ins : iterator_for(source_mod)) { + if(ins->name() == "mul") { + auto input0 = ins->inputs()[0]; + auto input1 = ins->inputs()[1]; + + if(contains(map_old_to_new, input0) && contains(map_old_to_new, input1)) { + bool input0_is_dot = (map_old_to_new.at(input0)->name() == "dot"); + bool input1_is_dot = (map_old_to_new.at(input1)->name() == "dot"); + + if(input0_is_dot || input1_is_dot) { + auto new_input0 = map_old_to_new.at(input0); + auto new_input1 = map_old_to_new.at(input1); + + scaled_scores = target_mod.add_instruction(make_op("mul"), new_input0, new_input1); + map_old_to_new[ins] = scaled_scores; + + std::cout << " created scaled scores with shape: " << scaled_scores->get_shape() << std::endl; + break; + } + } + } + } + } + + // For kv_cache_attention, rebuild early exit masking with modified broadcast shapes + std::cout << "rebuilding early exit masking with adjusted broadcasts..." << std::endl; + + // Find the range literal (position indices like {0,1,2,3...}) + instruction_ref range_literal; + bool found_range = false; + for(auto ins : iterator_for(source_mod)) { + if(ins->name() == "@literal") { + auto shape = ins->get_literal().get_shape(); + if(shape.type() == migraphx::shape::int32_type && + shape.lens().size() == 1 && + shape.lens()[0] == dims.max_seq_length) { + range_literal = ins; + found_range = true; + std::cout << " found range literal with shape: " << shape << std::endl; + break; + } + } + } + + // Find the past_sl parameter (past sequence length) + instruction_ref past_sl_param; + bool found_past_sl = false; + for(auto param : iterator_for(source_mod)) { + if(param->name() == "@param") { + auto shape = param->get_shape(); + // past_sl is int32 type with batch_size elements (e.g., [2,1]) + if(shape.type() == migraphx::shape::int32_type && + shape.elements() == dims.batch_size) { + past_sl_param = param; + found_past_sl = true; + std::cout << " found past_sl param with shape: " << shape << std::endl; + break; + } + } + } + + if(!found_range) { + std::cout << " WARNING: Could not find range literal" << std::endl; + } + if(!found_past_sl) { + std::cout << " WARNING: Could not find past_sl parameter" << std::endl; + } + + // create broadcast shape vector + std::vector broadcast_shape = {dims.batch_size, num_groups, dims.seq_length_per_group}; + std::cout << " broadcast shape: [" << dims.batch_size << ", " << num_groups << ", " + << dims.seq_length_per_group << "]" << std::endl; + + instruction_ref range_broadcast; + instruction_ref past_sl_reshaped; + + if(found_range && found_past_sl && contains(param_map, past_sl_param)) { + // range literal broadcast to [batch_size, num_groups, seq_length_per_group] + if(!contains(map_old_to_new, range_literal)) { + auto lit_val = range_literal->get_literal(); + auto new_lit = target_mod.add_literal(lit_val); + map_old_to_new[range_literal] = new_lit; + } + + // Broadcast range literal directly, matching original pattern + // Use axis=1 to match the original pattern (broadcast[axis=1] adds batch dimension at front) + std::vector intermediate_bc_shape = {dims.batch_size, dims.max_seq_length}; + auto range_broadcast_intermediate = target_mod.add_instruction( + make_op("broadcast", {{"axis", 1}, {"out_lens", intermediate_bc_shape}}), + map_old_to_new.at(range_literal)); + std::cout << " broadcasted range to: " << range_broadcast_intermediate->get_shape() << std::endl; + + // Then reshape to final shape [batch_size, num_groups, seq_length_per_group] + range_broadcast = target_mod.add_instruction( + make_op("reshape", {{"dims", broadcast_shape}}), + range_broadcast_intermediate); + std::cout << " reshaped range to final shape: " << range_broadcast->get_shape() << std::endl; + + // past_sl param from [batch_size, 1] to [batch_size, max_seq_length] + // then reshape to [batch_size, num_groups, seq_length_per_group] + auto past_sl_new = param_map.at(past_sl_param); + + std::vector intermediate_shape = {dims.batch_size, dims.max_seq_length}; + auto past_sl_broadcast = target_mod.add_instruction( + make_op("multibroadcast", {{"out_lens", intermediate_shape}}), + past_sl_new); + std::cout << " multibroadcasted past_sl to: " << past_sl_broadcast->get_shape() << std::endl; + + past_sl_reshaped = target_mod.add_instruction( + make_op("reshape", {{"dims", broadcast_shape}}), + past_sl_broadcast); + std::cout << " reshaped past_sl to: " << past_sl_reshaped->get_shape() << std::endl; + + auto greater = target_mod.add_instruction( + make_op("greater"), range_broadcast, past_sl_reshaped); + auto convert = target_mod.add_instruction( + make_op("convert", {{"target_type", migraphx::shape::bool_type}}), greater); + auto unsqueeze = target_mod.add_instruction( + make_op("unsqueeze", {{"axes", {1, 3}}}), convert); + auto multibroadcast = target_mod.add_instruction( + make_op("multibroadcast", {{"out_lens", bngsm}}), unsqueeze); + + // Check if we have the ninf_broadcast before using it + bool has_ninf_bc = false; + for(auto ins : iterator_for(target_mod)) { + if(ins == ninf_broadcast) { + has_ninf_bc = true; + break; + } + } + + if(has_ninf_bc) { + auto where = target_mod.add_instruction(make_op("where"), multibroadcast, ninf_broadcast, scaled_scores); + scaled_scores = where; // Update scaled_scores to the masked version + + } else { + std::cout << " WARNING: Could not find ninf_broadcast for where operation" << std::endl; + } + } + + // quick implementation of remaining ops for testing + + // convert to float for softmax computation + auto convert_to_float = target_mod.add_instruction( + make_op("convert", {{"target_type", migraphx::shape::float_type}}), scaled_scores); + + // Reduce max along last axis (axis 4 in BNGSM) + auto reduce_max = target_mod.add_instruction( + make_op("reduce_max", {{"axes", {4}}}), convert_to_float); + + // Broadcast max back to original shape + auto max_broadcast = target_mod.add_instruction( + make_op("multibroadcast", {{"out_lens", bngsm}}), reduce_max); + + // Subtract max for numerical stability + auto sub = target_mod.add_instruction( + make_op("sub"), convert_to_float, max_broadcast); + + // Exp + auto exp_scores = target_mod.add_instruction( + make_op("exp"), sub); + + // Reduce sum along last axis + auto reduce_sum = target_mod.add_instruction( + make_op("reduce_sum", {{"axes", {4}}}), exp_scores); + + // Broadcast sum back + auto sum_broadcast = target_mod.add_instruction( + make_op("multibroadcast", {{"out_lens", bngsm}}), reduce_sum); + + // Divide to get softmax + auto softmax = target_mod.add_instruction( + make_op("div"), exp_scores, sum_broadcast); + + // Convert back to half + auto convert_to_half = target_mod.add_instruction( + make_op("convert", {{"target_type", migraphx::shape::half_type}}), softmax); + + std::cout << "now doing the dot between the convert and v"; + std::cout << v_param->get_shape() << std::endl; + + // We need to use the mapped V parameter, not the original v_param + // The v_param passed in is from the original submodule, we need the one in param_map + auto v_mapped = param_map.at(v_param); + std::cout << "V mapped shape: " << v_mapped->get_shape() << std::endl; + + // Dot with V + auto dot2 = target_mod.add_instruction( + make_op("dot"), convert_to_half, v_mapped); + std::cout << "Dot2 shape: " << dot2->get_shape() << std::endl; + + + + // for flash decoding, we keep the group dimension and return it + // kernel 2 will handle the LSE-weighted reduction + // dot2 is currently [B, N, G, S, D] + + // transpose to [B, G, S, N, D] to match flash decoding output pattern + auto transpose_out = target_mod.add_instruction( + make_op("transpose", {{"permutation", {0, 2, 3, 1, 4}}}), dot2); + std::cout << "Transpose shape: " << transpose_out->get_shape() << std::endl; + + // reshape to [B, G, S, N*D] + std::vector final_shape = {dims.batch_size, num_groups, dims.sequence_length, dims.num_heads * dims.head_dim}; + auto reshape_out = target_mod.add_instruction( + make_op("reshape", {{"dims", final_shape}}), transpose_out); + std::cout << "Final reshape shape (with groups): " << reshape_out->get_shape() << std::endl; + + // for LSE (log-sum-exp), we need log(sum_exp) + max + // we already have reduce_max and reduce_sum from softmax computation + // LSE shape is [B, N, G, S, 1] which is correct for flash decoding + auto log_sum = target_mod.add_instruction(make_op("log"), reduce_sum); + auto lse = target_mod.add_instruction(make_op("add"), reduce_max, log_sum); + std::cout << "LSE shape: " << lse->get_shape() << std::endl; + + target_mod.add_return({reshape_out, lse}); + + // print the complete submodule + std::cout << "\n=== Complete GQA Flash Decoding Submodule ===" << std::endl; + std::cout << target_mod << std::endl; + std::cout << "=== End Submodule ===" << std::endl; + } + + std::optional extract_qkv_params(instruction_ref gemm1, instruction_ref gemm2) const + { + qkv_params result; + + // Q: gemm1's first input should be a slice from the QKV tensor + auto q_input = gemm1->inputs()[0]; + if(q_input->name() == "slice") { + // trace back from slice to find the parameter + auto before_slice = q_input->inputs()[0]; + + instruction_ref current = before_slice; + while(current->name() != "@param") { + if(current->inputs().empty()) { + std::cout << "Cannot trace Q back to parameter" << std::endl; + return std::nullopt; + } + current = current->inputs()[0]; + } + result.q_param = current; + } else { + std::cout << "Expected Q to come from slice, got: " << q_input->name() << std::endl; + return std::nullopt; + } + + // K: gemm1's second input should be transposed K from concat_past_present + auto k_input = gemm1->inputs()[1]; + if(k_input->name() == "transpose") { + result.k_param = k_input->inputs()[0]; + } else { + result.k_param = k_input; + } + + if(result.k_param->name() != "@param") { + std::cout << "Expected K to be a parameter, got: " << result.k_param->name() << std::endl; + return std::nullopt; + } + + // V: gemm2's second input should be V from concat_past_present + result.v_param = gemm2->inputs()[1]; + if(result.v_param->name() != "@param") { + std::cout << "Expected V to be a parameter, got: " << result.v_param->name() << std::endl; + return std::nullopt; + } + + return result; + } + + void apply(module_pass_manager& mpm, const match::matcher_result& r) const + { + auto& mm = mpm.get_module(); + auto attn_group_ins = r.instructions["group"]; + auto* submod = attn_group_ins->module_inputs().front(); + + std::cout << "GQA flash decoding detected" << std::endl; + + // check multiple returns + auto return_ins = std::prev(submod->end()); + assert(return_ins->name() == "@return" and + "Last instruction must be a @return instruction"); + if(return_ins->inputs().size() > 1) { + std::cout << "KV cache attention unexpected multiple returns" << std::endl; + return; + } + + // get gemm1 and gemm2 + auto [gemm1, gemm2] = get_gemms(submod); + + // Extract Q, K, V parameters from gemm inputs + auto qkv_opt = extract_qkv_params(gemm1, gemm2); + if(!qkv_opt) { + std::cout << "Failed to extract Q, K, V parameters" << std::endl; + return; + } + + auto [q_param, k_param, v_param] = *qkv_opt; + + std::cout << "Q attn module param shape: " << q_param->get_shape() << std::endl; + std::cout << "K attn module param shape: " << k_param->get_shape() << std::endl; + std::cout << "V attn module param shape: " << v_param->get_shape() << std::endl; + + // derive dim values + attention_dims dims(q_param, k_param, groups); + + std::cout << "Max sequence length: " << dims.max_seq_length << std::endl; + + if(groups <= 1) { + std::cout << "No splitting requested (groups=" << groups << ")" << std::endl; + return; + } + + // check if dimensions were calculated successfully + if(dims.seq_length_per_group == 0) { + std::cout << "Failed to calculate sequence length per group, returning" << std::endl; + return; + } + + // map submodule params to main module inputs + auto group_inputs = attn_group_ins->inputs(); + auto map_param_to_main = map_submod_params_to_inputs(submod, group_inputs); + + // get actual Q, K, V instructions from main module + auto q = map_param_to_main.at(q_param); // maps to the QKV tensor + auto k = map_param_to_main.at(k_param); // maps to K concat_past_present output + auto v = map_param_to_main.at(v_param); // maps to V concat_past_present output + + std::cout << "Main module Q shape: " << q->get_shape() << std::endl; + std::cout << "Main module K shape: " << k->get_shape() << std::endl; + std::cout << "Main module V shape: " << v->get_shape() << std::endl; + + // GQA flash decoding: + // - Q (QKV tensor): broadcast across groups (no split) + // - K: split sequence dimension into groups + // - V: split sequence dimension into groups + + // shapes before group transformation + auto q_shape = q->get_shape(); + auto k_shape_main = k->get_shape(); + auto v_shape_main = v->get_shape(); + + // insert group dimension at position -2 for all tensors + // K and V: [B, kv_heads, N, D] -> [B, kv_heads, G, N/G, D] (split) + // build transformed shapes + std::vector q_transformed_shape; + std::vector k_transformed_shape; + std::vector v_transformed_shape; + + // Q shape transformation (broadcast group dimension) + q_transformed_shape = {dims.batch_size, dims.concat_heads, groups, dims.sequence_length, dims.head_dim}; + k_transformed_shape = {dims.batch_size, dims.kv_heads, groups, dims.seq_length_per_group, dims.head_dim}; + v_transformed_shape = {dims.batch_size, dims.kv_heads, groups, dims.seq_length_per_group, dims.head_dim}; + + std::cout << "Q transformed shape: "; + for(auto d : q_transformed_shape) std::cout << d << " "; + std::cout << std::endl; + + std::cout << "K transformed shape: "; + for(auto d : k_transformed_shape) std::cout << d << " "; + std::cout << std::endl; + + std::cout << "V transformed shape: "; + for(auto d : v_transformed_shape) std::cout << d << " "; + std::cout << std::endl; + + // insert reshape operations before the attention group + // [B, concat_heads, seq, head_dim] -> [B, concat_heads, 1, seq, head_dim] -> [B, concat_heads, G, seq, head_dim] + auto q_unsqueezed = mm.insert_instruction( + attn_group_ins, + make_op("unsqueeze", {{"axes", {2}}}), + q); + + auto q_reshaped = mm.insert_instruction( + attn_group_ins, + make_op("multibroadcast", {{"out_lens", q_transformed_shape}}), + q_unsqueezed); + + // K: reshape to split sequence dimension + auto k_reshaped = mm.insert_instruction( + attn_group_ins, + make_op("reshape", {{"dims", k_transformed_shape}}), + k); + + // V: reshape to split sequence dimension + auto v_reshaped = mm.insert_instruction( + attn_group_ins, + make_op("reshape", {{"dims", v_transformed_shape}}), + v); + + std::cout << "Q reshaped: " << q_reshaped->get_shape() << std::endl; + std::cout << "K reshaped: " << k_reshaped->get_shape() << std::endl; + std::cout << "V reshaped: " << v_reshaped->get_shape() << std::endl; + + // No need to handle positions outside the submodule + // We'll adjust broadcast patterns inside for early exit masking + + // TODO can probably do this simpler + // Create new input list by replacing Q, K, V with reshaped versions + std::vector new_group_inputs = group_inputs; + for(size_t i = 0; i < group_inputs.size(); ++i) { + if(group_inputs[i] == q) { + new_group_inputs[i] = q_reshaped; + } else if(group_inputs[i] == k) { + new_group_inputs[i] = k_reshaped; + } else if(group_inputs[i] == v) { + new_group_inputs[i] = v_reshaped; + } + // Other inputs (like seq_len) stay the same + } + + // Create new flash decoding submodule + module m_flash_decode; + m_flash_decode.set_bypass(); + + // Get parameter names from original submodule + auto get_param_name = [](instruction_ref param) -> std::string { + assert(param->name() == "@param"); + return param->get_operator().to_value()["parameter"].to(); + }; + + auto q_name = get_param_name(q_param); + auto k_name = get_param_name(k_param); + auto v_name = get_param_name(v_param); + + // Add parameters to new submodule with transformed shapes + auto new_q_param = m_flash_decode.add_parameter( + q_name, shape{q_shape.type(), q_transformed_shape}); + auto new_k_param = m_flash_decode.add_parameter( + k_name, shape{k_shape_main.type(), k_transformed_shape}); + auto new_v_param = m_flash_decode.add_parameter( + v_name, shape{v_shape_main.type(), v_transformed_shape}); + + // Build mapping from old params to new params + std::unordered_map map_old_params_to_new; + map_old_params_to_new[q_param] = new_q_param; + map_old_params_to_new[k_param] = new_k_param; + map_old_params_to_new[v_param] = new_v_param; + + // Add other parameters (like seq_len) that don't change shape + for(auto param : iterator_for(*submod)) { + if(param->name() == "@param") { + if(param != q_param && param != k_param && param != v_param) { + auto param_name = get_param_name(param); + auto param_shape = param->get_shape(); + auto new_param = m_flash_decode.add_parameter(param_name, param_shape); + map_old_params_to_new[param] = new_param; + std::cout << "Added unchanged param: " << param_name + << " with shape: " << param_shape << std::endl; + } + } + } + + // TODO all the param stuff before this can be simplified + + // rebuild the attention operations in the flash decode submodule + std::cout << "Rebuilding GQA attention operations..." << std::endl; + rebuild_gqa_attention(m_flash_decode, *submod, map_old_params_to_new, + q_param, k_param, v_param, dims, groups); + + // create the module in the module pass manager + auto original_submod_name = attn_group_ins->module_inputs().front()->name(); + std::string new_mod_name = original_submod_name + "_gqa_flash_decoding"; + + module_ref mpm_flash_mod = mpm.create_module(new_mod_name, std::move(m_flash_decode)); + mpm_flash_mod->set_bypass(); + + // insert the new group operation + auto new_group_ins = mm.insert_instruction( + attn_group_ins, + make_op("group", {{"tag", "attention"}}), + new_group_inputs, + {mpm_flash_mod}); + + std::cout << "Created GQA flash decoding group" << std::endl; + std::cout << "Group output shape: " << new_group_ins->get_shape() << std::endl; + + // unpack O' and LSE + auto partial_output_o_prime = mm.insert_instruction( + attn_group_ins, make_op("get_tuple_elem", {{"index", 0}}), new_group_ins); + auto lse = mm.insert_instruction( + attn_group_ins, make_op("get_tuple_elem", {{"index", 1}}), new_group_ins); + + // LSE-weighted combination + std::cout << "\n=== Kernel 2: LSE-weighted combination ===" << std::endl; + std::cout << "Input LSE shape: " << lse->get_shape() << std::endl; // [B, N, G, S, 1] = [2, 2, 2, 1, 1] + std::cout << "Input O' shape: " << partial_output_o_prime->get_shape() << std::endl; // [B, G, S, N*D] = [2, 2, 1, 4] + + // align LSE with O' for proper weighting + // LSE is [B, N, G, S, 1], match group dimension of O' [B, G, S, N*D] + + // [B, N, G, S, 1] -> [B, G, N, S, 1] + std::cout << "\n1. Transposing LSE to align group dimension..." << std::endl; + auto lse_transposed = mm.insert_instruction( + attn_group_ins, make_op("transpose", {{"permutation", {0, 2, 1, 3, 4}}}), lse); + std::cout << " LSE transposed shape: " << lse_transposed->get_shape() << std::endl; // [2, 2, 2, 1, 1] + + // average across heads (N) since all heads in a group share the same weight + // [B, G, N, S, 1] -> [B, G, S, 1] (reduce over axis 2, then squeeze) + std::cout << "\n2. Averaging LSE across heads within each group..." << std::endl; + auto lse_avg = mm.insert_instruction( + attn_group_ins, make_op("reduce_mean", {{"axes", {2}}}), lse_transposed); + std::cout << " LSE averaged shape: " << lse_avg->get_shape() << std::endl; // [2, 2, 1, 1, 1] + + // squeeze axes 2 and 4: [B, G, 1, S, 1] -> [B, G, S] + auto lse_squeezed = mm.insert_instruction( + attn_group_ins, make_op("squeeze", {{"axes", {2, 4}}}), lse_avg); + std::cout << " LSE squeezed shape: " << lse_squeezed->get_shape() << std::endl; // [2, 2, 1] + + // softmax across groups for LSE weights + std::cout << "\n3. Computing softmax of LSE across groups..." << std::endl; + + // find max across groups for numerical stability + auto lse_max = mm.insert_instruction( + attn_group_ins, make_op("reduce_max", {{"axes", {1}}}), lse_squeezed); + std::cout << " Max LSE shape: " << lse_max->get_shape() << std::endl; // [2, 1, 1] + + // broadcast max back to original shape + auto lse_max_bcast = mm.insert_instruction( + attn_group_ins, + make_op("multibroadcast", {{"out_lens", lse_squeezed->get_shape().lens()}}), + lse_max); + + // exp(LSE - max_LSE) + auto lse_sub = mm.insert_instruction(attn_group_ins, make_op("sub"), lse_squeezed, lse_max_bcast); + auto lse_exp = mm.insert_instruction(attn_group_ins, make_op("exp"), lse_sub); + std::cout << " Exp(LSE) shape: " << lse_exp->get_shape() << std::endl; // [2, 2, 1] + + // sum exp across groups + auto lse_sum = mm.insert_instruction( + attn_group_ins, make_op("reduce_sum", {{"axes", {1}}}), lse_exp); + std::cout << " Sum exp shape: " << lse_sum->get_shape() << std::endl; // [2, 1, 1] + + // broadcast sum back + auto lse_sum_bcast = mm.insert_instruction( + attn_group_ins, + make_op("multibroadcast", {{"out_lens", lse_exp->get_shape().lens()}}), + lse_sum); + + auto weights = mm.insert_instruction(attn_group_ins, make_op("div"), lse_exp, lse_sum_bcast); + std::cout << " Softmax weights shape: " << weights->get_shape() << std::endl; // [2, 2, 1] + + // weights is [B, G, S], O' is [B, G, S, N*D] + // [B, G, S] -> [B, G, S, 1] + std::cout << "\n4. Preparing weights for multiplication with O'..." << std::endl; + auto weights_unsqueezed = mm.insert_instruction( + attn_group_ins, make_op("unsqueeze", {{"axes", {3}}}), weights); + std::cout << " Weights unsqueezed shape: " << weights_unsqueezed->get_shape() << std::endl; // [2, 2, 1, 1] + + // broadcast to match O' shape + auto weights_bcast = mm.insert_instruction( + attn_group_ins, + make_op("multibroadcast", {{"out_lens", partial_output_o_prime->get_shape().lens()}}), + weights_unsqueezed); + std::cout << " Weights broadcast shape: " << weights_bcast->get_shape() << std::endl; // [2, 2, 1, 4] + + // convert weights to match O' type + auto output_type = partial_output_o_prime->get_shape().type(); + auto weights_converted = mm.insert_instruction( + attn_group_ins, make_op("convert", {{"target_type", output_type}}), weights_bcast); + + // multiply O' by weights + std::cout << "\n5. Multiplying O' by softmax weights..." << std::endl; + auto weighted_output = mm.insert_instruction( + attn_group_ins, make_op("mul"), partial_output_o_prime, weights_converted); + std::cout << " Weighted output shape: " << weighted_output->get_shape() << std::endl; // [2, 2, 1, 4] + + // sum across groups to get final output + std::cout << "\n6. Summing weighted outputs across groups..." << std::endl; + auto final_output = mm.insert_instruction( + attn_group_ins, make_op("reduce_sum", {{"axes", {1}}}), weighted_output); + std::cout << " Final output shape: " << final_output->get_shape() << std::endl; // [2, 1, 1, 4] + + // squeeze the reduced group dimension + auto final_squeezed = mm.insert_instruction( + attn_group_ins, make_op("squeeze", {{"axes", {1}}}), final_output); + std::cout << " Final squeezed shape: " << final_squeezed->get_shape() << std::endl; // [2, 1, 4] + + mm.replace_instruction(attn_group_ins, final_squeezed); + + std::cout << "\n=== Kernel 2 complete: LSE-weighted combination successful ===" << std::endl; + } +}; + struct find_flash_decoding { // configuration from fuse_attention pass config @@ -947,12 +1917,17 @@ void fuse_attention::apply(module_pass_manager& mpm) const std::size_t configured_splits = get_num_splits(flash_decoding_num_splits); if(configured_splits > 0 or flash_decoding_enabled) { + // flash decoding for regular attention, single & multi-headed match::find_matches( mpm, find_flash_decoding{.configured_splits = flash_decoding_num_splits, .configured_threshold = flash_decoding_threshold, .configured_max_splits = flash_decoding_max_splits, .configured_min_chunk_size = flash_decoding_min_chunk_size}); + + // flash decoding for GQA attention + match::find_matches( + mpm, find_gqa_flash_decoding{.groups = num_splits}); mpm.run_pass(dead_code_elimination{}); } } From 250d0dba1ffde2d03611ca9d4e324ada4addb6ea Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Sat, 3 Jan 2026 22:24:43 -0600 Subject: [PATCH 2/4] update rocmlir commit --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index ced32b9b34b..bcc7a984bac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,4 +29,4 @@ pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off -DMSGPACK_BUILD_EXAMPLES=Off -DCMAKE_POLICY_VERSION_MINIMUM=3.5 sqlite3@3.50.4 -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/composable_kernel@b7775add2d28251674d81e220cd4a857b90b997a -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/rocMLIR@16ccb523bb3d8af67796c12ab9c15c9b69d69b58 -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off +ROCm/rocMLIR@5878fd6a5f4ec92101f9cd3d279e6240086cb3b4 -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off From f0061e711afcdc805f18fa47b12dd2624cc9114b Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Wed, 14 Jan 2026 13:00:26 -0600 Subject: [PATCH 3/4] AIMIGRAPHX-466 ; still need to consolidate broadcast/multibroadcast transformations --- src/fuse_attention.cpp | 1151 ++++++++++++++++------------------------ 1 file changed, 461 insertions(+), 690 deletions(-) diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index b086bbf25c3..02261336a3b 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -137,6 +138,34 @@ inline auto pointwise_inputs() }; } +// Helper function to extract the two gemm operations from an attention submodule +// Returns {gemm1, gemm2} where gemm1 is Q@K and gemm2 is P@V +inline std::pair get_attention_gemms(module_ref submod) +{ + std::vector gemms; + for(auto it = submod->begin(); it != submod->end(); ++it) + { + if(it->name() == "dot") + gemms.push_back(it); + } + assert(gemms.size() == 2 and "Expected exactly 2 gemm operations in attention submodule"); + + // gemms[0] is Q@K, gemms[1] is P@V + // gemms are in order since we iterate from begin to end + return {gemms[0], gemms[1]}; +} + +// Helper function to map submodule parameters to main module inputs +inline std::unordered_map +map_submod_params_to_inputs(module_ref submod, const std::vector& group_inputs) +{ + auto map_param_to_main = submod->get_ins_param_map(group_inputs, true); + // verify the mapping is correct + auto expected_inputs = submod->get_inputs(map_param_to_main); + assert(expected_inputs == group_inputs and "Mapped inputs don't match group inputs"); + return map_param_to_main; +} + struct find_attention { std::size_t* counter; @@ -309,48 +338,84 @@ struct find_gqa_flash_decoding { auto q_shape = q_param->get_shape(); auto k_shape = k_param->get_shape(); - + batch_size = q_shape.lens()[0]; concat_heads = q_shape.lens()[1]; sequence_length = q_shape.lens()[2]; head_dim = q_shape.lens()[3]; - + kv_heads = k_shape.lens()[1]; max_seq_length = k_shape.lens()[2]; - + // calculate Q heads from concat_heads = num_heads + 2 * kv_heads num_heads = concat_heads - 2 * kv_heads; - + // calculate sequence length per group if(max_seq_length % num_groups != 0) { std::cout << "Max sequence length " << max_seq_length << " not divisible by " << num_groups << " groups" << std::endl; // TODO: Add padding support - seq_length_per_group = 0; // Set to 0 to indicate error - return; + seq_length_per_group = 0; + } else { + seq_length_per_group = max_seq_length / num_groups; } - seq_length_per_group = max_seq_length / num_groups; } }; - auto matcher() const + // Helper function: adjust transpose permutation when inserting a group dimension + // insert_pos_from_end: position where group dimension is inserted in ORIGINAL tensor + // (negative values count from end, e.g., -2 means second-to-last) + // original_perm: the original permutation vector + // Returns: adjusted permutation vector + std::vector adjust_transpose_perm(const std::vector& original_perm, + int insert_pos_from_end) const { - return match::name("group")(match::has_op_value("tag", "kv_cache_attention")).bind("group"); + // Convert negative position to actual position in the ORIGINAL tensor + // For a 4D tensor, insert_pos_from_end=-2 means position 2 (0-indexed) + size_t original_rank = original_perm.size(); + int actual_insert_pos = insert_pos_from_end; + if(insert_pos_from_end < 0) { + actual_insert_pos = static_cast(original_rank) + insert_pos_from_end; + } + + // Adjust permutation values: any dimension >= insert_pos shifts up by 1 + std::vector new_perm; + for(auto idx : original_perm) { + if(idx >= actual_insert_pos) { + new_perm.push_back(idx + 1); + } else { + new_perm.push_back(idx); + } + } + + // Insert the group dimension at its natural position in the output + // The group dimension itself appears at position actual_insert_pos + new_perm.insert(new_perm.begin() + actual_insert_pos, actual_insert_pos); + + return new_perm; } - - std::pair get_gemms(module_ref submod) const + + // Helper function: adjust axes when a group dimension is inserted + // axes: original axes vector + // group_dim_pos: position where group dimension is inserted (in ORIGINAL tensor coordinates) + // Returns: adjusted axes + std::vector adjust_axes(const std::vector& axes, int group_dim_pos) const { - std::vector gemms; - for(auto it = submod->begin(); it != submod->end(); ++it) - { - if(it->name() == "dot") - gemms.push_back(it); + std::vector adjusted; + for(auto axis : axes) { + // If axis >= group_dim_pos, shift it by 1 + if(axis >= group_dim_pos) { + adjusted.push_back(axis + 1); + } else { + adjusted.push_back(axis); + } } - assert(gemms.size() == 2 and "Expected exactly 2 gemm operations in attention submodule"); + return adjusted; + } - // gemms[0] is Q@K, gemms[1] is P@V - // gemms are in order since we iterate from begin to end - return {gemms[0], gemms[1]}; + auto matcher() const + { + return match::name("group")(match::has_op_value("tag", "kv_cache_attention")).bind("group"); } // Helper to extract Q, K, V parameters from the attention submodule's gemm inputs @@ -358,538 +423,383 @@ struct find_gqa_flash_decoding instruction_ref q_param; // Parameter containing Q (full QKV tensor) instruction_ref k_param; // Parameter for K (concat_past_present output) instruction_ref v_param; // Parameter for V (concat_past_present output) - }; - - std::unordered_map - map_submod_params_to_inputs(module_ref submod, - const std::vector& group_inputs) const - { - auto map_param_to_main = submod->get_ins_param_map(group_inputs, true); - // verify the mapping is correct - auto expected_inputs = submod->get_inputs(map_param_to_main); - assert(expected_inputs == group_inputs and "Mapped inputs don't match group inputs"); - return map_param_to_main; - } - // rebuild GQA attention operations in flash decoding submodule - // Helper to find early exit masking operations - struct early_exit_mask_ops { - instruction_ref pos_literal; // Literal with position indices {0,1,2,3...} - instruction_ref pos_broadcast; // Broadcast of position literal - instruction_ref seq_len_param; // Sequence length parameter - instruction_ref seq_multicast; // Multibroadcast of seq_len - instruction_ref greater_op; // Greater comparison - instruction_ref convert_op; // Convert to bool - instruction_ref unsqueeze_op; // Unsqueeze mask - instruction_ref mask_broadcast; // Final multibroadcast of mask - instruction_ref ninf_literal; // -inf literal for masking - instruction_ref ninf_broadcast; // Multibroadcast of -inf - instruction_ref where_op; // Where operation applying mask - - // Flags to track which operations were found - bool found = false; - bool has_pos_literal = false; - bool has_pos_broadcast = false; - bool has_seq_len_param = false; - bool has_seq_multicast = false; - bool has_greater = false; - bool has_convert = false; - bool has_unsqueeze = false; - bool has_mask_broadcast = false; - bool has_ninf_literal = false; - bool has_ninf_broadcast = false; - }; - - early_exit_mask_ops find_early_exit_masking_ops( - const module& source_mod, - instruction_ref scaled_scores, - const std::unordered_map& map_old_to_new) const - { - early_exit_mask_ops mask_ops; - - // Find the where operation that uses our scaled scores - for(auto ins : iterator_for(source_mod)) { - if(ins->name() == "where") { - // Check if one of its inputs is our scaled scores (through the mapping) - for(auto input : ins->inputs()) { - if(contains(map_old_to_new, input) && map_old_to_new.at(input) == scaled_scores) { - mask_ops.where_op = ins; - mask_ops.found = true; - break; + // factory method to extract Q, K, V parameters from gemm operations + static std::optional from_gemms(instruction_ref gemm1, instruction_ref gemm2) + { + /* + - gemm1: Q@K + - gemm2: P@V + */ + auto trace_back_to_param = [](instruction_ref ins) -> std::optional { + instruction_ref current = ins; + while(current->name() != "@param") { + if(current->inputs().empty()) { + return std::nullopt; } + current = current->inputs()[0]; } - if(mask_ops.found) break; - } - } - - if(!mask_ops.found) { - return mask_ops; - } - - // Get the three inputs to where: mask, true_value (-inf), false_value (scores) - auto mask_input = mask_ops.where_op->inputs()[0]; - mask_ops.ninf_broadcast = mask_ops.where_op->inputs()[1]; - - // Trace back the mask to find multibroadcast -> unsqueeze -> convert -> greater - instruction_ref current = mask_input; - - // Should be multibroadcast - if(current->name() == "multibroadcast") { - mask_ops.mask_broadcast = current; - mask_ops.has_mask_broadcast = true; - current = current->inputs()[0]; - } - - // Should be unsqueeze - if(current->name() == "unsqueeze") { - mask_ops.unsqueeze_op = current; - mask_ops.has_unsqueeze = true; - current = current->inputs()[0]; - } - - // Should be convert - if(current->name() == "convert") { - mask_ops.convert_op = current; - mask_ops.has_convert = true; - current = current->inputs()[0]; - } - - // Should be greater - if(current->name() == "greater") { - mask_ops.greater_op = current; - mask_ops.has_greater = true; - - // Get inputs to greater - auto pos_input = mask_ops.greater_op->inputs()[0]; - auto seq_input = mask_ops.greater_op->inputs()[1]; - - // Position side: broadcast -> literal - if(pos_input->name() == "broadcast") { - mask_ops.pos_broadcast = pos_input; - mask_ops.has_pos_broadcast = true; - mask_ops.pos_literal = pos_input->inputs()[0]; - mask_ops.has_pos_literal = true; - } - - // Sequence length side: multibroadcast -> param - if(seq_input->name() == "multibroadcast") { - mask_ops.seq_multicast = seq_input; - mask_ops.has_seq_multicast = true; - mask_ops.seq_len_param = seq_input->inputs()[0]; - mask_ops.has_seq_len_param = true; - } - } - - // Find the -inf literal source - if(mask_ops.ninf_broadcast->name() == "multibroadcast") { - mask_ops.has_ninf_broadcast = true; - mask_ops.ninf_literal = mask_ops.ninf_broadcast->inputs()[0]; - mask_ops.has_ninf_literal = true; + return current; + }; + + auto q_input = gemm1->inputs()[0]; + auto k_input = gemm1->inputs()[1]; + auto v_input = gemm2->inputs()[1]; + + // trace back Q, K, V to find the parameters they originate from + auto q_param_opt = trace_back_to_param(q_input); + auto k_param_opt = trace_back_to_param(k_input); + auto v_param_opt = trace_back_to_param(v_input); + + if(not q_param_opt or not k_param_opt or not v_param_opt) return std::nullopt; + return qkv_params{*q_param_opt, *k_param_opt, *v_param_opt}; } - - return mask_ops; - } + }; void rebuild_gqa_attention(module& target_mod, const module& source_mod, const std::unordered_map& param_map, - instruction_ref q_param, - instruction_ref k_param, - instruction_ref v_param, + instruction_ref gemm2, const attention_dims& dims, std::size_t num_groups) const { - // map from instructions in old module to new module - std::unordered_map map_old_to_new = param_map; - // TODO can do this better, and also make it better for other flash decoding case - // track softmax components for LSE calculation - std::unordered_map softmax_parts; - - assert(contains(param_map, q_param) && "Q parameter must be mapped"); - assert(contains(param_map, k_param) && "K parameter must be mapped"); - assert(contains(param_map, v_param) && "V parameter must be mapped"); - (void)v_param; // Will be used later for V operations + std::cout << "Rebuilding GQA attention with inserter..." << std::endl; + std::cout << "Second gemm (will stop after): " << gemm2->name() << std::endl; - // handle Q extraction - // since we slice on axis 1 (concat_heads) and groups are at axis 2, no change needed - for(auto ins : iterator_for(source_mod)) { - if(ins->name() == "slice" && ins->inputs()[0] == q_param) { - auto op = ins->get_operator(); - auto new_q = map_old_to_new.at(q_param); - auto sliced_q = target_mod.add_instruction(op, new_q); - map_old_to_new[ins] = sliced_q; - std::cout << " Q slice created, shape: " << sliced_q->get_shape() << std::endl; - break; - } - } + // Calculate the group dimension position in the original tensor (4D) + // Group is inserted at position -2, which is index 2 in a 4D tensor + int group_dim_pos = 2; - // handle K transpose - instruction_ref transposed_k; - for(auto ins : iterator_for(source_mod)) { - if(ins->name() == "transpose") { - auto transpose_input = ins->inputs()[0]; - if(transpose_input == k_param) { - auto op = ins->get_operator(); - auto perm = op.to_value()["permutation"].to_vector(); - - // dims.batch_size, dims.kv_heads, groups, dims.seq_length_per_group, dims.head_dim} - // perm is now [0, 1, 2, 4, 3] for [B, H, G, D, S] - std::vector new_perm = {0, 1, 2, 4, 3}; - auto new_transpose_op = make_op("transpose", {{"permutation", new_perm}}); - auto new_k = map_old_to_new.at(k_param); - transposed_k = target_mod.add_instruction(new_transpose_op, new_k); - map_old_to_new[ins] = transposed_k; - - break; - } - } - } - - // ninf is of shape - // {batch_size, num_heads, sequence_length, max_seq_len} - - - // handle literal constants and their broadcasts - for(auto ins : iterator_for(source_mod)) { - if(ins->name() == "@literal") { - // copy literals directly - auto lit_val = ins->get_literal(); - auto new_lit = target_mod.add_literal(lit_val); - map_old_to_new[ins] = new_lit; - std::cout << " Added literal with shape: " << new_lit->get_shape() << std::endl; - } - } - - // TODO handle when kv_heads != num_heads - // define expected broadcast shapes for literals - std::vector bnsm{dims.batch_size, dims.num_heads, dims.sequence_length, dims.max_seq_length}; + // Define the BNGSM shape for broadcasts (needed in the inserter) std::vector bngsm{dims.batch_size, dims.num_heads, num_groups, dims.sequence_length, dims.seq_length_per_group}; - - // update multibroadcast shapes for literals - // if broadcasting to [B, N, S, M] shape, change to [B, N, G, S, M/G] - // Keep track of specific broadcasts for -inf and scale - instruction_ref ninf_broadcast; - instruction_ref scale_broadcast; + std::vector bnsm{dims.batch_size, dims.num_heads, dims.sequence_length, dims.max_seq_length}; - for(auto ins : iterator_for(source_mod)) { - if(ins->name() == "multibroadcast" && - !contains(map_old_to_new, ins)) { - auto input = ins->inputs()[0]; - - // check if input is a literal and is already mapped - if(contains(map_old_to_new, input) && input->name() == "@literal") { - auto op = ins->get_operator(); - auto out_lens = op.to_value()["out_lens"].to_vector(); - - // check if the shape matches [B, N, S, M] pattern for attention scores - if(out_lens == bnsm) { - // use the pre-defined bngsm shape - auto new_op = make_op("multibroadcast", {{"out_lens", bngsm}}); - auto new_input = map_old_to_new.at(input); - auto new_broadcast = target_mod.add_instruction(new_op, new_input); - map_old_to_new[ins] = new_broadcast; - - // Check the literal value to identify -inf or scale - auto lit = input->get_literal(); - if(lit.get_shape().type() == migraphx::shape::half_type) { - // Get the literal value as a string for comparison - auto lit_str = lit.to_string(); - if(lit_str == "-inf") { - ninf_broadcast = new_broadcast; - std::cout << "adjusted -inf multibroadcast from BNSM to BNGSM: " - << new_broadcast->get_shape() << std::endl; - } else { - // Assume it's the scale value - scale_broadcast = new_broadcast; - std::cout << "adjusted scale multibroadcast from BNSM to BNGSM: " - << new_broadcast->get_shape() << " (value: " << lit_str << ")" << std::endl; - } + // Track reduce operations for LSE calculation + instruction_ref reduce_max_ref; + instruction_ref reduce_sum_ref; + bool found_reduce_max = false; + bool found_reduce_sum = false; + + // Will get the second dot result after copying + instruction_ref second_dot_result; + + // Create the inserter function that transforms operations + auto inserter = [&](module& m, + instruction_ref ins, + const operation& op, + const std::vector& inputs, + const std::vector& mod_args) -> instruction_ref { + + auto op_name = op.name(); + + // Helper to print shape + auto print_shape = [](const std::vector& lens) { + std::cout << "{"; + for(size_t i = 0; i < lens.size(); ++i) { + std::cout << lens[i]; + if(i < lens.size() - 1) std::cout << ","; + } + std::cout << "}"; + }; + + auto print_output = [&print_shape](instruction_ref result) { + std::cout << " Output shape: "; + print_shape(result->get_shape().lens()); + std::cout << std::endl; + }; + + // Helper to print operation attributes + auto print_op_attrs = [](const operation& o) { + try { + auto val = o.to_value(); + if(!val.empty()) { + std::stringstream ss; + ss << val; + auto str = ss.str(); + if(!str.empty()) { + std::cout << " " << str; } - } else { - // for other shapes, just copy the multibroadcast as-is - auto new_input = map_old_to_new.at(input); - auto new_broadcast = target_mod.add_instruction(op, new_input); - map_old_to_new[ins] = new_broadcast; } + } catch(...) {} + }; + + // Debug: print operation being processed + std::cout << "\n>>> Processing op: " << op_name; + print_op_attrs(op); + std::cout << std::endl; + std::cout << " Input shapes: "; + for(size_t i = 0; i < inputs.size(); ++i) { + print_shape(inputs[i]->get_shape().lens()); + if(i < inputs.size() - 1) std::cout << ", "; + } + std::cout << std::endl; + + // Transpose: adjust permutation + if(op_name == "transpose") { + auto perm = op.to_value()["permutation"].to_vector(); + auto new_perm = adjust_transpose_perm(perm, -2); + + std::cout << " Adjusted perm: ["; + for(size_t i = 0; i < perm.size(); ++i) + std::cout << perm[i] << (i < perm.size()-1 ? "," : ""); + std::cout << "] -> ["; + for(size_t i = 0; i < new_perm.size(); ++i) + std::cout << new_perm[i] << (i < new_perm.size()-1 ? "," : ""); + std::cout << "]" << std::endl; + + auto new_op = make_op("transpose", {{"permutation", new_perm}}); + std::cout << " Creating: transpose"; + print_op_attrs(new_op); + std::cout << std::endl; + auto result = m.insert_instruction(ins, new_op, inputs, mod_args); + print_output(result); + return result; + } + + // Reduce operations: adjust axes + if(op_name == "reduce_max" || op_name == "reduce_sum") { + auto axes = op.to_value()["axes"].to_vector(); + auto new_axes = adjust_axes(axes, group_dim_pos); + + std::cout << " Adjusted axes: [" << axes[0] + << "] -> [" << new_axes[0] << "]" << std::endl; + + auto new_op = make_op(op_name, {{"axes", new_axes}}); + std::cout << " Creating: " << op_name; + print_op_attrs(new_op); + std::cout << std::endl; + auto result = m.insert_instruction(ins, new_op, inputs, mod_args); + + // Track these for LSE calculation + if(op_name == "reduce_max") { + reduce_max_ref = result; + found_reduce_max = true; + } + if(op_name == "reduce_sum") { + reduce_sum_ref = result; + found_reduce_sum = true; } + + print_output(result); + return result; } - } - - // Q slice [batch, num heads, groups, sl, max sl] - // check if we found the specific broadcasts - bool has_scale = false; - bool has_ninf = false; - for(auto ins : iterator_for(target_mod)) { - if(ins == scale_broadcast) has_scale = true; - if(ins == ninf_broadcast) has_ninf = true; - } - - if(has_scale) { - std::cout << " found scale broadcast for attention scaling" << std::endl; - } - if(has_ninf) { - std::cout << " found -inf broadcast for masking" << std::endl; - } - - // handle first dot product (Q @ K^T) - std::cout << "rebuilding first dot product (Q @ K^T)..." << std::endl; - instruction_ref dot1; - for(auto ins : iterator_for(source_mod)) { - if(ins->name() == "dot") { - // check if this is the first dot (Q @ K^T) - // it should have Q (or sliced Q) as first input and K transpose as second - auto input0 = ins->inputs()[0]; - auto input1 = ins->inputs()[1]; + + // Broadcast: if output is 2D {batch, max_seq_length}, add reshape to 3D + if(op_name == "broadcast") { + auto result = m.insert_instruction(ins, op, inputs, mod_args); + auto result_lens = result->get_shape().lens(); - // check if we've already mapped these inputs (Q slice and K transpose) - if(contains(map_old_to_new, input0) && contains(map_old_to_new, input1)) { - auto new_q = map_old_to_new.at(input0); - auto new_kt = map_old_to_new.at(input1); + // Check if result is {batch, max_seq_length} + if(result_lens.size() == 2 && + result_lens[0] == dims.batch_size && + result_lens[1] == dims.max_seq_length) { + std::cout << " Creating: broadcast (unchanged)"; + print_op_attrs(op); + std::cout << std::endl; + print_output(result); - // create the dot product with transformed inputs - dot1 = target_mod.add_instruction(make_op("dot"), new_q, new_kt); - map_old_to_new[ins] = dot1; + // Add reshape to split sequence dimension into groups + std::vector reshape_dims = {dims.batch_size, num_groups, dims.seq_length_per_group}; + std::cout << "\n>>> Auto-inserting reshape after 2D broadcast" << std::endl; + std::cout << " Reshape dims: "; + print_shape(reshape_dims); + std::cout << std::endl; - std::cout << " created dot1 (Q @ K^T) with shape: " << dot1->get_shape() << std::endl; - break; // assume first dot is Q @ K^T + auto reshape_result = m.insert_instruction(ins, make_op("reshape", {{"dims", reshape_dims}}), result); + std::cout << " Creating: reshape"; + auto reshape_op = make_op("reshape", {{"dims", reshape_dims}}); + print_op_attrs(reshape_op); + std::cout << std::endl; + print_output(reshape_result); + return reshape_result; } + + std::cout << " Creating: broadcast (unchanged)"; + print_op_attrs(op); + std::cout << std::endl; + print_output(result); + return result; } - } - - // handle scaling (mul with scale factor) - std::cout << "finding and rebuilding scale multiplication..." << std::endl; - instruction_ref scaled_scores; - - // check if we have both dot1 and scale_broadcast - bool has_dot1 = false; - bool has_scale_bc = false; - for(auto ins : iterator_for(target_mod)) { - if(ins == dot1) has_dot1 = true; - if(ins == scale_broadcast) has_scale_bc = true; - } - - if(has_dot1 && has_scale_bc) { - scaled_scores = target_mod.add_instruction(make_op("mul"), dot1, scale_broadcast); - std::cout << " created scaled scores with shape: " << scaled_scores->get_shape() << std::endl; - } else if(has_dot1) { - // ff we don't have scale_broadcast, try to find the mul in the original - for(auto ins : iterator_for(source_mod)) { - if(ins->name() == "mul") { - auto input0 = ins->inputs()[0]; - auto input1 = ins->inputs()[1]; + + // Multibroadcast: adjust output shape if it matches BNSM pattern + if(op_name == "multibroadcast") { + auto out_lens = op.to_value()["out_lens"].to_vector(); + + // Check if output is 2D {batch, max_seq_length} - needs reshape + if(out_lens.size() == 2 && + out_lens[0] == dims.batch_size && + out_lens[1] == dims.max_seq_length) { + std::cout << " 2D multibroadcast - will add reshape" << std::endl; + auto result = m.insert_instruction(ins, op, inputs, mod_args); + std::cout << " Creating: multibroadcast (unchanged)"; + print_op_attrs(op); + std::cout << std::endl; + print_output(result); - if(contains(map_old_to_new, input0) && contains(map_old_to_new, input1)) { - bool input0_is_dot = (map_old_to_new.at(input0)->name() == "dot"); - bool input1_is_dot = (map_old_to_new.at(input1)->name() == "dot"); - - if(input0_is_dot || input1_is_dot) { - auto new_input0 = map_old_to_new.at(input0); - auto new_input1 = map_old_to_new.at(input1); - - scaled_scores = target_mod.add_instruction(make_op("mul"), new_input0, new_input1); - map_old_to_new[ins] = scaled_scores; - - std::cout << " created scaled scores with shape: " << scaled_scores->get_shape() << std::endl; - break; - } - } - } - } - } - - // For kv_cache_attention, rebuild early exit masking with modified broadcast shapes - std::cout << "rebuilding early exit masking with adjusted broadcasts..." << std::endl; - - // Find the range literal (position indices like {0,1,2,3...}) - instruction_ref range_literal; - bool found_range = false; - for(auto ins : iterator_for(source_mod)) { - if(ins->name() == "@literal") { - auto shape = ins->get_literal().get_shape(); - if(shape.type() == migraphx::shape::int32_type && - shape.lens().size() == 1 && - shape.lens()[0] == dims.max_seq_length) { - range_literal = ins; - found_range = true; - std::cout << " found range literal with shape: " << shape << std::endl; - break; + // Add reshape to split sequence dimension into groups + std::vector reshape_dims = {dims.batch_size, num_groups, dims.seq_length_per_group}; + std::cout << "\n>>> Auto-inserting reshape after 2D multibroadcast" << std::endl; + std::cout << " Reshape dims: "; + print_shape(reshape_dims); + std::cout << std::endl; + + auto reshape_result = m.insert_instruction(ins, make_op("reshape", {{"dims", reshape_dims}}), result); + std::cout << " Creating: reshape"; + auto reshape_op = make_op("reshape", {{"dims", reshape_dims}}); + print_op_attrs(reshape_op); + std::cout << std::endl; + print_output(reshape_result); + return reshape_result; } - } - } - - // Find the past_sl parameter (past sequence length) - instruction_ref past_sl_param; - bool found_past_sl = false; - for(auto param : iterator_for(source_mod)) { - if(param->name() == "@param") { - auto shape = param->get_shape(); - // past_sl is int32 type with batch_size elements (e.g., [2,1]) - if(shape.type() == migraphx::shape::int32_type && - shape.elements() == dims.batch_size) { - past_sl_param = param; - found_past_sl = true; - std::cout << " found past_sl param with shape: " << shape << std::endl; - break; + + // Check if this is a 4D multibroadcast with max_seq_length at last dimension + // Insert group dimension at position -2 and split last dimension + if(out_lens.size() == 4 && out_lens[3] == dims.max_seq_length) { + std::cout << " 4D multibroadcast with max_seq_length - inserting group dim" << std::endl; + + // Build new shape: insert group at -2, change last dim to seq_per_group + std::vector new_lens = {out_lens[0], out_lens[1], num_groups, out_lens[2], dims.seq_length_per_group}; + + std::cout << " Adjusted: "; + print_shape(out_lens); + std::cout << " -> "; + print_shape(new_lens); + std::cout << std::endl; + + auto new_op = make_op("multibroadcast", {{"out_lens", new_lens}}); + std::cout << " Creating: multibroadcast"; + print_op_attrs(new_op); + std::cout << std::endl; + auto result = m.insert_instruction(ins, new_op, inputs, mod_args); + print_output(result); + return result; } } - } - - if(!found_range) { - std::cout << " WARNING: Could not find range literal" << std::endl; - } - if(!found_past_sl) { - std::cout << " WARNING: Could not find past_sl parameter" << std::endl; - } - - // create broadcast shape vector - std::vector broadcast_shape = {dims.batch_size, num_groups, dims.seq_length_per_group}; - std::cout << " broadcast shape: [" << dims.batch_size << ", " << num_groups << ", " - << dims.seq_length_per_group << "]" << std::endl; - - instruction_ref range_broadcast; - instruction_ref past_sl_reshaped; - - if(found_range && found_past_sl && contains(param_map, past_sl_param)) { - // range literal broadcast to [batch_size, num_groups, seq_length_per_group] - if(!contains(map_old_to_new, range_literal)) { - auto lit_val = range_literal->get_literal(); - auto new_lit = target_mod.add_literal(lit_val); - map_old_to_new[range_literal] = new_lit; - } - - // Broadcast range literal directly, matching original pattern - // Use axis=1 to match the original pattern (broadcast[axis=1] adds batch dimension at front) - std::vector intermediate_bc_shape = {dims.batch_size, dims.max_seq_length}; - auto range_broadcast_intermediate = target_mod.add_instruction( - make_op("broadcast", {{"axis", 1}, {"out_lens", intermediate_bc_shape}}), - map_old_to_new.at(range_literal)); - std::cout << " broadcasted range to: " << range_broadcast_intermediate->get_shape() << std::endl; - // Then reshape to final shape [batch_size, num_groups, seq_length_per_group] - range_broadcast = target_mod.add_instruction( - make_op("reshape", {{"dims", broadcast_shape}}), - range_broadcast_intermediate); - std::cout << " reshaped range to final shape: " << range_broadcast->get_shape() << std::endl; - - // past_sl param from [batch_size, 1] to [batch_size, max_seq_length] - // then reshape to [batch_size, num_groups, seq_length_per_group] - auto past_sl_new = param_map.at(past_sl_param); - - std::vector intermediate_shape = {dims.batch_size, dims.max_seq_length}; - auto past_sl_broadcast = target_mod.add_instruction( - make_op("multibroadcast", {{"out_lens", intermediate_shape}}), - past_sl_new); - std::cout << " multibroadcasted past_sl to: " << past_sl_broadcast->get_shape() << std::endl; - - past_sl_reshaped = target_mod.add_instruction( - make_op("reshape", {{"dims", broadcast_shape}}), - past_sl_broadcast); - std::cout << " reshaped past_sl to: " << past_sl_reshaped->get_shape() << std::endl; - - auto greater = target_mod.add_instruction( - make_op("greater"), range_broadcast, past_sl_reshaped); - auto convert = target_mod.add_instruction( - make_op("convert", {{"target_type", migraphx::shape::bool_type}}), greater); - auto unsqueeze = target_mod.add_instruction( - make_op("unsqueeze", {{"axes", {1, 3}}}), convert); - auto multibroadcast = target_mod.add_instruction( - make_op("multibroadcast", {{"out_lens", bngsm}}), unsqueeze); - - // Check if we have the ninf_broadcast before using it - bool has_ninf_bc = false; - for(auto ins : iterator_for(target_mod)) { - if(ins == ninf_broadcast) { - has_ninf_bc = true; - break; + // Unsqueeze: adjust axes (need to account for group dimension) + if(op_name == "unsqueeze") { + auto axes = op.to_value()["axes"].to_vector(); + // For unsqueeze, we need to adjust axes that are >= group_dim_pos + // But unsqueeze operates on the tensor BEFORE the dimensions are added + // So we need different logic + std::vector new_axes; + for(auto axis : axes) { + // If inserting group at position 2 in a 3D tensor (after greater/convert) + // and we want to unsqueeze at {1, 2}, it becomes {1, 3} (skip the group position) + if(axis >= group_dim_pos) { + new_axes.push_back(axis + 1); + } else { + new_axes.push_back(axis); + } } + + std::cout << " Adjusted axes: ["; + for(size_t i = 0; i < axes.size(); ++i) + std::cout << axes[i] << (i < axes.size()-1 ? "," : ""); + std::cout << "] -> ["; + for(size_t i = 0; i < new_axes.size(); ++i) + std::cout << new_axes[i] << (i < new_axes.size()-1 ? "," : ""); + std::cout << "]" << std::endl; + + auto new_op = make_op("unsqueeze", {{"axes", new_axes}}); + std::cout << " Creating: unsqueeze"; + print_op_attrs(new_op); + std::cout << std::endl; + auto result = m.insert_instruction(ins, new_op, inputs, mod_args); + print_output(result); + return result; } - if(has_ninf_bc) { - auto where = target_mod.add_instruction(make_op("where"), multibroadcast, ninf_broadcast, scaled_scores); - scaled_scores = where; // Update scaled_scores to the masked version + // Reshape: need to adjust dims if they span the group dimension + // This is tricky - for now, check if it's the mask reshape + // TODO + if(op_name == "reshape") { + auto dims_vec = op.to_value()["dims"].to_vector(); - } else { - std::cout << " WARNING: Could not find ninf_broadcast for where operation" << std::endl; + // Check if this is a mask reshape that needs to split the sequence dimension + // e.g., {batch, max_seq_length} -> {batch, num_groups, seq_per_group} + if(dims_vec.size() == 2 && dims_vec[1] == dims.max_seq_length) { + std::vector new_dims = {dims.batch_size, num_groups, dims.seq_length_per_group}; + std::cout << " Adjusted dims: "; + print_shape(dims_vec); + std::cout << " -> "; + print_shape(new_dims); + std::cout << std::endl; + + auto new_op = make_op("reshape", {{"dims", new_dims}}); + std::cout << " Creating: reshape"; + print_op_attrs(new_op); + std::cout << std::endl; + auto result = m.insert_instruction(ins, new_op, inputs, mod_args); + print_output(result); + return result; + } } - } - - // quick implementation of remaining ops for testing - - // convert to float for softmax computation - auto convert_to_float = target_mod.add_instruction( - make_op("convert", {{"target_type", migraphx::shape::float_type}}), scaled_scores); - - // Reduce max along last axis (axis 4 in BNGSM) - auto reduce_max = target_mod.add_instruction( - make_op("reduce_max", {{"axes", {4}}}), convert_to_float); - - // Broadcast max back to original shape - auto max_broadcast = target_mod.add_instruction( - make_op("multibroadcast", {{"out_lens", bngsm}}), reduce_max); - - // Subtract max for numerical stability - auto sub = target_mod.add_instruction( - make_op("sub"), convert_to_float, max_broadcast); - - // Exp - auto exp_scores = target_mod.add_instruction( - make_op("exp"), sub); + + // Default: copy operation as-is + std::cout << " Creating: " << op_name << " (unchanged)"; + print_op_attrs(op); + std::cout << std::endl; + auto result = m.insert_instruction(ins, op, inputs, mod_args); + print_output(result); + return result; + }; - // Reduce sum along last axis - auto reduce_sum = target_mod.add_instruction( - make_op("reduce_sum", {{"axes", {4}}}), exp_scores); + // Find the instruction after gemm2 to use as the 'last' parameter + instruction_ref stop_point = std::next(gemm2); - // Broadcast sum back - auto sum_broadcast = target_mod.add_instruction( - make_op("multibroadcast", {{"out_lens", bngsm}}), reduce_sum); + std::cout << "Copying instructions from source module up to (not including) instruction after gemm2..." << std::endl; - // Divide to get softmax - auto softmax = target_mod.add_instruction( - make_op("div"), exp_scores, sum_broadcast); + // Use add_instructions with range [begin, stop_point) to copy up to and including gemm2 + std::unordered_map map_old_to_new = param_map; + target_mod.add_instructions(source_mod.begin(), stop_point, &map_old_to_new, inserter); - // Convert back to half - auto convert_to_half = target_mod.add_instruction( - make_op("convert", {{"target_type", migraphx::shape::half_type}}), softmax); - - std::cout << "now doing the dot between the convert and v"; - std::cout << v_param->get_shape() << std::endl; - - // We need to use the mapped V parameter, not the original v_param - // The v_param passed in is from the original submodule, we need the one in param_map - auto v_mapped = param_map.at(v_param); - std::cout << "V mapped shape: " << v_mapped->get_shape() << std::endl; + // Get the transformed gemm2 + if(!contains(map_old_to_new, gemm2)) { + std::cout << "ERROR: gemm2 not found in map!" << std::endl; + return; + } + second_dot_result = map_old_to_new.at(gemm2); - // Dot with V - auto dot2 = target_mod.add_instruction( - make_op("dot"), convert_to_half, v_mapped); - std::cout << "Dot2 shape: " << dot2->get_shape() << std::endl; + std::cout << "\n=== Instructions copied and transformed ===" << std::endl; - + std::cout << "Second dot shape: " << second_dot_result->get_shape() << std::endl; - // for flash decoding, we keep the group dimension and return it - // kernel 2 will handle the LSE-weighted reduction - // dot2 is currently [B, N, G, S, D] + std::cout << "\n=== Adding final transpose and reshape for flash decoding ===" << std::endl; - // transpose to [B, G, S, N, D] to match flash decoding output pattern + // Add correct flash decoding transpose: {B, N, G, S, D} -> {B, G, S, N, D} + std::cout << "Adding transpose with permutation {0, 2, 3, 1, 4}" << std::endl; auto transpose_out = target_mod.add_instruction( - make_op("transpose", {{"permutation", {0, 2, 3, 1, 4}}}), dot2); - std::cout << "Transpose shape: " << transpose_out->get_shape() << std::endl; + make_op("transpose", {{"permutation", {0, 2, 3, 1, 4}}}), + second_dot_result); + std::cout << "Transpose output shape: " << transpose_out->get_shape() << std::endl; - // reshape to [B, G, S, N*D] + // Add correct flash decoding reshape: {B, G, S, N, D} -> {B, G, S, N*D} std::vector final_shape = {dims.batch_size, num_groups, dims.sequence_length, dims.num_heads * dims.head_dim}; + std::cout << "Adding reshape with dims {"; + for(size_t i = 0; i < final_shape.size(); ++i) { + std::cout << final_shape[i]; + if(i < final_shape.size() - 1) std::cout << ","; + } + std::cout << "}" << std::endl; auto reshape_out = target_mod.add_instruction( - make_op("reshape", {{"dims", final_shape}}), transpose_out); - std::cout << "Final reshape shape (with groups): " << reshape_out->get_shape() << std::endl; - - // for LSE (log-sum-exp), we need log(sum_exp) + max - // we already have reduce_max and reduce_sum from softmax computation - // LSE shape is [B, N, G, S, 1] which is correct for flash decoding - auto log_sum = target_mod.add_instruction(make_op("log"), reduce_sum); - auto lse = target_mod.add_instruction(make_op("add"), reduce_max, log_sum); + make_op("reshape", {{"dims", final_shape}}), + transpose_out); + std::cout << "Reshape output shape: " << reshape_out->get_shape() << std::endl; + + // Calculate LSE (log-sum-exp) from the tracked reduce operations + // LSE = log(sum_exp) + max + if(found_reduce_max && found_reduce_sum) { + auto log_sum = target_mod.add_instruction(make_op("log"), reduce_sum_ref); + auto lse = target_mod.add_instruction(make_op("add"), reduce_max_ref, log_sum); std::cout << "LSE shape: " << lse->get_shape() << std::endl; target_mod.add_return({reshape_out, lse}); + } else { + std::cout << "WARNING: Could not find reduce_max or reduce_sum for LSE calculation" << std::endl; + target_mod.add_return({reshape_out}); + } // print the complete submodule std::cout << "\n=== Complete GQA Flash Decoding Submodule ===" << std::endl; @@ -897,99 +807,25 @@ struct find_gqa_flash_decoding std::cout << "=== End Submodule ===" << std::endl; } - std::optional extract_qkv_params(instruction_ref gemm1, instruction_ref gemm2) const - { - qkv_params result; - - // Q: gemm1's first input should be a slice from the QKV tensor - auto q_input = gemm1->inputs()[0]; - if(q_input->name() == "slice") { - // trace back from slice to find the parameter - auto before_slice = q_input->inputs()[0]; - - instruction_ref current = before_slice; - while(current->name() != "@param") { - if(current->inputs().empty()) { - std::cout << "Cannot trace Q back to parameter" << std::endl; - return std::nullopt; - } - current = current->inputs()[0]; - } - result.q_param = current; - } else { - std::cout << "Expected Q to come from slice, got: " << q_input->name() << std::endl; - return std::nullopt; - } - - // K: gemm1's second input should be transposed K from concat_past_present - auto k_input = gemm1->inputs()[1]; - if(k_input->name() == "transpose") { - result.k_param = k_input->inputs()[0]; - } else { - result.k_param = k_input; - } - - if(result.k_param->name() != "@param") { - std::cout << "Expected K to be a parameter, got: " << result.k_param->name() << std::endl; - return std::nullopt; - } - - // V: gemm2's second input should be V from concat_past_present - result.v_param = gemm2->inputs()[1]; - if(result.v_param->name() != "@param") { - std::cout << "Expected V to be a parameter, got: " << result.v_param->name() << std::endl; - return std::nullopt; - } - - return result; - } - void apply(module_pass_manager& mpm, const match::matcher_result& r) const { auto& mm = mpm.get_module(); auto attn_group_ins = r.instructions["group"]; auto* submod = attn_group_ins->module_inputs().front(); - std::cout << "GQA flash decoding detected" << std::endl; + std::cout << "GQA flash decoding detected, here is the submodule: " << std::endl; + submod->debug_print(); - // check multiple returns - auto return_ins = std::prev(submod->end()); - assert(return_ins->name() == "@return" and - "Last instruction must be a @return instruction"); - if(return_ins->inputs().size() > 1) { - std::cout << "KV cache attention unexpected multiple returns" << std::endl; - return; - } - - // get gemm1 and gemm2 - auto [gemm1, gemm2] = get_gemms(submod); - - // Extract Q, K, V parameters from gemm inputs - auto qkv_opt = extract_qkv_params(gemm1, gemm2); - if(!qkv_opt) { - std::cout << "Failed to extract Q, K, V parameters" << std::endl; - return; - } - + // extract Q, K, V parameters from gemm inputs + auto [gemm1, gemm2] = get_attention_gemms(submod); + auto qkv_opt = qkv_params::from_gemms(gemm1, gemm2); + if(not qkv_opt) return; auto [q_param, k_param, v_param] = *qkv_opt; - - std::cout << "Q attn module param shape: " << q_param->get_shape() << std::endl; - std::cout << "K attn module param shape: " << k_param->get_shape() << std::endl; - std::cout << "V attn module param shape: " << v_param->get_shape() << std::endl; - // derive dim values + // derive attention dims from Q, K, V parameters attention_dims dims(q_param, k_param, groups); - - std::cout << "Max sequence length: " << dims.max_seq_length << std::endl; - if(groups <= 1) { - std::cout << "No splitting requested (groups=" << groups << ")" << std::endl; - return; - } - - // check if dimensions were calculated successfully - if(dims.seq_length_per_group == 0) { - std::cout << "Failed to calculate sequence length per group, returning" << std::endl; + if(groups <= 1 or dims.seq_length_per_group == 0) { return; } @@ -1001,50 +837,32 @@ struct find_gqa_flash_decoding auto q = map_param_to_main.at(q_param); // maps to the QKV tensor auto k = map_param_to_main.at(k_param); // maps to K concat_past_present output auto v = map_param_to_main.at(v_param); // maps to V concat_past_present output - - std::cout << "Main module Q shape: " << q->get_shape() << std::endl; - std::cout << "Main module K shape: " << k->get_shape() << std::endl; - std::cout << "Main module V shape: " << v->get_shape() << std::endl; // GQA flash decoding: - // - Q (QKV tensor): broadcast across groups (no split) + // - Q (QKV tensor): add new group dim and broadcast // - K: split sequence dimension into groups // - V: split sequence dimension into groups - // shapes before group transformation - auto q_shape = q->get_shape(); - auto k_shape_main = k->get_shape(); - auto v_shape_main = v->get_shape(); + auto q_type = q->get_shape().type(); + auto k_type = k->get_shape().type(); + auto v_type = v->get_shape().type(); // insert group dimension at position -2 for all tensors - // K and V: [B, kv_heads, N, D] -> [B, kv_heads, G, N/G, D] (split) + // K and V: [B, kv_heads, N, D] -> [B, kv_heads, G, N/G, D] // build transformed shapes std::vector q_transformed_shape; std::vector k_transformed_shape; std::vector v_transformed_shape; - - // Q shape transformation (broadcast group dimension) + q_transformed_shape = {dims.batch_size, dims.concat_heads, groups, dims.sequence_length, dims.head_dim}; k_transformed_shape = {dims.batch_size, dims.kv_heads, groups, dims.seq_length_per_group, dims.head_dim}; v_transformed_shape = {dims.batch_size, dims.kv_heads, groups, dims.seq_length_per_group, dims.head_dim}; - std::cout << "Q transformed shape: "; - for(auto d : q_transformed_shape) std::cout << d << " "; - std::cout << std::endl; - - std::cout << "K transformed shape: "; - for(auto d : k_transformed_shape) std::cout << d << " "; - std::cout << std::endl; - - std::cout << "V transformed shape: "; - for(auto d : v_transformed_shape) std::cout << d << " "; - std::cout << std::endl; - // insert reshape operations before the attention group // [B, concat_heads, seq, head_dim] -> [B, concat_heads, 1, seq, head_dim] -> [B, concat_heads, G, seq, head_dim] auto q_unsqueezed = mm.insert_instruction( attn_group_ins, - make_op("unsqueeze", {{"axes", {2}}}), + make_op("unsqueeze", {{"axes", {-2}}}), q); auto q_reshaped = mm.insert_instruction( @@ -1052,27 +870,22 @@ struct find_gqa_flash_decoding make_op("multibroadcast", {{"out_lens", q_transformed_shape}}), q_unsqueezed); - // K: reshape to split sequence dimension + // [B, kv_heads, N, D] -> [B, kv_heads, G, N/G, D] auto k_reshaped = mm.insert_instruction( attn_group_ins, make_op("reshape", {{"dims", k_transformed_shape}}), k); - // V: reshape to split sequence dimension + // [B, kv_heads, N, D] -> [B, kv_heads, G, N/G, D] auto v_reshaped = mm.insert_instruction( attn_group_ins, make_op("reshape", {{"dims", v_transformed_shape}}), v); - std::cout << "Q reshaped: " << q_reshaped->get_shape() << std::endl; - std::cout << "K reshaped: " << k_reshaped->get_shape() << std::endl; - std::cout << "V reshaped: " << v_reshaped->get_shape() << std::endl; - - // No need to handle positions outside the submodule - // We'll adjust broadcast patterns inside for early exit masking + // No need to reshape additional inputs + // We'll adjust broadcast patterns inside for masking - // TODO can probably do this simpler - // Create new input list by replacing Q, K, V with reshaped versions + // create new input list, starting with replacing Q, K, V with reshaped versions std::vector new_group_inputs = group_inputs; for(size_t i = 0; i < group_inputs.size(); ++i) { if(group_inputs[i] == q) { @@ -1082,14 +895,12 @@ struct find_gqa_flash_decoding } else if(group_inputs[i] == v) { new_group_inputs[i] = v_reshaped; } - // Other inputs (like seq_len) stay the same } - // Create new flash decoding submodule module m_flash_decode; m_flash_decode.set_bypass(); - // Get parameter names from original submodule + // get parameter names from original submodule auto get_param_name = [](instruction_ref param) -> std::string { assert(param->name() == "@param"); return param->get_operator().to_value()["parameter"].to(); @@ -1101,11 +912,11 @@ struct find_gqa_flash_decoding // Add parameters to new submodule with transformed shapes auto new_q_param = m_flash_decode.add_parameter( - q_name, shape{q_shape.type(), q_transformed_shape}); + q_name, shape{q_type, q_transformed_shape}); auto new_k_param = m_flash_decode.add_parameter( - k_name, shape{k_shape_main.type(), k_transformed_shape}); + k_name, shape{k_type, k_transformed_shape}); auto new_v_param = m_flash_decode.add_parameter( - v_name, shape{v_shape_main.type(), v_transformed_shape}); + v_name, shape{v_type, v_transformed_shape}); // Build mapping from old params to new params std::unordered_map map_old_params_to_new; @@ -1113,35 +924,29 @@ struct find_gqa_flash_decoding map_old_params_to_new[k_param] = new_k_param; map_old_params_to_new[v_param] = new_v_param; - // Add other parameters (like seq_len) that don't change shape + // add the rest of the parameters for(auto param : iterator_for(*submod)) { if(param->name() == "@param") { - if(param != q_param && param != k_param && param != v_param) { + if(not contains(map_old_params_to_new, param)) { auto param_name = get_param_name(param); auto param_shape = param->get_shape(); auto new_param = m_flash_decode.add_parameter(param_name, param_shape); map_old_params_to_new[param] = new_param; - std::cout << "Added unchanged param: " << param_name - << " with shape: " << param_shape << std::endl; } } } - // TODO all the param stuff before this can be simplified - // rebuild the attention operations in the flash decode submodule - std::cout << "Rebuilding GQA attention operations..." << std::endl; rebuild_gqa_attention(m_flash_decode, *submod, map_old_params_to_new, - q_param, k_param, v_param, dims, groups); - - // create the module in the module pass manager - auto original_submod_name = attn_group_ins->module_inputs().front()->name(); - std::string new_mod_name = original_submod_name + "_gqa_flash_decoding"; + gemm2, dims, groups); + + // create the module in the module pass manager and insert the new group operation + auto orig_name = attn_group_ins->module_inputs().front()->name(); + std::string new_mod_name = orig_name + "_gqa_flash_decoding"; module_ref mpm_flash_mod = mpm.create_module(new_mod_name, std::move(m_flash_decode)); mpm_flash_mod->set_bypass(); - // insert the new group operation auto new_group_ins = mm.insert_instruction( attn_group_ins, make_op("group", {{"tag", "attention"}}), @@ -1166,32 +971,26 @@ struct find_gqa_flash_decoding // LSE is [B, N, G, S, 1], match group dimension of O' [B, G, S, N*D] // [B, N, G, S, 1] -> [B, G, N, S, 1] - std::cout << "\n1. Transposing LSE to align group dimension..." << std::endl; auto lse_transposed = mm.insert_instruction( attn_group_ins, make_op("transpose", {{"permutation", {0, 2, 1, 3, 4}}}), lse); - std::cout << " LSE transposed shape: " << lse_transposed->get_shape() << std::endl; // [2, 2, 2, 1, 1] - // average across heads (N) since all heads in a group share the same weight - // [B, G, N, S, 1] -> [B, G, S, 1] (reduce over axis 2, then squeeze) - std::cout << "\n2. Averaging LSE across heads within each group..." << std::endl; + // average across heads since all heads in a group share the same weight + // [B, G, N, S, 1] -> [B, G, 1, S, 1] auto lse_avg = mm.insert_instruction( attn_group_ins, make_op("reduce_mean", {{"axes", {2}}}), lse_transposed); - std::cout << " LSE averaged shape: " << lse_avg->get_shape() << std::endl; // [2, 2, 1, 1, 1] - // squeeze axes 2 and 4: [B, G, 1, S, 1] -> [B, G, S] + // [B, G, 1, S, 1] -> [B, G, S] auto lse_squeezed = mm.insert_instruction( attn_group_ins, make_op("squeeze", {{"axes", {2, 4}}}), lse_avg); - std::cout << " LSE squeezed shape: " << lse_squeezed->get_shape() << std::endl; // [2, 2, 1] // softmax across groups for LSE weights - std::cout << "\n3. Computing softmax of LSE across groups..." << std::endl; - // find max across groups for numerical stability + // [B, G, S] -> [B, 1, S] auto lse_max = mm.insert_instruction( attn_group_ins, make_op("reduce_max", {{"axes", {1}}}), lse_squeezed); - std::cout << " Max LSE shape: " << lse_max->get_shape() << std::endl; // [2, 1, 1] // broadcast max back to original shape + // [B, 1, S] -> [B, G, S] auto lse_max_bcast = mm.insert_instruction( attn_group_ins, make_op("multibroadcast", {{"out_lens", lse_squeezed->get_shape().lens()}}), @@ -1200,35 +999,32 @@ struct find_gqa_flash_decoding // exp(LSE - max_LSE) auto lse_sub = mm.insert_instruction(attn_group_ins, make_op("sub"), lse_squeezed, lse_max_bcast); auto lse_exp = mm.insert_instruction(attn_group_ins, make_op("exp"), lse_sub); - std::cout << " Exp(LSE) shape: " << lse_exp->get_shape() << std::endl; // [2, 2, 1] // sum exp across groups + // [B, G, S] -> [B, 1, S] auto lse_sum = mm.insert_instruction( attn_group_ins, make_op("reduce_sum", {{"axes", {1}}}), lse_exp); - std::cout << " Sum exp shape: " << lse_sum->get_shape() << std::endl; // [2, 1, 1] - // broadcast sum back + // [B, 1, S] -> [B, G, S] auto lse_sum_bcast = mm.insert_instruction( attn_group_ins, make_op("multibroadcast", {{"out_lens", lse_exp->get_shape().lens()}}), lse_sum); + // [B, G, S] -> [B, G, S] auto weights = mm.insert_instruction(attn_group_ins, make_op("div"), lse_exp, lse_sum_bcast); - std::cout << " Softmax weights shape: " << weights->get_shape() << std::endl; // [2, 2, 1] // weights is [B, G, S], O' is [B, G, S, N*D] // [B, G, S] -> [B, G, S, 1] - std::cout << "\n4. Preparing weights for multiplication with O'..." << std::endl; auto weights_unsqueezed = mm.insert_instruction( attn_group_ins, make_op("unsqueeze", {{"axes", {3}}}), weights); - std::cout << " Weights unsqueezed shape: " << weights_unsqueezed->get_shape() << std::endl; // [2, 2, 1, 1] // broadcast to match O' shape + // [B, G, S, 1] -> [B, G, S, N*D] auto weights_bcast = mm.insert_instruction( attn_group_ins, make_op("multibroadcast", {{"out_lens", partial_output_o_prime->get_shape().lens()}}), weights_unsqueezed); - std::cout << " Weights broadcast shape: " << weights_bcast->get_shape() << std::endl; // [2, 2, 1, 4] // convert weights to match O' type auto output_type = partial_output_o_prime->get_shape().type(); @@ -1236,25 +1032,20 @@ struct find_gqa_flash_decoding attn_group_ins, make_op("convert", {{"target_type", output_type}}), weights_bcast); // multiply O' by weights - std::cout << "\n5. Multiplying O' by softmax weights..." << std::endl; auto weighted_output = mm.insert_instruction( attn_group_ins, make_op("mul"), partial_output_o_prime, weights_converted); - std::cout << " Weighted output shape: " << weighted_output->get_shape() << std::endl; // [2, 2, 1, 4] // sum across groups to get final output - std::cout << "\n6. Summing weighted outputs across groups..." << std::endl; + // [B, G, S, N*D] -> [B, 1, S, N*D] auto final_output = mm.insert_instruction( attn_group_ins, make_op("reduce_sum", {{"axes", {1}}}), weighted_output); - std::cout << " Final output shape: " << final_output->get_shape() << std::endl; // [2, 1, 1, 4] // squeeze the reduced group dimension + // [B, 1, S, N*D] -> [B, S, N*D] auto final_squeezed = mm.insert_instruction( attn_group_ins, make_op("squeeze", {{"axes", {1}}}), final_output); - std::cout << " Final squeezed shape: " << final_squeezed->get_shape() << std::endl; // [2, 1, 4] mm.replace_instruction(attn_group_ins, final_squeezed); - - std::cout << "\n=== Kernel 2 complete: LSE-weighted combination successful ===" << std::endl; } }; @@ -1271,21 +1062,6 @@ struct find_flash_decoding return match::name("group")(match::has_op_value("tag", "attention")).bind("group"); } - std::pair get_gemms(module_ref submod) const - { - std::vector gemms; - for(auto it = submod->begin(); it != submod->end(); ++it) - { - if(it->name() == "dot") - gemms.push_back(it); - } - assert(gemms.size() == 2 and "Expected exactly 2 gemm operations in attention submodule"); - - // gemms[0] is Q@K, gemms[1] is P@V - // gemms are in order since we iterate from begin to end - return {gemms[0], gemms[1]}; - } - std::vector get_qkv_shapes(instruction_ref q, instruction_ref k, instruction_ref v) const { std::vector qkv_shapes; @@ -1378,17 +1154,6 @@ struct find_flash_decoding return result; } - std::unordered_map - map_submod_params_to_inputs(module_ref submod, - const std::vector& group_inputs) const - { - auto map_param_to_main = submod->get_ins_param_map(group_inputs, true); - // verify the mapping is correct - auto expected_inputs = submod->get_inputs(map_param_to_main); - assert(expected_inputs == group_inputs and "Mapped inputs don't match group inputs"); - return map_param_to_main; - } - void rebuild_attention_submodule( module& target_mod, const module& source_mod, @@ -1487,7 +1252,7 @@ struct find_flash_decoding return; // get gemm1 and gemm2 - auto [gemm1, gemm2] = get_gemms(submod); + auto [gemm1, gemm2] = get_attention_gemms(submod); // TODO: for this first pass of flash decoding, assuming no input fusion / not supporting auto q_param = gemm1->inputs()[0]; @@ -1900,11 +1665,17 @@ void fuse_attention::apply(module_pass_manager& mpm) const { std::size_t counter = 0; + std::cout << "original module before fusion: " << std::endl; + mpm.get_module().debug_print(); + // Fuse kv-cache attention by default match::find_matches(mpm, find_kv_cache_attention{.counter = &counter}); mpm.get_module().sort(); mpm.run_pass(dead_code_elimination{}); + std::cout << "module after kv-cache attention fusion: " << std::endl; + mpm.get_module().debug_print(); + // Only fuse plain attention when requested if(attn_enabled) { From a095770c65b5c951a9ac9cf77fc6496b6424e9da Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Wed, 14 Jan 2026 13:57:35 -0600 Subject: [PATCH 4/4] bug from changing git head --- src/fuse_attention.cpp | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index 02261336a3b..d403e4154cf 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -354,7 +354,7 @@ struct find_gqa_flash_decoding if(max_seq_length % num_groups != 0) { std::cout << "Max sequence length " << max_seq_length << " not divisible by " << num_groups << " groups" << std::endl; - // TODO: Add padding support + // TODO: add autosplitting (padding won't be needed) seq_length_per_group = 0; } else { seq_length_per_group = max_seq_length / num_groups; @@ -427,10 +427,6 @@ struct find_gqa_flash_decoding // factory method to extract Q, K, V parameters from gemm operations static std::optional from_gemms(instruction_ref gemm1, instruction_ref gemm2) { - /* - - gemm1: Q@K - - gemm2: P@V - */ auto trace_back_to_param = [](instruction_ref ins) -> std::optional { instruction_ref current = ins; while(current->name() != "@param") { @@ -832,17 +828,16 @@ struct find_gqa_flash_decoding // map submodule params to main module inputs auto group_inputs = attn_group_ins->inputs(); auto map_param_to_main = map_submod_params_to_inputs(submod, group_inputs); - + // get actual Q, K, V instructions from main module auto q = map_param_to_main.at(q_param); // maps to the QKV tensor auto k = map_param_to_main.at(k_param); // maps to K concat_past_present output auto v = map_param_to_main.at(v_param); // maps to V concat_past_present output - + // GQA flash decoding: // - Q (QKV tensor): add new group dim and broadcast // - K: split sequence dimension into groups // - V: split sequence dimension into groups - auto q_type = q->get_shape().type(); auto k_type = k->get_shape().type(); auto v_type = v->get_shape().type(); @@ -1691,14 +1686,14 @@ void fuse_attention::apply(module_pass_manager& mpm) const // flash decoding for regular attention, single & multi-headed match::find_matches( mpm, - find_flash_decoding{.configured_splits = flash_decoding_num_splits, + find_flash_decoding{.configured_splits = configured_splits, .configured_threshold = flash_decoding_threshold, .configured_max_splits = flash_decoding_max_splits, .configured_min_chunk_size = flash_decoding_min_chunk_size}); // flash decoding for GQA attention match::find_matches( - mpm, find_gqa_flash_decoding{.groups = num_splits}); + mpm, find_gqa_flash_decoding{.groups = configured_splits}); mpm.run_pass(dead_code_elimination{}); } }