Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions expr_codegen/polars/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,15 @@ def codegen(exprs_ldl: ListDictList, exprs_src, syms_dst,
else:
_sym = ','.join(_sym)
if args.over_null == 'partition_by':
func_code.append(f"{va}=({s2}).over({_sym}, _ASSET_, order_by=_DATE_),")
if _sym:
func_code.append(f"{va}=({s2}).over({_sym}, _ASSET_, order_by=_DATE_),")
else:
func_code.append(f"{va}=({s2}).over(_ASSET_, order_by=_DATE_),")
elif args.over_null == 'order_by':
func_code.append(f"{va}=({s2}).over(_ASSET_, order_by=[{_sym}, _DATE_]),")
if _sym:
func_code.append(f"{va}=({s2}).over(_ASSET_, order_by=[{_sym}, _DATE_]),")
else:
func_code.append(f"{va}=({s2}).over(_ASSET_, order_by=_DATE_),")
else:
func_code.append(f"{va}=({s2}).over(_ASSET_, order_by=_DATE_),")
elif k[0] == CS:
Expand Down
311 changes: 310 additions & 1 deletion expr_codegen/polars/printer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sympy import Basic, Function, StrPrinter
from sympy import Add, Basic, Float, Function, Integer, Mul, Pow, Rational, StrPrinter
from sympy.printing.precedence import precedence, PRECEDENCE


Expand Down Expand Up @@ -56,6 +56,315 @@ def _print(self, expr, **kwargs) -> str:
def _print_Symbol(self, expr):
return expr.name

def _print_Function(self, expr) -> str:
func_name = expr.func.__name__

def _lit(v: object) -> str:
if isinstance(v, bool):
return "pl.lit(True)" if v else "pl.lit(False)"
if isinstance(v, int):
return f"pl.lit({v})"
if isinstance(v, float):
if v != v: # NaN
return "pl.lit(float('nan'))"
if v == float("inf"):
return "pl.lit(float('inf'))"
if v == float("-inf"):
return "pl.lit(float('-inf'))"
return f"pl.lit({repr(float(v))})"
return f"pl.lit({repr(v)})"

def _const_scalar(node: object) -> int | float | bool | None:
# Only fold constant-only expressions to avoid accidental execution of
# user-defined code. This is deliberately conservative.
if isinstance(node, (int, float, bool)):
return node
if isinstance(node, Integer):
return int(node)
if isinstance(node, Float):
return float(node)
if isinstance(node, Rational):
p, q = int(node.p), int(node.q)
if q == 0:
return None
v = p / q
if float(v).is_integer():
return int(v)
return float(v)
if not isinstance(node, Basic):
return None
try:
if getattr(node, "free_symbols", set()):
return None
except Exception:
return None
if getattr(node, "is_number", False):
try:
if getattr(node, "is_integer", False):
return int(node)
except Exception:
pass
try:
return float(node)
except Exception:
return None
return None

def _fold_const_function_call(call: object) -> str | None:
# Fold a small set of elementwise/conditional functions when all inputs
# are constant-only. This avoids emitting runtime calls like `sign(0)`,
# which may return numpy scalars (e.g. np.int64) and break Polars Lazy
# type inference in some window/weights plans.
if not getattr(call, "is_Function", False):
return None
fn = getattr(call.func, "__name__", "")
args = list(getattr(call, "args", ()))

vals = [_const_scalar(a) for a in args]
if any(v is None for v in vals):
return None

try:
if fn in ("abs_", "abs") and len(vals) == 1:
v = vals[0]
if isinstance(v, bool):
return _lit(1 if v else 0)
if isinstance(v, int):
return _lit(abs(v))
return _lit(abs(float(v)))

if fn in ("sign", "sign_") and len(vals) == 1:
x = float(vals[0])
out = 0 if x == 0.0 else (1 if x > 0.0 else -1)
return _lit(out)

if fn in ("neg_", "neg") and len(vals) == 1:
v = vals[0]
if isinstance(v, bool):
return _lit(-1 if v else 0)
if isinstance(v, int):
return _lit(-v)
return _lit(-float(v))

if fn in ("log", "log_") and len(vals) == 1:
import math

x = float(vals[0])
return _lit(float(math.log(x)))

if fn in ("min_", "min") and len(vals) >= 1:
xs = [float(v) for v in vals]
return _lit(float(min(xs)))

if fn in ("max_", "max") and len(vals) >= 1:
xs = [float(v) for v in vals]
return _lit(float(max(xs)))

if fn == "clip_" and len(vals) == 3:
x, lo, hi = (float(vals[0]), float(vals[1]), float(vals[2]))
return _lit(float(min(max(x, lo), hi)))

if fn in ("if_else", "where_") and len(vals) == 3:
cond = vals[0]
# bool condition: choose branch directly
if isinstance(cond, bool):
return _lit(vals[1] if cond else vals[2])
# numeric truthy/falsy: 0 => False
try:
return _lit(vals[1] if float(cond) != 0.0 else vals[2])
except Exception:
return None
except Exception:
return None

return None

folded = _fold_const_function_call(expr)
if folded is not None:
return folded

if func_name.startswith("ts_"):
# For ts_* functions:
# - The "series" args must be polars expressions (so numeric constants
# there should be wrapped as `pl.lit(...)`).
# - The trailing numeric params (window sizes, etc.) must remain plain
# Python scalars (ints/floats), otherwise downstream libs may do
# `min_samples or ...` and trigger Expr.__bool__ errors.
args = list(expr.args)

def _is_numeric(arg: object) -> bool:
if isinstance(arg, (Integer, Float, Rational)):
return True
if isinstance(arg, Basic):
try:
if getattr(arg, "is_number", False) and not getattr(arg, "free_symbols", set()):
return True
except Exception:
return False
return False

def _scalar_str(arg: object) -> str:
if isinstance(arg, Integer):
return str(int(arg))
if isinstance(arg, Float):
return repr(float(arg))
if isinstance(arg, Rational):
p, q = int(arg.p), int(arg.q)
if q == 1:
return str(p)
return repr(p / q)
if isinstance(arg, Basic) and getattr(arg, "is_number", False):
# Numeric-only SymPy expressions (e.g. Mul(Integer(-1), Integer(0))).
try:
if getattr(arg, "is_integer", False):
return str(int(arg))
except Exception:
pass
try:
return repr(float(arg))
except Exception:
return self._print(arg)
return self._print(arg)

def _scalar_str_from_value(v: object) -> str:
if isinstance(v, bool):
return "1" if v else "0"
if isinstance(v, int):
return str(v)
if isinstance(v, float):
return repr(float(v))
return repr(v)

def _try_const_value(node: object) -> int | float | bool | None:
# Only evaluate constant-only expressions to avoid accidentally
# executing user-defined code.
if isinstance(node, (int, float, bool)):
return node
if isinstance(node, Integer):
return int(node)
if isinstance(node, Float):
return float(node)
if isinstance(node, Rational):
p, q = int(node.p), int(node.q)
if q == 0:
return None
v = p / q
if float(v).is_integer():
return int(v)
return float(v)
if not isinstance(node, Basic):
return None
try:
if getattr(node, "free_symbols", set()):
return None
except Exception:
return None

# Native numeric-only SymPy expressions.
if getattr(node, "is_number", False):
try:
if getattr(node, "is_integer", False):
return int(node)
except Exception:
pass
try:
return float(node)
except Exception:
return None

# Basic arithmetic on constant subexpressions.
if isinstance(node, Add):
vals = [_try_const_value(a) for a in node.args]
if any(v is None for v in vals):
return None
return float(sum(float(v) for v in vals))
if isinstance(node, Mul):
vals = [_try_const_value(a) for a in node.args]
if any(v is None for v in vals):
return None
out = 1.0
for v in vals:
out *= float(v)
return float(out)
if isinstance(node, Pow):
base = _try_const_value(node.args[0])
exp = _try_const_value(node.args[1])
if base is None or exp is None:
return None
try:
return float(float(base) ** float(exp))
except Exception:
return None

# Safe subset of constant function evaluation.
if getattr(node, "is_Function", False):
fn = getattr(node.func, "__name__", "")
vals = [_try_const_value(a) for a in node.args]
if any(v is None for v in vals):
return None
xs = [float(v) for v in vals]
try:
if fn in ("abs_", "abs") and len(xs) == 1:
return abs(xs[0])
if fn in ("sign", "sign_") and len(xs) == 1:
x = xs[0]
return 0.0 if x == 0.0 else (1.0 if x > 0.0 else -1.0)
if fn in ("neg_", "neg") and len(xs) == 1:
return -xs[0]
if fn in ("log", "log_") and len(xs) == 1:
import math

return float(math.log(xs[0]))
if fn in ("min_", "min") and len(xs) >= 1:
return float(min(xs))
if fn in ("max_", "max") and len(xs) >= 1:
return float(max(xs))
if fn == "clip_" and len(xs) == 3:
x, lo, hi = xs
return float(min(max(x, lo), hi))
if fn in ("if_else", "where_") and len(vals) == 3:
cond = vals[0]
if isinstance(cond, bool):
return _try_const_value(node.args[1] if cond else node.args[2])
# Treat numeric condition as truthy/falsy (0 => False).
try:
return _try_const_value(node.args[1] if float(cond) != 0.0 else node.args[2])
except Exception:
return None
except Exception:
return None
return None

last_non_num_idx = -1
for i, arg in enumerate(args):
if not _is_numeric(arg):
last_non_num_idx = i

numeric_start = 1 if last_non_num_idx < 0 else last_non_num_idx + 1

printed_args: list[str] = []
for i, arg in enumerate(args):
if _is_numeric(arg):
if i >= numeric_start:
# Trailing numeric parameters: keep them as scalars.
printed_args.append(_scalar_str(arg))
else:
# Series position numeric constants: lift to Expr.
# Important: numeric-only SymPy expressions must also be lifted,
# otherwise Python will evaluate them as scalars before calling ts_*.
printed_args.append(f"pl.lit({_scalar_str(arg)})")
else:
if i < numeric_start:
v = _try_const_value(arg)
if v is not None:
printed_args.append(f"pl.lit({_scalar_str_from_value(v)})")
continue
printed_args.append(self._print(arg))

return f"{func_name}({','.join(printed_args)})"

return super()._print_Function(expr)

def _print_Equality(self, expr):
PREC = precedence(expr)
return "%s == %s" % (self.parenthesize(expr.args[0], PREC), self.parenthesize(expr.args[1], PREC))
Expand Down