Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 109 additions & 21 deletions .basedpyright/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -889,14 +889,6 @@
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 36,
"endColumn": 44,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
Expand Down Expand Up @@ -2837,6 +2829,14 @@
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 19,
"endColumn": 40,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
Expand All @@ -2853,14 +2853,6 @@
"lineCount": 1
}
},
{
"code": "reportUnknownParameterType",
"range": {
"startColumn": 8,
"endColumn": 14,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
Expand Down Expand Up @@ -2894,18 +2886,26 @@
}
},
{
"code": "reportUnknownVariableType",
"code": "reportAny",
"range": {
"startColumn": 15,
"endColumn": 41,
"startColumn": 55,
"endColumn": 59,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 26,
"endColumn": 30,
"startColumn": 55,
"endColumn": 59,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 55,
"endColumn": 59,
"lineCount": 1
}
},
Expand Down Expand Up @@ -5877,6 +5877,30 @@
"lineCount": 3
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 16,
"endColumn": 44,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 16,
"endColumn": 44,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 12,
"endColumn": 40,
"lineCount": 1
}
},
{
"code": "reportUnknownVariableType",
"range": {
Expand All @@ -5901,6 +5925,38 @@
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 12,
"endColumn": 40,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 58,
"endColumn": 64,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 12,
"endColumn": 40,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 66,
"endColumn": 72,
"lineCount": 1
}
},
{
"code": "reportUnknownVariableType",
"range": {
Expand All @@ -5917,6 +5973,38 @@
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 12,
"endColumn": 40,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 67,
"endColumn": 73,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 12,
"endColumn": 40,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 44,
"endColumn": 50,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ jobs:
build_py_project_in_conda_env
conda install graphviz

CI_SUPPORT_SPHINX_VERSION_SPECIFIER=">=4.0"
build_docs

downstream_tests:
Expand Down
1 change: 0 additions & 1 deletion .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ Python 3 Conda:
Documentation:
script: |
curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-docs.sh
CI_SUPPORT_SPHINX_VERSION_SPECIFIER=">=4.0"
. ./build-docs.sh
tags:
- python3
Expand Down
12 changes: 6 additions & 6 deletions arraycontext/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def linspace(self,
num: int = 50,
*, endpoint: bool = True,
retstep: Literal[False] = False,
dtype: DTypeLike = None,
dtype: DTypeLike | None = None,
axis: int = 0
) -> Array: ...

Expand All @@ -197,7 +197,7 @@ def linspace(self,
num: int = 50,
*, endpoint: bool = True,
retstep: Literal[True],
dtype: DTypeLike = None,
dtype: DTypeLike | None = None,
axis: int = 0
) -> tuple[Array, NDArray[Any] | float] | Array: ...

Expand All @@ -207,7 +207,7 @@ def linspace(self,
num: int = 50,
*, endpoint: bool = True,
retstep: bool = False,
dtype: DTypeLike = None,
dtype: DTypeLike | None = None,
axis: int = 0
) -> tuple[Array, NDArray[Any] | float] | Array:
num = operator.index(num)
Expand Down Expand Up @@ -446,19 +446,19 @@ def where(self,
def sum(self,
a: ArrayOrContainer, /,
axis: int | tuple[int, ...] | None = None,
dtype: DTypeLike = None,
dtype: DTypeLike | None = None,
) -> Array: ...
@overload
def sum(self,
a: ScalarLike, /,
axis: int | tuple[int, ...] | None = None,
dtype: DTypeLike = None,
dtype: DTypeLike | None = None,
) -> ScalarLike: ...

def sum(self,
a: ArrayOrContainerOrScalar, /,
axis: int | tuple[int, ...] | None = None,
dtype: DTypeLike = None,
dtype: DTypeLike | None = None,
) -> ArrayOrScalar: ...

@overload
Expand Down
6 changes: 3 additions & 3 deletions arraycontext/impl/jax/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,20 +209,20 @@ def rec_equal(x, y):
def sum(self,
a: ArrayOrContainer,
axis: int | tuple[int, ...] | None = None,
dtype: DTypeLike = None,
dtype: DTypeLike | None = None,
) -> Array: ...
@overload
def sum(self,
a: Scalar,
axis: int | tuple[int, ...] | None = None,
dtype: DTypeLike = None,
dtype: DTypeLike | None = None,
) -> Scalar: ...

@override
def sum(self,
a: ArrayOrContainerOrScalar,
axis: int | tuple[int, ...] | None = None,
dtype: DTypeLike = None,
dtype: DTypeLike | None = None,
) -> ArrayOrScalar:
return rec_map_reduce_array_container(
sum,
Expand Down
8 changes: 4 additions & 4 deletions arraycontext/impl/numpy/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,20 +100,20 @@ def __getattr__(self, name: str):
def sum(self,
a: ArrayOrContainer,
axis: int | tuple[int, ...] | None = None,
dtype: DTypeLike = None,
dtype: DTypeLike | None = None,
) -> Array: ...
@overload
def sum(self,
a: Scalar,
axis: int | tuple[int, ...] | None = None,
dtype: DTypeLike = None,
dtype: DTypeLike | None = None,
) -> Scalar: ...

@override
def sum(self,
a: ArrayOrContainerOrScalar,
axis: int | tuple[int, ...] | None = None,
dtype: DTypeLike = None,
dtype: DTypeLike | None = None,
) -> ArrayOrScalar:
return rec_map_reduce_array_container(sum, partial(np.sum,
axis=axis,
Expand Down Expand Up @@ -273,7 +273,7 @@ def array_equal(self,

@override
def arange(self, *args, **kwargs):
return np.arange(*args, **kwargs)
return cast("Array", cast("object", np.arange(*args, **kwargs)))

@override
def linspace(self, *args, **kwargs):
Expand Down
6 changes: 3 additions & 3 deletions arraycontext/impl/pyopencl/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,20 +345,20 @@ def inner(ary: ArrayOrScalar) -> ArrayOrScalar:
def sum(self,
a: ArrayOrContainer,
axis: int | tuple[int, ...] | None = None,
dtype: DTypeLike = None,
dtype: DTypeLike | None = None,
) -> Array: ...
@overload
def sum(self,
a: ScalarLike,
axis: int | tuple[int, ...] | None = None,
dtype: DTypeLike = None,
dtype: DTypeLike | None = None,
) -> ScalarLike: ...

@override
def sum(self,
a: ArrayOrContainerOrScalar,
axis: int | tuple[int, ...] | None = None,
dtype: DTypeLike = None,
dtype: DTypeLike | None = None,
) -> ArrayOrScalar:
if isinstance(axis, int):
axis = axis,
Expand Down
2 changes: 1 addition & 1 deletion arraycontext/impl/pyopencl/taggable_cl_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def to_tagged_cl_array(ary: cla.Array,
def empty(
queue: cl.CommandQueue,
shape: tuple[int, ...] | int,
dtype: DTypeLike = float,
dtype: DTypeLike | None = None,
*, axes: tuple[Axis, ...] | None = None,
tags: frozenset[Tag] = _EMPTY_TAG_SET,
order: Literal["C"] | Literal["F"] = "C",
Expand Down
23 changes: 23 additions & 0 deletions arraycontext/impl/pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,20 +663,43 @@ def _to_frozen(
# All name_hint_tags shared at least some common prefix.
function_name = f"frozen_{name_hint}" if name_hint else "frozen_result"

self._compile_trace_callback(
function_name, "frozen-expr", normalized_expr)

self._compile_trace_callback(
function_name, "post_transform_dag", transformed_dag)

self._dag_transform_cache[normalized_expr] = (
transformed_dag, function_name)

from arraycontext.loopy import _DEFAULT_LOOPY_OPTIONS
opts = _DEFAULT_LOOPY_OPTIONS
assert opts.return_dict

self._compile_trace_callback(
function_name, "pre_generate_loopy", transformed_dag)

pt_prg = pt.generate_loopy(transformed_dag,
options=opts,
function_name=function_name,
target=self.get_target()
).bind_to_context(self.context)

self._compile_trace_callback(
function_name, "post_generate_loopy", pt_prg)

self._compile_trace_callback(
function_name, "pre_transform_loopy_program", pt_prg)

pt_prg = pt_prg.with_transformed_translation_unit(
self.transform_loopy_program)

self._compile_trace_callback(
function_name, "post_transform_loopy_program", pt_prg)

self._compile_trace_callback(
function_name, "final", pt_prg)

self._freeze_prg_cache[normalized_expr] = pt_prg
else:
transformed_dag, function_name = (
Expand Down
Loading
Loading