diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 00000000..0e12e990 Binary files /dev/null and b/.DS_Store differ diff --git a/Cargo.toml b/Cargo.toml index 1919f279..7778d001 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pharmsol" -version = "0.7.7" +version = "0.7.6" edition = "2021" authors = ["Julián D. Otálvaro ", "Markus Hovd"] description = "Rust library for solving analytic and ode-defined pharmacometric models." diff --git a/examples/new_gd.rs b/examples/new_gd.rs new file mode 100644 index 00000000..a560a246 --- /dev/null +++ b/examples/new_gd.rs @@ -0,0 +1,60 @@ +use csv::ReaderBuilder; +use pharmsol::*; +use prelude::{data::read_pmetrics, simulator::Prediction}; + +fn main() { + // path to theta + let path = "../PMcore/outputs/theta.csv"; + // read theta into an Array2 + let mut rdr = ReaderBuilder::new().from_path(path).unwrap(); + + let mut spps = vec![]; + for result in rdr.records() { + let record = result.unwrap(); + let mut row = vec![]; + for field in record.iter() { + row.push(field.parse::().unwrap()); + } + spps.push(row); + } + + let sde = equation::SDE::new( + |x, p, _t, dx, _rateiv, _cov| { + // automatically defined + fetch_params!(p, ke0, _ske); + // let ke0 = 1.2; + dx[1] = -x[1] + ke0; + let ke = x[1]; + // user defined + dx[0] = -ke * x[0]; + }, + |p, d| { + fetch_params!(p, _ke0, ske); + d[1] = ske; + }, + |_p| lag! {}, + |_p| fa! {}, + |p, _t, _cov, x| { + fetch_params!(p, ke0, ske); + x[1] = ke0; + }, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke0, _ske); + y[0] = x[0] / 50.0; + }, + (2, 1), + 1, + ); + + let data = read_pmetrics("../PMcore/examples/iov/test.csv").unwrap(); + + for (i, spp) in spps.iter().enumerate() { + for (j, subject) in data.get_subjects().iter().enumerate() { + let trajectories: ndarray::Array2 = + sde.estimate_predictions(&subject, &spp); + let trajectory = trajectories.row(0); + println!("{}, {}", i, j); + dbg!(trajectory); + } + } +} diff --git a/examples/sde.rs b/examples/sde.rs index f4b4eba2..b1463b22 100644 --- a/examples/sde.rs +++ b/examples/sde.rs @@ -165,7 +165,7 @@ fn main() { let ode = three_c_ode(); let sde = three_c_sde(); - let data = read_pmetrics("../PMcore/examples/w_vanco_sde/test.csv").unwrap(); + let data = read_pmetrics("../PMcore/examples/vanco_sde/data.csv").unwrap(); let subject = data.get_subject("51").unwrap(); let ode_predictions = ode.estimate_predictions(&subject, &spp_ode); diff --git a/examples/sde_paper/main.rs b/examples/sde_paper/main.rs new file mode 100644 index 00000000..34488fd1 --- /dev/null +++ b/examples/sde_paper/main.rs @@ -0,0 +1,73 @@ +use pharmsol::equation; +use pmcore::prelude::*; +use rand::rngs::StdRng; +use rand::SeedableRng; +use rand::weighted::WeightedIndex; +use rand_distr::(Distribution, Normal); + +fn model() -> equation::SDE { + let sde = equation::SDE::new( + drift: + diffeq: |x, p, _t, dx, rateiv, _cov| { + // fetch_cov!(cov, t, wt); + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + rateiv[0]; + }, + lag: |_p| lag! {}, + fa: |_p| fa! {}, + init: |_p, _t, _cov, _x| {}, + out: |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + neqs: (1, 1), + ); + +} + +let fn sample_k0(rng: &mut StdRng, n1: Normal, ) -> f64 { + +} + + +fn sample_v() + + + + + +const N_SAMPLES: usize = 100; + +fn main() { + let m1: f64 = 0.5; + let s1: f64 = 0.05; + let m2: f64 = 1.5; + let s2: f64 = 0.15; + + let n1 = Normal::new(m1,s1).unwrap(); + let n2 = Normal::new(m2,s2).unwrap(); + + // let weights = [0.5, 0.5]; + // let dist = WeightedIndex::new(&weights).unwrap(); + + let mut rng = seed_from_u64(state: 42); + let mut_k0_pop: Vec = Vec::new(); + let v_pop: Vec = Vec::new(); + + let n3 = Normal::new(mean: 0.0, std_dev: 1.0).unwrap(); + + for _ in 0..N_SAMPLES { + k0_pop.push(sample_k0(&mut rng, n1, n2)); + v_pop.push(sample) + } + + // plot the distributions + let trace + + + let k0_dist = 0.5 * rand_distr::Normal::new(mean: m1, std_dev: s2) + + 0.5 * rand_distr::Normal::new(mean: m2, std_dev: s2).unwrap(); + let seed: u64 = 42; + let rng: StdRng = rand::rngs::StdRng::seed_from_u64(state:: seed); + +} \ No newline at end of file diff --git a/src/simulator/equation/sde/em.rs b/src/simulator/equation/sde/em.rs index dcfa9c5f..29a24aa8 100644 --- a/src/simulator/equation/sde/em.rs +++ b/src/simulator/equation/sde/em.rs @@ -1,6 +1,7 @@ use crate::{ data::Covariates, simulator::{Diffusion, Drift}, + Infusion, }; use nalgebra::DVector; use rand::rng; @@ -13,7 +14,7 @@ pub struct EM { params: DVector, state: DVector, cov: Covariates, - rateiv: DVector, + infusions: Vec, rtol: f64, atol: f64, max_step: f64, @@ -28,7 +29,7 @@ impl EM { params: DVector, initial_state: DVector, cov: Covariates, - rateiv: DVector, + infusions: Vec, rtol: f64, atol: f64, ) -> Self { @@ -38,7 +39,7 @@ impl EM { params, state: initial_state, cov, - rateiv, + infusions, rtol, atol, max_step: 0.1, // Can be made configurable @@ -64,27 +65,35 @@ impl EM { new_dt } - fn euler_maruyama_step(&self, time: f64, dt: f64, state: &mut DVector) { + fn euler_maruyama_step(&self, time: f64, dt: f64, state: &mut DVector, sigma: f64) { let n = state.len(); + let mut rateiv = DVector::from_vec(vec![0.0, 0.0, 0.0]); + //TODO: This should be pre-calculated + for infusion in &self.infusions { + if time >= infusion.time() && time <= infusion.duration() + infusion.time() { + rateiv[infusion.input()] += infusion.amount() / infusion.duration(); + } + } let mut drift_term = DVector::zeros(n); (self.drift)( state, &self.params, time, &mut drift_term, - self.rateiv.clone(), + rateiv, &self.cov, ); let mut diffusion_term = DVector::zeros(n); (self.diffusion)(&self.params, &mut diffusion_term); - let mut rng = rng(); - let normal_dist = Normal::new(0.0, 1.0).unwrap(); + let mut _rng = rng(); + let _normal_dist = Normal::new(0.0, 1.0).unwrap(); for i in 0..n { state[i] += - drift_term[i] * dt + diffusion_term[i] * normal_dist.sample(&mut rng) * dt.sqrt(); + // drift_term[i] * dt + diffusion_term[i] * normal_dist.sample(&mut rng) * dt.sqrt(); + drift_term[i] * dt + diffusion_term[i] * sigma * dt.sqrt(); } } @@ -94,17 +103,24 @@ impl EM { let safety = 0.9; let mut times = vec![t0]; let mut solution = vec![self.state.clone()]; + let mut sigma = 0.00000001; + + let mut rng = rng(); + let normal_dist = Normal::new(0.0, 1.0).unwrap(); + // sigma = normal_dist.sample(&mut rng); // oops!!! this makes one sigma for ENTIRE time while t < tf { let mut y1 = self.state.clone(); let mut y2 = self.state.clone(); + sigma = normal_dist.sample(&mut rng); // here is correct, only affects one step of the variable time step + // Single step - self.euler_maruyama_step(t, dt, &mut y1); + self.euler_maruyama_step(t, dt, &mut y1, sigma); // Two half steps - self.euler_maruyama_step(t, dt / 2.0, &mut y2); - self.euler_maruyama_step(t + dt / 2.0, dt / 2.0, &mut y2); + self.euler_maruyama_step(t, dt / 2.0, &mut y2, sigma); + self.euler_maruyama_step(t + dt / 2.0, dt / 2.0, &mut y2, sigma); let error = self.calculate_error(&y1, &y2); diff --git a/src/simulator/equation/sde/mod.rs b/src/simulator/equation/sde/mod.rs index db438f13..9cbe6868 100644 --- a/src/simulator/equation/sde/mod.rs +++ b/src/simulator/equation/sde/mod.rs @@ -34,13 +34,6 @@ pub(crate) fn simulate_sde_event( if ti == tf { return x; } - let mut rateiv = V::from_vec(vec![0.0, 0.0, 0.0]); - //TODO: This should be pre-calculated - for infusion in infusions { - if tf >= infusion.time() && tf <= infusion.duration() + infusion.time() { - rateiv[infusion.input()] += infusion.amount() / infusion.duration(); - } - } let mut sde = em::EM::new( *drift, @@ -48,7 +41,7 @@ pub(crate) fn simulate_sde_event( DVector::from_column_slice(support_point), x, cov.clone(), - rateiv, + infusions.to_vec(), 1e-2, 1e-2, ); @@ -109,8 +102,29 @@ impl Predictions for Array2 { fn get_predictions(&self) -> Vec { //TODO: This is only returning the first particle, not the best, not the worst, THE FIRST // CHANGE THIS - let row = self.row(0).to_vec(); - row + // let row = self.row(0).to_vec(); + // row + // Make this return the mean prediction across all particles + if self.is_empty() || self.ncols() == 0 { + return Vec::new(); + } + + let mut result = Vec::with_capacity(self.ncols()); + + for col in 0..self.ncols() { + let column = self.column(col); + let mean_prediction: f64 = column + .iter() + .map(|pred: &Prediction| pred.prediction()) + .sum::() + / self.nrows() as f64; + + let mut prediction = column.first().unwrap().clone(); + prediction.set_prediction(mean_prediction); + result.push(prediction); + } + + result } } @@ -209,8 +223,16 @@ impl EquationPriv for SDE { // q = pdf.(Distributions.Normal(0, 0.5), e) if let Some(em) = error_model { let mut q: Vec = Vec::with_capacity(self.nparticles); - + // + // wmy centering_function is a running Chi^2 w/exp = support point + // let centering_function = p.pred[2]; // move this inside the iteration. + // + // pred.iter().for_each(|p| q.push(p.state[4] * p.likelihood(em))); + // + // note: Above doesn't compile b/c pred is private; but also: I'm not sure if pred is the state, and state is x ??? + // pred.iter().for_each(|p| q.push(p.likelihood(em))); + // let sum_q: f64 = q.iter().sum(); let w: Vec = q.iter().map(|qi| qi / sum_q).collect(); let i = sysresample(&w); diff --git a/src/simulator/likelihood.rs b/src/simulator/likelihood.rs index 6826f0d7..89b5ddb3 100644 --- a/src/simulator/likelihood.rs +++ b/src/simulator/likelihood.rs @@ -194,6 +194,9 @@ impl Prediction { pub fn prediction(&self) -> f64 { self.prediction } + pub fn set_prediction(&mut self, prediction: f64) { + self.prediction = prediction; + } pub fn outeq(&self) -> usize { self.outeq } diff --git a/subjects.csv b/subjects.csv new file mode 100644 index 00000000..b3135fb1 --- /dev/null +++ b/subjects.csv @@ -0,0 +1,101 @@ +ID,Ke,sKe,Vol +0, 1.1537057285734011, 0.10289541640270668, 50 +1, 1.4055612966624533, 0.09639474422747848, 50 +2, 1.1710303971859157, 0.08639424528661005, 50 +3, 1.3460198695803256, 0.08570562300129098, 50 +4, 1.1554824575854128, 0.0995401334124625, 50 +5, 1.2143588043852875, 0.09452739681289624, 50 +6, 1.2278932130799713, 0.10167991593451817, 50 +7, 1.098898765793177, 0.0979097597171234, 50 +8, 1.2676471420496378, 0.10998042907759475, 50 +9, 1.2396981758367795, 0.10112948002412282, 50 +10, 1.117977631957013, 0.09509903677831381, 50 +11, 1.1218128017472697, 0.11258246839062266, 50 +12, 1.0570770700535563, 0.10901940784275149, 50 +13, 1.3086459805837078, 0.10199380624314083, 50 +14, 1.224909824636831, 0.09176339607666309, 50 +15, 1.2364421542805297, 0.10195378693392695, 50 +16, 1.2447022980721614, 0.09816450833528423, 50 +17, 1.0066430977249738, 0.0937567623723915, 50 +18, 1.0248748122975728, 0.09697434986643566, 50 +19, 1.1738535522369364, 0.09860182726274573, 50 +20, 1.2876552425478005, 0.10579657258010625, 50 +21, 1.2486211245416499, 0.0931927131998619, 50 +22, 1.0112097287516961, 0.1154175328721026, 50 +23, 1.1767184172339717, 0.12118419195337941, 50 +24, 1.2351478381804541, 0.12062354460327637, 50 +25, 1.1787961570706984, 0.09555639445693169, 50 +26, 1.3458841273883968, 0.09905409014148021, 50 +27, 1.0630331359257796, 0.0933021368257885, 50 +28, 1.108062308587545, 0.09529187016455509, 50 +29, 1.340623771830854, 0.10486979259352251, 50 +30, 1.1869593362388007, 0.12413765264759334, 50 +31, 1.2813169033370082, 0.0948188595877337, 50 +32, 1.064347711058793, 0.09510807250779679, 50 +33, 1.0524279818040063, 0.09801471704249785, 50 +34, 1.2533274970403327, 0.10653972217518849, 50 +35, 1.2013492131344725, 0.08857802962888102, 50 +36, 1.3004905392045623, 0.10210248288307533, 50 +37, 1.1076842014795856, 0.08059643688007641, 50 +38, 1.2688132137947212, 0.11900907616490888, 50 +39, 1.3221373140196022, 0.08923677354687104, 50 +40, 1.1866381500162704, 0.1041857802145374, 50 +41, 1.1112976938814996, 0.09336801473753686, 50 +42, 1.2295540033914845, 0.12053241093171055, 50 +43, 1.3510846402579197, 0.09090671121805115, 50 +44, 1.2285269381332489, 0.09060328343782482, 50 +45, 1.2615316236096836, 0.08569954123287596, 50 +46, 1.355040886379684, 0.0981102526027017, 50 +47, 1.209599693647232, 0.0895166768928232, 50 +48, 1.0179441848057886, 0.06849061740076876, 50 +49, 1.0696351866993108, 0.10267688640425526, 50 +50, 1.278859849997356, 0.0986897719878365, 50 +51, 1.1268211653925553, 0.08556789523774115, 50 +52, 1.1797201873016918, 0.08986372996336187, 50 +53, 1.248185906357412, 0.10247472618104703, 50 +54, 1.0675677605376521, 0.11564596558466642, 50 +55, 0.9697137996356358, 0.11183105566166397, 50 +56, 1.0921249127084762, 0.11264559252006313, 50 +57, 1.2931213068145355, 0.09733962543048513, 50 +58, 0.886820731374182, 0.09596054209213793, 50 +59, 1.1972305791891324, 0.10091692382494825, 50 +60, 1.1438810654208762, 0.095634978222107, 50 +61, 1.2105942190436731, 0.1091189235433658, 50 +62, 1.1218180302299432, 0.11732029332477402, 50 +63, 1.2667711119801168, 0.09598261537724273, 50 +64, 1.225642941806333, 0.08616025858594184, 50 +65, 0.9987592457582799, 0.1013981877042509, 50 +66, 1.2138946485453377, 0.10062881241060935, 50 +67, 1.3995467472337397, 0.0837534197032114, 50 +68, 1.1503861962448882, 0.08970545547459753, 50 +69, 1.2917968488378215, 0.10198981216193112, 50 +70, 1.2778069149863, 0.09606849753661166, 50 +71, 1.080351648006885, 0.11385621954234637, 50 +72, 1.1972642754581033, 0.09820646537237751, 50 +73, 1.2699424429054966, 0.11132956173515586, 50 +74, 1.0382734580265272, 0.10695713528078032, 50 +75, 1.215294466880341, 0.10824664157811359, 50 +76, 1.0750734337834869, 0.09334672510188377, 50 +77, 1.1491816816250375, 0.09105906901035939, 50 +78, 1.243505049274979, 0.08460268066036611, 50 +79, 1.035306988798836, 0.09087617946948215, 50 +80, 1.0582060793991683, 0.10094839455166812, 50 +81, 1.1357471297990474, 0.10358493403627932, 50 +82, 1.1586266841792847, 0.1232733758135109, 50 +83, 1.1692415618123773, 0.10369932924182489, 50 +84, 1.089704993345683, 0.09485483949272429, 50 +85, 1.1426673589726786, 0.08286971279699235, 50 +86, 1.3071602622738743, 0.10378982443294256, 50 +87, 1.1658696086370257, 0.08730846319176624, 50 +88, 0.9167324903824648, 0.10579852144930121, 50 +89, 1.1508404841198825, 0.09245940690596459, 50 +90, 1.2280184498451163, 0.08582550001397626, 50 +91, 1.09848817578252, 0.10798970381960545, 50 +92, 1.2400275782782906, 0.07333044137551628, 50 +93, 1.1436910224831092, 0.09234258577097995, 50 +94, 1.1428469515924335, 0.10880822347600295, 50 +95, 1.0731800787050982, 0.10477870068503345, 50 +96, 1.1532202352777614, 0.10481938224654984, 50 +97, 1.0398723117148394, 0.10566419670499176, 50 +98, 1.1001955107995218, 0.10577401199277633, 50 +99, 0.9456245532389306, 0.0960663151430105, 50