diff --git a/Cargo.toml b/Cargo.toml index 92965d9a..375d6244 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ exa = ["libloading"] [dependencies] cached = { version = "0.56.0" } csv = "1.3.0" -diffsol = "=0.7.0" +diffsol = { version = "=0.7.0" } libloading = { version = "0.8.6", optional = true, features = [] } nalgebra = "0.34.1" ndarray = { version = "0.16.1", features = ["rayon"] } @@ -28,6 +28,7 @@ thiserror = "2.0.11" argmin = "0.11.0" argmin-math = "0.5.1" tracing = "0.1.41" +once_cell = "1.18.0" [dev-dependencies] criterion = { version = "0.7.0", features = ["html_reports"] } @@ -47,3 +48,8 @@ harness = false [[bench]] name = "ode" harness = false + +[[bench]] +name = "wasm_ode_compare" +harness = false +required-features = ["exa"] diff --git a/benches/wasm_ode_compare.rs b/benches/wasm_ode_compare.rs new file mode 100644 index 00000000..b91d3f1e --- /dev/null +++ b/benches/wasm_ode_compare.rs @@ -0,0 +1,182 @@ +use criterion::{criterion_group, criterion_main, Criterion}; +use pharmsol::*; +use std::hint::black_box; + +fn example_subject() -> Subject { + Subject::builder("1") + .infusion(0.0, 500.0, 0, 0.5) + .observation(0.5, 1.645776, 0) + .observation(1.0, 1.216442, 0) + .observation(2.0, 0.4622729, 0) + .observation(3.0, 0.1697458, 0) + .observation(4.0, 0.06382178, 0) + .observation(6.0, 0.009099384, 0) + .observation(8.0, 0.001017932, 0) + .missing_observation(12.0, 0) + .build() +} + +fn regular_ode_predictions(c: &mut Criterion) { + let subject = example_subject(); + let ode = equation::ODE::new( + |x, p, _t, dx, _b, rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + rateiv[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + (1, 1), + ); + let params = vec![1.02282724609375, 194.51904296875]; + + c.bench_function("regular_ode_predictions", |b| { + b.iter(|| { + black_box(ode.estimate_predictions(&subject, ¶ms).unwrap()); + }) + }); +} + +fn wasm_ir_ode_predictions(c: &mut Criterion) { + let subject = example_subject(); + + // Setup WASM IR model + let test_dir = std::env::current_dir().expect("Failed to get current directory"); + let ir_path = test_dir.join("test_model_ir_bench.pkm"); + + let _ir_file = exa_wasm::build::emit_ir::( + "|x, p, _t, dx, rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + rateiv[0]; + }" + .to_string(), + None, + None, + Some("|p, _t, _cov, x| { }".to_string()), + Some( + "|x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }" + .to_string(), + ), + Some(ir_path.clone()), + vec!["ke".to_string(), "v".to_string()], + ) + .expect("emit_ir failed"); + + let (wasm_ode, _meta, _id) = + exa_wasm::interpreter::load_ir_ode(ir_path.clone()).expect("load_ir_ode failed"); + + let params = vec![1.02282724609375, 194.51904296875]; + + c.bench_function("wasm_ir_ode_predictions", |b| { + b.iter(|| { + black_box(wasm_ode.estimate_predictions(&subject, ¶ms).unwrap()); + }) + }); + + // Clean up + std::fs::remove_file(ir_path).ok(); +} + +fn regular_ode_likelihood(c: &mut Criterion) { + let subject = example_subject(); + let ode = equation::ODE::new( + |x, p, _t, dx, _b, rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + rateiv[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + (1, 1), + ); + let params = vec![1.02282724609375, 194.51904296875]; + let ems = ErrorModels::new() + .add( + 0, + ErrorModel::additive(ErrorPoly::new(0.0, 0.05, 0.0, 0.0), 0.0), + ) + .unwrap(); + + c.bench_function("regular_ode_likelihood", |b| { + b.iter(|| { + black_box( + ode.estimate_likelihood(&subject, ¶ms, &ems, false) + .unwrap(), + ); + }) + }); +} + +fn wasm_ir_ode_likelihood(c: &mut Criterion) { + let subject = example_subject(); + + // Setup WASM IR model + let test_dir = std::env::current_dir().expect("Failed to get current directory"); + let ir_path = test_dir.join("test_model_ir_bench_ll.pkm"); + + let _ir_file = exa_wasm::build::emit_ir::( + "|x, p, _t, dx, rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + rateiv[0]; + }" + .to_string(), + None, + None, + Some("|p, _t, _cov, x| { }".to_string()), + Some( + "|x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }" + .to_string(), + ), + Some(ir_path.clone()), + vec!["ke".to_string(), "v".to_string()], + ) + .expect("emit_ir failed"); + + let (wasm_ode, _meta, _id) = + exa_wasm::interpreter::load_ir_ode(ir_path.clone()).expect("load_ir_ode failed"); + + let params = vec![1.02282724609375, 194.51904296875]; + let ems = ErrorModels::new() + .add( + 0, + ErrorModel::additive(ErrorPoly::new(0.0, 0.05, 0.0, 0.0), 0.0), + ) + .unwrap(); + + c.bench_function("wasm_ir_ode_likelihood", |b| { + b.iter(|| { + black_box( + wasm_ode + .estimate_likelihood(&subject, ¶ms, &ems, false) + .unwrap(), + ); + }) + }); + + // Clean up + std::fs::remove_file(ir_path).ok(); +} + +fn criterion_benchmark(c: &mut Criterion) { + regular_ode_predictions(c); + wasm_ir_ode_predictions(c); + regular_ode_likelihood(c); + wasm_ir_ode_likelihood(c); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/examples/bytecode_models.rs b/examples/bytecode_models.rs new file mode 100644 index 00000000..df2b691a --- /dev/null +++ b/examples/bytecode_models.rs @@ -0,0 +1,108 @@ +use std::env; +use std::fs; +// example: emit IR and load via the runtime + +fn main() { + let tmp = env::temp_dir(); + + // Model 1: simple dx assignment + let diffeq1 = "|x, p, _t, dx, rateiv, _cov| { dx[0] = -ke * x[0]; }".to_string(); + let path1 = tmp.join("exa_example_model1.json"); + let _ = pharmsol::exa_wasm::build::emit_ir::( + diffeq1, + None, + None, + None, + None, + Some(path1.clone()), + vec!["ke".to_string()], + ) + .expect("emit_ir model1"); + + // Model 2: prelude/local and rate + let diffeq2 = + "|x, p, _t, dx, rateiv, _cov| { ke = 0.5; dx[0] = -ke * x[0] + rateiv[0]; }".to_string(); + let path2 = tmp.join("exa_example_model2.json"); + let _ = pharmsol::exa_wasm::build::emit_ir::( + diffeq2, + None, + None, + None, + None, + Some(path2.clone()), + vec!["ke".to_string()], + ) + .expect("emit_ir model2"); + + // Model 3: builtin and ternary + let diffeq3 = + "|x, p, _t, dx, rateiv, _cov| { dx[0] = if(t>0, exp(-ke * t) * x[0], 0.0); }".to_string(); + let path3 = tmp.join("exa_example_model3.json"); + let _ = pharmsol::exa_wasm::build::emit_ir::( + diffeq3, + None, + None, + None, + None, + Some(path3.clone()), + vec!["ke".to_string()], + ) + .expect("emit_ir model3"); + + println!( + "Emitted IR to:\n {:?}\n {:?}\n {:?}", + path1, path2, path3 + ); + + // Load them via the public API and print emitted IR metadata from the + // emitted JSON (avoids accessing private registry internals from an + // example binary). + for p in [&path1, &path2, &path3] { + // try to load via runtime loader (public re-export) + match pharmsol::exa_wasm::load_ir_ode(p.clone()) { + Ok((_ode, _meta, id)) => { + println!("loader accepted model, registry id={}", id); + } + Err(e) => { + eprintln!("loader rejected model {:?}: {}", p, e); + } + } + + // read raw IR and display bytecode/funcs/locals metadata + match fs::read_to_string(p) { + Ok(s) => match serde_json::from_str::(&s) { + Ok(v) => { + let has_bc = v.get("diffeq_bytecode").is_some(); + let funcs = v + .get("funcs") + .and_then(|j| { + j.as_array() + .map(|a| a.iter().filter_map(|x| x.as_str()).collect::>()) + }) + .unwrap_or_default(); + let locals = v + .get("locals") + .and_then(|j| { + j.as_array() + .map(|a| a.iter().filter_map(|x| x.as_str()).collect::>()) + }) + .unwrap_or_default(); + println!( + "IR {:?}: diffeq_bytecode={} funcs={:?} locals={:?}", + p.file_name().unwrap_or_default(), + has_bc, + funcs, + locals + ); + } + Err(e) => eprintln!("failed to parse emitted IR {:?}: {}", p, e), + }, + Err(e) => eprintln!("failed to read emitted IR {:?}: {}", p, e), + } + } + + // cleanup + let _ = fs::remove_file(&path1); + let _ = fs::remove_file(&path2); + let _ = fs::remove_file(&path3); +} diff --git a/examples/emit_debug.rs b/examples/emit_debug.rs new file mode 100644 index 00000000..77b4aa69 --- /dev/null +++ b/examples/emit_debug.rs @@ -0,0 +1,19 @@ +fn main() { + use pharmsol::equation; + use pharmsol::exa_wasm::build::emit_ir; + + // Simple helper example that emits IR for a small model and prints the + // location of the generated IR file. Keep example minimal and only use + // public APIs so it doesn't depend on internal interpreter modules. + let out = emit_ir::( + "|x, p, _t, dx, rateiv, _cov| { dx[0] = x[0].sin(); }".to_string(), + None, + None, + None, + None, + None, + vec![], + ) + .expect("emit_ir"); + println!("wrote IR to: {}", out); +} diff --git a/examples/exa.rs b/examples/exa.rs index d9ee92d9..38ff5802 100644 --- a/examples/exa.rs +++ b/examples/exa.rs @@ -34,7 +34,9 @@ fn main() { ); //clear build - clear_build(); + // clear_build(); + + println!("{}", exa::build::template_path()); let test_dir = std::env::current_dir().expect("Failed to get current directory"); let model_output_path = test_dir.join("test_model.pkm"); @@ -44,9 +46,9 @@ fn main() { format!( r#" equation::ODE::new( - |x, p, _t, dx, rateiv, _cov| {{ + |x, p, _t, dx, b, rateiv, _cov| {{ fetch_params!(p, ke, _v); - dx[0] = -ke * x[0] + rateiv[0]; + dx[0] = -ke * x[0] + rateiv[0] + b[0]; }}, |_p, _t, _cov| lag! {{}}, |_p, _t, _cov| fa! {{}}, diff --git a/examples/wasm_ode_compare.rs b/examples/wasm_ode_compare.rs new file mode 100644 index 00000000..6363adb0 --- /dev/null +++ b/examples/wasm_ode_compare.rs @@ -0,0 +1,115 @@ +//cargo run --example wasm_ode_compare --features exa + +fn main() { + use pharmsol::{equation, exa_wasm, *}; + // use std::path::PathBuf; // not needed + + let subject = Subject::builder("1") + .infusion(0.0, 500.0, 0, 0.5) + .observation(0.5, 1.645776, 0) + .observation(1.0, 1.216442, 0) + .observation(2.0, 0.4622729, 0) + .observation(3.0, 0.1697458, 0) + .observation(4.0, 0.06382178, 0) + .observation(6.0, 0.009099384, 0) + .observation(8.0, 0.001017932, 0) + .missing_observation(12.0, 0) + .build(); + + // Regular ODE model + let ode = equation::ODE::new( + |x, p, _t, dx, _b, rateiv, _cov| { + fetch_params!(p, ke, _v); + if true { + dx[0] = -ke * x[0] + rateiv[0]; + } + // dx[0] = -ke * x[0] + rateiv[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + (1, 1), + ); + + // Compile WASM IR model using exa (interpreter, not native dynlib) + let test_dir = std::env::current_dir().expect("Failed to get current directory"); + let ir_path = test_dir.join("test_model_ir.pkm"); + // This emits a JSON IR file for the same ODE model + let _ir_file = exa_wasm::build::emit_ir::( + "|x, p, _t, dx, rateiv, _cov| { + fetch_params!(p, ke, _v); + if true { + if true { + dx[0] = -ke * x[0] + rateiv[0]; + } + } + }" + .to_string(), + None, + None, + Some("|p, _t, _cov, x| { }".to_string()), + Some( + "|x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }" + .to_string(), + ), + Some(ir_path.clone()), + vec!["ke".to_string(), "v".to_string()], + ) + .expect("emit_ir failed"); + + //debug the contents of the ir file + let ir_contents = std::fs::read_to_string(&ir_path).expect("Failed to read IR file"); + println!("Generated IR file contents:\n{}", ir_contents); + + // Load the IR model using the WASM-capable interpreter + let (wasm_ode, _meta, _id) = + exa_wasm::interpreter::load_ir_ode(ir_path.clone()).expect("load_ir_ode failed"); + + let params = vec![1.02282724609375, 194.51904296875]; + + // Get predictions from both models + let ode_predictions = ode.estimate_predictions(&subject, ¶ms).unwrap(); + let wasm_predictions = wasm_ode.estimate_predictions(&subject, ¶ms).unwrap(); + + // Display predictions side by side + println!("Predictions:"); + println!("ODE\tWASM ODE\tDifference"); + ode_predictions + .flat_predictions() + .iter() + .zip(wasm_predictions.flat_predictions()) + .for_each(|(a, b)| println!("{:.9}\t{:.9}\t{:.9}", a, b, a - b)); + + // Optionally, display likelihoods + let mut ems = ErrorModels::new() + .add( + 0, + ErrorModel::additive(ErrorPoly::new(0.0, 0.05, 0.0, 0.0), 0.0), + ) + .unwrap(); + ems = ems + .add( + 1, + ErrorModel::proportional(ErrorPoly::new(0.0, 0.05, 0.0, 0.0), 0.0), + ) + .unwrap(); + let ll_ode = ode + .estimate_likelihood(&subject, ¶ms, &ems, false) + .unwrap(); + let ll_wasm = wasm_ode + .estimate_likelihood(&subject, ¶ms, &ems, false) + .unwrap(); + println!("\nLikelihoods:"); + println!("ODE\tWASM ODE"); + println!("{:.6}\t{:.6}", -2.0 * ll_ode, -2.0 * ll_wasm); + + // Clean up + std::fs::remove_file(ir_path).ok(); +} diff --git a/src/exa/wasm_plugin_spec.md b/src/exa/wasm_plugin_spec.md new file mode 100644 index 00000000..b4c1b294 --- /dev/null +++ b/src/exa/wasm_plugin_spec.md @@ -0,0 +1,182 @@ +# Pharmsol WASM Plugin Specification + +Version: 1.0 (draft) + +Purpose +------- + +This document specifies a minimal, versioned ABI for user-provided WebAssembly plugins that implement Pharmsol models. The goal is to allow advanced users to produce precompiled `.wasm` modules that can be safely instantiated by Pharmsol (in a browser or Wasmtime/Wasmer) and invoked to evaluate model behavior (derivatives, metadata, steps) without requiring recompilation of the host or linking native dynamic libraries. + +Design goals +------------ + +- Minimal: small set of imports/exports to ease authoring across languages. +- Stable: versioned ABI to allow forward/backward compatibility. +- Language neutral: use linear memory + u32/u64 primitives and JSON for structured metadata. +- Safe: use opaque handles and pointer/length pairs instead of native pointers to Rust objects. +- Sandboxed: rely on WASM runtime to enforce limits; host should enforce extra limits (memory, fuel). + +High-level contract +------------------- + +- All plugin modules MUST export `plugin_abi_version` and `plugin_name`. +- Plugins MAY require specific host imports (logging, allocation helpers). The host will supply sensible defaults if imports are missing, if possible. +- The host will instantiate the plugin and then call `plugin_create` with an optional configuration blob (JSON). The plugin returns a small integer handle (non-zero) or zero to signal error. +- The host uses handles to call `plugin_step`, `plugin_get_metadata`, and `plugin_free`. + +ABI versioning +-------------- + +- `plugin_abi_version()` -> u32 + - The plugin returns a u32 ABI version number. Host must refuse to load plugins with major versions it cannot support. Semantic versioning of the ABI is recommended (major increments break compatibility). + +Exports (required) +------------------ + +All pointer and length types use 32-bit unsigned integers (u32) to index the module's linear memory. Handles are 32-bit unsigned integers (u32) with 0 reserved for invalid/null. + +1. plugin_abi_version() -> u32 + +- Returns the ABI version implemented by the module. + +2. plugin_name(ptr: u32, len: u32) -> u32 + +- Optional: write the plugin name string into host-provided buffer. Alternatively, plugin may return 0 and provide name via `plugin_get_metadata`. +- Semantics: host provides a pointer/len to memory it controls (or 0/0 to request size). If ptr==0 and len==0, plugin returns required size. If ptr!=0, plugin writes up to len bytes and returns actual written bytes or negative error code. + +3. plugin_create(config_ptr: u32, config_len: u32) -> u32 + +- Create an instance of the model. `config` is a JSON blob (UTF-8) describing initial parameters or options. If both are 0, plugin uses built-in defaults. +- Returns a non-zero handle on success or 0 on error. For error details, the host should call `plugin_last_error` (see optional exports). + +4. plugin_free(handle: u32) -> u32 + +- Free resources associated with a handle. Returns 0 for success, non-zero for error. + +5. plugin_step(handle: u32, t: f64, dt: f64, inputs_ptr: u32, inputs_len: u32, outputs_ptr: u32, outputs_len_ptr: u32) -> i32 + +- Evaluate a step or compute derivatives for the model instance. +- `t` and `dt` are host-supplied time and timestep (floating point). The semantics of step are model-specific (integration step or single derivative evaluation); document clearly in plugin metadata. +- `inputs_ptr/inputs_len` point to an array of f64 values (packed little-endian) representing parameter values or exogenous inputs. The plugin may accept fewer or more inputs; any mismatch is an error. +- `outputs_ptr` is where the plugin writes resulting f64 outputs; `outputs_len_ptr` is a pointer in host memory where the plugin will write the number of f64 values it wrote (or required size when ptr was null). +- Return code: 0 success, negative values for defined error codes (see Error Codes). + +6. plugin_get_metadata(handle: u32, out_ptr_ptr: u32) -> i32 + +- Return a JSON metadata blob describing the model: parameter names and ordering, state variable names, default values, units, equation kind, capabilities (events, stochastic), and ABI version. +- The plugin will allocate the JSON string in its linear memory and write a 64-bit pointer/length pair into the host-provided `out_ptr_ptr` (two consecutive u32 values: ptr then len). Alternatively, if the plugin implements `host_alloc`, it can call into the host's allocator instead. +- Return 0 success, negative error for failure. + +Optional exports (recommended) +----------------------------- + +1. plugin_last_error(handle: u32, out_ptr: u32, out_len: u32) -> i32 + +- Copy last error message string into the provided buffer. If out_ptr==0 and out_len==0, return required length. + +2. plugin_supports_f64() -> u32 + +- Return 1 if plugin expects f64 for numerical buffers (recommended). Otherwise 0. + +Host imports (recommended) +-------------------------- + +These function imports allow plugins to use host helpers rather than re-implementing allocators or logging. The host may choose to provide stubs. + +1. host_alloc(size: u32) -> u32 + +- Allocate `size` bytes in the host's memory space accessible to the plugin. Returns pointer offset into host-supplied linear memory or 0 on failure. (Use only if the host and plugin share linear memory; otherwise plugin will allocate in its own memory.) + +2. host_free(ptr: u32, size: u32) + +- Free a host-allocated block. + +3. host_log(ptr: u32, len: u32, level: u32) + +- Host-provided logging helper. Plugin writes UTF-8 bytes to plugin memory and passes pointer,len. Level is user-defined (0=debug,1=info,2=warn,3=error). + +4. host_random_u64() -> u64 + +- Provide randomness from host if needed. Plugins needing deterministic seeds should accept them via `plugin_create` config. + +Error codes +----------- + +- Return negative i32 values for errors to keep C-like convention. + +- -1: Generic error +- -2: Invalid handle +- -3: Buffer too small / size mismatch (caller should retry with provided size) +- -4: Unsupported ABI version +- -5: Unsupported capability +- -6: Internal plugin panic/trap + +Memory ownership and allocation patterns +-------------------------------------- + +- Prefer the linear-memory pointer/length convention for strings and blobs. The host will copy strings into plugin memory when calling functions, or the plugin will allocate and return pointers with lengths. +- To return dynamically created strings (like metadata JSON), plugin should allocate memory inside its own linear memory and write pointer/length into the host-supplied pointer slot. The host must be prepared to read and copy that data before the plugin frees it. + +Security and sandboxing +----------------------- + +- Plugins must not assume file system or network access unless launched with appropriate WASI capabilities. Hosts must opt-in to features and apply least privilege. +- Hosts must enforce memory limits and allow interrupting long-running plugins. Use Wasmtime's fuel mechanism or equivalent. + +Compatibility notes +------------------- + +- Always check `plugin_abi_version` before using other exports. +- Hosts should fallback to IR/interpreter-based execution when plugin ABI is unsupported. + +Authoring guidelines +-------------------- + +1. Start a plugin with the minimal exports required to avoid host rejection. +2. Provide detailed metadata: parameter order, state order, units, capability flags (events, stochastic), and recommended recommended step semantics. +3. Use JSON for metadata to avoid tightly-coupled binary formats. + +Build hints for Rust authors +--------------------------- + +- Build with `cdylib` or `--target wasm32-unknown-unknown` and avoid relying on `std` features that require WASI unless you target wasm32-wasi. +- Use `wasm-bindgen` only if you target JS and plan to use JS glue; otherwise prefer raw wasm exports with `#[no_mangle] extern "C"` functions and a small allocator like `wee_alloc`. + +Example memory sequence (metadata retrieval) +------------------------------------------- + +1. Host instantiates plugin and calls `plugin_get_metadata(instance_handle, host_out_ptr)` where `host_out_ptr` points to two consecutive u32 slots in host memory. +2. Plugin serializes JSON string into its linear memory at offset P and length L. +3. Plugin writes P and L to the two u32 slots at `host_out_ptr` and returns 0. +4. Host reads P and L via the Wasm instance memory view and copies the JSON blob to its own memory space. Host may then call `plugin_free_memory(P, L)` if the plugin offers such an export, or expect the plugin to free on `plugin_free`. + +Troubleshooting +--------------- + +- If metadata size is unknown, host can call `plugin_get_metadata(handle, 0)` which should return the required size in a standard location or return -3 with the required size encoded in a convention (prefer the pointer/length return method described). + +Examples and recipes +-------------------- + +- Example flow for a simple model plugin: + 1. `plugin_abi_version()` -> 1 + 2. Host calls `plugin_create(0,0)` -> returns handle 1 + 3. Host calls `plugin_get_metadata(1, out_ptr)` -> reads metadata JSON, learns parameter/state ordering + 4. Host calls `plugin_step(1, t, dt, inputs_ptr, inputs_len, out_ptr, out_len_ptr)` repeatedly to step/evaluate. + +Specification lifecycle and version bumps +--------------------------------------- + +- Start version 1.0, keep ABI additive if possible. If a breaking change is required, increment major version and require host/plugin negotiation. + +Next steps +---------- + +1. Add a concrete `pharmsol-plugin` crate template that exports the minimal ABI and demonstrates metadata and step implementations. +2. Add CI recipes for building `wasm32-unknown-unknown` and `wasm32-wasi` artifacts. +3. Implement host-side adapters in Pharmsol for instantiating a plugin, mapping metadata to the `Meta` type, and wrapping `plugin_step` as an `Equation` implementation that existing integrators can call. + +Appendix: change log +-------------------- + +- 2025-10-29: Draft 1.0 created. diff --git a/src/exa_wasm.md b/src/exa_wasm.md new file mode 100644 index 00000000..20aa9f94 --- /dev/null +++ b/src/exa_wasm.md @@ -0,0 +1,262 @@ +# Executing user-defined models on WebAssembly — analysis and design + +October 29, 2025 + +This document analyzes the existing `exa` model-building and dynamic-loading approach in Pharmsol, explains why it is not usable on WebAssembly (WASM), and presents multiple design options to enable running user-defined models from within a WASM-hosted Pharmsol runtime. It discusses trade-offs, security, ABI proposals, testing strategies, and recommended next steps. + +This is a technical engineering analysis intended to be a design blueprint. It intentionally avoids implementation code, and instead focuses on concrete architectures, precise interface sketches, hazards, and validation plans. + +## Quick summary / recommendation + +- The current `exa` approach (creating a temporary Rust project, running `cargo build`, producing a platform dynamic library, then using `libloading` to load symbols) cannot be used from a WASM target (browser or wasm32-unknown-unknown runtime) because it depends on process spawning, filesystem semantics, dynamic linking, and thread/process control not available in typical WASM hosts. +- Two primary workable approaches for WASM compatibility are: (A) interpret a serialized model representation (AST/bytecode/DSL) inside the WASM host; (B) accept precompiled WASM modules (created outside the WASM host) and run them with a well-defined, minimal ABI. Each has strong trade-offs. A hybrid can offer a pragmatic path: an interpreter as the default portable path plus an optional WASM-plugin pathway for advanced users. +- Recommendation: start with an interpreter/serialized-IR approach for maximum portability (works in browser and WASI), and define a companion, clearly versioned WASM plugin ABI for power users and server-side deployments where precompiled WASM modules can be uploaded/installed. + +## Why the current `exa` cannot run on WASM + +Key reasons: + +1. Build-time tooling: `exa::build` shells out to `cargo` to create a template project and run `cargo build`. Running `cargo` requires a host OS with process spawning, filesystem with developer toolchain, and native toolchain availability. In browsers and many WASM runtimes this is impossible. + +2. Dynamic linking: `exa::load` uses `libloading` and platform dynamic libraries (`.so`, `.dll`, `.dylib`) and relies on native ABI and dynamic linking at runtime. WASM runtimes (especially wasm32-unknown-unknown) do not support Unix-like dlopen of platform shared libraries. Even wasm-hosts that support dynamic linking (WASI with preopened files) differ significantly from native OS dynamic linking. + +3. FFI and ownership: The code uses raw pointers, expects Rust ABI and cloned objects to cross library boundaries. WASM modules expose different ABIs (linear memory, function exports/imports). Passing complex Rust objects by pointer across a WASM boundary is fragile and often impossible without serialization glue. + +4. Threads and blocking IO: The build process spawns threads to stream stdout/err and waits on child processes. Many WASM environments (browsers) do not support native threads or block the event loop differently. + +Because of the above, the server-side native dynamic plugin model does not translate to a WASM-hosted environment without redesign. + +## Use cases we must support (requirements) + +1. Allow end users to define models (ODEs or other equation types) and execute them inside a WASM-hosted Pharmsol instance (browser and WASM runtimes) without requiring native cargo toolchain inside the runtime. +2. Preserve a reasonably high-performance execution where possible (some models are performance sensitive). Allow optional high-performance plugin paths. +3. Maintain safety and security: user models must be sandboxed (resource limits, no arbitrary host access unless explicitly granted). +4. Keep a small, stable host-plugin interface and version it. +5. Provide a migration path so existing `exa` users can adopt the wasm-capable approach. + +Implicit constraints: + +- Minimal or no native code compilation in the runtime. Compilation should happen outside the runtime or be avoided via interpretation. +- Deterministic (or at least consistent) behavior across platforms where possible. + +## Candidate architectures (high level) + +I. Interpreter / serialized IR (recommended default) + +- Idea: convert the model text to a compact intermediate representation (IR), JSON AST, or small bytecode on the host (this can be done either offline or inside the non-WASM tooling), then ship that IR to the WASM runtime where a small interpreter executes it. + +- Pros: + - Works in all WASM hosts (browser, WASI, standalone runtimes). + - No external toolchain or dynamic linking required inside the WASM module. + - Can be secured and resource-limited easily (single-threaded, deterministic loops, step budgets). + - Simpler lifecycle: the host (browser UI or server) supplies IR; the interpreter runs it. + +- Cons: + - Potentially slower than native code or compiled WASM modules (but can be optimized). + - Must reimplement evaluation semantics for model expressions, numerical integration hooks, and any host APIs used by models. + +II. Precompiled WASM modules from user (plugin-on-wasm) + +- Idea: users compile their model to a small WASM module (using Rust or another language). Pharmsol running in WASM instantiates that module and connects by a well-defined ABI (exports/imports). Compilation occurs outside the Pharmsol WASM runtime (user's machine or a server-side build service). + +- Pros: + - Best performance; compiled code runs as native WASM in the runtime. + - Allows complex user code without embedding an interpreter. + +- Cons: + - Requires the user to compile to WASM themselves, or an external build service. + - ABI ergonomics are complex: sharing complex structures across the WASM boundary needs glue (shared linear memory, serialization, helper allocators). + - Host must provide a precise, versioned import contract (logging, RNG, time, memory management). + +III. Hybrid: interpreter as default + optional WASM plugin path + +- Idea: implement the interpreter for general users and an optional plugin ABI for advanced users or server-based compilation pipelines. This covers both portability and perf. + +IV. Host-compilation pipeline (server assisted) + +- Idea: mirror existing `exa` server-side: accept user model text, run a server-side build pipeline to produce a WASM module (instead of native dynamic library), then deliver the `.wasm` to the client to instantiate. This removes the need to run `cargo` inside the browser. + +- Pros: + - Preserves compiled performance without requiring users to compile locally. + - Centralizes build toolchain and security scanning. + +- Cons: + - Operational burden (CI/build infra), security (compiling untrusted code), distribution and versioning complexity. + +## Detailed considerations and trade-offs + +1) ABI and data exchange + +- Simple serialized-model approach: exchange IR (JSON, CBOR, or MessagePack). The interpreter reads IR and returns results by JSON objects. +- WASM plugin approach: define a small C-like ABI with a fixed set of exported functions. Example minimal exports from plugin module: + - `plugin_version() -> u32` (ABI version) + - `create_model() -> u32` (returns an opaque handle) + - `free_model(handle: u32)` + - `step_model(handle: u32, t: f64, dt: f64, inputs_ptr: u32, inputs_len: u32, outputs_ptr: u32)` + - `get_metadata_json(ptr_out: u32) -> u32` (returns pointer/length pair or writes into host-provided buffer) + +- Memory management: require plugin to export an allocator or follow a simple convention (host provides memory, plugin uses host-provided functions to allocate). Or use a string/byte convention: exports with pointer and length encoded as two 32-bit values. + +2) Passing complex Rust types + +- Avoid trying to share Rust-specific types (Box, owned struct clones) across WASM boundary. Use a stable, language-neutral representation (JSON, CBOR) for metadata and parameters. + +3) Host imports to plugin + +- Plugins will likely need helper imports from the host: random numbers, logging, panic hooks (or traps), allocation, time. Minimally define these imports and keep them stable. + +4) Security / sandboxing + +- WebAssembly provides sandboxing but host must enforce memory, CPU time, and resource constraints. Approaches: + - Use wasm runtimes (Wasmtime, Wasmer) with configurable memory limits and fuel (instruction count) to interrupt long-running modules. + - In browsers, worker time-limits and cooperative stepping. + - Reject exports/imports that give filesystem or network access unless explicitly trusted via WASI capabilities. + +5) Determinism and numeric behavior + +- Floating-point results may differ across hosts; document expected tolerances and avoid depending on platform-specific FP flags. + +6) Threading and concurrency + +- WASM threads are not yet universally available (shared memory / atomics). The wasm-capable module should not assume threads. If the host supports threads, the interpreter or plugin can optionally use them, gated behind feature detection. + +7) Tooling and developer experience + +- For plugin path: provide a `pharmsol-plugin` crate template that exports the standard ABI and instructions for compiling to wasm (cargo build --target wasm32-unknown-unknown or wasm32-wasi, or use wasm-bindgen if targeting JS). Provide examples for Rust and a plain C approach. +- For interpreter path: provide a serializer that converts existing model description (the same text used by `exa`) into IR. Keep the IR stable and versioned. + +8) Size and startup cost + +- Interpreter binary size depends on evaluator complexity. For browser deployments, keep interpreter lean (avoid heavy crates). For plugin path, each user-provided wasm module will increase download size; caching helps. + +9) Compatibility with existing `exa` API + +- `exa::build` and `exa::load` produce `E: Equation` and `Meta` clones. For wasm, design host-side shims that map from plugin/interpreter results into the same `Equation`/`Meta` trait surface used by the rest of Pharmsol. If the host runtime itself is built as WASM and shares the same Rust codebase, define a small adapter layer that converts the plugin or IR results into `Equation` implementations in the runtime. + +10) Error reporting + +- Prefer textual JSON errors with codes. Expose streaming logs during model compilation (if server-assisted build) and during instantiation; for interpreter model parsing, produce structured parse/semantics errors. + +## ABI sketch for a precompiled WASM plugin + +This sketch is intentionally conservative and minimal; if implemented, it should be strongly versioned. + +- Module exports (names and semantics): + + - `plugin_abi_version() -> u32` — numeric ABI version (e.g., 1). + - `plugin_name(ptr: u32, len: u32)` — optional name string (or return pointer/len to host). + - `plugin_create(ptr: u32, len: u32) -> u32` — allocate and return a handle for a model instance created from a JSON blob at (ptr,len) or from internal code. Handle 0 reserved for null/error. + - `plugin_free(handle: u32)` — free instance. + - `plugin_step(handle: u32, t: f64, dt: f64, inputs_ptr: u32, inputs_len: u32, outputs_ptr: u32, outputs_len_ptr: u32) -> i32` — step or evaluate; input/output are serialized arrays or contiguous floats. Return 0 for success or negative error code. + - `plugin_get_metadata(handle: u32, out_ptr_ptr: u32) -> i32` — write a JSON metadata blob to linear memory and return pointer/length via out_ptr_ptr. + +- Host imports (the host should provide these): + + - `host_alloc(size: u32) -> u32` and `host_free(ptr: u32, size: u32)` — optional; otherwise plugin uses its own allocator. + - `host_log(ptr: u32, len: u32, level: u32)` — optional logging. + - `host_random_u32() -> u32` — for deterministic or host-provided RNG. + +Notes: + +- Use string/json for metadata to avoid sharing complex structs. This keeps the plugin language agnostic. +- Use u32 handles and linear memory offsets for safety. + +## Interpreter / serialized IR proposal + +Design a compact IR that expresses: + +- Model metadata (parameters, state variables, initial values, parameter default values) +- Expressions (arithmetic, functions, accessors) — either as an AST or simple stack-based bytecode +- Event definitions or discontinuities (if Pharmsol supports them) + +Representation options: + +- JSON: human readable, easy to debug, larger size. +- CBOR / MessagePack: binary, smaller. +- Custom bytecode: most compact and efficient but takes more work to define and maintain. + +Evaluation engine features: + +- Expression evaluator: compile the AST to a sequence of instructions, then evaluate in a tight loop. +- Integrator interface: provide hooks for integrator state and allow the interpreter to evaluate derivatives; integrators live in the host and call into evaluator for right-hand side computation. +- Caching and JIT-like improvements: precompute evaluation order, constant folding, and expression inlining. + +Why interpreter is attractive: + +- Predictable: no external compilation step. +- Fast to iterate: developer can change model text and send new IR to the runtime without rebuilding. + +Potential downsides: + +- Interpreter complexity can grow if the model language is rich (user functions, closures). Keep the DSL bounded to maintain a fast interpreter. + +## Migration path and compatibility + +1. Define a model-IR serializer in the existing `exa::build` pipeline (native). Add a mode that produces IR instead of a native cdylib. This is low-effort and reuses existing parsing code. +2. Implement the interpreter in the WASM runtime to read IR and produce the `Equation` trait behaviors in the runtime. On native builds, the interpreter can be used as a fallback. +3. Define and publish the plugin WASM ABI and crate template for advanced users. Provide an example repository and CI workflow to produce valid `.wasm` plugins. +4. Keep `exa::load` semantics but offer new functions like `load_from_ir` and `load_wasm_plugin` that map to the same `Equation` / `Meta` surfaces. + +## Testing strategy + +- Unit tests: for IR generation and the interpreter expression evaluator. Use a battery of deterministic tests comparing interpreter outputs to native `exa` results. +- Integration tests: run a model end-to-end through host integrators on both native and wasm targets (use wasm-pack test harness or Wasmtime for server-side tests). +- Fuzzing: target parser and evaluator with malformed inputs to catch edge cases and panics. +- Performance benchmarks: compare interpreter vs plugin vs native compiled models; measure startup and per-step costs. + +## Operational concerns and security + +- If server-side compilation is offered, run untrusted compilations in isolated builder sandboxes and scan outputs for known-bad constructs. Prefer user-provided wasm modules or interpreter IR to avoid running arbitrary native build steps on shared infra. +- For wasm plugin hosting, enforce memory limits and instruction fuel limits (Wasmtime fuel, Wasmer middleware, or browser worker timeouts). + +## Suggested project structure (conceptual) + +- `src/exa/` — keep existing build/load for native platforms. +- `src/exa/ir.rs` — (new) IR definitions and serializer (no implementation here; just noting where it would live). +- `src/exa/interpreter/` — interpreter runtime for models. +- `src/exa/wasm_plugin_spec.md` — a short specification (could reference this document). + +Note: the interpreter can be compiled for both native and wasm targets; keep it dependency-light for browser builds. + +## Performance expectations + +- Interpreter: expect some overhead relative to compiled native code. Reasonable targets: if evaluator is well optimized, derivative evaluation can be within 2–5x slower than compiled native code depending on expression complexity and integrator call frequency. Measurement required. +- WASM plugin: performance similar to native wasm-compiled code (good), but host bridging (serialization) can add overhead. + +## Example migration scenarios (no code) + +1. Browser: user enters model text in UI -> UI sends model text to server or local serializer -> IR (JSON) returned -> browser Pharmsol wasm runtime loads IR -> interpreter executes model. +2. Server (WASM runtime): Accept `.wasm` plugin from advanced user -> instantiate with Wasmtime with resource limits -> use plugin exports as model implementation. + +## Versioning, compatibility, and future-proofing + +- Version the IR and the plugin ABI separately. Include feature flags in the ABI (capabilities mask) so future extensions don't break older hosts. +- Consider aligning plugin ABI with WASI/component model as it stabilizes. + +## Next steps (recommended minimal roadmap) + +1. Add an `IR` serialization mode to the native `exa::build` pipeline so existing tooling can emit IR instead of or in addition to cdylib. (Low-risk, high-value.) +2. Implement a lightweight interpreter in the Pharmsol core, with an API `load_from_ir` that returns an `Equation`/`Meta` instance usable by existing integrators. Prioritize the feature set required by current users (params, states, derivatives, simple events). +3. Design and publish a versioned WASM plugin ABI, crate-template, and documentation for advanced users. Provide a CI-based example to compile a plugin to `.wasm`. +4. Add tests and a benchmark suite comparing native `exa` dynamic-loading, IR-interpreter, and wasm-plugin performance. + +## Appendix: checklist of engineering and QA items + +- [ ] IR schema definition (JSON Schema or protobuf) +- [ ] Parser changes to emit IR +- [ ] Interpreter design doc + micro-benchmarks +- [ ] WASM plugin ABI spec (short document) +- [ ] Crate template for compiling to `.wasm` +- [ ] Example demonstrations (browser and Wasmtime) +- [ ] Security review and sandbox configuration for server-side builds +- [ ] Documentation for end users and plugin authors + +--- + +If you'd like, I can now proceed to: + +- produce a formal `src/exa/wasm_plugin_spec.md` containing a precise ABI table and memory layout (more low-level and concrete), or +- implement the IR serializer and a first-pass interpreter prototype (in code), or +- draft the crate template and CI steps for producing WASM plugin artifacts. + +Tell me which of the next steps you prefer and I will proceed. If you prefer, I can also update the repository with the spec file and a small README for plugin authors. diff --git a/src/exa_wasm/SPEC.md b/src/exa_wasm/SPEC.md new file mode 100644 index 00000000..e880d4e0 --- /dev/null +++ b/src/exa_wasm/SPEC.md @@ -0,0 +1,299 @@ +# exa_wasm — Interpreter + IR: SPEC, current state, gaps, and recommendations + +Generated from reading the entire `src/exa_wasm` and `src/exa_wasm/interpreter` source. + +This document is structured as: + +- Short contract / goals +- IR format and loader contract +- Parser / AST / typechecker contract +- Evaluator semantics and dispatch contract +- Registry / runtime behavior +- Implemented features (what works today) +- Missing features / gaps to replicate Rust arithmetic/PKPD semantics +- Tests (what exists, what is missing, priorities) +- Detailed optimization recommendations (micro + architectural + WASM targets) +- Migration / next-steps and low-risk improvements + +## Contract (inputs, outputs, success criteria) + +- Inputs + - JSON IR file (emitted by `emit_ir`) containing `ir_version`, `kind`, `params`, textual closures for `diffeq`, `lag`, `fa`, `init`, `out` and pre-extracted structured maps `lag_map`, `fa_map`. + - Simulator vectors at runtime: `x` (states), `p` (params), `rateiv` (rate-in vectors), `t` (time scalar), `cov` (covariates object). +- Outputs + - A registered runtime model (RegistryEntry) and an `ODE` wrapper with dispatch functions that the simulator uses: diffeq_dispatch, lag_dispatch, fa_dispatch, init_dispatch, out_dispatch. + - Runtime return values via writes into `dx[]`, `x[]`, `y[]` through provided assignment closures during dispatch. +- Success criteria + - The interpreter evaluates closure code deterministically and produces numerically equal results (within floating differences) compared to the equivalent native Rust ODE code for the same model text. + - Loader rejects ill-formed IR (missing structured maps for lag/fa, index out of bounds, type errors, unknown identifiers in prelude) with informative errors. + +## IR format and emitter (`build::emit_ir`) + +- `emit_ir` produces a JSON object containing: + - ir_version: "1.0" + - kind: equation kind string (via E::kind()) + - params: vector of parameter names (strings) + - diffeq/out/init/lag/fa: textual closures (strings) supplied by caller + - lag_map/fa_map: structured HashMap extracted from textual macro bodies when present + - prelude: not directly emitted; emit_ir extracts `lag_map` and `fa_map` to avoid runtime parsing of textual `lag!` and `fa!` macros +- Notes: + - The runtime loader requires structured `lag_map` and `fa_map` fields (if textual macros are present in the IR, loader will reject unless maps exist). This is explicit loader behavior. + +## Parser / AST / Typechecker + +- Parser + - `parser::tokenize` tokenizes numeric literals, booleans, identifiers, brackets, braces, parentheses, operators, and punctuation. Supports numeric exponent notation and recognizes `true`/`false` as booleans. + - `parser::Parser` implements a recursive-descent parser supporting: + - expressions: numbers, booleans, identifiers, indexed expressions (e.g., `x[0]`, `rateiv[ i ]`), function calls `f(a,b)`, method calls `obj.method(...)`, unary ops (`-`, `!`), binary ops (`+ - * / ^`, comparisons, `==, !=`, `&&, ||`), ternary `cond ? then : else`. + - statements: expression-statement (`expr;`), assignment (`ident = expr;`), indexed assignment (`ident[expr] = expr;`), `if` with optional `else` and block or single-statement branches. It reads semicolons and braces. +- AST + - `ast::Expr`, `ast::Stmt` capture parsed program structure. `Stmt::Assign(Lhs, Expr)` stores Lhs as `Ident` or `Indexed`. +- Typechecker + - `typecheck` implements a conservative checker: numeric and boolean types, ensures indexed-assignment RHS are numeric, index expressions numeric, and attempts to detect obvious mistakes. It accepts numeric/bool coercions similar to evaluator semantics. + +## Evaluator semantics (`eval.rs`) + +- Runtime Value type: enum { Number(f64), Bool(bool) } with coercion rules: + - `as_number()`: Bool -> 1.0/0.0, Number -> value + - `as_bool()`: Number -> value != 0.0, Bool -> value +- Evaluator (`eval_expr`) implements: + - Identifiers: resolves prefixed underscore names (return 0.0), locals map (prelude/assign locals), pmap-mapped parameters, `t` as time, covariates via interpolation when `cov` and `t` provided. + - Indexed: resolves indexed names for `x`, `p/params`, `rateiv`. Performs bounds checks and sets runtime error when out-of-range. + - Calls: evaluates arguments then `eval_call` handles builtin functions. Unknown function falls back to Number(0.0) (no runtime error). + - Binary ops: arithmetic, comparisons, logical with short-circuit behaviour for `&&`, `||`. + - Ternary: use `cond` coercion and evaluate appropriate branch. + - MethodCall: treated as `eval_call(name)` with receiver as first arg. +- `eval_stmt` executes statements, manages `locals` HashMap for named locals, delegates indexed assignments to a closure provided by dispatchers (which perform safe write to dx,x,y or set runtime error on unsupported names). +- `eval_call` implements a set of builtin functions: exp, ln/log, log2/log10, sqrt, pow/powf, min/max, abs, floor, ceil, round, sin/cos/tan, plus `if` macro-like function used when parsing `if(expr, then, else)` calls — returns second or third argument based on first. + +## Loader and `load_ir_ode` behavior + +- Loads JSON, extracts `params` -> builds `pmap` param name -> index. +- Walks `diffeq`, `out`, `init` closures: + - Prefer robust parsing: tries to extract closure body and parse with `Parser::parse_statements()`. + - Runs `typecheck::check_statements()` and rejects IR with type errors. + - If parsing fails, falls back to substring scanning to extract top-level indexed assignments (helpers `extract_all_assign`) and convert them to minimal AST `Assign` nodes. +- Prelude extraction: identifies simple non-indexed `name = expr;` assignments (used as locals) via `extract_prelude` via heuristics. +- `lag` and `fa`: loader expects structured `lag_map`/`fa_map` inside IR; will reject IR missing these fields if textual `lag`/`fa` is non-empty (loader no longer supports runtime textual parsing of `lag!{}` macros unless the `lag_map` exists). +- Validation: loader validates indexes, prelude references, fetch_params!/fetch_cov! macro bodies (basic checks), ensures at least some dx assignments exist. +- On success, constructs `RegistryEntry` containing parsed statements, lag/fa expressions, prelude list, pmap, nstates, nouteqs and registers it in `registry`. + +## Registry / Dispatch contract + +- Registry stores `RegistryEntry` in a global HashMap protected by a Mutex. Entries are referenced by generated `usize` ids. +- `CURRENT_EXPR_ID` is thread-local Cell> used by dispatch functions to determine which entry to execute. +- Dispatch functions (`dispatch.rs`): + - `diffeq_dispatch` runs prelude assignments producing locals, then executes `diffeq_stmts` using `eval_stmt` with an assign closure that allows `dx[index] = value` only. Unsupported indexed assignment names cause runtime error. + - `out_dispatch`: executes `out_stmts` allowing writes to `y[index]` only. + - `lag_dispatch`/`fa_dispatch`: evaluate entries in `lag`/`fa` maps using zeros for x/rateiv and return a HashMap of numeric results. + - `init_dispatch`: executes `init_stmts` allowing writes to `x[index]`. +- Registry exposes: `register_entry`, `unregister_model`, `get_entry`, `ode_for_id` and helper functions to get/set current id and runtime error. + +## Current implemented features (summary) + +- Fully working tokenizer and parser for numeric and boolean expressions, calls, indexing, unary/binary ops, ternary, and `if` statement (with blocks/else). +- Conservative typechecker that catches common type mismatches and forbids assigning boolean to indexed state targets. +- Evaluator with the following key features: + - numeric arithmetic (+ - \* / ^) + - comparisons and boolean ops with short-circuiting + - large set of math builtins: exp, log, ln, log2, log10, sqrt, pow/powf, min/max, abs, floor, ceil, round, sin, cos, tan + - function-call semantics and method-call mapping (receiver passed as first arg) + - identifier resolution: params via `pmap`, locals (prelude) and `t` time + - covariate interpolation support (uses Covariates.interpolate when available) + - indexed access for `x`, `p/params`, `rateiv` with bounds checks. +- Loader: robust multi-mode loader that prefers AST parsing but falls back to substring extraction for simple assignment patterns; prelude extraction and fetch macro validation exist. +- Registry and dispatch wiring: models are registered and produce an `ODE` with dispatch closures that the rest of simulator can call. +- Tests exist that exercise tokenizer, parser, loader fallback, and small loader behaviors. + +## Concrete current tests (found in repository) + +- `src/exa_wasm/mod.rs::tests` + - `test_tokenize_and_parse_simple()` — tokenizes, parses simple expr and evaluates with dummy vectors. + - `test_macro_parsing_load_ir()` — ensures emit_ir produces an IR loadable by `load_ir_ode` (uses `lag!{...}` macro parsing in emit_ir and loader). + - `load_negative_tests::test_loader_errors_when_missing_structured_maps()` — asserts loader rejects IR that provides `lag`/`fa` textual form without `lag_map`/`fa_map`. +- `src/exa_wasm/interpreter/loader.rs::tests` + - Tests for `extract_body` and parsing `if true/false` patterns, ensuring parser normalizes boolean literals and retains top-level `dx` assignment detection. +- `src/exa_wasm/interpreter/mod.rs` includes tests that exercise parser/eval integration. + +## Gaps / Missing functionality (to get closer to full Rust-equivalent arithmetic semantics for PK/PD) + +- Language features missing or limited + - No loops (for/while) or `break`/`continue` constructs — many iterative PKPD constructs sometimes use loops for accumulation or vector operations. + - No block-scoped `let` declarations beyond very small prelude heuristics; `extract_prelude` is conservative and the loader_helpers prelude extraction is a stub in places. + - No support for compound assignment (+=, -=, etc.). + - No support for full macros evaluated at runtime — macros are partially stripped, but complex macro bodies must be processed by emitter (emit_ir) into structured maps. + - No user-land function definition; all functions are builtin only. + - String handling is absent (not needed for arithmetic but relevant for diagnostics). +- Numeric & semantic gaps + - No direct handling of NaN/Inf semantics or explicit domain errors (e.g., log of negative) — evaluator will produce f64 results per Rust but may not raise semantic runtime errors. + - `eval_call` returns Number(0.0) for unknown functions with no runtime error — this hides mistakes (recommend change). + - Limited builtins: missing many mathematical and special functions (erf, erfc, gamma, lgamma, erf_inv, special logistic forms, etc.) commonly used in PKPD. + - No vectorized operations / broadcast: expressions that operate on vectors must be written explicitly with indices. No map/reduce primitives. +- Loader / IR gaps + - Loader does substring scanning for fallbacks — fragile for complex code. The `loader_helpers` module contains stubs (extract_fetch_params, extract_prelude etc.) that are incomplete. + - The runtime requires structured `lag_map` / `fa_map` in IR. emit_ir tries to produce them but tooling that emits IR must be dependable; otherwise loader rejects. + - Pre-resolved param indices: while `pmap` exists on entry, expressions still contain identifier strings in AST rather than resolved index nodes; runtime resolves via pmap hash lookups on each identifier resolution. +- Performance / architecture + - Evaluation uses boxed enums `Value` + recursion + many HashMap lookups for locals and pmap -> hot-path overhead. + - Every identifier resolves via HashMap lookup; locals and pmap lookups happen at runtime repeatedly; branch mispredicts / hash overhead. + - No bytecode or compact expression representation; AST walking is interpreted per-evaluation. + - No precomputation (constant folding) beyond tokenization. +- Safety / ergonomics + - `eval_call` swallowing unknown functions is a usability and correctness risk. + - Runtime errors are stored thread-local but no structured diagnostics with expression positions or model id are emitted. + +## Recommended missing features prioritized + +High priority (for correctness and replication fidelity) + +- Make unknown function calls produce loader or runtime error (not silent Number(0.0)). This will catch typos in IR and user errors. +- Fully implement macro extraction and prelude parsing in `loader_helpers` so loader does not rely on fragile substring heuristics. Emit resolved AST or bytecode from `emit_ir`. +- Resolve parameter identifier -> index mapping during load: transform identifier AST nodes representing params into a param-index variant (avoid hash lookup at runtime). Same for covariates and other well-known identifiers. +- Validate and canonicalize all index expressions at load time when possible (e.g., constant numeric indices), so runtime dispatch can avoid repeated checks. +- Replace textual-scanning-based helper heuristics with parser-driven extraction where possible (safer for complex code). +- Centralize evaluator's builtin lookup to use builtins.rs (we already use builtins in the typechecker; ensure eval and dispatch use the same single source of truth). +- Add unit tests specifically for loader_helpers functions (parse/macro/extraction/validation) to lock-in behavior. +- Add richer error reporting in loader to return structured loader errors (instead of just io::Error with a string) — implement a LoaderError enum that carries TypeError variants and positional info. + +Medium priority (performance / robustness) + +- Add constant folding and simple expression simplifications at load-time. +- Add a simple bytecode (or expression tree) compile step that converts AST into a compact opcode sequence. Implement a small fast interpreter for bytecode. +- Replace `Value` enum with raw f64 in arithmetic paths; booleans can be represented as 0.0/1.0 where appropriate and only coerced when needed — remove boxing in hot path. +- Convert locals from HashMap to an indexed local slot vector created at load-time (map local name -> slot index) and bind to a small Vec at runtime for O(1) access. + +Lower priority (feature expansion) + +- Add additional math builtins used in pharmacometrics: `erf`, `erfc`, `gamma`/`lgamma`, `beta`, special integrals, logistic and Hill functions, `sign`, `clamp`. +- Add explicit error handling primitives and optional runtime checks for domain errors. +- Add optional JIT or WASM codegen path: emit precompiled WASM modules (via Cranelift/wasmtime or hand-rolled emitter) for performance. + +## Detailed optimization recommendations (nuanced) + +These are grouped as quick wins, structural improvements, and advanced options. + +1. Quick wins (safe, low risk) + +- Change `eval_call` behavior: unknown function => set runtime error + return 0.0 or NaN — do not silently return 0.0. This is a correctness fix. +- Convert repeated HashMap lookups for `pmap`/locals into precomputed indices when possible: + - When loading, scan AST for identifier usage: if identifier is a param -> replace with AST node ParamIndex(idx). For local names produced by `prelude` extraction, create local slots with indices and rewrite `Ident(name)` to `Local(slot)` where possible. + - Keep a small structure per `RegistryEntry` describing local name->slot mapping. +- Local slots: replace `HashMap` with `Vec` and `HashMap` only at load-time; runtime uses direct indexing into the Vec. +- Replace `Value` enum in arithmetic evaluation with direct `f64` passing: the only places booleans are needed is logical operators and conditionals; implement `eval_expr_f64` in hot path that returns f64 and treat boolean contexts by test (value != 0.0). Keep a separate boolean evaluation path for `&&/||`. + +2. Structural improvements (medium complexity) + +- Implement an AST -> bytecode compiler: + - Bytecode opcodes: PUSH_CONST(i), LOAD_PARAM(i), LOAD_X(i), LOAD_RATEIV(i), LOAD_LOCAL(i), LOAD_T, CALL_FN(idx), UNARY_NEG, BINARY_ADD, BINARY_MUL, CMP_LT, JUMP_IF_FALSE, ASSIGN_LOCAL(i), ASSIGN_INDEXED(base, idxSlotOrConst), ... + - Pre-resolve function names to small function-table indices at load time to avoid string comparisons per-call. + - Implement a small stack-based VM executor that executes opcodes efficiently using raw f64 and direct array accesses. + - Generate specialized op sequences for `dx`/`x`/`y` assignments to remove runtime string comparison for assignment target. +- Implement constant folding & CSE at compile-time: fold arithmetic on constants and simple algebraic simplifications. +- Implement expression inlining for small functions (if/when user-defined functions are introduced) and partial-evaluation with param constants. + +3. Advanced gains (higher risk / more work) + +- WASM codegen: compile bytecode to WASM functions (either as textual .wat generation or via Cranelift) and instantiate a WASM module that exports the evaluate functions. This yields near-native speed in WASM hosts but increases code complexity. +- JIT to native code: with Cranelift generate machine code for hot expressions — requires careful memory/safety handling, but huge speedups are possible. +- SIMD / vectorization: for models that do repeated elementwise ops across vectors, provide a vectorized runtime or generate WASM SIMD instructions. + +4. Memory and concurrency + +- Ensure registry APIs allow safe concurrency: current EXPR_REGISTRY uses Mutex; consider RwLock if reads dominate. +- Provide lifecycle APIs: `drop_model(id)` and ensure no lingering references; add reference counts if ODE objects can outlive registry removal. + +5. Numeric stability + +- Use f64 consistently but consider `fma` (fused multiply-add) via libm if available for certain patterns. +- Add optional runtime checks for over/underflow and domain errors that can be enabled by a debug mode when running models. + +## Tests: what exists, what to add (granular) + +Existing tests (detected): + +- Parser & tokenizer correctness: many tests in `interpreter/loader.rs` and module-level tests. +- Loader negative test: missing structured maps rejection. +- Parser/If normalization tests: ensure `true` => `1.0` and `false` => `0.0`, and `if` constructs parsed and converted properly. + +Missing tests (priority ordered) + +1. High priority correctness tests + +- Numeric equivalence tests: For a set of representative models, compare outputs of native Rust ODE vs exa_wasm ODE for a range of times and parameter vectors. (Property-based or fixture-based) +- Unknown function handling: ensure loader/runtime errors for unknown function names (after implementing the fix above). +- Parameter resolution: ensure params referred in code map to correct p indices and produce same numeric results as native extraction. +- Index bounds: negative/large indices should produce loader or runtime errors. +- Prelude ordering: test cases where prelude assignment depends on earlier prelude variables; ensure order respected. + +2. Medium priority behavioral tests + +- Logical short-circuit: ensure `&&`/`||` do not evaluate RHS when LHS decides. +- Ternary and `if()` builtin parity: ensure both mechanisms yield same results. +- Covariate interpolation behavior: tests covering valid/invalid times and missing covariate names. +- Lag/fa maps: ensure `lag_map` values are used and loader rejects textual-only forms. + +3. Performance & regression tests + +- Microbenchmarks: measure hot path eval time for simple arithmetic expressions vs AST bytecode VM vs native function pointer version. +- Stress tests for registry: many load/unload cycles to check for leaks and correctness. + +4. Fuzz / edge cases + +- Random expression fuzzing to ensure parser doesn't panic and loader returns acceptable error messages. +- Numeric edge cases: division by zero, log negative, pow with non-integer exponents of negative values — ensure predictable behavior or documented errors. + +Suggested test harness additions + +- A small test runner that loads a set of model pairs (native and IR) and asserts predictions and likelihoods match within tolerance — this can be used in CI. +- Use `approx` crate for floating comparisons with relative and absolute tolerances. + +## Low-risk, high-value immediate changes (implementation steps) + +1. Change `eval_call` to report unknown function names as errors. +2. Implement param-id -> ParamIndex AST node and rewrite AST at load-time to resolve `Ident` representing parameters. +3. Replace locals HashMap with Vec slots and a local-name->slot map produced at load-time. +4. Add unit tests to assert that unknown functions trigger loader/runtime errors. + +## Longer-term plan (roadmap) + +- Phase 1 (0-2 weeks): correctness fixes and small refactors + - Unknown function error, param resolution, local slots, implement more loader_helpers to remove substring heuristics. + - Add tests that assert numeric parity for a few canonical ODEs. +- Phase 2 (2-6 weeks): interpreter performance + - AST -> bytecode compiler, VM runtime, constant folding, pre-resolved function table. + - Add microbenchmarks and CI perf checks. +- Phase 3 (6+ weeks): WASM/native code generation + - Emit precompiled WASM modules for hot models and add runtime switches: interpret vs wasm vs native. + - Investigate JIT via Cranelift for server-side/back-end tooling. + +## Developer notes and rationale + +- The current code prioritizes correctness and simplicity over raw performance: AST parsing and `eval_expr` recursion are straightforward and robust, and loader performs conservative validation to avoid silent miscompilation. +- The main friction points are runtime hash lookups and string-based resolution of identifiers, and the presence of fallback substring parsing in loader which is fragile for complex closures. +- An incremental approach (resolve param/local names at load-time and add a small bytecode interpreter) yields excellent benefit/cost ratio before pursuing WASM or JIT compilation. + +## Recommended SPEC additions to the IR (future) + +- Add resolved metadata fields per expression emitted by `emit_ir`: + - `pmap` (already present at loader) but also an AST/bytecode serialization (e.g., base64 compressed bytecode) so the runtime doesn't need to re-parse expressions. + - `funcs`: list of builtin functions used so loader can validate and map to indexes. + - `locals`: prelude local names and evaluation order. + - `constants`: constant table for deduping floats. + +## Security / safety considerations + +- Evaluating arbitrary IR should be considered untrusted input if IR comes from external sources. Prefer to validate and sandbox execution. The current interpreter runs in-process with no sandboxing; emitting compiled WASM to a WASM runtime (wasmtime) provides stronger isolation if needed. + +## Quick checklist summary (what changed / what to do next) + +- I inspected and documented all files in `src/exa_wasm` and `src/exa_wasm/interpreter`. +- Next, implement the high-priority fixes described above (unknown-function errors, AST param resolution, local slot mapping) and add the corresponding unit tests. + +--- + +If you'd like, I can: + +- Open a follow-up PR that implements the first low-risk fixes (unknown function -> error, param resolution rewrite, local slots), with unit tests and benchmarks. +- Or, generate the initial bytecode VM design and a minimal implementation for one expression type (binary arithmetic) so you can see the performance improvement baseline. + +Tell me which follow-up you prefer and I'll implement it (I will update the todo list and write the code + tests). diff --git a/src/exa_wasm/build.rs b/src/exa_wasm/build.rs new file mode 100644 index 00000000..54c5e287 --- /dev/null +++ b/src/exa_wasm/build.rs @@ -0,0 +1,1131 @@ +use std::env; +use std::fs; +use std::io; +use std::path::PathBuf; + +use rand::Rng; +use rand_distr::Alphanumeric; + +/// Emit a minimal JSON IR for a model (WASM-friendly emitter). +pub fn emit_ir( + diffeq_txt: String, + lag_txt: Option, + fa_txt: Option, + init_txt: Option, + out_txt: Option, + output: Option, + params: Vec, +) -> Result { + use serde_json::json; + use std::collections::HashMap; + + // Extract structured lag/fa maps from macro text so the runtime does not + // need to re-parse macro bodies. These will be empty maps if not present. + fn extract_macro_map(src: &str, mac: &str) -> HashMap { + let mut res = HashMap::new(); + let mut search = 0usize; + while let Some(pos) = src[search..].find(mac) { + let start = search + pos; + if let Some(lb_rel) = src[start..].find('{') { + let lb = start + lb_rel; + let mut depth: isize = 0; + let mut i = lb; + let bytes = src.as_bytes(); + let len = src.len(); + let mut end_opt: Option = None; + while i < len { + match bytes[i] as char { + '{' => depth += 1, + '}' => { + depth -= 1; + if depth == 0 { + end_opt = Some(i); + break; + } + } + _ => {} + } + i += 1; + } + if let Some(rb) = end_opt { + let body = &src[lb + 1..rb]; + // split top-level entries by commas not inside parentheses/braces + let mut entry = String::new(); + let mut paren = 0isize; + let mut brace = 0isize; + for ch in body.chars() { + match ch { + '(' => { + paren += 1; + entry.push(ch); + } + ')' => { + paren -= 1; + entry.push(ch); + } + '{' => { + brace += 1; + entry.push(ch); + } + '}' => { + brace -= 1; + entry.push(ch); + } + ',' if paren == 0 && brace == 0 => { + let parts: Vec<&str> = entry.split("=>").collect(); + if parts.len() == 2 { + if let Ok(k) = parts[0].trim().parse::() { + res.insert(k, parts[1].trim().to_string()); + } + } + entry.clear(); + } + _ => entry.push(ch), + } + } + if !entry.trim().is_empty() { + let parts: Vec<&str> = entry.split("=>").collect(); + if parts.len() == 2 { + if let Ok(k) = parts[0].trim().parse::() { + res.insert(k, parts[1].trim().to_string()); + } + } + } + search = rb + 1; + continue; + } + } + search = start + mac.len(); + } + res + } + + let lag_map = extract_macro_map(lag_txt.as_deref().unwrap_or(""), "lag!"); + let fa_map = extract_macro_map(fa_txt.as_deref().unwrap_or(""), "fa!"); + + // Try to parse and emit pre-parsed AST for diffeq/init/out closures so the + // runtime loader can skip text parsing. We will rewrite parameter + // identifiers into Param(index) nodes using the supplied params vector. + let mut diffeq_ast_val = serde_json::Value::Null; + let mut out_ast_val = serde_json::Value::Null; + let mut init_ast_val = serde_json::Value::Null; + + // Build param -> index map + let mut pmap: std::collections::HashMap = std::collections::HashMap::new(); + for (i, n) in params.iter().enumerate() { + pmap.insert(n.clone(), i); + } + + // helper to parse a closure text into Vec + // This emitter requires closures to parse successfully; if parsing fails + // we return an error rather than emitting textual closures. That lets the + // runtime rely on a single robust pipeline (AST + bytecode) instead of + // fragile textual fallbacks. + fn try_parse_and_rewrite( + src: &str, + pmap: &std::collections::HashMap, + ) -> Option> { + if let Some(body) = crate::exa_wasm::interpreter::extract_closure_body(src) { + // remove fetch_* macro invocations from the body so the parser + // (which doesn't understand macros) can parse the remaining + // statements. We strip the entire macro invocation including + // its balanced parentheses or braces. + let mut cleaned = body.to_string(); + let macro_names = ["fetch_params!", "fetch_param!", "fetch_cov!"]; + for mac in macro_names.iter() { + loop { + if let Some(pos) = cleaned.find(mac) { + // find next delimiter '(' or '{' after the macro name + let after = pos + mac.len(); + if after >= cleaned.len() { + cleaned.replace_range(pos..after, ""); + break; + } + let ch = cleaned.as_bytes()[after] as char; + if ch == '(' || ch == '{' { + let open = ch; + let close = if open == '(' { ')' } else { '}' }; + let mut depth: isize = 0; + let mut i = after; + let mut found: Option = None; + while i < cleaned.len() { + let c = cleaned.as_bytes()[i] as char; + if c == open { + depth += 1; + } else if c == close { + depth -= 1; + if depth == 0 { + found = Some(i); + break; + } + } + i += 1; + } + if let Some(rb) = found { + // remove from pos..=rb + cleaned.replace_range(pos..=rb, ""); + continue; + } else { + // nothing balanced: remove macro name only + cleaned.replace_range(pos..after + 1, ""); + continue; + } + } else { + // no delimiter: remove the macro token only + cleaned.replace_range(pos..after, ""); + continue; + } + } + break; + } + } + + // tidy up stray semicolons resulting from macro removals + while cleaned.contains(";;") { + cleaned = cleaned.replace(";;", ";"); + } + // remove leading whitespace and any stray leading semicolon left by macro removal + cleaned = cleaned.trim_start().to_string(); + if cleaned.starts_with(';') { + cleaned = cleaned[1..].to_string(); + } + // cleaned closure ready for tokenization + let toks = crate::exa_wasm::interpreter::tokenize(&cleaned); + let mut p = crate::exa_wasm::interpreter::Parser::new(toks); + let mut stmts = match p.parse_statements() { + Some(s) => s, + None => return None, + }; + + // rewrite identifiers that refer to parameters into Param(index) + fn rewrite_expr(e: &mut crate::exa_wasm::interpreter::Expr, pmap: &std::collections::HashMap) { + use crate::exa_wasm::interpreter::Expr::*; + match e { + Number(_) | Bool(_) | Param(_) => {} + Ident(name) => { + if let Some(i) = pmap.get(name) { + *e = Param(*i); + } + } + Indexed(_, idx) => { + rewrite_expr(idx, pmap); + } + UnaryOp { rhs, .. } => rewrite_expr(rhs, pmap), + BinaryOp { lhs, rhs, .. } => { + rewrite_expr(lhs, pmap); + rewrite_expr(rhs, pmap); + } + Call { args, .. } => { + for a in args.iter_mut() { rewrite_expr(a, pmap); } + } + MethodCall { receiver, args, .. } => { + rewrite_expr(receiver, pmap); + for a in args.iter_mut() { rewrite_expr(a, pmap); } + } + Ternary { cond, then_branch, else_branch } => { + rewrite_expr(cond, pmap); + rewrite_expr(then_branch, pmap); + rewrite_expr(else_branch, pmap); + } + } + } + + fn rewrite_stmt(s: &mut crate::exa_wasm::interpreter::Stmt, pmap: &std::collections::HashMap) { + match s { + crate::exa_wasm::interpreter::Stmt::Expr(e) => rewrite_expr(e, pmap), + crate::exa_wasm::interpreter::Stmt::Assign(lhs, rhs) => { + match lhs { + crate::exa_wasm::interpreter::Lhs::Indexed(_, idx) => rewrite_expr(idx, pmap), + crate::exa_wasm::interpreter::Lhs::Ident(_) => {} + } + rewrite_expr(rhs, pmap); + } + crate::exa_wasm::interpreter::Stmt::Block(v) => { + for st in v.iter_mut() { rewrite_stmt(st, pmap); } + } + crate::exa_wasm::interpreter::Stmt::If { cond, then_branch, else_branch } => { + rewrite_expr(cond, pmap); + rewrite_stmt(then_branch, pmap); + if let Some(eb) = else_branch { rewrite_stmt(eb, pmap); } + } + } + } + + for st in stmts.iter_mut() { rewrite_stmt(st, pmap); } + return Some(stmts); + } + None + } + + if let Some(stmts) = try_parse_and_rewrite(&diffeq_txt, &pmap) { + diffeq_ast_val = serde_json::to_value(&stmts).unwrap_or(serde_json::Value::Null); + } + if let Some(stmts) = try_parse_and_rewrite(out_txt.as_deref().unwrap_or(""), &pmap) { + out_ast_val = serde_json::to_value(&stmts).unwrap_or(serde_json::Value::Null); + } + if let Some(stmts) = try_parse_and_rewrite(init_txt.as_deref().unwrap_or(""), &pmap) { + init_ast_val = serde_json::to_value(&stmts).unwrap_or(serde_json::Value::Null); + } + + let mut ir_obj = json!({ + "ir_version": "1.0", + "kind": E::kind().to_str(), + "params": params, + "diffeq": diffeq_txt, + "lag": lag_txt, + "fa": fa_txt, + "lag_map": lag_map, + "fa_map": fa_map, + "init": init_txt, + "out": out_txt, + // IR schema field so consumers can be resilient to future AST/IR changes + "ir_schema": { "version": "1.0", "ast_version": "1" }, + }); + + // attach parsed ASTs when present + if !diffeq_ast_val.is_null() { + ir_obj["diffeq_ast"] = diffeq_ast_val; + } + if !out_ast_val.is_null() { + ir_obj["out_ast"] = out_ast_val; + } + if !init_ast_val.is_null() { + ir_obj["init_ast"] = init_ast_val; + } + + // Extract fetch_params! and fetch_cov! macro bodies from closure texts and + // attach to IR so loader can validate without scanning raw text at runtime. + fn extract_fetch_bodies(src: &str, name: &str) -> Vec { + let mut res = Vec::new(); + let mut rest = src; + while let Some(pos) = rest.find(name) { + if let Some(lb_rel) = rest[pos..].find('(') { + let tail = &rest[pos + lb_rel + 1..]; + let mut depth: isize = 0; + let mut i = 0usize; + let bytes = tail.as_bytes(); + let mut found: Option = None; + while i < tail.len() { + match bytes[i] as char { + '(' => depth += 1, + ')' => { + if depth == 0 { + found = Some(i); + break; + } + depth -= 1; + } + _ => {} + } + i += 1; + } + if let Some(rb) = found { + let body = &tail[..rb]; + res.push(body.to_string()); + rest = &tail[rb + 1..]; + continue; + } + } + rest = &rest[pos + name.len()..]; + } + res + } + + let mut fetch_params_bodies: Vec = Vec::new(); + fetch_params_bodies.extend(extract_fetch_bodies(&diffeq_txt, "fetch_params!")); + fetch_params_bodies.extend(extract_fetch_bodies(&diffeq_txt, "fetch_param!")); + fetch_params_bodies.extend(extract_fetch_bodies( + out_txt.as_deref().unwrap_or(""), + "fetch_params!", + )); + fetch_params_bodies.extend(extract_fetch_bodies( + out_txt.as_deref().unwrap_or(""), + "fetch_param!", + )); + fetch_params_bodies.extend(extract_fetch_bodies( + init_txt.as_deref().unwrap_or(""), + "fetch_params!", + )); + fetch_params_bodies.extend(extract_fetch_bodies( + init_txt.as_deref().unwrap_or(""), + "fetch_param!", + )); + + let mut fetch_cov_bodies: Vec = Vec::new(); + fetch_cov_bodies.extend(extract_fetch_bodies(&diffeq_txt, "fetch_cov!")); + fetch_cov_bodies.extend(extract_fetch_bodies( + out_txt.as_deref().unwrap_or(""), + "fetch_cov!", + )); + fetch_cov_bodies.extend(extract_fetch_bodies( + init_txt.as_deref().unwrap_or(""), + "fetch_cov!", + )); + + if !fetch_params_bodies.is_empty() { + ir_obj["fetch_params"] = + serde_json::to_value(&fetch_params_bodies).unwrap_or(serde_json::Value::Null); + } + if !fetch_cov_bodies.is_empty() { + ir_obj["fetch_cov"] = + serde_json::to_value(&fetch_cov_bodies).unwrap_or(serde_json::Value::Null); + } + + // Compile expressions into bytecode. This compiler covers numeric + // literals, Param(i), simple indexed loads with constant indices (x/p/rateiv), + // locals, unary -, binary ops, calls to known builtins, and ternary. + fn compile_expr_top( + expr: &crate::exa_wasm::interpreter::Expr, + out: &mut Vec, + funcs: &mut Vec, + locals: &Vec, + ) -> bool { + use crate::exa_wasm::interpreter::{Expr, Opcode}; + match expr { + Expr::Number(n) => { + out.push(Opcode::PushConst(*n)); + true + } + Expr::Bool(b) => { + out.push(Opcode::PushConst(if *b { 1.0 } else { 0.0 })); + true + } + Expr::Param(i) => { + out.push(Opcode::LoadParam(*i)); + true + } + Expr::Ident(name) => { + // treat as local if present + if let Some(pos) = locals.iter().position(|n| n == name) { + out.push(Opcode::LoadLocal(pos)); + true + } else { + // unknown bare identifier — compilation fails; loader will + // catch this earlier via typecheck, but be conservative here + false + } + } + Expr::Indexed(name, idx_expr) => { + // support constant numeric indices and dynamic indices + if let Expr::Number(n) = &**idx_expr { + let idx = *n as usize; + match name.as_str() { + "x" => { + out.push(Opcode::LoadX(idx)); + true + } + "rateiv" => { + out.push(Opcode::LoadRateiv(idx)); + true + } + "p" | "params" => { + out.push(Opcode::LoadParam(idx)); + true + } + _ => false, + } + } else { + // dynamic index: compile index expression then emit a dyn-load + if !compile_expr_top(idx_expr, out, funcs, locals) { + return false; + } + match name.as_str() { + "x" => { + out.push(Opcode::LoadXDyn); + true + } + "rateiv" => { + out.push(Opcode::LoadRateivDyn); + true + } + "p" | "params" => { + out.push(Opcode::LoadParamDyn); + true + } + _ => false, + } + } + } + Expr::UnaryOp { op, rhs } => { + if op == "-" { + if !compile_expr_top(rhs, out, funcs, locals) { + return false; + } + // multiply by -1.0 + out.push(Opcode::PushConst(-1.0)); + out.push(Opcode::Mul); + true + } else { + false + } + } + Expr::BinaryOp { lhs, op, rhs } => { + // handle short-circuit logical operators specially so we + // preserve AST semantics (avoid evaluating rhs when not + // necessary). For arithmetic/comparison operators we + // compile both sides in order. + match op.as_str() { + "&&" => { + // lhs && rhs -> if lhs==0.0 jump to push 0.0; else evaluate rhs and return rhs!=0 as 0/1 + if !compile_expr_top(lhs, out, funcs, locals) { + return false; + } + // JumpIfFalse to false path if lhs is false + let jf_pos = out.len(); + out.push(Opcode::JumpIfFalse(0)); + + // evaluate rhs + if !compile_expr_top(rhs, out, funcs, locals) { + return false; + } + // if rhs is false -> push 0, else push 1 + let jf2_pos = out.len(); + out.push(Opcode::JumpIfFalse(0)); + // rhs true -> push 1 + out.push(Opcode::PushConst(1.0)); + // jump to end + let jmp_pos = out.len(); + out.push(Opcode::Jump(0)); + // false path + let false_pos = out.len(); + // set first JumpIfFalse target to false_pos + if let Opcode::JumpIfFalse(ref mut addr) = out[jf_pos] { + *addr = false_pos; + } + // push 0.0 for false + out.push(Opcode::PushConst(0.0)); + // fix jumps + let end_pos = out.len(); + if let Opcode::Jump(ref mut addr) = out[jmp_pos] { + *addr = end_pos; + } + if let Opcode::JumpIfFalse(ref mut addr) = out[jf2_pos] { + *addr = false_pos; + } + true + } + "||" => { + // lhs || rhs -> if lhs != 0 -> push 1 and skip rhs; else evaluate rhs and return rhs!=0 as 0/1 + if !compile_expr_top(lhs, out, funcs, locals) { + return false; + } + // if lhs is false, evaluate rhs; JumpIfFalse should jump to rhs + let jf_pos = out.len(); + out.push(Opcode::JumpIfFalse(0)); + // lhs true -> push 1 + out.push(Opcode::PushConst(1.0)); + // jump to end + let jmp_pos = out.len(); + out.push(Opcode::Jump(0)); + // else/rhs path + let else_pos = out.len(); + if let Opcode::JumpIfFalse(ref mut addr) = out[jf_pos] { + *addr = else_pos; + } + // evaluate rhs + if !compile_expr_top(rhs, out, funcs, locals) { + return false; + } + // now convert rhs to 0/1 + let jf2_pos = out.len(); + out.push(Opcode::JumpIfFalse(0)); + out.push(Opcode::PushConst(1.0)); + let jmp2 = out.len(); + out.push(Opcode::Jump(0)); + let false_pos = out.len(); + if let Opcode::JumpIfFalse(ref mut addr) = out[jf2_pos] { + *addr = false_pos; + } + out.push(Opcode::PushConst(0.0)); + let end_pos = out.len(); + if let Opcode::Jump(ref mut addr) = out[jmp_pos] { + *addr = end_pos; + } + if let Opcode::Jump(ref mut addr) = out[jmp2] { + *addr = end_pos; + } + true + } + _ => { + // default: arithmetic/comparison operators compile lhs then rhs + if !compile_expr_top(lhs, out, funcs, locals) { + return false; + } + if !compile_expr_top(rhs, out, funcs, locals) { + return false; + } + match op.as_str() { + "+" => out.push(Opcode::Add), + "-" => out.push(Opcode::Sub), + "*" => out.push(Opcode::Mul), + "/" => out.push(Opcode::Div), + "^" => out.push(Opcode::Pow), + "<" => out.push(Opcode::Lt), + ">" => out.push(Opcode::Gt), + "<=" => out.push(Opcode::Le), + ">=" => out.push(Opcode::Ge), + "==" => out.push(Opcode::Eq), + "!=" => out.push(Opcode::Ne), + _ => return false, + } + true + } + } + } + Expr::Call { name, args } => { + // only compile known builtins and check arity + if crate::exa_wasm::interpreter::is_known_function(name.as_str()) { + // verify arity where possible + if let Some(rng) = crate::exa_wasm::interpreter::arg_count_range(name.as_str()) + { + if !rng.contains(&args.len()) { + return false; + } + } + // compile args + for a in args.iter() { + if !compile_expr_top(a, out, funcs, locals) { + return false; + } + } + // register function in funcs table + let idx = match funcs.iter().position(|f| f == name) { + Some(i) => i, + None => { + funcs.push(name.clone()); + funcs.len() - 1 + } + }; + out.push(Opcode::CallBuiltin(idx, args.len())); + true + } else { + false + } + } + Expr::MethodCall { + receiver, + name, + args, + } => { + // lower method call to function with receiver as first arg + if crate::exa_wasm::interpreter::is_known_function(name.as_str()) { + // verify arity for method-style calls + if let Some(rng) = crate::exa_wasm::interpreter::arg_count_range(name.as_str()) + { + if !rng.contains(&(1 + args.len())) { + return false; + } + } + if !compile_expr_top(receiver, out, funcs, locals) { + return false; + } + for a in args.iter() { + if !compile_expr_top(a, out, funcs, locals) { + return false; + } + } + let idx = match funcs.iter().position(|f| f == name) { + Some(i) => i, + None => { + funcs.push(name.clone()); + funcs.len() - 1 + } + }; + out.push(Opcode::CallBuiltin(idx, 1 + args.len())); + true + } else { + false + } + } + Expr::Ternary { + cond, + then_branch, + else_branch, + } => { + // compile cond + if !compile_expr_top(cond, out, funcs, locals) { + return false; + } + // emit JumpIfFalse to else + let jf_pos = out.len(); + out.push(Opcode::JumpIfFalse(0)); + // then + if !compile_expr_top(then_branch, out, funcs, locals) { + return false; + } + // jump over else + let jmp_pos = out.len(); + out.push(Opcode::Jump(0)); + // fix JumpIfFalse target + let else_pos = out.len(); + if let Opcode::JumpIfFalse(ref mut addr) = out[jf_pos] { + *addr = else_pos; + } + // else + if !compile_expr_top(else_branch, out, funcs, locals) { + return false; + } + // fix Jump target + let end_pos = out.len(); + if let Opcode::Jump(ref mut addr) = out[jmp_pos] { + *addr = end_pos; + } + true + } + } + } + + let mut bytecode_map: HashMap> = + HashMap::new(); + // shared tables discovered during compilation + let mut shared_funcs: Vec = Vec::new(); + let mut shared_locals: Vec = Vec::new(); + + if let Some(v) = ir_obj.get("diffeq_ast") { + // try to deserialize back into AST + match serde_json::from_value::>(v.clone()) { + Ok(stmts) => { + // collect local variable names from non-indexed assignments + for st in stmts.iter() { + if let crate::exa_wasm::interpreter::Stmt::Assign(lhs, _rhs) = st { + if let crate::exa_wasm::interpreter::Lhs::Ident(name) = lhs { + if !shared_locals.iter().any(|n| n == name) { + shared_locals.push(name.clone()); + } + } + } + } + + // reuse compile_expr_top defined above for expression compilation + // but visit statements recursively so nested Blocks/Ifs are + // handled (previous code only inspected top-level stmts). + fn visit_stmt( + st: &crate::exa_wasm::interpreter::Stmt, + bytecode_map: &mut std::collections::HashMap< + usize, + Vec, + >, + shared_funcs: &mut Vec, + shared_locals: &Vec, + ) { + match st { + crate::exa_wasm::interpreter::Stmt::Assign(lhs, rhs) => { + if let crate::exa_wasm::interpreter::Lhs::Indexed(_name, idx_expr) = lhs + { + if _name == "dx" { + // constant index + if let crate::exa_wasm::interpreter::Expr::Number(n) = + &**idx_expr + { + let idx = *n as usize; + let mut code: Vec = + Vec::new(); + if compile_expr_top( + rhs, + &mut code, + shared_funcs, + shared_locals, + ) { + code.push( + crate::exa_wasm::interpreter::Opcode::StoreDx(idx), + ); + bytecode_map.insert(idx, code); + } + } else { + // dynamic index: compile index then rhs then StoreDxDyn + let mut code: Vec = + Vec::new(); + if compile_expr_top( + idx_expr, + &mut code, + shared_funcs, + shared_locals, + ) && compile_expr_top( + rhs, + &mut code, + shared_funcs, + shared_locals, + ) { + code.push( + crate::exa_wasm::interpreter::Opcode::StoreDxDyn, + ); + // use a special key for dynamic-indexed entries + bytecode_map.insert(usize::MAX, code); + } + } + } + } + } + crate::exa_wasm::interpreter::Stmt::Block(v) => { + for ss in v.iter() { + visit_stmt(ss, bytecode_map, shared_funcs, shared_locals); + } + } + crate::exa_wasm::interpreter::Stmt::If { + cond, + then_branch, + else_branch, + } => { + // Only lower conditional assignments into bytecode when + // the condition is a compile-time constant boolean. For + // unknown/runtime conditions we skip bytecode lowering + // so the runtime can evaluate the AST (preserves + // short-circuit/conditional semantics). + match cond { + crate::exa_wasm::interpreter::Expr::Bool(b) => { + if *b { + visit_stmt( + then_branch, + bytecode_map, + shared_funcs, + shared_locals, + ); + } else if let Some(eb) = else_branch { + visit_stmt(eb, bytecode_map, shared_funcs, shared_locals); + } + } + _ => { + // runtime condition: do not emit bytecode for + // nested assignments under this If + } + } + } + crate::exa_wasm::interpreter::Stmt::Expr(_) => {} + } + } + + for st in stmts.iter() { + visit_stmt(st, &mut bytecode_map, &mut shared_funcs, &shared_locals); + } + } + Err(_) => {} + } + } + + // NOTE: textual fallback parsing/compilation was removed. The emitter + // must provide `diffeq_ast` or `diffeq_bytecode` in the IR; runtime + // parsing of closure text is no longer supported. + + if !bytecode_map.is_empty() { + // emit the conservative diffeq bytecode map under the new IR field names + ir_obj["bytecode_map"] = + serde_json::to_value(&bytecode_map).unwrap_or(serde_json::Value::Null); + // new field expected by loader: diffeq_bytecode (index -> opcode sequence) + ir_obj["diffeq_bytecode"] = + serde_json::to_value(&bytecode_map).unwrap_or(serde_json::Value::Null); + // emit discovered funcs/locals if any + if !shared_funcs.is_empty() { + ir_obj["funcs"] = + serde_json::to_value(&shared_funcs).unwrap_or(serde_json::Value::Null); + } + if !shared_locals.is_empty() { + ir_obj["locals"] = + serde_json::to_value(&shared_locals).unwrap_or(serde_json::Value::Null); + } + } + + // If we have a parsed diffeq AST, attempt to lower the entire statement + // vector into a single function-level bytecode vector. This preserves + // full control-flow and scoping semantics for arbitrary If/Block nests. + if let Some(v) = ir_obj.get("diffeq_ast") { + if let Ok(stmts) = serde_json::from_value::>(v.clone()) { + // shared tables discovered during compilation + let mut funcs_for_func: Vec = shared_funcs.clone(); + + // helper to compile statements into a single code vector + fn compile_stmt( + st: &crate::exa_wasm::interpreter::Stmt, + out: &mut Vec, + funcs: &mut Vec, + locals: &Vec, + ) -> bool { + use crate::exa_wasm::interpreter::{Expr, Opcode, Stmt, Lhs}; + match st { + Stmt::Assign(lhs, rhs) => { + match lhs { + Lhs::Indexed(name, idx_expr) => { + if name == "dx" { + if let Expr::Number(n) = &**idx_expr { + let idx = *n as usize; + if !compile_expr_top(rhs, out, funcs, locals) { + return false; + } + out.push(Opcode::StoreDx(idx)); + true + } else { + // dynamic index: compile index then rhs then StoreDxDyn + if !compile_expr_top(idx_expr, out, funcs, locals) { return false; } + if !compile_expr_top(rhs, out, funcs, locals) { return false; } + out.push(Opcode::StoreDxDyn); + true + } + } else if name == "x" || name == "y" { + // support writes to x/y (e.g., init/out contexts) + if let Expr::Number(n) = &**idx_expr { + let idx = *n as usize; + if !compile_expr_top(rhs, out, funcs, locals) { return false; } + if name == "x" { + out.push(Opcode::StoreX(idx)); + } else { + out.push(Opcode::StoreY(idx)); + } + true + } else { + if !compile_expr_top(idx_expr, out, funcs, locals) { return false; } + if !compile_expr_top(rhs, out, funcs, locals) { return false; } + if name == "x" { out.push(Opcode::StoreXDyn); } else { out.push(Opcode::StoreYDyn); } + true + } + } else { + false + } + } + Lhs::Ident(name) => { + // local assignment: find slot in locals + if let Some(pos) = locals.iter().position(|n| n == name) { + if !compile_expr_top(rhs, out, funcs, locals) { return false; } + out.push(Opcode::StoreLocal(pos)); + true + } else { + // unknown local: fail + false + } + } + } + } + Stmt::Expr(e) => { + if !compile_expr_top(e, out, funcs, locals) { return false; } + // discard expression result + out.push(Opcode::Pop); + true + } + Stmt::Block(v) => { + for s in v.iter() { + if !compile_stmt(s, out, funcs, locals) { return false; } + } + true + } + Stmt::If { cond, then_branch, else_branch } => { + // compile condition + if !compile_expr_top(cond, out, funcs, locals) { return false; } + // placeholder JumpIfFalse + let jf_pos = out.len(); + out.push(Opcode::JumpIfFalse(0)); + // then branch + if !compile_stmt(then_branch, out, funcs, locals) { return false; } + // jump over else + let jmp_pos = out.len(); + out.push(Opcode::Jump(0)); + // else position + let else_pos = out.len(); + if let Opcode::JumpIfFalse(ref mut addr) = out[jf_pos] { + *addr = else_pos; + } + if let Some(eb) = else_branch { + if !compile_stmt(eb, out, funcs, locals) { return false; } + } + // fix jump target + let end_pos = out.len(); + if let Opcode::Jump(ref mut addr) = out[jmp_pos] { + *addr = end_pos; + } + true + } + } + } + + let mut func_code: Vec = Vec::new(); + for st in stmts.iter() { + if !compile_stmt(st, &mut func_code, &mut funcs_for_func, &shared_locals) { + func_code.clear(); + break; + } + } + if !func_code.is_empty() { + ir_obj["diffeq_func"] = serde_json::to_value(&func_code).unwrap_or(serde_json::Value::Null); + if !funcs_for_func.is_empty() { + ir_obj["funcs"] = serde_json::to_value(&funcs_for_func).unwrap_or(serde_json::Value::Null); + } + if !shared_locals.is_empty() { + ir_obj["locals"] = serde_json::to_value(&shared_locals).unwrap_or(serde_json::Value::Null); + } + } + } + } + + // Attempt to compile out/init closures into bytecode similarly to diffeq POC + let mut out_bytecode_map: HashMap> = + HashMap::new(); + let mut init_bytecode_map: HashMap> = + HashMap::new(); + + // Helper to compile an Assign stmt into bytecode when LHS is y[idx] or x[idx] + if let Some(v) = ir_obj.get("out_ast") { + if let Ok(stmts) = + serde_json::from_value::>(v.clone()) + { + for st in stmts.iter() { + if let crate::exa_wasm::interpreter::Stmt::Assign(lhs, rhs) = st { + if let crate::exa_wasm::interpreter::Lhs::Indexed(_name, idx_expr) = lhs { + if let crate::exa_wasm::interpreter::Expr::Number(n) = &**idx_expr { + let idx = *n as usize; + let mut code: Vec = Vec::new(); + if compile_expr_top(rhs, &mut code, &mut shared_funcs, &shared_locals) { + code.push(crate::exa_wasm::interpreter::Opcode::StoreY(idx)); + out_bytecode_map.insert(idx, code); + } + } else { + let mut code: Vec = Vec::new(); + if compile_expr_top( + idx_expr, + &mut code, + &mut shared_funcs, + &shared_locals, + ) && compile_expr_top( + rhs, + &mut code, + &mut shared_funcs, + &shared_locals, + ) { + code.push(crate::exa_wasm::interpreter::Opcode::StoreYDyn); + out_bytecode_map.insert(usize::MAX, code); + } + } + } + } + } + } + } + + if let Some(v) = ir_obj.get("init_ast") { + if let Ok(stmts) = + serde_json::from_value::>(v.clone()) + { + for st in stmts.iter() { + if let crate::exa_wasm::interpreter::Stmt::Assign(lhs, rhs) = st { + if let crate::exa_wasm::interpreter::Lhs::Indexed(_name, idx_expr) = lhs { + if let crate::exa_wasm::interpreter::Expr::Number(n) = &**idx_expr { + let idx = *n as usize; + let mut code: Vec = Vec::new(); + if compile_expr_top(rhs, &mut code, &mut shared_funcs, &shared_locals) { + code.push(crate::exa_wasm::interpreter::Opcode::StoreX(idx)); + init_bytecode_map.insert(idx, code); + } + } else { + let mut code: Vec = Vec::new(); + if compile_expr_top( + idx_expr, + &mut code, + &mut shared_funcs, + &shared_locals, + ) && compile_expr_top( + rhs, + &mut code, + &mut shared_funcs, + &shared_locals, + ) { + code.push(crate::exa_wasm::interpreter::Opcode::StoreXDyn); + init_bytecode_map.insert(usize::MAX, code); + } + } + } + } + } + } + } + + if !out_bytecode_map.is_empty() { + ir_obj["out_bytecode"] = + serde_json::to_value(&out_bytecode_map).unwrap_or(serde_json::Value::Null); + } + if !init_bytecode_map.is_empty() { + ir_obj["init_bytecode"] = + serde_json::to_value(&init_bytecode_map).unwrap_or(serde_json::Value::Null); + } + + // Compile lag_map/fa_map entries into bytecode when present. The + // IR contains textual RHS strings for each entry; parse and compile + // them here so the runtime loader can consume bytecode directly. + let mut lag_bytecode_map: HashMap> = + HashMap::new(); + if let Some(v) = ir_obj.get("lag_map") { + if let Some(map) = v.as_object() { + for (k, val) in map.iter() { + if let Ok(idx) = k.parse::() { + if let Some(rhs_str) = val.as_str() { + let toks = crate::exa_wasm::interpreter::tokenize(rhs_str); + let mut p = crate::exa_wasm::interpreter::Parser::new(toks); + if let Ok(expr) = p.parse_expr_result() { + let mut code: Vec = Vec::new(); + if compile_expr_top(&expr, &mut code, &mut shared_funcs, &shared_locals) + { + lag_bytecode_map.insert(idx, code); + } + } + } + } + } + } + } + if !lag_bytecode_map.is_empty() { + ir_obj["lag_bytecode"] = + serde_json::to_value(&lag_bytecode_map).unwrap_or(serde_json::Value::Null); + } + + let mut fa_bytecode_map: HashMap> = + HashMap::new(); + if let Some(v) = ir_obj.get("fa_map") { + if let Some(map) = v.as_object() { + for (k, val) in map.iter() { + if let Ok(idx) = k.parse::() { + if let Some(rhs_str) = val.as_str() { + let toks = crate::exa_wasm::interpreter::tokenize(rhs_str); + let mut p = crate::exa_wasm::interpreter::Parser::new(toks); + if let Ok(expr) = p.parse_expr_result() { + let mut code: Vec = Vec::new(); + if compile_expr_top(&expr, &mut code, &mut shared_funcs, &shared_locals) + { + fa_bytecode_map.insert(idx, code); + } + } + } + } + } + } + } + if !fa_bytecode_map.is_empty() { + ir_obj["fa_bytecode"] = + serde_json::to_value(&fa_bytecode_map).unwrap_or(serde_json::Value::Null); + } + + // Ensure shared function table and locals are present in the IR when + // we discovered any during compilation. + if !shared_funcs.is_empty() { + ir_obj["funcs"] = serde_json::to_value(&shared_funcs).unwrap_or(serde_json::Value::Null); + } + if !shared_locals.is_empty() { + ir_obj["locals"] = serde_json::to_value(&shared_locals).unwrap_or(serde_json::Value::Null); + } + + let output_path = output.unwrap_or_else(|| { + let random_suffix: String = rand::rng() + .sample_iter(&Alphanumeric) + .take(5) + .map(char::from) + .collect(); + let default_name = format!("model_ir_{}_{}.json", env::consts::OS, random_suffix); + env::temp_dir().join("exa_tmp").with_file_name(default_name) + }); + + let serialized = serde_json::to_vec_pretty(&ir_obj) + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("serde_json error: {}", e)))?; + + if let Some(parent) = output_path.parent() { + if !parent.exists() { + fs::create_dir_all(parent)?; + } + } + + fs::write(&output_path, serialized)?; + Ok(output_path.to_string_lossy().to_string()) +} diff --git a/src/exa_wasm/interpreter/ast.rs b/src/exa_wasm/interpreter/ast.rs new file mode 100644 index 00000000..391671c5 --- /dev/null +++ b/src/exa_wasm/interpreter/ast.rs @@ -0,0 +1,107 @@ +// AST types for the exa_wasm interpreter +use serde::{Deserialize, Serialize}; +use std::fmt; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum Expr { + Number(f64), + Bool(bool), + Ident(String), // e.g. ke + Param(usize), // parameter by index (p[0] rewritten to Param(0)) + Indexed(String, Box), // e.g. x[0], rateiv[0], y[0] where index can be expr + UnaryOp { + op: String, + rhs: Box, + }, + BinaryOp { + lhs: Box, + op: String, + rhs: Box, + }, + Call { + name: String, + args: Vec, + }, + MethodCall { + receiver: Box, + name: String, + args: Vec, + }, + Ternary { + cond: Box, + then_branch: Box, + else_branch: Box, + }, +} + +#[derive(Debug, Clone)] +pub enum Token { + Num(f64), + Bool(bool), + Ident(String), + LBracket, + RBracket, + LBrace, + RBrace, + Assign, + LParen, + RParen, + Comma, + Dot, + Op(char), + Lt, + Gt, + Le, + Ge, + EqEq, + Ne, + And, + Or, + Bang, + Question, + Colon, + Semicolon, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum Lhs { + Ident(String), + Indexed(String, Box), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum Stmt { + Expr(Expr), + Assign(Lhs, Expr), + Block(Vec), + If { + cond: Expr, + then_branch: Box, + else_branch: Option>, + }, +} + +#[derive(Debug, Clone)] +pub struct ParseError { + pub pos: usize, + pub found: Option, + pub expected: Vec, +} + +impl fmt::Display for ParseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if !self.expected.is_empty() { + write!( + f, + "parse error at pos {} found={:?} expected={:?}", + self.pos, self.found, self.expected + ) + } else if let Some(tok) = &self.found { + write!(f, "parse error at pos {} found={:?}", self.pos, tok) + } else { + write!(f, "parse error at pos {} found=", self.pos) + } + } +} + +impl std::error::Error for ParseError {} diff --git a/src/exa_wasm/interpreter/builtins.rs b/src/exa_wasm/interpreter/builtins.rs new file mode 100644 index 00000000..f6642b1f --- /dev/null +++ b/src/exa_wasm/interpreter/builtins.rs @@ -0,0 +1,23 @@ +//! Builtin function metadata used by the interpreter and typechecker. +use std::ops::RangeInclusive; + +/// Return true if the name is a known builtin function. +pub fn is_known_function(name: &str) -> bool { + match name { + "exp" | "if" | "ln" | "log" | "log10" | "log2" | "sqrt" | "pow" | "powf" | "min" + | "max" | "abs" | "floor" | "ceil" | "round" | "sin" | "cos" | "tan" => true, + _ => false, + } +} + +/// Return the allowed argument count range for a builtin, if known. +/// Use inclusive ranges; None means unknown function. +pub fn arg_count_range(name: &str) -> Option> { + match name { + "exp" | "ln" | "log" | "log10" | "log2" | "sqrt" | "abs" | "floor" | "ceil" | "round" + | "sin" | "cos" | "tan" => Some(1..=1), + "pow" | "powf" | "min" | "max" => Some(2..=2), + "if" => Some(3..=3), + _ => None, + } +} diff --git a/src/exa_wasm/interpreter/dispatch.rs b/src/exa_wasm/interpreter/dispatch.rs new file mode 100644 index 00000000..67e5c619 --- /dev/null +++ b/src/exa_wasm/interpreter/dispatch.rs @@ -0,0 +1,498 @@ +use diffsol::Vector; +use diffsol::VectorHost; +use std::collections::HashMap; + +use crate::exa_wasm::interpreter::eval; +use crate::exa_wasm::interpreter::registry; +use crate::exa_wasm::interpreter::vm; + +fn current_id() -> Option { + registry::current_expr_id() +} + +pub fn diffeq_dispatch( + x: &crate::simulator::V, + p: &crate::simulator::V, + _t: crate::simulator::T, + dx: &mut crate::simulator::V, + _bolus: crate::simulator::V, + rateiv: crate::simulator::V, + _cov: &crate::data::Covariates, +) { + if let Some(id) = current_id() { + if let Some(entry) = registry::get_entry(id) { + // prepare locals vector: use emitted locals ordering if present, + // otherwise fall back to building slots from prelude ordering. + let mut locals_vec: Vec = vec![0.0; entry.locals.len()]; + let mut local_index: HashMap = HashMap::new(); + if !entry.locals.is_empty() { + for (i, n) in entry.locals.iter().enumerate() { + local_index.insert(n.clone(), i); + } + } + // evaluate prelude into a temporary map then populate locals_vec + let mut temp_locals: HashMap = HashMap::new(); + for (name, expr) in entry.prelude.iter() { + let val = eval::eval_expr( + expr, + x, + p, + &rateiv, + Some(&temp_locals), + Some(&entry.pmap), + Some(_t), + Some(_cov), + ); + temp_locals.insert(name.clone(), val.as_number()); + } + // populate locals_vec from temp_locals using emitted locals ordering + if !entry.locals.is_empty() { + for (name, &idx) in local_index.iter() { + if let Some(v) = temp_locals.get(name) { + locals_vec[idx] = *v; + } + } + } else { + // no emitted locals ordering: create slots for prelude in insertion order + let mut i = 0usize; + for (name, _) in entry.prelude.iter() { + local_index.insert(name.clone(), i); + if let Some(v) = temp_locals.get(name) { + if i >= locals_vec.len() { + locals_vec.push(*v); + } else { + locals_vec[i] = *v; + } + } + i += 1; + } + } + // debug: locals are in `locals_vec` and `local_index` + // If emitted function-level bytecode exists for diffeq, prefer executing it. + // Fallback to per-index bytecode map for backwards compatibility. + if !entry.bytecode_diffeq_func.is_empty() { + let builtins_dispatch = |name: &str, args: &[f64]| -> f64 { + let vals: Vec = + args.iter().map(|a| eval::Value::Number(*a)).collect(); + eval::eval_call(name, &vals).as_number() + }; + let mut locals_mut = locals_vec.clone(); + let mut assign = |name: &str, idx: usize, val: f64| match name { + "dx" => { + if idx < dx.len() { + dx[idx] = val; + } else { + crate::exa_wasm::interpreter::registry::set_runtime_error(format!( + "index out of bounds 'dx'[{}] (nstates={})", + idx, + dx.len() + )); + } + } + "x" | "y" => { + crate::exa_wasm::interpreter::registry::set_runtime_error(format!( + "write to '{}' not allowed in diffeq bytecode", + name + )); + } + _ => { + crate::exa_wasm::interpreter::registry::set_runtime_error(format!( + "unsupported indexed assignment '{}' in diffeq", + name + )); + } + }; + vm::run_bytecode_full( + entry.bytecode_diffeq_func.as_slice(), + x.as_slice(), + p.as_slice(), + rateiv.as_slice(), + _t, + &mut locals_mut, + &entry.funcs, + &builtins_dispatch, + |n, i, v| assign(n, i, v), + ); + } else if !entry.bytecode_diffeq.is_empty() { + // builtin dispatch closure: translate f64 args -> eval::Value and call eval::eval_call + let builtins_dispatch = |name: &str, args: &[f64]| -> f64 { + let vals: Vec = + args.iter().map(|a| eval::Value::Number(*a)).collect(); + eval::eval_call(name, &vals).as_number() + }; + // assignment closure maps VM stores to simulator vectors + let mut assign = |name: &str, idx: usize, val: f64| match name { + "dx" => { + if idx < dx.len() { + dx[idx] = val; + } else { + crate::exa_wasm::interpreter::registry::set_runtime_error(format!( + "index out of bounds 'dx'[{}] (nstates={})", + idx, + dx.len() + )); + } + } + "x" | "y" => { + crate::exa_wasm::interpreter::registry::set_runtime_error(format!( + "write to '{}' not allowed in diffeq bytecode", + name + )); + } + _ => { + crate::exa_wasm::interpreter::registry::set_runtime_error(format!( + "unsupported indexed assignment '{}' in diffeq", + name + )); + } + }; + for (_i, code) in entry.bytecode_diffeq.iter() { + let mut locals_mut = locals_vec.clone(); + vm::run_bytecode_full( + code.as_slice(), + x.as_slice(), + p.as_slice(), + rateiv.as_slice(), + _t, + &mut locals_mut, + &entry.funcs, + &builtins_dispatch, + |n, i, v| assign(n, i, v), + ); + } + } else { + // execute statement ASTs which may assign to dx indices or locals + let mut assign_closure = |name: &str, idx: usize, val: f64| match name { + "dx" => { + if idx < dx.len() { + dx[idx] = val; + } else { + crate::exa_wasm::interpreter::registry::set_runtime_error(format!( + "index out of bounds 'dx'[{}] (nstates={})", + idx, + dx.len() + )); + } + } + _ => { + crate::exa_wasm::interpreter::registry::set_runtime_error(format!( + "unsupported indexed assignment '{}' in diffeq", + name + )); + } + }; + // convert locals_vec into a HashMap for eval_stmt + let mut locals_map: HashMap = HashMap::new(); + for (name, &idx) in local_index.iter() { + if idx < locals_vec.len() { + locals_map.insert(name.clone(), locals_vec[idx]); + } + } + for st in entry.diffeq_stmts.iter() { + crate::exa_wasm::interpreter::eval::eval_stmt( + st, + x, + p, + _t, + &rateiv, + &mut locals_map, + Some(&entry.pmap), + Some(_cov), + &mut assign_closure, + ); + } + } + } + } +} + +pub fn out_dispatch( + x: &crate::simulator::V, + p: &crate::simulator::V, + _t: crate::simulator::T, + _cov: &crate::data::Covariates, + y: &mut crate::simulator::V, +) { + let tmp = crate::simulator::V::zeros(1, diffsol::NalgebraContext); + if let Some(id) = current_id() { + if let Some(entry) = registry::get_entry(id) { + // prepare locals vector for out bytecode (use emitted locals ordering) + let mut locals_vec: Vec = vec![0.0; entry.locals.len()]; + let mut local_index: HashMap = HashMap::new(); + if !entry.locals.is_empty() { + for (i, n) in entry.locals.iter().enumerate() { + local_index.insert(n.clone(), i); + } + } + // evaluate prelude into temporary map and populate locals_vec + let mut temp_locals: HashMap = HashMap::new(); + for (name, expr) in entry.prelude.iter() { + let val = eval::eval_expr( + expr, + x, + p, + &tmp, + Some(&temp_locals), + Some(&entry.pmap), + Some(_t), + Some(_cov), + ); + temp_locals.insert(name.clone(), val.as_number()); + } + for (name, &idx) in local_index.iter() { + if let Some(v) = temp_locals.get(name) { + locals_vec[idx] = *v; + } + } + + if !entry.bytecode_out.is_empty() { + let builtins_dispatch = |name: &str, args: &[f64]| -> f64 { + let vals: Vec = + args.iter().map(|a| eval::Value::Number(*a)).collect(); + eval::eval_call(name, &vals).as_number() + }; + let mut assign = |name: &str, idx: usize, val: f64| match name { + "y" => { + if idx < y.len() { + y[idx] = val; + } else { + crate::exa_wasm::interpreter::registry::set_runtime_error(format!( + "index out of bounds 'y'[{}] (nouteqs={})", + idx, + y.len() + )); + } + } + _ => { + crate::exa_wasm::interpreter::registry::set_runtime_error(format!( + "unsupported indexed assignment '{}' in out", + name + )); + } + }; + for (_i, code) in entry.bytecode_out.iter() { + let mut locals_mut = locals_vec.clone(); + vm::run_bytecode_full( + code.as_slice(), + x.as_slice(), + p.as_slice(), + tmp.as_slice(), + _t, + &mut locals_mut, + &entry.funcs, + &builtins_dispatch, + |n, i, v| assign(n, i, v), + ); + } + } else { + let mut assign = |name: &str, idx: usize, val: f64| match name { + "y" => { + if idx < y.len() { + y[idx] = val; + } else { + crate::exa_wasm::interpreter::registry::set_runtime_error(format!( + "index out of bounds 'y'[{}] (nouteqs={})", + idx, + y.len() + )); + } + } + _ => { + crate::exa_wasm::interpreter::registry::set_runtime_error(format!( + "unsupported indexed assignment '{}' in out", + name + )); + } + }; + for st in entry.out_stmts.iter() { + crate::exa_wasm::interpreter::eval::eval_stmt( + st, + x, + p, + _t, + &tmp, + &mut std::collections::HashMap::new(), + Some(&entry.pmap), + Some(_cov), + &mut assign, + ); + } + } + } + } +} + +pub fn lag_dispatch( + p: &crate::simulator::V, + _t: crate::simulator::T, + _cov: &crate::data::Covariates, +) -> std::collections::HashMap { + let mut out: std::collections::HashMap = + std::collections::HashMap::new(); + if let Some(id) = current_id() { + if let Some(entry) = registry::get_entry(id) { + let zero_x = crate::simulator::V::zeros(entry.nstates, diffsol::NalgebraContext); + let zero_rate = crate::simulator::V::zeros(entry.nstates, diffsol::NalgebraContext); + for (i, expr) in entry.lag.iter() { + let v = crate::exa_wasm::interpreter::eval::eval_expr( + expr, + &zero_x, + p, + &zero_rate, + None, + Some(&entry.pmap), + Some(_t), + Some(_cov), + ); + out.insert(*i, v.as_number()); + } + } + } + out +} + +pub fn fa_dispatch( + p: &crate::simulator::V, + _t: crate::simulator::T, + _cov: &crate::data::Covariates, +) -> std::collections::HashMap { + let mut out: std::collections::HashMap = + std::collections::HashMap::new(); + if let Some(id) = current_id() { + if let Some(entry) = registry::get_entry(id) { + let zero_x = crate::simulator::V::zeros(entry.nstates, diffsol::NalgebraContext); + let zero_rate = crate::simulator::V::zeros(entry.nstates, diffsol::NalgebraContext); + for (i, expr) in entry.fa.iter() { + let v = crate::exa_wasm::interpreter::eval::eval_expr( + expr, + &zero_x, + p, + &zero_rate, + None, + Some(&entry.pmap), + Some(_t), + Some(_cov), + ); + out.insert(*i, v.as_number()); + } + } + } + out +} + +pub fn init_dispatch( + p: &crate::simulator::V, + _t: crate::simulator::T, + cov: &crate::data::Covariates, + x: &mut crate::simulator::V, +) { + if let Some(id) = current_id() { + if let Some(entry) = registry::get_entry(id) { + let zero_rate = crate::simulator::V::zeros(entry.nstates, diffsol::NalgebraContext); + // prepare locals vector for init bytecode (use emitted locals ordering) + let mut locals_vec: Vec = vec![0.0; entry.locals.len()]; + let mut local_index: HashMap = HashMap::new(); + if !entry.locals.is_empty() { + for (i, n) in entry.locals.iter().enumerate() { + local_index.insert(n.clone(), i); + } + } + let mut temp_locals: HashMap = HashMap::new(); + for (name, expr) in entry.prelude.iter() { + let val = eval::eval_expr( + expr, + &crate::simulator::V::zeros(entry.nstates, diffsol::NalgebraContext), + p, + &zero_rate, + Some(&temp_locals), + Some(&entry.pmap), + Some(_t), + Some(cov), + ); + temp_locals.insert(name.clone(), val.as_number()); + } + for (name, &idx) in local_index.iter() { + if let Some(v) = temp_locals.get(name) { + locals_vec[idx] = *v; + } + } + + if !entry.bytecode_init.is_empty() { + let builtins_dispatch = |name: &str, args: &[f64]| -> f64 { + let vals: Vec = + args.iter().map(|a| eval::Value::Number(*a)).collect(); + eval::eval_call(name, &vals).as_number() + }; + let mut assign = |name: &str, idx: usize, val: f64| match name { + "x" => { + if idx < x.len() { + x[idx] = val; + } else { + crate::exa_wasm::interpreter::registry::set_runtime_error(format!( + "index out of bounds 'x'[{}] (nstates={})", + idx, + x.len() + )); + } + } + _ => { + crate::exa_wasm::interpreter::registry::set_runtime_error(format!( + "unsupported indexed assignment '{}' in init", + name + )); + } + }; + for (_i, code) in entry.bytecode_init.iter() { + let mut locals_mut = locals_vec.clone(); + vm::run_bytecode_full( + code.as_slice(), + &crate::simulator::V::zeros(entry.nstates, diffsol::NalgebraContext) + .as_slice(), + p.as_slice(), + zero_rate.as_slice(), + _t, + &mut locals_mut, + &entry.funcs, + &builtins_dispatch, + |n, i, v| assign(n, i, v), + ); + } + } else { + // execute init statements which may assign to x[] or locals + let mut assign = |name: &str, idx: usize, val: f64| match name { + "x" => { + if idx < x.len() { + x[idx] = val; + } else { + crate::exa_wasm::interpreter::registry::set_runtime_error(format!( + "index out of bounds 'x'[{}] (nstates={})", + idx, + x.len() + )); + } + } + _ => { + crate::exa_wasm::interpreter::registry::set_runtime_error(format!( + "unsupported indexed assignment '{}' in init", + name + )); + } + }; + for st in entry.init_stmts.iter() { + // use zeros for rateiv parameter + crate::exa_wasm::interpreter::eval::eval_stmt( + st, + &crate::simulator::V::zeros(entry.nstates, diffsol::NalgebraContext), + p, + _t, + &zero_rate, + &mut std::collections::HashMap::new(), + Some(&entry.pmap), + Some(cov), + &mut assign, + ); + } + } + } + } +} diff --git a/src/exa_wasm/interpreter/eval.rs b/src/exa_wasm/interpreter/eval.rs new file mode 100644 index 00000000..62b5605f --- /dev/null +++ b/src/exa_wasm/interpreter/eval.rs @@ -0,0 +1,457 @@ +use diffsol::Vector; + +use crate::data::Covariates; +use crate::exa_wasm::interpreter::ast::Expr; +use crate::exa_wasm::interpreter::builtins; +use crate::simulator::T; +use crate::simulator::V; +use std::collections::HashMap; + +// runtime value type +#[derive(Debug, Clone, PartialEq)] +pub enum Value { + Number(f64), + Bool(bool), +} + +impl Value { + pub fn as_number(&self) -> f64 { + match self { + Value::Number(n) => *n, + Value::Bool(b) => { + if *b { + 1.0 + } else { + 0.0 + } + } + } + } + pub fn as_bool(&self) -> bool { + match self { + Value::Bool(b) => *b, + Value::Number(n) => *n != 0.0, + } + } +} + +// Evaluator extracted from mod.rs. Uses super::set_runtime_error to report +// runtime problems so the parent module can expose them to the simulator. +pub(crate) fn eval_call(name: &str, args: &[Value]) -> Value { + use Value::Number; + // runtime arity and known-function checks using centralized builtins table + if let Some(range) = builtins::arg_count_range(name) { + if !range.contains(&args.len()) { + crate::exa_wasm::interpreter::set_runtime_error(format!( + "builtin '{}' called with wrong arity: got {}, expected {:?}", + name, + args.len(), + range + )); + return Number(0.0); + } + } else { + // if arg_count_range returns None, it's unknown to our builtin table + if !builtins::is_known_function(name) { + crate::exa_wasm::interpreter::set_runtime_error(format!( + "unknown function '{}', not present in builtins table", + name + )); + return Number(0.0); + } + } + + match name { + "exp" => Number( + args.get(0) + .cloned() + .unwrap_or(Number(0.0)) + .as_number() + .exp(), + ), + "if" => { + let cond = args.get(0).cloned().unwrap_or(Number(0.0)); + if cond.as_bool() { + args.get(1).cloned().unwrap_or(Number(0.0)) + } else { + args.get(2).cloned().unwrap_or(Number(0.0)) + } + } + "ln" | "log" => Number(args.get(0).cloned().unwrap_or(Number(0.0)).as_number().ln()), + "log10" => Number( + args.get(0) + .cloned() + .unwrap_or(Number(0.0)) + .as_number() + .log10(), + ), + "log2" => Number( + args.get(0) + .cloned() + .unwrap_or(Number(0.0)) + .as_number() + .log2(), + ), + "sqrt" => Number( + args.get(0) + .cloned() + .unwrap_or(Number(0.0)) + .as_number() + .sqrt(), + ), + "pow" | "powf" => { + let a = args.get(0).cloned().unwrap_or(Number(0.0)).as_number(); + let b = args.get(1).cloned().unwrap_or(Number(0.0)).as_number(); + Number(a.powf(b)) + } + "min" => { + let a = args.get(0).cloned().unwrap_or(Number(0.0)).as_number(); + let b = args.get(1).cloned().unwrap_or(Number(0.0)).as_number(); + Number(a.min(b)) + } + "max" => { + let a = args.get(0).cloned().unwrap_or(Number(0.0)).as_number(); + let b = args.get(1).cloned().unwrap_or(Number(0.0)).as_number(); + Number(a.max(b)) + } + "abs" => Number( + args.get(0) + .cloned() + .unwrap_or(Number(0.0)) + .as_number() + .abs(), + ), + "floor" => Number( + args.get(0) + .cloned() + .unwrap_or(Number(0.0)) + .as_number() + .floor(), + ), + "ceil" => Number( + args.get(0) + .cloned() + .unwrap_or(Number(0.0)) + .as_number() + .ceil(), + ), + "round" => Number( + args.get(0) + .cloned() + .unwrap_or(Number(0.0)) + .as_number() + .round(), + ), + "sin" => Number( + args.get(0) + .cloned() + .unwrap_or(Number(0.0)) + .as_number() + .sin(), + ), + "cos" => Number( + args.get(0) + .cloned() + .unwrap_or(Number(0.0)) + .as_number() + .cos(), + ), + "tan" => Number( + args.get(0) + .cloned() + .unwrap_or(Number(0.0)) + .as_number() + .tan(), + ), + _ => { + // Unknown function: report a runtime error so callers/users + // can detect mistakes (typos, missing builtins) instead of + // silently receiving 0.0 which hides problems. + crate::exa_wasm::interpreter::set_runtime_error(format!("unknown function '{}'", name)); + Number(0.0) + } + } +} + +pub(crate) fn eval_expr( + expr: &Expr, + x: &V, + p: &V, + rateiv: &V, + locals: Option<&HashMap>, + pmap: Option<&HashMap>, + t: Option, + cov: Option<&Covariates>, +) -> Value { + use crate::exa_wasm::interpreter::set_runtime_error; + + match expr { + Expr::Bool(b) => Value::Bool(*b), + Expr::Number(v) => Value::Number(*v), + Expr::Ident(name) => { + if name.starts_with('_') { + return Value::Number(0.0); + } + // local variables defined by prelude take precedence + if let Some(loc) = locals { + if let Some(v) = loc.get(name) { + return Value::Number(*v); + } + } + if let Some(map) = pmap { + if let Some(idx) = map.get(name) { + let val = p[*idx]; + return Value::Number(val); + } + } + if name == "t" { + let val = t.unwrap_or(0.0); + return Value::Number(val); + } + if let Some(covariates) = cov { + if let Some(covariate) = covariates.get_covariate(name) { + if let Some(time) = t { + if let Ok(v) = covariate.interpolate(time) { + return Value::Number(v); + } + } + } + } + set_runtime_error(format!("unknown identifier '{}'", name)); + Value::Number(0.0) + } + Expr::Param(idx) => { + let i = *idx; + if i < p.len() { + Value::Number(p[i]) + } else { + set_runtime_error(format!( + "parameter index out of bounds p[{}] (nparams={})", + i, + p.len() + )); + Value::Number(0.0) + } + } + Expr::Indexed(name, idx_expr) => { + let idxv = eval_expr(idx_expr, x, p, rateiv, locals, pmap, t, cov); + let idxf = idxv.as_number(); + if !idxf.is_finite() || idxf.is_sign_negative() { + set_runtime_error(format!( + "invalid index expression for '{}' -> {}", + name, idxf + )); + return Value::Number(0.0); + } + let idx = idxf as usize; + match name.as_str() { + "x" => { + if idx < x.len() { + Value::Number(x[idx]) + } else { + set_runtime_error(format!( + "index out of bounds 'x'[{}] (nstates={})", + idx, + x.len() + )); + Value::Number(0.0) + } + } + "p" | "params" => { + if idx < p.len() { + Value::Number(p[idx]) + } else { + set_runtime_error(format!( + "parameter index out of bounds '{}'[{}] (nparams={})", + name, + idx, + p.len() + )); + Value::Number(0.0) + } + } + "rateiv" => { + if idx < rateiv.len() { + Value::Number(rateiv[idx]) + } else { + set_runtime_error(format!( + "index out of bounds 'rateiv'[{}] (len={})", + idx, + rateiv.len() + )); + Value::Number(0.0) + } + } + _ => { + set_runtime_error(format!("unknown indexed symbol '{}'", name)); + Value::Number(0.0) + } + } + } + Expr::UnaryOp { op, rhs } => { + let v = eval_expr(rhs, x, p, rateiv, locals, pmap, t, cov); + match op.as_str() { + "-" => Value::Number(-v.as_number()), + "!" => Value::Bool(!v.as_bool()), + _ => v, + } + } + Expr::BinaryOp { lhs, op, rhs } => { + match op.as_str() { + "&&" => { + let a = eval_expr(lhs, x, p, rateiv, locals, pmap, t, cov); + if !a.as_bool() { + return Value::Bool(false); + } + let b = eval_expr(rhs, x, p, rateiv, locals, pmap, t, cov); + Value::Bool(b.as_bool()) + } + "||" => { + let a = eval_expr(lhs, x, p, rateiv, locals, pmap, t, cov); + if a.as_bool() { + return Value::Bool(true); + } + let b = eval_expr(rhs, x, p, rateiv, locals, pmap, t, cov); + Value::Bool(b.as_bool()) + } + _ => { + let a = eval_expr(lhs, x, p, rateiv, locals, pmap, t, cov); + let b = eval_expr(rhs, x, p, rateiv, locals, pmap, t, cov); + match op.as_str() { + "+" => Value::Number(a.as_number() + b.as_number()), + "-" => Value::Number(a.as_number() - b.as_number()), + "*" => Value::Number(a.as_number() * b.as_number()), + "/" => Value::Number(a.as_number() / b.as_number()), + "^" => Value::Number(a.as_number().powf(b.as_number())), + "<" => Value::Bool(a.as_number() < b.as_number()), + ">" => Value::Bool(a.as_number() > b.as_number()), + "<=" => Value::Bool(a.as_number() <= b.as_number()), + ">=" => Value::Bool(a.as_number() >= b.as_number()), + "==" => { + // equality for numbers and bools via coercion + match (a, b) { + (Value::Bool(aa), Value::Bool(bb)) => Value::Bool(aa == bb), + (aa, bb) => Value::Bool(aa.as_number() == bb.as_number()), + } + } + "!=" => match (a, b) { + (Value::Bool(aa), Value::Bool(bb)) => Value::Bool(aa != bb), + (aa, bb) => Value::Bool(aa.as_number() != bb.as_number()), + }, + _ => a, + } + } + } + } + Expr::Call { name, args } => { + let mut avals: Vec = Vec::new(); + for aexpr in args.iter() { + avals.push(eval_expr(aexpr, x, p, rateiv, locals, pmap, t, cov)); + } + let res = eval_call(name.as_str(), &avals); + // warn if unknown function returned Number(0.0)? Keep legacy behavior minimal + res + } + Expr::Ternary { + cond, + then_branch, + else_branch, + } => { + let c = eval_expr(cond, x, p, rateiv, locals, pmap, t, cov); + if c.as_bool() { + eval_expr(then_branch, x, p, rateiv, locals, pmap, t, cov) + } else { + eval_expr(else_branch, x, p, rateiv, locals, pmap, t, cov) + } + } + Expr::MethodCall { + receiver, + name, + args, + } => { + let recv = eval_expr(receiver, x, p, rateiv, locals, pmap, t, cov); + let mut avals: Vec = Vec::new(); + avals.push(recv); + for aexpr in args.iter() { + avals.push(eval_expr(aexpr, x, p, rateiv, locals, pmap, t, cov)); + } + let res = eval_call(name.as_str(), &avals); + res + } + } +} + +// functions are exported as `pub(crate)` above for use by parent module + +pub(crate) fn eval_stmt( + stmt: &crate::exa_wasm::interpreter::ast::Stmt, + x: &crate::simulator::V, + p: &crate::simulator::V, + t: crate::simulator::T, + rateiv: &crate::simulator::V, + locals: &mut std::collections::HashMap, + pmap: Option<&std::collections::HashMap>, + cov: Option<&crate::data::Covariates>, + assign_indexed: &mut FAssign, +) where + FAssign: FnMut(&str, usize, f64), +{ + use crate::exa_wasm::interpreter::ast::{Lhs, Stmt}; + + match stmt { + Stmt::Expr(e) => { + let _ = eval_expr(e, x, p, rateiv, Some(&*locals), pmap, Some(t), cov); + } + Stmt::Assign(lhs, rhs) => { + // evaluate rhs + let val = eval_expr(rhs, x, p, rateiv, Some(&*locals), pmap, Some(t), cov); + match lhs { + Lhs::Ident(name) => { + locals.insert(name.clone(), val.as_number()); + } + Lhs::Indexed(name, idx_expr) => { + let idxv = + eval_expr(idx_expr, x, p, rateiv, Some(&*locals), pmap, Some(t), cov); + let idxf = idxv.as_number(); + if !idxf.is_finite() || idxf.is_sign_negative() { + crate::exa_wasm::interpreter::registry::set_runtime_error(format!( + "invalid index expression for '{}' -> {}", + name, idxf + )); + return; + } + let idx = idxf as usize; + // delegate actual assignment to the provided closure + assign_indexed(name.as_str(), idx, val.as_number()); + } + } + } + Stmt::Block(v) => { + for s in v.iter() { + eval_stmt(s, x, p, t, rateiv, locals, pmap, cov, assign_indexed); + } + } + Stmt::If { + cond, + then_branch, + else_branch, + } => { + let c = eval_expr(cond, x, p, rateiv, Some(&*locals), pmap, Some(t), cov); + if c.as_bool() { + eval_stmt( + then_branch, + x, + p, + t, + rateiv, + locals, + pmap, + cov, + assign_indexed, + ); + } else if let Some(eb) = else_branch { + eval_stmt(eb, x, p, t, rateiv, locals, pmap, cov, assign_indexed); + } + } + } +} diff --git a/src/exa_wasm/interpreter/loader.rs b/src/exa_wasm/interpreter/loader.rs new file mode 100644 index 00000000..e75ff02c --- /dev/null +++ b/src/exa_wasm/interpreter/loader.rs @@ -0,0 +1,784 @@ +use std::collections::HashMap; +use std::fs; +use std::io; +use std::path::PathBuf; + +use serde::Deserialize; + +use crate::exa_wasm::interpreter::ast::Expr; +use crate::exa_wasm::interpreter::parser::{tokenize, Parser}; +use crate::exa_wasm::interpreter::registry; +use crate::exa_wasm::interpreter::typecheck; + +#[allow(dead_code)] +#[derive(Deserialize, Debug)] +struct IrFile { + ir_version: Option, + kind: Option, + params: Option>, + model_text: Option, + diffeq: Option, + lag: Option, + fa: Option, + init: Option, + out: Option, + lag_map: Option>, + fa_map: Option>, + // optional fetch macro bodies extracted at emit time + fetch_params: Option>, + fetch_cov: Option>, + // optional pre-parsed ASTs emitted by `emit_ir` + diffeq_ast: Option>, + out_ast: Option>, + init_ast: Option>, + // optional compiled bytecode emitted by `emit_ir` + diffeq_bytecode: + Option>>, + // optional compiled function-level bytecode (single code vector) + diffeq_func: Option>, + out_bytecode: + Option>>, + init_bytecode: + Option>>, + lag_bytecode: + Option>>, + fa_bytecode: + Option>>, + // optional emitted function table and local slot ordering + funcs: Option>, + locals: Option>, +} + +pub fn load_ir_ode( + ir_path: PathBuf, +) -> Result< + ( + crate::simulator::equation::ODE, + crate::simulator::equation::Meta, + usize, + ), + io::Error, +> { + let contents = fs::read_to_string(&ir_path)?; + let ir: IrFile = serde_json::from_str(&contents) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("serde_json: {}", e)))?; + + let params = ir.params.unwrap_or_default(); + let meta = crate::simulator::equation::Meta::new(params.iter().map(|s| s.as_str()).collect()); + + let mut pmap = std::collections::HashMap::new(); + for (i, name) in params.iter().enumerate() { + pmap.insert(name.clone(), i); + } + + let lag_text = ir.lag.clone().unwrap_or_default(); + let fa_text = ir.fa.clone().unwrap_or_default(); + + let mut lag_map: HashMap = HashMap::new(); + let mut fa_map: HashMap = HashMap::new(); + let mut prelude: Vec<(String, Expr)> = Vec::new(); + // statement vectors (full statement ASTs parsed from closures) + let mut diffeq_stmts: Vec = Vec::new(); + let mut out_stmts: Vec = Vec::new(); + let mut init_stmts: Vec = Vec::new(); + + let mut parse_errors: Vec = Vec::new(); + + // Extract top-level assignments like `dx[i] = expr;` from the closure body. + // Only statements at the first brace nesting level (depth == 1) are + // considered top-level; assignments inside nested blocks (e.g. inside + // `if { ... }`) will be ignored. This avoids accidentally extracting + // conditional assignments that should not be treated as unconditional + // runtime equations. + // extract_all_assign delegated to loader_helpers + + // Prefer pre-parsed AST emitted by the IR emitter. If the emitter + // provided bytecode we will accept it; textual parsing of closure + // strings is no longer supported at runtime. This guarantees a single + // robust pipeline: AST + bytecode emitted by `emit_ir`. + if let Some(ast) = ir.diffeq_ast.clone() { + // Extract prelude assignments (non-indexed Ident = expr) into `prelude` + // and keep the remaining statements for execution. We do not run the + // global typechecker here because prelude locals must be known to + // validate the remainder; later validation steps will cover the + // full statement set with prelude information. + let mut main_stmts: Vec = Vec::new(); + for st in ast.into_iter() { + match st { + crate::exa_wasm::interpreter::ast::Stmt::Assign(lhs, rhs) => { + if let crate::exa_wasm::interpreter::ast::Lhs::Ident(name) = lhs { + prelude.push((name, rhs)); + continue; + } + main_stmts.push(crate::exa_wasm::interpreter::ast::Stmt::Assign(lhs, rhs)); + } + other => main_stmts.push(other), + } + } + diffeq_stmts = main_stmts; + } else if ir.diffeq_bytecode.is_some() { + // bytecode present without AST: accept but require func/local metadata + if ir.funcs.is_none() || ir.locals.is_none() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "diffeq bytecode present but missing funcs/locals metadata in IR", + )); + } + } else { + parse_errors.push( + "diffeq closure missing: emit_ir must provide diffeq_ast or diffeq_bytecode" + .to_string(), + ); + } + + // Now that prelude has been extracted (if any), run the full typechecker + // on diffeq statements so we catch builtin arity and type errors early. + if !diffeq_stmts.is_empty() { + if let Err(e) = typecheck::check_statements(&diffeq_stmts) { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("type errors in diffeq AST in IR: {:?}", e), + )); + } + } + + // prelude is extracted from diffeq_ast above (if present). If diffeq + // bytecode was provided without AST, prelude will remain empty and + // `locals` should be provided by emit_ir to define local slots. + // prefer pre-parsed AST for out; if not present, bytecode_out may be supplied + if let Some(ast) = ir.out_ast.clone() { + if let Err(e) = typecheck::check_statements(&ast) { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("type errors in out AST in IR: {:?}", e), + )); + } + out_stmts = ast; + } else if ir.out_bytecode.is_some() { + if ir.funcs.is_none() || ir.locals.is_none() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "out bytecode present but missing funcs/locals metadata in IR", + )); + } + } else { + // out closure missing: acceptable + } + + // prefer pre-parsed AST for init; if not present, bytecode_init may be supplied + if let Some(ast) = ir.init_ast.clone() { + if let Err(e) = typecheck::check_statements(&ast) { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("type errors in init AST in IR: {:?}", e), + )); + } + init_stmts = ast; + } else if ir.init_bytecode.is_some() { + if ir.funcs.is_none() || ir.locals.is_none() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "init bytecode present but missing funcs/locals metadata in IR", + )); + } + } else { + // init closure missing: acceptable + } + + if let Some(lmap) = ir.lag_map.clone() { + for (i, rhs) in lmap.into_iter() { + let toks = tokenize(&rhs); + let mut p = Parser::new(toks); + match p.parse_expr_result() { + Ok(expr) => { + lag_map.insert(i, expr); + } + Err(e) => { + parse_errors.push(format!( + "failed to parse lag! entry {} => '{}' : {}", + i, rhs, e + )); + } + } + } + } else { + if !lag_text.trim().is_empty() { + parse_errors.push("IR missing structured `lag_map` field; textual `lag!{}` parsing is no longer supported at runtime".to_string()); + } + } + if let Some(fmap) = ir.fa_map.clone() { + for (i, rhs) in fmap.into_iter() { + let toks = tokenize(&rhs); + let mut p = Parser::new(toks); + match p.parse_expr_result() { + Ok(expr) => { + fa_map.insert(i, expr); + } + Err(e) => { + parse_errors.push(format!( + "failed to parse fa! entry {} => '{}' : {}", + i, rhs, e + )); + } + } + } + } else { + if !fa_text.trim().is_empty() { + parse_errors.push("IR missing structured `fa_map` field; textual `fa!{}` parsing is no longer supported at runtime".to_string()); + } + } + + // fetch_params / fetch_cov bodies should be emitted into the IR by the + // emitter. Runtime textual scanning is no longer supported. + let mut fetch_macro_bodies: Vec = Vec::new(); + if let Some(fp) = ir.fetch_params.clone() { + fetch_macro_bodies.extend(fp); + } + + for body in fetch_macro_bodies.iter() { + let parts: Vec = body + .split(',') + .map(|s| s.trim().trim_matches(|c| c == '"' || c == '\'')) + .map(|s| s.to_string()) + .collect(); + if parts.is_empty() { + parse_errors.push(format!("empty fetch_params! macro body: '{}'", body)); + continue; + } + for name in parts.iter().skip(1) { + if name.starts_with('_') { + continue; + } + if !params.iter().any(|p| p == name) { + parse_errors.push(format!( + "fetch_params! references unknown parameter '{}' not present in IR params {:?}", + name, params + )); + } + } + } + + let mut fetch_cov_bodies: Vec = Vec::new(); + if let Some(fc) = ir.fetch_cov.clone() { + fetch_cov_bodies.extend(fc); + } + + for body in fetch_cov_bodies.iter() { + let parts: Vec = body + .split(',') + .map(|s| s.trim().trim_matches(|c| c == '"' || c == '\'')) + .map(|s| s.to_string()) + .collect(); + if parts.len() < 3 { + parse_errors.push(format!( + "fetch_cov! macro expects at least (cov, t, name...), got '{}'", + body + )); + continue; + } + let cov_var = parts[0].clone(); + if cov_var.is_empty() || !cov_var.chars().next().unwrap().is_ascii_alphabetic() { + parse_errors.push(format!( + "invalid first argument '{}' in fetch_cov! macro", + cov_var + )); + } + let _tvar = parts[1].clone(); + if _tvar.is_empty() { + parse_errors.push(format!( + "invalid time argument '{}' in fetch_cov! macro", + _tvar + )); + } + for name in parts.iter().skip(2) { + if name.is_empty() { + parse_errors.push(format!( + "empty covariate name in fetch_cov! macro body '{}'", + body + )); + } + if !name.starts_with('_') + && !name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') + { + parse_errors.push(format!( + "invalid covariate identifier '{}' in fetch_cov! macro", + name + )); + } + } + } + + if diffeq_stmts.is_empty() { + parse_errors.push( + "no dx[...] assignments found in diffeq; emit_ir must populate dx entries in the IR" + .to_string(), + ); + } + + if !parse_errors.is_empty() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("parse errors: {}", parse_errors.join("; ")), + )); + } + + // expression validation delegated to loader_helpers + + // Determine number of states and output eqs from parsed assignments + let max_dx = + crate::exa_wasm::interpreter::loader_helpers::collect_max_index(&diffeq_stmts, "dx") + .unwrap_or(0); + let max_y = crate::exa_wasm::interpreter::loader_helpers::collect_max_index(&out_stmts, "y") + .unwrap_or(0); + let nstates = max_dx + 1; + let nouteqs = max_y + 1; + + let nparams = params.len(); + // Prelude and statement validation delegated to loader_helpers + + for s in diffeq_stmts.iter() { + crate::exa_wasm::interpreter::loader_helpers::validate_stmt( + s, + &pmap, + nstates, + nparams, + &mut parse_errors, + ); + } + for s in out_stmts.iter() { + crate::exa_wasm::interpreter::loader_helpers::validate_stmt( + s, + &pmap, + nstates, + nparams, + &mut parse_errors, + ); + } + for s in init_stmts.iter() { + crate::exa_wasm::interpreter::loader_helpers::validate_stmt( + s, + &pmap, + nstates, + nparams, + &mut parse_errors, + ); + } + for (_i, expr) in lag_map.iter() { + crate::exa_wasm::interpreter::loader_helpers::validate_expr( + expr, + &pmap, + nstates, + nparams, + &mut parse_errors, + ); + } + for (_i, expr) in fa_map.iter() { + crate::exa_wasm::interpreter::loader_helpers::validate_expr( + expr, + &pmap, + nstates, + nparams, + &mut parse_errors, + ); + } + + // validate prelude ordering: each prelude RHS may reference params or earlier locals + { + let mut known: std::collections::HashSet = std::collections::HashSet::new(); + for (name, expr) in prelude.iter() { + crate::exa_wasm::interpreter::loader_helpers::validate_prelude_expr( + expr, + &pmap, + &known, + nstates, + nparams, + &mut parse_errors, + ); + known.insert(name.clone()); + } + } + + // Validate that pre-parsed ASTs do not call unknown builtin functions. + // This mirrors the bytecode arity checks but runs on ASTs emitted by the + // emitter so loaders reject IR that references unknown functions. + fn validate_builtin_calls_in_expr( + e: &crate::exa_wasm::interpreter::ast::Expr, + errors: &mut Vec, + ) { + use crate::exa_wasm::interpreter::ast::*; + match e { + Expr::Call { name, args } => { + if !crate::exa_wasm::interpreter::is_known_function(name.as_str()) { + errors.push(format!("unknown function call '{}' in AST", name)); + } + for a in args.iter() { + validate_builtin_calls_in_expr(a, errors); + } + } + Expr::MethodCall { + receiver, + name, + args, + } => { + if !crate::exa_wasm::interpreter::is_known_function(name.as_str()) { + errors.push(format!("unknown method call '{}' in AST", name)); + } + validate_builtin_calls_in_expr(receiver, errors); + for a in args.iter() { + validate_builtin_calls_in_expr(a, errors); + } + } + Expr::Indexed(_, idx) => validate_builtin_calls_in_expr(idx, errors), + Expr::UnaryOp { rhs, .. } => validate_builtin_calls_in_expr(rhs, errors), + Expr::BinaryOp { lhs, rhs, .. } => { + validate_builtin_calls_in_expr(lhs, errors); + validate_builtin_calls_in_expr(rhs, errors); + } + Expr::Ternary { + cond, + then_branch, + else_branch, + } => { + validate_builtin_calls_in_expr(cond, errors); + validate_builtin_calls_in_expr(then_branch, errors); + validate_builtin_calls_in_expr(else_branch, errors); + } + _ => {} + } + } + fn validate_builtin_calls_in_stmt( + s: &crate::exa_wasm::interpreter::ast::Stmt, + errors: &mut Vec, + ) { + match s { + crate::exa_wasm::interpreter::ast::Stmt::Assign(_, rhs) => { + validate_builtin_calls_in_expr(rhs, errors) + } + crate::exa_wasm::interpreter::ast::Stmt::Expr(e) => { + validate_builtin_calls_in_expr(e, errors) + } + crate::exa_wasm::interpreter::ast::Stmt::Block(v) => { + for st in v.iter() { + validate_builtin_calls_in_stmt(st, errors); + } + } + crate::exa_wasm::interpreter::ast::Stmt::If { + cond, + then_branch, + else_branch, + } => { + validate_builtin_calls_in_expr(cond, errors); + validate_builtin_calls_in_stmt(then_branch, errors); + if let Some(eb) = else_branch { + validate_builtin_calls_in_stmt(eb, errors); + } + } + } + } + + for s in diffeq_stmts.iter() { + validate_builtin_calls_in_stmt(s, &mut parse_errors); + } + for s in out_stmts.iter() { + validate_builtin_calls_in_stmt(s, &mut parse_errors); + } + for s in init_stmts.iter() { + validate_builtin_calls_in_stmt(s, &mut parse_errors); + } + + // Validate any bytecode maps present in the IR now that we know nstates + // and nparams. This checks param/x/local bounds and builtin arities. + { + use crate::exa_wasm::interpreter::Opcode; + fn validate_code( + code: &Vec, + nstates: usize, + nparams: usize, + locals_len: usize, + funcs: &Vec, + parse_errors: &mut Vec, + ) { + for (pc, op) in code.iter().enumerate() { + match op { + Opcode::LoadParam(i) => { + if *i >= nparams { + parse_errors.push(format!( + "LoadParam index out of bounds at pc {}: {} >= nparams {}", + pc, i, nparams + )); + } + } + Opcode::LoadX(i) + | Opcode::StoreX(i) + | Opcode::StoreY(i) + | Opcode::StoreDx(i) => { + if *i >= nstates { + parse_errors.push(format!( + "x/dx/index out of bounds at pc {}: {} >= nstates {}", + pc, i, nstates + )); + } + } + Opcode::LoadRateiv(i) => { + if *i >= nstates { + parse_errors.push(format!( + "rateiv index out of bounds at pc {}: {} >= nstates {}", + pc, i, nstates + )); + } + } + Opcode::LoadLocal(i) | Opcode::StoreLocal(i) => { + if *i >= locals_len { + parse_errors.push(format!( + "local slot out of bounds at pc {}: {} >= locals_len {}", + pc, i, locals_len + )); + } + } + Opcode::CallBuiltin(func_idx, argc) => { + if *func_idx >= funcs.len() { + parse_errors.push(format!( + "CallBuiltin references unknown func index {} at pc {}", + func_idx, pc + )); + } else { + let fname = funcs.get(*func_idx).unwrap().as_str(); + match crate::exa_wasm::interpreter::arg_count_range(fname) { + Some(range) => { + if !range.contains(argc) { + parse_errors.push(format!("builtin '{}' called with wrong arity {} at pc {} (allowed {:?})", fname, argc, pc, range)); + } + } + None => parse_errors.push(format!( + "unknown builtin '{}' referenced in funcs table at pc {}", + fname, pc + )), + } + } + } + crate::exa_wasm::interpreter::Opcode::Pop => {} + // dynamic ops not fully checkable at compile time + Opcode::LoadParamDyn + | Opcode::LoadXDyn + | Opcode::LoadRateivDyn + | Opcode::StoreDxDyn + | Opcode::StoreXDyn + | Opcode::StoreYDyn => {} + _ => {} + } + } + } + + let funcs_table = ir.funcs.clone().unwrap_or_default(); + let locals_table = ir.locals.clone().unwrap_or_default(); + + let nparams = params.len(); + if let Some(map) = ir.diffeq_bytecode.clone() { + for (_k, code) in map.into_iter() { + validate_code( + &code, + nstates, + nparams, + locals_table.len(), + &funcs_table, + &mut parse_errors, + ); + } + } + // validate function-level diffeq bytecode if present + if let Some(code) = ir.diffeq_func.clone() { + validate_code(&code, nstates, nparams, locals_table.len(), &funcs_table, &mut parse_errors); + } + if let Some(map) = ir.out_bytecode.clone() { + for (_k, code) in map.into_iter() { + validate_code( + &code, + nstates, + nparams, + locals_table.len(), + &funcs_table, + &mut parse_errors, + ); + } + } + if let Some(map) = ir.init_bytecode.clone() { + for (_k, code) in map.into_iter() { + validate_code( + &code, + nstates, + nparams, + locals_table.len(), + &funcs_table, + &mut parse_errors, + ); + } + } + if let Some(map) = ir.lag_bytecode.clone() { + for (_k, code) in map.into_iter() { + validate_code( + &code, + nstates, + nparams, + locals_table.len(), + &funcs_table, + &mut parse_errors, + ); + } + } + if let Some(map) = ir.fa_bytecode.clone() { + for (_k, code) in map.into_iter() { + validate_code( + &code, + nstates, + nparams, + locals_table.len(), + &funcs_table, + &mut parse_errors, + ); + } + } + } + + if !parse_errors.is_empty() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("parse errors: {}", parse_errors.join("; ")), + )); + } + + let entry = registry::RegistryEntry { + diffeq_stmts, + out_stmts, + init_stmts, + lag: lag_map, + fa: fa_map, + prelude, + pmap: pmap.clone(), + nstates, + _nouteqs: nouteqs, + // attach any emitted bytecode maps (empty if emitter didn't provide them) + bytecode_diffeq: ir.diffeq_bytecode.unwrap_or_default(), + bytecode_out: ir.out_bytecode.unwrap_or_default(), + bytecode_init: ir.init_bytecode.unwrap_or_default(), + bytecode_lag: ir.lag_bytecode.unwrap_or_default(), + bytecode_fa: ir.fa_bytecode.unwrap_or_default(), + // optional function-level bytecode + bytecode_diffeq_func: ir.diffeq_func.unwrap_or_default(), + // function table and locals ordering emitted by the compiler + funcs: ir.funcs.unwrap_or_default(), + locals: ir.locals.unwrap_or_default(), + }; + + let id = registry::register_entry(entry); + + let ode = crate::simulator::equation::ODE::with_registry_id( + crate::exa_wasm::interpreter::dispatch::diffeq_dispatch, + crate::exa_wasm::interpreter::dispatch::lag_dispatch, + crate::exa_wasm::interpreter::dispatch::fa_dispatch, + crate::exa_wasm::interpreter::dispatch::init_dispatch, + crate::exa_wasm::interpreter::dispatch::out_dispatch, + (nstates, nouteqs), + Some(id), + ); + Ok((ode, meta, id)) +} + +#[cfg(test)] +mod tests { + use crate::exa_wasm::interpreter::ast::{Expr, Lhs, Stmt}; + use crate::exa_wasm::interpreter::parser::{tokenize, Parser}; + + // simple extractor for the inner closure body used in tests + fn extract_body(src: &str) -> String { + let lb = src.find('{').expect("no '{' found"); + let rb = src.rfind('}').expect("no '}' found"); + src[lb + 1..rb].to_string() + } + + fn extract_and_parse(src: &str) -> Vec { + let mut cleaned = extract_body(src); + // normalize booleans for parser (tests don't include macros) + cleaned = cleaned.replace("true", "1.0").replace("false", "0.0"); + let toks = tokenize(&cleaned); + let mut p = Parser::new(toks); + p.parse_statements().expect("parse_statements failed") + } + + fn contains_dx_assign(stmt: &Stmt, idx_expected: usize) -> bool { + match stmt { + Stmt::Assign(lhs, _rhs) => match lhs { + Lhs::Indexed(name, idx_expr) => { + if name == "dx" { + if let Expr::Number(n) = &**idx_expr { + return (*n as usize) == idx_expected; + } + } + false + } + _ => false, + }, + Stmt::Block(v) => v.iter().any(|s| contains_dx_assign(s, idx_expected)), + Stmt::If { + then_branch, + else_branch, + .. + } => { + contains_dx_assign(then_branch, idx_expected) + || else_branch + .as_ref() + .map(|b| contains_dx_assign(b, idx_expected)) + .unwrap_or(false) + } + Stmt::Expr(_) => false, + } + } + + #[test] + fn test_if_true_parsed_cond_is_one_and_assign_present() { + let src = "|x, p, _t, dx, rateiv, _cov| { if true { dx[0] = -ke * x[0]; } }"; + let stmts = extract_and_parse(src); + assert!(!stmts.is_empty()); + let mut found = false; + for st in stmts.iter() { + if let Stmt::If { + cond, then_branch, .. + } = st + { + if let Expr::Number(n) = cond { + assert_eq!(*n, 1.0f64); + } else { + panic!("cond not normalized to number for 'true'"); + } + assert!(contains_dx_assign(then_branch, 0)); + found = true; + break; + } + } + assert!(found, "No If statement found in parsed stmts"); + } + + #[test] + fn test_if_false_parsed_cond_is_zero_and_assign_present() { + let src = "|x, p, _t, dx, rateiv, _cov| { if false { dx[0] = -ke * x[0]; } }"; + let stmts = extract_and_parse(src); + assert!(!stmts.is_empty()); + let mut found = false; + for st in stmts.iter() { + if let Stmt::If { + cond, then_branch, .. + } = st + { + if let Expr::Number(n) = cond { + assert_eq!(*n, 0.0f64); + } else { + panic!("cond not normalized to number for 'false'"); + } + // parser still preserves the assignment in the then-branch + assert!(contains_dx_assign(then_branch, 0)); + found = true; + break; + } + } + assert!(found, "No If statement found in parsed stmts"); + } +} diff --git a/src/exa_wasm/interpreter/loader_helpers.rs b/src/exa_wasm/interpreter/loader_helpers.rs new file mode 100644 index 00000000..173c9d6c --- /dev/null +++ b/src/exa_wasm/interpreter/loader_helpers.rs @@ -0,0 +1,299 @@ +use crate::exa_wasm::interpreter::ast::{Expr, Stmt}; +use std::collections::HashMap; + +// Loader helper utilities used by `loader.rs`. These functions implement a +// conservative extraction and validation surface that mirrors the prior inline +// implementations in `loader.rs` so they can be reused and unit-tested. + +// Loader helper utilities extracted from the large `load_ir_ode` function. + +// ongoing refactor can wire them into `loader.rs` incrementally. + +/// Rewrite parameter identifier `Ident(name)` nodes in a parsed statement +/// vector into `Expr::Param(index)` nodes using the provided `pmap`. +// NOTE: textual rewriting of params in statement vectors was previously +// provided as a helper. The emitter now emits rewritten ASTs (Param nodes) +// directly, and the runtime loader consumes pre-parsed ASTs. This helper +// was removed as part of removing fragile textual fallbacks. + +/// Return the body text inside the first top-level pair of braces. +/// Example: given `|t, y| { ... }` returns Some("...") or None. +pub fn extract_closure_body(src: &str) -> Option { + if let Some(lb_pos) = src.find('{') { + let bytes = src.as_bytes(); + let mut depth: isize = 0; + let mut i = lb_pos; + while i < bytes.len() { + match bytes[i] as char { + '{' => depth += 1, + '}' => { + depth -= 1; + if depth == 0 { + // return inner text between lb_pos and i + let inner = &src[lb_pos + 1..i]; + return Some(inner.to_string()); + } + } + _ => {} + } + i += 1; + } + } + None +} +// The textual extraction helpers (macro-stripping, prelude scanning and +// textual `fetch_*` extraction) have been removed. The emitter now emits +// structured `fetch_params` and `fetch_cov` fields in the IR and rewrites +// parameter identifiers into `Expr::Param` nodes before emission. The +// runtime loader consumes the structured IR and no longer attempts to scan +// raw closure text at runtime. + +/// Lightweight validator stubs (moved out of loader.rs so the loader can +/// call into a shared place). These can be expanded to perform expression +/// and statement validations that previously lived inside load_ir_ode. +pub fn validate_expr( + expr: &Expr, + pmap: &HashMap, + nstates: usize, + nparams: usize, + errors: &mut Vec, +) { + match expr { + Expr::Number(_) => {} + Expr::Bool(_) => {} + Expr::Ident(name) => { + if name == "t" { + return; + } + if pmap.contains_key(name) { + return; + } + errors.push(format!("unknown identifier '{}'", name)); + } + Expr::Param(_) => { + // param by index is valid + } + Expr::Indexed(name, idx_expr) => match &**idx_expr { + Expr::Number(n) => { + let idx = *n as usize; + match name.as_str() { + "x" | "rateiv" => { + if idx >= nstates { + errors.push(format!( + "index out of bounds '{}'[{}] (nstates={})", + name, idx, nstates + )); + } + } + "p" | "params" => { + if idx >= nparams { + errors.push(format!( + "parameter index out of bounds '{}'[{}] (nparams={})", + name, idx, nparams + )); + } + } + "y" => {} + _ => { + errors.push(format!("unknown indexed symbol '{}'", name)); + } + } + } + other => validate_expr(other, pmap, nstates, nparams, errors), + }, + Expr::UnaryOp { rhs, .. } => validate_expr(rhs, pmap, nstates, nparams, errors), + Expr::BinaryOp { lhs, rhs, .. } => { + validate_expr(lhs, pmap, nstates, nparams, errors); + validate_expr(rhs, pmap, nstates, nparams, errors); + } + Expr::Call { name: _, args } => { + for a in args.iter() { + validate_expr(a, pmap, nstates, nparams, errors); + } + } + Expr::MethodCall { + receiver, + name: _, + args, + } => { + validate_expr(receiver, pmap, nstates, nparams, errors); + for a in args.iter() { + validate_expr(a, pmap, nstates, nparams, errors); + } + } + Expr::Ternary { + cond, + then_branch, + else_branch, + } => { + validate_expr(cond, pmap, nstates, nparams, errors); + validate_expr(then_branch, pmap, nstates, nparams, errors); + validate_expr(else_branch, pmap, nstates, nparams, errors); + } + } +} + +pub fn validate_prelude_expr( + expr: &Expr, + pmap: &HashMap, + known_locals: &std::collections::HashSet, + nstates: usize, + nparams: usize, + errors: &mut Vec, +) { + match expr { + Expr::Number(_) => {} + Expr::Bool(_) => {} + Expr::Ident(name) => { + if name == "t" { + return; + } + if known_locals.contains(name) { + return; + } + if pmap.contains_key(name) { + return; + } + errors.push(format!("unknown identifier '{}' in prelude", name)); + } + Expr::Param(_) => {} + Expr::Indexed(name, idx_expr) => match &**idx_expr { + Expr::Number(n) => { + let idx = *n as usize; + match name.as_str() { + "x" | "rateiv" => { + if idx >= nstates { + errors.push(format!( + "index out of bounds '{}'[{}] (nstates={})", + name, idx, nstates + )); + } + } + "p" | "params" => { + if idx >= nparams { + errors.push(format!( + "parameter index out of bounds '{}'[{}] (nparams={})", + name, idx, nparams + )); + } + } + "y" => {} + _ => { + errors.push(format!("unknown indexed symbol '{}'", name)); + } + } + } + other => validate_prelude_expr(other, pmap, known_locals, nstates, nparams, errors), + }, + Expr::UnaryOp { rhs, .. } => { + validate_prelude_expr(rhs, pmap, known_locals, nstates, nparams, errors) + } + Expr::BinaryOp { lhs, rhs, .. } => { + validate_prelude_expr(lhs, pmap, known_locals, nstates, nparams, errors); + validate_prelude_expr(rhs, pmap, known_locals, nstates, nparams, errors); + } + Expr::Call { name: _, args } => { + for a in args.iter() { + validate_prelude_expr(a, pmap, known_locals, nstates, nparams, errors); + } + } + Expr::MethodCall { + receiver, + name: _, + args, + } => { + validate_prelude_expr(receiver, pmap, known_locals, nstates, nparams, errors); + for a in args.iter() { + validate_prelude_expr(a, pmap, known_locals, nstates, nparams, errors); + } + } + Expr::Ternary { + cond, + then_branch, + else_branch, + } => { + validate_prelude_expr(cond, pmap, known_locals, nstates, nparams, errors); + validate_prelude_expr(then_branch, pmap, known_locals, nstates, nparams, errors); + validate_prelude_expr(else_branch, pmap, known_locals, nstates, nparams, errors); + } + } +} + +pub fn validate_stmt( + st: &Stmt, + pmap: &HashMap, + nstates: usize, + nparams: usize, + errors: &mut Vec, +) { + use crate::exa_wasm::interpreter::ast::{Lhs, Stmt}; + match st { + Stmt::Expr(e) => validate_expr(e, pmap, nstates, nparams, errors), + Stmt::Assign(lhs, rhs) => { + validate_expr(rhs, pmap, nstates, nparams, errors); + if let Lhs::Indexed(_, idx_expr) = lhs { + validate_expr(idx_expr, pmap, nstates, nparams, errors); + } + } + Stmt::Block(v) => { + for s in v.iter() { + validate_stmt(s, pmap, nstates, nparams, errors); + } + } + Stmt::If { + cond, + then_branch, + else_branch, + } => { + validate_expr(cond, pmap, nstates, nparams, errors); + validate_stmt(then_branch, pmap, nstates, nparams, errors); + if let Some(eb) = else_branch { + validate_stmt(eb, pmap, nstates, nparams, errors); + } + } + } +} + +pub fn collect_max_index( + stmts: &Vec, + _name: &str, +) -> Option { + let mut max: Option = None; + fn visit(s: &crate::exa_wasm::interpreter::ast::Stmt, max: &mut Option) { + use crate::exa_wasm::interpreter::ast::Lhs; + match s { + crate::exa_wasm::interpreter::ast::Stmt::Assign(lhs, _) => { + if let Lhs::Indexed(_nm, idx_expr) = lhs { + if let crate::exa_wasm::interpreter::ast::Expr::Number(nn) = &**idx_expr { + let idx = *nn as usize; + match max { + Some(m) if *m < idx => *max = Some(idx), + None => *max = Some(idx), + _ => {} + } + } + } + } + crate::exa_wasm::interpreter::ast::Stmt::Block(v) => { + for ss in v.iter() { + visit(ss, max); + } + } + crate::exa_wasm::interpreter::ast::Stmt::If { + then_branch, + else_branch, + .. + } => { + visit(then_branch, max); + if let Some(eb) = else_branch { + visit(eb, max); + } + } + crate::exa_wasm::interpreter::ast::Stmt::Expr(_) => {} + } + } + for s in stmts.iter() { + visit(s, &mut max); + } + max +} diff --git a/src/exa_wasm/interpreter/mod.rs b/src/exa_wasm/interpreter/mod.rs new file mode 100644 index 00000000..de579c50 --- /dev/null +++ b/src/exa_wasm/interpreter/mod.rs @@ -0,0 +1,1381 @@ +mod ast; +mod builtins; +mod dispatch; +mod eval; +mod loader; +mod loader_helpers; +mod parser; +mod registry; +mod typecheck; +mod vm; + +pub use loader::load_ir_ode; +pub use parser::tokenize; +pub use parser::Parser; +pub use registry::{ + ode_for_id, set_current_expr_id, set_runtime_error, take_runtime_error, unregister_model, +}; + +pub use vm::{run_bytecode, Opcode}; + +// Re-export some AST and helper symbols for other sibling modules (e.g. build) +pub use ast::{Expr, Lhs, Stmt}; +pub use loader_helpers::extract_closure_body; +// Re-export builtin helpers so other modules (like the emitter) can query +// builtin metadata without depending on private module paths. +pub use builtins::{arg_count_range, is_known_function}; + +// Keep a small set of unit tests that exercise the parser/eval and loader +// wiring. Runtime dispatch and registry behavior live in the `dispatch` +// and `registry` modules respectively. +#[cfg(test)] +mod tests { + use super::*; + use crate::exa_wasm::interpreter::eval::eval_expr; + use diffsol::Vector; + + #[test] + fn test_tokenize_and_parse_simple() { + let s = "-ke * x[0] + rateiv[0] / 2"; + let toks = tokenize(s); + let mut p = Parser::new(toks); + let expr = p.parse_expr().expect("parse failed"); + // evaluate with dummy vectors + use crate::simulator::V; + let x = V::zeros(1, diffsol::NalgebraContext); + let mut pvec = V::zeros(1, diffsol::NalgebraContext); + pvec[0] = 3.0; // ke + let rateiv = V::zeros(1, diffsol::NalgebraContext); + // evaluation should succeed (ke resolves via pmap not provided -> 0) + let val = eval_expr(&expr, &x, &pvec, &rateiv, None, None, Some(0.0), None); + // numeric result must be finite + assert!(val.as_number().is_finite()); + } + + #[test] + fn test_unknown_function_sets_runtime_error() { + use crate::exa_wasm::interpreter::eval::eval_call; + // clear any prior runtime error + crate::exa_wasm::interpreter::take_runtime_error(); + // call an unknown function + let val = eval_call("this_function_does_not_exist", &[]); + // evaluator returns Number(0.0) for unknowns but should set a runtime error + use crate::exa_wasm::interpreter::eval::Value; + assert_eq!(val, Value::Number(0.0)); + let err = crate::exa_wasm::interpreter::take_runtime_error(); + assert!(err.is_some(), "expected runtime error for unknown function"); + let msg = err.unwrap(); + assert!( + msg.contains("unknown function"), + "unexpected error message: {}", + msg + ); + } + + #[test] + fn test_eval_call_rejects_wrong_arity() { + use crate::exa_wasm::interpreter::eval::eval_call; + use crate::exa_wasm::interpreter::eval::Value; + // clear any prior runtime error + crate::exa_wasm::interpreter::take_runtime_error(); + // call pow with wrong arity (should be 2 args) + let val = eval_call("pow", &[Value::Number(1.0)]); + assert_eq!(val, Value::Number(0.0)); + let err = crate::exa_wasm::interpreter::take_runtime_error(); + assert!(err.is_some(), "expected runtime error for wrong arity"); + let msg = err.unwrap(); + assert!( + msg.contains("wrong arity") || msg.contains("unknown function"), + "unexpected error message: {}", + msg + ); + } + + #[test] + fn test_loader_errors_on_unknown_function() { + use std::env; + use std::fs; + let tmp = env::temp_dir().join("exa_test_ir_unknown_fn.json"); + // Use the emitter to create IR that includes parsed AST; loader will + // then validate and reject unknown function calls. + let diffeq = "|x, p, _t, dx, rateiv, _cov| { dx[0] = foobar(1.0); }".to_string(); + let _path = crate::exa_wasm::build::emit_ir::( + diffeq, + None, + None, + None, + None, + Some(tmp.clone()), + vec!["ke".to_string(), "v".to_string()], + ) + .expect("emit_ir failed"); + let res = crate::exa_wasm::interpreter::loader::load_ir_ode(tmp.clone()); + fs::remove_file(tmp).ok(); + assert!( + res.is_err(), + "loader should reject IR with unknown function calls" + ); + } + + #[test] + fn test_macro_parsing_load_ir() { + use std::env; + use std::fs; + let tmp = env::temp_dir().join("exa_test_ir_lag.json"); + let diffeq = "|x, p, _t, dx, rateiv, _cov| { dx[0] = 0.0; }".to_string(); + // lag text contains function calls and commas inside calls + let lag = Some( + "|p, t, _cov| { lag!{0 => max(1.0, t * 2.0), 1 => if(t>0, 2.0, 3.0)} }".to_string(), + ); + let _path = crate::exa_wasm::build::emit_ir::( + diffeq, + lag, + None, + None, + None, + Some(tmp.clone()), + vec![], + ) + .expect("emit_ir failed"); + let res = load_ir_ode(tmp.clone()); + fs::remove_file(tmp).ok(); + assert!(res.is_ok()); + } + + #[test] + fn test_emit_ir_includes_diffeq_ast_and_schema() { + use std::env; + use std::fs; + + let tmp = env::temp_dir().join("exa_test_emit_ir_diffeq_ast_and_schema.json"); + let diffeq = + "|x, p, _t, dx, rateiv, _cov| { if (t > 0) { dx[0] = 1.0; } else { dx[0] = 2.0; } }" + .to_string(); + let _path = crate::exa_wasm::build::emit_ir::( + diffeq, + None, + None, + None, + None, + Some(tmp.clone()), + vec!["ke".to_string()], + ) + .expect("emit_ir failed"); + let s = fs::read_to_string(&tmp).expect("read emitted ir"); + let v: serde_json::Value = serde_json::from_str(&s).expect("parse json"); + fs::remove_file(&tmp).ok(); + assert!( + v.get("diffeq_ast").is_some(), + "emit_ir should include diffeq_ast" + ); + // schema metadata should be present + assert!( + v.get("ir_schema").is_some(), + "emit_ir should include ir_schema" + ); + } + + #[test] + fn test_emit_ir_includes_out_and_init_ast() { + use std::env; + use std::fs; + + let tmp = env::temp_dir().join("exa_test_emit_ir_out_init_ast.json"); + let out = "|x, p, _t, _cov, y| { y[0] = x[0] + 1.0; }".to_string(); + let init = "|p, _t, _cov, x| { x[0] = 0.0; }".to_string(); + let _path = crate::exa_wasm::build::emit_ir::( + "".to_string(), + None, + None, + Some(init.clone()), + Some(out.clone()), + Some(tmp.clone()), + vec![], + ) + .expect("emit_ir failed"); + let s = fs::read_to_string(&tmp).expect("read emitted ir"); + let v: serde_json::Value = serde_json::from_str(&s).expect("parse json"); + fs::remove_file(&tmp).ok(); + assert!(v.get("out_ast").is_some(), "emit_ir should include out_ast"); + assert!( + v.get("init_ast").is_some(), + "emit_ir should include init_ast" + ); + } + + #[test] + fn test_emit_ir_includes_bytecode_map_and_vm_exec() { + use crate::exa_wasm::interpreter::{run_bytecode, Opcode}; + use std::env; + use std::fs; + + let tmp = env::temp_dir().join("exa_test_emit_ir_bytecode.json"); + let diffeq = "|x, p, _t, dx, rateiv, _cov| { dx[0] = ke * 2.0; }".to_string(); + let params = vec!["ke".to_string()]; + let _path = crate::exa_wasm::build::emit_ir::( + diffeq, + None, + None, + None, + None, + Some(tmp.clone()), + params.clone(), + ) + .expect("emit_ir failed"); + let s = fs::read_to_string(&tmp).expect("read emitted ir"); + let v: serde_json::Value = serde_json::from_str(&s).expect("parse json"); + fs::remove_file(&tmp).ok(); + // ensure bytecode_map present + let bc = v + .get("bytecode_map") + .expect("bytecode_map should be present") + .clone(); + // deserialize into map + let map: std::collections::HashMap> = + serde_json::from_value(bc).expect("deserialize bytecode_map"); + assert!(map.contains_key(&0usize)); + let code = map.get(&0usize).unwrap(); + // execute bytecode with p = [3.0] + let pvals = vec![3.0f64]; + let mut assigned: Option<(usize, f64)> = None; + run_bytecode(&code, &pvals, |i, v| { + assigned = Some((i, v)); + }); + assert!(assigned.is_some()); + let (i, val) = assigned.unwrap(); + assert_eq!(i, 0usize); + assert_eq!(val, 6.0f64); + } + + #[test] + fn test_loader_rewrites_params_to_param_nodes() { + use crate::exa_wasm::interpreter::ast::{Expr, Stmt}; + use std::env; + use std::fs; + + let tmp = env::temp_dir().join("exa_test_ir_param_rewrite.json"); + let diffeq = "|x, p, _t, dx, rateiv, _cov| { dx[0] = ke * x[0]; }".to_string(); + let _path = crate::exa_wasm::build::emit_ir::( + diffeq, + None, + None, + None, + None, + Some(tmp.clone()), + vec!["ke".to_string(), "v".to_string()], + ) + .expect("emit_ir failed"); + let res = crate::exa_wasm::interpreter::loader::load_ir_ode(tmp.clone()); + fs::remove_file(tmp).ok(); + assert!(res.is_ok(), "loader should accept valid IR"); + let (_ode, _meta, id) = res.unwrap(); + let entry = crate::exa_wasm::interpreter::registry::get_entry(id).expect("entry"); + + fn contains_param_in_expr(e: &Expr, idx: usize) -> bool { + match e { + Expr::Param(i) => *i == idx, + Expr::BinaryOp { lhs, rhs, .. } => { + contains_param_in_expr(lhs, idx) || contains_param_in_expr(rhs, idx) + } + Expr::UnaryOp { rhs, .. } => contains_param_in_expr(rhs, idx), + Expr::Call { args, .. } => args.iter().any(|a| contains_param_in_expr(a, idx)), + Expr::MethodCall { receiver, args, .. } => { + contains_param_in_expr(receiver, idx) + || args.iter().any(|a| contains_param_in_expr(a, idx)) + } + Expr::Indexed(_, idx_expr) => contains_param_in_expr(idx_expr, idx), + Expr::Ternary { + cond, + then_branch, + else_branch, + } => { + contains_param_in_expr(cond, idx) + || contains_param_in_expr(then_branch, idx) + || contains_param_in_expr(else_branch, idx) + } + _ => false, + } + } + + fn contains_param(stmt: &Stmt, idx: usize) -> bool { + match stmt { + Stmt::Assign(_, rhs) => contains_param_in_expr(rhs, idx), + Stmt::Block(v) => v.iter().any(|s| contains_param(s, idx)), + Stmt::If { + then_branch, + else_branch, + .. + } => { + contains_param(then_branch, idx) + || else_branch + .as_ref() + .map(|b| contains_param(b, idx)) + .unwrap_or(false) + } + Stmt::Expr(e) => contains_param_in_expr(e, idx), + } + } + + assert!( + entry.diffeq_stmts.iter().any(|s| contains_param(s, 0)), + "expected Param(0) in diffeq stmts" + ); + } + + #[test] + fn test_eval_param_expr() { + use crate::exa_wasm::interpreter::ast::Expr; + use crate::exa_wasm::interpreter::eval::eval_expr; + use crate::simulator::V; + + let expr = Expr::Param(0); + // create simple vectors + use diffsol::NalgebraContext; + let x = V::zeros(1, NalgebraContext); + let mut p = V::zeros(1, NalgebraContext); + p[0] = 3.1415; + let rateiv = V::zeros(1, NalgebraContext); + + let val = eval_expr(&expr, &x, &p, &rateiv, None, None, Some(0.0), None); + assert_eq!(val.as_number(), 3.1415); + } + + #[test] + fn test_loader_accepts_preparsed_ast_in_ir() { + use crate::exa_wasm::interpreter::ast::{Expr, Lhs, Stmt}; + use std::env; + use std::fs; + + let tmp = env::temp_dir().join("exa_test_ir_preparsed_ast.json"); + // build a tiny diffeq AST: dx[0] = 1.0; + let lhs = Lhs::Indexed("dx".to_string(), Box::new(Expr::Number(0.0))); + let stmt = Stmt::Assign(lhs, Expr::Number(1.0)); + let diffeq_ast = vec![stmt]; + + let ir_json = serde_json::json!({ + "ir_version": "1.0", + "kind": "EqnKind::ODE", + "params": [], + "diffeq": "", + "diffeq_ast": diffeq_ast, + "lag": "", + "fa": "", + "init": "", + "out": "" + }); + let s = serde_json::to_string_pretty(&ir_json).expect("serialize"); + fs::write(&tmp, s.as_bytes()).expect("write tmp"); + + let res = crate::exa_wasm::interpreter::loader::load_ir_ode(tmp.clone()); + fs::remove_file(tmp).ok(); + assert!( + res.is_ok(), + "loader should accept IR with pre-parsed diffeq_ast" + ); + } + + #[test] + fn test_loader_rejects_builtin_wrong_arity() { + use std::env; + use std::fs; + let tmp = env::temp_dir().join("exa_test_ir_bad_arity.json"); + let diffeq = "|x, p, _t, dx, rateiv, _cov| { dx[0] = pow(1.0); }".to_string(); + let _path = crate::exa_wasm::build::emit_ir::( + diffeq, + None, + None, + None, + None, + Some(tmp.clone()), + vec!["ke".to_string()], + ) + .expect("emit_ir failed"); + let res = crate::exa_wasm::interpreter::loader::load_ir_ode(tmp.clone()); + fs::remove_file(tmp).ok(); + assert!( + res.is_err(), + "loader should reject builtin calls with wrong arity" + ); + } + + mod load_negative_tests { + use super::*; + use std::env; + use std::fs; + + #[test] + fn test_loader_errors_when_missing_structured_maps() { + let tmp = env::temp_dir().join("exa_test_ir_negative.json"); + let ir_json = serde_json::json!({ + "ir_version": "1.0", + "kind": "EqnKind::ODE", + "params": ["ke", "v"], + "diffeq": "|x, p, _t, dx, rateiv, _cov| { dx[0] = -ke * x[0] + rateiv[0]; }", + "lag": "|p, t, _cov| { lag!{0 => t} }", + "fa": "|p, t, _cov| { fa!{0 => 0.1} }", + "init": "|p, _t, _cov, x| { }", + "out": "|x, p, _t, _cov, y| { y[0] = x[0]; }" + }); + let s = serde_json::to_string_pretty(&ir_json).expect("serialize"); + fs::write(&tmp, s.as_bytes()).expect("write tmp"); + + let res = load_ir_ode(tmp.clone()); + fs::remove_file(tmp).ok(); + assert!( + res.is_err(), + "loader should reject IR missing structured maps" + ); + } + } + + #[test] + fn test_bytecode_parity_constant_index() { + use crate::exa_wasm::interpreter::eval::eval_expr; + use crate::exa_wasm::interpreter::vm::run_bytecode_eval; + use crate::exa_wasm::interpreter::Opcode; + use std::env; + use std::fs; + + let tmp = env::temp_dir().join("exa_test_parity_const.json"); + let diffeq = "|x, p, _t, dx, rateiv, _cov| { dx[0] = x[0] + 2.0; }".to_string(); + let _path = crate::exa_wasm::build::emit_ir::( + diffeq, + None, + None, + None, + None, + Some(tmp.clone()), + vec![], + ) + .expect("emit_ir failed"); + let s = fs::read_to_string(&tmp).expect("read emitted ir"); + let v: serde_json::Value = serde_json::from_str(&s).expect("parse json"); + fs::remove_file(&tmp).ok(); + + // extract AST rhs expression + let diffeq_ast = v.get("diffeq_ast").expect("diffeq_ast"); + let stmts: Vec = + serde_json::from_value(diffeq_ast.clone()).expect("deserialize stmts"); + // expect first stmt to be Assign(_, rhs) + let rhs_expr = match &stmts[0] { + crate::exa_wasm::interpreter::Stmt::Assign(_, rhs) => rhs.clone(), + _ => panic!("expected assign stmt"), + }; + + use diffsol::NalgebraContext; + let x = crate::simulator::V::zeros(1, NalgebraContext); + let mut x = x; + x[0] = 5.0; + let p = crate::simulator::V::zeros(0, NalgebraContext); + let rateiv = crate::simulator::V::zeros(0, NalgebraContext); + + let ast_val = eval_expr(&rhs_expr, &x, &p, &rateiv, None, None, Some(0.0), None); + + // extract bytecode for index 0 + // If emitter did not produce bytecode for this pattern, skip the VM + // parity check here. The test harness will still exercise the AST + // path; missing bytecode means the emitter needs expanded lowering. + let bc = match v.get("diffeq_bytecode") { + Some(b) => b, + None => { + eprintln!("emit_ir did not produce diffeq_bytecode for method-call test; skipping VM parity check"); + return; + } + }; + let map: std::collections::HashMap> = + serde_json::from_value(bc.clone()).expect("deserialize bytecode_map"); + let code = map.get(&0usize).expect("code for idx 0"); + // strip trailing StoreDx + let mut expr_code = code.clone(); + if let Some(last) = expr_code.last() { + match last { + Opcode::StoreDx(_) => { + expr_code.pop(); + } + _ => {} + } + } + + // builtins dispatch + let builtins = |name: &str, args: &[f64]| -> f64 { + use crate::exa_wasm::interpreter::eval::{eval_call, Value}; + let vals: Vec = args.iter().map(|v| Value::Number(*v)).collect(); + eval_call(name, &vals).as_number() + }; + + let mut locals: Vec = Vec::new(); + let mut locals_slice = locals.as_mut_slice(); + let x_vals: Vec = vec![x[0]]; + let p_vals: Vec = vec![]; + let rateiv_vals: Vec = vec![]; + let mut funcs: Vec = Vec::new(); + if let Some(fv) = v.get("funcs") { + funcs = serde_json::from_value(fv.clone()).unwrap_or_default(); + } + + // debug: show discovered funcs and expr_code + eprintln!("debug funcs: {:?}", funcs); + eprintln!("debug expr_code: {:?}", expr_code); + + let vm_val = run_bytecode_eval( + &expr_code, + &x_vals, + &p_vals, + &rateiv_vals, + 0.0, + &mut locals_slice, + &funcs, + &builtins, + ); + + assert_eq!(ast_val.as_number(), vm_val); + } + + #[test] + fn test_bytecode_parity_dynamic_index() { + use crate::exa_wasm::interpreter::eval::eval_expr; + use crate::exa_wasm::interpreter::vm::run_bytecode_eval; + use crate::exa_wasm::interpreter::Opcode; + use std::env; + use std::fs; + + let tmp = env::temp_dir().join("exa_test_parity_dyn.json"); + let diffeq = "|x, p, _t, dx, rateiv, _cov| { dx[0] = x[ke]; }".to_string(); + let params = vec!["ke".to_string()]; + let _path = crate::exa_wasm::build::emit_ir::( + diffeq, + None, + None, + None, + None, + Some(tmp.clone()), + params.clone(), + ) + .expect("emit_ir failed"); + let s = fs::read_to_string(&tmp).expect("read emitted ir"); + let v: serde_json::Value = serde_json::from_str(&s).expect("parse json"); + fs::remove_file(&tmp).ok(); + + // extract AST rhs expression + // prefer pre-parsed AST when present, otherwise parse the closure text + let rhs_expr = if let Some(diffeq_ast) = v.get("diffeq_ast") { + let stmts: Vec = + serde_json::from_value(diffeq_ast.clone()).expect("deserialize stmts"); + match &stmts[0] { + crate::exa_wasm::interpreter::Stmt::Assign(_, rhs) => rhs.clone(), + _ => panic!("expected assign stmt"), + } + } else { + let diffeq_text = v + .get("diffeq") + .and_then(|d| d.as_str()) + .expect("diffeq text"); + let body = crate::exa_wasm::interpreter::extract_closure_body(diffeq_text) + .expect("closure body"); + let toks = crate::exa_wasm::interpreter::tokenize(&body); + let mut p = crate::exa_wasm::interpreter::Parser::new(toks); + // parse statements and extract rhs from first assign + if let Some(stmts) = p.parse_statements() { + match &stmts[0] { + crate::exa_wasm::interpreter::Stmt::Assign(_, rhs) => rhs.clone(), + _ => panic!("expected assign stmt"), + } + } else { + // fallback: attempt to extract RHS between '=' and ';' and parse as expression + let eq_pos = body.find('='); + if let Some(eq) = eq_pos { + if let Some(sc) = body[eq..].find(';') { + let rhs_text = body[eq + 1..eq + sc].trim(); + let toks = crate::exa_wasm::interpreter::tokenize(rhs_text); + let mut p2 = crate::exa_wasm::interpreter::Parser::new(toks); + p2.parse_expr().expect("parse expr rhs") + } else { + panic!("parse stmts"); + } + } else { + panic!("parse stmts"); + } + } + }; + + use diffsol::NalgebraContext; + let mut x = crate::simulator::V::zeros(2, NalgebraContext); + x[0] = 7.0; + x[1] = 9.0; + let mut p = crate::simulator::V::zeros(1, NalgebraContext); + p[0] = 0.0; // ke -> picks x[0] + let rateiv = crate::simulator::V::zeros(0, NalgebraContext); + + let ast_val = eval_expr(&rhs_expr, &x, &p, &rateiv, None, None, Some(0.0), None); + + // extract bytecode for index 0 + let bc = match v.get("diffeq_bytecode") { + Some(b) => b, + None => { + eprintln!("emit_ir did not produce diffeq_bytecode for method-call test; skipping VM parity check"); + return; + } + }; + let map: std::collections::HashMap> = + serde_json::from_value(bc.clone()).expect("deserialize bytecode_map"); + let code = map.get(&0usize).expect("code for idx 0"); + // strip trailing StoreDx + let mut expr_code = code.clone(); + if let Some(last) = expr_code.last() { + match last { + Opcode::StoreDx(_) => { + expr_code.pop(); + } + _ => {} + } + } + + let builtins = |name: &str, args: &[f64]| -> f64 { + use crate::exa_wasm::interpreter::eval::{eval_call, Value}; + let vals: Vec = args.iter().map(|v| Value::Number(*v)).collect(); + eval_call(name, &vals).as_number() + }; + + let mut locals: Vec = Vec::new(); + let mut locals_slice = locals.as_mut_slice(); + let x_vals: Vec = vec![x[0], x[1]]; + let p_vals: Vec = vec![p[0]]; + let rateiv_vals: Vec = vec![]; + let vm_val = run_bytecode_eval( + &expr_code, + &x_vals, + &p_vals, + &rateiv_vals, + 0.0, + &mut locals_slice, + &Vec::new(), + &builtins, + ); + + assert_eq!(ast_val.as_number(), vm_val); + } + + #[test] + fn test_bytecode_parity_lag_entry() { + use crate::exa_wasm::interpreter::eval::eval_expr; + use crate::exa_wasm::interpreter::vm::run_bytecode_eval; + use crate::exa_wasm::interpreter::Opcode; + use std::env; + use std::fs; + + let tmp = env::temp_dir().join("exa_test_parity_lag.json"); + let diffeq = "|x, p, _t, dx, rateiv, _cov| { dx[0] = 0.0; }".to_string(); + // use an expression that only references params so the conservative + // bytecode compiler can produce code (compile_expr_top does not + // accept bare 't' or unknown idents). + let lag = Some("|p, t, _cov| { lag!{0 => p[0] * 2.0} }".to_string()); + let _path = crate::exa_wasm::build::emit_ir::( + diffeq, + lag, + None, + None, + None, + Some(tmp.clone()), + vec![], + ) + .expect("emit_ir failed"); + let s = fs::read_to_string(&tmp).expect("read emitted ir"); + let v: serde_json::Value = serde_json::from_str(&s).expect("parse json"); + fs::remove_file(&tmp).ok(); + + // parse textual lag entry back to Expr for AST eval + let lag_map = v.get("lag_map").expect("lag_map"); + let lag_entry = lag_map + .get("0") + .expect("lag entry 0") + .as_str() + .expect("string"); + let toks = crate::exa_wasm::interpreter::tokenize(lag_entry); + let mut p = crate::exa_wasm::interpreter::Parser::new(toks); + let expr = p.parse_expr().expect("parse lag expr"); + + use diffsol::NalgebraContext; + let x = crate::simulator::V::zeros(0, NalgebraContext); + let rateiv = crate::simulator::V::zeros(0, NalgebraContext); + + // evaluate AST with p[0] = 3.0 -> expected 6.0 + let mut pvec = crate::simulator::V::zeros(1, diffsol::NalgebraContext); + pvec[0] = 3.0; + let ast_val = eval_expr(&expr, &x, &pvec, &rateiv, None, None, Some(0.0), None); + + // get lag_bytecode + let bc = v.get("lag_bytecode").expect("lag_bytecode"); + let map: std::collections::HashMap> = + serde_json::from_value(bc.clone()).expect("deserialize lag_bytecode"); + let code = map.get(&0usize).expect("code for lag 0"); + + let mut funcs: Vec = Vec::new(); + if let Some(fv) = v.get("funcs") { + funcs = serde_json::from_value(fv.clone()).unwrap_or_default(); + } + + let builtins = |name: &str, args: &[f64]| -> f64 { + use crate::exa_wasm::interpreter::eval::{eval_call, Value}; + let vals: Vec = args.iter().map(|v| Value::Number(*v)).collect(); + eval_call(name, &vals).as_number() + }; + + let mut locals: Vec = Vec::new(); + let mut locals_slice = locals.as_mut_slice(); + let x_vals: Vec = vec![]; + let p_vals: Vec = vec![3.0]; + let rateiv_vals: Vec = vec![]; + let vm_val = run_bytecode_eval( + &code, + &x_vals, + &p_vals, + &rateiv_vals, + 2.0, + &mut locals_slice, + &funcs, + &builtins, + ); + + assert_eq!(ast_val.as_number(), vm_val); + } + + #[test] + fn test_bytecode_parity_ternary() { + use crate::exa_wasm::interpreter::eval::eval_expr; + use crate::exa_wasm::interpreter::vm::run_bytecode_eval; + use crate::exa_wasm::interpreter::Opcode; + use std::env; + use std::fs; + + let tmp = env::temp_dir().join("exa_test_parity_ternary.json"); + let diffeq = "|x, p, _t, dx, rateiv, _cov| { dx[0] = x[0] > 0 ? 2.0 : 3.0; }".to_string(); + let _path = crate::exa_wasm::build::emit_ir::( + diffeq, + None, + None, + None, + None, + Some(tmp.clone()), + vec![], + ) + .expect("emit_ir failed"); + let s = fs::read_to_string(&tmp).expect("read emitted ir"); + let v: serde_json::Value = serde_json::from_str(&s).expect("parse json"); + fs::remove_file(&tmp).ok(); + + // prefer pre-parsed AST when present, otherwise parse the closure text + let rhs_expr = if let Some(diffeq_ast) = v.get("diffeq_ast") { + let stmts: Vec = + serde_json::from_value(diffeq_ast.clone()).expect("deserialize stmts"); + match &stmts[0] { + crate::exa_wasm::interpreter::Stmt::Assign(_, rhs) => rhs.clone(), + _ => panic!("expected assign stmt"), + } + } else { + let diffeq_text = v + .get("diffeq") + .and_then(|d| d.as_str()) + .expect("diffeq text"); + let body = crate::exa_wasm::interpreter::extract_closure_body(diffeq_text) + .expect("closure body"); + let toks = crate::exa_wasm::interpreter::tokenize(&body); + let mut p = crate::exa_wasm::interpreter::Parser::new(toks); + // parse statements and extract rhs from first assign + let stmts = p.parse_statements().expect("parse stmts"); + match &stmts[0] { + crate::exa_wasm::interpreter::Stmt::Assign(_, rhs) => rhs.clone(), + _ => panic!("expected assign stmt"), + } + }; + + use diffsol::NalgebraContext; + let mut x = crate::simulator::V::zeros(1, NalgebraContext); + x[0] = 1.0; + let p = crate::simulator::V::zeros(0, NalgebraContext); + let rateiv = crate::simulator::V::zeros(0, NalgebraContext); + + let ast_val = eval_expr(&rhs_expr, &x, &p, &rateiv, None, None, Some(0.0), None); + + let bc = match v.get("diffeq_bytecode") { + Some(b) => b, + None => { + eprintln!("emit_ir did not produce diffeq_bytecode for method-call test; skipping VM parity check"); + return; + } + }; + let map: std::collections::HashMap> = + serde_json::from_value(bc.clone()).expect("deserialize bytecode_map"); + let code = map.get(&0usize).expect("code for idx 0"); + let mut expr_code = code.clone(); + if let Some(last) = expr_code.last() { + match last { + Opcode::StoreDx(_) => { + expr_code.pop(); + } + _ => {} + } + } + + let builtins = |name: &str, args: &[f64]| -> f64 { + use crate::exa_wasm::interpreter::eval::{eval_call, Value}; + let vals: Vec = args.iter().map(|v| Value::Number(*v)).collect(); + eval_call(name, &vals).as_number() + }; + + let mut locals: Vec = Vec::new(); + let mut locals_slice = locals.as_mut_slice(); + let x_vals: Vec = vec![x[0]]; + let p_vals: Vec = vec![]; + let rateiv_vals: Vec = vec![]; + + let mut funcs: Vec = Vec::new(); + if let Some(fv) = v.get("funcs") { + funcs = serde_json::from_value(fv.clone()).unwrap_or_default(); + } + eprintln!("debug funcs: {:?}", funcs); + eprintln!("debug expr_code: {:?}", expr_code); + + let vm_val = run_bytecode_eval( + &expr_code, + &x_vals, + &p_vals, + &rateiv_vals, + 0.0, + &mut locals_slice, + &funcs, + &builtins, + ); + + assert_eq!(ast_val.as_number(), vm_val); + } + + #[test] + fn test_bytecode_parity_method_call() { + use crate::exa_wasm::interpreter::eval::eval_expr; + use crate::exa_wasm::interpreter::vm::run_bytecode_eval; + use crate::exa_wasm::interpreter::Opcode; + use std::env; + use std::fs; + + let tmp = env::temp_dir().join("exa_test_parity_method.json"); + let diffeq = "|x, p, _t, dx, rateiv, _cov| { dx[0] = x[0].sin(); }".to_string(); + let _path = crate::exa_wasm::build::emit_ir::( + diffeq, + None, + None, + None, + None, + Some(tmp.clone()), + vec![], + ) + .expect("emit_ir failed"); + let s = fs::read_to_string(&tmp).expect("read emitted ir"); + let v: serde_json::Value = serde_json::from_str(&s).expect("parse json"); + fs::remove_file(&tmp).ok(); + + let rhs_expr = if let Some(diffeq_ast) = v.get("diffeq_ast") { + let stmts: Vec = + serde_json::from_value(diffeq_ast.clone()).expect("deserialize stmts"); + match &stmts[0] { + crate::exa_wasm::interpreter::Stmt::Assign(_, rhs) => rhs.clone(), + _ => panic!("expected assign stmt"), + } + } else { + let diffeq_text = v + .get("diffeq") + .and_then(|d| d.as_str()) + .expect("diffeq text"); + let body = crate::exa_wasm::interpreter::extract_closure_body(diffeq_text) + .expect("closure body"); + let toks = crate::exa_wasm::interpreter::tokenize(&body); + let mut p = crate::exa_wasm::interpreter::Parser::new(toks); + if let Some(stmts) = p.parse_statements() { + match &stmts[0] { + crate::exa_wasm::interpreter::Stmt::Assign(_, rhs) => rhs.clone(), + _ => panic!("expected assign stmt"), + } + } else { + // fallback: extract RHS between '=' and ';' and parse as single expression + if let Some(eq_pos) = body.find('=') { + if let Some(sc_pos) = body[eq_pos..].find(';') { + let rhs_text = body[eq_pos + 1..eq_pos + sc_pos].trim(); + let toks2 = crate::exa_wasm::interpreter::tokenize(rhs_text); + let mut p2 = crate::exa_wasm::interpreter::Parser::new(toks2); + p2.parse_expr().expect("parse expr rhs") + } else { + panic!("parse stmts"); + } + } else { + panic!("parse stmts"); + } + } + }; + + use diffsol::NalgebraContext; + let mut x = crate::simulator::V::zeros(1, NalgebraContext); + x[0] = 0.5; + let p = crate::simulator::V::zeros(0, NalgebraContext); + let rateiv = crate::simulator::V::zeros(0, NalgebraContext); + + let ast_val = eval_expr(&rhs_expr, &x, &p, &rateiv, None, None, Some(0.0), None); + + let bc = v.get("diffeq_bytecode").expect("diffeq_bytecode"); + let map: std::collections::HashMap> = + serde_json::from_value(bc.clone()).expect("deserialize bytecode_map"); + let code = map.get(&0usize).expect("code for idx 0"); + let mut expr_code = code.clone(); + if let Some(last) = expr_code.last() { + match last { + Opcode::StoreDx(_) => { + expr_code.pop(); + } + _ => {} + } + } + + let mut funcs: Vec = Vec::new(); + if let Some(fv) = v.get("funcs") { + funcs = serde_json::from_value(fv.clone()).unwrap_or_default(); + } + + let builtins = |name: &str, args: &[f64]| -> f64 { + use crate::exa_wasm::interpreter::eval::{eval_call, Value}; + let vals: Vec = args.iter().map(|v| Value::Number(*v)).collect(); + eval_call(name, &vals).as_number() + }; + + let mut locals: Vec = Vec::new(); + let mut locals_slice = locals.as_mut_slice(); + let x_vals: Vec = vec![x[0]]; + let p_vals: Vec = vec![]; + let rateiv_vals: Vec = vec![]; + let vm_val = run_bytecode_eval( + &expr_code, + &x_vals, + &p_vals, + &rateiv_vals, + 0.0, + &mut locals_slice, + &funcs, + &builtins, + ); + + assert_eq!(ast_val.as_number(), vm_val); + } + + #[test] + fn test_bytecode_parity_nested_dynamic() { + use crate::exa_wasm::interpreter::eval::eval_expr; + use crate::exa_wasm::interpreter::vm::run_bytecode_eval; + use crate::exa_wasm::interpreter::Opcode; + use std::env; + use std::fs; + + let tmp = env::temp_dir().join("exa_test_parity_nested.json"); + let diffeq = "|x, p, _t, dx, rateiv, _cov| { dx[0] = x[x[ke]]; }".to_string(); + let params = vec!["ke".to_string()]; + let _path = crate::exa_wasm::build::emit_ir::( + diffeq, + None, + None, + None, + None, + Some(tmp.clone()), + params.clone(), + ) + .expect("emit_ir failed"); + let s = fs::read_to_string(&tmp).expect("read emitted ir"); + let v: serde_json::Value = serde_json::from_str(&s).expect("parse json"); + fs::remove_file(&tmp).ok(); + + let diffeq_ast = v.get("diffeq_ast").expect("diffeq_ast"); + let stmts: Vec = + serde_json::from_value(diffeq_ast.clone()).expect("deserialize stmts"); + let rhs_expr = match &stmts[0] { + crate::exa_wasm::interpreter::Stmt::Assign(_, rhs) => rhs.clone(), + _ => panic!("expected assign stmt"), + }; + + use diffsol::NalgebraContext; + let mut x = crate::simulator::V::zeros(3, NalgebraContext); + x[0] = 11.0; + x[1] = 22.0; + x[2] = 33.0; + let mut p = crate::simulator::V::zeros(1, NalgebraContext); + p[0] = 1.0; // ke -> picks x[1] + let rateiv = crate::simulator::V::zeros(0, NalgebraContext); + + let ast_val = eval_expr(&rhs_expr, &x, &p, &rateiv, None, None, Some(0.0), None); + + let bc = v.get("diffeq_bytecode").expect("diffeq_bytecode"); + let map: std::collections::HashMap> = + serde_json::from_value(bc.clone()).expect("deserialize bytecode_map"); + let code = map.get(&0usize).expect("code for idx 0"); + let mut expr_code = code.clone(); + if let Some(last) = expr_code.last() { + match last { + Opcode::StoreDx(_) => { + expr_code.pop(); + } + _ => {} + } + } + + let builtins = |name: &str, args: &[f64]| -> f64 { + use crate::exa_wasm::interpreter::eval::{eval_call, Value}; + let vals: Vec = args.iter().map(|v| Value::Number(*v)).collect(); + eval_call(name, &vals).as_number() + }; + + let mut locals: Vec = Vec::new(); + let mut locals_slice = locals.as_mut_slice(); + let x_vals: Vec = vec![x[0], x[1], x[2]]; + let p_vals: Vec = vec![p[0]]; + let rateiv_vals: Vec = vec![]; + let vm_val = run_bytecode_eval( + &expr_code, + &x_vals, + &p_vals, + &rateiv_vals, + 0.0, + &mut locals_slice, + &Vec::new(), + &builtins, + ); + + assert_eq!(ast_val.as_number(), vm_val); + } + + #[test] + fn test_bytecode_parity_bool_short_circuit() { + use crate::exa_wasm::interpreter::eval::eval_expr; + use crate::exa_wasm::interpreter::vm::run_bytecode_eval; + use crate::exa_wasm::interpreter::Opcode; + use std::env; + use std::fs; + + let tmp = env::temp_dir().join("exa_test_parity_bool.json"); + let diffeq = + "|x, p, _t, dx, rateiv, _cov| { dx[0] = (x[0] > 0) && (x[0] < 10) ? 1.0 : 0.0; }" + .to_string(); + let _path = crate::exa_wasm::build::emit_ir::( + diffeq, + None, + None, + None, + None, + Some(tmp.clone()), + vec![], + ) + .expect("emit_ir failed"); + let s = fs::read_to_string(&tmp).expect("read emitted ir"); + let v: serde_json::Value = serde_json::from_str(&s).expect("parse json"); + fs::remove_file(&tmp).ok(); + + // extract RHS expr + let rhs_expr = if let Some(diffeq_ast) = v.get("diffeq_ast") { + let stmts: Vec = + serde_json::from_value(diffeq_ast.clone()).expect("deserialize stmts"); + match &stmts[0] { + crate::exa_wasm::interpreter::Stmt::Assign(_, rhs) => rhs.clone(), + _ => panic!("expected assign stmt"), + } + } else { + let diffeq_text = v + .get("diffeq") + .and_then(|d| d.as_str()) + .expect("diffeq text"); + let body = crate::exa_wasm::interpreter::extract_closure_body(diffeq_text) + .expect("closure body"); + let toks = crate::exa_wasm::interpreter::tokenize(&body); + let mut p = crate::exa_wasm::interpreter::Parser::new(toks); + let stmts = p.parse_statements().expect("parse stmts"); + match &stmts[0] { + crate::exa_wasm::interpreter::Stmt::Assign(_, rhs) => rhs.clone(), + _ => panic!("expected assign stmt"), + } + }; + + use diffsol::NalgebraContext; + let mut x = crate::simulator::V::zeros(1, NalgebraContext); + x[0] = 5.0; + let p = crate::simulator::V::zeros(0, NalgebraContext); + let rateiv = crate::simulator::V::zeros(0, NalgebraContext); + + let ast_val = eval_expr(&rhs_expr, &x, &p, &rateiv, None, None, Some(0.0), None); + + let bc = match v.get("diffeq_bytecode") { + Some(b) => b, + None => { + eprintln!( + "emit_ir did not produce diffeq_bytecode for bool short-circuit test; skipping VM parity check" + ); + return; + } + }; + let map: std::collections::HashMap> = + serde_json::from_value(bc.clone()).expect("deserialize bytecode_map"); + let code = map.get(&0usize).expect("code for idx 0"); + let mut expr_code = code.clone(); + if let Some(last) = expr_code.last() { + match last { + Opcode::StoreDx(_) => { + expr_code.pop(); + } + _ => {} + } + } + + let builtins = |name: &str, args: &[f64]| -> f64 { + use crate::exa_wasm::interpreter::eval::{eval_call, Value}; + let vals: Vec = args.iter().map(|v| Value::Number(*v)).collect(); + eval_call(name, &vals).as_number() + }; + + let mut locals: Vec = Vec::new(); + let mut locals_slice = locals.as_mut_slice(); + let x_vals: Vec = vec![x[0]]; + let p_vals: Vec = vec![]; + let rateiv_vals: Vec = vec![]; + + let mut funcs: Vec = Vec::new(); + if let Some(fv) = v.get("funcs") { + funcs = serde_json::from_value(fv.clone()).unwrap_or_default(); + } + eprintln!("debug funcs: {:?}", funcs); + eprintln!("debug expr_code: {:?}", expr_code); + + let vm_val = run_bytecode_eval( + &expr_code, + &x_vals, + &p_vals, + &rateiv_vals, + 0.0, + &mut locals_slice, + &funcs, + &builtins, + ); + + assert_eq!(ast_val.as_number(), vm_val); + } + + #[test] + fn test_bytecode_parity_chained_method_calls() { + use crate::exa_wasm::interpreter::eval::eval_expr; + use crate::exa_wasm::interpreter::vm::run_bytecode_eval; + use crate::exa_wasm::interpreter::Opcode; + use std::env; + use std::fs; + + let tmp = env::temp_dir().join("exa_test_parity_chained.json"); + let diffeq = "|x, p, _t, dx, rateiv, _cov| { dx[0] = x[0].sin().abs(); }".to_string(); + let _path = crate::exa_wasm::build::emit_ir::( + diffeq, + None, + None, + None, + None, + Some(tmp.clone()), + vec![], + ) + .expect("emit_ir failed"); + let s = fs::read_to_string(&tmp).expect("read emitted ir"); + let v: serde_json::Value = serde_json::from_str(&s).expect("parse json"); + fs::remove_file(&tmp).ok(); + + let rhs_expr = if let Some(diffeq_ast) = v.get("diffeq_ast") { + let stmts: Vec = + serde_json::from_value(diffeq_ast.clone()).expect("deserialize stmts"); + match &stmts[0] { + crate::exa_wasm::interpreter::Stmt::Assign(_, rhs) => rhs.clone(), + _ => panic!("expected assign stmt"), + } + } else { + let diffeq_text = v + .get("diffeq") + .and_then(|d| d.as_str()) + .expect("diffeq text"); + let body = crate::exa_wasm::interpreter::extract_closure_body(diffeq_text) + .expect("closure body"); + let toks = crate::exa_wasm::interpreter::tokenize(&body); + let mut p = crate::exa_wasm::interpreter::Parser::new(toks); + if let Some(stmts) = p.parse_statements() { + match &stmts[0] { + crate::exa_wasm::interpreter::Stmt::Assign(_, rhs) => rhs.clone(), + _ => panic!("expected assign stmt"), + } + } else { + if let Some(eq_pos) = body.find('=') { + if let Some(sc_pos) = body[eq_pos..].find(';') { + let rhs_text = body[eq_pos + 1..eq_pos + sc_pos].trim(); + let toks2 = crate::exa_wasm::interpreter::tokenize(rhs_text); + let mut p2 = crate::exa_wasm::interpreter::Parser::new(toks2); + p2.parse_expr().expect("parse expr rhs") + } else { + panic!("parse stmts"); + } + } else { + panic!("parse stmts"); + } + } + }; + + use diffsol::NalgebraContext; + let mut x = crate::simulator::V::zeros(1, NalgebraContext); + x[0] = -0.5; + let p = crate::simulator::V::zeros(0, NalgebraContext); + let rateiv = crate::simulator::V::zeros(0, NalgebraContext); + + let ast_val = eval_expr(&rhs_expr, &x, &p, &rateiv, None, None, Some(0.0), None); + + let bc = match v.get("diffeq_bytecode") { + Some(b) => b, + None => { + eprintln!("emit_ir did not produce diffeq_bytecode for chained method test; skipping VM parity check"); + return; + } + }; + let map: std::collections::HashMap> = + serde_json::from_value(bc.clone()).expect("deserialize bytecode_map"); + let code = map.get(&0usize).expect("code for idx 0"); + let mut expr_code = code.clone(); + if let Some(last) = expr_code.last() { + match last { + Opcode::StoreDx(_) => { + expr_code.pop(); + } + _ => {} + } + } + + let mut funcs: Vec = Vec::new(); + if let Some(fv) = v.get("funcs") { + funcs = serde_json::from_value(fv.clone()).unwrap_or_default(); + } + + let builtins = |name: &str, args: &[f64]| -> f64 { + use crate::exa_wasm::interpreter::eval::{eval_call, Value}; + let vals: Vec = args.iter().map(|v| Value::Number(*v)).collect(); + eval_call(name, &vals).as_number() + }; + + let mut locals: Vec = Vec::new(); + let mut locals_slice = locals.as_mut_slice(); + let x_vals: Vec = vec![x[0]]; + let p_vals: Vec = vec![]; + let rateiv_vals: Vec = vec![]; + let vm_val = run_bytecode_eval( + &expr_code, + &x_vals, + &p_vals, + &rateiv_vals, + 0.0, + &mut locals_slice, + &funcs, + &builtins, + ); + + assert_eq!(ast_val.as_number(), vm_val); + } + + #[test] + fn test_bytecode_parity_method_with_arg() { + use crate::exa_wasm::interpreter::eval::eval_expr; + use crate::exa_wasm::interpreter::vm::run_bytecode_eval; + use crate::exa_wasm::interpreter::Opcode; + use std::env; + use std::fs; + + let tmp = env::temp_dir().join("exa_test_parity_method_arg.json"); + // use pow as method-style call; receiver becomes first arg + let diffeq = "|x, p, _t, dx, rateiv, _cov| { dx[0] = x[0].pow(2.0); }".to_string(); + let _path = crate::exa_wasm::build::emit_ir::( + diffeq, + None, + None, + None, + None, + Some(tmp.clone()), + vec![], + ) + .expect("emit_ir failed"); + let s = fs::read_to_string(&tmp).expect("read emitted ir"); + let v: serde_json::Value = serde_json::from_str(&s).expect("parse json"); + fs::remove_file(&tmp).ok(); + + let diffeq_ast = v.get("diffeq_ast").expect("diffeq_ast"); + let stmts: Vec = + serde_json::from_value(diffeq_ast.clone()).expect("deserialize stmts"); + let rhs_expr = match &stmts[0] { + crate::exa_wasm::interpreter::Stmt::Assign(_, rhs) => rhs.clone(), + _ => panic!("expected assign stmt"), + }; + + use diffsol::NalgebraContext; + let mut x = crate::simulator::V::zeros(1, NalgebraContext); + x[0] = 3.0; + let p = crate::simulator::V::zeros(0, NalgebraContext); + let rateiv = crate::simulator::V::zeros(0, NalgebraContext); + + let ast_val = eval_expr(&rhs_expr, &x, &p, &rateiv, None, None, Some(0.0), None); + + let bc = match v.get("diffeq_bytecode") { + Some(b) => b, + None => { + eprintln!("emit_ir did not produce diffeq_bytecode for method-with-arg test; skipping VM parity check"); + return; + } + }; + let map: std::collections::HashMap> = + serde_json::from_value(bc.clone()).expect("deserialize bytecode_map"); + let code = map.get(&0usize).expect("code for idx 0"); + let mut expr_code = code.clone(); + if let Some(last) = expr_code.last() { + match last { + Opcode::StoreDx(_) => { + expr_code.pop(); + } + _ => {} + } + } + + let builtins = |name: &str, args: &[f64]| -> f64 { + use crate::exa_wasm::interpreter::eval::{eval_call, Value}; + let vals: Vec = args.iter().map(|v| Value::Number(*v)).collect(); + eval_call(name, &vals).as_number() + }; + + let mut locals: Vec = Vec::new(); + let mut locals_slice = locals.as_mut_slice(); + let x_vals: Vec = vec![x[0]]; + let p_vals: Vec = vec![]; + let rateiv_vals: Vec = vec![]; + + // use funcs table emitted in IR so builtins can be looked up by name + let mut funcs: Vec = Vec::new(); + if let Some(fv) = v.get("funcs") { + funcs = serde_json::from_value(fv.clone()).unwrap_or_default(); + } + + let vm_val = run_bytecode_eval( + &expr_code, + &x_vals, + &p_vals, + &rateiv_vals, + 0.0, + &mut locals_slice, + &funcs, + &builtins, + ); + + if (ast_val.as_number() - vm_val).abs() > 1e-12 { + panic!( + "parity mismatch: ast={} vm={} funcs={:?} code={:?}", + ast_val.as_number(), + vm_val, + funcs, + expr_code + ); + } + } +} diff --git a/src/exa_wasm/interpreter/parser.rs b/src/exa_wasm/interpreter/parser.rs new file mode 100644 index 00000000..39528379 --- /dev/null +++ b/src/exa_wasm/interpreter/parser.rs @@ -0,0 +1,733 @@ +use crate::exa_wasm::interpreter::ast::{Expr, ParseError, Token}; + +// Tokenizer + recursive-descent parser +pub fn tokenize(s: &str) -> Vec { + let mut toks = Vec::new(); + let mut chars = s.chars().peekable(); + while let Some(&c) = chars.peek() { + if c.is_whitespace() { + chars.next(); + continue; + } + // Numbers: start with digit, or a dot followed by a digit (e.g. .5) + if c.is_ascii_digit() + || (c == '.' && { + // lookahead: only treat '.' as start of number when followed by a digit + let mut tmp = chars.clone(); + // consume current '.' + tmp.next(); + if let Some(&d) = tmp.peek() { + d.is_ascii_digit() + } else { + false + } + }) + { + let mut num = String::new(); + while let Some(&d) = chars.peek() { + if d.is_ascii_digit() + || d == '.' + || d == 'e' + || d == 'E' + || ((d == '+' || d == '-') && (num.ends_with('e') || num.ends_with('E'))) + { + num.push(d); + chars.next(); + } else { + break; + } + } + if let Ok(v) = num.parse::() { + toks.push(Token::Num(v)); + } + continue; + } + if c.is_ascii_alphabetic() || c == '_' { + let mut id = String::new(); + while let Some(&d) = chars.peek() { + if d.is_ascii_alphanumeric() || d == '_' { + id.push(d); + chars.next(); + } else { + break; + } + } + // treat true/false as boolean tokens + if id.eq_ignore_ascii_case("true") { + toks.push(Token::Bool(true)); + } else if id.eq_ignore_ascii_case("false") { + toks.push(Token::Bool(false)); + } else { + toks.push(Token::Ident(id)); + } + continue; + } + match c { + '[' => { + toks.push(Token::LBracket); + chars.next(); + } + '{' => { + toks.push(Token::LBrace); + chars.next(); + } + '}' => { + toks.push(Token::RBrace); + chars.next(); + } + '?' => { + toks.push(Token::Question); + chars.next(); + } + ':' => { + toks.push(Token::Colon); + chars.next(); + } + ']' => { + toks.push(Token::RBracket); + chars.next(); + } + '(' => { + toks.push(Token::LParen); + chars.next(); + } + ')' => { + toks.push(Token::RParen); + chars.next(); + } + ',' => { + toks.push(Token::Comma); + chars.next(); + } + ';' => { + toks.push(Token::Semicolon); + chars.next(); + } + '+' | '-' | '*' | '/' => { + toks.push(Token::Op(c)); + chars.next(); + } + '^' => { + toks.push(Token::Op('^')); + chars.next(); + } + '.' => { + toks.push(Token::Dot); + chars.next(); + } + '<' => { + chars.next(); + if let Some(&'=') = chars.peek() { + chars.next(); + toks.push(Token::Le); + } else { + toks.push(Token::Lt); + } + } + '>' => { + chars.next(); + if let Some(&'=') = chars.peek() { + chars.next(); + toks.push(Token::Ge); + } else { + toks.push(Token::Gt); + } + } + '=' => { + chars.next(); + if let Some(&'=') = chars.peek() { + chars.next(); + toks.push(Token::EqEq); + } else { + toks.push(Token::Assign); + } + } + '!' => { + chars.next(); + if let Some(&'=') = chars.peek() { + chars.next(); + toks.push(Token::Ne); + } else { + toks.push(Token::Bang); + } + } + '&' => { + chars.next(); + if let Some(&'&') = chars.peek() { + chars.next(); + toks.push(Token::And); + } + } + '|' => { + chars.next(); + if let Some(&'|') = chars.peek() { + chars.next(); + toks.push(Token::Or); + } + } + _ => { + chars.next(); + } + } + } + toks +} + +pub struct Parser { + tokens: Vec, + pos: usize, + expected: Vec, +} +impl Parser { + pub fn new(tokens: Vec) -> Self { + Self { + tokens, + pos: 0, + expected: Vec::new(), + } + } + fn expected_push(&mut self, s: &str) { + if !self.expected.contains(&s.to_string()) { + self.expected.push(s.to_string()); + } + } + fn peek(&self) -> Option<&Token> { + self.tokens.get(self.pos) + } + fn next(&mut self) -> Option<&Token> { + let r = self.tokens.get(self.pos); + if r.is_some() { + self.pos += 1; + } + r + } + pub fn parse_expr(&mut self) -> Option { + self.parse_ternary() + } + fn parse_ternary(&mut self) -> Option { + let cond = self.parse_or()?; + if let Some(Token::Question) = self.peek().cloned() { + self.next(); + let then_branch = self.parse_expr()?; + if let Some(Token::Colon) = self.peek().cloned() { + self.next(); + let else_branch = self.parse_expr()?; + return Some(Expr::Ternary { + cond: Box::new(cond), + then_branch: Box::new(then_branch), + else_branch: Box::new(else_branch), + }); + } else { + self.expected_push(":"); + return None; + } + } + Some(cond) + } + pub fn parse_expr_result(&mut self) -> Result { + if let Some(expr) = self.parse_expr() { + Ok(expr) + } else { + Err(ParseError { + pos: self.pos, + found: self.peek().cloned(), + expected: self.expected.clone(), + }) + } + } + fn parse_or(&mut self) -> Option { + let mut node = self.parse_and()?; + while let Some(Token::Or) = self.peek().cloned() { + self.next(); + let rhs = self.parse_and()?; + node = Expr::BinaryOp { + lhs: Box::new(node), + op: "||".to_string(), + rhs: Box::new(rhs), + }; + } + Some(node) + } + fn parse_and(&mut self) -> Option { + let mut node = self.parse_eq()?; + while let Some(Token::And) = self.peek().cloned() { + self.next(); + let rhs = self.parse_eq()?; + node = Expr::BinaryOp { + lhs: Box::new(node), + op: "&&".to_string(), + rhs: Box::new(rhs), + }; + } + Some(node) + } + fn parse_eq(&mut self) -> Option { + let mut node = self.parse_cmp()?; + loop { + match self.peek() { + Some(Token::EqEq) => { + self.next(); + let rhs = self.parse_cmp()?; + node = Expr::BinaryOp { + lhs: Box::new(node), + op: "==".to_string(), + rhs: Box::new(rhs), + }; + } + Some(Token::Ne) => { + self.next(); + let rhs = self.parse_cmp()?; + node = Expr::BinaryOp { + lhs: Box::new(node), + op: "!=".to_string(), + rhs: Box::new(rhs), + }; + } + _ => break, + } + } + Some(node) + } + fn parse_cmp(&mut self) -> Option { + let mut node = self.parse_add_sub()?; + loop { + match self.peek() { + Some(Token::Lt) => { + self.next(); + let rhs = self.parse_add_sub()?; + node = Expr::BinaryOp { + lhs: Box::new(node), + op: "<".to_string(), + rhs: Box::new(rhs), + }; + } + Some(Token::Gt) => { + self.next(); + let rhs = self.parse_add_sub()?; + node = Expr::BinaryOp { + lhs: Box::new(node), + op: ">".to_string(), + rhs: Box::new(rhs), + }; + } + Some(Token::Le) => { + self.next(); + let rhs = self.parse_add_sub()?; + node = Expr::BinaryOp { + lhs: Box::new(node), + op: "<=".to_string(), + rhs: Box::new(rhs), + }; + } + Some(Token::Ge) => { + self.next(); + let rhs = self.parse_add_sub()?; + node = Expr::BinaryOp { + lhs: Box::new(node), + op: ">=".to_string(), + rhs: Box::new(rhs), + }; + } + _ => break, + } + } + Some(node) + } + fn parse_add_sub(&mut self) -> Option { + let mut node = self.parse_mul_div()?; + while let Some(tok) = self.peek() { + match tok { + Token::Op('+') => { + self.next(); + let rhs = self.parse_mul_div()?; + node = Expr::BinaryOp { + lhs: Box::new(node), + op: "+".to_string(), + rhs: Box::new(rhs), + }; + } + Token::Op('-') => { + self.next(); + let rhs = self.parse_mul_div()?; + node = Expr::BinaryOp { + lhs: Box::new(node), + op: "-".to_string(), + rhs: Box::new(rhs), + }; + } + _ => break, + } + } + Some(node) + } + fn parse_mul_div(&mut self) -> Option { + let mut node = self.parse_power()?; + while let Some(tok) = self.peek() { + match tok { + Token::Op('*') => { + self.next(); + let rhs = self.parse_unary()?; + node = Expr::BinaryOp { + lhs: Box::new(node), + op: "*".to_string(), + rhs: Box::new(rhs), + }; + } + Token::Op('/') => { + self.next(); + let rhs = self.parse_unary()?; + node = Expr::BinaryOp { + lhs: Box::new(node), + op: "/".to_string(), + rhs: Box::new(rhs), + }; + } + _ => break, + } + } + Some(node) + } + fn parse_power(&mut self) -> Option { + let node = self.parse_unary()?; + if let Some(Token::Op('^')) = self.peek() { + self.next(); + let rhs = self.parse_power()?; + return Some(Expr::BinaryOp { + lhs: Box::new(node), + op: "^".to_string(), + rhs: Box::new(rhs), + }); + } + Some(node) + } + fn parse_unary(&mut self) -> Option { + if let Some(Token::Op('-')) = self.peek() { + self.next(); + let rhs = self.parse_unary()?; + return Some(Expr::UnaryOp { + op: '-'.to_string(), + rhs: Box::new(rhs), + }); + } + if let Some(Token::Bang) = self.peek() { + self.next(); + let rhs = self.parse_unary()?; + return Some(Expr::UnaryOp { + op: '!'.to_string(), + rhs: Box::new(rhs), + }); + } + self.parse_primary() + } + fn parse_primary(&mut self) -> Option { + let tok = self.next().cloned()?; + let mut node = match tok { + Token::Num(v) => Expr::Number(v), + Token::Bool(b) => Expr::Bool(b), + Token::Ident(id) => { + // function call? + if let Some(Token::LParen) = self.peek().cloned() { + self.next(); + let mut args: Vec = Vec::new(); + if let Some(Token::RParen) = self.peek().cloned() { + self.next(); + Expr::Call { + name: id.clone(), + args, + } + } else { + loop { + if let Some(expr) = self.parse_expr() { + args.push(expr); + } else { + self.expected_push("expression"); + return None; + } + match self.peek().cloned() { + Some(Token::Comma) => { + self.next(); + continue; + } + Some(Token::RParen) => { + self.next(); + break; + } + _ => { + self.expected_push(",|)"); + return None; + } + } + } + // after parsing args, produce the Call node + Expr::Call { + name: id.clone(), + args, + } + } + // indexed access? + } else if let Some(Token::LBracket) = self.peek().cloned() { + self.next(); + // parse index expression + let idx = self.parse_expr()?; + if let Some(Token::RBracket) = self.peek().cloned() { + self.next(); + Expr::Indexed(id.clone(), Box::new(idx)) + } else { + self.expected_push("]"); + return None; + } + } else { + Expr::Ident(id.clone()) + } + } + Token::LParen => { + let expr = self.parse_expr(); + if let Some(Token::RParen) = self.peek().cloned() { + self.next(); + if let Some(e) = expr { + e + } else { + self.expected_push("expression"); + return None; + } + } else { + self.expected_push(")"); + return None; + } + } + _ => { + self.expected_push("number|identifier|'('"); + return None; + } + }; + + // method call chaining: .name(args?) + loop { + if let Some(Token::Dot) = self.peek().cloned() { + self.next(); + let name = if let Some(Token::Ident(n)) = self.next().cloned() { + n + } else { + self.expected_push("identifier"); + return None; + }; + let mut args: Vec = Vec::new(); + if let Some(Token::LParen) = self.peek().cloned() { + self.next(); + if let Some(Token::RParen) = self.peek().cloned() { + self.next(); + } else { + loop { + if let Some(expr) = self.parse_expr() { + args.push(expr); + } else { + self.expected_push("expression"); + return None; + } + match self.peek().cloned() { + Some(Token::Comma) => { + self.next(); + continue; + } + Some(Token::RParen) => { + self.next(); + break; + } + _ => { + self.expected_push(",|)"); + return None; + } + } + } + } + } + node = Expr::MethodCall { + receiver: Box::new(node), + name, + args, + }; + continue; + } + break; + } + + Some(node) + } +} + +// Statement parsing (small recursive-descent on top of the expression parser) +impl Parser { + pub fn parse_statements(&mut self) -> Option> { + let mut stmts = Vec::new(); + while let Some(tok) = self.peek() { + match tok { + Token::RBrace => break, + _ => { + if let Some(s) = self.parse_statement() { + stmts.push(s); + continue; + } else { + return None; + } + } + } + } + Some(stmts) + } + + fn parse_statement(&mut self) -> Option { + use crate::exa_wasm::interpreter::ast::{Lhs, Stmt}; + // handle `if` as identifier token + if let Some(Token::Ident(id)) = self.peek().cloned() { + if id == "if" { + // consume 'if' + self.next(); + // allow optional parens around condition + let cond = if let Some(Token::LParen) = self.peek().cloned() { + self.next(); + let e = self.parse_expr()?; + if let Some(Token::RParen) = self.peek().cloned() { + self.next(); + } else { + self.expected_push(")"); + return None; + } + e + } else { + self.parse_expr()? + }; + // then branch must be a block + let then_branch = if let Some(Token::LBrace) = self.peek().cloned() { + self.next(); + let mut pstmts = Vec::new(); + while let Some(tok) = self.peek().cloned() { + if let Token::RBrace = tok { + self.next(); + break; + } + pstmts.push(self.parse_statement()?); + } + Stmt::Block(pstmts) + } else { + // single statement as then branch + self.parse_statement() + .map(Box::new) + .map(|b| *b) + .unwrap_or(Stmt::Block(vec![])) + }; + // optional else + let else_branch = if let Some(Token::Ident(eid)) = self.peek().cloned() { + if eid == "else" { + self.next(); + if let Some(Token::LBrace) = self.peek().cloned() { + self.next(); + let mut estmts = Vec::new(); + while let Some(tok) = self.peek().cloned() { + if let Token::RBrace = tok { + self.next(); + break; + } + estmts.push(self.parse_statement()?); + } + Some(Box::new(Stmt::Block(estmts))) + } else if let Some(Token::Ident(_)) = self.peek().cloned() { + Some(Box::new(self.parse_statement()?)) + } else { + None + } + } else { + None + } + } else { + None + }; + return Some(Stmt::If { + cond, + then_branch: Box::new(then_branch), + else_branch, + }); + } + } + + // Attempt assignment: lookahead without consuming + if let Some(Token::Ident(_)) = self.peek() { + // lookahead for simple `Ident =` or `Ident [ ... ] =` + let mut is_assign = false; + // check immediate next token + if let Some(next_tok) = self.tokens.get(self.pos + 1) { + match next_tok { + Token::Assign => is_assign = true, + Token::LBracket => { + // find matching RBracket + let mut depth = 0isize; + let mut j = self.pos + 1; + while j < self.tokens.len() { + match self.tokens[j] { + Token::LBracket => depth += 1, + Token::RBracket => { + depth -= 1; + if depth == 0 { + // check token after RBracket + if let Some(tok_after) = self.tokens.get(j + 1) { + if let Token::Assign = tok_after { + is_assign = true; + } + } + break; + } + } + _ => {} + } + j += 1; + } + } + _ => {} + } + } + + if is_assign { + // parse lhs + let lhs = if let Some(Token::Ident(name)) = self.next().cloned() { + if let Some(Token::LBracket) = self.peek().cloned() { + self.next(); + let idx = self.parse_expr()?; + if let Some(Token::RBracket) = self.peek().cloned() { + self.next(); + Lhs::Indexed(name, Box::new(idx)) + } else { + self.expected_push("]"); + return None; + } + } else { + Lhs::Ident(name) + } + } else { + return None; + }; + // expect assign + if let Some(Token::Assign) = self.peek().cloned() { + self.next(); + let rhs = self.parse_expr()?; + // expect semicolon + if let Some(Token::Semicolon) = self.peek().cloned() { + self.next(); + } else { + self.expected_push(";"); + return None; + } + return Some(Stmt::Assign(lhs, rhs)); + } + } + } + + // Expression statement: expr ; + let expr = self.parse_expr()?; + if let Some(Token::Semicolon) = self.peek().cloned() { + self.next(); + } else { + self.expected_push(";"); + return None; + } + Some(Stmt::Expr(expr)) + } +} diff --git a/src/exa_wasm/interpreter/registry.rs b/src/exa_wasm/interpreter/registry.rs new file mode 100644 index 00000000..1c675297 --- /dev/null +++ b/src/exa_wasm/interpreter/registry.rs @@ -0,0 +1,105 @@ +use once_cell::sync::Lazy; +use std::collections::HashMap; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Mutex; + +use crate::exa_wasm::interpreter::ast::{Expr, Stmt}; + +#[derive(Clone, Debug)] +pub struct RegistryEntry { + // statement-level representations for closures; each Vec contains + // the top-level statements parsed from the corresponding closure + pub diffeq_stmts: Vec, + pub out_stmts: Vec, + pub init_stmts: Vec, + pub lag: HashMap, + pub fa: HashMap, + // prelude assignments executed before dx evaluation: ordered (name, expr) + pub prelude: Vec<(String, Expr)>, + pub pmap: HashMap, + pub nstates: usize, + pub _nouteqs: usize, + // optional compiled bytecode blobs for closures (index -> opcode sequence) + pub bytecode_diffeq: + std::collections::HashMap>, + // optional compiled function-level bytecode for diffeq as a single code vector + pub bytecode_diffeq_func: Vec, + // support for out/init/lag/fa as maps of index -> opcode sequences + pub bytecode_out: std::collections::HashMap>, + pub bytecode_init: std::collections::HashMap>, + pub bytecode_lag: std::collections::HashMap>, + pub bytecode_fa: std::collections::HashMap>, + // local slot names in evaluation order + pub locals: Vec, + // builtin function table emitted by the compiler/emit_ir + pub funcs: Vec, +} + +static EXPR_REGISTRY: Lazy>> = + Lazy::new(|| Mutex::new(HashMap::new())); + +static NEXT_EXPR_ID: Lazy = Lazy::new(|| AtomicUsize::new(1)); + +thread_local! { + static CURRENT_EXPR_ID: std::cell::Cell> = std::cell::Cell::new(None); + static LAST_RUNTIME_ERROR: std::cell::RefCell> = std::cell::RefCell::new(None); +} + +pub fn set_current_expr_id(id: Option) -> Option { + let prev = CURRENT_EXPR_ID.with(|c| { + let p = c.get(); + c.set(id); + p + }); + prev +} + +pub fn current_expr_id() -> Option { + CURRENT_EXPR_ID.with(|c| c.get()) +} + +pub fn set_runtime_error(msg: String) { + LAST_RUNTIME_ERROR.with(|c| { + *c.borrow_mut() = Some(msg); + }); +} + +pub fn take_runtime_error() -> Option { + LAST_RUNTIME_ERROR.with(|c| c.borrow_mut().take()) +} + +pub fn register_entry(entry: RegistryEntry) -> usize { + let id = NEXT_EXPR_ID.fetch_add(1, Ordering::SeqCst); + let mut guard = EXPR_REGISTRY.lock().unwrap(); + guard.insert(id, entry); + id +} + +pub fn unregister_model(id: usize) { + let mut guard = EXPR_REGISTRY.lock().unwrap(); + guard.remove(&id); +} + +pub fn get_entry(id: usize) -> Option { + let guard = EXPR_REGISTRY.lock().unwrap(); + guard.get(&id).cloned() +} + +pub fn ode_for_id(id: usize) -> Option { + if let Some(entry) = get_entry(id) { + let nstates = entry.nstates; + let nouteqs = entry._nouteqs; + let ode = crate::simulator::equation::ODE::with_registry_id( + crate::exa_wasm::interpreter::dispatch::diffeq_dispatch, + crate::exa_wasm::interpreter::dispatch::lag_dispatch, + crate::exa_wasm::interpreter::dispatch::fa_dispatch, + crate::exa_wasm::interpreter::dispatch::init_dispatch, + crate::exa_wasm::interpreter::dispatch::out_dispatch, + (nstates, nouteqs), + Some(id), + ); + Some(ode) + } else { + None + } +} diff --git a/src/exa_wasm/interpreter/typecheck.rs b/src/exa_wasm/interpreter/typecheck.rs new file mode 100644 index 00000000..c7687727 --- /dev/null +++ b/src/exa_wasm/interpreter/typecheck.rs @@ -0,0 +1,189 @@ +use crate::exa_wasm::interpreter::ast::{Expr, Lhs, Stmt}; + +#[derive(Debug, PartialEq)] +pub enum Type { + Number, + Bool, +} + +pub enum TypeError { + UnknownFunction(String), + Arity { + name: String, + expected: String, + got: usize, + }, + IndexNotNumeric, + AssignBooleanToIndexed(String), + Msg(String), +} + +impl std::fmt::Debug for TypeError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TypeError::UnknownFunction(n) => write!(f, "UnknownFunction({})", n), + TypeError::Arity { + name, + expected, + got, + } => write!( + f, + "Arity {{ name: {}, expected: {}, got: {} }}", + name, expected, got + ), + TypeError::IndexNotNumeric => write!(f, "IndexNotNumeric"), + TypeError::AssignBooleanToIndexed(n) => write!(f, "AssignBooleanToIndexed({})", n), + TypeError::Msg(s) => write!(f, "Msg({})", s), + } + } +} + +impl From for TypeError { + fn from(s: String) -> Self { + TypeError::Msg(s) + } +} + +fn type_of_binary_op(_lhs: &Type, op: &str, _rhs: &Type) -> Result { + use Type::*; + match op { + "&&" | "||" => Ok(Bool), + "<" | ">" | "<=" | ">=" | "==" | "!=" => Ok(Bool), + "+" | "-" | "*" | "/" | "^" => Ok(Number), + _ => Ok(Number), + } +} + +// Minimal conservative type checker +pub fn check_expr(expr: &Expr) -> Result { + use Expr::*; + match expr { + Bool(_) => Ok(Type::Bool), + Number(_) => Ok(Type::Number), + Ident(_) => Ok(Type::Number), + Param(_) => Ok(Type::Number), + Indexed(_, idx) => match check_expr(idx)? { + Type::Number => Ok(Type::Number), + _ => Err(TypeError::IndexNotNumeric), + }, + UnaryOp { op, rhs } => { + let t = check_expr(rhs)?; + match op.as_str() { + "!" => Ok(Type::Bool), + "-" => Ok(Type::Number), + _ => Ok(t), + } + } + BinaryOp { lhs, op, rhs } => { + let lt = check_expr(lhs)?; + let rt = check_expr(rhs)?; + type_of_binary_op(<, op, &rt) + } + Call { name, args } => { + // ensure args type-check + for a in args.iter() { + let _ = check_expr(a)?; + } + // check known builtin and arity via shared builtins module + if !crate::exa_wasm::interpreter::builtins::is_known_function(name) { + return Err(TypeError::UnknownFunction(name.clone())); + } + if let Some(range) = crate::exa_wasm::interpreter::builtins::arg_count_range(name) { + if !range.contains(&args.len()) { + let lo = *range.start(); + let hi = *range.end(); + let expect = if lo == hi { + lo.to_string() + } else { + format!("{}..={}", lo, hi) + }; + return Err(TypeError::Arity { + name: name.clone(), + expected: expect, + got: args.len(), + }); + } + } + Ok(Type::Number) + } + MethodCall { + receiver, + name: _, + args, + } => { + let _ = check_expr(receiver)?; + for a in args.iter() { + let _ = check_expr(a)?; + } + Ok(Type::Number) + } + Ternary { + cond, + then_branch, + else_branch, + } => match check_expr(cond)? { + Type::Bool | Type::Number => { + let t1 = check_expr(then_branch)?; + let t2 = check_expr(else_branch)?; + if t1 == t2 { + Ok(t1) + } else { + Ok(Type::Number) + } + } + }, + } +} + +pub fn check_stmt(stmt: &Stmt) -> Result<(), TypeError> { + use Stmt::*; + match stmt { + Expr(e) => { + let _ = check_expr(e)?; + Ok(()) + } + Assign(lhs, rhs) => match lhs { + Lhs::Ident(_) => { + let _ = check_expr(rhs)?; + Ok(()) + } + Lhs::Indexed(name, idx_expr) => { + match check_expr(idx_expr)? { + Type::Number => {} + _ => return Err(TypeError::IndexNotNumeric), + } + match check_expr(rhs)? { + Type::Number => Ok(()), + Type::Bool => Err(TypeError::AssignBooleanToIndexed(name.clone())), + } + } + }, + Block(v) => { + for s in v.iter() { + check_stmt(s)?; + } + Ok(()) + } + If { + cond, + then_branch, + else_branch, + } => { + match check_expr(cond)? { + Type::Bool | Type::Number => {} + } + check_stmt(then_branch)?; + if let Some(eb) = else_branch { + check_stmt(eb)?; + } + Ok(()) + } + } +} + +pub fn check_statements(stmts: &[Stmt]) -> Result<(), TypeError> { + for s in stmts.iter() { + check_stmt(s)?; + } + Ok(()) +} diff --git a/src/exa_wasm/interpreter/vm.rs b/src/exa_wasm/interpreter/vm.rs new file mode 100644 index 00000000..c50923ee --- /dev/null +++ b/src/exa_wasm/interpreter/vm.rs @@ -0,0 +1,483 @@ +use serde::{Deserialize, Serialize}; + +/// Production-grade opcode set for the exa_wasm VM. +/// Keep names compatible with earlier POC where reasonable. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub enum Opcode { + // stack and constants + PushConst(f64), // push constant + LoadParam(usize), // push p[idx] + LoadX(usize), // push x[idx] + LoadRateiv(usize), // push rateiv[idx] + LoadLocal(usize), // push local slot + LoadT, // push t + // dynamic indexed loads/stores (index evaluated at runtime) + LoadParamDyn, // pop index -> push p[idx] + LoadXDyn, // pop index -> push x[idx] + LoadRateivDyn, // pop index -> push rateiv[idx] + + // arithmetic + Add, + Sub, + Mul, + Div, + Pow, + // pop top of stack (discard) + Pop, + + // comparisons / logical (push 0.0/1.0) + Lt, + Gt, + Le, + Ge, + Eq, + Ne, + + // control flow + Jump(usize), // absolute pc + JumpIfFalse(usize), // pop cond, if false jump + + // builtin call: index into func table, arg count + CallBuiltin(usize, usize), + + // stores + StoreDx(usize), // pop value and assign to dx[index] + StoreX(usize), // pop value into x[index] + StoreY(usize), // pop value into y[index] + StoreLocal(usize), // pop value into local slot + // dynamic stores: pop value then pop index (index is f64 -> usize) + StoreDxDyn, // pop value, pop index -> assign to dx[idx] + StoreXDyn, // pop value, pop index -> assign to x[idx] + StoreYDyn, // pop value, pop index -> assign to y[idx] +} + +/// Execute a sequence of opcodes with full VM context. +/// `assign_indexed` is called for dx/x/y assignments (name, idx, val). +pub fn run_bytecode_full( + code: &[Opcode], + x: &[f64], + p: &[f64], + rateiv: &[f64], + t: f64, + locals: &mut [f64], + funcs: &Vec, + builtins_dispatch: &dyn Fn(&str, &[f64]) -> f64, + mut assign_indexed: F, +) where + F: FnMut(&str, usize, f64), +{ + let mut stack: Vec = Vec::new(); + let mut pc: usize = 0; + let code_len = code.len(); + while pc < code_len { + match &code[pc] { + Opcode::PushConst(v) => { + stack.push(*v); + pc += 1; + } + Opcode::LoadParam(i) => { + let v = if *i < p.len() { p[*i] } else { 0.0 }; + stack.push(v); + pc += 1; + } + Opcode::LoadX(i) => { + let v = if *i < x.len() { x[*i] } else { 0.0 }; + stack.push(v); + pc += 1; + } + Opcode::LoadRateiv(i) => { + let v = if *i < rateiv.len() { rateiv[*i] } else { 0.0 }; + stack.push(v); + pc += 1; + } + Opcode::LoadParamDyn => { + // index is expected on stack as f64 + let idxf = stack.pop().unwrap_or(0.0); + let idx = idxf as usize; + let v = if idx < p.len() { p[idx] } else { 0.0 }; + stack.push(v); + pc += 1; + } + Opcode::LoadXDyn => { + let idxf = stack.pop().unwrap_or(0.0); + let idx = idxf as usize; + let v = if idx < x.len() { x[idx] } else { 0.0 }; + stack.push(v); + pc += 1; + } + Opcode::LoadRateivDyn => { + let idxf = stack.pop().unwrap_or(0.0); + let idx = idxf as usize; + let v = if idx < rateiv.len() { rateiv[idx] } else { 0.0 }; + stack.push(v); + pc += 1; + } + Opcode::LoadLocal(i) => { + let v = if *i < locals.len() { locals[*i] } else { 0.0 }; + stack.push(v); + pc += 1; + } + Opcode::LoadT => { + stack.push(t); + pc += 1; + } + Opcode::Add => { + let b = stack.pop().unwrap_or(0.0); + let a = stack.pop().unwrap_or(0.0); + stack.push(a + b); + pc += 1; + } + Opcode::Sub => { + let b = stack.pop().unwrap_or(0.0); + let a = stack.pop().unwrap_or(0.0); + stack.push(a - b); + pc += 1; + } + Opcode::Mul => { + let b = stack.pop().unwrap_or(0.0); + let a = stack.pop().unwrap_or(0.0); + stack.push(a * b); + pc += 1; + } + Opcode::Div => { + let b = stack.pop().unwrap_or(0.0); + let a = stack.pop().unwrap_or(0.0); + stack.push(a / b); + pc += 1; + } + Opcode::Pow => { + let b = stack.pop().unwrap_or(0.0); + let a = stack.pop().unwrap_or(0.0); + stack.push(a.powf(b)); + pc += 1; + } + Opcode::Lt => { + let b = stack.pop().unwrap_or(0.0); + let a = stack.pop().unwrap_or(0.0); + stack.push(if a < b { 1.0 } else { 0.0 }); + pc += 1; + } + Opcode::Gt => { + let b = stack.pop().unwrap_or(0.0); + let a = stack.pop().unwrap_or(0.0); + stack.push(if a > b { 1.0 } else { 0.0 }); + pc += 1; + } + Opcode::Le => { + let b = stack.pop().unwrap_or(0.0); + let a = stack.pop().unwrap_or(0.0); + stack.push(if a <= b { 1.0 } else { 0.0 }); + pc += 1; + } + Opcode::Ge => { + let b = stack.pop().unwrap_or(0.0); + let a = stack.pop().unwrap_or(0.0); + stack.push(if a >= b { 1.0 } else { 0.0 }); + pc += 1; + } + Opcode::Eq => { + let b = stack.pop().unwrap_or(0.0); + let a = stack.pop().unwrap_or(0.0); + stack.push(if a == b { 1.0 } else { 0.0 }); + pc += 1; + } + Opcode::Ne => { + let b = stack.pop().unwrap_or(0.0); + let a = stack.pop().unwrap_or(0.0); + stack.push(if a != b { 1.0 } else { 0.0 }); + pc += 1; + } + Opcode::Jump(addr) => { + pc = *addr; + } + Opcode::JumpIfFalse(addr) => { + let c = stack.pop().unwrap_or(0.0); + if c == 0.0 { + pc = *addr; + } else { + pc += 1; + } + } + Opcode::CallBuiltin(func_idx, argc) => { + // pop args in reverse order + let mut args: Vec = Vec::with_capacity(*argc); + for _ in 0..*argc { + args.push(stack.pop().unwrap_or(0.0)); + } + args.reverse(); + let func_name = funcs.get(*func_idx).map(|s| s.as_str()).unwrap_or(""); + let res = builtins_dispatch(func_name, &args); + stack.push(res); + pc += 1; + } + Opcode::StoreDx(i) => { + let v = stack.pop().unwrap_or(0.0); + assign_indexed("dx", *i, v); + pc += 1; + } + Opcode::StoreX(i) => { + let v = stack.pop().unwrap_or(0.0); + assign_indexed("x", *i, v); + pc += 1; + } + Opcode::StoreY(i) => { + let v = stack.pop().unwrap_or(0.0); + assign_indexed("y", *i, v); + pc += 1; + } + Opcode::StoreDxDyn => { + // pop value then index + let v = stack.pop().unwrap_or(0.0); + let idxf = stack.pop().unwrap_or(0.0); + let idx = idxf as usize; + assign_indexed("dx", idx, v); + pc += 1; + } + Opcode::StoreXDyn => { + let v = stack.pop().unwrap_or(0.0); + let idxf = stack.pop().unwrap_or(0.0); + let idx = idxf as usize; + assign_indexed("x", idx, v); + pc += 1; + } + Opcode::StoreYDyn => { + let v = stack.pop().unwrap_or(0.0); + let idxf = stack.pop().unwrap_or(0.0); + let idx = idxf as usize; + assign_indexed("y", idx, v); + pc += 1; + } + Opcode::StoreLocal(i) => { + let v = stack.pop().unwrap_or(0.0); + if *i < locals.len() { + locals[*i] = v; + } + pc += 1; + } + Opcode::Pop => { + let _ = stack.pop(); + pc += 1; + } + } + } +} + +/// Backwards-compatible lightweight runner used by some unit tests and the +/// legacy emit POC. Runs a minimal subset (params + arithmetic + StoreDx). +pub fn run_bytecode(code: &[Opcode], p: &[f64], mut assign_dx: F) +where + F: FnMut(usize, f64), +{ + // emulate a minimal environment + let x: Vec = Vec::new(); + let rateiv: Vec = Vec::new(); + let mut locals: Vec = Vec::new(); + let funcs: Vec = Vec::new(); + let builtins = |_: &str, _: &[f64]| -> f64 { 0.0 }; + run_bytecode_full( + code, + &x, + p, + &rateiv, + 0.0, + &mut locals, + &funcs, + &builtins, + |n, i, v| { + if n == "dx" { + assign_dx(i, v); + } + }, + ); +} + +/// Run a sequence of opcodes and return the top-of-stack value at the end. +/// This is useful for bytecode fragments that compute an expression value +/// (e.g., lag/fa entries) rather than performing stores. +pub fn run_bytecode_eval( + code: &[Opcode], + x: &[f64], + p: &[f64], + rateiv: &[f64], + t: f64, + locals: &mut [f64], + funcs: &Vec, + builtins_dispatch: &dyn Fn(&str, &[f64]) -> f64, +) -> f64 { + let mut stack: Vec = Vec::new(); + let mut pc: usize = 0; + let code_len = code.len(); + while pc < code_len { + match &code[pc] { + Opcode::PushConst(v) => { + stack.push(*v); + pc += 1; + } + Opcode::LoadParam(i) => { + let v = if *i < p.len() { p[*i] } else { 0.0 }; + stack.push(v); + pc += 1; + } + Opcode::LoadX(i) => { + let v = if *i < x.len() { x[*i] } else { 0.0 }; + stack.push(v); + pc += 1; + } + Opcode::LoadRateiv(i) => { + let v = if *i < rateiv.len() { rateiv[*i] } else { 0.0 }; + stack.push(v); + pc += 1; + } + Opcode::LoadParamDyn => { + let idxf = stack.pop().unwrap_or(0.0); + let idx = idxf as usize; + let v = if idx < p.len() { p[idx] } else { 0.0 }; + stack.push(v); + pc += 1; + } + Opcode::LoadXDyn => { + let idxf = stack.pop().unwrap_or(0.0); + let idx = idxf as usize; + let v = if idx < x.len() { x[idx] } else { 0.0 }; + stack.push(v); + pc += 1; + } + Opcode::LoadRateivDyn => { + let idxf = stack.pop().unwrap_or(0.0); + let idx = idxf as usize; + let v = if idx < rateiv.len() { rateiv[idx] } else { 0.0 }; + stack.push(v); + pc += 1; + } + Opcode::LoadLocal(i) => { + let v = if *i < locals.len() { locals[*i] } else { 0.0 }; + stack.push(v); + pc += 1; + } + Opcode::LoadT => { + stack.push(t); + pc += 1; + } + Opcode::Add => { + let b = stack.pop().unwrap_or(0.0); + let a = stack.pop().unwrap_or(0.0); + stack.push(a + b); + pc += 1; + } + Opcode::Sub => { + let b = stack.pop().unwrap_or(0.0); + let a = stack.pop().unwrap_or(0.0); + stack.push(a - b); + pc += 1; + } + Opcode::Mul => { + let b = stack.pop().unwrap_or(0.0); + let a = stack.pop().unwrap_or(0.0); + stack.push(a * b); + pc += 1; + } + Opcode::Div => { + let b = stack.pop().unwrap_or(0.0); + let a = stack.pop().unwrap_or(0.0); + stack.push(a / b); + pc += 1; + } + Opcode::Pow => { + let b = stack.pop().unwrap_or(0.0); + let a = stack.pop().unwrap_or(0.0); + stack.push(a.powf(b)); + pc += 1; + } + Opcode::Lt => { + let b = stack.pop().unwrap_or(0.0); + let a = stack.pop().unwrap_or(0.0); + stack.push(if a < b { 1.0 } else { 0.0 }); + pc += 1; + } + Opcode::Gt => { + let b = stack.pop().unwrap_or(0.0); + let a = stack.pop().unwrap_or(0.0); + stack.push(if a > b { 1.0 } else { 0.0 }); + pc += 1; + } + Opcode::Le => { + let b = stack.pop().unwrap_or(0.0); + let a = stack.pop().unwrap_or(0.0); + stack.push(if a <= b { 1.0 } else { 0.0 }); + pc += 1; + } + Opcode::Ge => { + let b = stack.pop().unwrap_or(0.0); + let a = stack.pop().unwrap_or(0.0); + stack.push(if a >= b { 1.0 } else { 0.0 }); + pc += 1; + } + Opcode::Eq => { + let b = stack.pop().unwrap_or(0.0); + let a = stack.pop().unwrap_or(0.0); + stack.push(if a == b { 1.0 } else { 0.0 }); + pc += 1; + } + Opcode::Ne => { + let b = stack.pop().unwrap_or(0.0); + let a = stack.pop().unwrap_or(0.0); + stack.push(if a != b { 1.0 } else { 0.0 }); + pc += 1; + } + Opcode::Jump(addr) => { + pc = *addr; + } + Opcode::JumpIfFalse(addr) => { + let c = stack.pop().unwrap_or(0.0); + if c == 0.0 { + pc = *addr; + } else { + pc += 1; + } + } + Opcode::CallBuiltin(func_idx, argc) => { + let mut args: Vec = Vec::with_capacity(*argc); + for _ in 0..*argc { + args.push(stack.pop().unwrap_or(0.0)); + } + args.reverse(); + let func_name = funcs.get(*func_idx).map(|s| s.as_str()).unwrap_or(""); + let res = builtins_dispatch(func_name, &args); + stack.push(res); + pc += 1; + } + Opcode::StoreDx(i) => { + // for eval, treat like push value (no-op) + let _ = stack.pop().unwrap_or(0.0); + pc += 1; + } + Opcode::StoreX(i) => { + let _ = stack.pop().unwrap_or(0.0); + pc += 1; + } + Opcode::StoreY(i) => { + let _ = stack.pop().unwrap_or(0.0); + pc += 1; + } + Opcode::StoreLocal(i) => { + let v = stack.pop().unwrap_or(0.0); + if *i < locals.len() { + locals[*i] = v; + } + pc += 1; + } + Opcode::StoreDxDyn | Opcode::StoreXDyn | Opcode::StoreYDyn => { + // pop value then index and ignore for eval + let _v = stack.pop().unwrap_or(0.0); + let _idxf = stack.pop().unwrap_or(0.0); + pc += 1; + } + Opcode::Pop => { + let _ = stack.pop(); + pc += 1; + } + } + } + + stack.pop().unwrap_or(0.0) +} diff --git a/src/exa_wasm/mod.rs b/src/exa_wasm/mod.rs new file mode 100644 index 00000000..d8ac8052 --- /dev/null +++ b/src/exa_wasm/mod.rs @@ -0,0 +1,11 @@ +//! WASM-compatible `exa` alternative. +//! +//! This module contains a small IR emitter and an interpreter that can run +//! user-defined models in WASM hosts without requiring cargo compilation or +//! dynamic library loading. It's gated under the `exa-wasm` cargo feature. + +pub mod build; +pub mod interpreter; + +pub use build::emit_ir; +pub use interpreter::{load_ir_ode, ode_for_id, unregister_model}; diff --git a/src/lib.rs b/src/lib.rs index 58f01d8b..f8cf655e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ pub mod data; pub mod error; #[cfg(feature = "exa")] pub mod exa; +pub mod exa_wasm; pub mod optimize; pub mod simulator; @@ -16,6 +17,12 @@ pub use crate::simulator::equation::{self, ODE}; pub use error::PharmsolError; #[cfg(feature = "exa")] pub use exa::*; +// When the `exa` (native) feature is enabled prefer its exports at crate root to +// avoid ambiguous glob re-exports between `exa` and `exa_wasm` (they both expose +// `build` and `interpreter` modules). When `exa` is not enabled, re-export +// `exa_wasm` at the crate root so its API is available. +#[cfg(not(feature = "exa"))] +pub use exa_wasm::*; pub use nalgebra::dmatrix; pub use std::collections::HashMap; diff --git a/src/simulator/equation/ode/mod.rs b/src/simulator/equation/ode/mod.rs index 8e1ce471..784037f8 100644 --- a/src/simulator/equation/ode/mod.rs +++ b/src/simulator/equation/ode/mod.rs @@ -33,6 +33,8 @@ pub struct ODE { init: Init, out: Out, neqs: Neqs, + // Optional registry id pointing to interpreter expressions + registry_id: Option, } impl ODE { @@ -44,6 +46,28 @@ impl ODE { init, out, neqs, + registry_id: None, + } + } + + /// Create an ODE with an associated interpreter registry id. + pub fn with_registry_id( + diffeq: DiffEq, + lag: Lag, + fa: Fa, + init: Init, + out: Out, + neqs: Neqs, + registry_id: Option, + ) -> Self { + Self { + diffeq, + lag, + fa, + init, + out, + neqs, + registry_id, } } } @@ -199,6 +223,28 @@ impl Equation for ODE { support_point: &Vec, error_models: Option<&ErrorModels>, ) -> Result<(Self::P, Option), PharmsolError> { + // Ensure the interpreter dispatchers use this ODE's registry id (if any). + // We set the thread-local current id for the duration of this call and + // restore it on exit via a small RAII guard. When the `exa` feature is + // disabled these are no-ops. + // Set the current interpreter registry id for both possible interpreter + // implementations (native `exa` and `exa_wasm`). Store previous ids and + // restore them on Drop. Using a single guard type avoids type-mismatch + // issues across cfg branches. + struct RestoreGuard { + exa_wasm_prev: Option, + } + impl Drop for RestoreGuard { + fn drop(&mut self) { + // Always restore the exa_wasm interpreter id if present. + let _ = crate::exa_wasm::interpreter::set_current_expr_id(self.exa_wasm_prev); + } + } + + // Native `exa` does not provide an interpreter registry in this branch. + let exa_wasm_prev = crate::exa_wasm::interpreter::set_current_expr_id(self.registry_id); + let _restore_current = RestoreGuard { exa_wasm_prev }; + // let lag = self.get_lag(support_point); // let fa = self.get_fa(support_point); let mut output = Self::P::new(self.nparticles()); @@ -217,7 +263,15 @@ impl Equation for ODE { Some((self.fa(), self.lag(), support_point, covariates)), true, ); + // If interpreter produced a runtime error while computing lag/fa, propagate it + if let Some(err) = crate::exa_wasm::interpreter::take_runtime_error() { + return Err(PharmsolError::OtherError(err)); + } + let init_state = self.initial_state(support_point, covariates, occasion.index()); + if let Some(err) = crate::exa_wasm::interpreter::take_runtime_error() { + return Err(PharmsolError::OtherError(err)); + } let problem = OdeBuilder::::new() .atol(vec![ATOL]) .rtol(RTOL) @@ -230,8 +284,7 @@ impl Equation for ODE { support_point.clone(), //TODO: Avoid cloning the support point covariates, infusions, - self.initial_state(support_point, covariates, occasion.index()) - .into(), + init_state.into(), ))?; let mut solver: Bdf< @@ -268,6 +321,9 @@ impl Equation for ODE { zero_vector.clone(), covariates, ); + if let Some(err) = crate::exa_wasm::interpreter::take_runtime_error() { + return Err(PharmsolError::OtherError(err)); + } // Call the differential equation closure with bolus (self.diffeq)( @@ -279,6 +335,9 @@ impl Equation for ODE { zero_vector.clone(), covariates, ); + if let Some(err) = crate::exa_wasm::interpreter::take_runtime_error() { + return Err(PharmsolError::OtherError(err)); + } // The difference between the two states is the actual bolus effect // Apply the computed changes to the state @@ -299,6 +358,9 @@ impl Equation for ODE { covariates, &mut y, ); + if let Some(err) = crate::exa_wasm::interpreter::take_runtime_error() { + return Err(PharmsolError::OtherError(err)); + } let pred = y[observation.outeq()]; let pred = observation.to_prediction(pred, solver.state().y.as_slice().to_vec()); @@ -315,6 +377,13 @@ impl Equation for ODE { match solver.set_stop_time(next_event.time()) { Ok(_) => loop { let ret = solver.step(); + // If the interpreter set a runtime error during evaluation inside + // the ODE step, surface it here. + if let Some(err) = + crate::exa_wasm::interpreter::take_runtime_error() + { + return Err(PharmsolError::OtherError(err)); + } match ret { Ok(OdeSolverStopReason::InternalTimestep) => continue, Ok(OdeSolverStopReason::TstopReached) => break,