diff --git a/Cargo.toml b/Cargo.toml index 69a8b0d2a..2341051bd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ tracing-subscriber = { version = "0.3.19", features = [ ] } faer = "0.23.1" faer-ext = { version = "0.7.1", features = ["nalgebra", "ndarray"] } -pharmsol = "=0.21.0" +pharmsol = "=0.22.0" rand = "0.9.0" anyhow = "1.0.100" rayon = "1.10.0" diff --git a/src/algorithms/mod.rs b/src/algorithms/mod.rs index ab5c16912..35eb2bd34 100644 --- a/src/algorithms/mod.rs +++ b/src/algorithms/mod.rs @@ -1,6 +1,7 @@ use std::fs; use std::path::Path; +use crate::routines::math::logsumexp_rows; use crate::routines::output::NPResult; use crate::routines::settings::Settings; use crate::structs::psi::Psi; @@ -36,35 +37,63 @@ pub trait Algorithms: Sync + Send + 'static { // Count problematic values in psi let mut nan_count = 0; let mut inf_count = 0; + let is_log_space = match self.psi().space() { + crate::structs::psi::Space::Linear => false, + crate::structs::psi::Space::Log => true, + }; let psi = self.psi().matrix().as_ref().into_ndarray(); - // First coerce all NaN and infinite in psi to 0.0 + // First coerce all NaN and infinite in psi to 0.0 (or NEG_INFINITY for log-space) for i in 0..psi.nrows() { for j in 0..self.psi().matrix().ncols() { let val = psi.get((i, j)).unwrap(); if val.is_nan() { nan_count += 1; - // *val = 0.0; } else if val.is_infinite() { - inf_count += 1; - // *val = 0.0; + // In log-space, NEG_INFINITY is valid (represents zero probability) + // Only count positive infinity as problematic + if !is_log_space || val.is_sign_positive() { + inf_count += 1; + } } } } if nan_count + inf_count > 0 { tracing::warn!( - "Psi matrix contains {} NaN, {} Infinite values of {} total values", + "Psi matrix contains {} NaN, {} problematic Infinite values of {} total values", nan_count, inf_count, psi.ncols() * psi.nrows() ); } - let (_, col) = psi.dim(); - let ecol: ArrayBase, Dim<[usize; 1]>> = Array::ones(col); - let plam = psi.dot(&ecol); - let w = 1. / &plam; + // Calculate row sums: for regular space: sum; for log-space: logsumexp + let plam: ArrayBase, Dim<[usize; 1]>> = if is_log_space { + // For log-space, use logsumexp for each row + Array::from_vec(logsumexp_rows(psi.nrows(), psi.ncols(), |i, j| psi[(i, j)])) + } else { + // For regular space, sum each row + let (_, col) = psi.dim(); + let ecol: ArrayBase, Dim<[usize; 1]>> = Array::ones(col); + psi.dot(&ecol) + }; + + // Check for subjects with zero probability + // In log-space: -inf means zero probability + // In regular space: 0 means zero probability + let w: ArrayBase, Dim<[usize; 1]>> = if is_log_space { + // For log-space, check if logsumexp result is -inf + Array::from_shape_fn(plam.len(), |i| { + if plam[i].is_infinite() && plam[i].is_sign_negative() { + f64::INFINITY // Will be flagged as problematic + } else { + 1.0 // Valid + } + }) + } else { + 1. / &plam + }; // Get the index of each element in `w` that is NaN or infinite let indices: Vec = w diff --git a/src/algorithms/npag.rs b/src/algorithms/npag.rs index 68ed04693..e83f20d64 100644 --- a/src/algorithms/npag.rs +++ b/src/algorithms/npag.rs @@ -1,8 +1,9 @@ use crate::algorithms::{Status, StopReason}; use crate::prelude::algorithms::Algorithms; -pub use crate::routines::estimation::ipm::burke; +pub use crate::routines::estimation::ipm::{burke, burke_ipm, burke_log}; pub use crate::routines::estimation::qr; +use crate::routines::math::logsumexp; use crate::routines::settings::Settings; use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult}; @@ -160,8 +161,24 @@ impl Algorithms for NPAG { if (self.last_objf - self.objf).abs() <= THETA_G && self.eps > THETA_E { self.eps /= 2.; if self.eps <= THETA_E { - let pyl = psi * w.weights(); - self.f1 = pyl.iter().map(|x| x.ln()).sum(); + // Compute f1 = sum(log(pyl)) where pyl = psi * w + self.f1 = if self.psi.space() == crate::structs::psi::Space::Log { + // For log-space: f1 = sum_i(logsumexp(log_psi[i,:] + log(w))) + let log_w: Vec = w.weights().iter().map(|&x| x.ln()).collect(); + (0..psi.nrows()) + .map(|i| { + let combined: Vec = (0..psi.ncols()) + .map(|j| *psi.get(i, j) + log_w[j]) + .collect(); + logsumexp(&combined) + }) + .sum() + } else { + // For regular space: f1 = sum(log(psi * w)) + let pyl = psi * w.weights(); + pyl.iter().map(|x| x.ln()).sum() + }; + if (self.f1 - self.f0).abs() <= THETA_F { tracing::info!("The model converged after {} cycles", self.cycle,); self.set_status(Status::Stop(StopReason::Converged)); @@ -204,24 +221,20 @@ impl Algorithms for NPAG { &self.error_models, self.cycle == 1 && self.settings.config().progress, self.cycle != 1, + self.settings.advanced().space, )?; if let Err(err) = self.validate_psi() { bail!(err); } - (self.lambda, _) = match burke(&self.psi) { - Ok((lambda, objf)) => (lambda, objf), - Err(err) => { - bail!("Error in IPM during estimation: {:?}", err); - } - }; + (self.lambda, _) = burke_ipm(&self.psi) + .map_err(|err| anyhow::anyhow!("Error in IPM during estimation: {:?}", err))?; Ok(()) } fn condensation(&mut self) -> Result<()> { // Filter out the support points with lambda < max(lambda)/1000 - let max_lambda = self .lambda .iter() @@ -273,15 +286,9 @@ impl Algorithms for NPAG { self.psi.filter_column_indices(keep.as_slice()); self.validate_psi()?; - (self.lambda, self.objf) = match burke(&self.psi) { - Ok((lambda, objf)) => (lambda, objf), - Err(err) => { - return Err(anyhow::anyhow!( - "Error in IPM during condensation: {:?}", - err - )); - } - }; + + (self.lambda, self.objf) = burke_ipm(&self.psi) + .map_err(|err| anyhow::anyhow!("Error in IPM during condensation: {:?}", err))?; self.w = self.lambda.clone(); Ok(()) } @@ -298,8 +305,6 @@ impl Algorithms for NPAG { } }) .try_for_each(|(outeq, em)| -> Result<()> { - // OPTIMIZATION - let gamma_up = em.factor()? * (1.0 + self.gamma_delta[outeq]); let gamma_down = em.factor()? / (1.0 + self.gamma_delta[outeq]); @@ -316,7 +321,9 @@ impl Algorithms for NPAG { &error_model_up, false, true, + self.settings.advanced().space, )?; + let psi_down = calculate_psi( &self.equation, &self.data, @@ -324,20 +331,15 @@ impl Algorithms for NPAG { &error_model_down, false, true, + self.settings.advanced().space, )?; - let (lambda_up, objf_up) = match burke(&psi_up) { - Ok((lambda, objf)) => (lambda, objf), - Err(err) => { - bail!("Error in IPM during optim: {:?}", err); - } - }; - let (lambda_down, objf_down) = match burke(&psi_down) { - Ok((lambda, objf)) => (lambda, objf), - Err(err) => { - bail!("Error in IPM during optim: {:?}", err); - } - }; + let (lambda_up, objf_up) = burke_ipm(&psi_up) + .map_err(|err| anyhow::anyhow!("Error in IPM during optim: {:?}", err))?; + + let (lambda_down, objf_down) = burke_ipm(&psi_down) + .map_err(|err| anyhow::anyhow!("Error in IPM during optim: {:?}", err))?; + if objf_up > self.objf { self.error_models.set_factor(outeq, gamma_up)?; self.objf = objf_up; diff --git a/src/algorithms/npod.rs b/src/algorithms/npod.rs index ed962d971..0cb11da4e 100644 --- a/src/algorithms/npod.rs +++ b/src/algorithms/npod.rs @@ -1,35 +1,31 @@ use crate::algorithms::StopReason; use crate::routines::initialization::sample_space; +use crate::routines::math::logsumexp; use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult}; +use crate::structs::psi::calculate_psi; use crate::structs::weights::Weights; use crate::{ algorithms::Status, prelude::{ algorithms::Algorithms, routines::{ - estimation::{ipm::burke, qr}, + estimation::{ipm::burke_ipm, qr}, settings::Settings, }, }, - structs::{ - psi::{calculate_psi, Psi}, - theta::Theta, - }, + structs::{psi::Psi, theta::Theta}, }; use pharmsol::SppOptimizer; use anyhow::bail; use anyhow::Result; use faer_ext::IntoNdarray; +use pharmsol::prelude::{data::Data, simulator::Equation}; use pharmsol::{prelude::ErrorModel, ErrorModels}; -use pharmsol::{ - prelude::{data::Data, simulator::Equation}, - Subject, -}; use ndarray::{ parallel::prelude::{IntoParallelRefMutIterator, ParallelIterator}, - Array, Array1, ArrayBase, Dim, OwnedRepr, + Array1, }; const THETA_F: f64 = 1e-2; @@ -207,27 +203,22 @@ impl Algorithms for NPOD { } fn estimation(&mut self) -> Result<()> { - let error_model: ErrorModels = self.error_models.clone(); - self.psi = calculate_psi( &self.equation, &self.data, &self.theta, - &error_model, + &self.error_models, self.cycle == 1 && self.settings.config().progress, self.cycle != 1, + self.settings.advanced().space, )?; if let Err(err) = self.validate_psi() { bail!(err); } - (self.lambda, _) = match burke(&self.psi) { - Ok((lambda, objf)) => (lambda, objf), - Err(err) => { - bail!(err); - } - }; + (self.lambda, _) = burke_ipm(&self.psi) + .map_err(|err| anyhow::anyhow!("Error in IPM during estimation: {:?}", err))?; Ok(()) } @@ -280,12 +271,8 @@ impl Algorithms for NPOD { self.theta.filter_indices(keep.as_slice()); self.psi.filter_column_indices(keep.as_slice()); - (self.lambda, self.objf) = match burke(&self.psi) { - Ok((lambda, objf)) => (lambda, objf), - Err(err) => { - return Err(anyhow::anyhow!("Error in IPM: {:?}", err)); - } - }; + (self.lambda, self.objf) = burke_ipm(&self.psi) + .map_err(|err| anyhow::anyhow!("Error in IPM during condensation: {:?}", err))?; self.w = self.lambda.clone(); Ok(()) } @@ -320,6 +307,7 @@ impl Algorithms for NPOD { &error_model_up, false, true, + self.settings.advanced().space, )?; let psi_down = calculate_psi( &self.equation, @@ -328,20 +316,15 @@ impl Algorithms for NPOD { &error_model_down, false, true, + self.settings.advanced().space, )?; - let (lambda_up, objf_up) = match burke(&psi_up) { - Ok((lambda, objf)) => (lambda, objf), - Err(err) => { - bail!("Error in IPM during optim: {:?}", err); - } - }; - let (lambda_down, objf_down) = match burke(&psi_down) { - Ok((lambda, objf)) => (lambda, objf), - Err(err) => { - bail!("Error in IPM during optim: {:?}", err); - } - }; + let (lambda_up, objf_up) = burke_ipm(&psi_up) + .map_err(|err| anyhow::anyhow!("Error in IPM during optim: {:?}", err))?; + + let (lambda_down, objf_down) = burke_ipm(&psi_down) + .map_err(|err| anyhow::anyhow!("Error in IPM during optim: {:?}", err))?; + if objf_up > self.objf { self.error_models.set_factor(outeq, gamma_up)?; self.objf = objf_up; @@ -368,9 +351,28 @@ impl Algorithms for NPOD { fn expansion(&mut self) -> Result<()> { // If no stop signal, add new point to theta based on the optimization of the D function - let psi = self.psi().matrix().as_ref().into_ndarray().to_owned(); + // Note: SppOptimizer expects regular-space psi for the D-optimizer + // If we're in log-space, we need to convert pyl to regular space + + let psi_mat = self.psi().matrix().as_ref().into_ndarray().to_owned(); let w: Array1 = self.w.clone().iter().collect(); - let pyl = psi.dot(&w); + + // Compute pyl = P(Y|L) for each subject + // In log-space, we need to use logsumexp and then exp to get regular pyl + let pyl = match self.settings.advanced().space { + crate::structs::psi::Space::Log => { + let log_w: Array1 = w.iter().map(|&x| x.ln()).collect(); + let mut pyl = Array1::zeros(psi_mat.nrows()); + for i in 0..psi_mat.nrows() { + let combined: Vec = (0..psi_mat.ncols()) + .map(|j| psi_mat[[i, j]] + log_w[j]) + .collect(); + pyl[i] = logsumexp(&combined).exp(); + } + pyl + } + crate::structs::psi::Space::Linear => psi_mat.dot(&w), + }; // Add new point to theta based on the optimization of the D function let error_model: ErrorModels = self.error_models.clone(); @@ -397,48 +399,3 @@ impl Algorithms for NPOD { Ok(()) } } - -impl NPOD { - fn validate_psi(&mut self) -> Result<()> { - let mut psi = self.psi().matrix().as_ref().into_ndarray().to_owned(); - // First coerce all NaN and infinite in psi to 0.0 - if psi.iter().any(|x| x.is_nan() || x.is_infinite()) { - tracing::warn!("Psi contains NaN or Inf values, coercing to 0.0"); - for i in 0..psi.nrows() { - for j in 0..psi.ncols() { - let val = psi.get_mut((i, j)).unwrap(); - if val.is_nan() || val.is_infinite() { - *val = 0.0; - } - } - } - } - - // Calculate the sum of each column in psi - let (_, col) = psi.dim(); - let ecol: ArrayBase, Dim<[usize; 1]>> = Array::ones(col); - let plam = psi.dot(&ecol); - let w = 1. / &plam; - - // Get the index of each element in `w` that is NaN or infinite - let indices: Vec = w - .iter() - .enumerate() - .filter(|(_, x)| x.is_nan() || x.is_infinite()) - .map(|(i, _)| i) - .collect::>(); - - // If any elements in `w` are NaN or infinite, return the subject IDs for each index - if !indices.is_empty() { - let subject: Vec<&Subject> = self.data.subjects(); - let zero_probability_subjects: Vec<&String> = - indices.iter().map(|&i| subject[i].id()).collect(); - - return Err(anyhow::anyhow!( - "The probability of one or more subjects, given the model, is zero. The following subjects have zero probability: {:?}", zero_probability_subjects - )); - } - - Ok(()) - } -} diff --git a/src/algorithms/postprob.rs b/src/algorithms/postprob.rs index 496b36e28..a909bfce7 100644 --- a/src/algorithms/postprob.rs +++ b/src/algorithms/postprob.rs @@ -14,7 +14,7 @@ use pharmsol::prelude::{ simulator::Equation, }; -use crate::routines::estimation::ipm::burke; +use crate::routines::estimation::ipm::burke_ipm; use crate::routines::initialization; use crate::routines::output::{cycles::CycleLog, NPResult}; use crate::routines::settings::Settings; @@ -126,8 +126,10 @@ impl Algorithms for POSTPROB { &self.error_models, false, false, + self.settings.advanced().space, )?; - (self.w, self.objf) = burke(&self.psi).context("Error in IPM")?; + + (self.w, self.objf) = burke_ipm(&self.psi).context("Error in IPM")?; Ok(()) } diff --git a/src/bestdose/posterior.rs b/src/bestdose/posterior.rs index 9109f65c1..87b9a35fc 100644 --- a/src/bestdose/posterior.rs +++ b/src/bestdose/posterior.rs @@ -53,12 +53,13 @@ use anyhow::Result; use faer::Mat; -use crate::algorithms::npag::burke; use crate::algorithms::npag::NPAG; use crate::algorithms::Algorithms; use crate::algorithms::Status; use crate::prelude::*; +use crate::routines::estimation::ipm::burke_ipm; use crate::structs::psi::calculate_psi; +use crate::structs::psi::Space; use crate::structs::theta::Theta; use crate::structs::weights::Weights; use pharmsol::prelude::*; @@ -95,14 +96,24 @@ pub fn npagfull11_filter( past_data: &Data, eq: &ODE, error_models: &ErrorModels, + space: Space, ) -> Result<(Theta, Weights, Weights)> { tracing::info!("Stage 1.1: NPAGFULL11 Bayesian filtering"); // Calculate psi matrix P(data|theta_i) for all support points - let psi = calculate_psi(eq, past_data, population_theta, error_models, false, true)?; + // Use log-space or regular space based on setting + let psi = calculate_psi( + eq, + past_data, + population_theta, + error_models, + false, + true, + space, + )?; // First burke call to get initial posterior probabilities - let (initial_weights, _) = burke(&psi)?; + let (initial_weights, _) = burke_ipm(&psi)?; // NPAGFULL11 filtering: Keep all points within 1e-100 of the maximum weight // This is different from NPAG's condensation - NO QR decomposition here! @@ -325,6 +336,7 @@ pub fn calculate_two_step_posterior( past_data, eq, error_models, + settings.advanced().space, )?; // Step 1.2: NPAGFULL refinement diff --git a/src/routines/condensation/mod.rs b/src/routines/condensation/mod.rs index d01533b6b..2ef44fe6b 100644 --- a/src/routines/condensation/mod.rs +++ b/src/routines/condensation/mod.rs @@ -1,4 +1,4 @@ -use crate::algorithms::npag::{burke, qr}; +use crate::routines::estimation::{ipm::burke_ipm, qr}; use crate::structs::psi::Psi; use crate::structs::theta::Theta; use crate::structs::weights::Weights; @@ -93,8 +93,8 @@ pub fn condense_support_points( filtered_theta.filter_indices(&keep_qr); filtered_psi.filter_column_indices(&keep_qr); - // Step 3: Recalculate weights with Burke's IPM - let (final_weights, objf) = burke(&filtered_psi)?; + // Step 3: Recalculate weights with Burke's IPM (auto-dispatches based on psi.is_log_space()) + let (final_weights, objf) = burke_ipm(&filtered_psi)?; tracing::debug!( "Condensation complete: {} -> {} support points (objective: {:.4})", diff --git a/src/routines/estimation/ipm.rs b/src/routines/estimation/ipm.rs index fbb1768b2..d297babdb 100644 --- a/src/routines/estimation/ipm.rs +++ b/src/routines/estimation/ipm.rs @@ -278,6 +278,297 @@ pub fn burke(psi: &Psi) -> anyhow::Result<(Weights, f64)> { Ok((lam.into(), obj)) } +/// Applies Burke's Interior Point Method (IPM) operating in log space. +/// +/// This version works with log-likelihoods directly, which provides better numerical +/// stability when dealing with very small probabilities (many observations or extreme +/// parameter values). +/// +/// The objective function to maximize is: +/// f(x) = Σ(log(Σ(exp(log_ψ_ij) * x_j))) for i = 1 to n_sub +/// = Σ(logsumexp(log_ψ_ij + log(x_j))) +/// +/// subject to: +/// 1. x_j ≥ 0 for all j = 1 to n_point, +/// 2. Σ(x_j) = 1, +/// +/// # Arguments +/// +/// * `log_psi` - A reference to a Psi structure containing log-likelihoods. +/// +/// # Returns +/// +/// On success, returns a tuple `(weights, obj)` where: +/// - [Weights] contains the optimized weights (probabilities) for each support point. +/// - `obj` is the value of the objective function at the solution. +/// +/// # Errors +/// +/// This function returns an error if any step in the optimization fails. +pub fn burke_log(log_psi: &Psi) -> anyhow::Result<(Weights, f64)> { + let log_psi_mat = log_psi.matrix(); + + // Validate that all entries are finite + for row in log_psi_mat.row_iter() { + for &x in row.iter() { + if !x.is_finite() { + bail!("Input log-psi matrix must have finite entries"); + } + } + } + + let (n_sub, n_point) = log_psi_mat.shape(); + + if n_sub == 0 || n_point == 0 { + bail!("Input matrix cannot be empty"); + } + + // Convert log_psi to regular psi for the IPM iterations + // We need to work in regular space for the IPM, but we use logsumexp for the weighted sums + // to maintain numerical stability. + // + // Key insight: The IPM needs to compute psi * lam, which in log space is logsumexp(log_psi + log_lam). + // However, the internal IPM computations (Hessian, gradients) are more complex in log space. + // + // Strategy: Convert log_psi to psi using exp, but handle potential underflow by using + // a shifted version. We'll keep track of the shift and adjust the objective function. + + // Find the maximum log-likelihood per row to prevent underflow + let row_max: Vec = (0..n_sub) + .map(|i| { + (0..n_point) + .map(|j| *log_psi_mat.get(i, j)) + .fold(f64::NEG_INFINITY, f64::max) + }) + .collect(); + + // Create shifted psi matrix: psi_shifted[i,j] = exp(log_psi[i,j] - row_max[i]) + // This ensures the maximum value in each row is 1.0, preventing underflow + let psi_shifted: Mat = Mat::from_fn(n_sub, n_point, |i, j| { + let log_val = *log_psi_mat.get(i, j); + (log_val - row_max[i]).exp() + }); + + // Now run the standard IPM on the shifted matrix + let ecol: Col = Col::from_fn(n_point, |_| 1.0); + let erow: Row = Row::from_fn(n_sub, |_| 1.0); + + let mut plam: Col = &psi_shifted * &ecol; + let eps: f64 = 1e-8; + let mut sig: f64 = 0.0; + + let mut lam = ecol.clone(); + + let mut w: Col = Col::from_fn(plam.nrows(), |i| 1.0 / plam.get(i)); + + let mut ptw: Col = psi_shifted.transpose() * &w; + + let ptw_max = ptw.iter().fold(f64::NEG_INFINITY, |acc, &x| x.max(acc)); + let shrink = 2.0 * ptw_max; + lam *= shrink; + plam *= shrink; + w /= shrink; + ptw /= shrink; + + let mut y: Col = &ecol - &ptw; + let mut r: Col = Col::from_fn(n_sub, |i| erow.get(i) - w.get(i) * plam.get(i)); + let mut norm_r: f64 = r.iter().fold(0.0, |max, &val| max.max(val.abs())); + + let sum_log_plam: f64 = plam.iter().map(|x| x.ln()).sum(); + let sum_log_w: f64 = w.iter().map(|x| x.ln()).sum(); + let mut gap: f64 = (sum_log_w + sum_log_plam).abs() / (1.0 + sum_log_plam.abs()); + + let mut mu = lam.transpose() * &y / n_point as f64; + + let mut psi_inner: Mat = Mat::zeros(n_sub, n_point); + + let n_threads = faer::get_global_parallelism().degree(); + let mut output: Vec> = (0..n_threads).map(|_| Mat::zeros(n_sub, n_sub)).collect(); + + let mut h: Mat = Mat::zeros(n_sub, n_sub); + + while mu > eps || norm_r > eps || gap > eps { + let smu = sig * mu; + let inner = Col::from_fn(lam.nrows(), |i| lam.get(i) / y.get(i)); + let w_plam = Col::from_fn(plam.nrows(), |i| plam.get(i) / w.get(i)); + + // Scale columns and compute H matrix + if psi_shifted.ncols() > n_threads * 128 { + psi_inner + .par_col_partition_mut(n_threads) + .zip(psi_shifted.par_col_partition(n_threads)) + .zip(inner.par_partition(n_threads)) + .zip(output.par_iter_mut()) + .for_each(|(((mut psi_inner, psi_part), inner_part), output)| { + psi_inner + .as_mut() + .col_iter_mut() + .zip(psi_part.col_iter()) + .zip(inner_part.iter()) + .for_each(|((col, psi_col), inner_val)| { + col.iter_mut().zip(psi_col.iter()).for_each(|(x, psi_val)| { + *x = psi_val * inner_val; + }); + }); + faer::linalg::matmul::triangular::matmul( + output.as_mut(), + faer::linalg::matmul::triangular::BlockStructure::TriangularLower, + faer::Accum::Replace, + &psi_inner, + faer::linalg::matmul::triangular::BlockStructure::Rectangular, + psi_part.transpose(), + faer::linalg::matmul::triangular::BlockStructure::Rectangular, + 1.0, + faer::Par::Seq, + ); + }); + + let mut first_iter = true; + for out in &output { + if first_iter { + h.copy_from(out); + first_iter = false; + } else { + h += out; + } + } + } else { + psi_inner + .as_mut() + .col_iter_mut() + .zip(psi_shifted.col_iter()) + .zip(inner.iter()) + .for_each(|((col, psi_col), inner_val)| { + col.iter_mut().zip(psi_col.iter()).for_each(|(x, psi_val)| { + *x = psi_val * inner_val; + }); + }); + faer::linalg::matmul::triangular::matmul( + h.as_mut(), + faer::linalg::matmul::triangular::BlockStructure::TriangularLower, + faer::Accum::Replace, + &psi_inner, + faer::linalg::matmul::triangular::BlockStructure::Rectangular, + psi_shifted.transpose(), + faer::linalg::matmul::triangular::BlockStructure::Rectangular, + 1.0, + faer::Par::Seq, + ); + } + + for i in 0..h.nrows() { + h[(i, i)] += w_plam[i]; + } + + let uph = match h.llt(faer::Side::Lower) { + Ok(llt) => llt, + Err(_) => { + bail!("Error during Cholesky decomposition in log-space IPM. The matrix might not be positive definite.") + } + }; + let uph = uph.L().transpose().to_owned(); + + let smuyinv: Col = Col::from_fn(ecol.nrows(), |i| smu * (ecol[i] / y[i])); + let psi_dot_muyinv: Col = &psi_shifted * &smuyinv; + let rhsdw: Row = Row::from_fn(erow.ncols(), |i| erow[i] / w[i] - psi_dot_muyinv[i]); + + let mut dw = Mat::from_fn(rhsdw.ncols(), 1, |i, _j| *rhsdw.get(i)); + + solve_lower_triangular_in_place(uph.transpose().as_ref(), dw.as_mut(), faer::Par::rayon(0)); + solve_upper_triangular_in_place(uph.as_ref(), dw.as_mut(), faer::Par::rayon(0)); + + let dw = dw.col(0); + let dy = -(psi_shifted.transpose() * dw); + + let inner_times_dy = Col::from_fn(ecol.nrows(), |i| inner[i] * dy[i]); + let dlam: Row = + Row::from_fn(ecol.nrows(), |i| smuyinv[i] - lam[i] - inner_times_dy[i]); + + let ratio_dlam_lam = Row::from_fn(lam.nrows(), |i| dlam[i] / lam[i]); + let min_ratio_dlam = ratio_dlam_lam.iter().cloned().fold(f64::INFINITY, f64::min); + let mut alfpri: f64 = -1.0 / min_ratio_dlam.min(-0.5); + alfpri = (0.99995 * alfpri).min(1.0); + + let ratio_dy_y = Row::from_fn(y.nrows(), |i| dy[i] / y[i]); + let min_ratio_dy = ratio_dy_y.iter().cloned().fold(f64::INFINITY, f64::min); + let ratio_dw_w = Row::from_fn(dw.nrows(), |i| dw[i] / w[i]); + let min_ratio_dw = ratio_dw_w.iter().cloned().fold(f64::INFINITY, f64::min); + let mut alfdual = -1.0 / min_ratio_dy.min(-0.5); + alfdual = alfdual.min(-1.0 / min_ratio_dw.min(-0.5)); + alfdual = (0.99995 * alfdual).min(1.0); + + lam += alfpri * dlam.transpose(); + w += alfdual * dw; + y += alfdual * &dy; + + mu = lam.transpose() * &y / n_point as f64; + plam = &psi_shifted * &lam; + + r = Col::from_fn(n_sub, |i| erow.get(i) - w.get(i) * plam.get(i)); + ptw -= alfdual * dy; + + norm_r = r.norm_max(); + let sum_log_plam: f64 = plam.iter().map(|x| x.ln()).sum(); + let sum_log_w: f64 = w.iter().map(|x| x.ln()).sum(); + gap = (sum_log_w + sum_log_plam).abs() / (1.0 + sum_log_plam.abs()); + + if mu < eps && norm_r > eps { + sig = 1.0; + } else { + let candidate1 = (1.0 - alfpri).powi(2); + let candidate2 = (1.0 - alfdual).powi(2); + let candidate3 = (norm_r - mu) / (norm_r + 100.0 * mu); + sig = candidate1.max(candidate2).max(candidate3).min(0.3); + } + } + + // Scale lam + lam /= n_sub as f64; + + // Compute the objective function value in log space + // obj = sum_i(log(sum_j(psi[i,j] * lam[j]))) + // = sum_i(log(sum_j(exp(log_psi[i,j]) * lam[j]))) + // = sum_i(logsumexp(log_psi[i,j] + log(lam[j]))) + // + // But we worked with shifted psi, so: + // obj = sum_i(row_max[i] + log(sum_j(psi_shifted[i,j] * lam[j]))) + // = sum_i(row_max[i]) + sum_i(log(plam[i])) + let plam_final = &psi_shifted * &lam; + let obj: f64 = row_max.iter().sum::() + plam_final.iter().map(|x| x.ln()).sum::(); + + // Normalize lam to sum to 1 + let lam_sum: f64 = lam.iter().sum(); + lam = &lam / lam_sum; + + Ok((lam.into(), obj)) +} + +/// Unified IPM dispatch that automatically chooses the correct algorithm based on psi type. +/// +/// This function checks if `psi` is in log-space and calls the appropriate Burke IPM: +/// - If `psi.is_log_space()` is true: calls `burke_log` +/// - Otherwise: calls `burke` +/// +/// # Arguments +/// +/// * `psi` - A reference to a Psi structure (either regular or log-space) +/// +/// # Returns +/// +/// On success, returns a tuple `(weights, obj)` where: +/// - [Weights] contains the optimized weights (probabilities) for each support point. +/// - `obj` is the value of the objective function at the solution. +/// +/// # Errors +/// +/// Returns an error if the underlying IPM optimization fails. +pub fn burke_ipm(psi: &Psi) -> anyhow::Result<(Weights, f64)> { + match psi.space() { + crate::structs::psi::Space::Linear => burke(psi), + crate::structs::psi::Space::Log => burke_log(psi), + } +} + #[cfg(test)] mod tests { use super::*; @@ -514,4 +805,174 @@ mod tests { // The objective function should be finite assert!(obj.is_finite(), "Objective function should be finite"); } + + // ========== Log-space IPM tests ========== + + #[test] + fn test_burke_log_identity() { + // Test with identity matrix converted to log space + // log(1) = 0, log(0) = -inf, but we use a small positive value instead + use ndarray::Array2; + + let n = 10; + // Create identity matrix in log space: log(1) = 0, log(eps) for off-diagonal + let log_mat = Array2::from_shape_fn((n, n), |(i, j)| { + if i == j { + 0.0 // log(1) = 0 + } else { + -30.0 // very small probability, exp(-30) ≈ 0 + } + }); + + let mat = Mat::from_fn(n, n, |i, j| log_mat[(i, j)]); + let psi = Psi::new_log(mat); + + let (lam, obj) = burke_log(&psi).unwrap(); + + // For identity-like matrix, weights should be roughly equal + let expected = 1.0 / n as f64; + for i in 0..n { + assert_relative_eq!(lam[i], expected, epsilon = 1e-6); + } + + // Check that lambda sums to 1 + assert_relative_eq!(lam.iter().sum::(), 1.0, epsilon = 1e-10); + + // Objective should be finite + assert!(obj.is_finite(), "Objective function should be finite"); + } + + #[test] + fn test_burke_log_uniform() { + // Test with uniform matrix in log space + // log(1) = 0 everywhere + + use ndarray::Array2; + + let n_sub = 10; + let n_point = 10; + let log_mat = Array2::from_shape_fn((n_sub, n_point), |_| 0.0); // log(1) = 0 + + let mat = Mat::from_fn(n_sub, n_point, |i, j| log_mat[(i, j)]); + let psi = Psi::new_log(mat); + + let (lam, obj) = burke_log(&psi).unwrap(); + + // Check that lambda sums to 1 + assert_relative_eq!(lam.iter().sum::(), 1.0, epsilon = 1e-10); + + // For uniform matrix, all weights should be equal + let expected = 1.0 / n_point as f64; + for i in 0..n_point { + assert_relative_eq!(lam[i], expected, epsilon = 1e-6); + } + + // Objective should be finite + assert!(obj.is_finite(), "Objective function should be finite"); + } + + #[test] + fn test_burke_log_consistency_with_regular() { + // Test that burke_log produces the same results as burke + // when given equivalent inputs + use ndarray::Array2; + + let n_sub = 5; + let n_point = 8; + + // Create a regular psi matrix with positive values + let regular_mat = Array2::from_shape_fn((n_sub, n_point), |(i, j)| { + 0.5 + 0.1 * (i as f64) + 0.05 * (j as f64) + }); + let regular_psi = Psi::from(regular_mat.clone()); + + // Create the equivalent log-space matrix + let log_mat = regular_mat.mapv(|x| x.ln()); + + let mat = Mat::from_fn(n_sub, n_point, |i, j| log_mat[(i, j)]); + let psi = Psi::new_log(mat); + + // Run both algorithms + let (lam_regular, obj_regular) = burke(®ular_psi).unwrap(); + let (lam_log, obj_log) = burke_log(&psi).unwrap(); + + // The weights should be very similar + for i in 0..n_point { + assert_relative_eq!(lam_regular[i], lam_log[i], epsilon = 1e-6); + } + + // The objective functions should be very similar + assert_relative_eq!(obj_regular, obj_log, epsilon = 1e-6); + } + + #[test] + fn test_burke_log_handles_very_small_likelihoods() { + // Test that log-space IPM handles very small likelihoods that would + // underflow in regular space + use ndarray::Array2; + + let n_sub = 5; + let n_point = 5; + + // Create log-likelihoods that would underflow if exponentiated directly + // These represent likelihoods of exp(-500) ≈ 10^(-217) + let log_mat = Array2::from_shape_fn((n_sub, n_point), |(i, j)| { + -500.0 + (i as f64) * 0.1 + (j as f64) * 0.05 + }); + + let mat = Mat::from_fn(n_point, n_sub, |i, j| log_mat[(i, j)]); + let psi = Psi::new_log(mat); + + // This should succeed without underflow issues + let result = burke_log(&psi); + assert!( + result.is_ok(), + "Log-space IPM should handle very small likelihoods" + ); + + let (lam, obj) = result.unwrap(); + + // Check basic properties + assert_relative_eq!(lam.iter().sum::(), 1.0, epsilon = 1e-10); + assert!(obj.is_finite(), "Objective function should be finite"); + + // All weights should be non-negative + for i in 0..n_point { + assert!(lam[i] >= 0.0, "Lambda values should be non-negative"); + } + } + + #[test] + fn test_burke_log_with_varying_magnitudes() { + // Test with log-likelihoods of varying magnitudes + use ndarray::Array2; + + let n_sub = 8; + let n_point = 12; + + // Create varying log-likelihoods + let log_mat = Array2::from_shape_fn((n_sub, n_point), |(i, j)| { + // Range from -100 to -10, with column 0 having higher values (better fit) + if j == 0 { + -10.0 - (i as f64) + } else { + -50.0 - (i as f64) - (j as f64) + } + }); + + let mat = Mat::from_fn(n_point, n_sub, |i, j| log_mat[(j, i)]); + let psi = Psi::new_log(mat); + + let (lam, obj) = burke_log(&psi).unwrap(); + + // Check basic properties + assert_relative_eq!(lam.iter().sum::(), 1.0, epsilon = 1e-10); + assert!(obj.is_finite(), "Objective function should be finite"); + + // First column should have higher weight since it has higher log-likelihoods + assert!( + lam[0] > lam[1], + "First support point should have higher weight" + ); + } } diff --git a/src/routines/estimation/qr.rs b/src/routines/estimation/qr.rs index acc104d26..bda5e21e0 100644 --- a/src/routines/estimation/qr.rs +++ b/src/routines/estimation/qr.rs @@ -3,30 +3,70 @@ use anyhow::{bail, Result}; use faer::linalg::solvers::ColPivQr; use faer::Mat; +/// Compute log-sum-exp of a slice for numerical stability +#[inline] +fn logsumexp(values: &[f64]) -> f64 { + let max_val = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max); + if max_val.is_infinite() { + return max_val; + } + max_val + + values + .iter() + .map(|&x| (x - max_val).exp()) + .sum::() + .ln() +} + /// Perform a QR decomposition on the Psi matrix /// /// Normalizes each row of the matrix to sum to 1 before decomposition. +/// For log-space matrices, applies softmax normalization (converts log-probs to probs, +/// normalizes, then can proceed with QR). /// Returns the R matrix from QR decomposition and the column permutation vector. /// /// # Arguments -/// * `psi` - The Psi matrix to decompose +/// * `psi` - The Psi matrix to decompose (can be in regular or log space) /// /// # Returns /// * Tuple containing the R matrix (as [faer::Mat]) and permutation vector (as [Vec]) -/// * Error if any row in the matrix sums to zero +/// * Error if any row in the matrix sums to zero (or has all -inf in log space) pub fn qrd(psi: &Psi) -> Result<(Mat, Vec)> { let mut mat = psi.matrix().to_owned(); - // Normalize the rows to sum to 1 - for (index, row) in mat.row_iter_mut().enumerate() { - let row_sum: f64 = row.as_ref().iter().sum(); + match psi.space() { + crate::structs::psi::Space::Linear => { + // For regular space: normalize rows to sum to 1 + for (index, row) in mat.row_iter_mut().enumerate() { + let row_sum: f64 = row.as_ref().iter().sum(); - // Check if the row sum is zero - if row_sum.abs() == 0.0 { - bail!("In psi, the row with index {} sums to zero", index); + if row_sum.abs() == 0.0 { + bail!("In psi, the row with index {} sums to zero", index); + } + row.iter_mut().for_each(|x| *x /= row_sum); + } } - row.iter_mut().for_each(|x| *x /= row_sum); - } + crate::structs::psi::Space::Log => { + // For log-space: apply softmax normalization + // softmax(x)_i = exp(x_i) / sum(exp(x_j)) = exp(x_i - logsumexp(x)) + for index in 0..mat.nrows() { + let log_values: Vec = (0..mat.ncols()).map(|j| *mat.get(index, j)).collect(); + let log_sum = logsumexp(&log_values); + + if log_sum.is_infinite() && log_sum < 0.0 { + bail!( + "In log_psi, the row with index {} has all -inf values (zero probability)", + index + ); + } + + // Convert to normalized probabilities via softmax + for j in 0..mat.ncols() { + *mat.get_mut(index, j) = (log_values[j] - log_sum).exp(); + } + } + } + }; // Perform column pivoted QR decomposition let qr: ColPivQr = mat.col_piv_qr(); diff --git a/src/routines/math.rs b/src/routines/math.rs new file mode 100644 index 000000000..f606f2a77 --- /dev/null +++ b/src/routines/math.rs @@ -0,0 +1,178 @@ +//! Mathematical utility functions for numerical stability +//! +//! This module provides stable implementations of common numerical operations. + +/// Compute the log-sum-exp of a slice of values in a numerically stable way. +/// +/// The log-sum-exp is defined as: `log(sum(exp(x_i)))` for all elements `x_i`. +/// +/// This implementation uses the "shift by max" trick to avoid overflow: +/// `logsumexp(x) = max(x) + log(sum(exp(x_i - max(x))))` +/// +/// # Arguments +/// * `values` - A slice of f64 values (typically log-likelihoods) +/// +/// # Returns +/// The log-sum-exp of the values. Returns `f64::NEG_INFINITY` if all values are `-inf`. +/// +/// # Example +/// ```ignore +/// let log_probs = vec![-1.0, -2.0, -3.0]; +/// let result = logsumexp(&log_probs); +/// // result ≈ log(exp(-1) + exp(-2) + exp(-3)) ≈ -0.407 +/// ``` +#[inline] +pub fn logsumexp(values: &[f64]) -> f64 { + if values.is_empty() { + return f64::NEG_INFINITY; + } + + let max_val = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max); + + if max_val.is_infinite() && max_val.is_sign_negative() { + // All values are -inf, return -inf + f64::NEG_INFINITY + } else if max_val.is_infinite() && max_val.is_sign_positive() { + // At least one value is +inf + f64::INFINITY + } else { + max_val + + values + .iter() + .map(|&x| (x - max_val).exp()) + .sum::() + .ln() + } +} + +/// Compute the weighted log-sum-exp: `logsumexp(log_values + log_weights)`. +/// +/// This computes `log(sum(values_i * weights_i))` when `log_values` contains log-likelihoods. +/// Equivalent to `log(sum(exp(log_values_i) * weights_i))`. +/// +/// # Arguments +/// * `log_values` - A slice of log-values (e.g., log-likelihoods) +/// * `log_weights` - A slice of log-weights (should be same length as log_values) +/// +/// # Returns +/// The weighted log-sum-exp. Panics if slices have different lengths. +#[inline] +pub fn logsumexp_weighted(log_values: &[f64], log_weights: &[f64]) -> f64 { + assert_eq!( + log_values.len(), + log_weights.len(), + "log_values and log_weights must have the same length" + ); + + let combined: Vec = log_values + .iter() + .zip(log_weights.iter()) + .map(|(&lv, &lw)| lv + lw) + .collect(); + + logsumexp(&combined) +} + +/// Compute log-sum-exp for each row of a matrix represented as a closure. +/// +/// # Arguments +/// * `nrows` - Number of rows +/// * `ncols` - Number of columns +/// * `get_value` - Closure that returns the value at (row, col) +/// +/// # Returns +/// A vector of logsumexp values, one per row. +pub fn logsumexp_rows(nrows: usize, ncols: usize, get_value: F) -> Vec +where + F: Fn(usize, usize) -> f64, +{ + (0..nrows) + .map(|i| { + let row: Vec = (0..ncols).map(|j| get_value(i, j)).collect(); + logsumexp(&row) + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_logsumexp_basic() { + let values = vec![-1.0, -2.0, -3.0]; + let result = logsumexp(&values); + // log(exp(-1) + exp(-2) + exp(-3)) ≈ -0.4076 + let expected = ((-1.0_f64).exp() + (-2.0_f64).exp() + (-3.0_f64).exp()).ln(); + assert!((result - expected).abs() < 1e-10); + } + + #[test] + fn test_logsumexp_single_value() { + let values = vec![-5.0]; + let result = logsumexp(&values); + assert!((result - (-5.0)).abs() < 1e-10); + } + + #[test] + fn test_logsumexp_empty() { + let values: Vec = vec![]; + let result = logsumexp(&values); + assert!(result.is_infinite() && result.is_sign_negative()); + } + + #[test] + fn test_logsumexp_all_neg_inf() { + let values = vec![f64::NEG_INFINITY, f64::NEG_INFINITY]; + let result = logsumexp(&values); + assert!(result.is_infinite() && result.is_sign_negative()); + } + + #[test] + fn test_logsumexp_with_neg_inf() { + // logsumexp([-inf, 0]) = log(0 + 1) = 0 + let values = vec![f64::NEG_INFINITY, 0.0]; + let result = logsumexp(&values); + assert!((result - 0.0).abs() < 1e-10); + } + + #[test] + fn test_logsumexp_large_values() { + // Test numerical stability with large values + let values = vec![1000.0, 1001.0, 1002.0]; + let result = logsumexp(&values); + // Should be close to 1002 + log(exp(-2) + exp(-1) + 1) ≈ 1002.41 + let expected = 1002.0 + ((-2.0_f64).exp() + (-1.0_f64).exp() + 1.0).ln(); + assert!((result - expected).abs() < 1e-10); + } + + #[test] + fn test_logsumexp_very_negative() { + // Test with very negative values that would underflow with naive implementation + let values = vec![-1000.0, -1001.0, -1002.0]; + let result = logsumexp(&values); + let expected = -1000.0 + (1.0 + (-1.0_f64).exp() + (-2.0_f64).exp()).ln(); + assert!((result - expected).abs() < 1e-10); + } + + #[test] + fn test_logsumexp_weighted() { + let log_values = vec![-1.0, -2.0]; + let log_weights = vec![0.0, 0.0]; // weights = 1 + let result = logsumexp_weighted(&log_values, &log_weights); + let expected = logsumexp(&log_values); + assert!((result - expected).abs() < 1e-10); + } + + #[test] + fn test_logsumexp_rows() { + let matrix = vec![vec![-1.0, -2.0], vec![-3.0, -4.0]]; + let result = logsumexp_rows(2, 2, |i, j| matrix[i][j]); + + let expected_0 = logsumexp(&[-1.0, -2.0]); + let expected_1 = logsumexp(&[-3.0, -4.0]); + + assert!((result[0] - expected_0).abs() < 1e-10); + assert!((result[1] - expected_1).abs() < 1e-10); + } +} diff --git a/src/routines/mod.rs b/src/routines/mod.rs index af25d67e6..ac626d2c9 100644 --- a/src/routines/mod.rs +++ b/src/routines/mod.rs @@ -8,6 +8,8 @@ pub mod expansion; pub mod initialization; // Routines for logging pub mod logger; +// Mathematical utilities +pub mod math; // Routines for output pub mod output; // Routines for settings diff --git a/src/routines/output/mod.rs b/src/routines/output/mod.rs index 2a9e43d69..533137b20 100644 --- a/src/routines/output/mod.rs +++ b/src/routines/output/mod.rs @@ -22,8 +22,6 @@ pub mod cycles; pub mod posterior; pub mod predictions; -use posterior::posterior; - /// Defines the result objects from an NPAG run /// An [NPResult] contains the necessary information to generate predictions and summary statistics #[derive(Debug, Serialize)] @@ -61,7 +59,7 @@ impl NPResult { cyclelog: CycleLog, ) -> Result { // Calculate the posterior probabilities - let posterior = posterior(&psi, &w) + let posterior = Posterior::calculate(&psi, &w) .context("Failed to calculate posterior during initialization of NPResult")?; let result = Self { diff --git a/src/routines/output/posterior.rs b/src/routines/output/posterior.rs index 008ce16c1..3204acc9c 100644 --- a/src/routines/output/posterior.rs +++ b/src/routines/output/posterior.rs @@ -2,6 +2,7 @@ pub use anyhow::{bail, Result}; use faer::Mat; use serde::{Deserialize, Serialize}; +use crate::routines::math::logsumexp; use crate::structs::{psi::Psi, weights::Weights}; /// Posterior probabilities for each support points @@ -37,10 +38,37 @@ impl Posterior { } let psi_matrix = psi.matrix(); - let py = psi_matrix * w.weights(); - + let is_log_space = match psi.space() { + crate::structs::psi::Space::Linear => false, + crate::structs::psi::Space::Log => true, + }; + + // Calculate py[i] = sum_j(psi[i,j] * w[j]) for each subject i + // In log-space: py[i] = logsumexp_j(log_psi[i,j] + log(w[j])) + let py: Vec = if is_log_space { + let log_w: Vec = (0..w.len()).map(|j| w.weights().get(j).ln()).collect(); + (0..psi_matrix.nrows()) + .map(|i| { + let combined: Vec = (0..psi_matrix.ncols()) + .map(|j| *psi_matrix.get(i, j) + log_w[j]) + .collect(); + logsumexp(&combined) + }) + .collect() + } else { + let py_mat = psi_matrix * w.weights(); + (0..py_mat.nrows()).map(|i| *py_mat.get(i)).collect() + }; + + // Calculate posterior[i,j] = psi[i,j] * w[j] / py[i] + // In log-space: posterior[i,j] = exp(log_psi[i,j] + log(w[j]) - log_py[i]) let posterior = Mat::from_fn(psi_matrix.nrows(), psi_matrix.ncols(), |i, j| { - psi_matrix.get(i, j) * w.weights().get(j) / py.get(i) + if is_log_space { + let log_w_j = w.weights().get(j).ln(); + (*psi_matrix.get(i, j) + log_w_j - py[i]).exp() + } else { + psi_matrix.get(i, j) * w.weights().get(j) / py[i] + } }); Ok(posterior.into()) @@ -180,25 +208,3 @@ impl<'de> Deserialize<'de> for Posterior { deserializer.deserialize_seq(PosteriorVisitor) } } - -/// Calculates the posterior probabilities for each support point given the weights -/// -/// The shape is the same as [Psi], and thus subjects are the rows and support points are the columns. -pub fn posterior(psi: &Psi, w: &Weights) -> Result { - if psi.matrix().ncols() != w.len() { - bail!( - "Number of rows in psi ({}) and number of weights ({}) do not match.", - psi.matrix().nrows(), - w.len() - ); - } - - let psi_matrix = psi.matrix(); - let py = psi_matrix * w.weights(); - - let posterior = Mat::from_fn(psi_matrix.nrows(), psi_matrix.ncols(), |i, j| { - psi_matrix.get(i, j) * w.weights().get(j) / py.get(i) - }); - - Ok(posterior.into()) -} diff --git a/src/routines/settings.rs b/src/routines/settings.rs index 3471a683f..46dcb1515 100644 --- a/src/routines/settings.rs +++ b/src/routines/settings.rs @@ -1,6 +1,7 @@ use crate::algorithms::Algorithm; use crate::routines::initialization::Prior; use crate::routines::output::OutputFile; +use crate::structs::psi::Space; use anyhow::{bail, Result}; use pharmsol::prelude::data::ErrorModels; @@ -270,6 +271,13 @@ pub struct Advanced { /// /// This is used in the [NPOD](crate::algorithms::npod) algorithm, specifically in the [D-optimizer](crate::routines::optimization::d_optimizer) pub tolerance: f64, + /// Use log-space computations for improved numerical stability + /// + /// When true, likelihoods are computed and stored in log space throughout the algorithm. + /// This prevents underflow issues when dealing with many observations or extreme parameter values. + /// The log-sum-exp trick is used to maintain numerical stability in weighted sum operations. + /// Default is true for better numerical properties. + pub space: Space, } impl Default for Advanced { @@ -278,6 +286,7 @@ impl Default for Advanced { min_distance: 1e-4, nm_steps: 100, tolerance: 1e-6, + space: Space::Log, } } } diff --git a/src/structs/psi.rs b/src/structs/psi.rs index c63756cae..bbc986cd6 100644 --- a/src/structs/psi.rs +++ b/src/structs/psi.rs @@ -4,7 +4,7 @@ use faer::Mat; use faer_ext::IntoFaer; use faer_ext::IntoNdarray; use ndarray::{Array2, ArrayView2}; -use pharmsol::prelude::simulator::psi; +use pharmsol::prelude::simulator::{log_psi, psi}; use pharmsol::Data; use pharmsol::Equation; use pharmsol::ErrorModels; @@ -12,21 +12,95 @@ use serde::{Deserialize, Serialize}; use super::theta::Theta; +/// Enum to represent whether the [Psi] matrix is in linear space or log space +#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] +pub enum Space { + /// Linear space (for regular likelihoods) + Linear, + /// Log space (for log-likelihoods) + Log, +} + /// [Psi] is a structure that holds the likelihood for each subject (row), for each support point (column) +/// +/// The matrix can store either regular likelihoods or log-likelihoods depending on how it was constructed. +/// Use the `is_log_space` flag to determine which representation is stored. #[derive(Debug, Clone, PartialEq)] pub struct Psi { matrix: Mat, + space: Space, } impl Psi { pub fn new() -> Self { - Psi { matrix: Mat::new() } + Psi { + matrix: Mat::new(), + space: Space::Linear, + } + } + + /// Create a new Psi in log space + pub fn new_linear(mat: Mat) -> Self { + Psi { + matrix: mat, + space: Space::Linear, + } + } + + /// Create a new Psi in log space + pub fn new_log(mat: Mat) -> Self { + Psi { + matrix: mat, + space: Space::Log, + } } pub fn matrix(&self) -> &Mat { &self.matrix } + /// Get the [Space] (Linear or Log) of the Psi matrix + pub fn space(&self) -> Space { + self.space + } + + /// Set the [Space] (Linear or Log) of the Psi matrix + /// + /// Note: This does not update the actual matrix values, only the space flag. + pub fn set_space(&mut self, space: Space) { + self.space = space; + } + + /// Convert the Psi matrix to the specified [Space] (Linear or Log) + /// This modifies the matrix values accordingly. + pub fn to_space(&mut self, space: Space) -> &mut Self { + match (space, self.space) { + (Space::Linear, Space::Log) => { + // Convert from log to linear + for col in self.matrix.col_iter_mut() { + col.iter_mut().for_each(|val| { + *val = val.exp(); + }); + } + } + (Space::Log, Space::Linear) => { + // Convert from linear to log + + for col in self.matrix.col_iter_mut() { + col.iter_mut().for_each(|val| { + *val = val.ln(); + }); + } + } + _ => { + // No conversion needed + } + } + + self.space = space; + self + } + pub fn nspp(&self) -> usize { self.matrix.nrows() } @@ -101,7 +175,10 @@ impl Psi { // Create matrix from rows let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]); - Ok(Psi { matrix: mat }) + Ok(Psi { + matrix: mat, + space: Space::Linear, + }) } } @@ -114,27 +191,39 @@ impl Default for Psi { impl From> for Psi { fn from(array: Array2) -> Self { let matrix = array.view().into_faer().to_owned(); - Psi { matrix } + Psi { + matrix, + space: Space::Linear, + } } } impl From> for Psi { fn from(matrix: Mat) -> Self { - Psi { matrix } + Psi { + matrix, + space: Space::Linear, + } } } impl From> for Psi { fn from(array_view: ArrayView2<'_, f64>) -> Self { let matrix = array_view.into_faer().to_owned(); - Psi { matrix } + Psi { + matrix, + space: Space::Linear, + } } } impl From<&Array2> for Psi { fn from(array: &Array2) -> Self { let matrix = array.view().into_faer().to_owned(); - Psi { matrix } + Psi { + matrix, + space: Space::Linear, + } } } @@ -160,7 +249,7 @@ impl Serialize for Psi { } impl<'de> Deserialize<'de> for Psi { - fn deserialize(deserializer: D) -> std::result::Result + fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { @@ -208,7 +297,9 @@ impl<'de> Deserialize<'de> for Psi { // Create matrix from rows let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]); - Ok(Psi { matrix: mat }) + let psi = Psi::new_linear(mat); + + Ok(psi) } } @@ -216,6 +307,7 @@ impl<'de> Deserialize<'de> for Psi { } } +/// Calculate the likelihood matrix (regular space) pub(crate) fn calculate_psi( equation: &impl Equation, subjects: &Data, @@ -223,17 +315,33 @@ pub(crate) fn calculate_psi( error_models: &ErrorModels, progress: bool, cache: bool, + space: Space, ) -> Result { - let psi_ndarray = psi( - equation, - subjects, - &theta.matrix().clone().as_ref().into_ndarray().to_owned(), - error_models, - progress, - cache, - )?; - - Ok(psi_ndarray.view().into()) + let psi_mat = match space { + Space::Linear => psi( + equation, + subjects, + &theta.matrix().clone().as_ref().into_ndarray().to_owned(), + error_models, + progress, + cache, + ), + Space::Log => log_psi( + equation, + subjects, + &theta.matrix().clone().as_ref().into_ndarray().to_owned(), + error_models, + progress, + cache, + ), + }?; + + let psi = Psi { + matrix: psi_mat.view().into_faer().to_owned(), + space, + }; + + Ok(psi) } #[cfg(test)]