Skip to content

[BUG] Annotated Generics Not Behaving As Expected With Type Checking #251

@RyanSaxe

Description

@RyanSaxe

Required prerequisites

What version of OpTree are you using?

0.17

System information

3.13.1 (main, Dec 6 2024, 20:13:21) [Clang 18.1.8 ] darwin
0.17.0

Problem description

I would expect the use of PyTree with Generics would let me properly annotate functions where my LSP and/or type checker will work nicely with them.

Here is a very simple example:

import torch
from optree import PyTree, tree_map


def example(data: PyTree[torch.Tensor]) -> PyTree[int]:

    def get_ndim(tensor: torch.Tensor) -> int:
        return len(tensor.shape)

    return tree_map(get_ndim, data)


data = {"a": torch.zeros((3, 4))}
result = example(data)

In the above code, for both pylance, pyright, and basedpyright, the definition of example is okay, but the final line `result
= example(data) yields the following diagnostic:

Argument of type "dict[str, Tensor]" cannot be assigned to parameter "data" of type "PyTree[Tensor]" in function "example"
  "dict[str, Tensor]" is not assignable to "PyTree[Tensor]"Pylance[reportArgumentType](https://github.com/microsoft/pylance-release/blob/main/docs/diagnostics/reportArgumentType.md)
(variable) data: dict[str, Tensor]

Now, improved annotations were mentioned in this closed issue: #6, closed via this PR: #166

Based on the documentation here the example provided makes me believe my example should work, but it does not?

>>> import torch
>>> TensorTree = PyTree[torch.Tensor]
>>> TensorTree
typing.Union[torch.Tensor,
             tuple[ForwardRef('PyTree[torch.Tensor]'), ...],
             list[ForwardRef('PyTree[torch.Tensor]')],
             dict[typing.Any, ForwardRef('PyTree[torch.Tensor]')],
             collections.deque[ForwardRef('PyTree[torch.Tensor]')],
             optree.typing.CustomTreeNode[ForwardRef('PyTree[torch.Tensor]')]]

My example creates a type dict[str, torch.Tensor], which it seems should work properly according to the above (which I confirmed on my machine I get the same output as the code above), yet all the LSPs seem to disagree with this as a valid input.

Reproducible example code

The Python snippets:

import torch
from optree import PyTree, tree_map


def example(data: PyTree[torch.Tensor]) -> PyTree[int]:

    def get_ndim(tensor: torch.Tensor) -> int:
        return len(tensor.shape)

    return tree_map(get_ndim, data)


data = {"a": torch.zeros((3, 4))}
result = example(data)

Traceback

Expected behavior

I would expect that example(data) would not give me any diagnostic issues

Additional context

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingpythonSomething related to the Python source codepython-typingSomething related to Python typing

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions