diff --git a/examples/one_compartment.rs b/examples/one_compartment.rs index 4c0d13c1..ebebfa4e 100644 --- a/examples/one_compartment.rs +++ b/examples/one_compartment.rs @@ -67,28 +67,28 @@ fn main() -> Result<(), pharmsol::PharmsolError> { let ke = 1.022; // Elimination rate constant let v = 194.0; // Volume of distribution - // Compute likelihoods and predictions for both models - let analytical_likelihoods = an.estimate_likelihood(&subject, &vec![ke, v], &ems, false)?; + // Compute log-likelihoods and predictions for both models + let analytical_log_lik = an.estimate_log_likelihood(&subject, &vec![ke, v], &ems, false)?; let analytical_predictions = an.estimate_predictions(&subject, &vec![ke, v])?; - let ode_likelihoods = ode.estimate_likelihood(&subject, &vec![ke, v], &ems, false)?; + let ode_log_lik = ode.estimate_log_likelihood(&subject, &vec![ke, v], &ems, false)?; let ode_predictions = ode.estimate_predictions(&subject, &vec![ke, v])?; // Print comparison table - println!("\n┌───────────┬─────────────────┬─────────────────┬─────────────────────┐"); - println!("│ │ Analytical │ ODE │ Difference │"); - println!("├───────────┼─────────────────┼─────────────────┼─────────────────────┤"); + println!("\n┌─────────────────┬─────────────────┬─────────────────┬─────────────────────┐"); + println!("│ │ Analytical │ ODE │ Difference │"); + println!("├─────────────────┼─────────────────┼─────────────────┼─────────────────────┤"); println!( - "│ Likelihood│ {:>15.6} │ {:>15.6} │ {:>19.2e} │", - analytical_likelihoods, - ode_likelihoods, - analytical_likelihoods - ode_likelihoods + "│ Log-Likelihood │ {:>15.6} │ {:>15.6} │ {:>19.2e} │", + analytical_log_lik, + ode_log_lik, + analytical_log_lik - ode_log_lik ); - println!("├───────────┼─────────────────┼─────────────────┼─────────────────────┤"); - println!("│ Time │ Prediction │ Prediction │ │"); - println!("├───────────┼─────────────────┼─────────────────┼─────────────────────┤"); + println!("├─────────────────┼─────────────────┼─────────────────┼─────────────────────┤"); + println!("│ Time │ Prediction │ Prediction │ │"); + println!("├─────────────────┼─────────────────┼─────────────────┼─────────────────────┤"); let times = analytical_predictions.flat_times(); let analytical_preds = analytical_predictions.flat_predictions(); @@ -101,12 +101,12 @@ fn main() -> Result<(), pharmsol::PharmsolError> { { let diff = a - b; println!( - "│ {:>9.2} │ {:>15.9} │ {:>15.9} │ {:>19.2e} │", + "│ {:>15.2} │ {:>15.9} │ {:>15.9} │ {:>19.2e} │", t, a, b, diff ); } - println!("└───────────┴─────────────────┴─────────────────┴─────────────────────┘\n"); + println!("└─────────────────┴─────────────────┴─────────────────┴─────────────────────┘\n"); Ok(()) } diff --git a/examples/pf.rs b/examples/pf.rs index 9e0b4446..25235d98 100644 --- a/examples/pf.rs +++ b/examples/pf.rs @@ -37,11 +37,11 @@ fn main() { .unwrap(); let ll = sde - .estimate_likelihood(&subject, &vec![1.0], &ems, false) + .estimate_log_likelihood(&subject, &vec![1.0], &ems, false) .unwrap(); dbg!(sde - .estimate_likelihood(&subject, &vec![1.0], &ems, false) + .estimate_log_likelihood(&subject, &vec![1.0], &ems, false) .unwrap()); println!("{ll:#?}"); } diff --git a/src/lib.rs b/src/lib.rs index a88e465d..58f01d8b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,7 +30,7 @@ pub mod prelude { pub use crate::simulator::{ equation, equation::Equation, - likelihood::{log_psi, psi, PopulationPredictions, Prediction, SubjectPredictions}, + likelihood::{psi, PopulationPredictions, Prediction, SubjectPredictions}, }; } pub mod models { diff --git a/src/optimize/spp.rs b/src/optimize/spp.rs index ac19bd3b..dffce51a 100644 --- a/src/optimize/spp.rs +++ b/src/optimize/spp.rs @@ -11,7 +11,7 @@ pub struct SppOptimizer<'a, E: Equation> { equation: &'a E, data: &'a Data, sig: &'a ErrorModels, - pyl: &'a Array1, + log_pyl: &'a Array1, } impl CostFunction for SppOptimizer<'_, E> { @@ -20,22 +20,23 @@ impl CostFunction for SppOptimizer<'_, E> { fn cost(&self, spp: &Self::Param) -> Result { let theta = Array1::from(spp.clone()).insert_axis(Axis(0)); - let psi = psi(self.equation, self.data, &theta, self.sig, false, false)?; + let log_psi = psi(self.equation, self.data, &theta, self.sig, false, false)?; - if psi.ncols() > 1 { - tracing::error!("Psi in SppOptimizer has more than one column"); + if log_psi.ncols() > 1 { + tracing::error!("log_psi in SppOptimizer has more than one column"); } - if psi.nrows() != self.pyl.len() { + if log_psi.nrows() != self.log_pyl.len() { tracing::error!( - "Psi in SppOptimizer has {} rows, but spp has {}", - psi.nrows(), - self.pyl.len() + "log_psi in SppOptimizer has {} rows, but spp has {}", + log_psi.nrows(), + self.log_pyl.len() ); } - let nsub = psi.nrows() as f64; + let nsub = log_psi.nrows() as f64; let mut sum = -nsub; - for (p_i, pyl_i) in psi.iter().zip(self.pyl.iter()) { - sum += p_i / pyl_i; + // Convert log-likelihoods back to likelihood ratio: exp(log_psi - log_pyl) + for (log_p_i, log_pyl_i) in log_psi.iter().zip(self.log_pyl.iter()) { + sum += (log_p_i - log_pyl_i).exp(); } Ok(-sum) } @@ -46,13 +47,13 @@ impl<'a, E: Equation> SppOptimizer<'a, E> { equation: &'a E, data: &'a Data, sig: &'a ErrorModels, - pyl: &'a Array1, + log_pyl: &'a Array1, ) -> Self { Self { equation, data, sig, - pyl, + log_pyl, } } pub fn optimize_point(self, spp: Array1) -> Result, Error> { diff --git a/src/simulator/equation/analytical/mod.rs b/src/simulator/equation/analytical/mod.rs index 5db77979..6a5e13d4 100644 --- a/src/simulator/equation/analytical/mod.rs +++ b/src/simulator/equation/analytical/mod.rs @@ -191,7 +191,7 @@ impl EquationPriv for Analytical { let pred = y[observation.outeq()]; let pred = observation.to_prediction(pred, x.as_slice().to_vec()); if let Some(error_models) = error_models { - likelihood.push(pred.likelihood(error_models)?); + likelihood.push(pred.log_likelihood(error_models)?); } output.add_prediction(pred); Ok(()) @@ -283,16 +283,6 @@ mod tests { } } impl Equation for Analytical { - fn estimate_likelihood( - &self, - subject: &Subject, - support_point: &Vec, - error_models: &ErrorModels, - cache: bool, - ) -> Result { - _estimate_likelihood(self, subject, support_point, error_models, cache) - } - fn estimate_log_likelihood( &self, subject: &Subject, @@ -341,18 +331,3 @@ fn _subject_predictions( ) -> Result { Ok(ode.simulate_subject(subject, support_point, None)?.0) } - -fn _estimate_likelihood( - ode: &Analytical, - subject: &Subject, - support_point: &Vec, - error_models: &ErrorModels, - cache: bool, -) -> Result { - let ypred = if cache { - _subject_predictions(ode, subject, support_point) - } else { - _subject_predictions_no_cache(ode, subject, support_point) - }?; - ypred.likelihood(error_models) -} diff --git a/src/simulator/equation/mod.rs b/src/simulator/equation/mod.rs index 70219080..3fa2f835 100644 --- a/src/simulator/equation/mod.rs +++ b/src/simulator/equation/mod.rs @@ -171,32 +171,10 @@ pub(crate) trait EquationPriv: EquationTypes { /// and estimate parameters. #[allow(private_bounds)] pub trait Equation: EquationPriv + 'static + Clone + Sync { - /// Estimate the likelihood of the subject given the support point and error model. - /// - /// This function calculates how likely the observed data is given the model - /// parameters and error model. It may use caching for performance. - /// - /// # Parameters - /// - `subject`: The subject data - /// - `support_point`: The parameter values - /// - `error_model`: The error model - /// - `cache`: Whether to use caching - /// - /// # Returns - /// The likelihood value (product of individual observation likelihoods) - fn estimate_likelihood( - &self, - subject: &Subject, - support_point: &Vec, - error_models: &ErrorModels, - cache: bool, - ) -> Result; - /// Estimate the log-likelihood of the subject given the support point and error model. /// /// This function calculates the log of how likely the observed data is given the model - /// parameters and error model. It is numerically more stable than `estimate_likelihood` - /// for extreme values or many observations. + /// parameters and error model. It is numerically stable for extreme values or many observations. /// /// # Parameters /// - `subject`: The subject data @@ -282,7 +260,8 @@ pub trait Equation: EquationPriv + 'static + Clone + Sync { )?; } } - let ll = error_models.map(|_| likelihood.iter().product::()); + // Return log-likelihood (sum of log-likelihoods) when error_models is provided + let ll = error_models.map(|_| likelihood.iter().sum::()); Ok((output, ll)) } } diff --git a/src/simulator/equation/ode/mod.rs b/src/simulator/equation/ode/mod.rs index 4c005ee7..1f1623af 100644 --- a/src/simulator/equation/ode/mod.rs +++ b/src/simulator/equation/ode/mod.rs @@ -69,21 +69,6 @@ fn spphash(spp: &[f64]) -> u64 { /// Hash a subject ID string to u64 for cache key generation. -fn _estimate_likelihood( - ode: &ODE, - subject: &Subject, - support_point: &Vec, - error_models: &ErrorModels, - cache: bool, -) -> Result { - let ypred = if cache { - _subject_predictions(ode, subject, support_point) - } else { - _subject_predictions_no_cache(ode, subject, support_point) - }?; - ypred.likelihood(error_models) -} - #[inline(always)] #[cached( ty = "UnboundCache<(u64, u64), SubjectPredictions>", @@ -174,16 +159,6 @@ impl EquationPriv for ODE { } impl Equation for ODE { - fn estimate_likelihood( - &self, - subject: &Subject, - support_point: &Vec, - error_models: &ErrorModels, - cache: bool, - ) -> Result { - _estimate_likelihood(self, subject, support_point, error_models, cache) - } - fn estimate_log_likelihood( &self, subject: &Subject, @@ -324,7 +299,7 @@ impl Equation for ODE { let pred = observation.to_prediction(pred, solver.state().y.as_slice().to_vec()); if let Some(error_models) = error_models { - likelihood.push(pred.likelihood(error_models)?); + likelihood.push(pred.log_likelihood(error_models)?); } output.add_prediction(pred); } @@ -379,7 +354,7 @@ impl Equation for ODE { } } } - let ll = error_models.map(|_| likelihood.iter().product::()); + let ll = error_models.map(|_| likelihood.iter().sum::()); Ok((output, ll)) } } diff --git a/src/simulator/equation/sde/mod.rs b/src/simulator/equation/sde/mod.rs index 63d2cf82..aab4eab2 100644 --- a/src/simulator/equation/sde/mod.rs +++ b/src/simulator/equation/sde/mod.rs @@ -306,23 +306,31 @@ impl EquationPriv for SDE { //e = y[t] .- x[:,1] // q = pdf.(Distributions.Normal(0, 0.5), e) if let Some(em) = error_models { - let mut q: Vec = Vec::with_capacity(self.nparticles); + // Compute log-likelihoods for each particle + let mut log_q: Vec = Vec::with_capacity(self.nparticles); pred.iter().for_each(|p| { - let lik = p.likelihood(em); - match lik { - Ok(l) => q.push(l), - Err(e) => panic!("Error in likelihood calculation: {:?}", e), + let log_lik = p.log_likelihood(em); + match log_lik { + Ok(l) => log_q.push(l), + Err(e) => panic!("Error in log-likelihood calculation: {:?}", e), } }); - let sum_q: f64 = q.iter().sum(); - let w: Vec = q.iter().map(|qi| qi / sum_q).collect(); + + // Use log-sum-exp trick for numerical stability + let max_log_q = log_q.iter().cloned().fold(f64::NEG_INFINITY, f64::max); + let sum_exp: f64 = log_q.iter().map(|&lq| (lq - max_log_q).exp()).sum(); + let log_sum_q = max_log_q + sum_exp.ln(); + + // Compute normalized weights from log-likelihoods + let w: Vec = log_q.iter().map(|&lq| (lq - log_sum_q).exp()).collect(); let i = sysresample(&w); let a: Vec> = i.iter().map(|&i| x[i].clone()).collect(); *x = a; - likelihood.push(sum_q / self.nparticles as f64); - // let qq: Vec = i.iter().map(|&i| q[i]).collect(); - // likelihood.push(qq.iter().sum::() / self.nparticles as f64); + + // Push the average likelihood (in regular space) for final computation + // log(mean(likelihood)) = log(sum(exp(log_lik))) - log(n) + likelihood.push((log_sum_q - (self.nparticles as f64).ln()).exp()); } Ok(()) } @@ -351,7 +359,7 @@ impl EquationPriv for SDE { } impl Equation for SDE { - /// Estimates the likelihood of observed data given a model and parameters. + /// Estimates the log-likelihood of observed data given a model and parameters. /// /// # Arguments /// @@ -363,31 +371,20 @@ impl Equation for SDE { /// # Returns /// /// The log-likelihood of the observed data given the model and parameters. - fn estimate_likelihood( + fn estimate_log_likelihood( &self, subject: &Subject, support_point: &Vec, error_models: &ErrorModels, cache: bool, ) -> Result { - if cache { + // For SDE, the particle filter computes likelihood in regular space internally. + // We take the log of the final likelihood. + let lik = if cache { _estimate_likelihood(self, subject, support_point, error_models) } else { _estimate_likelihood_no_cache(self, subject, support_point, error_models) - } - } - - fn estimate_log_likelihood( - &self, - subject: &Subject, - support_point: &Vec, - error_models: &ErrorModels, - cache: bool, - ) -> Result { - // For SDE, the particle filter computes likelihood in regular space. - // We take the log of the cached/computed likelihood. - // Note: For extreme underflow cases, this may return -inf. - let lik = self.estimate_likelihood(subject, support_point, error_models, cache)?; + }?; if lik > 0.0 { Ok(lik.ln()) } else { diff --git a/src/simulator/likelihood/mod.rs b/src/simulator/likelihood/mod.rs index b8938d54..5cac7758 100644 --- a/src/simulator/likelihood/mod.rs +++ b/src/simulator/likelihood/mod.rs @@ -12,9 +12,6 @@ use statrs::distribution::Normal; mod progress; -const FRAC_1_SQRT_2PI: f64 = - std::f64::consts::FRAC_2_SQRT_PI * std::f64::consts::FRAC_1_SQRT_2 / 2.0; - // ln(2π) = ln(2) + ln(π) ≈ 1.8378770664093453 const LOG_2PI: f64 = 1.8378770664093453_f64; @@ -43,29 +40,6 @@ impl Predictions for SubjectPredictions { } impl SubjectPredictions { - /// Calculate the likelihood of the predictions given an error model. - /// - /// This multiplies the likelihood of each prediction to get the joint likelihood. - /// - /// # Parameters - /// - `error_model`: The error model to use for calculating the likelihood - /// - /// # Returns - /// The product of all individual prediction likelihoods - pub fn likelihood(&self, error_models: &ErrorModels) -> Result { - match self.predictions.is_empty() { - true => Ok(1.0), - false => self - .predictions - .iter() - .filter(|p| p.observation.is_some()) - .map(|p| p.likelihood(error_models)) - .collect::, _>>() - .map(|likelihoods| likelihoods.iter().product()) - .map_err(PharmsolError::from), - } - } - /// Calculate the log-likelihood of the predictions given an error model. /// /// This sums the log-likelihood of each prediction to get the joint log-likelihood. @@ -135,12 +109,6 @@ impl SubjectPredictions { } } -/// Probability density function of the normal distribution -#[inline(always)] -fn normpdf(obs: f64, pred: f64, sigma: f64) -> f64 { - (FRAC_1_SQRT_2PI / sigma) * (-((obs - pred) * (obs - pred)) / (2.0 * sigma * sigma)).exp() -} - /// Log of the probability density function of the normal distribution. /// /// This is numerically stable and avoids underflow for extreme values. @@ -192,12 +160,6 @@ fn lognormccdf(obs: f64, pred: f64, sigma: f64) -> Result } } -#[inline(always)] -fn normcdf(obs: f64, pred: f64, sigma: f64) -> Result { - let norm = Normal::new(pred, sigma).map_err(|_| ErrorModelError::NegativeSigma)?; - Ok(norm.cdf(obs)) -} - impl From> for SubjectPredictions { fn from(predictions: Vec) -> Self { Self { @@ -230,78 +192,6 @@ impl From> for PopulationPredictions { } } -/// Calculate the psi matrix for maximum likelihood estimation. -/// -/// # Parameters -/// - `equation`: The equation to use for simulation -/// - `subjects`: The subject data -/// - `support_points`: The support points to evaluate -/// - `error_model`: The error model to use -/// - `progress`: Whether to show a progress bar -/// - `cache`: Whether to use caching -/// -/// # Returns -/// A 2D array of likelihoods -pub fn psi( - equation: &impl Equation, - subjects: &Data, - support_points: &Array2, - error_models: &ErrorModels, - progress: bool, - cache: bool, -) -> Result, PharmsolError> { - let mut psi: Array2 = Array2::default((subjects.len(), support_points.nrows()).f()); - - let subjects = subjects.subjects(); - - let progress_tracker = if progress { - let total = subjects.len() * support_points.nrows(); - println!( - "Simulating {} subjects with {} support points each...", - subjects.len(), - support_points.nrows() - ); - Some(ProgressTracker::new(total)) - } else { - None - }; - - let result: Result<(), PharmsolError> = psi - .axis_iter_mut(Axis(0)) - .into_par_iter() - .enumerate() - .try_for_each(|(i, mut row)| { - row.axis_iter_mut(Axis(0)) - .into_par_iter() - .enumerate() - .try_for_each(|(j, mut element)| { - let subject = subjects.get(i).unwrap(); - match equation.estimate_likelihood( - subject, - support_points.row(j).to_vec().as_ref(), - error_models, - cache, - ) { - Ok(likelihood) => { - element.fill(likelihood); - if let Some(ref tracker) = progress_tracker { - tracker.inc(); - } - } - Err(e) => return Err(e), - }; - Ok(()) - }) - }); - - if let Some(tracker) = progress_tracker { - tracker.finish(); - } - - result?; - Ok(psi) -} - /// Calculate the log-likelihood matrix for all subjects and support points. /// /// This function computes log-likelihoods directly in log-space, which is numerically @@ -319,7 +209,7 @@ pub fn psi( /// /// # Returns /// A 2D array of log-likelihoods with shape (n_subjects, n_support_points) -pub fn log_psi( +pub fn psi( equation: &impl Equation, subjects: &Data, support_points: &Array2, @@ -444,32 +334,6 @@ impl Prediction { self.observation.map(|obs| (self.prediction - obs).powi(2)) } - /// Calculate the likelihood of this prediction given an error model. - /// - /// Returns an error if the observation is missing or if the likelihood is either zero or non-finite. - pub fn likelihood(&self, error_models: &ErrorModels) -> Result { - if self.observation.is_none() { - return Err(PharmsolError::MissingObservation); - } - - let sigma = error_models.sigma(self)?; - - //TODO: For the BLOQ and ALOQ cases, we should be using the LOQ values, not the observation values. - let likelihood = match self.censoring { - Censor::None => normpdf(self.observation.unwrap(), self.prediction, sigma), - Censor::BLOQ => normcdf(self.observation.unwrap(), self.prediction, sigma)?, - Censor::ALOQ => 1.0 - normcdf(self.observation.unwrap(), self.prediction, sigma)?, - }; - - if likelihood.is_finite() { - return Ok(likelihood); - } else if likelihood == 0.0 { - return Err(PharmsolError::ZeroLikelihood); - } else { - return Err(PharmsolError::NonFiniteLikelihood(likelihood)); - } - } - /// Calculate the log-likelihood of this prediction given an error model. /// /// This method is numerically stable and avoids underflow issues that can occur @@ -569,14 +433,14 @@ mod tests { use crate::Censor; #[test] - fn empty_predictions_have_neutral_likelihood() { + fn empty_predictions_have_neutral_log_likelihood() { let preds = SubjectPredictions::default(); let errors = ErrorModels::new(); - assert_eq!(preds.likelihood(&errors).unwrap(), 1.0); + assert_eq!(preds.log_likelihood(&errors).unwrap(), 0.0); // log(1) = 0 } #[test] - fn likelihood_combines_observations() { + fn log_likelihood_combines_observations() { let mut preds = SubjectPredictions::default(); let obs = Observation::new(0.0, Some(1.0), 0, None, 0, Censor::None); preds.add_prediction(obs.to_prediction(1.0, vec![])); @@ -584,11 +448,13 @@ mod tests { let error_model = ErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 0.0); let errors = ErrorModels::new().add(0, error_model).unwrap(); - assert!(preds.likelihood(&errors).unwrap() > 0.0); + // log_likelihood should be finite and negative (log of a probability < 1) + let ll = preds.log_likelihood(&errors).unwrap(); + assert!(ll.is_finite()); } #[test] - fn test_log_likelihood_equals_log_of_likelihood() { + fn test_log_likelihood_basic() { // Create a prediction with an observation let prediction = Prediction { time: 1.0, @@ -609,17 +475,11 @@ mod tests { ) .unwrap(); - let lik = prediction.likelihood(&error_models).unwrap(); let log_lik = prediction.log_likelihood(&error_models).unwrap(); - // log_likelihood should equal ln(likelihood) - let expected_log_lik = lik.ln(); - assert!( - (log_lik - expected_log_lik).abs() < 1e-10, - "log_likelihood ({}) should equal ln(likelihood) ({})", - log_lik, - expected_log_lik - ); + // log_likelihood should be finite and negative (log of probability < 1) + assert!(log_lik.is_finite()); + assert!(log_lik < 0.0, "log_likelihood should be negative"); } #[test] @@ -644,9 +504,6 @@ mod tests { ) .unwrap(); - // Regular likelihood will be extremely small but non-zero - let lik = prediction.likelihood(&error_models).unwrap(); - // log_likelihood should give a finite (very negative) value let log_lik = prediction.log_likelihood(&error_models).unwrap(); @@ -655,23 +512,11 @@ mod tests { log_lik < -100.0, "log_likelihood should be very negative for large mismatch" ); - - // They should match: log_lik ≈ ln(lik) - if lik > 0.0 && lik.ln().is_finite() { - let diff = (log_lik - lik.ln()).abs(); - assert!( - diff < 1e-6, - "log_likelihood ({}) should equal ln(likelihood) ({}) for non-extreme cases, diff={}", - log_lik, - lik.ln(), - diff - ); - } } #[test] fn test_log_likelihood_extreme_underflow() { - // Test with truly extreme values where regular likelihood underflows to 0 + // Test with truly extreme values where regular likelihood would underflow to 0 let prediction = Prediction { time: 1.0, observation: Some(10.0), @@ -691,10 +536,7 @@ mod tests { ) .unwrap(); - // Regular likelihood may underflow to 0 - let _lik_result = prediction.likelihood(&error_models); - - // log_likelihood should still work + // log_likelihood should still work even for extreme values let log_lik = prediction.log_likelihood(&error_models).unwrap(); assert!( @@ -744,17 +586,11 @@ mod tests { ) .unwrap(); - let lik = subject_predictions.likelihood(&error_models).unwrap(); let log_lik = subject_predictions.log_likelihood(&error_models).unwrap(); - // Sum of log likelihoods should equal log of product of likelihoods - let expected_log_lik = lik.ln(); - assert!( - (log_lik - expected_log_lik).abs() < 1e-10, - "Subject log_likelihood ({}) should equal ln(likelihood) ({})", - log_lik, - expected_log_lik - ); + // log_likelihood should be finite and negative + assert!(log_lik.is_finite()); + assert!(log_lik < 0.0, "Combined log_likelihood should be negative"); } #[test] @@ -764,12 +600,16 @@ mod tests { let pred = 0.0; let sigma = 1.0; - let pdf = normpdf(obs, pred, sigma); let log_pdf = lognormpdf(obs, pred, sigma); + // For standard normal at x=0: pdf = 1/sqrt(2*pi) ≈ 0.3989 + // log(pdf) ≈ -0.9189 + let expected = -0.5 * LOG_2PI; // -0.5 * ln(2π) assert!( - (log_pdf - pdf.ln()).abs() < 1e-12, - "lognormpdf should equal ln(normpdf)" + (log_pdf - expected).abs() < 1e-12, + "lognormpdf at x=0 should equal -0.5*ln(2π), got {} expected {}", + log_pdf, + expected ); } } diff --git a/tests/ode_optimizations.rs b/tests/ode_optimizations.rs index 024405fb..b343e455 100644 --- a/tests/ode_optimizations.rs +++ b/tests/ode_optimizations.rs @@ -881,21 +881,20 @@ fn likelihood_calculation_matches_analytical() { let params = vec![0.1, 50.0]; let ll_analytical = analytical - .estimate_likelihood(&subject, ¶ms, &error_models, false) - .expect("analytical likelihood"); + .estimate_log_likelihood(&subject, ¶ms, &error_models, false) + .expect("analytical log-likelihood"); let ll_ode = ode - .estimate_likelihood(&subject, ¶ms, &error_models, false) - .expect("ode likelihood"); + .estimate_log_likelihood(&subject, ¶ms, &error_models, false) + .expect("ode log-likelihood"); let ll_diff = (ll_analytical - ll_ode).abs(); - let ll_rel_diff = ll_diff / ll_analytical.abs().max(1e-10); - + // For log-likelihoods, we compare absolute differences rather than relative assert!( - ll_rel_diff < 0.01, // Within 1% - "Likelihoods should match: analytical={:.6}, ode={:.6}, rel_diff={:.2e}", + ll_diff < 0.01, // Within 0.01 log-likelihood units + "Log-likelihoods should match: analytical={:.6}, ode={:.6}, diff={:.2e}", ll_analytical, ll_ode, - ll_rel_diff + ll_diff ); }