diff --git a/include/core/allocator.h b/include/core/allocator.h
index 002601d..c66491f 100644
--- a/include/core/allocator.h
+++ b/include/core/allocator.h
@@ -27,7 +27,10 @@ namespace infini {
// TODO:可能需要设计一个数据结构来存储free block,以便于管理和合并
// HINT: 可以使用一个 map 来存储 free block,key 为 block 的起始/结尾地址,value 为 block 的大小
// =================================== 作业 ===================================
-
+ //
+ std::map free_blocks;
+
+
public:
Allocator(Runtime runtime);
@@ -55,5 +58,11 @@ namespace infini {
// function: memory alignment, rouned up
// return: size of the aligned memory block
size_t getAlignedSize(size_t size);
+
+ // function: merge adjacent free blocks
+ // arguments:
+ // addr: address of the newly freed block
+ // size: size of the newly freed block
+ void mergeAdjacentBlocks(size_t addr, size_t size);
};
}
diff --git a/include/core/graph.h b/include/core/graph.h
index c45580c..c086613 100644
--- a/include/core/graph.h
+++ b/include/core/graph.h
@@ -2,11 +2,14 @@
#include "core/allocator.h"
#include "core/operator.h"
#include "core/tensor.h"
+#include "operators/transpose.h"
#include
#include
namespace infini
{
+ // 前向声明
+ class MatmulObj;
class GraphObj : public Object
{
@@ -27,6 +30,36 @@ namespace infini
TensorVec addTensor(const TensorVec &tensors);
void removeOperator(Operator op)
{
+ // 清理操作符与张量的连接关系
+ for (auto& input : op->getInputs()) {
+ if (input) {
+ input->removeTarget(op);
+ }
+ }
+ for (auto& output : op->getOutputs()) {
+ if (output) {
+ output->setSource(nullptr);
+ }
+ }
+
+ // 清理操作符之间的连接关系
+ // 从所有前驱操作符中移除对当前操作符的引用
+ auto predecessors = op->getPredecessors();
+ for (const auto& pred : predecessors) {
+ if (pred) {
+ pred->removeSuccessors(op);
+ }
+ }
+
+ // 从所有后继操作符中移除对当前操作符的引用
+ auto successors = op->getSuccessors();
+ for (const auto& succ : successors) {
+ if (succ) {
+ succ->removePredecessors(op);
+ }
+ }
+
+ // 从操作符列表中删除
auto it = std::find(ops.begin(), ops.end(), op);
if (it != ops.end())
ops.erase(it);
@@ -116,6 +149,36 @@ namespace infini
* @brief If the nodes is sorted in topological order.
*/
bool sorted;
+
+ /**
+ * @brief Add check function for inverse transpose
+ */
+ bool areInverseTransposes(const TransposeObj *transpose1, const TransposeObj *transpose2);
+
+ /**
+ * @brief Add check function for same transpose
+ */
+ bool areSameTransposes(const TransposeObj *transpose1, const TransposeObj *transpose2);
+
+ /**
+ * @brief Check if transpose swaps last two dimensions
+ */
+ bool isLastTwoDimsSwap(const TransposeObj *transpose);
+
+ /**
+ * @brief Merge transpose into matmul operator
+ */
+ void mergeTransposeToMatmul(const Operator& transpose, const Operator& matmul);
+
+ /**
+ * @brief Reconnect graph after removing operators
+ */
+ void reconnectGraph(const Operator& op1, const Operator& op2);
+
+ /**
+ * @brief Clean up unused tensors
+ */
+ void cleanupUnusedTensors();
};
} // namespace infini
diff --git a/src/core/allocator.cc b/src/core/allocator.cc
index ff593ae..b0ceede 100644
--- a/src/core/allocator.cc
+++ b/src/core/allocator.cc
@@ -28,12 +28,48 @@ namespace infini
IT_ASSERT(this->ptr == nullptr);
// pad the size to the multiple of alignment
size = this->getAlignedSize(size);
-
+
+
// =================================== 作业 ===================================
// TODO: 设计一个算法来分配内存,返回起始地址偏移量
// =================================== 作业 ===================================
- return 0;
+ // 使用First Fit算法查找合适的空闲块
+ for (auto it = free_blocks.begin(); it != free_blocks.end(); ++it) {
+ size_t block_addr = it->first;
+ size_t block_size = it->second;
+
+ if (block_size >= size) {
+ // 找到合适的块,进行分配
+ size_t remaining_size = block_size - size;
+
+ if (remaining_size > 0) {
+ // 如果剩余空间足够大,创建新的空闲块
+ free_blocks[block_addr + size] = remaining_size;
+ }
+
+ // 移除或更新当前块
+ free_blocks.erase(it);
+
+ // 更新使用统计
+ used += size;
+ if (used > peak) {
+ peak = used;
+ }
+
+ return block_addr;
+ }
+ }
+
+ // 如果没有找到合适的空闲块,需要扩展内存
+ // 这里简化处理,直接返回当前已使用的大小作为新地址
+ size_t new_addr = used;
+ used += size;
+ if (used > peak) {
+ peak = used;
+ }
+
+ return new_addr;
}
void Allocator::free(size_t addr, size_t size)
@@ -44,6 +80,15 @@ namespace infini
// =================================== 作业 ===================================
// TODO: 设计一个算法来回收内存
// =================================== 作业 ===================================
+
+ // 将释放的内存块添加到空闲块映射中
+ free_blocks[addr] = size;
+
+ // 尝试与相邻的空闲块合并
+ mergeAdjacentBlocks(addr, size);
+
+ // 更新使用统计
+ used -= size;
}
void *Allocator::getPtr()
@@ -61,6 +106,38 @@ namespace infini
return ((size - 1) / this->alignment + 1) * this->alignment;
}
+ void Allocator::mergeAdjacentBlocks(size_t addr, size_t size)
+ {
+ // 查找前一个相邻的空闲块
+ auto prev_it = free_blocks.find(addr - 1);
+ if (prev_it != free_blocks.end()) {
+ // 找到前一个块,检查是否真的相邻
+ size_t prev_addr = prev_it->first;
+ size_t prev_size = prev_it->second;
+
+ if (prev_addr + prev_size == addr) {
+ // 可以与前一个块合并
+ size_t new_size = prev_size + size;
+ free_blocks[prev_addr] = new_size;
+ free_blocks.erase(addr); // 移除当前块
+
+ // 更新参数,继续检查是否可以与后一个块合并
+ addr = prev_addr;
+ size = new_size;
+ }
+ }
+
+ // 查找后一个相邻的空闲块
+ auto next_it = free_blocks.find(addr + size);
+ if (next_it != free_blocks.end()) {
+ // 找到后一个块,可以合并
+ size_t next_size = next_it->second;
+ size_t new_size = size + next_size;
+ free_blocks[addr] = new_size;
+ free_blocks.erase(addr + size); // 移除后一个块
+ }
+ }
+
void Allocator::info()
{
std::cout << "Used memory: " << this->used
diff --git a/src/core/graph.cc b/src/core/graph.cc
index 3a90637..9eb7808 100644
--- a/src/core/graph.cc
+++ b/src/core/graph.cc
@@ -1,4 +1,6 @@
#include "core/graph.h"
+#include "operators/transpose.h"
+#include "operators/matmul.h"
#include
#include
#include
@@ -106,7 +108,86 @@ namespace infini
// 1. 去除冗余的算子(例如,两个相邻的算子都是 transpose 算子,且做的是相反的操作,可以将其全部删除)
// 2. 合并算子(例如,矩阵乘算子中含有属性transA、transB,如果其输入存在transpose,且对最后两个维度做交换,就可以将transpose融入到矩阵乘算子的属性中去)
// =================================== 作业 ===================================
+ // 使用更安全的方式:先收集要删除的操作符,然后统一删除
+ std::vector> to_remove;
+
+ for (const auto& op : ops) {
+ if (op->getOpType() == OpType::Transpose) {
+ auto successors = op->getSuccessors();
+
+ for (const auto& succ : successors) {
+ if (succ->getOpType() == OpType::Transpose) {
+ // 检查是否是逆操作
+ auto transpose1 = dynamic_cast(op.get());
+ auto transpose2 = dynamic_cast(succ.get());
+
+ if (this->areSameTransposes(transpose1, transpose2)) {
+ to_remove.emplace_back(op, succ);
+ break; // 找到一个匹配就跳出内层循环
+ }
+ }
+ }
+ }
+ }
+
+ // 统一处理删除和重新连接
+ for (const auto& pair : to_remove) {
+ auto op1 = pair.first;
+ auto op2 = pair.second;
+
+ // 重新连接图结构
+ this->reconnectGraph(op1, op2);
+
+ // 删除操作符
+ this->removeOperator(op1);
+ this->removeOperator(op2);
+ }
+
+ // =================================== 合并算子优化 ===================================
+ // 将转置操作融入到矩阵乘算子的属性中
+ std::vector> to_merge;
+
+ for (const auto& op : ops) {
+ if (op && op->getOpType() == OpType::Transpose) {
+ auto successors = op->getSuccessors();
+
+ for (const auto& succ : successors) {
+ // 安全检查:确保后继操作符仍然有效
+ if (!succ) continue;
+
+ if (succ->getOpType() == OpType::MatMul) {
+ // 检查转置是否对最后两个维度做交换
+ auto transpose = dynamic_cast(op.get());
+ auto matmul = dynamic_cast(succ.get());
+
+ if (transpose && matmul && this->isLastTwoDimsSwap(transpose)) {
+ to_merge.emplace_back(op, succ);
+ break;
+ }
+ }
+ }
+ }
}
+
+ // 统一处理合并
+ for (const auto& pair : to_merge) {
+ auto transpose = pair.first;
+ auto matmul = pair.second;
+
+ // 安全检查:确保操作符仍然有效
+ if (!transpose || !matmul) continue;
+
+ // 合并转置到矩阵乘
+ this->mergeTransposeToMatmul(transpose, matmul);
+
+ // 删除转置操作符
+ this->removeOperator(transpose);
+ }
+
+ // 清理未使用的张量
+ this->cleanupUnusedTensors();
+
+}
Tensor GraphObj::getTensor(int fuid) const
{
@@ -152,7 +233,49 @@ namespace infini
// TODO:利用 allocator 给计算图分配内存
// HINT: 获取分配好的内存指针后,可以调用 tensor 的 setDataBlob 函数给 tensor 绑定内存
// =================================== 作业 ===================================
-
+
+ // 第一阶段:收集所有tensor的内存需求并分配偏移
+ std::unordered_set allocated_tensors;
+ std::vector> tensor_offsets;
+
+ // 遍历所有tensor,分配内存偏移
+ for (auto &tensor : tensors)
+ {
+ if (tensor && allocated_tensors.find(tensor.get()) == allocated_tensors.end())
+ {
+ // 计算tensor需要的内存大小
+ size_t tensor_size = tensor->getBytes();
+
+ // 使用allocator分配内存地址偏移
+ size_t offset = allocator.alloc(tensor_size);
+
+ // 保存tensor和偏移的对应关系
+ tensor_offsets.emplace_back(tensor, offset);
+ allocated_tensors.insert(tensor.get());
+ }
+ }
+
+ // 第二阶段:获取实际内存指针并绑定到tensor
+ void *base_ptr = allocator.getPtr();
+ if (base_ptr)
+ {
+ for (const auto &pair : tensor_offsets)
+ {
+ auto tensor = pair.first;
+ size_t offset = pair.second;
+
+ // 计算实际的内存地址
+ //char* 以字节为单位进行指针算术
+ void *tensor_ptr = static_cast(base_ptr) + offset;
+
+ // 创建Blob对象
+ Blob blob = make_ref(runtime, tensor_ptr);
+
+ // 绑定内存到tensor
+ tensor->setDataBlob(blob);
+ }
+ }
+
allocator.info();
}
@@ -227,4 +350,215 @@ namespace infini
return true;
}
+ bool GraphObj::areInverseTransposes(const TransposeObj* transpose1, const TransposeObj* transpose2) {
+ if (!transpose1 || !transpose2) {
+ return false;
+ }
+
+ auto perm1 = transpose1->getPermute();
+ auto perm2 = transpose2->getPermute();
+
+ if (perm1.size() != perm2.size()) {
+ return false;
+ }
+
+ // 检查是否为互逆排列
+ for (size_t i = 0; i < perm1.size(); ++i) {
+ if (perm1[perm2[i]] != static_cast(i)) {
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ bool GraphObj::areSameTransposes(const TransposeObj* transpose1, const TransposeObj* transpose2) {
+ if (!transpose1 || !transpose2) {
+ return false;
+ }
+
+ auto perm1 = transpose1->getPermute();
+ auto perm2 = transpose2->getPermute();
+
+ if (perm1.size() != perm2.size()) {
+ return false;
+ }
+
+ // 检查是否为相同的排列
+ for (size_t i = 0; i < perm1.size(); ++i) {
+ if (perm1[i] != perm2[i]) {
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ bool GraphObj::isLastTwoDimsSwap(const TransposeObj* transpose) {
+ if (!transpose) {
+ return false;
+ }
+
+ auto perm = transpose->getPermute();
+ if (perm.size() < 2) {
+ return false;
+ }
+
+ // 检查最后两个维度是否被交换
+ // 对于 n 维张量,最后两个维度的索引是 n-2 和 n-1
+ size_t n = perm.size();
+ return (perm[n-2] == static_cast(n-1) && perm[n-1] == static_cast(n-2));
+ }
+
+ void GraphObj::mergeTransposeToMatmul(const Operator& transpose, const Operator& matmul) {
+ auto transpose_obj = dynamic_cast(transpose.get());
+ auto matmul_obj = dynamic_cast(matmul.get());
+
+ if (!transpose_obj || !matmul_obj) {
+ return;
+ }
+
+ // 获取转置的输入张量
+ auto transpose_inputs = transpose->getInputs();
+ if (transpose_inputs.empty()) {
+ return;
+ }
+
+ // 获取矩阵乘的输入张量
+ auto matmul_inputs = matmul->getInputs();
+ if (matmul_inputs.size() < 2) {
+ return;
+ }
+
+ // 确定转置操作对应的是矩阵乘的哪个输入(A 或 B)
+ auto transpose_output = transpose->getOutput();
+ if (!transpose_output) {
+ return;
+ }
+
+ bool isInputA = (transpose_output == matmul_inputs[0]);
+ bool isInputB = (transpose_output == matmul_inputs[1]);
+
+ if (isInputA) {
+ // 转置操作在 A 输入上,设置 transA = true
+ matmul_obj->setTransA(true);
+ // 将矩阵乘的 A 输入改为转置的输入
+ matmul->replaceInput(matmul_inputs[0], transpose_inputs[0]);
+
+ // 更新张量连接关系
+ transpose_inputs[0]->addTarget(matmul);
+ transpose_output->removeTarget(matmul);
+ } else if (isInputB) {
+ // 转置操作在 B 输入上,设置 transB = true
+ matmul_obj->setTransB(true);
+ // 将矩阵乘的 B 输入改为转置的输入
+ matmul->replaceInput(matmul_inputs[1], transpose_inputs[0]);
+
+ // 更新张量连接关系
+ transpose_inputs[0]->addTarget(matmul);
+ transpose_output->removeTarget(matmul);
+ }
+ }
+
+ void GraphObj::reconnectGraph(const Operator& op1, const Operator& op2) {
+ // 参数验证
+ if (!op1 || !op2) {
+ return;
+ }
+
+ // 确保 op1 和 op2 是相邻的
+ auto op1_successors = op1->getSuccessors();
+ bool areAdjacent = false;
+ for (const auto& succ : op1_successors) {
+ if (succ == op2) {
+ areAdjacent = true;
+ break;
+ }
+ }
+
+ if (!areAdjacent) {
+ // 如果两个操作符不相邻,可能需要不同的处理逻辑
+ return;
+ }
+
+ // 处理张量重新连接:将 op1 的输入直接连接到 op2 的后继节点
+ auto op1_inputs = op1->getInputs();
+ auto op2_outputs = op2->getOutputs();
+ auto op2_successors = op2->getSuccessors();
+
+ if (!op1_inputs.empty() && !op2_outputs.empty()) {
+ auto input_tensor = op1_inputs[0]; // op1 的输入张量
+ auto output_tensor = op2_outputs[0]; // op2 的输出张量
+
+ // 将所有使用 op2 输出张量的操作符改为使用 op1 的输入张量
+ auto targets = output_tensor->getTargets();
+ for (const auto& target : targets) {
+ if (target && target != op1 && target != op2) {
+ target->replaceInput(output_tensor, input_tensor);
+ input_tensor->addTarget(target);
+ }
+ }
+ }
+
+ // 获取连接信息
+ auto op1_predecessors = op1->getPredecessors();
+
+ // 建立新的连接:op1的前驱 -> op2的后继
+ for (const auto& pred : op1_predecessors) {
+ for (const auto& succ : op2_successors) {
+ // 避免自环
+ if (pred != succ) {
+ pred->addSuccessors(succ);
+ succ->addPredecessors(pred);
+ }
+ }
+ }
+
+ // 移除旧的连接
+ for (const auto& pred : op1_predecessors) {
+ pred->removeSuccessors(op1);
+ }
+
+ for (const auto& succ : op2_successors) {
+ succ->removePredecessors(op2);
+ }
+}
+
+void GraphObj::cleanupUnusedTensors() {
+ // 清理不再被任何有效操作符使用的张量
+ auto it = tensors.begin();
+ while (it != tensors.end()) {
+ auto tensor = *it;
+ bool isUsed = false;
+
+ // 检查是否有任何操作符将此张量作为输入或输出
+ for (const auto& op : ops) {
+ // 检查输入
+ for (const auto& input : op->getInputs()) {
+ if (input == tensor) {
+ isUsed = true;
+ break;
+ }
+ }
+ if (isUsed) break;
+
+ // 检查输出
+ for (const auto& output : op->getOutputs()) {
+ if (output == tensor) {
+ isUsed = true;
+ break;
+ }
+ }
+ if (isUsed) break;
+ }
+
+ // 如果张量没有被任何操作符使用,则删除它
+ if (!isUsed) {
+ it = tensors.erase(it);
+ } else {
+ ++it;
+ }
+ }
+}
+
} // namespace infini
\ No newline at end of file
diff --git a/src/core/operator.cc b/src/core/operator.cc
index a70ca48..a856797 100644
--- a/src/core/operator.cc
+++ b/src/core/operator.cc
@@ -11,6 +11,7 @@ namespace infini
{
for (auto it = predecessors.begin(); it != predecessors.end();)
{
+ //turn wear_ptr to shared_ptr
if (it->lock() == op)
it = predecessors.erase(it);
else
diff --git a/src/operators/concat.cc b/src/operators/concat.cc
index d196330..82567fe 100644
--- a/src/operators/concat.cc
+++ b/src/operators/concat.cc
@@ -10,6 +10,10 @@ ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int _dim)
}
optional> ConcatObj::inferShape(const TensorVec &inputs) {
+ if (inputs.empty()) {
+ return std::nullopt;
+ }
+
Shape dims = inputs[0]->getDims();
auto rank = inputs[0]->getRank();
@@ -17,7 +21,22 @@ optional> ConcatObj::inferShape(const TensorVec &inputs) {
// TODO:修改 dims,返回正确的 concat 后的 shape
// REF: https://onnx.ai/onnx/operators/onnx__Concat.html#concat-13
// =================================== 作业 ===================================
+ // 预计算拼接维度的大小
+ IT_ASSERT(dim > 0 && dim < static_cast(rank), "Dimension out of range");
+
+ int concat_size = dims[dim];
+ for (size_t i = 1; i < inputs.size(); i++) {
+ IT_ASSERT(inputs[i]->getRank() == rank, "All input tensors must have the same rank");
+ for (int j = 0; j < static_cast(rank); j++) {
+ if (j != dim){
+ IT_ASSERT(dims[j] == inputs[i]->getDims()[j], "All input tensors must have the same shape except the concatenation dimension");
+ }
+ }
+ concat_size += inputs[i]->getDims()[dim];
+ }
+ dims[dim] = concat_size;
+
return {{dims}};
}
diff --git a/src/operators/matmul.cc b/src/operators/matmul.cc
index 7a16ca2..edb3056 100644
--- a/src/operators/matmul.cc
+++ b/src/operators/matmul.cc
@@ -1,4 +1,5 @@
#include "operators/matmul.h"
+#include "utils/operator_utils.h"
namespace infini
{
@@ -6,7 +7,7 @@ namespace infini
MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA,
bool transB)
: OperatorObj(OpType::MatMul, TensorVec{A, B}, {C}),
- transA(transA), transB(transB)
+ transA(transA), transB(transB), m(0), n(0), k(0)
{
IT_ASSERT(checkValid(graph));
}
@@ -27,7 +28,90 @@ namespace infini
// TODO:返回经过 matmul 操作后的 shape
// REF: https://github.com/onnx/onnx/blob/main/docs/Operators.md#gemm
// =================================== 作业 ===================================
- return std::nullopt;
+
+ if (inputs.size() != 2) {
+ return std::nullopt;
+ }
+
+ auto A = inputs[0];
+ auto B = inputs[1];
+
+ if (!A || !B) {
+ return std::nullopt;
+ }
+
+ Shape shapeA = A->getDims();
+ Shape shapeB = B->getDims();
+
+ if (shapeA.size() < 2 || shapeB.size() < 2) {
+ return std::nullopt;
+ }
+
+ // 获取 A 和 B 的实际矩阵维度(考虑转置)
+ Shape effectiveA = shapeA;
+ Shape effectiveB = shapeB;
+
+ // 处理转置:只影响最后两个维度
+ if (transA && effectiveA.size() >= 2) {
+ size_t rank = effectiveA.size();
+ std::swap(effectiveA[rank-2], effectiveA[rank-1]);
+ }
+
+ if (transB && effectiveB.size() >= 2) {
+ size_t rank = effectiveB.size();
+ std::swap(effectiveB[rank-2], effectiveB[rank-1]);
+ }
+
+ // 获取矩阵乘法的维度
+ size_t rankA = effectiveA.size();
+ size_t rankB = effectiveB.size();
+
+ // 提取最后两个维度进行矩阵乘法
+ int m = effectiveA[rankA-2]; // A 的行数
+ int k_A = effectiveA[rankA-1]; // A 的列数
+ int k_B = effectiveB[rankB-2]; // B 的行数
+ int n = effectiveB[rankB-1]; // B 的列数
+
+ // 检查矩阵乘法的维度兼容性
+ if (k_A != k_B) {
+ return std::nullopt;
+ }
+
+ // 处理批次维度的广播
+ Shape batchA(shapeA.begin(), shapeA.end()-2); // A 的批次维度
+ Shape batchB(shapeB.begin(), shapeB.end()-2); // B 的批次维度
+
+ // 广播批次维度
+ Shape resultBatch;
+ try {
+ if (batchA.empty() && batchB.empty()) {
+ // 都没有批次维度
+ resultBatch = {};
+ } else if (batchA.empty()) {
+ // A 没有批次维度,使用 B 的批次维度
+ resultBatch = batchB;
+ } else if (batchB.empty()) {
+ // B 没有批次维度,使用 A 的批次维度
+ resultBatch = batchA;
+ } else {
+ // 都有批次维度,需要广播
+ resultBatch = infer_broadcast(batchA, batchB);
+ }
+ } catch (...) {
+ return std::nullopt;
+ }
+
+ // 构造输出形状:批次维度 + [m, n]
+ Shape outputShape = resultBatch;
+ outputShape.push_back(m);
+ outputShape.push_back(n);
+
+ // 设置成员变量用于后续计算
+ this->m = m;
+ this->n = n;
+ this->k = k_A;
+
+ return {{outputShape}};
}
} // namespace infini
\ No newline at end of file
diff --git a/src/operators/transpose.cc b/src/operators/transpose.cc
index faab2b6..dce689d 100644
--- a/src/operators/transpose.cc
+++ b/src/operators/transpose.cc
@@ -9,6 +9,7 @@ namespace infini
auto rank = input->getRank();
if (permute.empty())
{
+ transposePermute.resize(rank);
for (size_t i = 0; i < rank; ++i)
{
transposePermute[i] = i;
@@ -34,7 +35,11 @@ namespace infini
// REF: https://onnx.ai/onnx/operators/onnx__Transpose.html#transpose-21
// =================================== 作业 ===================================
- return std::nullopt;
+ for (int i = 0; i < rank; ++i){
+ output_dim[i] = input_dim[transposePermute[i]];
+ }
+
+ return vector{output_dim};
}
std::string TransposeObj::toString() const
diff --git a/src/operators/unary.cc b/src/operators/unary.cc
index 3daad36..43540c7 100644
--- a/src/operators/unary.cc
+++ b/src/operators/unary.cc
@@ -39,7 +39,15 @@ namespace infini
// TODO:返回经过 clip 操作后的 shape
// REF: https://onnx.ai/onnx/operators/onnx__Clip.html#clip-13
// =================================== 作业 ===================================
- return std::nullopt;
+
+ // 检查输入数量
+ IT_ASSERT(inputs.size() == 1, "Clip operation requires exactly one input tensor");
+
+ const auto input = inputs[0];
+
+ // Clip 操作不改变张量的形状,只对值进行裁剪
+ // 输出张量的形状与输入张量完全相同
+ return {{input->getDims()}};
}
std::string ClipObj::toString() const
@@ -66,7 +74,8 @@ namespace infini
// REF_FILE: src/core/operator.cc
// REF: https://onnx.ai/onnx/operators/onnx__Cast.html#cast-21
// =================================== 作业 ===================================
- return {};
+ auto data_type = getOutputDataType();
+ return {data_type};
}
optional> CastObj::inferShape(const TensorVec &inputs)
@@ -75,7 +84,7 @@ namespace infini
// TODO:返回经过 cast 操作后的 shape
// REF: https://onnx.ai/onnx/operators/onnx__Cast.html#cast-21
// =================================== 作业 ===================================
- return std::nullopt;
+ return {{inputs[0]->getDims()}};
}
std::string CastObj::toString() const
diff --git a/src/utils/operator_utils.cc b/src/utils/operator_utils.cc
index edbd2c8..8495222 100644
--- a/src/utils/operator_utils.cc
+++ b/src/utils/operator_utils.cc
@@ -10,7 +10,55 @@ Shape infer_broadcast(const Shape &A, const Shape &B) {
// REF: https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md
// =================================== 作业 ===================================
- return {};
+ // 处理空张量的情况
+ if (A.empty()) return B;
+ if (B.empty()) return A;
+
+ int rank_A = A.size();
+ int rank_B = B.size();
+
+ // 确定最终的维度数
+ int max_rank = std::max(rank_A, rank_B);
+
+ // 创建对齐后的形状
+ Shape local_A = A;
+ Shape local_B = B;
+
+ // 在前面填充1,使两个形状的维度数相同
+ if (rank_A < max_rank) {
+ local_A.insert(local_A.begin(), max_rank - rank_A, 1);
+ }
+ if (rank_B < max_rank) {
+ local_B.insert(local_B.begin(), max_rank - rank_B, 1);
+ }
+
+ // 创建结果形状
+ Shape result(max_rank);
+
+ // 按照广播规则计算每个维度
+ for (int i = 0; i < max_rank; i++) {
+ int dim_A = local_A[i];
+ int dim_B = local_B[i];
+
+ if (dim_A == dim_B) {
+ // 维度相同,直接使用
+ result[i] = dim_A;
+ } else if (dim_A == 1) {
+ // A 的维度为1,可以广播到 B 的维度
+ result[i] = dim_B;
+ } else if (dim_B == 1) {
+ // B 的维度为1,可以广播到 A 的维度
+ result[i] = dim_A;
+ } else {
+ // 维度不兼容,无法广播
+ IT_ASSERT(false,
+ "Cannot broadcast shapes: dimension " + std::to_string(i) +
+ " has incompatible sizes " + std::to_string(dim_A) +
+ " and " + std::to_string(dim_B));
+ }
+ }
+
+ return result;
}
int get_real_axis(const int &axis, const int &rank) {
diff --git a/test/core/test_memory_allocation.cc b/test/core/test_memory_allocation.cc
new file mode 100644
index 0000000..8b13789
--- /dev/null
+++ b/test/core/test_memory_allocation.cc
@@ -0,0 +1 @@
+