Description
When running the provided training code on 8*4090, we observed significantly higher memory usage on one single GPU compared to other GPUs.

We also observed that at the beggining of the training stage, all GPUs get relatively low memory usage, while some GPUs grows higher after a few training batches. This impedes multi-batch training on 8*4090. Are there any known issues in the code where unexpected memory allocation may happen during training?