diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 0ab6d7e2b699..f9942cdaab6e 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -986,15 +986,8 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sre /*! * \brief Helper to fuse epilogue block into reduction block - * Analyzes epilogue pattern and transforms reduction init/update + * Uses generalized approach to handle any epilogue expression without pattern matching */ -// Epilogue type enumeration -enum class EpilogueType { - Bias, // temp + C - BiasReLU, // max(temp + C, 0) - Clipping, // min(max(temp, lower), upper) -}; - class ReductionEpilogueFuser : public BaseInliner { public: explicit ReductionEpilogueFuser(const Buffer& reduction_buffer, const BlockNode* reduction_block, @@ -1002,8 +995,7 @@ class ReductionEpilogueFuser : public BaseInliner { const StmtSRef& scope_root_sref) : BaseInliner(reduction_buffer, epilogue_block_realize->block, scope_root_sref), reduction_block_(reduction_block), - epilogue_block_(epilogue_block_realize->block.get()), - epilogue_type_(EpilogueType::Bias) { + epilogue_block_(epilogue_block_realize->block.get()) { // Disable opaque access check for epilogue fusion // Epilogue blocks can read multiple buffers (temp + bias), which is allowed has_opaque_access = false; @@ -1023,7 +1015,6 @@ class ReductionEpilogueFuser : public BaseInliner { const BlockRealizeNode* reduction_realize); private: - bool AnalyzeEpiloguePattern(const PrimExpr& value); bool IsReductionBlock(const BlockNode* block); void ExtractEpilogueInfo(); // Helper function to extract BufferLoad nodes from BufferStore @@ -1052,15 +1043,16 @@ class ReductionEpilogueFuser : public BaseInliner { const BlockNode* reduction_block_; const BlockNode* epilogue_block_; - PrimExpr epilogue_addend_{nullptr}; // C[vi, vj] in D = temp + C - Buffer epilogue_output_buffer_{nullptr}; // Output buffer D + // Generalized approach: store the entire epilogue expression + PrimExpr epilogue_expression_{ + nullptr}; // The entire epilogue expression (e.g., temp + C, max(temp + C, 0)) + const BufferLoadNode* reduction_buffer_load_{ + nullptr}; // The reduction buffer load in epilogue expression + Buffer epilogue_output_buffer_{nullptr}; // Output buffer D ffi::Array epilogue_output_indices_{nullptr}; // Indices of D[vi, vj] BufferRegion epilogue_output_region_{nullptr}; // Write region of D - Buffer epilogue_addend_buffer_{nullptr}; // Addend buffer C - BufferRegion epilogue_addend_region_{nullptr}; // Read region of C - EpilogueType epilogue_type_; // Type of epilogue operation - PrimExpr clipping_lower_{nullptr}; // Lower bound for clipping - PrimExpr clipping_upper_{nullptr}; // Upper bound for clipping + Buffer epilogue_addend_buffer_{nullptr}; // Additional buffer (e.g., bias buffer C) + BufferRegion epilogue_addend_region_{nullptr}; // Read region of additional buffer }; bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize) { @@ -1083,166 +1075,112 @@ bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue return false; } - // 4. Analyze epilogue pattern: D[i,j] = temp[i,j] + C[i,j] or - // D[i,j] = min(max(temp[i,j], lower), upper) - if (!AnalyzeEpiloguePattern(inlined_store_->value)) { - // Failure: epilogue is not a supported pattern (Bias, BiasReLU, or Clipping) - return false; - } - - // 5. Verify temp appears exactly once in the epilogue pattern - // This ensures correctness for all supported patterns (Bias, BiasReLU, Clipping) - // The reduction result buffer must be used exactly once in the epilogue expression + // 4. Generalized approach: store the entire epilogue expression + // Verify reduction buffer appears exactly once (required for fusion correctness) if (loads.size() != 1) { // Failure: The reduction result (temp) must be used exactly once in the // epilogue expression for fusion. return false; } - // 6. Check if producer is a reduction block - if (!IsReductionBlock(reduction_block_)) { - // Failure: producer is not a reduction block - return false; - } - - // 7. Extract epilogue information (output buffer, indices, regions, etc.) - ExtractEpilogueInfo(); + // Store the epilogue expression and reduction buffer load + epilogue_expression_ = inlined_store_->value; + reduction_buffer_load_ = loads[0]; - return true; -} + // 5. Reject epilogues that scale the reduction result with non-additive ops + // For example, (reduce_out * 2.0) + C[i] is not a valid bias-style epilogue. + // We only allow the reduction result to be combined via Add/Min/Max shells. + class ScalingDetector : public ExprVisitor { + public: + explicit ScalingDetector(const Buffer& buffer) : buffer_(buffer) {} -bool ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) { - // Pattern 1: temp[i,j] + C[i,j] or C[i,j] + temp[i,j] (Bias) - if (const auto* add = value.as()) { - const auto* load_a = add->a.as(); - const auto* load_b = add->b.as(); + bool HasScaling(const PrimExpr& expr) { + has_scaling_ = false; + VisitExpr(expr); + return has_scaling_; + } - bool a_is_target = load_a && load_a->buffer.same_as(inlined_buffer_); - bool b_is_target = load_b && load_b->buffer.same_as(inlined_buffer_); + private: + // Helper to check if a subtree contains a load from the reduction buffer + bool ContainsTarget(const PrimExpr& expr) { + class TargetFinder : public ExprVisitor { + public: + explicit TargetFinder(const Buffer& buffer) : buffer_(buffer) {} + + bool Find(const PrimExpr& e) { + found_ = false; + VisitExpr(e); + return found_; + } - // Ensure exactly one operand is from the reduction buffer - if (a_is_target != b_is_target) { - epilogue_addend_ = a_is_target ? add->b : add->a; - epilogue_type_ = EpilogueType::Bias; - return true; - } - } + private: + void VisitExpr_(const BufferLoadNode* op) final { + if (op->buffer.same_as(buffer_)) { + found_ = true; + return; + } + ExprVisitor::VisitExpr_(op); + } - // Pattern 2: min(max(temp[i,j], lower), upper) or max(min(temp[i,j], upper), lower) (Clipping) - // Handle all commutative variants of min/max at each level. + Buffer buffer_; + bool found_{false}; + }; - // Helper to check if an expression is a load from the reduction buffer, and - // return the other operand as `other` if so. - auto match_buffer_in_commutative_op = [this](const PrimExpr& a, const PrimExpr& b, - PrimExpr* other) -> bool { - if (const auto* load_a = a.as()) { - if (load_a->buffer.same_as(inlined_buffer_)) { - *other = b; - return true; - } - } - if (const auto* load_b = b.as()) { - if (load_b->buffer.same_as(inlined_buffer_)) { - *other = a; - return true; - } + TargetFinder finder(buffer_); + return finder.Find(expr); } - return false; - }; - // Check for min(max(temp, lower), upper) and commutative variants - if (const auto* min_node = value.as()) { - const MaxNode* max_node = nullptr; - PrimExpr upper; - // Try both (a, b) as possible positions of the inner max - if ((max_node = min_node->a.as())) { - upper = min_node->b; - } else if ((max_node = min_node->b.as())) { - upper = min_node->a; - } - if (max_node != nullptr) { - PrimExpr lower; - if (match_buffer_in_commutative_op(max_node->a, max_node->b, &lower)) { - clipping_lower_ = lower; - clipping_upper_ = upper; - epilogue_type_ = EpilogueType::Clipping; - return true; + void VisitExpr_(const MulNode* op) final { + if (has_scaling_) return; + // If either operand subtree contains the reduction buffer load, + // we treat this as invalid scaling of the reduction result. + if (ContainsTarget(op->a) || ContainsTarget(op->b)) { + has_scaling_ = true; + return; } + ExprVisitor::VisitExpr_(op); } - } - // Check for max(min(temp[i,j], upper), lower) and commutative variants - if (const auto* max_node = value.as()) { - const MinNode* min_node = nullptr; - PrimExpr lower; - // Try both (a, b) as possible positions of the inner min - if ((min_node = max_node->a.as())) { - lower = max_node->b; - } else if ((min_node = max_node->b.as())) { - lower = max_node->a; - } - if (min_node != nullptr) { - PrimExpr upper; - if (match_buffer_in_commutative_op(min_node->a, min_node->b, &upper)) { - clipping_lower_ = lower; - clipping_upper_ = upper; - epilogue_type_ = EpilogueType::Clipping; - return true; + void VisitExpr_(const DivNode* op) final { + if (has_scaling_) return; + if (ContainsTarget(op->a) || ContainsTarget(op->b)) { + has_scaling_ = true; + return; } + ExprVisitor::VisitExpr_(op); } - } - // Pattern 3: max(temp[i,j] + C[i,j], 0) or max(C[i,j] + temp[i,j], 0) (BiasReLU) - // Also handle max(0, temp[i,j] + C[i,j]) or max(0, C[i,j] + temp[i,j]) - if (const auto* max_node = value.as()) { - // Check if either operand is zero (ReLU: max(x, 0) or max(0, x)) - // Support both integer and float zero constants. - const PrimExpr* add_candidate = nullptr; - bool is_zero_const = false; - auto is_zero_expr = [](const PrimExpr& expr) -> bool { - if (tir::is_zero(expr)) { - return true; - } - if (const auto* float_imm = expr.as()) { - return float_imm->value == 0.0; + void VisitExpr_(const ModNode* op) final { + if (has_scaling_) return; + if (ContainsTarget(op->a) || ContainsTarget(op->b)) { + has_scaling_ = true; + return; } - return false; - }; - - if (is_zero_expr(max_node->a)) { - is_zero_const = true; - add_candidate = &max_node->b; - } else if (is_zero_expr(max_node->b)) { - is_zero_const = true; - add_candidate = &max_node->a; + ExprVisitor::VisitExpr_(op); } - if (is_zero_const && add_candidate != nullptr) { - if (const auto* add = add_candidate->as()) { - const auto* load_a = add->a.as(); - const auto* load_b = add->b.as(); - - bool a_is_target = load_a && load_a->buffer.same_as(inlined_buffer_); - bool b_is_target = load_b && load_b->buffer.same_as(inlined_buffer_); + Buffer buffer_; + bool has_scaling_{false}; + }; - // Ensure exactly one operand is from the reduction buffer - if (a_is_target != b_is_target) { - epilogue_addend_ = a_is_target ? add->b : add->a; - epilogue_type_ = EpilogueType::BiasReLU; - return true; - } - } else if (const auto* load = add_candidate->as()) { - // Handle bias-free ReLU: max(temp, 0) or max(0, temp) - if (load->buffer.same_as(inlined_buffer_)) { - epilogue_addend_ = tir::make_zero(load->dtype); - epilogue_type_ = EpilogueType::BiasReLU; - return true; - } - } + { + ScalingDetector detector(inlined_buffer_); + if (detector.HasScaling(inlined_store_->value)) { + // Failure: Non-additive scaling of the reduction result is not supported + return false; } } - return false; + // 6. Check if producer is a reduction block + if (!IsReductionBlock(reduction_block_)) { + // Failure: producer is not a reduction block + return false; + } + + // 7. Extract epilogue information (output buffer, indices, regions, etc.) + ExtractEpilogueInfo(); + + return true; } bool ReductionEpilogueFuser::IsReductionBlock(const BlockNode* block) { @@ -1268,12 +1206,29 @@ void ReductionEpilogueFuser::ExtractEpilogueInfo() { } } - // Extract epilogue addend buffer and region from epilogue_addend_ - if (const auto* load = epilogue_addend_.as()) { - epilogue_addend_buffer_ = load->buffer; + // Generalized approach: extract all non-reduction buffers from epilogue expression + // Find all buffers in epilogue expression (except the reduction buffer) + struct BufferExtractor : public ExprVisitor { + void VisitExpr_(const BufferLoadNode* load) final { + if (!load->buffer.same_as(reduction_buffer)) { + other_buffers.insert(load->buffer.get()); + } + ExprVisitor::VisitExpr_(load); + } + Buffer reduction_buffer; + std::unordered_set other_buffers; + } extractor; + extractor.reduction_buffer = inlined_buffer_; + extractor(epilogue_expression_); + + // Extract the first non-reduction buffer and its region + // In most cases, there's one additional buffer (e.g., bias buffer) + if (!extractor.other_buffers.empty()) { + const BufferNode* first_buffer = *extractor.other_buffers.begin(); + epilogue_addend_buffer_ = ffi::GetRef(first_buffer); // Find the read region from epilogue block reads for (const BufferRegion& read : epilogue_block_->reads) { - if (read->buffer.same_as(epilogue_addend_buffer_)) { + if (read->buffer.get() == first_buffer) { epilogue_addend_region_ = read; break; } @@ -1308,53 +1263,163 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti var_map[epilogue_data_vars[i]] = reduction_data_vars[i]; } - // 2. Change init to epilogue value based on epilogue type - BufferStore new_init_store; - if (epilogue_type_ == EpilogueType::BiasReLU) { - // For ReLU, init should be max(C[vi, vj], 0) to match per-iteration ReLU semantics - PrimExpr init_value = Substitute(epilogue_addend_, var_map); - PrimExpr zero = tir::make_zero(init_value.dtype()); - new_init_store = BufferStore(epilogue_output_buffer_, Max(init_value, zero), - Substitute(epilogue_output_indices_, var_map)); - } else if (epilogue_type_ == EpilogueType::Clipping) { - // For Clipping, init should be min(max(init_value, lower), upper) - // Since init is typically 0, this becomes min(max(0, lower), upper) - PrimExpr init_value = tir::make_zero(epilogue_output_buffer_->dtype); - PrimExpr clipped_init = Min(Max(init_value, Substitute(clipping_lower_, var_map)), - Substitute(clipping_upper_, var_map)); - new_init_store = BufferStore(epilogue_output_buffer_, clipped_init, - Substitute(epilogue_output_indices_, var_map)); - } else { - // Bias: D[vi, vj] = C[vi, vj] - new_init_store = BufferStore(epilogue_output_buffer_, Substitute(epilogue_addend_, var_map), - Substitute(epilogue_output_indices_, var_map)); - } + // 2. Generalized init transformation: substitute reduction buffer load with identity element (0) + // Create a substituter to replace reduction_buffer_load_ with identity element + class InitSubstituter : public ExprMutator { + public: + InitSubstituter(const Buffer& target_buffer, PrimExpr identity_elem) + : target_buffer_(target_buffer), identity_elem_(identity_elem) {} + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad load = Downcast(ExprMutator::VisitExpr_(op)); + if (load->buffer.same_as(target_buffer_)) { + return identity_elem_; + } + return load; + } + + private: + Buffer target_buffer_; + PrimExpr identity_elem_; + }; + + // Identity element for reduction (assumed to be 0 for addition-based reductions) + PrimExpr identity_elem = tir::make_zero(epilogue_output_buffer_->dtype); + + // Substitute reduction buffer load with identity element + InitSubstituter init_subst(inlined_buffer_, identity_elem); + PrimExpr init_epilogue = init_subst(epilogue_expression_); + + // Apply index mapping + init_epilogue = Substitute(init_epilogue, var_map); + + // Simplify the expression (e.g., 0 + C[vi, vj] -> C[vi, vj]) + arith::Analyzer analyzer; + init_epilogue = analyzer.Simplify(init_epilogue); + + BufferStore new_init_store = BufferStore(epilogue_output_buffer_, init_epilogue, + Substitute(epilogue_output_indices_, var_map)); new_block->init = new_init_store; - // 3. Replace output buffer from temp to D in body - class BufferReplacer : public StmtExprMutator { + // 3. Generalized update transformation: apply epilogue expression with reduction buffer replaced + // If reduction buffer load's parent is Add and other operand is not a reduction buffer, + // remove that operand (bias addend) from update expression + class UpdateSubstituter : public StmtExprMutator { public: - BufferReplacer(Buffer old_buf, Buffer new_buf, EpilogueType epilogue_type, DataType dtype, - PrimExpr clipping_lower = PrimExpr(), PrimExpr clipping_upper = PrimExpr()) + UpdateSubstituter(const Buffer& old_buf, const Buffer& new_buf, const Buffer& reduction_buf, + const PrimExpr& epilogue_expr, const std::unordered_map& var_map) : old_buffer_(old_buf), new_buffer_(new_buf), - epilogue_type_(epilogue_type), - dtype_(dtype), - clipping_lower_(clipping_lower), - clipping_upper_(clipping_upper) {} + reduction_buffer_(reduction_buf), + epilogue_expression_(epilogue_expr), + var_map_(var_map) {} Stmt VisitStmt_(const BufferStoreNode* op) final { BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); if (store->buffer.same_as(old_buffer_)) { - PrimExpr new_value = store->value; - // For ReLU, apply max per iteration to match per-iteration ReLU semantics - if (epilogue_type_ == EpilogueType::BiasReLU) { - PrimExpr zero = tir::make_zero(dtype_); - new_value = Max(new_value, zero); - } else if (epilogue_type_ == EpilogueType::Clipping) { - // For Clipping, apply min(max(value, lower), upper) per iteration - new_value = Min(Max(new_value, clipping_lower_), clipping_upper_); - } + // Replace old_buffer_ in store->value with new_buffer_ to get the reduction update + // expression This ensures store->value references new_buffer_ instead of old_buffer_ + class ReductionUpdateReplacer : public ExprMutator { + public: + ReductionUpdateReplacer(const Buffer& old_buf, const Buffer& new_buf) + : old_buffer_(old_buf), new_buffer_(new_buf) {} + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad load = Downcast(ExprMutator::VisitExpr_(op)); + if (load->buffer.same_as(old_buffer_)) { + return BufferLoad(new_buffer_, load->indices); + } + return load; + } + + private: + Buffer old_buffer_; + Buffer new_buffer_; + }; + + ReductionUpdateReplacer reduction_replacer(old_buffer_, new_buffer_); + PrimExpr reduction_update = reduction_replacer(store->value); + + // Generalized approach: apply epilogue expression with reduction buffer load replaced + // If reduction buffer load's direct parent is Add and the other operand is not a reduction + // buffer, remove that operand (bias addend) from the update expression + class GeneralizedEpilogueApplier : public ExprMutator { + public: + GeneralizedEpilogueApplier(const Buffer& target_buf, const Buffer& reduction_buf, + const PrimExpr& replacement) + : target_buffer_(target_buf), + reduction_buffer_(reduction_buf), + replacement_(replacement), + found_target_load_(false) {} + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad load = Downcast(ExprMutator::VisitExpr_(op)); + if (load->buffer.same_as(target_buffer_)) { + found_target_load_ = true; + // Check if parent is Add (will be checked in VisitExpr_(const AddNode*)) + return replacement_; + } + return load; + } + + PrimExpr VisitExpr_(const AddNode* op) final { + // Visit children first to see if we find the target buffer load + bool found_before = found_target_load_; + found_target_load_ = false; + + PrimExpr a = VisitExpr(op->a); + bool found_in_a = found_target_load_; + found_target_load_ = false; + + PrimExpr b = VisitExpr(op->b); + bool found_in_b = found_target_load_; + + // If target buffer load was found in this Add node + if (found_in_a || found_in_b) { + // Check if the other operand is NOT from the reduction buffer + // If so, it's likely a bias addend that should be removed in update + bool other_is_reduction = false; + if (found_in_a) { + // Check if b is from reduction buffer + if (const auto* load_b = b.as()) { + other_is_reduction = load_b->buffer.same_as(reduction_buffer_); + } + if (!other_is_reduction) { + // b is the bias addend, remove it + return a; + } + } else { // found_in_b + // Check if a is from reduction buffer + if (const auto* load_a = a.as()) { + other_is_reduction = load_a->buffer.same_as(reduction_buffer_); + } + if (!other_is_reduction) { + // a is the bias addend, remove it + return b; + } + } + // If other operand is also from reduction buffer, keep the Add + return Add(a, b); + } + + // Target buffer load not found in this Add, return as is + found_target_load_ = found_before; + return Add(a, b); + } + + private: + const Buffer& target_buffer_; + const Buffer& reduction_buffer_; + const PrimExpr& replacement_; + bool found_target_load_; + }; + + GeneralizedEpilogueApplier applier(old_buffer_, reduction_buffer_, reduction_update); + PrimExpr new_value = applier(epilogue_expression_); + + // Apply index mapping + new_value = Substitute(new_value, var_map_); + return BufferStore(new_buffer_, new_value, store->indices); } return store; @@ -1371,19 +1436,16 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti private: Buffer old_buffer_; Buffer new_buffer_; - EpilogueType epilogue_type_; - DataType dtype_; - PrimExpr clipping_lower_; - PrimExpr clipping_upper_; + Buffer reduction_buffer_; + PrimExpr epilogue_expression_; + std::unordered_map var_map_; }; - DataType dtype = epilogue_output_buffer_->dtype; - PrimExpr clipping_lower_subst = - epilogue_type_ == EpilogueType::Clipping ? Substitute(clipping_lower_, var_map) : PrimExpr(); - PrimExpr clipping_upper_subst = - epilogue_type_ == EpilogueType::Clipping ? Substitute(clipping_upper_, var_map) : PrimExpr(); - BufferReplacer replacer(inlined_buffer_, epilogue_output_buffer_, epilogue_type_, dtype, - clipping_lower_subst, clipping_upper_subst); + // Apply index mapping to epilogue expression first + PrimExpr epilogue_expr_mapped = Substitute(epilogue_expression_, var_map); + + UpdateSubstituter replacer(inlined_buffer_, epilogue_output_buffer_, inlined_buffer_, + epilogue_expr_mapped, var_map); new_block->body = replacer(reduction_block->body); // 4. Update write regions @@ -1398,21 +1460,22 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti } new_block->writes = new_writes; - // 5. Update read regions (C first, then A, B) + // 5. Update read regions: add all buffers from epilogue expression (except reduction buffer) ffi::Array new_reads; std::unordered_set read_bufs; - // Add C buffer read first (used in init) - if (epilogue_addend_buffer_.defined()) { - new_reads.push_back(BufferRegion(epilogue_addend_buffer_, - Substitute(epilogue_addend_region_->region, var_map))); - read_bufs.insert(epilogue_addend_buffer_.get()); + // Add all non-reduction buffers from epilogue expression + for (const BufferRegion& read : epilogue_block_->reads) { + if (!read->buffer.same_as(inlined_buffer_)) { + new_reads.push_back(BufferRegion(read->buffer, Substitute(read->region, var_map))); + read_bufs.insert(read->buffer.get()); + } } - // Add existing read regions (A, B, etc.) + // Add existing read regions from reduction block (A, B, etc.) for (const BufferRegion& read : reduction_block->reads) { if (!read->buffer.same_as(inlined_buffer_)) { - // Only add non-temp buffers + // Only add non-temp buffers that haven't been added yet if (read_bufs.find(read->buffer.get()) == read_bufs.end()) { new_reads.push_back(read); read_bufs.insert(read->buffer.get()); diff --git a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py index dc89f9df56a7..addb7efff323 100644 --- a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py +++ b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py @@ -214,5 +214,64 @@ def test_fuse_reduction_epilogue_multiple_epilogue(): assert mod is not None +@T.prim_func +def matmul_bias_invalid_multiple_use_before( + A: T.Buffer((16, 16), "int8"), + B: T.Buffer((16, 16), "int8"), + C1: T.Buffer((16, 16), "int32"), + C2: T.Buffer((16, 16), "int32"), + D: T.Buffer((16, 16), "int32"), +) -> None: + """Epilogue uses the reduction result twice; fusion must be rejected.""" + temp = T.alloc_buffer((16, 16), dtype="int32") + for i, j, k in T.grid(16, 16, 16): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + temp[vi, vj] = T.int32(0) + temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") + for i, j in T.grid(16, 16): + with T.block("bad_epilogue"): + vi, vj = T.axis.remap("SS", [i, j]) + # temp[vi, vj] is used twice in the epilogue expression + D[vi, vj] = (temp[vi, vj] + C1[vi, vj]) * (temp[vi, vj] + C2[vi, vj]) + + +def test_fuse_reduction_epilogue_reject_multiple_use(): + """fusion should be rejected when the reduction result appears more than once.""" + sch = tir.Schedule(matmul_bias_invalid_multiple_use_before, debug_mask="all") + with pytest.raises(tvm.tir.ScheduleError): + sch.fuse_reduction_epilogue("multiply", "bad_epilogue") + + +@T.prim_func +def matmul_bias_invalid_scaling_before( + A: T.Buffer((16, 16), "int8"), + B: T.Buffer((16, 16), "int8"), + C: T.Buffer((16, 16), "int32"), + D: T.Buffer((16, 16), "int32"), +) -> None: + """Epilogue scales the reduction result; fusion must be rejected.""" + temp = T.alloc_buffer((16, 16), dtype="int32") + for i, j, k in T.grid(16, 16, 16): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + temp[vi, vj] = T.int32(0) + temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") + for i, j in T.grid(16, 16): + with T.block("scaled_epilogue"): + vi, vj = T.axis.remap("SS", [i, j]) + # temp[vi, vj] is scaled by 2 before adding bias; this must not be fused. + D[vi, vj] = temp[vi, vj] * T.int32(2) + C[vi, vj] + + +def test_fuse_reduction_epilogue_reject_scaling(): + """fusion should be rejected when the reduction result is scaled by Mul/Div/Mod.""" + sch = tir.Schedule(matmul_bias_invalid_scaling_before, debug_mask="all") + with pytest.raises(tvm.tir.ScheduleError): + sch.fuse_reduction_epilogue("multiply", "scaled_epilogue") + + if __name__ == "__main__": tvm.testing.main()