-
Notifications
You must be signed in to change notification settings - Fork 16
Description
Thank you for sharing your project.
But I have some trouble in running the training code.
Traceback (most recent call last):
File "torchrun.py", line 325, in
net_manager.train()
File "torchrun.py", line 145, in train
outputs = model(x, *y)
File "D:\download\Anaconda3\envs\torch_env\lib\site-packages\torch\nn\modules\module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "D:\Study\3D\ExampleCode\2021\PRNet-PyTorch-master\torchmodel.py", line 63, in forward
x = self.encoder(x)
File "D:\download\Anaconda3\envs\torch_env\lib\site-packages\torch\nn\modules\module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "D:\download\Anaconda3\envs\torch_env\lib\site-packages\torch\nn\modules\container.py", line 100, in forward
input = module(input)
File "D:\download\Anaconda3\envs\torch_env\lib\site-packages\torch\nn\modules\module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "D:\Study\3D\ExampleCode\2021\PRNet-PyTorch-master\torchmodule.py", line 118, in forward
assert (s.shape == out.shape)
AssertionError
The output shape:
s.shape: torch.Size([15, 32, 128, 128])
out.shape: torch.Size([15, 32, 129, 129])
I did not change any code.
How can I solve it?