From 9f1fd0b6a447ab1638092d365c93538a50bdcc0b Mon Sep 17 00:00:00 2001 From: owCode Date: Fri, 16 Jan 2026 23:06:06 -0500 Subject: [PATCH 1/7] finish allocator --- include/core/allocator.h | 7 ++-- src/core/allocator.cc | 73 +++++++++++++++++++++++++++++++++++++++- src/core/graph.cc | 16 +++++++++ 3 files changed, 92 insertions(+), 4 deletions(-) diff --git a/include/core/allocator.h b/include/core/allocator.h index 002601d..adf54e3 100644 --- a/include/core/allocator.h +++ b/include/core/allocator.h @@ -14,19 +14,20 @@ namespace infini { private: Runtime runtime; - size_t used; + size_t used; // 当前使用的内存大小(bytes) - size_t peak; + size_t peak; // 历史峰值内存使用大小(bytes) size_t alignment; // pointer to the memory actually allocated - void *ptr; + void *ptr; // 预分配时,该指针必须为nullptr,预分配的地址偏移都是相对于ptr的 // =================================== 作业 =================================== // TODO:可能需要设计一个数据结构来存储free block,以便于管理和合并 // HINT: 可以使用一个 map 来存储 free block,key 为 block 的起始/结尾地址,value 为 block 的大小 // =================================== 作业 =================================== + std::map freeBlocks; // 用于记录预分配的空闲内存块,key为块的起始地址偏移,value为块的大小(bytes) public: Allocator(Runtime runtime); diff --git a/src/core/allocator.cc b/src/core/allocator.cc index ff593ae..c3e1141 100644 --- a/src/core/allocator.cc +++ b/src/core/allocator.cc @@ -32,8 +32,50 @@ namespace infini // =================================== 作业 =================================== // TODO: 设计一个算法来分配内存,返回起始地址偏移量 // =================================== 作业 =================================== + // 1. 在 freeBlocks 中查找合适块--首次适配算法 + for (auto it = freeBlocks.begin(); it != freeBlocks.end(); ++it) { + if (it->second >= size) { // 找到大小足够的内存块 + size_t start_addr = it->first; + size_t block_size = it->second; // 空闲块的原始大小 + // 从空闲块列表中移除这个块 + freeBlocks.erase(it); + // 检查是否需要分割 + if (block_size > size) + { + // 分割:剩余部分作为新的空闲块 + size_t remaining_addr = start_addr + size; + size_t remaining_size = block_size - size; + freeBlocks[remaining_addr] = remaining_size; + } + // 如果 block_size == size,就不需要创建新的空闲块 + // 更新已使用内存 + used += size; + // 更新 peak(检查当前分配是否达到了新的峰值) + size_t end_addr = start_addr + size; + if (end_addr > peak) { + peak = end_addr; + } + return start_addr; + } + } + // 2. freeBlocks 中没有合适块,需要从末尾分配 + // 末尾地址就是当前的 peak + auto it = freeBlocks.begin(); + size_t start_addr; + if(it->first + it-> second != this->peak) { + start_addr = this->peak; // 从当前 peak 开始分配 + size_t end_addr = start_addr + size; + // 更新 peak + this->peak = end_addr; + used += size; + } else { // 末尾拓展 + freeBlocks.erase(it); + start_addr = it->first; + this->peak = it->first + size; + used += size; - return 0; + } + return start_addr; } void Allocator::free(size_t addr, size_t size) @@ -44,6 +86,35 @@ namespace infini // =================================== 作业 =================================== // TODO: 设计一个算法来回收内存 // =================================== 作业 =================================== + // 将释放的内存块加入freeBlocks + this->freeBlocks[addr] = size; + this->used -= size; + // 合并相邻的空闲块 + auto it = this->freeBlocks.find(addr); + // 检查是否能与前一个块合并 + if (it != this->freeBlocks.begin()) { + auto prev = std::prev(it); + if (prev->first + prev->second == it->first) { // first: 起始地址 second: 大小 + // 可以合并 + prev->second += it->second; // 合并时更新大小 + this->freeBlocks.erase(it); + it = prev; // 更新当前迭代器指向合并后的块 + } else { + // 不能合并,恢复it指向原始块 + it = freeBlocks.find(addr); + } + } + + // 检查是否能与后一个块合并 + if (it != freeBlocks.end()) { + auto next = std::next(it); + if (next != freeBlocks.end() && it->first + it->second == next->first) { + // 可以合并 + it->second += next->second; + freeBlocks.erase(next); + } + } + // 之所以可以这么合并,因为map中key是有序分布的 } void *Allocator::getPtr() diff --git a/src/core/graph.cc b/src/core/graph.cc index 3a90637..7b5d1d0 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -152,6 +152,22 @@ namespace infini // TODO:利用 allocator 给计算图分配内存 // HINT: 获取分配好的内存指针后,可以调用 tensor 的 setDataBlob 函数给 tensor 绑定内存 // =================================== 作业 =================================== + // 预先分配所有 tensor 所需的内存并记录 + std::vector tensorOffsets; + for (auto &tensor : tensors) + { + size_t offset = allocator.alloc(tensor->getBytes()); + tensorOffsets.emplace_back(offset); + } + + auto ptr = allocator.getPtr(); // 实际开始分配地址 + IT_ASSERT(ptr != nullptr,"Allocator getPtr() returns nullptr"); + + for (int i = 0; i < (int)tensors.size(); ++i) + { + auto blob = make_ref(runtime, static_cast(ptr) + tensorOffsets[i]); + tensors[i]->setDataBlob(blob); + } allocator.info(); } From 017c9f0c1ccc768adfe27a3828fd551e99603f8f Mon Sep 17 00:00:00 2001 From: owCode Date: Fri, 16 Jan 2026 23:19:58 -0500 Subject: [PATCH 2/7] finsh transpose --- src/operators/transpose.cc | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/src/operators/transpose.cc b/src/operators/transpose.cc index faab2b6..6b3c876 100644 --- a/src/operators/transpose.cc +++ b/src/operators/transpose.cc @@ -24,17 +24,36 @@ namespace infini optional> TransposeObj::inferShape(const TensorVec &inputs) { - const auto A = inputs[0]; - auto input_dim = A->getDims(); - auto output_dim = input_dim; - int rank = A->getRank(); + // =================================== 作业 =================================== // TODO:修改 output_dim,返回正确的 transpose 后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Transpose.html#transpose-21 // =================================== 作业 =================================== - - return std::nullopt; + // 1. 检查输入参数 + if (inputs.size() != 1) { + // 通常应该记录日志或抛出异常 + return std::nullopt; + } + const auto A = inputs[0]; + auto input_dim = A->getDims(); + int rank = A->getRank(); + // 2. 验证 permute 参数的有效性 + if (transposePermute.size() != static_cast(rank)) { + // 错误:permute 数组长度必须等于 rank + return std::nullopt; + } + // 3. 创建输出形状并计算 + vector output_dim(rank); + for (int i = 0; i < rank; ++i) { + int src_idx = transposePermute[i]; + // 检查索引是否越界 + if (src_idx < 0 || src_idx >= rank) { + return std::nullopt; // 索引越界 + } + output_dim[i] = input_dim[src_idx]; + } + return vector{output_dim}; // 注意:这里返回的是 vector,需要嵌套一层 } std::string TransposeObj::toString() const From 6622ef8c60edc78c3b80f7e1c8f88e9c54b4fe9a Mon Sep 17 00:00:00 2001 From: owCode Date: Fri, 16 Jan 2026 23:23:26 -0500 Subject: [PATCH 3/7] finish concat --- src/operators/concat.cc | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/operators/concat.cc b/src/operators/concat.cc index d196330..3c9bc58 100644 --- a/src/operators/concat.cc +++ b/src/operators/concat.cc @@ -17,7 +17,14 @@ optional> ConcatObj::inferShape(const TensorVec &inputs) { // TODO:修改 dims,返回正确的 concat 后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Concat.html#concat-13 // =================================== 作业 =================================== - + for (size_t i = 1; i < inputs.size(); ++i) { + auto i_dims = inputs[i]->getDims(); + auto i_rank = inputs[i]->getRank(); + // 检查输入张量的维度是否匹配 + IT_ASSERT(rank==i_rank, "All input tensors must have the same rank"); + // 累加 concat 维度的大小 + dims[dim] += i_dims[dim]; + } return {{dims}}; } From 1ad132cfcf898d5f3aeff4712e10d53a18daad82 Mon Sep 17 00:00:00 2001 From: owCode Date: Fri, 16 Jan 2026 23:28:35 -0500 Subject: [PATCH 4/7] finish clip cast --- src/operators/unary.cc | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/operators/unary.cc b/src/operators/unary.cc index 3daad36..c466af3 100644 --- a/src/operators/unary.cc +++ b/src/operators/unary.cc @@ -39,7 +39,12 @@ namespace infini // TODO:返回经过 clip 操作后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Clip.html#clip-13 // =================================== 作业 =================================== - return std::nullopt; + if(inputs.size() != 1) { + return std::nullopt; + } + const auto A = inputs[0]; + return vector{A->getDims()}; // clip 不改变 shape + } std::string ClipObj::toString() const @@ -66,7 +71,7 @@ namespace infini // REF_FILE: src/core/operator.cc // REF: https://onnx.ai/onnx/operators/onnx__Cast.html#cast-21 // =================================== 作业 =================================== - return {}; + return vector{getOutputDataType()}; } optional> CastObj::inferShape(const TensorVec &inputs) @@ -75,7 +80,11 @@ namespace infini // TODO:返回经过 cast 操作后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Cast.html#cast-21 // =================================== 作业 =================================== - return std::nullopt; + if(inputs.size() != 1) { + return std::nullopt; + } + const auto A = inputs[0]; + return vector{A->getDims()}; // cast 不改变 shape } std::string CastObj::toString() const From 9dcb6adca68165e5d5367aa6edcb670e3d17907c Mon Sep 17 00:00:00 2001 From: owCode Date: Fri, 16 Jan 2026 23:33:40 -0500 Subject: [PATCH 5/7] finish elementwise --- src/utils/operator_utils.cc | 38 +++++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/src/utils/operator_utils.cc b/src/utils/operator_utils.cc index edbd2c8..a410880 100644 --- a/src/utils/operator_utils.cc +++ b/src/utils/operator_utils.cc @@ -9,8 +9,42 @@ Shape infer_broadcast(const Shape &A, const Shape &B) { // TODO:对 A 和 B 进行双向广播,返回广播后的形状。 // REF: https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md // =================================== 作业 =================================== - - return {}; + // 获取A和B的维度数 + + int64_t dimsA = static_cast(A.size()); + int64_t dimsB = static_cast(B.size()); + // 计算广播后的维度数(取两者中的最大值) + int64_t maxDims = std::max(dimsA, dimsB); + // 创建结果形状 + Shape result(maxDims); + // 从最右侧维度开始比较 + for (int64_t i = 0; i < maxDims; i++) { + // 计算A和B在当前维度上的索引(从右侧开始) + int64_t idxA = dimsA - 1 - i; + int64_t idxB = dimsB - 1 - i; + // 获取A和B在当前维度上的大小(如果维度不存在则为1) + int64_t dimA = (idxA >= 0) ? A[idxA] : 1; + int64_t dimB = (idxB >= 0) ? B[idxB] : 1; + // 计算结果维度 + if (dimA == dimB) { + // 情况1:维度大小相等 + result[maxDims - 1 - i] = dimA; + } else if (dimA == 1) { + // 情况2:A的维度为1,使用B的维度 + result[maxDims - 1 - i] = dimB; + } else if (dimB == 1) { + // 情况3:B的维度为1,使用A的维度 + result[maxDims - 1 - i] = dimA; + } else { + // 情况4:维度大小不相等且都不为1,无法广播 + throw std::invalid_argument( + "Incompatible shapes for broadcasting: dimension " + + std::to_string(i) + " is " + std::to_string(dimA) + + " and " + std::to_string(dimB) + ", but both are > 1" + ); + } + } + return result; } int get_real_axis(const int &axis, const int &rank) { From 732ac8584f902ab813e5c39a5abd9ffb3d96a4f3 Mon Sep 17 00:00:00 2001 From: owCode Date: Fri, 16 Jan 2026 23:36:08 -0500 Subject: [PATCH 6/7] finish matmul --- src/operators/matmul.cc | 174 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 173 insertions(+), 1 deletion(-) diff --git a/src/operators/matmul.cc b/src/operators/matmul.cc index 7a16ca2..7206b43 100644 --- a/src/operators/matmul.cc +++ b/src/operators/matmul.cc @@ -27,7 +27,179 @@ namespace infini // TODO:返回经过 matmul 操作后的 shape // REF: https://github.com/onnx/onnx/blob/main/docs/Operators.md#gemm // =================================== 作业 =================================== - return std::nullopt; + + + // 检查输入数量 + + if (inputs.size() != 2) { + + return std::nullopt; + + } + + + + Shape A_dims = inputs[0]->getDims(); + + auto A_rank = inputs[0]->getRank(); + + Shape B_dims = inputs[1]->getDims(); + + auto B_rank = inputs[1]->getRank(); + + + + IT_ASSERT(A_rank >= 2 && B_rank >= 2); + + + + // 获取转置参数 + + bool transA = this->getTransA(); + + bool transB = this->getTransB(); + + + + // 计算A和B在矩阵乘法中使用的维度 + + int A_M, A_K; // A的行数(M)和列数(K) + + int B_K, B_N; // B的行数(K)和列数(N) + + + + if (!transA) { + + // A: [..., M, K] + + A_M = A_dims[A_rank - 2]; + + A_K = A_dims[A_rank - 1]; + + } else { + + // A转置: [..., K, M] -> 当作[..., M, K]使用 + + A_M = A_dims[A_rank - 1]; + + A_K = A_dims[A_rank - 2]; + + } + + + + if (!transB) { + + // B: [..., K, N] + + B_K = B_dims[B_rank - 2]; + + B_N = B_dims[B_rank - 1]; + + } else { + + // B转置: [..., N, K] -> 当作[..., K, N]使用 + + B_K = B_dims[B_rank - 1]; + + B_N = B_dims[B_rank - 2]; + + } + + + + // 检查K维度是否匹配 + + if (A_K != B_K) { + + return std::nullopt; + + } + + + + // 准备输出形状 + + Shape output_shape; + + + + // 处理广播维度 + + if (A_rank == 2 && B_rank == 2) { + + // 简单2D情况 + + output_shape = {A_M, B_N}; + + } else { + + // 需要广播的情况 + + + + // 获取batch维度 + + size_t A_batch_dims = A_rank - 2; + + size_t B_batch_dims = B_rank - 2; + + + + // 确定最大batch维度数 + + size_t max_batch_dims = std::max(A_batch_dims, B_batch_dims); + + + + // 从最前面的维度开始比较(广播是从前面的维度开始的) + + for (size_t i = 0; i < max_batch_dims; ++i) { + + // 计算在各自维度中的索引 + + int64_t A_dim = (i < A_batch_dims) ? A_dims[i] : 1; + + int64_t B_dim = (i < B_batch_dims) ? B_dims[i] : 1; + + + + // 广播规则 + + if (A_dim == B_dim) { + + output_shape.push_back(A_dim); + + } else if (A_dim == 1) { + + output_shape.push_back(B_dim); + + } else if (B_dim == 1) { + + output_shape.push_back(A_dim); + + } else { + + return std::nullopt; + + } + + } + + + + // 添加矩阵乘法的结果维度 + + output_shape.push_back(A_M); + + output_shape.push_back(B_N); + + } + + + + return vector{output_shape}; } } // namespace infini \ No newline at end of file From 294c9ec111edadae1fad836ca95b8de8ab25976b Mon Sep 17 00:00:00 2001 From: owCode Date: Sat, 17 Jan 2026 00:00:14 -0500 Subject: [PATCH 7/7] finish all --- include/core/graph.h | 2 + src/core/graph.cc | 323 +++++++++++++++++++++++++++++++++++++------ 2 files changed, 283 insertions(+), 42 deletions(-) diff --git a/include/core/graph.h b/include/core/graph.h index c45580c..64d5e4f 100644 --- a/include/core/graph.h +++ b/include/core/graph.h @@ -5,6 +5,8 @@ #include #include +#include "../operators/transpose.h" +#include "../operators/matmul.h" namespace infini { diff --git a/src/core/graph.cc b/src/core/graph.cc index 7b5d1d0..a2f452d 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -1,5 +1,6 @@ #include "core/graph.h" #include +#include #include #include @@ -65,49 +66,286 @@ namespace infini { return true; } + + // Use Kahn's algorithm for O(V+E) complexity + std::unordered_map inDegree; + std::queue zeroInDegree; std::vector sorted; - std::unordered_set flags; sorted.reserve(ops.size()); - flags.reserve(ops.size()); - while (sorted.size() < ops.size()) + inDegree.reserve(ops.size()); + + // Calculate in-degree for each operator + for (const auto &op : ops) + { + inDegree[op.get()] = 0; + } + for (const auto &op : ops) { - // Any node is move to sorted in this loop. - auto modified = false; - for (auto const &op : ops) + for (const auto &succ : op->getSuccessors()) { - if (auto const &inputs = op->getInputs(); - flags.find(op.get()) == flags.end() && - std::all_of(inputs.begin(), inputs.end(), - [&flags](auto const &input) - { - auto ptr = input->getSource().get(); - return !ptr || flags.find(ptr) != flags.end(); - })) - { - modified = true; - sorted.emplace_back(op); - flags.insert(op.get()); - } + inDegree[succ.get()]++; + } + } + + // Initialize queue with operators having zero in-degree + for (const auto &op : ops) + { + if (inDegree[op.get()] == 0) + { + zeroInDegree.push(op); } - if (!modified) + } + + // Process operators in topological order + while (!zeroInDegree.empty()) + { + auto current = zeroInDegree.front(); + zeroInDegree.pop(); + sorted.emplace_back(current); + + // Decrement in-degree for successors + for (const auto &succ : current->getSuccessors()) { - return false; + if (--inDegree[succ.get()] == 0) + { + zeroInDegree.push(succ); + } } } + + // Check for cycles + if (sorted.size() != ops.size()) + { + return false; + } + this->ops = std::move(sorted); return this->sorted = true; } void GraphObj::optimize() - { - // =================================== 作业 =================================== - // TODO: 设计一个算法来实现指定的图优化规则 - // 图优化规则如下: - // 1. 去除冗余的算子(例如,两个相邻的算子都是 transpose 算子,且做的是相反的操作,可以将其全部删除) - // 2. 合并算子(例如,矩阵乘算子中含有属性transA、transB,如果其输入存在transpose,且对最后两个维度做交换,就可以将transpose融入到矩阵乘算子的属性中去) - // =================================== 作业 =================================== +{ + // =================================== 作业 =================================== + // TODO: 设计一个算法来实现指定的图优化规则 + // 图优化规则如下: + // 1. 去除冗余的算子(例如,两个相邻的算子都是 transpose 算子,且做的是相反的操作,可以将其全部删除) + // 2. 合并算子(例如,矩阵乘算子中含有属性transA、transB,如果其输入存在transpose,且对最后两个维度做交换,就可以将transpose融入到矩阵乘算子的属性中去) + // =================================== 作业 =================================== + + // Step 1: Remove redundant transpose operators + bool modified = true; + while (modified) { + modified = false; + for (auto it = ops.begin(); it != ops.end(); ) { + auto op = *it; + + // Check if this is a transpose operator + if (op->getOpType() != OpType::Transpose) { + ++it; + continue; + } + + auto transposeOp = std::static_pointer_cast(op); + auto permute = transposeOp->getPermute(); + + // Check if the successor is also a transpose operator + auto succs = op->getSuccessors(); + if (succs.size() != 1) { + ++it; + continue; + } + + auto succOp = succs[0]; + if (succOp->getOpType() != OpType::Transpose) { + ++it; + continue; + } + + auto succTranspose = std::static_pointer_cast(succOp); + auto succPermute = succTranspose->getPermute(); + + // Check if the two permutations are inverses of each other + bool isInverse = (permute.size() == succPermute.size()); + for (size_t i = 0; i < permute.size() && isInverse; ++i) { + if (permute[i] >= static_cast(succPermute.size()) || + succPermute[permute[i]] != static_cast(i)) { + isInverse = false; + break; + } + } + + if (!isInverse) { + ++it; + continue; + } + + // 获取输入输出tensor + auto inputTensor = op->getInputs()[0]; + auto outputTensor = succOp->getOutputs()[0]; + auto intermediateTensor = op->getOutputs()[0]; + + // 获取前驱和后继 + auto predecessor = inputTensor->getSource(); + + // 更新连接关系 + if (predecessor) { + // 前驱连接到后继的后继 + for (auto targetOp : succOp->getSuccessors()) { + predecessor->addSuccessors(targetOp); + targetOp->addPredecessors(predecessor); + } + predecessor->removeSuccessors(op); + } + + // 更新tensor连接 + inputTensor->removeTarget(op); + for (auto targetOp : outputTensor->getTargets()) { + inputTensor->addTarget(targetOp); + targetOp->replaceInput(outputTensor, inputTensor); + } + + // 删除操作符 + auto nextIt = it; + ++nextIt; + + // 确保不会重复删除 + if (std::find(ops.begin(), ops.end(), op) != ops.end()) { + ops.erase(std::find(ops.begin(), ops.end(), op)); + } + if (std::find(ops.begin(), ops.end(), succOp) != ops.end()) { + ops.erase(std::find(ops.begin(), ops.end(), succOp)); + } + + // 删除中间tensor + if (std::find(tensors.begin(), tensors.end(), intermediateTensor) != tensors.end()) { + tensors.erase(std::find(tensors.begin(), tensors.end(), intermediateTensor)); + } + if (std::find(tensors.begin(), tensors.end(), outputTensor) != tensors.end()) { + tensors.erase(std::find(tensors.begin(), tensors.end(), outputTensor)); + } + + modified = true; + it = ops.begin(); // 重新开始扫描 + } } - + + // Step 2: Merge transpose into matmul operators + modified = true; + while (modified) { + modified = false; + for (auto it = ops.begin(); it != ops.end(); ) { + auto op = *it; + bool currentModified = false; + + if (op->getOpType() != OpType::MatMul) { + ++it; + continue; + } + + auto matmulOp = std::static_pointer_cast(op); + bool newTransA = matmulOp->getTransA(); + bool newTransB = matmulOp->getTransB(); + + // Helper function to check and merge transpose + auto tryMergeTranspose = [&](int inputIdx, bool& transFlag) -> bool { + auto inputTensor = op->getInputs()[inputIdx]; + auto sourceOp = inputTensor->getSource(); + + if (!sourceOp || sourceOp->getOpType() != OpType::Transpose) { + return false; + } + + auto transposeOp = std::static_pointer_cast(sourceOp); + auto permute = transposeOp->getPermute(); + auto shape = inputTensor->getDims(); + + // Check if this transpose is only used by this matmul + if (inputTensor->getTargets().size() > 1) { + return false; + } + + // Check if transpose swaps only the last two dimensions and keeps others unchanged + if (shape.size() < 2) { + return false; + } + + size_t n = shape.size(); + bool isValidTranspose = true; + + // 检查是否是只交换最后两个维度的转置 + for (size_t i = 0; i < n; ++i) { + if (i < n - 2) { + // 前n-2个维度保持不变 + if (permute[i] != static_cast(i)) { + isValidTranspose = false; + break; + } + } else if (i == n - 2) { + // 倒数第二个维度交换到最后一个 + if (permute[i] != static_cast(n - 1)) { + isValidTranspose = false; + break; + } + } else { // i == n - 1 + // 最后一个维度交换到倒数第二个 + if (permute[i] != static_cast(n - 2)) { + isValidTranspose = false; + break; + } + } + } + + if (!isValidTranspose) { + return false; + } + + // 可以合并 + transFlag = !transFlag; + + // 获取transpose的输入tensor + auto bypassTensor = transposeOp->getInputs()[0]; + + // 更新连接 + op->replaceInput(inputTensor, bypassTensor); + bypassTensor->addTarget(op); + inputTensor->removeTarget(op); + + // 删除transpose操作符 + auto transposeIt = std::find(ops.begin(), ops.end(), sourceOp); + if (transposeIt != ops.end()) { + ops.erase(transposeIt); + } + + // 删除中间tensor + auto tensorIt = std::find(tensors.begin(), tensors.end(), inputTensor); + if (tensorIt != tensors.end()) { + tensors.erase(tensorIt); + } + + return true; + }; + + // Try merge transpose for input A + if (tryMergeTranspose(0, newTransA)) { + currentModified = true; + } + + // Try merge transpose for input B + if (tryMergeTranspose(1, newTransB)) { + currentModified = true; + } + + if (currentModified) { + matmulOp->setTransA(newTransA); + matmulOp->setTransB(newTransB); + modified = true; + it = ops.begin(); // 重新开始扫描 + } else { + ++it; + } + } + } +} Tensor GraphObj::getTensor(int fuid) const { for (auto tensor : tensors) @@ -152,23 +390,24 @@ namespace infini // TODO:利用 allocator 给计算图分配内存 // HINT: 获取分配好的内存指针后,可以调用 tensor 的 setDataBlob 函数给 tensor 绑定内存 // =================================== 作业 =================================== - // 预先分配所有 tensor 所需的内存并记录 - std::vector tensorOffsets; - for (auto &tensor : tensors) - { - size_t offset = allocator.alloc(tensor->getBytes()); - tensorOffsets.emplace_back(offset); - } + - auto ptr = allocator.getPtr(); // 实际开始分配地址 - IT_ASSERT(ptr != nullptr,"Allocator getPtr() returns nullptr"); + std::vector tensorOffsets; // 获取所有 tensor 的 offset - for (int i = 0; i < (int)tensors.size(); ++i) - { - auto blob = make_ref(runtime, static_cast(ptr) + tensorOffsets[i]); - tensors[i]->setDataBlob(blob); + for (auto &tensor : tensors) { // alloc() 必须在 getPtr 之前,所以用一个容器存下来 + size_t offset = allocator.alloc(tensor->getBytes()); + IT_ASSERT(offset != SIZE_MAX, "Memory allocation failed for tensor"); + tensorOffsets.emplace_back(offset); } + auto ptr = allocator.getPtr(); + IT_ASSERT(ptr != nullptr, "Failed to get memory pointer from allocator"); + + for(int i = 0; i < (int)tensors.size(); ++i) { + auto blob = make_ref(runtime, static_cast(ptr) + tensorOffsets[i]); + tensors[i]->setDataBlob(blob); + } + allocator.info(); }