From 1195d8cc9bfca9707e9df3e8a74db43dd4ae97e4 Mon Sep 17 00:00:00 2001 From: 0xhhh321321 <51537937+0xhhh321321@users.noreply.github.com> Date: Sun, 25 Jan 2026 00:40:30 +0800 Subject: [PATCH 1/6] Fix polars codegen for constants and empty over_null mask --- expr_codegen/polars/code.py | 10 ++++++++-- expr_codegen/polars/printer.py | 12 +++++++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/expr_codegen/polars/code.py b/expr_codegen/polars/code.py index 1eb2cbc..65ebf63 100644 --- a/expr_codegen/polars/code.py +++ b/expr_codegen/polars/code.py @@ -99,9 +99,15 @@ def codegen(exprs_ldl: ListDictList, exprs_src, syms_dst, else: _sym = ','.join(_sym) if args.over_null == 'partition_by': - func_code.append(f"{va}=({s2}).over({_sym}, _ASSET_, order_by=_DATE_),") + if _sym: + func_code.append(f"{va}=({s2}).over({_sym}, _ASSET_, order_by=_DATE_),") + else: + func_code.append(f"{va}=({s2}).over(_ASSET_, order_by=_DATE_),") elif args.over_null == 'order_by': - func_code.append(f"{va}=({s2}).over(_ASSET_, order_by=[{_sym}, _DATE_]),") + if _sym: + func_code.append(f"{va}=({s2}).over(_ASSET_, order_by=[{_sym}, _DATE_]),") + else: + func_code.append(f"{va}=({s2}).over(_ASSET_, order_by=_DATE_),") else: func_code.append(f"{va}=({s2}).over(_ASSET_, order_by=_DATE_),") elif k[0] == CS: diff --git a/expr_codegen/polars/printer.py b/expr_codegen/polars/printer.py index 551514a..573af08 100644 --- a/expr_codegen/polars/printer.py +++ b/expr_codegen/polars/printer.py @@ -1,4 +1,4 @@ -from sympy import Basic, Function, StrPrinter +from sympy import Basic, Float, Function, Integer, Rational, StrPrinter from sympy.printing.precedence import precedence, PRECEDENCE @@ -56,6 +56,16 @@ def _print(self, expr, **kwargs) -> str: def _print_Symbol(self, expr): return expr.name + def _print_Integer(self, expr: Integer) -> str: + return f"pl.lit({int(expr)})" + + def _print_Float(self, expr: Float) -> str: + return f"pl.lit({float(expr)!r})" + + def _print_Rational(self, expr: Rational) -> str: + p, q = int(expr.p), int(expr.q) + return f"pl.lit({p}/{q})" + def _print_Equality(self, expr): PREC = precedence(expr) return "%s == %s" % (self.parenthesize(expr.args[0], PREC), self.parenthesize(expr.args[1], PREC)) From 14bcfe70042440cd0d51d7dc05bbbef9bde0d579 Mon Sep 17 00:00:00 2001 From: 0xhhh321321 <51537937+0xhhh321321@users.noreply.github.com> Date: Sun, 25 Jan 2026 01:19:33 +0800 Subject: [PATCH 2/6] Fix ts_* window args printing --- expr_codegen/polars/printer.py | 37 ++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/expr_codegen/polars/printer.py b/expr_codegen/polars/printer.py index 573af08..5688cd2 100644 --- a/expr_codegen/polars/printer.py +++ b/expr_codegen/polars/printer.py @@ -66,6 +66,43 @@ def _print_Rational(self, expr: Rational) -> str: p, q = int(expr.p), int(expr.q) return f"pl.lit({p}/{q})" + def _print_Function(self, expr) -> str: + func_name = expr.func.__name__ + if func_name.startswith("ts_"): + # For ts_* functions, numeric trailing args are usually window sizes / params + # that must be plain Python scalars (not `pl.lit(...)`). + args = list(expr.args) + + last_non_num_idx = -1 + for i, arg in enumerate(args): + if not isinstance(arg, (Integer, Float, Rational)): + last_non_num_idx = i + + numeric_start = 1 if last_non_num_idx < 0 else last_non_num_idx + 1 + + def _print_ts_param(arg) -> str: + if isinstance(arg, Integer): + return str(int(arg)) + if isinstance(arg, Float): + return repr(float(arg)) + if isinstance(arg, Rational): + p, q = int(arg.p), int(arg.q) + if q == 1: + return str(p) + return repr(p / q) + return self._print(arg) + + printed_args: list[str] = [] + for i, arg in enumerate(args): + if i >= numeric_start and isinstance(arg, (Integer, Float, Rational)): + printed_args.append(_print_ts_param(arg)) + else: + printed_args.append(self._print(arg)) + + return f"{func_name}({','.join(printed_args)})" + + return super()._print_Function(expr) + def _print_Equality(self, expr): PREC = precedence(expr) return "%s == %s" % (self.parenthesize(expr.args[0], PREC), self.parenthesize(expr.args[1], PREC)) From 7699286d16205111e61d56e8df332cd5e68fd6d7 Mon Sep 17 00:00:00 2001 From: 0xhhh321321 <51537937+0xhhh321321@users.noreply.github.com> Date: Sun, 25 Jan 2026 01:43:15 +0800 Subject: [PATCH 3/6] Polars printer: only lift numeric series args --- expr_codegen/polars/printer.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/expr_codegen/polars/printer.py b/expr_codegen/polars/printer.py index 5688cd2..2e41a20 100644 --- a/expr_codegen/polars/printer.py +++ b/expr_codegen/polars/printer.py @@ -56,21 +56,15 @@ def _print(self, expr, **kwargs) -> str: def _print_Symbol(self, expr): return expr.name - def _print_Integer(self, expr: Integer) -> str: - return f"pl.lit({int(expr)})" - - def _print_Float(self, expr: Float) -> str: - return f"pl.lit({float(expr)!r})" - - def _print_Rational(self, expr: Rational) -> str: - p, q = int(expr.p), int(expr.q) - return f"pl.lit({p}/{q})" - def _print_Function(self, expr) -> str: func_name = expr.func.__name__ if func_name.startswith("ts_"): - # For ts_* functions, numeric trailing args are usually window sizes / params - # that must be plain Python scalars (not `pl.lit(...)`). + # For ts_* functions: + # - The "series" args must be polars expressions (so numeric constants + # there should be wrapped as `pl.lit(...)`). + # - The trailing numeric params (window sizes, etc.) must remain plain + # Python scalars (ints/floats), otherwise downstream libs may do + # `min_samples or ...` and trigger Expr.__bool__ errors. args = list(expr.args) last_non_num_idx = -1 @@ -80,7 +74,7 @@ def _print_Function(self, expr) -> str: numeric_start = 1 if last_non_num_idx < 0 else last_non_num_idx + 1 - def _print_ts_param(arg) -> str: + def _print_scalar(arg) -> str: if isinstance(arg, Integer): return str(int(arg)) if isinstance(arg, Float): @@ -94,8 +88,13 @@ def _print_ts_param(arg) -> str: printed_args: list[str] = [] for i, arg in enumerate(args): - if i >= numeric_start and isinstance(arg, (Integer, Float, Rational)): - printed_args.append(_print_ts_param(arg)) + if isinstance(arg, (Integer, Float, Rational)): + if i >= numeric_start: + # Trailing numeric parameters: keep them as scalars. + printed_args.append(_print_scalar(arg)) + else: + # Series position numeric constants: lift to Expr. + printed_args.append(f"pl.lit({_print_scalar(arg)})") else: printed_args.append(self._print(arg)) From 79532801f5597892b52e24cf834c70d836049e3d Mon Sep 17 00:00:00 2001 From: 0xhhh321321 <51537937+0xhhh321321@users.noreply.github.com> Date: Sun, 25 Jan 2026 03:03:04 +0800 Subject: [PATCH 4/6] Polars printer: lift numeric-only ts_* series args --- expr_codegen/polars/printer.py | 46 ++++++++++++++++++++++++++-------- 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/expr_codegen/polars/printer.py b/expr_codegen/polars/printer.py index 2e41a20..7312d85 100644 --- a/expr_codegen/polars/printer.py +++ b/expr_codegen/polars/printer.py @@ -67,14 +67,18 @@ def _print_Function(self, expr) -> str: # `min_samples or ...` and trigger Expr.__bool__ errors. args = list(expr.args) - last_non_num_idx = -1 - for i, arg in enumerate(args): - if not isinstance(arg, (Integer, Float, Rational)): - last_non_num_idx = i - - numeric_start = 1 if last_non_num_idx < 0 else last_non_num_idx + 1 - - def _print_scalar(arg) -> str: + def _is_numeric(arg: object) -> bool: + if isinstance(arg, (Integer, Float, Rational)): + return True + if isinstance(arg, Basic): + try: + if getattr(arg, "is_number", False) and not getattr(arg, "free_symbols", set()): + return True + except Exception: + return False + return False + + def _scalar_str(arg: object) -> str: if isinstance(arg, Integer): return str(int(arg)) if isinstance(arg, Float): @@ -84,17 +88,37 @@ def _print_scalar(arg) -> str: if q == 1: return str(p) return repr(p / q) + if isinstance(arg, Basic) and getattr(arg, "is_number", False): + # Numeric-only SymPy expressions (e.g. Mul(Integer(-1), Integer(0))). + try: + if getattr(arg, "is_integer", False): + return str(int(arg)) + except Exception: + pass + try: + return repr(float(arg)) + except Exception: + return self._print(arg) return self._print(arg) + last_non_num_idx = -1 + for i, arg in enumerate(args): + if not _is_numeric(arg): + last_non_num_idx = i + + numeric_start = 1 if last_non_num_idx < 0 else last_non_num_idx + 1 + printed_args: list[str] = [] for i, arg in enumerate(args): - if isinstance(arg, (Integer, Float, Rational)): + if _is_numeric(arg): if i >= numeric_start: # Trailing numeric parameters: keep them as scalars. - printed_args.append(_print_scalar(arg)) + printed_args.append(_scalar_str(arg)) else: # Series position numeric constants: lift to Expr. - printed_args.append(f"pl.lit({_print_scalar(arg)})") + # Important: numeric-only SymPy expressions must also be lifted, + # otherwise Python will evaluate them as scalars before calling ts_*. + printed_args.append(f"pl.lit({_scalar_str(arg)})") else: printed_args.append(self._print(arg)) From d467f419cff8791917bb3a5f34d0892d39c7667d Mon Sep 17 00:00:00 2001 From: 0xhhh321321 <51537937+0xhhh321321@users.noreply.github.com> Date: Sun, 25 Jan 2026 08:27:17 +0800 Subject: [PATCH 5/6] Polars printer: lift constant-only ts_* series args --- expr_codegen/polars/printer.py | 116 ++++++++++++++++++++++++++++++++- 1 file changed, 115 insertions(+), 1 deletion(-) diff --git a/expr_codegen/polars/printer.py b/expr_codegen/polars/printer.py index 7312d85..6eab201 100644 --- a/expr_codegen/polars/printer.py +++ b/expr_codegen/polars/printer.py @@ -1,4 +1,4 @@ -from sympy import Basic, Float, Function, Integer, Rational, StrPrinter +from sympy import Add, Basic, Float, Function, Integer, Mul, Pow, Rational, StrPrinter from sympy.printing.precedence import precedence, PRECEDENCE @@ -101,6 +101,115 @@ def _scalar_str(arg: object) -> str: return self._print(arg) return self._print(arg) + def _scalar_str_from_value(v: object) -> str: + if isinstance(v, bool): + return "1" if v else "0" + if isinstance(v, int): + return str(v) + if isinstance(v, float): + return repr(float(v)) + return repr(v) + + def _try_const_value(node: object) -> int | float | bool | None: + # Only evaluate constant-only expressions to avoid accidentally + # executing user-defined code. + if isinstance(node, (int, float, bool)): + return node + if isinstance(node, Integer): + return int(node) + if isinstance(node, Float): + return float(node) + if isinstance(node, Rational): + p, q = int(node.p), int(node.q) + if q == 0: + return None + v = p / q + if float(v).is_integer(): + return int(v) + return float(v) + if not isinstance(node, Basic): + return None + try: + if getattr(node, "free_symbols", set()): + return None + except Exception: + return None + + # Native numeric-only SymPy expressions. + if getattr(node, "is_number", False): + try: + if getattr(node, "is_integer", False): + return int(node) + except Exception: + pass + try: + return float(node) + except Exception: + return None + + # Basic arithmetic on constant subexpressions. + if isinstance(node, Add): + vals = [_try_const_value(a) for a in node.args] + if any(v is None for v in vals): + return None + return float(sum(float(v) for v in vals)) + if isinstance(node, Mul): + vals = [_try_const_value(a) for a in node.args] + if any(v is None for v in vals): + return None + out = 1.0 + for v in vals: + out *= float(v) + return float(out) + if isinstance(node, Pow): + base = _try_const_value(node.args[0]) + exp = _try_const_value(node.args[1]) + if base is None or exp is None: + return None + try: + return float(float(base) ** float(exp)) + except Exception: + return None + + # Safe subset of constant function evaluation. + if getattr(node, "is_Function", False): + fn = getattr(node.func, "__name__", "") + vals = [_try_const_value(a) for a in node.args] + if any(v is None for v in vals): + return None + xs = [float(v) for v in vals] + try: + if fn in ("abs_", "abs") and len(xs) == 1: + return abs(xs[0]) + if fn in ("sign", "sign_") and len(xs) == 1: + x = xs[0] + return 0.0 if x == 0.0 else (1.0 if x > 0.0 else -1.0) + if fn in ("neg_", "neg") and len(xs) == 1: + return -xs[0] + if fn in ("log", "log_") and len(xs) == 1: + import math + + return float(math.log(xs[0])) + if fn in ("min_", "min") and len(xs) >= 1: + return float(min(xs)) + if fn in ("max_", "max") and len(xs) >= 1: + return float(max(xs)) + if fn == "clip_" and len(xs) == 3: + x, lo, hi = xs + return float(min(max(x, lo), hi)) + if fn in ("if_else", "where_") and len(vals) == 3: + cond = vals[0] + if isinstance(cond, bool): + return _try_const_value(node.args[1] if cond else node.args[2]) + # Treat numeric condition as truthy/falsy (0 => False). + try: + return _try_const_value(node.args[1] if float(cond) != 0.0 else node.args[2]) + except Exception: + return None + except Exception: + return None + return None + last_non_num_idx = -1 for i, arg in enumerate(args): if not _is_numeric(arg): @@ -120,6 +229,11 @@ def _scalar_str(arg: object) -> str: # otherwise Python will evaluate them as scalars before calling ts_*. printed_args.append(f"pl.lit({_scalar_str(arg)})") else: + if i < numeric_start: + v = _try_const_value(arg) + if v is not None: + printed_args.append(f"pl.lit({_scalar_str_from_value(v)})") + continue printed_args.append(self._print(arg)) return f"{func_name}({','.join(printed_args)})" From 16c9c00f2505f38e0c5b94b7ac6b354ea20d7212 Mon Sep 17 00:00:00 2001 From: 0xhhh321321 <51537937+0xhhh321321@users.noreply.github.com> Date: Sun, 25 Jan 2026 10:14:37 +0800 Subject: [PATCH 6/6] polars printer: fold constant elementwise calls to pl.lit --- expr_codegen/polars/printer.py | 125 +++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) diff --git a/expr_codegen/polars/printer.py b/expr_codegen/polars/printer.py index 6eab201..bd3d324 100644 --- a/expr_codegen/polars/printer.py +++ b/expr_codegen/polars/printer.py @@ -58,6 +58,131 @@ def _print_Symbol(self, expr): def _print_Function(self, expr) -> str: func_name = expr.func.__name__ + + def _lit(v: object) -> str: + if isinstance(v, bool): + return "pl.lit(True)" if v else "pl.lit(False)" + if isinstance(v, int): + return f"pl.lit({v})" + if isinstance(v, float): + if v != v: # NaN + return "pl.lit(float('nan'))" + if v == float("inf"): + return "pl.lit(float('inf'))" + if v == float("-inf"): + return "pl.lit(float('-inf'))" + return f"pl.lit({repr(float(v))})" + return f"pl.lit({repr(v)})" + + def _const_scalar(node: object) -> int | float | bool | None: + # Only fold constant-only expressions to avoid accidental execution of + # user-defined code. This is deliberately conservative. + if isinstance(node, (int, float, bool)): + return node + if isinstance(node, Integer): + return int(node) + if isinstance(node, Float): + return float(node) + if isinstance(node, Rational): + p, q = int(node.p), int(node.q) + if q == 0: + return None + v = p / q + if float(v).is_integer(): + return int(v) + return float(v) + if not isinstance(node, Basic): + return None + try: + if getattr(node, "free_symbols", set()): + return None + except Exception: + return None + if getattr(node, "is_number", False): + try: + if getattr(node, "is_integer", False): + return int(node) + except Exception: + pass + try: + return float(node) + except Exception: + return None + return None + + def _fold_const_function_call(call: object) -> str | None: + # Fold a small set of elementwise/conditional functions when all inputs + # are constant-only. This avoids emitting runtime calls like `sign(0)`, + # which may return numpy scalars (e.g. np.int64) and break Polars Lazy + # type inference in some window/weights plans. + if not getattr(call, "is_Function", False): + return None + fn = getattr(call.func, "__name__", "") + args = list(getattr(call, "args", ())) + + vals = [_const_scalar(a) for a in args] + if any(v is None for v in vals): + return None + + try: + if fn in ("abs_", "abs") and len(vals) == 1: + v = vals[0] + if isinstance(v, bool): + return _lit(1 if v else 0) + if isinstance(v, int): + return _lit(abs(v)) + return _lit(abs(float(v))) + + if fn in ("sign", "sign_") and len(vals) == 1: + x = float(vals[0]) + out = 0 if x == 0.0 else (1 if x > 0.0 else -1) + return _lit(out) + + if fn in ("neg_", "neg") and len(vals) == 1: + v = vals[0] + if isinstance(v, bool): + return _lit(-1 if v else 0) + if isinstance(v, int): + return _lit(-v) + return _lit(-float(v)) + + if fn in ("log", "log_") and len(vals) == 1: + import math + + x = float(vals[0]) + return _lit(float(math.log(x))) + + if fn in ("min_", "min") and len(vals) >= 1: + xs = [float(v) for v in vals] + return _lit(float(min(xs))) + + if fn in ("max_", "max") and len(vals) >= 1: + xs = [float(v) for v in vals] + return _lit(float(max(xs))) + + if fn == "clip_" and len(vals) == 3: + x, lo, hi = (float(vals[0]), float(vals[1]), float(vals[2])) + return _lit(float(min(max(x, lo), hi))) + + if fn in ("if_else", "where_") and len(vals) == 3: + cond = vals[0] + # bool condition: choose branch directly + if isinstance(cond, bool): + return _lit(vals[1] if cond else vals[2]) + # numeric truthy/falsy: 0 => False + try: + return _lit(vals[1] if float(cond) != 0.0 else vals[2]) + except Exception: + return None + except Exception: + return None + + return None + + folded = _fold_const_function_call(expr) + if folded is not None: + return folded + if func_name.startswith("ts_"): # For ts_* functions: # - The "series" args must be polars expressions (so numeric constants