From ac3507774b3d899d54dbb524b97dc230b6f9a910 Mon Sep 17 00:00:00 2001 From: MCFS436 <596170708@qq.com> Date: Mon, 1 Sep 2025 17:42:29 +0800 Subject: [PATCH 1/8] Update allocator.cc --- src/core/allocator.cc | 49 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/src/core/allocator.cc b/src/core/allocator.cc index ff593ae..d364981 100644 --- a/src/core/allocator.cc +++ b/src/core/allocator.cc @@ -28,22 +28,63 @@ namespace infini IT_ASSERT(this->ptr == nullptr); // pad the size to the multiple of alignment size = this->getAlignedSize(size); - + this->used += size; + // for(const auto &block : free_blocks) {//auto it = this->free_blocks.begin(); it != this->free_blocks.end();it++ + // if(block.second >= size) { + // size_t addr = block.first; + // size_t block_size = block.second; + // free_blocks.erase(block.first); + // if(block_size > size) { + // free_blocks[addr + size] = block_size - size; + // } + // return addr; + // } + // } + for (auto it = this->free_blocks.begin(); it != this->free_blocks.end(); + it++) { + if (it->second >= size) { + size_t addr = it->first; + size_t space = it->second - size; + this->free_blocks.erase(it); + if (space > 0) { + this->free_blocks[addr + size] = space; + } + return addr; + } + } + this->peak += size; + return this->peak - size; + // =================================== 作业 =================================== // TODO: 设计一个算法来分配内存,返回起始地址偏移量 // =================================== 作业 =================================== - - return 0; } void Allocator::free(size_t addr, size_t size) { IT_ASSERT(this->ptr == nullptr); size = getAlignedSize(size); - + // =================================== 作业 =================================== // TODO: 设计一个算法来回收内存 // =================================== 作业 =================================== + this->used -= size; + if (addr + size == this->peak) { + this->peak -= size; + return; + } + for (auto it = this->free_blocks.begin(); it != this->free_blocks.end();it++) { + if (it->first + it->second == addr) { + it->second += size; + return; + } + if (it->first == addr + size) { + this->free_blocks[addr] = size + it->second; + this->free_blocks.erase(it); + return; + } + } + this->free_blocks[addr] = size; } void *Allocator::getPtr() From a4575922c8bc0fa20cb625f1a168722ae18b878e Mon Sep 17 00:00:00 2001 From: MCFS436 <596170708@qq.com> Date: Mon, 1 Sep 2025 17:47:46 +0800 Subject: [PATCH 2/8] Update allocator.h --- include/core/allocator.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/core/allocator.h b/include/core/allocator.h index 002601d..f51e11c 100644 --- a/include/core/allocator.h +++ b/include/core/allocator.h @@ -22,7 +22,7 @@ namespace infini { // pointer to the memory actually allocated void *ptr; - + std::map free_blocks; // =================================== 作业 =================================== // TODO:可能需要设计一个数据结构来存储free block,以便于管理和合并 // HINT: 可以使用一个 map 来存储 free block,key 为 block 的起始/结尾地址,value 为 block 的大小 From d5190c0d1f2c792e9c26480be63db34953cbdf6c Mon Sep 17 00:00:00 2001 From: MCFS436 <596170708@qq.com> Date: Mon, 1 Sep 2025 18:05:35 +0800 Subject: [PATCH 3/8] Update graph.cc --- src/core/graph.cc | 483 ++++++++++++++++++++++++++++------------------ 1 file changed, 292 insertions(+), 191 deletions(-) diff --git a/src/core/graph.cc b/src/core/graph.cc index 3a90637..405f2ac 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -1,230 +1,331 @@ #include "core/graph.h" +#include "core/op_type.h" +#include "operators/matmul.h" +#include "operators/transpose.h" #include +#include #include #include -namespace infini -{ - - void GraphObj::addOperatorAndConnect(const Operator &op) - { - sorted = false; - ops.push_back(op); - for (auto &input : op->getInputs()) - { - if (input) - { - input->addTarget(op); - if (auto pred = input->getSource()) - { - pred->addSuccessors(op); - op->addPredecessors(pred); - } +namespace infini { + +void GraphObj::addOperatorAndConnect(const Operator &op) { + sorted = false; + ops.push_back(op); + for (auto &input : op->getInputs()) { + if (input) { + input->addTarget(op); + if (auto pred = input->getSource()) { + pred->addSuccessors(op); + op->addPredecessors(pred); } } - for (auto &output : op->getOutputs()) - { - if (output) - { - output->setSource(op); - for (auto &succ : output->getTargets()) - { - succ->addPredecessors(op); - op->addSuccessors(succ); - } + } + for (auto &output : op->getOutputs()) { + if (output) { + output->setSource(op); + for (auto &succ : output->getTargets()) { + succ->addPredecessors(op); + op->addSuccessors(succ); } } } +} - string GraphObj::toString() const - { - std::ostringstream oss; - oss << "Graph Tensors:\n"; - for (const auto &tensor : tensors) - oss << tensor << "\n"; - - oss << "Graph operators:\n"; - for (const auto &op : ops) - { - vector preds, succs; - for (auto &o : op->getPredecessors()) - preds.emplace_back(o->getGuid()); - for (auto &o : op->getSuccessors()) - succs.emplace_back(o->getGuid()); - oss << "OP " << op->getGuid(); - oss << ", pred " << vecToString(preds); - oss << ", succ " << vecToString(succs); - oss << ", " << op << "\n"; +string GraphObj::toString() const { + std::ostringstream oss; + oss << "Graph Tensors:\n"; + for (const auto &tensor : tensors) + oss << tensor << "\n"; + + oss << "Graph operators:\n"; + for (const auto &op : ops) { + vector preds, succs; + for (auto &o : op->getPredecessors()) + preds.emplace_back(o->getGuid()); + for (auto &o : op->getSuccessors()) + succs.emplace_back(o->getGuid()); + oss << "OP " << op->getGuid(); + oss << ", pred " << vecToString(preds); + oss << ", succ " << vecToString(succs); + oss << ", " << op << "\n"; + } + return oss.str(); +} + +bool GraphObj::topo_sort() { + if (this->sorted) { + return true; + } + std::vector sorted; + std::unordered_set flags; + sorted.reserve(ops.size()); + flags.reserve(ops.size()); + while (sorted.size() < ops.size()) { + // Any node is move to sorted in this loop. + auto modified = false; + for (auto const &op : ops) { + if (auto const &inputs = op->getInputs(); + flags.find(op.get()) == flags.end() && + std::all_of(inputs.begin(), inputs.end(), + [&flags](auto const &input) { + auto ptr = input->getSource().get(); + return !ptr || flags.find(ptr) != flags.end(); + })) { + modified = true; + sorted.emplace_back(op); + flags.insert(op.get()); + } + } + if (!modified) { + return false; } - return oss.str(); } + this->ops = std::move(sorted); + return this->sorted = true; +} + +void GraphObj::optimize() { + // =================================== 作业 + // =================================== + // TODO: 设计一个算法来实现指定的图优化规则 + // 图优化规则如下: + // 1. 去除冗余的算子(例如,两个相邻的算子都是 transpose + // 算子,且做的是相反的操作,可以将其全部删除) + // 2. + // 合并算子(例如,矩阵乘算子中含有属性transA、transB,如果其输入存在transpose,且对最后两个维度做交换,就可以将transpose融入到矩阵乘算子的属性中去) + // =================================== 作业 + // =================================== - bool GraphObj::topo_sort() - { - if (this->sorted) - { - return true; + if (!this->topo_sort()) { + return; + } + int opsSize = ops.size(); + for (int i = 0; i < opsSize; i++) { + auto op = ops[i]; + if (op->getOpType() == OpType::Transpose) { + auto opd = std::dynamic_pointer_cast(op); + auto input = op->getInputs(0); + auto prevOp = input->getSource(); + if (prevOp && prevOp->getOpType() == OpType::Transpose && + input->getTargets().size() == 1) { + auto prevOpd = std::dynamic_pointer_cast(prevOp); + auto prevInput = prevOp->getInputs(0); + auto perm = opd->getPermute(); + bool flag = true; + for (size_t j = 0; j < perm.size(); j++) { + perm[j] = prevOpd->getPermute()[perm[j]]; + if (perm[j] != int(j)) { + flag = false; + } + } + prevInput->removeTarget(prevOp); + if (flag) { + for (auto succ : op->getSuccessors()) { + succ->replaceInput(op->getOutput(), prevInput); + prevInput->addTarget(succ); + } + this->removeTensor(op->getOutput()); + } else { + auto newOp = make_ref(this, prevInput, + op->getOutput(), perm); + this->addOperatorAndConnect(newOp); + } + for (auto pred : prevOp->getPredecessors()) { + pred->removeSuccessors(prevOp); + } + for (auto succ : op->getSuccessors()) { + succ->removePredecessors(op); + } + this->removeTensor(input); + this->removeOperator(op); + this->removeOperator(prevOp); + i -= 2; + opsSize -= 2; + continue; + } } - std::vector sorted; - std::unordered_set flags; - sorted.reserve(ops.size()); - flags.reserve(ops.size()); - while (sorted.size() < ops.size()) - { - // Any node is move to sorted in this loop. - auto modified = false; - for (auto const &op : ops) - { - if (auto const &inputs = op->getInputs(); - flags.find(op.get()) == flags.end() && - std::all_of(inputs.begin(), inputs.end(), - [&flags](auto const &input) - { - auto ptr = input->getSource().get(); - return !ptr || flags.find(ptr) != flags.end(); - })) - { - modified = true; - sorted.emplace_back(op); - flags.insert(op.get()); + + if (op->getOpType() == OpType::MatMul) { + auto opd = std::dynamic_pointer_cast(op); + auto ita = op->getInputs(0), itb = op->getInputs(1); + auto prevOpA = ita->getSource(), prevOpB = itb->getSource(); + if (prevOpA && prevOpA->getOpType() == OpType::Transpose && + ita->targets.size() == 1) { + auto prevOpd = std::dynamic_pointer_cast(prevOpA); + auto perm = prevOpd->getPermute(); + bool chk = true; + for (size_t j = 0; j < perm.size() - 2; j++) { + if (perm[j] != int(j)) { + chk = false; + break; + } + } + if (!chk || perm[perm.size() - 2] != int(perm.size() - 1) || + perm[perm.size() - 1] != int(perm.size() - 2)) { + continue; + } + auto prevInput = prevOpA->getInputs(0); + opd->setTransA(!opd->getTransA()); + opd->removePredecessors(prevOpA); + for (auto pred : prevOpA->getPredecessors()) { + pred->removeSuccessors(prevOpA); + pred->addSuccessors(op); + op->addPredecessors(pred); } + prevInput->removeTarget(prevOpA); + prevInput->addTarget(op); + opd->inputs[0] = prevInput; + this->removeTensor(ita); + this->removeOperator(prevOpA); + i--; + opsSize--; } - if (!modified) - { - return false; + if (prevOpB && prevOpB->getOpType() == OpType::Transpose && + itb->targets.size() == 1) { + auto prevOpd = std::dynamic_pointer_cast(prevOpB); + auto perm = prevOpd->getPermute(); + bool chk = true; + for (size_t j = 0; j < perm.size() - 2; j++) { + if (perm[j] != int(j)) { + chk = false; + break; + } + } + if (!chk || perm[perm.size() - 2] != int(perm.size() - 1) || + perm[perm.size() - 1] != int(perm.size() - 2)) { + continue; + } + auto prevInput = prevOpB->getInputs(0); + opd->setTransB(!opd->getTransB()); + opd->removePredecessors(prevOpB); + for (auto pred : prevOpB->getPredecessors()) { + pred->removeSuccessors(prevOpB); + pred->addSuccessors(op); + opd->addPredecessors(pred); + } + prevInput->removeTarget(prevOpB); + prevInput->addTarget(op); + opd->inputs[1] = prevInput; + this->removeTensor(itb); + this->removeOperator(prevOpB); + i--; + opsSize--; } } - this->ops = std::move(sorted); - return this->sorted = true; - } - - void GraphObj::optimize() - { - // =================================== 作业 =================================== - // TODO: 设计一个算法来实现指定的图优化规则 - // 图优化规则如下: - // 1. 去除冗余的算子(例如,两个相邻的算子都是 transpose 算子,且做的是相反的操作,可以将其全部删除) - // 2. 合并算子(例如,矩阵乘算子中含有属性transA、transB,如果其输入存在transpose,且对最后两个维度做交换,就可以将transpose融入到矩阵乘算子的属性中去) - // =================================== 作业 =================================== } +} - Tensor GraphObj::getTensor(int fuid) const - { - for (auto tensor : tensors) - { - if (tensor->getFuid() == fuid) - { - return tensor; - } +Tensor GraphObj::getTensor(int fuid) const { + for (auto tensor : tensors) { + if (tensor->getFuid() == fuid) { + return tensor; } - return nullptr; } + return nullptr; +} - void GraphObj::shape_infer() - { - for (auto &op : ops) - { - auto ans = op->inferShape(); - IT_ASSERT(ans.has_value()); - auto oldOutputs = op->getOutputs(); - IT_ASSERT(ans.value().size() == oldOutputs.size()); - // replace the old outputshape and size with new one - for (int i = 0; i < (int)ans.value().size(); ++i) - { - auto newShape = ans.value()[i]; - auto oldShape = oldOutputs[i]->getDims(); - auto fuid = oldOutputs[i]->getFuid(); - if (newShape != oldShape) - { - auto tensor = this->getTensor(fuid); - tensor->setShape(newShape); - } +void GraphObj::shape_infer() { + for (auto &op : ops) { + auto ans = op->inferShape(); + IT_ASSERT(ans.has_value()); + auto oldOutputs = op->getOutputs(); + IT_ASSERT(ans.value().size() == oldOutputs.size()); + // replace the old outputshape and size with new one + for (int i = 0; i < (int)ans.value().size(); ++i) { + auto newShape = ans.value()[i]; + auto oldShape = oldOutputs[i]->getDims(); + auto fuid = oldOutputs[i]->getFuid(); + if (newShape != oldShape) { + auto tensor = this->getTensor(fuid); + tensor->setShape(newShape); } } } +} - void GraphObj::dataMalloc() - { - // topological sorting first - IT_ASSERT(topo_sort() == true); +void GraphObj::dataMalloc() { + // topological sorting first + IT_ASSERT(topo_sort() == true); - // =================================== 作业 =================================== - // TODO:利用 allocator 给计算图分配内存 - // HINT: 获取分配好的内存指针后,可以调用 tensor 的 setDataBlob 函数给 tensor 绑定内存 - // =================================== 作业 =================================== - - allocator.info(); + // =================================== 作业 + // =================================== + // TODO:利用 allocator 给计算图分配内存 + // HINT: 获取分配好的内存指针后,可以调用 tensor 的 setDataBlob 函数给 + // tensor 绑定内存 + // =================================== 作业 + // =================================== + auto n = this->tensors.size(); + vector offsets(n); + for (size_t i = 0; i < n; i++) { + offsets[i] = this->allocator.alloc(this->tensors[i]->getBytes()); } - - Tensor GraphObj::addTensor(Shape dim, DataType dtype) - { - return tensors.emplace_back(make_ref(dim, dtype, runtime)); + auto hptr = this->allocator.getPtr(); + for (size_t i = 0; i < n; i++) { + auto ptr = hptr + offsets[i]; + auto blob = make_ref(this->runtime, ptr); + this->tensors[i]->setDataBlob(blob); } + allocator.info(); +} - Tensor GraphObj::addTensor(const Tensor &tensor) - { - IT_ASSERT(tensor->getRuntime() == runtime, - std::string("Tensor runtime mismatch: cannot add a tenosr in ") + - tensor->getRuntime()->toString() + " to " + - runtime->toString()); - tensors.emplace_back(tensor); - return tensor; - } +Tensor GraphObj::addTensor(Shape dim, DataType dtype) { + return tensors.emplace_back(make_ref(dim, dtype, runtime)); +} - TensorVec GraphObj::addTensor(const TensorVec &tensors) - { - for (auto &t : tensors) - addTensor(t); - return tensors; - } +Tensor GraphObj::addTensor(const Tensor &tensor) { + IT_ASSERT(tensor->getRuntime() == runtime, + std::string("Tensor runtime mismatch: cannot add a tenosr in ") + + tensor->getRuntime()->toString() + " to " + + runtime->toString()); + tensors.emplace_back(tensor); + return tensor; +} - // tensor's "source" and "target" must be in "ops". - // tensor has no "source" and no "target" must not exist. - // "inputs" or "outputs" of operators must be in "tensors" - // "predecessors" and "successors" of an operator of "ops" must be in "ops". - bool GraphObj::checkValid() const - { - for (auto tensor : tensors) - { - IT_ASSERT(!(tensor->getTargets().size() == 0 && - nullptr == tensor->getSource())); - for (auto op : tensor->getTargets()) - { - IT_ASSERT(std::find(ops.begin(), ops.end(), op) != ops.end()); - } - auto op = tensor->getSource(); - IT_ASSERT(!(op && std::find(ops.begin(), ops.end(), op) == ops.end())); +TensorVec GraphObj::addTensor(const TensorVec &tensors) { + for (auto &t : tensors) + addTensor(t); + return tensors; +} + +// tensor's "source" and "target" must be in "ops". +// tensor has no "source" and no "target" must not exist. +// "inputs" or "outputs" of operators must be in "tensors" +// "predecessors" and "successors" of an operator of "ops" must be in "ops". +bool GraphObj::checkValid() const { + for (auto tensor : tensors) { + IT_ASSERT(!(tensor->getTargets().size() == 0 && + nullptr == tensor->getSource())); + for (auto op : tensor->getTargets()) { + IT_ASSERT(std::find(ops.begin(), ops.end(), op) != ops.end()); } - for (auto op : ops) - { - for (auto tensor : op->getInputs()) - { - IT_ASSERT(std::find(tensors.begin(), tensors.end(), tensor) != - tensors.end()); - } - for (auto tensor : op->getOutputs()) - { - IT_ASSERT(std::find(tensors.begin(), tensors.end(), tensor) != - tensors.end()); - } - for (auto pre : op->getPredecessors()) - { - IT_ASSERT(std::find(ops.begin(), ops.end(), pre) != ops.end()); - } - for (auto suc : op->getSuccessors()) - { - IT_ASSERT(std::find(ops.begin(), ops.end(), suc) != ops.end()); - } + auto op = tensor->getSource(); + IT_ASSERT(!(op && std::find(ops.begin(), ops.end(), op) == ops.end())); + } + for (auto op : ops) { + for (auto tensor : op->getInputs()) { + IT_ASSERT(std::find(tensors.begin(), tensors.end(), tensor) != + tensors.end()); } - std::set s; - // check whether two tensors with the same FUID exist - for (auto tensor : tensors) - { - int cnt = s.count(tensor->getFuid()); - IT_ASSERT(cnt == 0, std::to_string(tensor->getFuid())); - s.insert(tensor->getFuid()); + for (auto tensor : op->getOutputs()) { + IT_ASSERT(std::find(tensors.begin(), tensors.end(), tensor) != + tensors.end()); } - return true; + for (auto pre : op->getPredecessors()) { + IT_ASSERT(std::find(ops.begin(), ops.end(), pre) != ops.end()); + } + for (auto suc : op->getSuccessors()) { + IT_ASSERT(std::find(ops.begin(), ops.end(), suc) != ops.end()); + } + } + std::set s; + // check whether two tensors with the same FUID exist + for (auto tensor : tensors) { + int cnt = s.count(tensor->getFuid()); + IT_ASSERT(cnt == 0, std::to_string(tensor->getFuid())); + s.insert(tensor->getFuid()); } + return true; +} -} // namespace infini \ No newline at end of file +} // namespace infini From 260af1ee3ac4cc525c4d8acef0f52ee1526fb9c5 Mon Sep 17 00:00:00 2001 From: MCFS436 <596170708@qq.com> Date: Mon, 1 Sep 2025 18:08:30 +0800 Subject: [PATCH 4/8] Update transpose.cc --- src/operators/transpose.cc | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/operators/transpose.cc b/src/operators/transpose.cc index faab2b6..e707369 100644 --- a/src/operators/transpose.cc +++ b/src/operators/transpose.cc @@ -26,15 +26,23 @@ namespace infini { const auto A = inputs[0]; auto input_dim = A->getDims(); - auto output_dim = input_dim; + //auto output_dim = input_dim; int rank = A->getRank(); - + + vector output_dim(rank); + + // 根据permute重新排列维度 + for (int i = 0; i < rank; ++i) { + output_dim[i] = input_dim[transposePermute[i]]; + } + + // 返回包含输出形状的向量 + return {{output_dim}}; // =================================== 作业 =================================== // TODO:修改 output_dim,返回正确的 transpose 后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Transpose.html#transpose-21 // =================================== 作业 =================================== - return std::nullopt; } std::string TransposeObj::toString() const From db6a4c0d0236d2d075138c5d46ebf1d7fd0fda2b Mon Sep 17 00:00:00 2001 From: MCFS436 <596170708@qq.com> Date: Mon, 1 Sep 2025 18:10:03 +0800 Subject: [PATCH 5/8] Update unary.cc --- src/operators/unary.cc | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/operators/unary.cc b/src/operators/unary.cc index 3daad36..cb83881 100644 --- a/src/operators/unary.cc +++ b/src/operators/unary.cc @@ -35,11 +35,17 @@ namespace infini optional> ClipObj::inferShape(const TensorVec &inputs) { + // if(inputs.size()!=1){ + // return std::nullopt; + // } + vector outputShapes={}; + for(auto it= inputs.begin();it!=inputs.end();it++){auto o=(*it); + outputShapes.push_back(o->getDims());} // =================================== 作业 =================================== // TODO:返回经过 clip 操作后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Clip.html#clip-13 // =================================== 作业 =================================== - return std::nullopt; + return outputShapes; } std::string ClipObj::toString() const @@ -61,12 +67,13 @@ namespace infini vector CastObj::inferDataType(const TensorVec &inputs) const { + // =================================== 作业 =================================== // TODO:返回经过 cast 操作后, 输出 tensor 的数目和数据类型 // REF_FILE: src/core/operator.cc // REF: https://onnx.ai/onnx/operators/onnx__Cast.html#cast-21 // =================================== 作业 =================================== - return {}; + return {this->getOutputDataType()}; // IGNORE } optional> CastObj::inferShape(const TensorVec &inputs) @@ -75,7 +82,7 @@ namespace infini // TODO:返回经过 cast 操作后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Cast.html#cast-21 // =================================== 作业 =================================== - return std::nullopt; + return {{inputs[0]->getDims()}}; } std::string CastObj::toString() const From c2aee93abb48d2c399a8a3af7e2a5e4fbf17faa8 Mon Sep 17 00:00:00 2001 From: MCFS436 <596170708@qq.com> Date: Mon, 1 Sep 2025 18:11:19 +0800 Subject: [PATCH 6/8] Update matmul.cc --- src/operators/matmul.cc | 55 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 53 insertions(+), 2 deletions(-) diff --git a/src/operators/matmul.cc b/src/operators/matmul.cc index 7a16ca2..95d6f35 100644 --- a/src/operators/matmul.cc +++ b/src/operators/matmul.cc @@ -23,11 +23,62 @@ namespace infini optional> MatmulObj::inferShape(const TensorVec &inputs) { + auto rankA = inputs[0]->getRank(); + auto rankB = inputs[1]->getRank(); + auto dimA = inputs[0]->getDims(); + auto dimB = inputs[1]->getDims(); + if (rankA < 2 || rankB < 2) { + return {}; + } + auto rankR = std::max(rankA, rankB); + Shape res(rankR); + int i = int(rankA) - 3, j = int(rankB) - 3; + for (; i >= 0 && j >= 0; i--, j--) { + if (dimA[i] == dimB[j] || dimB[j] == 1) { + res[std::max(i, j)] = dimA[i]; + } else if (dimA[i] == 1) { + res[std::max(i, j)] = dimB[j]; + } else { + return {}; + } + } + for (; i >= 0; i--) { + res[i] = dimA[i]; + } + for (; j >= 0; j--) { + res[j] = dimB[j]; + } + auto m = dimA[rankA - 2], kA = dimA[rankA - 1], kB = dimB[rankB - 2], + n = dimB[rankB - 1]; + if (this->transA) { + std::swap(m, kA); + } + if (this->transB) { + std::swap(kB, n); + } + if (kA != kB) { + return {}; + } + res[rankR - 2] = m; + res[rankR - 1] = n; + if (inputs.size() <= 2) { + return {{res}}; + } + auto rankC = inputs[2]->getRank(); + auto dimC = inputs[2]->getDims(); + if (rankC > rankR) { + return {}; + } + for (int i = rankC - 1, j = rankR - 1; i >= 0; i--, j--) { + if (dimC[i] != res[j] && dimC[i] != 1) { + return {}; + } + } + return {{res}}; // =================================== 作业 =================================== // TODO:返回经过 matmul 操作后的 shape // REF: https://github.com/onnx/onnx/blob/main/docs/Operators.md#gemm // =================================== 作业 =================================== - return std::nullopt; } -} // namespace infini \ No newline at end of file +} // namespace infini From 93a6702309efaf142a6e0ea4fb511d19627f37f5 Mon Sep 17 00:00:00 2001 From: MCFS436 <596170708@qq.com> Date: Mon, 1 Sep 2025 18:18:11 +0800 Subject: [PATCH 7/8] Update operator_utils.cc --- src/utils/operator_utils.cc | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/src/utils/operator_utils.cc b/src/utils/operator_utils.cc index edbd2c8..8635758 100644 --- a/src/utils/operator_utils.cc +++ b/src/utils/operator_utils.cc @@ -4,13 +4,37 @@ namespace infini { Shape infer_broadcast(const Shape &A, const Shape &B) { - + if(A == B) + return A; + auto aRank = A.size(); + auto bRank = B.size(); + auto rank = std::max(aRank, bRank); + Shape aShape = A; + Shape bShape = B; + if (aRank < rank) { + aShape.insert(aShape.begin(), rank - aRank, 1); + } + if (bRank < rank) { + bShape.insert(bShape.begin(), rank - bRank, 1); + } + Shape dims(rank); + for (long unsigned int i = 0; i < rank; i++) { + if (aShape[i] == bShape[i]) { + dims[i] = aShape[i]; + } else if (aShape[i] == 1) { + dims[i] = bShape[i]; + } else if (bShape[i] == 1) { + dims[i] = aShape[i]; + } else { + return {}; // incompatible + } + } // =================================== 作业 =================================== // TODO:对 A 和 B 进行双向广播,返回广播后的形状。 // REF: https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md // =================================== 作业 =================================== - return {}; + return dims; } int get_real_axis(const int &axis, const int &rank) { From fa0376922ff7ef716700ce50b25ad6100b966ede Mon Sep 17 00:00:00 2001 From: MCFS436 <596170708@qq.com> Date: Mon, 1 Sep 2025 18:28:12 +0800 Subject: [PATCH 8/8] Update concat.cc --- src/operators/concat.cc | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/operators/concat.cc b/src/operators/concat.cc index d196330..9626840 100644 --- a/src/operators/concat.cc +++ b/src/operators/concat.cc @@ -12,7 +12,20 @@ ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int _dim) optional> ConcatObj::inferShape(const TensorVec &inputs) { Shape dims = inputs[0]->getDims(); auto rank = inputs[0]->getRank(); - + auto n = inputs.size(); + for (size_t i = 1; i < n; i++) { + auto input_dims = inputs[i]->getDims(); + if (input_dims.size() != rank) { + return std::nullopt; + } + for (size_t j = 0; j < rank; j++) { + if (j == size_t(dim)) { + dims[j] += input_dims[j]; + } else if (dims[j] != input_dims[j]) { + return std::nullopt; + } + } + } // =================================== 作业 =================================== // TODO:修改 dims,返回正确的 concat 后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Concat.html#concat-13