diff --git a/include/core/allocator.h b/include/core/allocator.h index 002601d..c66491f 100644 --- a/include/core/allocator.h +++ b/include/core/allocator.h @@ -27,7 +27,10 @@ namespace infini { // TODO:可能需要设计一个数据结构来存储free block,以便于管理和合并 // HINT: 可以使用一个 map 来存储 free block,key 为 block 的起始/结尾地址,value 为 block 的大小 // =================================== 作业 =================================== - + // + std::map free_blocks; + + public: Allocator(Runtime runtime); @@ -55,5 +58,11 @@ namespace infini { // function: memory alignment, rouned up // return: size of the aligned memory block size_t getAlignedSize(size_t size); + + // function: merge adjacent free blocks + // arguments: + // addr: address of the newly freed block + // size: size of the newly freed block + void mergeAdjacentBlocks(size_t addr, size_t size); }; } diff --git a/include/core/graph.h b/include/core/graph.h index c45580c..c086613 100644 --- a/include/core/graph.h +++ b/include/core/graph.h @@ -2,11 +2,14 @@ #include "core/allocator.h" #include "core/operator.h" #include "core/tensor.h" +#include "operators/transpose.h" #include #include namespace infini { + // 前向声明 + class MatmulObj; class GraphObj : public Object { @@ -27,6 +30,36 @@ namespace infini TensorVec addTensor(const TensorVec &tensors); void removeOperator(Operator op) { + // 清理操作符与张量的连接关系 + for (auto& input : op->getInputs()) { + if (input) { + input->removeTarget(op); + } + } + for (auto& output : op->getOutputs()) { + if (output) { + output->setSource(nullptr); + } + } + + // 清理操作符之间的连接关系 + // 从所有前驱操作符中移除对当前操作符的引用 + auto predecessors = op->getPredecessors(); + for (const auto& pred : predecessors) { + if (pred) { + pred->removeSuccessors(op); + } + } + + // 从所有后继操作符中移除对当前操作符的引用 + auto successors = op->getSuccessors(); + for (const auto& succ : successors) { + if (succ) { + succ->removePredecessors(op); + } + } + + // 从操作符列表中删除 auto it = std::find(ops.begin(), ops.end(), op); if (it != ops.end()) ops.erase(it); @@ -116,6 +149,36 @@ namespace infini * @brief If the nodes is sorted in topological order. */ bool sorted; + + /** + * @brief Add check function for inverse transpose + */ + bool areInverseTransposes(const TransposeObj *transpose1, const TransposeObj *transpose2); + + /** + * @brief Add check function for same transpose + */ + bool areSameTransposes(const TransposeObj *transpose1, const TransposeObj *transpose2); + + /** + * @brief Check if transpose swaps last two dimensions + */ + bool isLastTwoDimsSwap(const TransposeObj *transpose); + + /** + * @brief Merge transpose into matmul operator + */ + void mergeTransposeToMatmul(const Operator& transpose, const Operator& matmul); + + /** + * @brief Reconnect graph after removing operators + */ + void reconnectGraph(const Operator& op1, const Operator& op2); + + /** + * @brief Clean up unused tensors + */ + void cleanupUnusedTensors(); }; } // namespace infini diff --git a/src/core/allocator.cc b/src/core/allocator.cc index ff593ae..b0ceede 100644 --- a/src/core/allocator.cc +++ b/src/core/allocator.cc @@ -28,12 +28,48 @@ namespace infini IT_ASSERT(this->ptr == nullptr); // pad the size to the multiple of alignment size = this->getAlignedSize(size); - + + // =================================== 作业 =================================== // TODO: 设计一个算法来分配内存,返回起始地址偏移量 // =================================== 作业 =================================== - return 0; + // 使用First Fit算法查找合适的空闲块 + for (auto it = free_blocks.begin(); it != free_blocks.end(); ++it) { + size_t block_addr = it->first; + size_t block_size = it->second; + + if (block_size >= size) { + // 找到合适的块,进行分配 + size_t remaining_size = block_size - size; + + if (remaining_size > 0) { + // 如果剩余空间足够大,创建新的空闲块 + free_blocks[block_addr + size] = remaining_size; + } + + // 移除或更新当前块 + free_blocks.erase(it); + + // 更新使用统计 + used += size; + if (used > peak) { + peak = used; + } + + return block_addr; + } + } + + // 如果没有找到合适的空闲块,需要扩展内存 + // 这里简化处理,直接返回当前已使用的大小作为新地址 + size_t new_addr = used; + used += size; + if (used > peak) { + peak = used; + } + + return new_addr; } void Allocator::free(size_t addr, size_t size) @@ -44,6 +80,15 @@ namespace infini // =================================== 作业 =================================== // TODO: 设计一个算法来回收内存 // =================================== 作业 =================================== + + // 将释放的内存块添加到空闲块映射中 + free_blocks[addr] = size; + + // 尝试与相邻的空闲块合并 + mergeAdjacentBlocks(addr, size); + + // 更新使用统计 + used -= size; } void *Allocator::getPtr() @@ -61,6 +106,38 @@ namespace infini return ((size - 1) / this->alignment + 1) * this->alignment; } + void Allocator::mergeAdjacentBlocks(size_t addr, size_t size) + { + // 查找前一个相邻的空闲块 + auto prev_it = free_blocks.find(addr - 1); + if (prev_it != free_blocks.end()) { + // 找到前一个块,检查是否真的相邻 + size_t prev_addr = prev_it->first; + size_t prev_size = prev_it->second; + + if (prev_addr + prev_size == addr) { + // 可以与前一个块合并 + size_t new_size = prev_size + size; + free_blocks[prev_addr] = new_size; + free_blocks.erase(addr); // 移除当前块 + + // 更新参数,继续检查是否可以与后一个块合并 + addr = prev_addr; + size = new_size; + } + } + + // 查找后一个相邻的空闲块 + auto next_it = free_blocks.find(addr + size); + if (next_it != free_blocks.end()) { + // 找到后一个块,可以合并 + size_t next_size = next_it->second; + size_t new_size = size + next_size; + free_blocks[addr] = new_size; + free_blocks.erase(addr + size); // 移除后一个块 + } + } + void Allocator::info() { std::cout << "Used memory: " << this->used diff --git a/src/core/graph.cc b/src/core/graph.cc index 3a90637..9eb7808 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -1,4 +1,6 @@ #include "core/graph.h" +#include "operators/transpose.h" +#include "operators/matmul.h" #include #include #include @@ -106,7 +108,86 @@ namespace infini // 1. 去除冗余的算子(例如,两个相邻的算子都是 transpose 算子,且做的是相反的操作,可以将其全部删除) // 2. 合并算子(例如,矩阵乘算子中含有属性transA、transB,如果其输入存在transpose,且对最后两个维度做交换,就可以将transpose融入到矩阵乘算子的属性中去) // =================================== 作业 =================================== + // 使用更安全的方式:先收集要删除的操作符,然后统一删除 + std::vector> to_remove; + + for (const auto& op : ops) { + if (op->getOpType() == OpType::Transpose) { + auto successors = op->getSuccessors(); + + for (const auto& succ : successors) { + if (succ->getOpType() == OpType::Transpose) { + // 检查是否是逆操作 + auto transpose1 = dynamic_cast(op.get()); + auto transpose2 = dynamic_cast(succ.get()); + + if (this->areSameTransposes(transpose1, transpose2)) { + to_remove.emplace_back(op, succ); + break; // 找到一个匹配就跳出内层循环 + } + } + } + } + } + + // 统一处理删除和重新连接 + for (const auto& pair : to_remove) { + auto op1 = pair.first; + auto op2 = pair.second; + + // 重新连接图结构 + this->reconnectGraph(op1, op2); + + // 删除操作符 + this->removeOperator(op1); + this->removeOperator(op2); + } + + // =================================== 合并算子优化 =================================== + // 将转置操作融入到矩阵乘算子的属性中 + std::vector> to_merge; + + for (const auto& op : ops) { + if (op && op->getOpType() == OpType::Transpose) { + auto successors = op->getSuccessors(); + + for (const auto& succ : successors) { + // 安全检查:确保后继操作符仍然有效 + if (!succ) continue; + + if (succ->getOpType() == OpType::MatMul) { + // 检查转置是否对最后两个维度做交换 + auto transpose = dynamic_cast(op.get()); + auto matmul = dynamic_cast(succ.get()); + + if (transpose && matmul && this->isLastTwoDimsSwap(transpose)) { + to_merge.emplace_back(op, succ); + break; + } + } + } + } } + + // 统一处理合并 + for (const auto& pair : to_merge) { + auto transpose = pair.first; + auto matmul = pair.second; + + // 安全检查:确保操作符仍然有效 + if (!transpose || !matmul) continue; + + // 合并转置到矩阵乘 + this->mergeTransposeToMatmul(transpose, matmul); + + // 删除转置操作符 + this->removeOperator(transpose); + } + + // 清理未使用的张量 + this->cleanupUnusedTensors(); + +} Tensor GraphObj::getTensor(int fuid) const { @@ -152,7 +233,49 @@ namespace infini // TODO:利用 allocator 给计算图分配内存 // HINT: 获取分配好的内存指针后,可以调用 tensor 的 setDataBlob 函数给 tensor 绑定内存 // =================================== 作业 =================================== - + + // 第一阶段:收集所有tensor的内存需求并分配偏移 + std::unordered_set allocated_tensors; + std::vector> tensor_offsets; + + // 遍历所有tensor,分配内存偏移 + for (auto &tensor : tensors) + { + if (tensor && allocated_tensors.find(tensor.get()) == allocated_tensors.end()) + { + // 计算tensor需要的内存大小 + size_t tensor_size = tensor->getBytes(); + + // 使用allocator分配内存地址偏移 + size_t offset = allocator.alloc(tensor_size); + + // 保存tensor和偏移的对应关系 + tensor_offsets.emplace_back(tensor, offset); + allocated_tensors.insert(tensor.get()); + } + } + + // 第二阶段:获取实际内存指针并绑定到tensor + void *base_ptr = allocator.getPtr(); + if (base_ptr) + { + for (const auto &pair : tensor_offsets) + { + auto tensor = pair.first; + size_t offset = pair.second; + + // 计算实际的内存地址 + //char* 以字节为单位进行指针算术 + void *tensor_ptr = static_cast(base_ptr) + offset; + + // 创建Blob对象 + Blob blob = make_ref(runtime, tensor_ptr); + + // 绑定内存到tensor + tensor->setDataBlob(blob); + } + } + allocator.info(); } @@ -227,4 +350,215 @@ namespace infini return true; } + bool GraphObj::areInverseTransposes(const TransposeObj* transpose1, const TransposeObj* transpose2) { + if (!transpose1 || !transpose2) { + return false; + } + + auto perm1 = transpose1->getPermute(); + auto perm2 = transpose2->getPermute(); + + if (perm1.size() != perm2.size()) { + return false; + } + + // 检查是否为互逆排列 + for (size_t i = 0; i < perm1.size(); ++i) { + if (perm1[perm2[i]] != static_cast(i)) { + return false; + } + } + + return true; + } + + bool GraphObj::areSameTransposes(const TransposeObj* transpose1, const TransposeObj* transpose2) { + if (!transpose1 || !transpose2) { + return false; + } + + auto perm1 = transpose1->getPermute(); + auto perm2 = transpose2->getPermute(); + + if (perm1.size() != perm2.size()) { + return false; + } + + // 检查是否为相同的排列 + for (size_t i = 0; i < perm1.size(); ++i) { + if (perm1[i] != perm2[i]) { + return false; + } + } + + return true; + } + + bool GraphObj::isLastTwoDimsSwap(const TransposeObj* transpose) { + if (!transpose) { + return false; + } + + auto perm = transpose->getPermute(); + if (perm.size() < 2) { + return false; + } + + // 检查最后两个维度是否被交换 + // 对于 n 维张量,最后两个维度的索引是 n-2 和 n-1 + size_t n = perm.size(); + return (perm[n-2] == static_cast(n-1) && perm[n-1] == static_cast(n-2)); + } + + void GraphObj::mergeTransposeToMatmul(const Operator& transpose, const Operator& matmul) { + auto transpose_obj = dynamic_cast(transpose.get()); + auto matmul_obj = dynamic_cast(matmul.get()); + + if (!transpose_obj || !matmul_obj) { + return; + } + + // 获取转置的输入张量 + auto transpose_inputs = transpose->getInputs(); + if (transpose_inputs.empty()) { + return; + } + + // 获取矩阵乘的输入张量 + auto matmul_inputs = matmul->getInputs(); + if (matmul_inputs.size() < 2) { + return; + } + + // 确定转置操作对应的是矩阵乘的哪个输入(A 或 B) + auto transpose_output = transpose->getOutput(); + if (!transpose_output) { + return; + } + + bool isInputA = (transpose_output == matmul_inputs[0]); + bool isInputB = (transpose_output == matmul_inputs[1]); + + if (isInputA) { + // 转置操作在 A 输入上,设置 transA = true + matmul_obj->setTransA(true); + // 将矩阵乘的 A 输入改为转置的输入 + matmul->replaceInput(matmul_inputs[0], transpose_inputs[0]); + + // 更新张量连接关系 + transpose_inputs[0]->addTarget(matmul); + transpose_output->removeTarget(matmul); + } else if (isInputB) { + // 转置操作在 B 输入上,设置 transB = true + matmul_obj->setTransB(true); + // 将矩阵乘的 B 输入改为转置的输入 + matmul->replaceInput(matmul_inputs[1], transpose_inputs[0]); + + // 更新张量连接关系 + transpose_inputs[0]->addTarget(matmul); + transpose_output->removeTarget(matmul); + } + } + + void GraphObj::reconnectGraph(const Operator& op1, const Operator& op2) { + // 参数验证 + if (!op1 || !op2) { + return; + } + + // 确保 op1 和 op2 是相邻的 + auto op1_successors = op1->getSuccessors(); + bool areAdjacent = false; + for (const auto& succ : op1_successors) { + if (succ == op2) { + areAdjacent = true; + break; + } + } + + if (!areAdjacent) { + // 如果两个操作符不相邻,可能需要不同的处理逻辑 + return; + } + + // 处理张量重新连接:将 op1 的输入直接连接到 op2 的后继节点 + auto op1_inputs = op1->getInputs(); + auto op2_outputs = op2->getOutputs(); + auto op2_successors = op2->getSuccessors(); + + if (!op1_inputs.empty() && !op2_outputs.empty()) { + auto input_tensor = op1_inputs[0]; // op1 的输入张量 + auto output_tensor = op2_outputs[0]; // op2 的输出张量 + + // 将所有使用 op2 输出张量的操作符改为使用 op1 的输入张量 + auto targets = output_tensor->getTargets(); + for (const auto& target : targets) { + if (target && target != op1 && target != op2) { + target->replaceInput(output_tensor, input_tensor); + input_tensor->addTarget(target); + } + } + } + + // 获取连接信息 + auto op1_predecessors = op1->getPredecessors(); + + // 建立新的连接:op1的前驱 -> op2的后继 + for (const auto& pred : op1_predecessors) { + for (const auto& succ : op2_successors) { + // 避免自环 + if (pred != succ) { + pred->addSuccessors(succ); + succ->addPredecessors(pred); + } + } + } + + // 移除旧的连接 + for (const auto& pred : op1_predecessors) { + pred->removeSuccessors(op1); + } + + for (const auto& succ : op2_successors) { + succ->removePredecessors(op2); + } +} + +void GraphObj::cleanupUnusedTensors() { + // 清理不再被任何有效操作符使用的张量 + auto it = tensors.begin(); + while (it != tensors.end()) { + auto tensor = *it; + bool isUsed = false; + + // 检查是否有任何操作符将此张量作为输入或输出 + for (const auto& op : ops) { + // 检查输入 + for (const auto& input : op->getInputs()) { + if (input == tensor) { + isUsed = true; + break; + } + } + if (isUsed) break; + + // 检查输出 + for (const auto& output : op->getOutputs()) { + if (output == tensor) { + isUsed = true; + break; + } + } + if (isUsed) break; + } + + // 如果张量没有被任何操作符使用,则删除它 + if (!isUsed) { + it = tensors.erase(it); + } else { + ++it; + } + } +} + } // namespace infini \ No newline at end of file diff --git a/src/core/operator.cc b/src/core/operator.cc index a70ca48..a856797 100644 --- a/src/core/operator.cc +++ b/src/core/operator.cc @@ -11,6 +11,7 @@ namespace infini { for (auto it = predecessors.begin(); it != predecessors.end();) { + //turn wear_ptr to shared_ptr if (it->lock() == op) it = predecessors.erase(it); else diff --git a/src/operators/concat.cc b/src/operators/concat.cc index d196330..82567fe 100644 --- a/src/operators/concat.cc +++ b/src/operators/concat.cc @@ -10,6 +10,10 @@ ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int _dim) } optional> ConcatObj::inferShape(const TensorVec &inputs) { + if (inputs.empty()) { + return std::nullopt; + } + Shape dims = inputs[0]->getDims(); auto rank = inputs[0]->getRank(); @@ -17,7 +21,22 @@ optional> ConcatObj::inferShape(const TensorVec &inputs) { // TODO:修改 dims,返回正确的 concat 后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Concat.html#concat-13 // =================================== 作业 =================================== + // 预计算拼接维度的大小 + IT_ASSERT(dim > 0 && dim < static_cast(rank), "Dimension out of range"); + + int concat_size = dims[dim]; + for (size_t i = 1; i < inputs.size(); i++) { + IT_ASSERT(inputs[i]->getRank() == rank, "All input tensors must have the same rank"); + for (int j = 0; j < static_cast(rank); j++) { + if (j != dim){ + IT_ASSERT(dims[j] == inputs[i]->getDims()[j], "All input tensors must have the same shape except the concatenation dimension"); + } + } + concat_size += inputs[i]->getDims()[dim]; + } + dims[dim] = concat_size; + return {{dims}}; } diff --git a/src/operators/matmul.cc b/src/operators/matmul.cc index 7a16ca2..edb3056 100644 --- a/src/operators/matmul.cc +++ b/src/operators/matmul.cc @@ -1,4 +1,5 @@ #include "operators/matmul.h" +#include "utils/operator_utils.h" namespace infini { @@ -6,7 +7,7 @@ namespace infini MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA, bool transB) : OperatorObj(OpType::MatMul, TensorVec{A, B}, {C}), - transA(transA), transB(transB) + transA(transA), transB(transB), m(0), n(0), k(0) { IT_ASSERT(checkValid(graph)); } @@ -27,7 +28,90 @@ namespace infini // TODO:返回经过 matmul 操作后的 shape // REF: https://github.com/onnx/onnx/blob/main/docs/Operators.md#gemm // =================================== 作业 =================================== - return std::nullopt; + + if (inputs.size() != 2) { + return std::nullopt; + } + + auto A = inputs[0]; + auto B = inputs[1]; + + if (!A || !B) { + return std::nullopt; + } + + Shape shapeA = A->getDims(); + Shape shapeB = B->getDims(); + + if (shapeA.size() < 2 || shapeB.size() < 2) { + return std::nullopt; + } + + // 获取 A 和 B 的实际矩阵维度(考虑转置) + Shape effectiveA = shapeA; + Shape effectiveB = shapeB; + + // 处理转置:只影响最后两个维度 + if (transA && effectiveA.size() >= 2) { + size_t rank = effectiveA.size(); + std::swap(effectiveA[rank-2], effectiveA[rank-1]); + } + + if (transB && effectiveB.size() >= 2) { + size_t rank = effectiveB.size(); + std::swap(effectiveB[rank-2], effectiveB[rank-1]); + } + + // 获取矩阵乘法的维度 + size_t rankA = effectiveA.size(); + size_t rankB = effectiveB.size(); + + // 提取最后两个维度进行矩阵乘法 + int m = effectiveA[rankA-2]; // A 的行数 + int k_A = effectiveA[rankA-1]; // A 的列数 + int k_B = effectiveB[rankB-2]; // B 的行数 + int n = effectiveB[rankB-1]; // B 的列数 + + // 检查矩阵乘法的维度兼容性 + if (k_A != k_B) { + return std::nullopt; + } + + // 处理批次维度的广播 + Shape batchA(shapeA.begin(), shapeA.end()-2); // A 的批次维度 + Shape batchB(shapeB.begin(), shapeB.end()-2); // B 的批次维度 + + // 广播批次维度 + Shape resultBatch; + try { + if (batchA.empty() && batchB.empty()) { + // 都没有批次维度 + resultBatch = {}; + } else if (batchA.empty()) { + // A 没有批次维度,使用 B 的批次维度 + resultBatch = batchB; + } else if (batchB.empty()) { + // B 没有批次维度,使用 A 的批次维度 + resultBatch = batchA; + } else { + // 都有批次维度,需要广播 + resultBatch = infer_broadcast(batchA, batchB); + } + } catch (...) { + return std::nullopt; + } + + // 构造输出形状:批次维度 + [m, n] + Shape outputShape = resultBatch; + outputShape.push_back(m); + outputShape.push_back(n); + + // 设置成员变量用于后续计算 + this->m = m; + this->n = n; + this->k = k_A; + + return {{outputShape}}; } } // namespace infini \ No newline at end of file diff --git a/src/operators/transpose.cc b/src/operators/transpose.cc index faab2b6..dce689d 100644 --- a/src/operators/transpose.cc +++ b/src/operators/transpose.cc @@ -9,6 +9,7 @@ namespace infini auto rank = input->getRank(); if (permute.empty()) { + transposePermute.resize(rank); for (size_t i = 0; i < rank; ++i) { transposePermute[i] = i; @@ -34,7 +35,11 @@ namespace infini // REF: https://onnx.ai/onnx/operators/onnx__Transpose.html#transpose-21 // =================================== 作业 =================================== - return std::nullopt; + for (int i = 0; i < rank; ++i){ + output_dim[i] = input_dim[transposePermute[i]]; + } + + return vector{output_dim}; } std::string TransposeObj::toString() const diff --git a/src/operators/unary.cc b/src/operators/unary.cc index 3daad36..43540c7 100644 --- a/src/operators/unary.cc +++ b/src/operators/unary.cc @@ -39,7 +39,15 @@ namespace infini // TODO:返回经过 clip 操作后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Clip.html#clip-13 // =================================== 作业 =================================== - return std::nullopt; + + // 检查输入数量 + IT_ASSERT(inputs.size() == 1, "Clip operation requires exactly one input tensor"); + + const auto input = inputs[0]; + + // Clip 操作不改变张量的形状,只对值进行裁剪 + // 输出张量的形状与输入张量完全相同 + return {{input->getDims()}}; } std::string ClipObj::toString() const @@ -66,7 +74,8 @@ namespace infini // REF_FILE: src/core/operator.cc // REF: https://onnx.ai/onnx/operators/onnx__Cast.html#cast-21 // =================================== 作业 =================================== - return {}; + auto data_type = getOutputDataType(); + return {data_type}; } optional> CastObj::inferShape(const TensorVec &inputs) @@ -75,7 +84,7 @@ namespace infini // TODO:返回经过 cast 操作后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Cast.html#cast-21 // =================================== 作业 =================================== - return std::nullopt; + return {{inputs[0]->getDims()}}; } std::string CastObj::toString() const diff --git a/src/utils/operator_utils.cc b/src/utils/operator_utils.cc index edbd2c8..8495222 100644 --- a/src/utils/operator_utils.cc +++ b/src/utils/operator_utils.cc @@ -10,7 +10,55 @@ Shape infer_broadcast(const Shape &A, const Shape &B) { // REF: https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md // =================================== 作业 =================================== - return {}; + // 处理空张量的情况 + if (A.empty()) return B; + if (B.empty()) return A; + + int rank_A = A.size(); + int rank_B = B.size(); + + // 确定最终的维度数 + int max_rank = std::max(rank_A, rank_B); + + // 创建对齐后的形状 + Shape local_A = A; + Shape local_B = B; + + // 在前面填充1,使两个形状的维度数相同 + if (rank_A < max_rank) { + local_A.insert(local_A.begin(), max_rank - rank_A, 1); + } + if (rank_B < max_rank) { + local_B.insert(local_B.begin(), max_rank - rank_B, 1); + } + + // 创建结果形状 + Shape result(max_rank); + + // 按照广播规则计算每个维度 + for (int i = 0; i < max_rank; i++) { + int dim_A = local_A[i]; + int dim_B = local_B[i]; + + if (dim_A == dim_B) { + // 维度相同,直接使用 + result[i] = dim_A; + } else if (dim_A == 1) { + // A 的维度为1,可以广播到 B 的维度 + result[i] = dim_B; + } else if (dim_B == 1) { + // B 的维度为1,可以广播到 A 的维度 + result[i] = dim_A; + } else { + // 维度不兼容,无法广播 + IT_ASSERT(false, + "Cannot broadcast shapes: dimension " + std::to_string(i) + + " has incompatible sizes " + std::to_string(dim_A) + + " and " + std::to_string(dim_B)); + } + } + + return result; } int get_real_axis(const int &axis, const int &rank) { diff --git a/test/core/test_memory_allocation.cc b/test/core/test_memory_allocation.cc new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/test/core/test_memory_allocation.cc @@ -0,0 +1 @@ +