-
-
Notifications
You must be signed in to change notification settings - Fork 88
Description
Description of the bug
Several model forward() methods in graph_weather assume that input tensors are provided in the shape
[batch_size, num_nodes, num_features], but this assumption is not explicitly validated.
When inputs with incorrect shapes are passed, the code either raises cryptic PyTorch runtime errors or fails later during execution, making it difficult for users to identify the root cause of the problem.
This results in misleading error behavior rather than failing fast at the API boundary.
To Reproduce
- Steps to reproduce the behavior:
- Instantiate a model from graph_weather.models
- Pass an input tensor with an invalid shape (e.g. missing batch dimension)
import torch
x = torch.randn(100, 64) # invalid shape
model(x)
Observe that the resulting error is unclear or occurs deep inside the computation graph rather than at input validation.
Expected behavior
The model should explicitly validate input tensor shape and fail fast with a clear and informative error message, such as: ValueError: Expected input shape [batch, nodes, features], got torch.Size([100, 64])
Additional context
Adding lightweight input validation and corresponding unit tests would improve developer experience without changing model behavior or performance. This is particularly helpful for new contributors and users experimenting with the library.