Thanks for opensourcing your great work HDiT!
In
|
def forward(x, theta, conj): |
, it seems that the ctx is missing. But strangely, the function could run smoothly without the ctx argument on PyTorch 2.1. As I switch to PyTorch-1.12, the code is not runnable. I am wondering about the reason...