From 6714ac49bfaa7297b6adc4edf993572647af74c4 Mon Sep 17 00:00:00 2001 From: Hyun Gyu Kim Date: Mon, 5 Jan 2026 16:00:50 +0900 Subject: [PATCH 1/9] Solve Issue 1 --- src/tir/schedule/primitive/compute_inline.cc | 211 ++++++++++++------- 1 file changed, 132 insertions(+), 79 deletions(-) diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 0ab6d7e2b699..77ff91f4df32 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -1052,15 +1052,18 @@ class ReductionEpilogueFuser : public BaseInliner { const BlockNode* reduction_block_; const BlockNode* epilogue_block_; - PrimExpr epilogue_addend_{nullptr}; // C[vi, vj] in D = temp + C + // Generalized approach: store the entire epilogue expression + PrimExpr epilogue_expression_{nullptr}; // The entire epilogue expression (e.g., temp + C, max(temp + C, 0)) + const BufferLoadNode* reduction_buffer_load_{nullptr}; // The reduction buffer load in epilogue expression + PrimExpr epilogue_addend_{nullptr}; // C[vi, vj] in D = temp + C (kept for backward compatibility) Buffer epilogue_output_buffer_{nullptr}; // Output buffer D ffi::Array epilogue_output_indices_{nullptr}; // Indices of D[vi, vj] BufferRegion epilogue_output_region_{nullptr}; // Write region of D Buffer epilogue_addend_buffer_{nullptr}; // Addend buffer C BufferRegion epilogue_addend_region_{nullptr}; // Read region of C - EpilogueType epilogue_type_; // Type of epilogue operation - PrimExpr clipping_lower_{nullptr}; // Lower bound for clipping - PrimExpr clipping_upper_{nullptr}; // Upper bound for clipping + EpilogueType epilogue_type_; // Type of epilogue operation (kept for backward compatibility) + PrimExpr clipping_lower_{nullptr}; // Lower bound for clipping (kept for backward compatibility) + PrimExpr clipping_upper_{nullptr}; // Upper bound for clipping (kept for backward compatibility) }; bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize) { @@ -1083,21 +1086,20 @@ bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue return false; } - // 4. Analyze epilogue pattern: D[i,j] = temp[i,j] + C[i,j] or - // D[i,j] = min(max(temp[i,j], lower), upper) - if (!AnalyzeEpiloguePattern(inlined_store_->value)) { - // Failure: epilogue is not a supported pattern (Bias, BiasReLU, or Clipping) - return false; - } - - // 5. Verify temp appears exactly once in the epilogue pattern - // This ensures correctness for all supported patterns (Bias, BiasReLU, Clipping) - // The reduction result buffer must be used exactly once in the epilogue expression + // 4. Generalized approach: store the entire epilogue expression + // Verify reduction buffer appears exactly once (required for fusion correctness) if (loads.size() != 1) { // Failure: The reduction result (temp) must be used exactly once in the // epilogue expression for fusion. return false; } + + // Store the epilogue expression and reduction buffer load + epilogue_expression_ = inlined_store_->value; + reduction_buffer_load_ = loads[0]; + + // For backward compatibility, try to analyze pattern (optional) + AnalyzeEpiloguePattern(inlined_store_->value); // 6. Check if producer is a reduction block if (!IsReductionBlock(reduction_block_)) { @@ -1268,17 +1270,47 @@ void ReductionEpilogueFuser::ExtractEpilogueInfo() { } } - // Extract epilogue addend buffer and region from epilogue_addend_ - if (const auto* load = epilogue_addend_.as()) { - epilogue_addend_buffer_ = load->buffer; + // Generalized approach: extract all non-reduction buffers from epilogue expression + // Find all buffers in epilogue expression (except the reduction buffer) + struct BufferExtractor : public ExprVisitor { + void VisitExpr_(const BufferLoadNode* load) final { + if (!load->buffer.same_as(reduction_buffer)) { + other_buffers.insert(load->buffer.get()); + } + ExprVisitor::VisitExpr_(load); + } + Buffer reduction_buffer; + std::unordered_set other_buffers; + } extractor; + extractor.reduction_buffer = inlined_buffer_; + extractor(epilogue_expression_); + + // Extract the first non-reduction buffer and its region (for backward compatibility) + // In most cases, there's one additional buffer (e.g., bias buffer) + if (!extractor.other_buffers.empty()) { + const BufferNode* first_buffer = *extractor.other_buffers.begin(); + epilogue_addend_buffer_ = ffi::GetRef(first_buffer); // Find the read region from epilogue block reads for (const BufferRegion& read : epilogue_block_->reads) { - if (read->buffer.same_as(epilogue_addend_buffer_)) { + if (read->buffer.get() == first_buffer) { epilogue_addend_region_ = read; break; } } } + + // For backward compatibility, also try to extract from epilogue_addend_ + if (!epilogue_addend_buffer_.defined() && epilogue_addend_.defined()) { + if (const auto* load = epilogue_addend_.as()) { + epilogue_addend_buffer_ = load->buffer; + for (const BufferRegion& read : epilogue_block_->reads) { + if (read->buffer.same_as(epilogue_addend_buffer_)) { + epilogue_addend_region_ = read; + break; + } + } + } + } } Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reduction_block, @@ -1308,54 +1340,83 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti var_map[epilogue_data_vars[i]] = reduction_data_vars[i]; } - // 2. Change init to epilogue value based on epilogue type - BufferStore new_init_store; - if (epilogue_type_ == EpilogueType::BiasReLU) { - // For ReLU, init should be max(C[vi, vj], 0) to match per-iteration ReLU semantics - PrimExpr init_value = Substitute(epilogue_addend_, var_map); - PrimExpr zero = tir::make_zero(init_value.dtype()); - new_init_store = BufferStore(epilogue_output_buffer_, Max(init_value, zero), - Substitute(epilogue_output_indices_, var_map)); - } else if (epilogue_type_ == EpilogueType::Clipping) { - // For Clipping, init should be min(max(init_value, lower), upper) - // Since init is typically 0, this becomes min(max(0, lower), upper) - PrimExpr init_value = tir::make_zero(epilogue_output_buffer_->dtype); - PrimExpr clipped_init = Min(Max(init_value, Substitute(clipping_lower_, var_map)), - Substitute(clipping_upper_, var_map)); - new_init_store = BufferStore(epilogue_output_buffer_, clipped_init, - Substitute(epilogue_output_indices_, var_map)); - } else { - // Bias: D[vi, vj] = C[vi, vj] - new_init_store = BufferStore(epilogue_output_buffer_, Substitute(epilogue_addend_, var_map), - Substitute(epilogue_output_indices_, var_map)); - } + // 2. Generalized init transformation: substitute reduction buffer load with identity element (0) + // Create a substituter to replace reduction_buffer_load_ with identity element + class InitSubstituter : public ExprMutator { + public: + InitSubstituter(const Buffer& target_buffer, PrimExpr identity_elem) + : target_buffer_(target_buffer), identity_elem_(identity_elem) {} + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad load = Downcast(ExprMutator::VisitExpr_(op)); + if (load->buffer.same_as(target_buffer_)) { + return identity_elem_; + } + return load; + } + + private: + Buffer target_buffer_; + PrimExpr identity_elem_; + }; + + // Identity element for reduction (assumed to be 0 for addition-based reductions) + PrimExpr identity_elem = tir::make_zero(epilogue_output_buffer_->dtype); + + // Substitute reduction buffer load with identity element + InitSubstituter init_subst(inlined_buffer_, identity_elem); + PrimExpr init_epilogue = init_subst(epilogue_expression_); + + // Apply index mapping + init_epilogue = Substitute(init_epilogue, var_map); + + // Simplify the expression (e.g., 0 + C[vi, vj] -> C[vi, vj]) + arith::Analyzer analyzer; + init_epilogue = analyzer.Simplify(init_epilogue); + + BufferStore new_init_store = BufferStore(epilogue_output_buffer_, init_epilogue, + Substitute(epilogue_output_indices_, var_map)); new_block->init = new_init_store; - // 3. Replace output buffer from temp to D in body - class BufferReplacer : public StmtExprMutator { + // 3. Generalized update transformation: replace reduction buffer with output buffer + // The epilogue addend should only be in the init, not in the update + // So we just replace the buffer reference, without applying the epilogue expression + class UpdateSubstituter : public StmtExprMutator { public: - BufferReplacer(Buffer old_buf, Buffer new_buf, EpilogueType epilogue_type, DataType dtype, - PrimExpr clipping_lower = PrimExpr(), PrimExpr clipping_upper = PrimExpr()) - : old_buffer_(old_buf), - new_buffer_(new_buf), - epilogue_type_(epilogue_type), - dtype_(dtype), - clipping_lower_(clipping_lower), - clipping_upper_(clipping_upper) {} + UpdateSubstituter(const Buffer& old_buf, const Buffer& new_buf, + const std::unordered_map& var_map) + : old_buffer_(old_buf), new_buffer_(new_buf), var_map_(var_map) {} Stmt VisitStmt_(const BufferStoreNode* op) final { BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); if (store->buffer.same_as(old_buffer_)) { - PrimExpr new_value = store->value; - // For ReLU, apply max per iteration to match per-iteration ReLU semantics - if (epilogue_type_ == EpilogueType::BiasReLU) { - PrimExpr zero = tir::make_zero(dtype_); - new_value = Max(new_value, zero); - } else if (epilogue_type_ == EpilogueType::Clipping) { - // For Clipping, apply min(max(value, lower), upper) per iteration - new_value = Min(Max(new_value, clipping_lower_), clipping_upper_); - } - return BufferStore(new_buffer_, new_value, store->indices); + // Replace old_buffer_ in store->value with new_buffer_ to get the reduction update expression + // This ensures store->value references new_buffer_ instead of old_buffer_ + class ReductionUpdateReplacer : public ExprMutator { + public: + ReductionUpdateReplacer(const Buffer& old_buf, const Buffer& new_buf) + : old_buffer_(old_buf), new_buffer_(new_buf) {} + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad load = Downcast(ExprMutator::VisitExpr_(op)); + if (load->buffer.same_as(old_buffer_)) { + return BufferLoad(new_buffer_, load->indices); + } + return load; + } + + private: + Buffer old_buffer_; + Buffer new_buffer_; + }; + + ReductionUpdateReplacer reduction_replacer(old_buffer_, new_buffer_); + PrimExpr reduction_update = reduction_replacer(store->value); + + // Apply index mapping + reduction_update = Substitute(reduction_update, var_map_); + + return BufferStore(new_buffer_, reduction_update, store->indices); } return store; } @@ -1371,19 +1432,10 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti private: Buffer old_buffer_; Buffer new_buffer_; - EpilogueType epilogue_type_; - DataType dtype_; - PrimExpr clipping_lower_; - PrimExpr clipping_upper_; + std::unordered_map var_map_; }; - - DataType dtype = epilogue_output_buffer_->dtype; - PrimExpr clipping_lower_subst = - epilogue_type_ == EpilogueType::Clipping ? Substitute(clipping_lower_, var_map) : PrimExpr(); - PrimExpr clipping_upper_subst = - epilogue_type_ == EpilogueType::Clipping ? Substitute(clipping_upper_, var_map) : PrimExpr(); - BufferReplacer replacer(inlined_buffer_, epilogue_output_buffer_, epilogue_type_, dtype, - clipping_lower_subst, clipping_upper_subst); + + UpdateSubstituter replacer(inlined_buffer_, epilogue_output_buffer_, var_map); new_block->body = replacer(reduction_block->body); // 4. Update write regions @@ -1398,21 +1450,22 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti } new_block->writes = new_writes; - // 5. Update read regions (C first, then A, B) + // 5. Update read regions: add all buffers from epilogue expression (except reduction buffer) ffi::Array new_reads; std::unordered_set read_bufs; - // Add C buffer read first (used in init) - if (epilogue_addend_buffer_.defined()) { - new_reads.push_back(BufferRegion(epilogue_addend_buffer_, - Substitute(epilogue_addend_region_->region, var_map))); - read_bufs.insert(epilogue_addend_buffer_.get()); + // Add all non-reduction buffers from epilogue expression + for (const BufferRegion& read : epilogue_block_->reads) { + if (!read->buffer.same_as(inlined_buffer_)) { + new_reads.push_back(BufferRegion(read->buffer, Substitute(read->region, var_map))); + read_bufs.insert(read->buffer.get()); + } } - // Add existing read regions (A, B, etc.) + // Add existing read regions from reduction block (A, B, etc.) for (const BufferRegion& read : reduction_block->reads) { if (!read->buffer.same_as(inlined_buffer_)) { - // Only add non-temp buffers + // Only add non-temp buffers that haven't been added yet if (read_bufs.find(read->buffer.get()) == read_bufs.end()) { new_reads.push_back(read); read_bufs.insert(read->buffer.get()); From 6ae855ddf510b913462cfb83a8565a9e25857597 Mon Sep 17 00:00:00 2001 From: Hyun Gyu Kim Date: Mon, 5 Jan 2026 16:11:39 +0900 Subject: [PATCH 2/9] Solve Issue relu, clipping --- src/tir/schedule/primitive/compute_inline.cc | 106 +++++++++++++++++-- 1 file changed, 100 insertions(+), 6 deletions(-) diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 77ff91f4df32..c6c6c15dfc86 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -1379,13 +1379,20 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti new_block->init = new_init_store; // 3. Generalized update transformation: replace reduction buffer with output buffer - // The epilogue addend should only be in the init, not in the update - // So we just replace the buffer reference, without applying the epilogue expression + // For Bias pattern: epilogue addend should only be in the init, not in the update + // For BiasReLU and Clipping patterns: epilogue expression must be applied per-iteration class UpdateSubstituter : public StmtExprMutator { public: UpdateSubstituter(const Buffer& old_buf, const Buffer& new_buf, + const PrimExpr& epilogue_expr, EpilogueType epilogue_type, + const PrimExpr& epilogue_addend, const std::unordered_map& var_map) - : old_buffer_(old_buf), new_buffer_(new_buf), var_map_(var_map) {} + : old_buffer_(old_buf), + new_buffer_(new_buf), + epilogue_expression_(epilogue_expr), + epilogue_type_(epilogue_type), + epilogue_addend_(epilogue_addend), + var_map_(var_map) {} Stmt VisitStmt_(const BufferStoreNode* op) final { BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); @@ -1413,10 +1420,87 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti ReductionUpdateReplacer reduction_replacer(old_buffer_, new_buffer_); PrimExpr reduction_update = reduction_replacer(store->value); + PrimExpr new_value; + if (epilogue_type_ == EpilogueType::Bias) { + // For Bias pattern: just use the reduction update without epilogue expression + new_value = reduction_update; + } else if (epilogue_type_ == EpilogueType::BiasReLU) { + // For BiasReLU pattern: apply ReLU only (without bias addend) per-iteration + // The epilogue expression is max(temp + C, 0), but update should be max(reduction_update, 0) + // So we need to remove the bias addend from the epilogue expression + class ReLUApplier : public ExprMutator { + public: + ReLUApplier(const Buffer& target_buf, const PrimExpr& replacement, + const PrimExpr& addend_to_remove) + : target_buffer_(target_buf), + replacement_(replacement), + addend_to_remove_(addend_to_remove) {} + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad load = Downcast(ExprMutator::VisitExpr_(op)); + if (load->buffer.same_as(target_buffer_)) { + return replacement_; + } + return load; + } + + PrimExpr VisitExpr_(const AddNode* op) final { + // Remove the bias addend from the addition + PrimExpr a = VisitExpr(op->a); + PrimExpr b = VisitExpr(op->b); + + // Check if either operand matches the addend to remove + arith::Analyzer analyzer; + if (analyzer.CanProveEqual(a, addend_to_remove_)) { + return b; + } + if (analyzer.CanProveEqual(b, addend_to_remove_)) { + return a; + } + + // If neither matches, return the original addition + return Add(a, b); + } + + private: + Buffer target_buffer_; + PrimExpr replacement_; + PrimExpr addend_to_remove_; + }; + + // Get the bias addend with index mapping applied + PrimExpr addend_mapped = Substitute(epilogue_addend_, var_map_); + ReLUApplier applier(old_buffer_, reduction_update, addend_mapped); + new_value = applier(epilogue_expression_); + } else { + // For Clipping pattern: apply epilogue expression per-iteration + // Substitute reduction buffer load with the reduction update in epilogue expression + class EpilogueApplier : public ExprMutator { + public: + EpilogueApplier(const Buffer& target_buf, const PrimExpr& replacement) + : target_buffer_(target_buf), replacement_(replacement) {} + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad load = Downcast(ExprMutator::VisitExpr_(op)); + if (load->buffer.same_as(target_buffer_)) { + return replacement_; + } + return load; + } + + private: + Buffer target_buffer_; + PrimExpr replacement_; + }; + + EpilogueApplier applier(old_buffer_, reduction_update); + new_value = applier(epilogue_expression_); + } + // Apply index mapping - reduction_update = Substitute(reduction_update, var_map_); + new_value = Substitute(new_value, var_map_); - return BufferStore(new_buffer_, reduction_update, store->indices); + return BufferStore(new_buffer_, new_value, store->indices); } return store; } @@ -1432,10 +1516,20 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti private: Buffer old_buffer_; Buffer new_buffer_; + PrimExpr epilogue_expression_; + EpilogueType epilogue_type_; + PrimExpr epilogue_addend_; std::unordered_map var_map_; }; + + // Apply index mapping to epilogue expression first + PrimExpr epilogue_expr_mapped = Substitute(epilogue_expression_, var_map); + PrimExpr epilogue_addend_mapped = epilogue_addend_.defined() + ? Substitute(epilogue_addend_, var_map) + : PrimExpr(nullptr); - UpdateSubstituter replacer(inlined_buffer_, epilogue_output_buffer_, var_map); + UpdateSubstituter replacer(inlined_buffer_, epilogue_output_buffer_, epilogue_expr_mapped, + epilogue_type_, epilogue_addend_mapped, var_map); new_block->body = replacer(reduction_block->body); // 4. Update write regions From 0713850100e56c5e761fc84c7d627d0f7409be1f Mon Sep 17 00:00:00 2001 From: Hyun Gyu Kim Date: Mon, 5 Jan 2026 16:24:33 +0900 Subject: [PATCH 3/9] general form --- src/tir/schedule/primitive/compute_inline.cc | 158 +++++++++---------- 1 file changed, 78 insertions(+), 80 deletions(-) diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index c6c6c15dfc86..b1aa4991670f 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -1384,14 +1384,12 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti class UpdateSubstituter : public StmtExprMutator { public: UpdateSubstituter(const Buffer& old_buf, const Buffer& new_buf, - const PrimExpr& epilogue_expr, EpilogueType epilogue_type, - const PrimExpr& epilogue_addend, + const Buffer& reduction_buf, const PrimExpr& epilogue_expr, const std::unordered_map& var_map) : old_buffer_(old_buf), new_buffer_(new_buf), + reduction_buffer_(reduction_buf), epilogue_expression_(epilogue_expr), - epilogue_type_(epilogue_type), - epilogue_addend_(epilogue_addend), var_map_(var_map) {} Stmt VisitStmt_(const BufferStoreNode* op) final { @@ -1420,82 +1418,86 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti ReductionUpdateReplacer reduction_replacer(old_buffer_, new_buffer_); PrimExpr reduction_update = reduction_replacer(store->value); - PrimExpr new_value; - if (epilogue_type_ == EpilogueType::Bias) { - // For Bias pattern: just use the reduction update without epilogue expression - new_value = reduction_update; - } else if (epilogue_type_ == EpilogueType::BiasReLU) { - // For BiasReLU pattern: apply ReLU only (without bias addend) per-iteration - // The epilogue expression is max(temp + C, 0), but update should be max(reduction_update, 0) - // So we need to remove the bias addend from the epilogue expression - class ReLUApplier : public ExprMutator { - public: - ReLUApplier(const Buffer& target_buf, const PrimExpr& replacement, - const PrimExpr& addend_to_remove) - : target_buffer_(target_buf), - replacement_(replacement), - addend_to_remove_(addend_to_remove) {} - - PrimExpr VisitExpr_(const BufferLoadNode* op) final { - BufferLoad load = Downcast(ExprMutator::VisitExpr_(op)); - if (load->buffer.same_as(target_buffer_)) { - return replacement_; - } - return load; - } + // Generalized approach: apply epilogue expression with reduction buffer load replaced + // If reduction buffer load's direct parent is Add and the other operand is not a reduction buffer, + // remove that operand (bias addend) from the update expression + class GeneralizedEpilogueApplier : public ExprMutator { + public: + GeneralizedEpilogueApplier(const Buffer& target_buf, const Buffer& reduction_buf, + const PrimExpr& replacement) + : target_buffer_(target_buf), + reduction_buffer_(reduction_buf), + replacement_(replacement), + found_target_load_(false), + parent_is_add_(false), + addend_to_remove_(nullptr) {} - PrimExpr VisitExpr_(const AddNode* op) final { - // Remove the bias addend from the addition - PrimExpr a = VisitExpr(op->a); - PrimExpr b = VisitExpr(op->b); - - // Check if either operand matches the addend to remove - arith::Analyzer analyzer; - if (analyzer.CanProveEqual(a, addend_to_remove_)) { - return b; - } - if (analyzer.CanProveEqual(b, addend_to_remove_)) { - return a; - } - - // If neither matches, return the original addition - return Add(a, b); + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad load = Downcast(ExprMutator::VisitExpr_(op)); + if (load->buffer.same_as(target_buffer_)) { + found_target_load_ = true; + // Check if parent is Add (will be checked in VisitExpr_(const AddNode*)) + return replacement_; } + return load; + } - private: - Buffer target_buffer_; - PrimExpr replacement_; - PrimExpr addend_to_remove_; - }; - - // Get the bias addend with index mapping applied - PrimExpr addend_mapped = Substitute(epilogue_addend_, var_map_); - ReLUApplier applier(old_buffer_, reduction_update, addend_mapped); - new_value = applier(epilogue_expression_); - } else { - // For Clipping pattern: apply epilogue expression per-iteration - // Substitute reduction buffer load with the reduction update in epilogue expression - class EpilogueApplier : public ExprMutator { - public: - EpilogueApplier(const Buffer& target_buf, const PrimExpr& replacement) - : target_buffer_(target_buf), replacement_(replacement) {} - - PrimExpr VisitExpr_(const BufferLoadNode* op) final { - BufferLoad load = Downcast(ExprMutator::VisitExpr_(op)); - if (load->buffer.same_as(target_buffer_)) { - return replacement_; + PrimExpr VisitExpr_(const AddNode* op) final { + // Visit children first to see if we find the target buffer load + bool found_before = found_target_load_; + found_target_load_ = false; + + PrimExpr a = VisitExpr(op->a); + bool found_in_a = found_target_load_; + found_target_load_ = false; + + PrimExpr b = VisitExpr(op->b); + bool found_in_b = found_target_load_; + + // If target buffer load was found in this Add node + if (found_in_a || found_in_b) { + // Check if the other operand is NOT from the reduction buffer + // If so, it's likely a bias addend that should be removed in update + bool other_is_reduction = false; + if (found_in_a) { + // Check if b is from reduction buffer + if (const auto* load_b = b.as()) { + other_is_reduction = load_b->buffer.same_as(reduction_buffer_); + } + if (!other_is_reduction) { + // b is the bias addend, remove it + return a; + } + } else { // found_in_b + // Check if a is from reduction buffer + if (const auto* load_a = a.as()) { + other_is_reduction = load_a->buffer.same_as(reduction_buffer_); + } + if (!other_is_reduction) { + // a is the bias addend, remove it + return b; + } } - return load; + // If other operand is also from reduction buffer, keep the Add + return Add(a, b); } + + // Target buffer load not found in this Add, return as is + found_target_load_ = found_before; + return Add(a, b); + } - private: - Buffer target_buffer_; - PrimExpr replacement_; - }; + private: + const Buffer& target_buffer_; + const Buffer& reduction_buffer_; + const PrimExpr& replacement_; + bool found_target_load_; + bool parent_is_add_; + PrimExpr addend_to_remove_; + }; - EpilogueApplier applier(old_buffer_, reduction_update); - new_value = applier(epilogue_expression_); - } + GeneralizedEpilogueApplier applier(old_buffer_, reduction_buffer_, reduction_update); + PrimExpr new_value = applier(epilogue_expression_); // Apply index mapping new_value = Substitute(new_value, var_map_); @@ -1516,20 +1518,16 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti private: Buffer old_buffer_; Buffer new_buffer_; + Buffer reduction_buffer_; PrimExpr epilogue_expression_; - EpilogueType epilogue_type_; - PrimExpr epilogue_addend_; std::unordered_map var_map_; }; // Apply index mapping to epilogue expression first PrimExpr epilogue_expr_mapped = Substitute(epilogue_expression_, var_map); - PrimExpr epilogue_addend_mapped = epilogue_addend_.defined() - ? Substitute(epilogue_addend_, var_map) - : PrimExpr(nullptr); - UpdateSubstituter replacer(inlined_buffer_, epilogue_output_buffer_, epilogue_expr_mapped, - epilogue_type_, epilogue_addend_mapped, var_map); + UpdateSubstituter replacer(inlined_buffer_, epilogue_output_buffer_, inlined_buffer_, + epilogue_expr_mapped, var_map); new_block->body = replacer(reduction_block->body); // 4. Update write regions From d10c376db65e0f7e0a8cc35ad529a901c434187e Mon Sep 17 00:00:00 2001 From: Hyun Gyu Kim Date: Mon, 5 Jan 2026 16:36:39 +0900 Subject: [PATCH 4/9] general 2 --- src/tir/schedule/primitive/compute_inline.cc | 185 +------------------ 1 file changed, 9 insertions(+), 176 deletions(-) diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index b1aa4991670f..919571e1ff85 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -986,15 +986,8 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sre /*! * \brief Helper to fuse epilogue block into reduction block - * Analyzes epilogue pattern and transforms reduction init/update + * Uses generalized approach to handle any epilogue expression without pattern matching */ -// Epilogue type enumeration -enum class EpilogueType { - Bias, // temp + C - BiasReLU, // max(temp + C, 0) - Clipping, // min(max(temp, lower), upper) -}; - class ReductionEpilogueFuser : public BaseInliner { public: explicit ReductionEpilogueFuser(const Buffer& reduction_buffer, const BlockNode* reduction_block, @@ -1002,8 +995,7 @@ class ReductionEpilogueFuser : public BaseInliner { const StmtSRef& scope_root_sref) : BaseInliner(reduction_buffer, epilogue_block_realize->block, scope_root_sref), reduction_block_(reduction_block), - epilogue_block_(epilogue_block_realize->block.get()), - epilogue_type_(EpilogueType::Bias) { + epilogue_block_(epilogue_block_realize->block.get()) { // Disable opaque access check for epilogue fusion // Epilogue blocks can read multiple buffers (temp + bias), which is allowed has_opaque_access = false; @@ -1023,7 +1015,6 @@ class ReductionEpilogueFuser : public BaseInliner { const BlockRealizeNode* reduction_realize); private: - bool AnalyzeEpiloguePattern(const PrimExpr& value); bool IsReductionBlock(const BlockNode* block); void ExtractEpilogueInfo(); // Helper function to extract BufferLoad nodes from BufferStore @@ -1055,15 +1046,11 @@ class ReductionEpilogueFuser : public BaseInliner { // Generalized approach: store the entire epilogue expression PrimExpr epilogue_expression_{nullptr}; // The entire epilogue expression (e.g., temp + C, max(temp + C, 0)) const BufferLoadNode* reduction_buffer_load_{nullptr}; // The reduction buffer load in epilogue expression - PrimExpr epilogue_addend_{nullptr}; // C[vi, vj] in D = temp + C (kept for backward compatibility) Buffer epilogue_output_buffer_{nullptr}; // Output buffer D ffi::Array epilogue_output_indices_{nullptr}; // Indices of D[vi, vj] BufferRegion epilogue_output_region_{nullptr}; // Write region of D - Buffer epilogue_addend_buffer_{nullptr}; // Addend buffer C - BufferRegion epilogue_addend_region_{nullptr}; // Read region of C - EpilogueType epilogue_type_; // Type of epilogue operation (kept for backward compatibility) - PrimExpr clipping_lower_{nullptr}; // Lower bound for clipping (kept for backward compatibility) - PrimExpr clipping_upper_{nullptr}; // Upper bound for clipping (kept for backward compatibility) + Buffer epilogue_addend_buffer_{nullptr}; // Additional buffer (e.g., bias buffer C) + BufferRegion epilogue_addend_region_{nullptr}; // Read region of additional buffer }; bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize) { @@ -1098,9 +1085,6 @@ bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue epilogue_expression_ = inlined_store_->value; reduction_buffer_load_ = loads[0]; - // For backward compatibility, try to analyze pattern (optional) - AnalyzeEpiloguePattern(inlined_store_->value); - // 6. Check if producer is a reduction block if (!IsReductionBlock(reduction_block_)) { // Failure: producer is not a reduction block @@ -1113,140 +1097,6 @@ bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue return true; } -bool ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) { - // Pattern 1: temp[i,j] + C[i,j] or C[i,j] + temp[i,j] (Bias) - if (const auto* add = value.as()) { - const auto* load_a = add->a.as(); - const auto* load_b = add->b.as(); - - bool a_is_target = load_a && load_a->buffer.same_as(inlined_buffer_); - bool b_is_target = load_b && load_b->buffer.same_as(inlined_buffer_); - - // Ensure exactly one operand is from the reduction buffer - if (a_is_target != b_is_target) { - epilogue_addend_ = a_is_target ? add->b : add->a; - epilogue_type_ = EpilogueType::Bias; - return true; - } - } - - // Pattern 2: min(max(temp[i,j], lower), upper) or max(min(temp[i,j], upper), lower) (Clipping) - // Handle all commutative variants of min/max at each level. - - // Helper to check if an expression is a load from the reduction buffer, and - // return the other operand as `other` if so. - auto match_buffer_in_commutative_op = [this](const PrimExpr& a, const PrimExpr& b, - PrimExpr* other) -> bool { - if (const auto* load_a = a.as()) { - if (load_a->buffer.same_as(inlined_buffer_)) { - *other = b; - return true; - } - } - if (const auto* load_b = b.as()) { - if (load_b->buffer.same_as(inlined_buffer_)) { - *other = a; - return true; - } - } - return false; - }; - - // Check for min(max(temp, lower), upper) and commutative variants - if (const auto* min_node = value.as()) { - const MaxNode* max_node = nullptr; - PrimExpr upper; - // Try both (a, b) as possible positions of the inner max - if ((max_node = min_node->a.as())) { - upper = min_node->b; - } else if ((max_node = min_node->b.as())) { - upper = min_node->a; - } - if (max_node != nullptr) { - PrimExpr lower; - if (match_buffer_in_commutative_op(max_node->a, max_node->b, &lower)) { - clipping_lower_ = lower; - clipping_upper_ = upper; - epilogue_type_ = EpilogueType::Clipping; - return true; - } - } - } - - // Check for max(min(temp[i,j], upper), lower) and commutative variants - if (const auto* max_node = value.as()) { - const MinNode* min_node = nullptr; - PrimExpr lower; - // Try both (a, b) as possible positions of the inner min - if ((min_node = max_node->a.as())) { - lower = max_node->b; - } else if ((min_node = max_node->b.as())) { - lower = max_node->a; - } - if (min_node != nullptr) { - PrimExpr upper; - if (match_buffer_in_commutative_op(min_node->a, min_node->b, &upper)) { - clipping_lower_ = lower; - clipping_upper_ = upper; - epilogue_type_ = EpilogueType::Clipping; - return true; - } - } - } - - // Pattern 3: max(temp[i,j] + C[i,j], 0) or max(C[i,j] + temp[i,j], 0) (BiasReLU) - // Also handle max(0, temp[i,j] + C[i,j]) or max(0, C[i,j] + temp[i,j]) - if (const auto* max_node = value.as()) { - // Check if either operand is zero (ReLU: max(x, 0) or max(0, x)) - // Support both integer and float zero constants. - const PrimExpr* add_candidate = nullptr; - bool is_zero_const = false; - auto is_zero_expr = [](const PrimExpr& expr) -> bool { - if (tir::is_zero(expr)) { - return true; - } - if (const auto* float_imm = expr.as()) { - return float_imm->value == 0.0; - } - return false; - }; - - if (is_zero_expr(max_node->a)) { - is_zero_const = true; - add_candidate = &max_node->b; - } else if (is_zero_expr(max_node->b)) { - is_zero_const = true; - add_candidate = &max_node->a; - } - - if (is_zero_const && add_candidate != nullptr) { - if (const auto* add = add_candidate->as()) { - const auto* load_a = add->a.as(); - const auto* load_b = add->b.as(); - - bool a_is_target = load_a && load_a->buffer.same_as(inlined_buffer_); - bool b_is_target = load_b && load_b->buffer.same_as(inlined_buffer_); - - // Ensure exactly one operand is from the reduction buffer - if (a_is_target != b_is_target) { - epilogue_addend_ = a_is_target ? add->b : add->a; - epilogue_type_ = EpilogueType::BiasReLU; - return true; - } - } else if (const auto* load = add_candidate->as()) { - // Handle bias-free ReLU: max(temp, 0) or max(0, temp) - if (load->buffer.same_as(inlined_buffer_)) { - epilogue_addend_ = tir::make_zero(load->dtype); - epilogue_type_ = EpilogueType::BiasReLU; - return true; - } - } - } - } - - return false; -} - bool ReductionEpilogueFuser::IsReductionBlock(const BlockNode* block) { // Check if block has reduction iter vars for (const IterVar& iter : block->iter_vars) { @@ -1285,7 +1135,7 @@ void ReductionEpilogueFuser::ExtractEpilogueInfo() { extractor.reduction_buffer = inlined_buffer_; extractor(epilogue_expression_); - // Extract the first non-reduction buffer and its region (for backward compatibility) + // Extract the first non-reduction buffer and its region // In most cases, there's one additional buffer (e.g., bias buffer) if (!extractor.other_buffers.empty()) { const BufferNode* first_buffer = *extractor.other_buffers.begin(); @@ -1298,19 +1148,6 @@ void ReductionEpilogueFuser::ExtractEpilogueInfo() { } } } - - // For backward compatibility, also try to extract from epilogue_addend_ - if (!epilogue_addend_buffer_.defined() && epilogue_addend_.defined()) { - if (const auto* load = epilogue_addend_.as()) { - epilogue_addend_buffer_ = load->buffer; - for (const BufferRegion& read : epilogue_block_->reads) { - if (read->buffer.same_as(epilogue_addend_buffer_)) { - epilogue_addend_region_ = read; - break; - } - } - } - } } Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reduction_block, @@ -1378,9 +1215,9 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti Substitute(epilogue_output_indices_, var_map)); new_block->init = new_init_store; - // 3. Generalized update transformation: replace reduction buffer with output buffer - // For Bias pattern: epilogue addend should only be in the init, not in the update - // For BiasReLU and Clipping patterns: epilogue expression must be applied per-iteration + // 3. Generalized update transformation: apply epilogue expression with reduction buffer replaced + // If reduction buffer load's parent is Add and other operand is not a reduction buffer, + // remove that operand (bias addend) from update expression class UpdateSubstituter : public StmtExprMutator { public: UpdateSubstituter(const Buffer& old_buf, const Buffer& new_buf, @@ -1428,9 +1265,7 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti : target_buffer_(target_buf), reduction_buffer_(reduction_buf), replacement_(replacement), - found_target_load_(false), - parent_is_add_(false), - addend_to_remove_(nullptr) {} + found_target_load_(false) {} PrimExpr VisitExpr_(const BufferLoadNode* op) final { BufferLoad load = Downcast(ExprMutator::VisitExpr_(op)); @@ -1492,8 +1327,6 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti const Buffer& reduction_buffer_; const PrimExpr& replacement_; bool found_target_load_; - bool parent_is_add_; - PrimExpr addend_to_remove_; }; GeneralizedEpilogueApplier applier(old_buffer_, reduction_buffer_, reduction_update); From 98fd3316bfe966be4ee72d68f86b8283ebd03278 Mon Sep 17 00:00:00 2001 From: Hyun Gyu Kim Date: Mon, 5 Jan 2026 16:48:31 +0900 Subject: [PATCH 5/9] solve whitespace issue --- src/tir/schedule/primitive/compute_inline.cc | 28 ++++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 919571e1ff85..1aed9a3d70ea 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -1080,11 +1080,11 @@ bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue // epilogue expression for fusion. return false; } - + // Store the epilogue expression and reduction buffer load epilogue_expression_ = inlined_store_->value; reduction_buffer_load_ = loads[0]; - + // 6. Check if producer is a reduction block if (!IsReductionBlock(reduction_block_)) { // Failure: producer is not a reduction block @@ -1199,18 +1199,18 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti // Identity element for reduction (assumed to be 0 for addition-based reductions) PrimExpr identity_elem = tir::make_zero(epilogue_output_buffer_->dtype); - + // Substitute reduction buffer load with identity element InitSubstituter init_subst(inlined_buffer_, identity_elem); PrimExpr init_epilogue = init_subst(epilogue_expression_); - + // Apply index mapping init_epilogue = Substitute(init_epilogue, var_map); - + // Simplify the expression (e.g., 0 + C[vi, vj] -> C[vi, vj]) arith::Analyzer analyzer; init_epilogue = analyzer.Simplify(init_epilogue); - + BufferStore new_init_store = BufferStore(epilogue_output_buffer_, init_epilogue, Substitute(epilogue_output_indices_, var_map)); new_block->init = new_init_store; @@ -1254,7 +1254,7 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti ReductionUpdateReplacer reduction_replacer(old_buffer_, new_buffer_); PrimExpr reduction_update = reduction_replacer(store->value); - + // Generalized approach: apply epilogue expression with reduction buffer load replaced // If reduction buffer load's direct parent is Add and the other operand is not a reduction buffer, // remove that operand (bias addend) from the update expression @@ -1281,14 +1281,14 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti // Visit children first to see if we find the target buffer load bool found_before = found_target_load_; found_target_load_ = false; - + PrimExpr a = VisitExpr(op->a); bool found_in_a = found_target_load_; found_target_load_ = false; - + PrimExpr b = VisitExpr(op->b); bool found_in_b = found_target_load_; - + // If target buffer load was found in this Add node if (found_in_a || found_in_b) { // Check if the other operand is NOT from the reduction buffer @@ -1316,7 +1316,7 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti // If other operand is also from reduction buffer, keep the Add return Add(a, b); } - + // Target buffer load not found in this Add, return as is found_target_load_ = found_before; return Add(a, b); @@ -1331,10 +1331,10 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti GeneralizedEpilogueApplier applier(old_buffer_, reduction_buffer_, reduction_update); PrimExpr new_value = applier(epilogue_expression_); - + // Apply index mapping new_value = Substitute(new_value, var_map_); - + return BufferStore(new_buffer_, new_value, store->indices); } return store; @@ -1358,7 +1358,7 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti // Apply index mapping to epilogue expression first PrimExpr epilogue_expr_mapped = Substitute(epilogue_expression_, var_map); - + UpdateSubstituter replacer(inlined_buffer_, epilogue_output_buffer_, inlined_buffer_, epilogue_expr_mapped, var_map); new_block->body = replacer(reduction_block->body); From 8d0a07860364bc78dae8f923ed2a25d437fac8a7 Mon Sep 17 00:00:00 2001 From: Hyun Gyu Kim Date: Mon, 5 Jan 2026 16:54:26 +0900 Subject: [PATCH 6/9] clang format --- src/tir/schedule/primitive/compute_inline.cc | 25 ++++++++++---------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 1aed9a3d70ea..d9a457ff9ef6 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -1044,13 +1044,15 @@ class ReductionEpilogueFuser : public BaseInliner { const BlockNode* reduction_block_; const BlockNode* epilogue_block_; // Generalized approach: store the entire epilogue expression - PrimExpr epilogue_expression_{nullptr}; // The entire epilogue expression (e.g., temp + C, max(temp + C, 0)) - const BufferLoadNode* reduction_buffer_load_{nullptr}; // The reduction buffer load in epilogue expression - Buffer epilogue_output_buffer_{nullptr}; // Output buffer D + PrimExpr epilogue_expression_{ + nullptr}; // The entire epilogue expression (e.g., temp + C, max(temp + C, 0)) + const BufferLoadNode* reduction_buffer_load_{ + nullptr}; // The reduction buffer load in epilogue expression + Buffer epilogue_output_buffer_{nullptr}; // Output buffer D ffi::Array epilogue_output_indices_{nullptr}; // Indices of D[vi, vj] BufferRegion epilogue_output_region_{nullptr}; // Write region of D - Buffer epilogue_addend_buffer_{nullptr}; // Additional buffer (e.g., bias buffer C) - BufferRegion epilogue_addend_region_{nullptr}; // Read region of additional buffer + Buffer epilogue_addend_buffer_{nullptr}; // Additional buffer (e.g., bias buffer C) + BufferRegion epilogue_addend_region_{nullptr}; // Read region of additional buffer }; bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize) { @@ -1220,9 +1222,8 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti // remove that operand (bias addend) from update expression class UpdateSubstituter : public StmtExprMutator { public: - UpdateSubstituter(const Buffer& old_buf, const Buffer& new_buf, - const Buffer& reduction_buf, const PrimExpr& epilogue_expr, - const std::unordered_map& var_map) + UpdateSubstituter(const Buffer& old_buf, const Buffer& new_buf, const Buffer& reduction_buf, + const PrimExpr& epilogue_expr, const std::unordered_map& var_map) : old_buffer_(old_buf), new_buffer_(new_buf), reduction_buffer_(reduction_buf), @@ -1232,8 +1233,8 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti Stmt VisitStmt_(const BufferStoreNode* op) final { BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); if (store->buffer.same_as(old_buffer_)) { - // Replace old_buffer_ in store->value with new_buffer_ to get the reduction update expression - // This ensures store->value references new_buffer_ instead of old_buffer_ + // Replace old_buffer_ in store->value with new_buffer_ to get the reduction update + // expression This ensures store->value references new_buffer_ instead of old_buffer_ class ReductionUpdateReplacer : public ExprMutator { public: ReductionUpdateReplacer(const Buffer& old_buf, const Buffer& new_buf) @@ -1256,8 +1257,8 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti PrimExpr reduction_update = reduction_replacer(store->value); // Generalized approach: apply epilogue expression with reduction buffer load replaced - // If reduction buffer load's direct parent is Add and the other operand is not a reduction buffer, - // remove that operand (bias addend) from the update expression + // If reduction buffer load's direct parent is Add and the other operand is not a reduction + // buffer, remove that operand (bias addend) from the update expression class GeneralizedEpilogueApplier : public ExprMutator { public: GeneralizedEpilogueApplier(const Buffer& target_buf, const Buffer& reduction_buf, From 22bc3f857faa327fb297928bdf14be4305fcd0dd Mon Sep 17 00:00:00 2001 From: Hyun Gyu Kim Date: Mon, 19 Jan 2026 17:38:01 +0900 Subject: [PATCH 7/9] Add Test Case to verify Case: Multiple Occurrence --- ...st_tir_schedule_fuse_reduction_epilogue.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py index dc89f9df56a7..c3cf6bab4d74 100644 --- a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py +++ b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py @@ -214,5 +214,37 @@ def test_fuse_reduction_epilogue_multiple_epilogue(): assert mod is not None +@T.prim_func +def matmul_bias_invalid_multiple_use_before( + A: T.Buffer((16, 16), "int8"), + B: T.Buffer((16, 16), "int8"), + C1: T.Buffer((16, 16), "int32"), + C2: T.Buffer((16, 16), "int32"), + D: T.Buffer((16, 16), "int32"), +) -> None: + """Epilogue uses the reduction result twice; fusion must be rejected.""" + temp = T.alloc_buffer((16, 16), dtype="int32") + for i, j, k in T.grid(16, 16, 16): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + temp[vi, vj] = T.int32(0) + temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast( + B[vj, vk], "int32" + ) + for i, j in T.grid(16, 16): + with T.block("bad_epilogue"): + vi, vj = T.axis.remap("SS", [i, j]) + # temp[vi, vj] is used twice in the epilogue expression + D[vi, vj] = (temp[vi, vj] + C1[vi, vj]) * (temp[vi, vj] + C2[vi, vj]) + + +def test_fuse_reduction_epilogue_reject_multiple_use(): + """fusion should be rejected when the reduction result appears more than once.""" + sch = tir.Schedule(matmul_bias_invalid_multiple_use_before, debug_mask="all") + with pytest.raises(tvm.tir.ScheduleError): + sch.fuse_reduction_epilogue("multiply", "bad_epilogue") + + if __name__ == "__main__": tvm.testing.main() From da8354b40e5dc370acced59d29ae3f586bd6f3af Mon Sep 17 00:00:00 2001 From: Hyun Gyu Kim Date: Mon, 19 Jan 2026 17:45:37 +0900 Subject: [PATCH 8/9] Write Code and Add Test Case to verify Case: Non-additive scaling --- src/tir/schedule/primitive/compute_inline.cc | 84 +++++++++++++++++++ ...st_tir_schedule_fuse_reduction_epilogue.py | 31 +++++++ 2 files changed, 115 insertions(+) diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index d9a457ff9ef6..f9942cdaab6e 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -1087,6 +1087,90 @@ bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue epilogue_expression_ = inlined_store_->value; reduction_buffer_load_ = loads[0]; + // 5. Reject epilogues that scale the reduction result with non-additive ops + // For example, (reduce_out * 2.0) + C[i] is not a valid bias-style epilogue. + // We only allow the reduction result to be combined via Add/Min/Max shells. + class ScalingDetector : public ExprVisitor { + public: + explicit ScalingDetector(const Buffer& buffer) : buffer_(buffer) {} + + bool HasScaling(const PrimExpr& expr) { + has_scaling_ = false; + VisitExpr(expr); + return has_scaling_; + } + + private: + // Helper to check if a subtree contains a load from the reduction buffer + bool ContainsTarget(const PrimExpr& expr) { + class TargetFinder : public ExprVisitor { + public: + explicit TargetFinder(const Buffer& buffer) : buffer_(buffer) {} + + bool Find(const PrimExpr& e) { + found_ = false; + VisitExpr(e); + return found_; + } + + private: + void VisitExpr_(const BufferLoadNode* op) final { + if (op->buffer.same_as(buffer_)) { + found_ = true; + return; + } + ExprVisitor::VisitExpr_(op); + } + + Buffer buffer_; + bool found_{false}; + }; + + TargetFinder finder(buffer_); + return finder.Find(expr); + } + + void VisitExpr_(const MulNode* op) final { + if (has_scaling_) return; + // If either operand subtree contains the reduction buffer load, + // we treat this as invalid scaling of the reduction result. + if (ContainsTarget(op->a) || ContainsTarget(op->b)) { + has_scaling_ = true; + return; + } + ExprVisitor::VisitExpr_(op); + } + + void VisitExpr_(const DivNode* op) final { + if (has_scaling_) return; + if (ContainsTarget(op->a) || ContainsTarget(op->b)) { + has_scaling_ = true; + return; + } + ExprVisitor::VisitExpr_(op); + } + + void VisitExpr_(const ModNode* op) final { + if (has_scaling_) return; + if (ContainsTarget(op->a) || ContainsTarget(op->b)) { + has_scaling_ = true; + return; + } + ExprVisitor::VisitExpr_(op); + } + + Buffer buffer_; + bool has_scaling_{false}; + }; + + { + ScalingDetector detector(inlined_buffer_); + if (detector.HasScaling(inlined_store_->value)) { + // Failure: Non-additive scaling of the reduction result is not supported + return false; + } + } + // 6. Check if producer is a reduction block if (!IsReductionBlock(reduction_block_)) { // Failure: producer is not a reduction block diff --git a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py index c3cf6bab4d74..408efa30853a 100644 --- a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py +++ b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py @@ -246,5 +246,36 @@ def test_fuse_reduction_epilogue_reject_multiple_use(): sch.fuse_reduction_epilogue("multiply", "bad_epilogue") +@T.prim_func +def matmul_bias_invalid_scaling_before( + A: T.Buffer((16, 16), "int8"), + B: T.Buffer((16, 16), "int8"), + C: T.Buffer((16, 16), "int32"), + D: T.Buffer((16, 16), "int32"), +) -> None: + """Epilogue scales the reduction result; fusion must be rejected.""" + temp = T.alloc_buffer((16, 16), dtype="int32") + for i, j, k in T.grid(16, 16, 16): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + temp[vi, vj] = T.int32(0) + temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast( + B[vj, vk], "int32" + ) + for i, j in T.grid(16, 16): + with T.block("scaled_epilogue"): + vi, vj = T.axis.remap("SS", [i, j]) + # temp[vi, vj] is scaled by 2 before adding bias; this must not be fused. + D[vi, vj] = temp[vi, vj] * T.int32(2) + C[vi, vj] + + +def test_fuse_reduction_epilogue_reject_scaling(): + """fusion should be rejected when the reduction result is scaled by Mul/Div/Mod.""" + sch = tir.Schedule(matmul_bias_invalid_scaling_before, debug_mask="all") + with pytest.raises(tvm.tir.ScheduleError): + sch.fuse_reduction_epilogue("multiply", "scaled_epilogue") + + if __name__ == "__main__": tvm.testing.main() From cca87ccedd64ede55dec887fbe6783baed064175 Mon Sep 17 00:00:00 2001 From: Hyun Gyu Kim Date: Wed, 21 Jan 2026 14:14:13 +0900 Subject: [PATCH 9/9] lint Signed-off-by: Hyun Gyu Kim --- .../test_tir_schedule_fuse_reduction_epilogue.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py index 408efa30853a..addb7efff323 100644 --- a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py +++ b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py @@ -229,9 +229,7 @@ def matmul_bias_invalid_multiple_use_before( vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): temp[vi, vj] = T.int32(0) - temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast( - B[vj, vk], "int32" - ) + temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") for i, j in T.grid(16, 16): with T.block("bad_epilogue"): vi, vj = T.axis.remap("SS", [i, j]) @@ -260,9 +258,7 @@ def matmul_bias_invalid_scaling_before( vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): temp[vi, vj] = T.int32(0) - temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast( - B[vj, vk], "int32" - ) + temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") for i, j in T.grid(16, 16): with T.block("scaled_epilogue"): vi, vj = T.axis.remap("SS", [i, j])