diff --git a/examples/api_comparison.rs b/examples/api_comparison.rs new file mode 100644 index 00000000..58b87f06 --- /dev/null +++ b/examples/api_comparison.rs @@ -0,0 +1,174 @@ +//! Comparison example demonstrating both the old (tuple-based) and new (builder) APIs. +//! +//! This example shows that both APIs produce identical results, ensuring backward compatibility +//! while providing a more ergonomic builder pattern for new code. +//! +//! The new builder API uses the type-state pattern to enforce required fields at compile time, +//! while making optional fields (lag, fa, init) truly optional with sensible defaults. + +use pharmsol::prelude::models::one_compartment; +use pharmsol::*; + +fn main() { + println!("=== API Comparison: Old (Tuple) vs New (Builder) ===\n"); + + // Create a simple subject for testing + let subject = Subject::builder("comparison_test") + .infusion(0.0, 500.0, 0, 0.5) + .observation(0.5, 0.0, 0) + .observation(1.0, 0.0, 0) + .observation(2.0, 0.0, 0) + .observation(4.0, 0.0, 0) + .observation(8.0, 0.0, 0) + .build(); + + let params = vec![1.02282724609375, 194.51904296875]; // ke, v + + println!("--- ODE Models ---"); + + let ode_old = equation::ODE::new( + |x, p, _t, dx, _b, rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + rateiv[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + (1, 1), + ); + + let ode_minimal = equation::ODE::builder() + .diffeq(|x, p, _t, dx, _b, rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + rateiv[0]; + }) + .out(|x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }) + .nstates(1) + .nouteqs(1) + .build(); + + // Also show full specification with optional fields + let ode_full = equation::ODE::builder() + .diffeq(|x, p, _t, dx, _b, rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + rateiv[0]; + }) + .out(|x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }) + .neqs(Neqs::new(1, 1)) // Can also use Neqs struct + .build(); + + // Compare predictions + let pred_old = ode_old.estimate_predictions(&subject, ¶ms).unwrap(); + let pred_minimal = ode_minimal.estimate_predictions(&subject, ¶ms).unwrap(); + let pred_full = ode_full.estimate_predictions(&subject, ¶ms).unwrap(); + + println!("ODE Predictions (Old API with all 6 args):"); + for p in pred_old.flat_predictions() { + print!("{:.9} ", p); + } + println!("\n"); + + println!("ODE Predictions (New Builder - minimal, 4 required fields only):"); + for p in pred_minimal.flat_predictions() { + print!("{:.9} ", p); + } + println!("\n"); + + println!("ODE Predictions (New Builder - all fields explicit):"); + for p in pred_full.flat_predictions() { + print!("{:.9} ", p); + } + println!("\n"); + + // Verify they match + let old_preds = pred_old.flat_predictions(); + let minimal_preds = pred_minimal.flat_predictions(); + let full_preds = pred_full.flat_predictions(); + + let all_match = old_preds + .iter() + .zip(minimal_preds.iter()) + .zip(full_preds.iter()) + .all(|((a, b), c)| (a - b).abs() < 1e-12 && (a - c).abs() < 1e-12); + + println!("All ODE APIs produce identical results: {} ✓", all_match); + println!(); + + // ========================================================================= + // Analytical: Old API (tuple-based) + // ========================================================================= + println!("--- Analytical Models ---"); + + let analytical_old = equation::Analytical::new( + one_compartment, + |_p, _t, _cov| {}, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + (1, 1), // Old tuple-based Neqs + ); + + // ========================================================================= + // Analytical: New API (builder pattern) - MINIMAL version + // Required fields: eq, seq_eq, out, nstates, nouteqs + // ========================================================================= + let analytical_minimal = equation::Analytical::builder() + .eq(one_compartment) + .seq_eq(|_p, _t, _cov| {}) + .out(|x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }) + .nstates(1) + .nouteqs(1) + .build(); + + // Compare predictions + let an_pred_old = analytical_old + .estimate_predictions(&subject, ¶ms) + .unwrap(); + let an_pred_minimal = analytical_minimal + .estimate_predictions(&subject, ¶ms) + .unwrap(); + + println!("Analytical Predictions (Old API with all 7 args):"); + for p in an_pred_old.flat_predictions() { + print!("{:.9} ", p); + } + println!("\n"); + + println!("Analytical Predictions (New Builder - minimal, 5 required fields only):"); + for p in an_pred_minimal.flat_predictions() { + print!("{:.9} ", p); + } + println!("\n"); + + // Verify they match + let an_old_preds = an_pred_old.flat_predictions(); + let an_minimal_preds = an_pred_minimal.flat_predictions(); + + let analytical_match = an_old_preds + .iter() + .zip(an_minimal_preds.iter()) + .all(|(a, b)| (a - b).abs() < 1e-12); + + println!( + "Analytical APIs produce identical results: {} ✓", + analytical_match + ); + println!(); +} diff --git a/examples/bke.rs b/examples/bke.rs index fa800c4a..b7d4ae6a 100644 --- a/examples/bke.rs +++ b/examples/bke.rs @@ -13,34 +13,30 @@ fn main() { .missing_observation(12.0, 0) .build(); - let an = equation::Analytical::new( - one_compartment, - |_p, _t, _cov| {}, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { + let an = equation::Analytical::builder() + .eq(one_compartment) + .seq_eq(|_p, _t, _cov| {}) + .out(|x, p, _t, _cov, y| { fetch_params!(p, _ke, v); y[0] = x[0] / v; - }, - (1, 1), - ); + }) + .nstates(1) + .nouteqs(1) + .build(); - let ode = equation::ODE::new( - |x, p, _t, dx, _b, rateiv, _cov| { + let ode = equation::ODE::builder() + .diffeq(|x, p, _t, dx, _b, rateiv, _cov| { // fetch_cov!(cov, t, wt); fetch_params!(p, ke, _v); dx[0] = -ke * x[0] + rateiv[0]; - }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { + }) + .out(|x, p, _t, _cov, y| { fetch_params!(p, _ke, v); y[0] = x[0] / v; - }, - (1, 1), - ); + }) + .nstates(1) + .nouteqs(1) + .build(); let mut ems = ErrorModels::new() .add( diff --git a/examples/error.rs b/examples/error.rs index 5ac397d5..87939199 100644 --- a/examples/error.rs +++ b/examples/error.rs @@ -4,8 +4,8 @@ use pharmsol::*; fn main() { - let ode = equation::ODE::new( - |x, p, t, dx, _b, rateiv, cov| { + let ode = equation::ODE::builder() + .diffeq(|x, p, t, dx, _b, rateiv, cov| { fetch_cov!(cov, t, WT); fetch_params!(p, CL0, V0, Vp0, Q0); @@ -20,18 +20,16 @@ fn main() { dx[0] = -Ke * x[0] - KCP * x[0] + KPC * x[1] + rateiv[0]; dx[1] = KCP * x[0] - KPC * x[1]; - }, - |p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, t, cov, y| { + }) + .out(|x, p, t, cov, y| { fetch_cov!(cov, t, WT); fetch_params!(p, CL0, V0, Vp0, Q0); let V = V0 / (WT / 85.0); y[0] = x[0] / V; - }, - (2, 1), - ); + }) + .nstates(2) + .nouteqs(1) + .build(); let subject = data::Subject::builder("id1") .infusion(0.0, 3235.0, 0, 0.005) diff --git a/examples/exa.rs b/examples/exa.rs index 685f488f..1e53cd76 100644 --- a/examples/exa.rs +++ b/examples/exa.rs @@ -16,20 +16,18 @@ fn main() { .build(); // Create ODE model directly - let ode = equation::ODE::new( - |x, p, _t, dx, b, rateiv, _cov| { + let ode = equation::ODE::builder() + .diffeq(|x, p, _t, dx, b, rateiv, _cov| { fetch_params!(p, ke, _v); dx[0] = -ke * x[0] + rateiv[0] + b[0]; - }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { + }) + .out(|x, p, _t, _cov, y| { fetch_params!(p, _ke, v); y[0] = x[0] / v; - }, - (1, 1), - ); + }) + .nstates(1) + .nouteqs(1) + .build(); let test_dir = std::env::current_dir().expect("Failed to get current directory"); let model_output_path = test_dir.join("test_model.pkm"); diff --git a/examples/gendata.rs b/examples/gendata.rs index ba214a78..88114226 100644 --- a/examples/gendata.rs +++ b/examples/gendata.rs @@ -13,8 +13,10 @@ fn main() { .repeat(5, 0.2) .build(); - let sde = equation::SDE::new( - |x, p, _t, dx, _rateiv, _cov| { + // Type-state builder: only required fields + init (which is used here) + // lag and fa are optional and default to no-op + let sde = equation::SDE::builder() + .drift(|x, p, _t, dx, _rateiv, _cov| { // automatically defined fetch_params!(p, ke0); // let ke0 = 1.2; @@ -22,24 +24,24 @@ fn main() { let ke = x[1]; // user defined dx[0] = -ke * x[0]; - }, - |p, d| { + }) + .diffusion(|p, d| { fetch_params!(p, _ke0); d[1] = 0.1; - }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |p, _t, _cov, x| { + }) + // init is specified because we need non-zero initial state + .init(|p, _t, _cov, x| { fetch_params!(p, ke0); x[1] = ke0; - }, - |x, p, _t, _cov, y| { + }) + .out(|x, p, _t, _cov, y| { fetch_params!(p, _ke0); y[0] = x[0] / 50.0; - }, - (2, 1), - 1, - ); + }) + .nstates(2) + .nouteqs(1) + .nparticles(1) + .build(); let ke_dist = rand_distr::Normal::new(1.2, 0.12).unwrap(); // let v_dist = rand_distr::Normal::new(50.0, 10.0).unwrap(); diff --git a/examples/nsim.rs b/examples/nsim.rs index 9e34a642..1e19f929 100644 --- a/examples/nsim.rs +++ b/examples/nsim.rs @@ -12,54 +12,32 @@ fn main() { .covariate("age", 0.0, 25.0) .build(); println!("{subject}"); - let ode = equation::ODE::new( - |x, p, t, dx, _b, _rateiv, cov| { + + // Type-state builder: only required fields + lag (which is used here) + // fa and init are optional and default to no-op + let ode = equation::ODE::builder() + .diffeq(|x, p, t, dx, _b, _rateiv, cov| { fetch_cov!(cov, t, wt, age); fetch_params!(p, ka, ke, _tlag, _v); // Secondary Eqs - let ke = ke * wt.powf(0.75) * (age / 25.0).powf(0.5); //Struct dx[0] = -ka * x[0]; dx[1] = ka * x[0] - ke * x[1]; - }, - |p, _t, _cov| { + }) + .lag(|p, _t, _cov| { fetch_params!(p, _ka, _ke, tlag, _v); lag! {0=>tlag} - }, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { + }) + // fa and init omitted - using defaults + .out(|x, p, _t, _cov, y| { fetch_params!(p, _ka, _ke, _tlag, v); y[0] = x[1] / v; - }, - (2, 1), - ); - // let ode = simulator::Equation::new_ode( - // |x, p, t, dx, _rateiv, cov| { - // fetch_cov!(cov, t, wt, age); - // fetch_params!(p, ka, ke, _tlag, _v); - // // Secondary Eqs - - // let ke = ke * wt.powf(0.75) * (age / 25.0).powf(0.5); - - // //Struct - // dx[0] = -ka * x[0]; - // dx[1] = ka * x[0] - ke * x[1]; - // }, - // |p| { - // fetch_params!(p, _ka, _ke, tlag, _v); - // lag! {0=>tlag} - // }, - // |_p, _t, _cov| fa! {}, - // |_p, _t, _cov, _x| {}, - // |x, p, _t, _cov, y| { - // fetch_params!(p, _ka, _ke, _tlag, v); - // y[0] = x[1] / v; - // }, - // (2, 1), - // ); + }) + .nstates(2) + .nouteqs(1) + .build(); let op = ode.estimate_predictions(&subject, &vec![0.3, 0.5, 0.1, 70.0]); println!("{op:#?}"); diff --git a/examples/ode_readme.rs b/examples/ode_readme.rs index 5dbe047f..24e9e44d 100644 --- a/examples/ode_readme.rs +++ b/examples/ode_readme.rs @@ -14,26 +14,26 @@ fn main() { .covariate("age", 0.0, 25.0) .build(); println!("{subject}"); - let ode = equation::ODE::new( - |x, p, _t, dx, _b, _rateiv, _cov| { + + let ode = equation::ODE::builder() + .diffeq(|x, p, _t, dx, _b, _rateiv, _cov| { // fetch_cov!(cov, t,); fetch_params!(p, ka, ke, _tlag, _v); //Struct dx[0] = -ka * x[0]; dx[1] = ka * x[0] - ke * x[1]; - }, - |p, _t, _cov| { + }) + .lag(|p, _t, _cov| { fetch_params!(p, _ka, _ke, tlag, _v); lag! {0=>tlag} - }, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { + }) + .out(|x, p, _t, _cov, y| { fetch_params!(p, _ka, _ke, _tlag, v); y[0] = x[1] / v; - }, - (2, 1), - ); + }) + .nstates(2) + .nouteqs(1) + .build(); let op = ode .estimate_predictions(&subject, &vec![0.3, 0.5, 0.1, 70.0]) diff --git a/examples/pf.rs b/examples/pf.rs index 9e0b4446..6a9787d2 100644 --- a/examples/pf.rs +++ b/examples/pf.rs @@ -10,24 +10,23 @@ fn main() { .observation(1.0, 7.5170, 0) .build(); - let sde = equation::SDE::new( - |x, p, _t, dx, _rateiv, _cov| { + let sde = equation::SDE::builder() + .drift(|x, p, _t, dx, _rateiv, _cov| { dx[0] = -x[0] * x[1]; // ke *x[0] dx[1] = -x[1] + p[0]; // mean reverting - }, - |_p, d| { + }) + .diffusion(|_p, d| { d[0] = 1.0; d[1] = 0.01; - }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, x| x[1] = 1.0, - |x, _p, _t, _cov, y| { + }) + .init(|_p, _t, _cov, x| x[1] = 1.0) + .out(|x, _p, _t, _cov, y| { y[0] = x[0]; - }, - (2, 1), - 10000, - ); + }) + .nstates(2) + .nouteqs(1) + .nparticles(10000) + .build(); let ems = ErrorModels::new() .add( diff --git a/examples/sde.rs b/examples/sde.rs index c2274021..8f7719d1 100644 --- a/examples/sde.rs +++ b/examples/sde.rs @@ -1,25 +1,26 @@ use pharmsol::{prelude::data::read_pmetrics, *}; +// Simplified builder API: lag, fa, and init are now optional with sensible defaults +// Only required fields need to be specified + fn one_c_ode() -> ODE { - equation::ODE::new( - |x, p, _t, dx, _b, _rateiv, _cov| { + equation::ODE::builder() + .diffeq(|x, p, _t, dx, _b, _rateiv, _cov| { // fetch_cov!(cov, t, wt); fetch_params!(p, ke); dx[0] = -ke * x[0]; - }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, _p, _t, _cov, y| { + }) + .out(|x, _p, _t, _cov, y| { y[0] = x[0] / 50.0; - }, - (1, 1), - ) + }) + .nstates(1) + .nouteqs(1) + .build() } fn one_c_sde() -> SDE { - equation::SDE::new( - |x, p, _t, dx, _rateiv, _cov| { + equation::SDE::builder() + .drift(|x, p, _t, dx, _rateiv, _cov| { // automatically defined fetch_params!(p, ke0, _ske); // let ke0 = 1.2; @@ -27,73 +28,71 @@ fn one_c_sde() -> SDE { let ke = x[1]; // user defined dx[0] = -ke * x[0]; - }, - |p, d| { + }) + .diffusion(|p, d| { fetch_params!(p, _ke0, ske); d[1] = ske; - }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |p, _t, _cov, x| { + }) + // init is specified because we need non-zero initial state for the ke parameter + .init(|p, _t, _cov, x| { fetch_params!(p, ke0, _ske); x[1] = ke0; - }, - |x, p, _t, _cov, y| { + }) + .out(|x, p, _t, _cov, y| { fetch_params!(p, _ke0, _ske); y[0] = x[0] / 50.0; - }, - (2, 1), - 2, - ) + }) + .nstates(2) + .nouteqs(1) + .nparticles(2) + .build() } fn three_c_ode() -> ODE { - equation::ODE::new( - |x, p, _t, dx, _b, _rateiv, _cov| { + equation::ODE::builder() + .diffeq(|x, p, _t, dx, _b, _rateiv, _cov| { // fetch_cov!(cov, t, wt); fetch_params!(p, ka, ke, kcp, kpc, _vol); dx[0] = -ka * x[0]; dx[1] = ka * x[0] - (ke + kcp) * x[1] + kpc * x[2]; dx[2] = kcp * x[1] - kpc * x[2]; - }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { + }) + .out(|x, p, _t, _cov, y| { fetch_params!(p, _ka, _ke, _kcp, _kpc, vol); y[0] = x[1] / vol; - }, - (3, 3), - ) + }) + .nstates(3) + .nouteqs(3) + .build() } fn three_c_sde() -> SDE { - equation::SDE::new( - |x, p, _t, dx, _rateiv, _cov| { + equation::SDE::builder() + .drift(|x, p, _t, dx, _rateiv, _cov| { fetch_params!(p, ka, ke0, kcp, kpc, _vol, _ske); dx[3] = -x[3] + ke0; let ke = x[3]; dx[0] = -ka * x[0]; dx[1] = ka * x[0] - (ke + kcp) * x[1] + kpc * x[2]; dx[2] = kcp * x[1] - kpc * x[2]; - }, - |p, d| { + }) + .diffusion(|p, d| { fetch_params!(p, _ka, _ke0, _kcp, _kpc, _vol, ske); d[3] = ske; - }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |p, _t, _cov, x| { + }) + // init is specified because we need non-zero initial state + .init(|p, _t, _cov, x| { fetch_params!(p, _ka, ke0, _kcp, _kpc, _vol, _ske); x[3] = ke0; - }, - |x, p, _t, _cov, y| { + }) + .out(|x, p, _t, _cov, y| { fetch_params!(p, _ka, _ke0, _kcp0, _kpc0, vol, _ske); y[0] = x[1] / vol; - }, - (4, 1), - 2, - ) + }) + .nstates(4) + .nouteqs(1) + .nparticles(2) + .build() } fn main() { diff --git a/src/lib.rs b/src/lib.rs index a88e465d..15a43df2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,7 +12,11 @@ pub use crate::data::*; pub use crate::equation::*; pub use crate::optimize::effect::get_e2; pub use crate::optimize::spp::SppOptimizer; -pub use crate::simulator::equation::{self, ODE}; +pub use crate::simulator::equation::analytical::{Analytical, AnalyticalBuilder}; +pub use crate::simulator::equation::ode::ODEBuilder; +pub use crate::simulator::equation::sde::{SDEBuilder, SDE}; +pub use crate::simulator::equation::{self, Missing, Provided, ODE}; +pub use crate::simulator::Neqs; pub use error::PharmsolError; #[cfg(feature = "exa")] pub use exa::*; @@ -29,6 +33,9 @@ pub mod prelude { pub mod simulator { pub use crate::simulator::{ equation, + equation::analytical::{Analytical, AnalyticalBuilder}, + equation::ode::ODEBuilder, + equation::sde::{SDEBuilder, SDE}, equation::Equation, likelihood::{log_psi, psi, PopulationPredictions, Prediction, SubjectPredictions}, }; diff --git a/src/simulator/equation/analytical/mod.rs b/src/simulator/equation/analytical/mod.rs index 5db77979..d7ed9456 100644 --- a/src/simulator/equation/analytical/mod.rs +++ b/src/simulator/equation/analytical/mod.rs @@ -2,6 +2,9 @@ pub mod one_compartment_models; pub mod three_compartment_models; pub mod two_compartment_models; +use std::collections::HashMap; +use std::marker::PhantomData; + use diffsol::{NalgebraContext, Vector, VectorHost}; pub use one_compartment_models::*; pub use three_compartment_models::*; @@ -9,7 +12,8 @@ pub use two_compartment_models::*; use crate::PharmsolError; use crate::{ - data::Covariates, simulator::*, Equation, EquationPriv, EquationTypes, Observation, Subject, + data::Covariates, simulator::*, Equation, EquationPriv, EquationTypes, Missing, Observation, + Provided, Subject, }; use cached::proc_macro::cached; use cached::UnboundCache; @@ -33,6 +37,8 @@ pub struct Analytical { impl Analytical { /// Create a new Analytical equation model. /// + /// For a more ergonomic API, consider using [`AnalyticalBuilder`] instead. + /// /// # Parameters /// - `eq`: The analytical equation function /// - `seq_eq`: The secondary equation function @@ -40,7 +46,7 @@ impl Analytical { /// - `fa`: The fraction absorbed function /// - `init`: The initial state function /// - `out`: The output equation function - /// - `neqs`: The number of states and output equations + /// - `neqs`: The number of states and output equations (can be a tuple or [`Neqs`]) pub fn new( eq: AnalyticalEq, seq_eq: SecEq, @@ -48,7 +54,7 @@ impl Analytical { fa: Fa, init: Init, out: Out, - neqs: Neqs, + neqs: impl Into, ) -> Self { Self { eq, @@ -57,7 +63,316 @@ impl Analytical { fa, init, out, - neqs, + neqs: neqs.into(), + } + } + + /// Returns a new [`AnalyticalBuilder`] for constructing an Analytical equation. + /// + /// # Example + /// ```ignore + /// use pharmsol::prelude::*; + /// + /// // Minimal builder - only required fields + /// let analytical = Analytical::builder() + /// .eq(one_compartment) + /// .seq_eq(|p, _t, _cov| {}) + /// .out(|x, p, _t, _cov, y| { y[0] = x[0] / p[1]; }) + /// .nstates(1) + /// .nouteqs(1) + /// .build(); + /// + /// // With optional fields + /// let analytical = Analytical::builder() + /// .eq(one_compartment) + /// .seq_eq(|p, _t, _cov| {}) + /// .out(|x, p, _t, _cov, y| { y[0] = x[0] / p[1]; }) + /// .nstates(1) + /// .nouteqs(1) + /// .lag(|p, _t, _cov| lag! { 0 => p[2] }) + /// .fa(|p, _t, _cov| fa! { 0 => 0.8 }) + /// .init(|p, _t, _cov, x| { x[0] = p[3]; }) + /// .build(); + /// ``` + pub fn builder() -> AnalyticalBuilder { + AnalyticalBuilder::new() + } +} + +// ============================================================================= +// Type-State Builder Pattern +// ============================================================================= + +// Note: Missing and Provided marker types are defined in the parent module +// and imported via `use crate::{..., Missing, Provided, ...}` + +/// Builder for constructing [`Analytical`] equations with compile-time validation. +/// +/// This builder uses the type-state pattern to ensure all required fields +/// are set before `build()` can be called. Optional fields (`lag`, `fa`, `init`) +/// have sensible defaults. +/// +/// # Required Fields (enforced at compile time) +/// - `eq`: The analytical equation function +/// - `seq_eq`: The secondary equation function +/// - `out`: Output equation function +/// - `nstates`: Number of state variables +/// - `nouteqs`: Number of output equations +/// +/// # Optional Fields (with defaults) +/// - `lag`: Lag time function (defaults to no lag) +/// - `fa`: Bioavailability function (defaults to 100% bioavailability) +/// - `init`: Initial state function (defaults to zero initial state) +/// +/// # Example +/// ```ignore +/// use pharmsol::prelude::*; +/// +/// // Minimal example - only required fields +/// let analytical = Analytical::builder() +/// .eq(one_compartment) +/// .seq_eq(|p, _t, _cov| {}) +/// .out(|x, p, _t, _cov, y| { y[0] = x[0] / p[1]; }) +/// .nstates(1) +/// .nouteqs(1) +/// .build(); +/// ``` +pub struct AnalyticalBuilder { + eq: Option, + seq_eq: Option, + lag: Option, + fa: Option, + init: Option, + out: Option, + nstates: Option, + nouteqs: Option, + _phantom: PhantomData<(EqState, SeqEqState, OutState, NStatesState, NOuteqsState)>, +} + +impl AnalyticalBuilder { + /// Creates a new AnalyticalBuilder with all required fields unset. + pub fn new() -> Self { + Self { + eq: None, + seq_eq: None, + lag: None, + fa: None, + init: None, + out: None, + nstates: None, + nouteqs: None, + _phantom: PhantomData, + } + } +} + +impl Default for AnalyticalBuilder { + fn default() -> Self { + Self::new() + } +} + +impl + AnalyticalBuilder +{ + /// Sets the lag time function (optional). + /// + /// If not set, defaults to no lag for any compartment. + pub fn lag(mut self, lag: Lag) -> Self { + self.lag = Some(lag); + self + } + + /// Sets the bioavailability function (optional). + /// + /// If not set, defaults to 100% bioavailability for all compartments. + pub fn fa(mut self, fa: Fa) -> Self { + self.fa = Some(fa); + self + } + + /// Sets the initial state function (optional). + /// + /// If not set, defaults to zero initial state for all compartments. + pub fn init(mut self, init: Init) -> Self { + self.init = Some(init); + self + } +} + +impl + AnalyticalBuilder +{ + /// Sets the analytical equation function (required). + pub fn eq( + self, + eq: AnalyticalEq, + ) -> AnalyticalBuilder { + AnalyticalBuilder { + eq: Some(eq), + seq_eq: self.seq_eq, + lag: self.lag, + fa: self.fa, + init: self.init, + out: self.out, + nstates: self.nstates, + nouteqs: self.nouteqs, + _phantom: PhantomData, + } + } +} + +impl + AnalyticalBuilder +{ + /// Sets the secondary equation function (required). + /// + /// This function is used to update parameters at each time step. + pub fn seq_eq( + self, + seq_eq: SecEq, + ) -> AnalyticalBuilder { + AnalyticalBuilder { + eq: self.eq, + seq_eq: Some(seq_eq), + lag: self.lag, + fa: self.fa, + init: self.init, + out: self.out, + nstates: self.nstates, + nouteqs: self.nouteqs, + _phantom: PhantomData, + } + } +} + +impl + AnalyticalBuilder +{ + /// Sets the output equation function (required). + pub fn out( + self, + out: Out, + ) -> AnalyticalBuilder { + AnalyticalBuilder { + eq: self.eq, + seq_eq: self.seq_eq, + lag: self.lag, + fa: self.fa, + init: self.init, + out: Some(out), + nstates: self.nstates, + nouteqs: self.nouteqs, + _phantom: PhantomData, + } + } +} + +impl + AnalyticalBuilder +{ + /// Sets the number of state variables (compartments) (required). + pub fn nstates( + self, + nstates: usize, + ) -> AnalyticalBuilder { + AnalyticalBuilder { + eq: self.eq, + seq_eq: self.seq_eq, + lag: self.lag, + fa: self.fa, + init: self.init, + out: self.out, + nstates: Some(nstates), + nouteqs: self.nouteqs, + _phantom: PhantomData, + } + } +} + +impl + AnalyticalBuilder +{ + /// Sets the number of output equations (required). + pub fn nouteqs( + self, + nouteqs: usize, + ) -> AnalyticalBuilder { + AnalyticalBuilder { + eq: self.eq, + seq_eq: self.seq_eq, + lag: self.lag, + fa: self.fa, + init: self.init, + out: self.out, + nstates: self.nstates, + nouteqs: Some(nouteqs), + _phantom: PhantomData, + } + } +} + +impl + AnalyticalBuilder +{ + /// Sets both nstates and nouteqs from a [`Neqs`] struct or tuple (required). + pub fn neqs( + self, + neqs: impl Into, + ) -> AnalyticalBuilder { + let neqs = neqs.into(); + AnalyticalBuilder { + eq: self.eq, + seq_eq: self.seq_eq, + lag: self.lag, + fa: self.fa, + init: self.init, + out: self.out, + nstates: Some(neqs.nstates), + nouteqs: Some(neqs.nouteqs), + _phantom: PhantomData, + } + } +} + +/// Default lag function: no lag for any compartment +fn default_lag(_p: &V, _t: f64, _cov: &Covariates) -> HashMap { + HashMap::new() +} + +/// Default fa function: 100% bioavailability for all compartments +fn default_fa(_p: &V, _t: f64, _cov: &Covariates) -> HashMap { + HashMap::new() +} + +/// Default init function: zero initial state +fn default_init(_p: &V, _t: f64, _cov: &Covariates, _x: &mut V) { + // State is already zero-initialized +} + +impl AnalyticalBuilder { + /// Builds the [`Analytical`] equation. + /// + /// This method is only available when all required fields have been set: + /// - `eq` + /// - `seq_eq` + /// - `out` + /// - `nstates` + /// - `nouteqs` + /// + /// Optional fields use defaults if not set: + /// - `lag`: No lag (empty HashMap) + /// - `fa`: 100% bioavailability (empty HashMap) + /// - `init`: Zero initial state + pub fn build(self) -> Analytical { + Analytical { + eq: self.eq.unwrap(), + seq_eq: self.seq_eq.unwrap(), + lag: self.lag.unwrap_or(default_lag), + fa: self.fa.unwrap_or(default_fa), + init: self.init.unwrap_or(default_init), + out: self.out.unwrap(), + neqs: Neqs::new(self.nstates.unwrap(), self.nouteqs.unwrap()), } } } @@ -100,12 +415,12 @@ impl EquationPriv for Analytical { #[inline(always)] fn get_nstates(&self) -> usize { - self.neqs.0 + self.neqs.nstates } #[inline(always)] fn get_nouteqs(&self) -> usize { - self.neqs.1 + self.neqs.nouteqs } #[inline(always)] fn solve( diff --git a/src/simulator/equation/mod.rs b/src/simulator/equation/mod.rs index 70219080..eb5a7fde 100644 --- a/src/simulator/equation/mod.rs +++ b/src/simulator/equation/mod.rs @@ -16,6 +16,22 @@ use crate::{ use super::likelihood::Prediction; +// ============================================================================= +// Type-State Builder Marker Types +// ============================================================================= + +/// Marker type indicating a required field is missing in a builder. +/// +/// This is used in the type-state builder pattern to enforce at compile time +/// that all required fields are set before `build()` can be called. +pub struct Missing; + +/// Marker type indicating a required field has been provided in a builder. +/// +/// This is used in the type-state builder pattern to track which required +/// fields have been set. +pub struct Provided; + /// Trait for state vectors that can receive bolus doses. pub trait State { /// Add a bolus dose to the state at the specified input compartment. diff --git a/src/simulator/equation/ode/mod.rs b/src/simulator/equation/ode/mod.rs index 4c005ee7..f0ed2933 100644 --- a/src/simulator/equation/ode/mod.rs +++ b/src/simulator/equation/ode/mod.rs @@ -1,5 +1,8 @@ mod closure; +use std::collections::HashMap; +use std::marker::PhantomData; + use crate::{ data::{Covariates, Infusion}, error_model::ErrorModels, @@ -18,7 +21,7 @@ use diffsol::{ }; use nalgebra::DVector; -use super::{Equation, EquationPriv, EquationTypes, State}; +use super::{Equation, EquationPriv, EquationTypes, Missing, Provided, State}; const RTOL: f64 = 1e-4; const ATOL: f64 = 1e-4; @@ -34,14 +37,315 @@ pub struct ODE { } impl ODE { - pub fn new(diffeq: DiffEq, lag: Lag, fa: Fa, init: Init, out: Out, neqs: Neqs) -> Self { + /// Creates a new ODE equation. + /// + /// For a more ergonomic API, consider using [`ODEBuilder`] instead. + /// + /// # Parameters + /// - `diffeq`: The differential equation closure + /// - `lag`: Lag time function + /// - `fa`: Bioavailability function + /// - `init`: Initial state function + /// - `out`: Output equation function + /// - `neqs`: Number of states and output equations (can be a tuple or [`Neqs`]) + pub fn new( + diffeq: DiffEq, + lag: Lag, + fa: Fa, + init: Init, + out: Out, + neqs: impl Into, + ) -> Self { Self { diffeq, lag, fa, init, out, - neqs, + neqs: neqs.into(), + } + } + + /// Returns a new [`ODEBuilder`] for constructing an ODE equation. + /// + /// # Example + /// ```ignore + /// use pharmsol::prelude::*; + /// + /// // Minimal builder - only required fields + /// let ode = ODE::builder() + /// .diffeq(diffeq) + /// .out(out) + /// .nstates(2) + /// .nouteqs(1) + /// .build(); + /// + /// // With optional fields + /// let ode = ODE::builder() + /// .diffeq(diffeq) + /// .out(out) + /// .nstates(2) + /// .nouteqs(1) + /// .lag(|p, _t, _cov| lag! { 0 => p[2] }) + /// .fa(|p, _t, _cov| fa! { 0 => 0.8 }) + /// .init(|p, _t, _cov, x| { x[0] = p[3]; }) + /// .build(); + /// ``` + pub fn builder() -> ODEBuilder { + ODEBuilder::new() + } +} + +// ============================================================================= +// Type-State Builder Pattern +// ============================================================================= + +// Note: Missing and Provided marker types are defined in the parent module +// and imported via `use super::{..., Missing, Provided, ...}` + +/// Builder for constructing [`ODE`] equations with compile-time validation. +/// +/// This builder uses the type-state pattern to ensure all required fields +/// are set before `build()` can be called. Optional fields (`lag`, `fa`, `init`) +/// have sensible defaults. +/// +/// # Required Fields (enforced at compile time) +/// - `diffeq`: The differential equation closure +/// - `out`: Output equation function +/// - `nstates`: Number of state variables +/// - `nouteqs`: Number of output equations +/// +/// # Optional Fields (with defaults) +/// - `lag`: Lag time function (defaults to no lag) +/// - `fa`: Bioavailability function (defaults to 100% bioavailability) +/// - `init`: Initial state function (defaults to zero initial state) +/// +/// # Example +/// ```ignore +/// use pharmsol::prelude::*; +/// +/// // Minimal example - only required fields +/// let ode = ODE::builder() +/// .diffeq(|x, p, _t, dx, _b, rateiv, _cov| { +/// fetch_params!(p, ke, _v); +/// dx[0] = -ke * x[0] + rateiv[0]; +/// }) +/// .out(|x, p, _t, _cov, y| { +/// fetch_params!(p, _ke, v); +/// y[0] = x[0] / v; +/// }) +/// .nstates(1) +/// .nouteqs(1) +/// .build(); +/// ``` +pub struct ODEBuilder { + diffeq: Option, + lag: Option, + fa: Option, + init: Option, + out: Option, + nstates: Option, + nouteqs: Option, + _phantom: PhantomData<(DiffEqState, OutState, NStatesState, NOuteqsState)>, +} + +impl ODEBuilder { + /// Creates a new ODEBuilder with all required fields unset. + pub fn new() -> Self { + Self { + diffeq: None, + lag: None, + fa: None, + init: None, + out: None, + nstates: None, + nouteqs: None, + _phantom: PhantomData, + } + } +} + +impl Default for ODEBuilder { + fn default() -> Self { + Self::new() + } +} + +impl + ODEBuilder +{ + /// Sets the lag time function (optional). + /// + /// If not set, defaults to no lag for any compartment. + pub fn lag(mut self, lag: Lag) -> Self { + self.lag = Some(lag); + self + } + + /// Sets the bioavailability function (optional). + /// + /// If not set, defaults to 100% bioavailability for all compartments. + pub fn fa(mut self, fa: Fa) -> Self { + self.fa = Some(fa); + self + } + + /// Sets the initial state function (optional). + /// + /// If not set, defaults to zero initial state for all compartments. + pub fn init(mut self, init: Init) -> Self { + self.init = Some(init); + self + } +} + +impl + ODEBuilder +{ + /// Sets the differential equation closure (required). + /// + /// This closure defines the system of ODEs: dx/dt = f(x, p, t, ...) + /// + /// # Parameters + /// The closure receives: + /// - `x`: Current state vector + /// - `p`: Parameter vector + /// - `t`: Current time + /// - `dx`: Output vector for derivatives (mutated by the closure) + /// - `bolus`: Bolus amounts + /// - `rateiv`: IV infusion rates + /// - `cov`: Covariates + pub fn diffeq( + self, + diffeq: DiffEq, + ) -> ODEBuilder { + ODEBuilder { + diffeq: Some(diffeq), + lag: self.lag, + fa: self.fa, + init: self.init, + out: self.out, + nstates: self.nstates, + nouteqs: self.nouteqs, + _phantom: PhantomData, + } + } +} + +impl + ODEBuilder +{ + /// Sets the output equation function (required). + /// + /// This closure computes observable outputs from the state. + pub fn out(self, out: Out) -> ODEBuilder { + ODEBuilder { + diffeq: self.diffeq, + lag: self.lag, + fa: self.fa, + init: self.init, + out: Some(out), + nstates: self.nstates, + nouteqs: self.nouteqs, + _phantom: PhantomData, + } + } +} + +impl ODEBuilder { + /// Sets the number of state variables (compartments) (required). + pub fn nstates( + self, + nstates: usize, + ) -> ODEBuilder { + ODEBuilder { + diffeq: self.diffeq, + lag: self.lag, + fa: self.fa, + init: self.init, + out: self.out, + nstates: Some(nstates), + nouteqs: self.nouteqs, + _phantom: PhantomData, + } + } +} + +impl ODEBuilder { + /// Sets the number of output equations (required). + pub fn nouteqs( + self, + nouteqs: usize, + ) -> ODEBuilder { + ODEBuilder { + diffeq: self.diffeq, + lag: self.lag, + fa: self.fa, + init: self.init, + out: self.out, + nstates: self.nstates, + nouteqs: Some(nouteqs), + _phantom: PhantomData, + } + } +} + +impl ODEBuilder { + /// Sets both nstates and nouteqs from a [`Neqs`] struct or tuple (required). + pub fn neqs( + self, + neqs: impl Into, + ) -> ODEBuilder { + let neqs = neqs.into(); + ODEBuilder { + diffeq: self.diffeq, + lag: self.lag, + fa: self.fa, + init: self.init, + out: self.out, + nstates: Some(neqs.nstates), + nouteqs: Some(neqs.nouteqs), + _phantom: PhantomData, + } + } +} + +/// Default lag function: no lag for any compartment +fn default_lag(_p: &V, _t: f64, _cov: &Covariates) -> HashMap { + HashMap::new() +} + +/// Default fa function: 100% bioavailability for all compartments +fn default_fa(_p: &V, _t: f64, _cov: &Covariates) -> HashMap { + HashMap::new() +} + +/// Default init function: zero initial state +fn default_init(_p: &V, _t: f64, _cov: &Covariates, _x: &mut V) { + // State is already zero-initialized +} + +impl ODEBuilder { + /// Builds the [`ODE`] equation. + /// + /// This method is only available when all required fields have been set: + /// - `diffeq` + /// - `out` + /// - `nstates` + /// - `nouteqs` + /// + /// Optional fields use defaults if not set: + /// - `lag`: No lag (empty HashMap) + /// - `fa`: 100% bioavailability (empty HashMap) + /// - `init`: Zero initial state + pub fn build(self) -> ODE { + ODE { + diffeq: self.diffeq.unwrap(), + lag: self.lag.unwrap_or(default_lag), + fa: self.fa.unwrap_or(default_fa), + init: self.init.unwrap_or(default_init), + out: self.out.unwrap(), + neqs: Neqs::new(self.nstates.unwrap(), self.nouteqs.unwrap()), } } } @@ -127,12 +431,12 @@ impl EquationPriv for ODE { } #[inline(always)] fn get_nstates(&self) -> usize { - self.neqs.0 + self.neqs.nstates } #[inline(always)] fn get_nouteqs(&self) -> usize { - self.neqs.1 + self.neqs.nouteqs } #[inline(always)] fn solve( diff --git a/src/simulator/equation/sde/mod.rs b/src/simulator/equation/sde/mod.rs index 63d2cf82..ec81fa47 100644 --- a/src/simulator/equation/sde/mod.rs +++ b/src/simulator/equation/sde/mod.rs @@ -1,5 +1,8 @@ mod em; +use std::collections::HashMap; +use std::marker::PhantomData; + use diffsol::{NalgebraContext, Vector}; use nalgebra::DVector; use ndarray::{concatenate, Array2, Axis}; @@ -14,7 +17,7 @@ use crate::{ error_model::ErrorModels, prelude::simulator::Prediction, simulator::{Diffusion, Drift, Fa, Init, Lag, Neqs, Out, V}, - Subject, + Missing, Provided, Subject, }; use diffsol::VectorCommon; @@ -93,6 +96,8 @@ pub struct SDE { impl SDE { /// Creates a new stochastic differential equation solver. /// + /// For a more ergonomic API, consider using [`SDEBuilder`] instead. + /// /// # Arguments /// /// * `drift` - Function defining the deterministic component of the SDE @@ -101,7 +106,7 @@ impl SDE { /// * `fa` - Function to compute bioavailability fractions /// * `init` - Function to initialize the system state /// * `out` - Function to compute output equations - /// * `neqs` - Tuple containing the number of state and output equations + /// * `neqs` - Number of states and output equations (can be a tuple or [`Neqs`]) /// * `nparticles` - Number of particles to use in the simulation /// /// # Returns @@ -115,7 +120,7 @@ impl SDE { fa: Fa, init: Init, out: Out, - neqs: Neqs, + neqs: impl Into, nparticles: usize, ) -> Self { Self { @@ -125,10 +130,378 @@ impl SDE { fa, init, out, - neqs, + neqs: neqs.into(), nparticles, } } + + /// Returns a new [`SDEBuilder`] for constructing an SDE equation. + /// + /// # Example + /// ```ignore + /// use pharmsol::prelude::*; + /// + /// // Minimal builder - only required fields + /// let sde = SDE::builder() + /// .drift(drift) + /// .diffusion(diffusion) + /// .out(out) + /// .nstates(2) + /// .nouteqs(1) + /// .nparticles(1000) + /// .build(); + /// + /// // With optional fields + /// let sde = SDE::builder() + /// .drift(drift) + /// .diffusion(diffusion) + /// .out(out) + /// .nstates(2) + /// .nouteqs(1) + /// .nparticles(1000) + /// .lag(|p, _t, _cov| lag! { 0 => p[2] }) + /// .fa(|p, _t, _cov| fa! { 0 => 0.8 }) + /// .init(|p, _t, _cov, x| { x[0] = p[3]; }) + /// .build(); + /// ``` + pub fn builder() -> SDEBuilder { + SDEBuilder::new() + } +} + +// ============================================================================= +// Type-State Builder Pattern +// ============================================================================= + +// Note: Missing and Provided marker types are defined in the parent module +// and imported via `use crate::{..., Missing, Provided, ...}` + +/// Builder for constructing [`SDE`] equations with compile-time validation. +/// +/// This builder uses the type-state pattern to ensure all required fields +/// are set before `build()` can be called. Optional fields (`lag`, `fa`, `init`) +/// have sensible defaults. +/// +/// # Required Fields (enforced at compile time) +/// - `drift`: The drift (deterministic) function +/// - `diffusion`: The diffusion (stochastic) function +/// - `out`: Output equation function +/// - `nstates`: Number of state variables +/// - `nouteqs`: Number of output equations +/// - `nparticles`: Number of particles for simulation +/// +/// # Optional Fields (with defaults) +/// - `lag`: Lag time function (defaults to no lag) +/// - `fa`: Bioavailability function (defaults to 100% bioavailability) +/// - `init`: Initial state function (defaults to zero initial state) +/// +/// # Example +/// ```ignore +/// use pharmsol::prelude::*; +/// +/// // Minimal example - only required fields +/// let sde = SDE::builder() +/// .drift(|x, p, t, dx, rateiv, cov| { /* ... */ }) +/// .diffusion(|x, p, t, dx, cov| { /* ... */ }) +/// .out(|x, p, _t, _cov, y| { y[0] = x[0] / p[1]; }) +/// .nstates(1) +/// .nouteqs(1) +/// .nparticles(1000) +/// .build(); +/// ``` +pub struct SDEBuilder< + DriftState, + DiffusionState, + OutState, + NStatesState, + NOuteqsState, + NParticlesState, +> { + drift: Option, + diffusion: Option, + lag: Option, + fa: Option, + init: Option, + out: Option, + nstates: Option, + nouteqs: Option, + nparticles: Option, + _phantom: PhantomData<( + DriftState, + DiffusionState, + OutState, + NStatesState, + NOuteqsState, + NParticlesState, + )>, +} + +impl SDEBuilder { + /// Creates a new SDEBuilder with all required fields unset. + pub fn new() -> Self { + Self { + drift: None, + diffusion: None, + lag: None, + fa: None, + init: None, + out: None, + nstates: None, + nouteqs: None, + nparticles: None, + _phantom: PhantomData, + } + } +} + +impl Default for SDEBuilder { + fn default() -> Self { + Self::new() + } +} + +impl + SDEBuilder +{ + /// Sets the lag time function (optional). + /// + /// If not set, defaults to no lag for any compartment. + pub fn lag(mut self, lag: Lag) -> Self { + self.lag = Some(lag); + self + } + + /// Sets the bioavailability function (optional). + /// + /// If not set, defaults to 100% bioavailability for all compartments. + pub fn fa(mut self, fa: Fa) -> Self { + self.fa = Some(fa); + self + } + + /// Sets the initial state function (optional). + /// + /// If not set, defaults to zero initial state for all compartments. + pub fn init(mut self, init: Init) -> Self { + self.init = Some(init); + self + } +} + +impl + SDEBuilder +{ + /// Sets the drift (deterministic) function (required). + /// + /// The drift function defines the deterministic component of the SDE: dx/dt = f(x, p, t, ...) + pub fn drift( + self, + drift: Drift, + ) -> SDEBuilder + { + SDEBuilder { + drift: Some(drift), + diffusion: self.diffusion, + lag: self.lag, + fa: self.fa, + init: self.init, + out: self.out, + nstates: self.nstates, + nouteqs: self.nouteqs, + nparticles: self.nparticles, + _phantom: PhantomData, + } + } +} + +impl + SDEBuilder +{ + /// Sets the diffusion (stochastic) function (required). + /// + /// The diffusion function defines the stochastic component of the SDE. + pub fn diffusion( + self, + diffusion: Diffusion, + ) -> SDEBuilder + { + SDEBuilder { + drift: self.drift, + diffusion: Some(diffusion), + lag: self.lag, + fa: self.fa, + init: self.init, + out: self.out, + nstates: self.nstates, + nouteqs: self.nouteqs, + nparticles: self.nparticles, + _phantom: PhantomData, + } + } +} + +impl + SDEBuilder +{ + /// Sets the output equation function (required). + pub fn out( + self, + out: Out, + ) -> SDEBuilder + { + SDEBuilder { + drift: self.drift, + diffusion: self.diffusion, + lag: self.lag, + fa: self.fa, + init: self.init, + out: Some(out), + nstates: self.nstates, + nouteqs: self.nouteqs, + nparticles: self.nparticles, + _phantom: PhantomData, + } + } +} + +impl + SDEBuilder +{ + /// Sets the number of state variables (compartments) (required). + pub fn nstates( + self, + nstates: usize, + ) -> SDEBuilder + { + SDEBuilder { + drift: self.drift, + diffusion: self.diffusion, + lag: self.lag, + fa: self.fa, + init: self.init, + out: self.out, + nstates: Some(nstates), + nouteqs: self.nouteqs, + nparticles: self.nparticles, + _phantom: PhantomData, + } + } +} + +impl + SDEBuilder +{ + /// Sets the number of output equations (required). + pub fn nouteqs( + self, + nouteqs: usize, + ) -> SDEBuilder + { + SDEBuilder { + drift: self.drift, + diffusion: self.diffusion, + lag: self.lag, + fa: self.fa, + init: self.init, + out: self.out, + nstates: self.nstates, + nouteqs: Some(nouteqs), + nparticles: self.nparticles, + _phantom: PhantomData, + } + } +} + +impl + SDEBuilder +{ + /// Sets both nstates and nouteqs from a [`Neqs`] struct or tuple (required). + pub fn neqs( + self, + neqs: impl Into, + ) -> SDEBuilder { + let neqs = neqs.into(); + SDEBuilder { + drift: self.drift, + diffusion: self.diffusion, + lag: self.lag, + fa: self.fa, + init: self.init, + out: self.out, + nstates: Some(neqs.nstates), + nouteqs: Some(neqs.nouteqs), + nparticles: self.nparticles, + _phantom: PhantomData, + } + } +} + +impl + SDEBuilder +{ + /// Sets the number of particles for simulation (required). + pub fn nparticles( + self, + nparticles: usize, + ) -> SDEBuilder + { + SDEBuilder { + drift: self.drift, + diffusion: self.diffusion, + lag: self.lag, + fa: self.fa, + init: self.init, + out: self.out, + nstates: self.nstates, + nouteqs: self.nouteqs, + nparticles: Some(nparticles), + _phantom: PhantomData, + } + } +} + +/// Default lag function: no lag for any compartment +fn default_lag(_p: &V, _t: f64, _cov: &Covariates) -> HashMap { + HashMap::new() +} + +/// Default fa function: 100% bioavailability for all compartments +fn default_fa(_p: &V, _t: f64, _cov: &Covariates) -> HashMap { + HashMap::new() +} + +/// Default init function: zero initial state +fn default_init(_p: &V, _t: f64, _cov: &Covariates, _x: &mut V) { + // State is already zero-initialized +} + +impl SDEBuilder { + /// Builds the [`SDE`] equation. + /// + /// This method is only available when all required fields have been set: + /// - `drift` + /// - `diffusion` + /// - `out` + /// - `nstates` + /// - `nouteqs` + /// - `nparticles` + /// + /// Optional fields use defaults if not set: + /// - `lag`: No lag (empty HashMap) + /// - `fa`: 100% bioavailability (empty HashMap) + /// - `init`: Zero initial state + pub fn build(self) -> SDE { + SDE { + drift: self.drift.unwrap(), + diffusion: self.diffusion.unwrap(), + lag: self.lag.unwrap_or(default_lag), + fa: self.fa.unwrap_or(default_fa), + init: self.init.unwrap_or(default_init), + out: self.out.unwrap(), + neqs: Neqs::new(self.nstates.unwrap(), self.nouteqs.unwrap()), + nparticles: self.nparticles.unwrap(), + } + } } /// State trait implementation for particle-based SDE simulation. @@ -237,12 +610,12 @@ impl EquationPriv for SDE { #[inline(always)] fn get_nstates(&self) -> usize { - self.neqs.0 + self.neqs.nstates } #[inline(always)] fn get_nouteqs(&self) -> usize { - self.neqs.1 + self.neqs.nouteqs } #[inline(always)] fn solve( diff --git a/src/simulator/mod.rs b/src/simulator/mod.rs index 39d2baab..d3f60a4f 100644 --- a/src/simulator/mod.rs +++ b/src/simulator/mod.rs @@ -198,16 +198,53 @@ pub type Fa = fn(&V, T, &Covariates) -> HashMap; /// The number of states and output equations of the model. /// -/// # Components -/// - The first element is the number of states -/// - The second element is the number of output equations -/// -/// This is used to initialize the state vector and the output vector. +/// This struct specifies the dimensions of the model: +/// - `nstates`: Number of state variables (compartments) +/// - `nouteqs`: Number of output equations (observable outputs) /// /// # Example /// ```ignore -/// let neqs = (2, 1); +/// use pharmsol::Neqs; +/// // A two-compartment model with one output equation +/// let neqs = Neqs::new(2, 1); +/// // Or using the tuple shorthand for backward compatibility +/// let neqs: Neqs = (2, 1).into(); /// ``` -/// This means that the system of equations has 2 states and there is only 1 output equation. -/// -pub type Neqs = (usize, usize); +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Neqs { + /// Number of state variables (compartments) + pub nstates: usize, + /// Number of output equations (observable outputs) + pub nouteqs: usize, +} + +impl Neqs { + /// Create a new Neqs with the specified dimensions. + /// + /// # Parameters + /// - `nstates`: Number of state variables (compartments) + /// - `nouteqs`: Number of output equations (observable outputs) + #[inline] + pub const fn new(nstates: usize, nouteqs: usize) -> Self { + Self { nstates, nouteqs } + } +} + +/// Allow construction from tuple for backward compatibility +impl From<(usize, usize)> for Neqs { + #[inline] + fn from(tuple: (usize, usize)) -> Self { + Self { + nstates: tuple.0, + nouteqs: tuple.1, + } + } +} + +/// Allow conversion to tuple for backward compatibility +impl From for (usize, usize) { + #[inline] + fn from(neqs: Neqs) -> Self { + (neqs.nstates, neqs.nouteqs) + } +} diff --git a/tests/api_comparison.rs b/tests/api_comparison.rs new file mode 100644 index 00000000..3106329f --- /dev/null +++ b/tests/api_comparison.rs @@ -0,0 +1,580 @@ +//! Tests to verify that the old tuple-based API and new builder API produce identical results. +//! +//! This ensures backward compatibility while validating the new type-state builder pattern. +//! The new builder API enforces required fields at compile time and provides sensible defaults +//! for optional fields (lag, fa, init). + +use pharmsol::prelude::models::{one_compartment, one_compartment_with_absorption}; +use pharmsol::*; + +const TOLERANCE: f64 = 1e-12; + +/// Helper to assert that predictions from two models are identical +fn assert_predictions_match( + label: &str, + model1: &E1, + model2: &E2, + subject: &Subject, + params: &[f64], +) { + let params_vec: Vec = params.to_vec(); + + let pred1 = model1 + .estimate_predictions(subject, ¶ms_vec) + .expect("model1 predictions should succeed"); + let pred2 = model2 + .estimate_predictions(subject, ¶ms_vec) + .expect("model2 predictions should succeed"); + + let preds1 = pred1.get_predictions(); + let preds2 = pred2.get_predictions(); + + assert_eq!( + preds1.len(), + preds2.len(), + "{}: prediction count mismatch ({} vs {})", + label, + preds1.len(), + preds2.len() + ); + + for (idx, (p1, p2)) in preds1.iter().zip(preds2.iter()).enumerate() { + let diff = (p1.prediction() - p2.prediction()).abs(); + assert!( + diff < TOLERANCE, + "{}: prediction {} differs (old={:.15}, new={:.15}, diff={:.2e})", + label, + idx, + p1.prediction(), + p2.prediction(), + diff + ); + } +} + +// ============================================================================= +// ODE API COMPARISON TESTS +// ============================================================================= + +#[test] +fn ode_builder_matches_tuple_api_one_compartment() { + let subject = Subject::builder("ode_comparison") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .observation(2.0, 0.0, 0) + .observation(4.0, 0.0, 0) + .observation(8.0, 0.0, 0) + .build(); + + // Old API with tuple + let ode_old = equation::ODE::new( + |x, p, _t, dx, b, rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + b[0] + rateiv[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + (1, 1), + ); + + // New builder API - minimal version (only required fields) + let ode_new = equation::ODE::builder() + .diffeq(|x, p, _t, dx, b, rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + b[0] + rateiv[0]; + }) + .out(|x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }) + .nstates(1) + .nouteqs(1) + .build(); + + assert_predictions_match( + "ode_one_compartment", + &ode_old, + &ode_new, + &subject, + &[0.1, 50.0], + ); +} + +#[test] +fn ode_builder_matches_tuple_api_two_compartment() { + let subject = Subject::builder("ode_two_comp") + .bolus(0.0, 100.0, 0) + .observation(0.5, 0.0, 0) + .observation(1.0, 0.0, 0) + .observation(2.0, 0.0, 0) + .observation(4.0, 0.0, 0) + .observation(8.0, 0.0, 0) + .build(); + + // Old API with tuple + let ode_old = equation::ODE::new( + |x, p, _t, dx, b, _rateiv, _cov| { + fetch_params!(p, ka, ke, _v); + dx[0] = -ka * x[0] + b[0]; + dx[1] = ka * x[0] - ke * x[1]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v); + y[0] = x[1] / v; + }, + (2, 1), + ); + + // New builder API - minimal version + let ode_new = equation::ODE::builder() + .diffeq(|x, p, _t, dx, b, _rateiv, _cov| { + fetch_params!(p, ka, ke, _v); + dx[0] = -ka * x[0] + b[0]; + dx[1] = ka * x[0] - ke * x[1]; + }) + .out(|x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v); + y[0] = x[1] / v; + }) + .nstates(2) + .nouteqs(1) + .build(); + + assert_predictions_match( + "ode_two_compartment", + &ode_old, + &ode_new, + &subject, + &[1.0, 0.1, 50.0], + ); +} + +#[test] +fn ode_builder_with_neqs_struct_matches_tuple() { + let subject = Subject::builder("ode_neqs_struct") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .observation(2.0, 0.0, 0) + .observation(4.0, 0.0, 0) + .build(); + + // Old API with tuple + let ode_old = equation::ODE::new( + |x, p, _t, dx, b, _rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + b[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + (1, 1), + ); + + // New builder API with Neqs struct + let ode_new = equation::ODE::builder() + .diffeq(|x, p, _t, dx, b, _rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + b[0]; + }) + .out(|x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }) + .neqs(Neqs::new(1, 1)) + .build(); + + assert_predictions_match( + "ode_neqs_struct", + &ode_old, + &ode_new, + &subject, + &[0.1, 50.0], + ); +} + +#[test] +fn ode_new_accepts_neqs_struct() { + let subject = Subject::builder("ode_new_neqs") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .observation(2.0, 0.0, 0) + .build(); + + // Old API with tuple + let ode_tuple = equation::ODE::new( + |x, p, _t, dx, b, _rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + b[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + (1, 1), + ); + + // Old API with Neqs struct (new feature!) + let ode_neqs = equation::ODE::new( + |x, p, _t, dx, b, _rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + b[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + Neqs::new(1, 1), + ); + + assert_predictions_match( + "ode_new_with_neqs", + &ode_tuple, + &ode_neqs, + &subject, + &[0.1, 50.0], + ); +} + +// ============================================================================= +// ANALYTICAL API COMPARISON TESTS +// ============================================================================= + +#[test] +fn analytical_builder_matches_tuple_api() { + let subject = Subject::builder("analytical_comparison") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .observation(2.0, 0.0, 0) + .observation(4.0, 0.0, 0) + .observation(8.0, 0.0, 0) + .build(); + + // Old API with tuple + let analytical_old = equation::Analytical::new( + one_compartment, + |_p, _t, _cov| {}, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + (1, 1), + ); + + // New builder API - minimal version + let analytical_new = equation::Analytical::builder() + .eq(one_compartment) + .seq_eq(|_p, _t, _cov| {}) + .out(|x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }) + .nstates(1) + .nouteqs(1) + .build(); + + assert_predictions_match( + "analytical_one_compartment", + &analytical_old, + &analytical_new, + &subject, + &[0.1, 50.0], + ); +} + +#[test] +fn analytical_builder_matches_tuple_api_with_absorption() { + let subject = Subject::builder("analytical_absorption") + .bolus(0.0, 100.0, 0) + .observation(0.5, 0.0, 0) + .observation(1.0, 0.0, 0) + .observation(2.0, 0.0, 0) + .observation(4.0, 0.0, 0) + .observation(8.0, 0.0, 0) + .build(); + + // Old API with tuple + let analytical_old = equation::Analytical::new( + one_compartment_with_absorption, + |_p, _t, _cov| {}, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v); + y[0] = x[1] / v; + }, + (2, 1), + ); + + // New builder API - minimal version + let analytical_new = equation::Analytical::builder() + .eq(one_compartment_with_absorption) + .seq_eq(|_p, _t, _cov| {}) + .out(|x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v); + y[0] = x[1] / v; + }) + .nstates(2) + .nouteqs(1) + .build(); + + assert_predictions_match( + "analytical_with_absorption", + &analytical_old, + &analytical_new, + &subject, + &[1.0, 0.1, 50.0], + ); +} + +#[test] +fn analytical_new_accepts_neqs_struct() { + let subject = Subject::builder("analytical_neqs") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .observation(2.0, 0.0, 0) + .build(); + + // Old API with tuple + let analytical_tuple = equation::Analytical::new( + one_compartment, + |_p, _t, _cov| {}, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + (1, 1), + ); + + // Old API with Neqs struct + let analytical_neqs = equation::Analytical::new( + one_compartment, + |_p, _t, _cov| {}, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + Neqs::new(1, 1), + ); + + assert_predictions_match( + "analytical_new_with_neqs", + &analytical_tuple, + &analytical_neqs, + &subject, + &[0.1, 50.0], + ); +} + +// ============================================================================= +// TYPE-STATE BUILDER - MINIMAL API TESTS +// ============================================================================= + +#[test] +fn ode_builder_minimal_matches_full_explicit() { + let subject = Subject::builder("minimal_vs_full") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .observation(2.0, 0.0, 0) + .build(); + + // Minimal builder (only required fields) + let ode_minimal = equation::ODE::builder() + .diffeq(|x, p, _t, dx, b, _rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + b[0]; + }) + .out(|x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }) + .nstates(1) + .nouteqs(1) + .build(); + + // Full builder (all fields explicit, using defaults) + let ode_full = equation::ODE::builder() + .diffeq(|x, p, _t, dx, b, _rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + b[0]; + }) + .out(|x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }) + .nstates(1) + .nouteqs(1) + .build(); + + assert_predictions_match( + "minimal_vs_full", + &ode_minimal, + &ode_full, + &subject, + &[0.1, 50.0], + ); +} + +#[test] +fn analytical_builder_minimal_matches_full_explicit() { + let subject = Subject::builder("analytical_minimal_vs_full") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .observation(2.0, 0.0, 0) + .build(); + + // Minimal builder (only required fields) + let analytical_minimal = equation::Analytical::builder() + .eq(one_compartment) + .seq_eq(|_p, _t, _cov| {}) + .out(|x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }) + .nstates(1) + .nouteqs(1) + .build(); + + // Full builder (all fields explicit) + let analytical_full = equation::Analytical::builder() + .eq(one_compartment) + .seq_eq(|_p, _t, _cov| {}) + .out(|x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }) + .nstates(1) + .nouteqs(1) + .build(); + + assert_predictions_match( + "analytical_minimal_vs_full", + &analytical_minimal, + &analytical_full, + &subject, + &[0.1, 50.0], + ); +} + +// ============================================================================= +// NEQS STRUCT TESTS +// ============================================================================= + +#[test] +fn neqs_struct_conversion_from_tuple() { + let neqs: Neqs = (2, 3).into(); + assert_eq!(neqs.nstates, 2); + assert_eq!(neqs.nouteqs, 3); +} + +#[test] +fn neqs_struct_conversion_to_tuple() { + let neqs = Neqs::new(4, 5); + let tuple: (usize, usize) = neqs.into(); + assert_eq!(tuple, (4, 5)); +} + +#[test] +fn neqs_struct_new() { + let neqs = Neqs::new(1, 2); + assert_eq!(neqs.nstates, 1); + assert_eq!(neqs.nouteqs, 2); +} + +// ============================================================================= +// LIKELIHOOD COMPARISON TESTS +// ============================================================================= + +#[test] +fn likelihood_matches_between_apis() { + let subject = Subject::builder("likelihood_comparison") + .bolus(0.0, 100.0, 0) + .observation(1.0, 1.8, 0) + .observation(2.0, 1.6, 0) + .observation(4.0, 1.3, 0) + .observation(8.0, 0.8, 0) + .build(); + + let error_models = ErrorModels::new() + .add( + 0, + ErrorModel::additive(ErrorPoly::new(0.0, 0.1, 0.0, 0.0), 0.0), + ) + .unwrap(); + + let params = vec![0.1, 50.0]; + + // ODE: Old API + let ode_old = equation::ODE::new( + |x, p, _t, dx, b, _rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + b[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + (1, 1), + ); + + // ODE: New API - minimal version + let ode_new = equation::ODE::builder() + .diffeq(|x, p, _t, dx, b, _rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + b[0]; + }) + .out(|x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }) + .nstates(1) + .nouteqs(1) + .build(); + + let ll_old = ode_old + .estimate_likelihood(&subject, ¶ms, &error_models, false) + .expect("old likelihood"); + + let ll_new = ode_new + .estimate_likelihood(&subject, ¶ms, &error_models, false) + .expect("new likelihood"); + + let diff = (ll_old - ll_new).abs(); + assert!( + diff < TOLERANCE, + "Likelihoods should match: old={:.15}, new={:.15}, diff={:.2e}", + ll_old, + ll_new, + diff + ); +} diff --git a/tests/numerical_stability.rs b/tests/numerical_stability.rs index 069e146e..876690a0 100644 --- a/tests/numerical_stability.rs +++ b/tests/numerical_stability.rs @@ -89,33 +89,29 @@ fn infusion_subject() -> Subject { } fn infusion_models() -> (equation::Analytical, equation::ODE) { - let analytical = equation::Analytical::new( - one_compartment, - |_p, _t, _cov| {}, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { + let analytical = equation::Analytical::builder() + .eq(one_compartment) + .seq_eq(|_p, _t, _cov| {}) + .out(|x, p, _t, _cov, y| { fetch_params!(p, _ke, v); y[0] = x[0] / v; - }, - (1, 1), - ); + }) + .nstates(1) + .nouteqs(1) + .build(); - let ode = equation::ODE::new( - |x, p, _t, dx, b, rateiv, _cov| { + let ode = equation::ODE::builder() + .diffeq(|x, p, _t, dx, b, rateiv, _cov| { fetch_params!(p, ke, _v); dx[0] = -ke * x[0] + rateiv[0] + b[0]; - }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { + }) + .out(|x, p, _t, _cov, y| { fetch_params!(p, _ke, v); y[0] = x[0] / v; - }, - (1, 1), - ); + }) + .nstates(1) + .nouteqs(1) + .build(); (analytical, ode) } @@ -137,34 +133,30 @@ fn absorption_subject() -> Subject { } fn absorption_models() -> (equation::Analytical, equation::ODE) { - let analytical = equation::Analytical::new( - one_compartment_with_absorption, - |_p, _t, _cov| {}, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { + let analytical = equation::Analytical::builder() + .eq(one_compartment_with_absorption) + .seq_eq(|_p, _t, _cov| {}) + .out(|x, p, _t, _cov, y| { fetch_params!(p, _ka, _ke, v); y[0] = x[1] / v; - }, - (2, 1), - ); + }) + .nstates(2) + .nouteqs(1) + .build(); - let ode = equation::ODE::new( - |x, p, _t, dx, b, rateiv, _cov| { + let ode = equation::ODE::builder() + .diffeq(|x, p, _t, dx, b, rateiv, _cov| { fetch_params!(p, ka, ke, _v); dx[0] = -ka * x[0] + b[0]; dx[1] = ka * x[0] - ke * x[1] + rateiv[0] + b[1]; - }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { + }) + .out(|x, p, _t, _cov, y| { fetch_params!(p, _ka, _ke, v); y[0] = x[1] / v; - }, - (2, 1), - ); + }) + .nstates(2) + .nouteqs(1) + .build(); (analytical, ode) } @@ -184,34 +176,30 @@ fn two_compartment_subject() -> Subject { } fn two_compartment_models() -> (equation::Analytical, equation::ODE) { - let analytical = equation::Analytical::new( - two_compartments, - |_p, _t, _cov| {}, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { + let analytical = equation::Analytical::builder() + .eq(two_compartments) + .seq_eq(|_p, _t, _cov| {}) + .out(|x, p, _t, _cov, y| { fetch_params!(p, _ke, _kcp, _kpc, v); y[0] = x[0] / v; - }, - (2, 1), - ); + }) + .nstates(2) + .nouteqs(1) + .build(); - let ode = equation::ODE::new( - |x, p, _t, dx, b, rateiv, _cov| { + let ode = equation::ODE::builder() + .diffeq(|x, p, _t, dx, b, rateiv, _cov| { fetch_params!(p, ke, kcp, kpc, _v); dx[0] = rateiv[0] - ke * x[0] - kcp * x[0] + kpc * x[1] + b[0]; dx[1] = kcp * x[0] - kpc * x[1] + b[1]; - }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { + }) + .out(|x, p, _t, _cov, y| { fetch_params!(p, _ke, _kcp, _kpc, v); y[0] = x[0] / v; - }, - (2, 1), - ); + }) + .nstates(2) + .nouteqs(1) + .build(); (analytical, ode) }