diff --git a/include/core/allocator.h b/include/core/allocator.h index 002601d..7a93aad 100644 --- a/include/core/allocator.h +++ b/include/core/allocator.h @@ -26,6 +26,11 @@ namespace infini { // =================================== 作业 =================================== // TODO:可能需要设计一个数据结构来存储free block,以便于管理和合并 // HINT: 可以使用一个 map 来存储 free block,key 为 block 的起始/结尾地址,value 为 block 的大小 + // Free block management: + // - freeBlocksByAddr: map from start address to size, for allocation search + // - freeBlocksByEnd: map from end address to size, for merging adjacent blocks + std::map freeBlocksByAddr; // key: start address, value: size + std::map freeBlocksByEnd; // key: end address, value: size // =================================== 作业 =================================== public: diff --git a/src/core/allocator.cc b/src/core/allocator.cc index ff593ae..e12eb47 100644 --- a/src/core/allocator.cc +++ b/src/core/allocator.cc @@ -31,6 +31,45 @@ namespace infini // =================================== 作业 =================================== // TODO: 设计一个算法来分配内存,返回起始地址偏移量 + // Use First Fit algorithm: find the first free block that is large enough + for (auto it = freeBlocksByAddr.begin(); it != freeBlocksByAddr.end(); ++it) + { + size_t addr = it->first; + size_t blockSize = it->second; + + if (blockSize >= size) + { + // Found a suitable block + // Remove this block from both maps + freeBlocksByAddr.erase(it); + freeBlocksByEnd.erase(addr + blockSize); + + // If the block is larger than needed, add the remaining part back + if (blockSize > size) + { + size_t newAddr = addr + size; + size_t newSize = blockSize - size; + freeBlocksByAddr[newAddr] = newSize; + freeBlocksByEnd[newAddr + newSize] = newSize; + } + + // Update memory usage statistics + used += size; + if (used > peak) + { + peak = used; + } + + return addr; + } + } + + // No suitable free block found, allocate at the end + size_t addr = peak; + used += size; + peak += size; + + return addr; // =================================== 作业 =================================== return 0; @@ -43,6 +82,75 @@ namespace infini // =================================== 作业 =================================== // TODO: 设计一个算法来回收内存 + // Update memory usage + used -= size; + + size_t blockStart = addr; + size_t blockEnd = addr + size; + size_t blockSize = size; + + // Special case: if freeing the block at the end, just reduce peak + if (blockEnd == peak) + { + // Check if we can merge with a previous free block at the end + auto prevIt = freeBlocksByEnd.find(blockStart); + if (prevIt != freeBlocksByEnd.end()) + { + // Merge with the previous block and reduce peak further + size_t prevSize = prevIt->second; + size_t prevStart = blockStart - prevSize; + + // Remove the previous block from both maps + freeBlocksByAddr.erase(prevStart); + freeBlocksByEnd.erase(blockStart); + + // Reduce peak to the start of the merged block + peak = prevStart; + } + else + { + // Just reduce peak + peak = blockStart; + } + return; + } + + // Try to merge with the previous adjacent free block + auto prevIt = freeBlocksByEnd.find(blockStart); + if (prevIt != freeBlocksByEnd.end()) + { + // Found a previous adjacent block + size_t prevSize = prevIt->second; + size_t prevStart = blockStart - prevSize; + + // Remove the previous block from both maps + freeBlocksByAddr.erase(prevStart); + freeBlocksByEnd.erase(blockStart); + + // Merge: extend the current block backwards + blockStart = prevStart; + blockSize += prevSize; + } + + // Try to merge with the next adjacent free block + auto nextIt = freeBlocksByAddr.find(blockEnd); + if (nextIt != freeBlocksByAddr.end()) + { + // Found a next adjacent block + size_t nextSize = nextIt->second; + + // Remove the next block from both maps + freeBlocksByAddr.erase(blockEnd); + freeBlocksByEnd.erase(blockEnd + nextSize); + + // Merge: extend the current block forwards + blockSize += nextSize; + blockEnd += nextSize; + } + + // Add the merged (or original) free block to both maps + freeBlocksByAddr[blockStart] = blockSize; + freeBlocksByEnd[blockEnd] = blockSize; // =================================== 作业 =================================== } diff --git a/src/core/graph.cc b/src/core/graph.cc index 3a90637..9b570ca 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -151,6 +151,95 @@ namespace infini // =================================== 作业 =================================== // TODO:利用 allocator 给计算图分配内存 // HINT: 获取分配好的内存指针后,可以调用 tensor 的 setDataBlob 函数给 tensor 绑定内存 + + // Track tensor lifetime: when a tensor is last used (by which operator index) + std::unordered_map tensorLastUse; + std::unordered_map tensorAddress; + + // Initialize: all tensors are used at least once initially (for outputs without targets) + for (auto &tensor : tensors) + { + tensorLastUse[tensor.get()] = -1; + } + + // Allocate memory for input tensors (tensors without source) first + for (auto &tensor : tensors) + { + if (!tensor->getSource()) + { + size_t offset = allocator.alloc(tensor->getBytes()); + tensorAddress[tensor.get()] = offset; + } + } + + // Calculate last use for each tensor based on the operators + for (size_t i = 0; i < ops.size(); ++i) + { + auto &op = ops[i]; + + // Check inputs - update their last use time + for (auto &input : op->getInputs()) + { + if (input) + { + tensorLastUse[input.get()] = i; + } + } + } + + // For output tensors that have no targets, they should live until the end + for (auto &tensor : tensors) + { + if (tensor->getTargets().size() == 0 && tensor->getSource()) + { + // This is a graph output, it should live until the end + tensorLastUse[tensor.get()] = ops.size(); + } + } + + // Process each operator in topological order + for (size_t i = 0; i < ops.size(); ++i) + { + auto &op = ops[i]; + + // Allocate memory for outputs + for (auto &output : op->getOutputs()) + { + if (output && tensorAddress.find(output.get()) == tensorAddress.end()) + { + size_t offset = allocator.alloc(output->getBytes()); + tensorAddress[output.get()] = offset; + } + } + + // Free inputs that are no longer needed after this operator + for (auto &input : op->getInputs()) + { + if (input && tensorLastUse[input.get()] == (int)i) + { + // This is the last use of this tensor + if (tensorAddress.find(input.get()) != tensorAddress.end()) + { + allocator.free(tensorAddress[input.get()], input->getBytes()); + } + } + } + } + + // Get the actual memory pointer from allocator + void *basePtr = allocator.getPtr(); + + // Bind memory to each tensor + for (auto &tensor : tensors) + { + if (tensorAddress.find(tensor.get()) != tensorAddress.end()) + { + size_t offset = tensorAddress[tensor.get()]; + void *tensorPtr = reinterpret_cast(basePtr) + offset; + auto blob = make_ref(runtime, tensorPtr); + tensor->setDataBlob(blob); + } + } // =================================== 作业 =================================== allocator.info(); diff --git a/src/operators/concat.cc b/src/operators/concat.cc index d196330..d8a35e3 100644 --- a/src/operators/concat.cc +++ b/src/operators/concat.cc @@ -16,6 +16,30 @@ optional> ConcatObj::inferShape(const TensorVec &inputs) { // =================================== 作业 =================================== // TODO:修改 dims,返回正确的 concat 后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Concat.html#concat-13 + + // All inputs should have the same shape except for the dimension being concatenated + // Sum up the sizes along the concatenation dimension + int concatDimSize = 0; + for (size_t i = 0; i < inputs.size(); ++i) { + auto inputDims = inputs[i]->getDims(); + + // Verify that all other dimensions match + if (inputDims.size() != dims.size()) { + return std::nullopt; + } + + for (size_t j = 0; j < dims.size(); ++j) { + if ((int)j != dim && inputDims[j] != dims[j]) { + return std::nullopt; + } + } + + // Accumulate the size of the concatenation dimension + concatDimSize += inputDims[dim]; + } + + // Update the output shape + dims[dim] = concatDimSize; // =================================== 作业 =================================== return {{dims}}; diff --git a/src/operators/transpose.cc b/src/operators/transpose.cc index faab2b6..1017eb3 100644 --- a/src/operators/transpose.cc +++ b/src/operators/transpose.cc @@ -32,9 +32,15 @@ namespace infini // =================================== 作业 =================================== // TODO:修改 output_dim,返回正确的 transpose 后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Transpose.html#transpose-21 + + // Apply the permutation to get the output shape + for (int i = 0; i < rank; ++i) + { + output_dim[i] = input_dim[transposePermute[i]]; + } // =================================== 作业 =================================== - return std::nullopt; + return {{output_dim}}; } std::string TransposeObj::toString() const diff --git a/src/utils/operator_utils.cc b/src/utils/operator_utils.cc index edbd2c8..a4cc27e 100644 --- a/src/utils/operator_utils.cc +++ b/src/utils/operator_utils.cc @@ -8,9 +8,47 @@ Shape infer_broadcast(const Shape &A, const Shape &B) { // =================================== 作业 =================================== // TODO:对 A 和 B 进行双向广播,返回广播后的形状。 // REF: https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md + + // Broadcasting rules: + // 1. If two shapes have different ranks, prepend 1s to the shorter one + // 2. For each dimension, the output dimension is max(dim_A, dim_B) + // 3. Dimensions are compatible if they are equal or one of them is 1 + + size_t rankA = A.size(); + size_t rankB = B.size(); + size_t maxRank = std::max(rankA, rankB); + + Shape result(maxRank); + + // Iterate from the trailing dimensions + for (size_t i = 0; i < maxRank; ++i) { + int dimA = 1, dimB = 1; + + // Get dimension from A (if exists) + if (i < rankA) { + dimA = A[rankA - 1 - i]; + } + + // Get dimension from B (if exists) + if (i < rankB) { + dimB = B[rankB - 1 - i]; + } + + // Check compatibility and compute output dimension + if (dimA == dimB) { + result[maxRank - 1 - i] = dimA; + } else if (dimA == 1) { + result[maxRank - 1 - i] = dimB; + } else if (dimB == 1) { + result[maxRank - 1 - i] = dimA; + } else { + // Incompatible dimensions + IT_ASSERT(false, "Incompatible broadcast dimensions"); + } + } // =================================== 作业 =================================== - return {}; + return result; } int get_real_axis(const int &axis, const int &rank) {