-
Notifications
You must be signed in to change notification settings - Fork 12
Description
Required prerequisites
- I have read the documentation https://optree.readthedocs.io.
- I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)
What version of OpTree are you using?
0.17
System information
3.13.1 (main, Dec 6 2024, 20:13:21) [Clang 18.1.8 ] darwin
0.17.0
Problem description
I would expect the use of PyTree with Generics would let me properly annotate functions where my LSP and/or type checker will work nicely with them.
Here is a very simple example:
import torch
from optree import PyTree, tree_map
def example(data: PyTree[torch.Tensor]) -> PyTree[int]:
def get_ndim(tensor: torch.Tensor) -> int:
return len(tensor.shape)
return tree_map(get_ndim, data)
data = {"a": torch.zeros((3, 4))}
result = example(data)In the above code, for both pylance, pyright, and basedpyright, the definition of example is okay, but the final line `result
= example(data) yields the following diagnostic:
Argument of type "dict[str, Tensor]" cannot be assigned to parameter "data" of type "PyTree[Tensor]" in function "example"
"dict[str, Tensor]" is not assignable to "PyTree[Tensor]"Pylance[reportArgumentType](https://github.com/microsoft/pylance-release/blob/main/docs/diagnostics/reportArgumentType.md)
(variable) data: dict[str, Tensor]
Now, improved annotations were mentioned in this closed issue: #6, closed via this PR: #166
Based on the documentation here the example provided makes me believe my example should work, but it does not?
>>> import torch
>>> TensorTree = PyTree[torch.Tensor]
>>> TensorTree
typing.Union[torch.Tensor,
tuple[ForwardRef('PyTree[torch.Tensor]'), ...],
list[ForwardRef('PyTree[torch.Tensor]')],
dict[typing.Any, ForwardRef('PyTree[torch.Tensor]')],
collections.deque[ForwardRef('PyTree[torch.Tensor]')],
optree.typing.CustomTreeNode[ForwardRef('PyTree[torch.Tensor]')]]My example creates a type dict[str, torch.Tensor], which it seems should work properly according to the above (which I confirmed on my machine I get the same output as the code above), yet all the LSPs seem to disagree with this as a valid input.
Reproducible example code
The Python snippets:
import torch
from optree import PyTree, tree_map
def example(data: PyTree[torch.Tensor]) -> PyTree[int]:
def get_ndim(tensor: torch.Tensor) -> int:
return len(tensor.shape)
return tree_map(get_ndim, data)
data = {"a": torch.zeros((3, 4))}
result = example(data)Traceback
Expected behavior
I would expect that example(data) would not give me any diagnostic issues
Additional context
No response