diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index b2469fe4..c01febbe 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -889,14 +889,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 36, - "endColumn": 44, - "lineCount": 1 - } - }, { "code": "reportAny", "range": { @@ -2837,6 +2829,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 19, + "endColumn": 40, + "lineCount": 1 + } + }, { "code": "reportAny", "range": { @@ -2853,14 +2853,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 14, - "lineCount": 1 - } - }, { "code": "reportAny", "range": { @@ -2894,18 +2886,26 @@ } }, { - "code": "reportUnknownVariableType", + "code": "reportAny", "range": { - "startColumn": 15, - "endColumn": 41, + "startColumn": 55, + "endColumn": 59, "lineCount": 1 } }, { "code": "reportAny", "range": { - "startColumn": 26, - "endColumn": 30, + "startColumn": 55, + "endColumn": 59, + "lineCount": 1 + } + }, + { + "code": "reportAny", + "range": { + "startColumn": 55, + "endColumn": 59, "lineCount": 1 } }, @@ -5877,6 +5877,30 @@ "lineCount": 3 } }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 16, + "endColumn": 44, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 16, + "endColumn": 44, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 12, + "endColumn": 40, + "lineCount": 1 + } + }, { "code": "reportUnknownVariableType", "range": { @@ -5901,6 +5925,38 @@ "lineCount": 1 } }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 12, + "endColumn": 40, + "lineCount": 1 + } + }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 58, + "endColumn": 64, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 12, + "endColumn": 40, + "lineCount": 1 + } + }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 66, + "endColumn": 72, + "lineCount": 1 + } + }, { "code": "reportUnknownVariableType", "range": { @@ -5917,6 +5973,38 @@ "lineCount": 1 } }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 12, + "endColumn": 40, + "lineCount": 1 + } + }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 67, + "endColumn": 73, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 12, + "endColumn": 40, + "lineCount": 1 + } + }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 44, + "endColumn": 50, + "lineCount": 1 + } + }, { "code": "reportUnknownMemberType", "range": { diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0fc48780..b31ccc9a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -108,7 +108,6 @@ jobs: build_py_project_in_conda_env conda install graphviz - CI_SUPPORT_SPHINX_VERSION_SPECIFIER=">=4.0" build_docs downstream_tests: diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 0f0947a0..e6384d29 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -82,7 +82,6 @@ Python 3 Conda: Documentation: script: | curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-docs.sh - CI_SUPPORT_SPHINX_VERSION_SPECIFIER=">=4.0" . ./build-docs.sh tags: - python3 diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py index 7074fb33..54d3e9f7 100644 --- a/arraycontext/fake_numpy.py +++ b/arraycontext/fake_numpy.py @@ -186,7 +186,7 @@ def linspace(self, num: int = 50, *, endpoint: bool = True, retstep: Literal[False] = False, - dtype: DTypeLike = None, + dtype: DTypeLike | None = None, axis: int = 0 ) -> Array: ... @@ -197,7 +197,7 @@ def linspace(self, num: int = 50, *, endpoint: bool = True, retstep: Literal[True], - dtype: DTypeLike = None, + dtype: DTypeLike | None = None, axis: int = 0 ) -> tuple[Array, NDArray[Any] | float] | Array: ... @@ -207,7 +207,7 @@ def linspace(self, num: int = 50, *, endpoint: bool = True, retstep: bool = False, - dtype: DTypeLike = None, + dtype: DTypeLike | None = None, axis: int = 0 ) -> tuple[Array, NDArray[Any] | float] | Array: num = operator.index(num) @@ -446,19 +446,19 @@ def where(self, def sum(self, a: ArrayOrContainer, /, axis: int | tuple[int, ...] | None = None, - dtype: DTypeLike = None, + dtype: DTypeLike | None = None, ) -> Array: ... @overload def sum(self, a: ScalarLike, /, axis: int | tuple[int, ...] | None = None, - dtype: DTypeLike = None, + dtype: DTypeLike | None = None, ) -> ScalarLike: ... def sum(self, a: ArrayOrContainerOrScalar, /, axis: int | tuple[int, ...] | None = None, - dtype: DTypeLike = None, + dtype: DTypeLike | None = None, ) -> ArrayOrScalar: ... @overload diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index af3ca5c9..8688b4cd 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -209,20 +209,20 @@ def rec_equal(x, y): def sum(self, a: ArrayOrContainer, axis: int | tuple[int, ...] | None = None, - dtype: DTypeLike = None, + dtype: DTypeLike | None = None, ) -> Array: ... @overload def sum(self, a: Scalar, axis: int | tuple[int, ...] | None = None, - dtype: DTypeLike = None, + dtype: DTypeLike | None = None, ) -> Scalar: ... @override def sum(self, a: ArrayOrContainerOrScalar, axis: int | tuple[int, ...] | None = None, - dtype: DTypeLike = None, + dtype: DTypeLike | None = None, ) -> ArrayOrScalar: return rec_map_reduce_array_container( sum, diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py index f923a4d6..dc6501bf 100644 --- a/arraycontext/impl/numpy/fake_numpy.py +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -100,20 +100,20 @@ def __getattr__(self, name: str): def sum(self, a: ArrayOrContainer, axis: int | tuple[int, ...] | None = None, - dtype: DTypeLike = None, + dtype: DTypeLike | None = None, ) -> Array: ... @overload def sum(self, a: Scalar, axis: int | tuple[int, ...] | None = None, - dtype: DTypeLike = None, + dtype: DTypeLike | None = None, ) -> Scalar: ... @override def sum(self, a: ArrayOrContainerOrScalar, axis: int | tuple[int, ...] | None = None, - dtype: DTypeLike = None, + dtype: DTypeLike | None = None, ) -> ArrayOrScalar: return rec_map_reduce_array_container(sum, partial(np.sum, axis=axis, @@ -273,7 +273,7 @@ def array_equal(self, @override def arange(self, *args, **kwargs): - return np.arange(*args, **kwargs) + return cast("Array", cast("object", np.arange(*args, **kwargs))) @override def linspace(self, *args, **kwargs): diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index 2f063984..62012bf3 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -345,20 +345,20 @@ def inner(ary: ArrayOrScalar) -> ArrayOrScalar: def sum(self, a: ArrayOrContainer, axis: int | tuple[int, ...] | None = None, - dtype: DTypeLike = None, + dtype: DTypeLike | None = None, ) -> Array: ... @overload def sum(self, a: ScalarLike, axis: int | tuple[int, ...] | None = None, - dtype: DTypeLike = None, + dtype: DTypeLike | None = None, ) -> ScalarLike: ... @override def sum(self, a: ArrayOrContainerOrScalar, axis: int | tuple[int, ...] | None = None, - dtype: DTypeLike = None, + dtype: DTypeLike | None = None, ) -> ArrayOrScalar: if isinstance(axis, int): axis = axis, diff --git a/arraycontext/impl/pyopencl/taggable_cl_array.py b/arraycontext/impl/pyopencl/taggable_cl_array.py index 6402ba66..457fe13c 100644 --- a/arraycontext/impl/pyopencl/taggable_cl_array.py +++ b/arraycontext/impl/pyopencl/taggable_cl_array.py @@ -185,7 +185,7 @@ def to_tagged_cl_array(ary: cla.Array, def empty( queue: cl.CommandQueue, shape: tuple[int, ...] | int, - dtype: DTypeLike = float, + dtype: DTypeLike | None = None, *, axes: tuple[Axis, ...] | None = None, tags: frozenset[Tag] = _EMPTY_TAG_SET, order: Literal["C"] | Literal["F"] = "C", diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 58d18b7b..6e883f55 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -663,6 +663,12 @@ def _to_frozen( # All name_hint_tags shared at least some common prefix. function_name = f"frozen_{name_hint}" if name_hint else "frozen_result" + self._compile_trace_callback( + function_name, "frozen-expr", normalized_expr) + + self._compile_trace_callback( + function_name, "post_transform_dag", transformed_dag) + self._dag_transform_cache[normalized_expr] = ( transformed_dag, function_name) @@ -670,13 +676,30 @@ def _to_frozen( opts = _DEFAULT_LOOPY_OPTIONS assert opts.return_dict + self._compile_trace_callback( + function_name, "pre_generate_loopy", transformed_dag) + pt_prg = pt.generate_loopy(transformed_dag, options=opts, function_name=function_name, target=self.get_target() ).bind_to_context(self.context) + + self._compile_trace_callback( + function_name, "post_generate_loopy", pt_prg) + + self._compile_trace_callback( + function_name, "pre_transform_loopy_program", pt_prg) + pt_prg = pt_prg.with_transformed_translation_unit( self.transform_loopy_program) + + self._compile_trace_callback( + function_name, "post_transform_loopy_program", pt_prg) + + self._compile_trace_callback( + function_name, "final", pt_prg) + self._freeze_prg_cache[normalized_expr] = pt_prg else: transformed_dag, function_name = ( diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index c96aaf62..46786cd9 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -234,20 +234,20 @@ def rec_equal( def sum(self, a: ArrayOrContainer, axis: int | tuple[int, ...] | None = None, - dtype: DTypeLike = None, + dtype: DTypeLike | None = None, ) -> Array: ... @overload def sum(self, a: Scalar, axis: int | tuple[int, ...] | None = None, - dtype: DTypeLike = None, + dtype: DTypeLike | None = None, ) -> Scalar: ... @override def sum(self, a: ArrayOrContainerOrScalar, axis: int | tuple[int, ...] | None = None, - dtype: DTypeLike = None, + dtype: DTypeLike | None = None, ) -> ArrayOrScalar: def _pt_sum(ary): if dtype not in [ary.dtype, None]: