From bbb4334847dc6ed4de4a33ad78364a36eb541bb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Thu, 2 Apr 2026 10:22:10 +0100 Subject: [PATCH 1/2] feat: add parametric algorithms to unified platform --- examples/bimodal_ke_saem/main.rs | 297 +++++ examples/bimodal_ke_saem/run_saemix.R | 126 ++ src/algorithms/mod.rs | 19 +- src/algorithms/parametric/algorithm.rs | 455 +++++++ src/algorithms/parametric/focei.rs | 477 +++++++ src/algorithms/parametric/mod.rs | 16 + src/algorithms/parametric/saem.rs | 1106 +++++++++++++++++ src/api/estimation_problem.rs | 31 + src/api/fit.rs | 11 +- src/api/mod.rs | 7 +- src/api/saem_config.rs | 124 ++ src/compile/validation.rs | 12 +- src/estimation/mod.rs | 1 + src/estimation/nonparametric/engine.rs | 8 +- src/estimation/parametric/assembler.rs | 295 +++++ src/estimation/parametric/compiler.rs | 170 +++ src/estimation/parametric/effects.rs | 300 +++++ src/estimation/parametric/engine.rs | 37 + src/estimation/parametric/individual.rs | 294 +++++ src/estimation/parametric/integration.rs | 10 + .../integration/importance_sampling.rs | 284 +++++ src/estimation/parametric/likelihood.rs | 376 ++++++ src/estimation/parametric/mod.rs | 73 ++ src/estimation/parametric/population.rs | 441 +++++++ src/estimation/parametric/posthoc.rs | 157 +++ src/estimation/parametric/predictions.rs | 307 +++++ src/estimation/parametric/reporting.rs | 168 +++ src/estimation/parametric/sampling.rs | 7 + src/estimation/parametric/sampling/kernels.rs | 388 ++++++ src/estimation/parametric/state.rs | 356 ++++++ src/estimation/parametric/statistics.rs | 297 +++++ src/estimation/parametric/sufficient_stats.rs | 368 ++++++ src/estimation/parametric/summaries.rs | 64 + src/estimation/parametric/transforms.rs | 435 +++++++ src/estimation/parametric/uncertainty.rs | 240 ++++ src/estimation/parametric/workspace.rs | 426 +++++++ src/lib.rs | 43 +- src/output/mod.rs | 1 + src/output/parametric.rs | 58 + src/output/writer.rs | 47 +- src/results/artifacts.rs | 9 + src/results/diagnostics.rs | 108 ++ src/results/fit_result.rs | 23 +- src/results/mod.rs | 6 +- src/results/predictions.rs | 28 +- tests/acceptance_baseline_tests.rs | 185 ++- tests/api_smoke_tests.rs | 548 +++++++- tests/output_writer_tests.rs | 92 +- tests/parametric_compiler_tests.rs | 210 ++++ tests/parametric_workspace_tests.rs | 105 ++ tests/results_summary_tests.rs | 144 ++- tests/saem_tests.rs | 711 +++++++++++ .../saem_validation/component_reference.json | 30 + tests/saem_validation/generate_reference.R | 509 ++++++++ tests/saem_validation/mod.rs | 32 + tests/saem_validation/onecomp_iv_data.csv | 121 ++ .../saem_validation/onecomp_iv_reference.json | 97 ++ tests/saem_validation/reference.rs | 192 +++ tests/saem_validation/reference_saemix.R | 111 ++ tests/saem_validation/tests.rs | 845 +++++++++++++ tests/saem_validation/theo_data.csv | 121 ++ tests/saem_validation/theo_reference.json | 70 ++ tests/saem_validation_tests.rs | 7 + 63 files changed, 12598 insertions(+), 38 deletions(-) create mode 100644 examples/bimodal_ke_saem/main.rs create mode 100644 examples/bimodal_ke_saem/run_saemix.R create mode 100644 src/algorithms/parametric/algorithm.rs create mode 100644 src/algorithms/parametric/focei.rs create mode 100644 src/algorithms/parametric/mod.rs create mode 100644 src/algorithms/parametric/saem.rs create mode 100644 src/api/saem_config.rs create mode 100644 src/estimation/parametric/assembler.rs create mode 100644 src/estimation/parametric/compiler.rs create mode 100644 src/estimation/parametric/effects.rs create mode 100644 src/estimation/parametric/engine.rs create mode 100644 src/estimation/parametric/individual.rs create mode 100644 src/estimation/parametric/integration.rs create mode 100644 src/estimation/parametric/integration/importance_sampling.rs create mode 100644 src/estimation/parametric/likelihood.rs create mode 100644 src/estimation/parametric/mod.rs create mode 100644 src/estimation/parametric/population.rs create mode 100644 src/estimation/parametric/posthoc.rs create mode 100644 src/estimation/parametric/predictions.rs create mode 100644 src/estimation/parametric/reporting.rs create mode 100644 src/estimation/parametric/sampling.rs create mode 100644 src/estimation/parametric/sampling/kernels.rs create mode 100644 src/estimation/parametric/state.rs create mode 100644 src/estimation/parametric/statistics.rs create mode 100644 src/estimation/parametric/sufficient_stats.rs create mode 100644 src/estimation/parametric/summaries.rs create mode 100644 src/estimation/parametric/transforms.rs create mode 100644 src/estimation/parametric/uncertainty.rs create mode 100644 src/estimation/parametric/workspace.rs create mode 100644 src/output/parametric.rs create mode 100644 tests/parametric_compiler_tests.rs create mode 100644 tests/parametric_workspace_tests.rs create mode 100644 tests/saem_tests.rs create mode 100644 tests/saem_validation/component_reference.json create mode 100644 tests/saem_validation/generate_reference.R create mode 100644 tests/saem_validation/mod.rs create mode 100644 tests/saem_validation/onecomp_iv_data.csv create mode 100644 tests/saem_validation/onecomp_iv_reference.json create mode 100644 tests/saem_validation/reference.rs create mode 100644 tests/saem_validation/reference_saemix.R create mode 100644 tests/saem_validation/tests.rs create mode 100644 tests/saem_validation/theo_data.csv create mode 100644 tests/saem_validation/theo_reference.json create mode 100644 tests/saem_validation_tests.rs diff --git a/examples/bimodal_ke_saem/main.rs b/examples/bimodal_ke_saem/main.rs new file mode 100644 index 000000000..3b660b673 --- /dev/null +++ b/examples/bimodal_ke_saem/main.rs @@ -0,0 +1,297 @@ +//! Run SAEM on the bimodal_ke dataset +//! +//! This example demonstrates using the SAEM algorithm for a simple +//! one-compartment model with elimination rate constant (ke) and volume (v). +//! +//! Run with: cargo run --example bimodal_ke_saem --release + +use anyhow::Result; +use pharmsol::{ResidualErrorModel, ResidualErrorModels}; +use pmcore::prelude::*; + +/// Create analytical one-compartment model (much faster than ODE) +fn create_equation() -> equation::Analytical { + equation::Analytical::new( + |x, p, t, rateiv, _cov| { + let mut xout = x.clone(); + fetch_params!(p, ke, _v); + xout[0] = x[0] * (-ke * t).exp() + rateiv[1] / ke * (1.0 - (-ke * t).exp()); + xout + }, + |_p, _t, _cov| {}, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[1] = x[0] / v; + }, + ) +} + +fn main() -> Result<()> { + // Load data + let data = data::read_pmetrics("examples/bimodal_ke/bimodal_ke.csv")?; + println!("Loaded {} subjects", data.len()); + + // Create model + let eq = create_equation(); + + // Parameter ranges + // NPAG found: ke mean=0.191 (range 0.01-0.98), v mean=107 (range 67-209) + // SAEM needs reasonable starting bounds since it initializes at midpoint + // Use ranges that center near the expected values + let observations = ObservationSpec::new() + .add_channel(ObservationChannel::continuous(1, "cp")) + .with_residual_error_models( + ResidualErrorModels::new().add(1, ResidualErrorModel::proportional(0.1)), + ); + + let model = ModelDefinition::builder(eq) + .parameters( + ParameterSpace::new() + .add(ParameterSpec::bounded("ke", 0.01, 0.5)) + .add(ParameterSpec::bounded("v", 50.0, 180.0)), + ) + .observations(observations) + .build()?; + + println!("Running SAEM algorithm..."); + + let mut fit_result = EstimationProblem::builder(model, data) + .method(EstimationMethod::Parametric(ParametricMethod::Saem( + SaemOptions::default(), + ))) + .output(OutputPlan { + write: true, + path: Some("examples/bimodal_ke_saem/output/".to_string()), + }) + .run()?; + + // Write all output files + fit_result.write_outputs()?; + println!("\nOutput files written to: examples/bimodal_ke_saem/output/"); + + // Print comprehensive results summary (matching R saemix format) + let result = fit_result + .as_parametric() + .expect("SAEM example should produce a parametric result"); + print_saem_report(result); + + Ok(()) +} + +/// Print a comprehensive SAEM report matching R saemix output format +fn print_saem_report(result: &pmcore::prelude::ParametricWorkspace) { + let n_subjects = result.data().len(); + // Count observations from all occasions + let n_obs: usize = result + .data() + .subjects() + .iter() + .flat_map(|s| s.occasions()) + .flat_map(|o| o.events()) + .filter(|e| matches!(e, pharmsol::Event::Observation(_))) + .count(); + let param_names = result.population().param_names(); + let n_params = param_names.len(); + let mu = result.mu(); + let omega = result.omega(); + + println!("\n{}", "=".repeat(60)); + println!("{:^60}", "SAEM Algorithm Results"); + println!("{}", "=".repeat(60)); + + // Dataset characteristics + println!("\n{}", "-".repeat(60)); + println!("{:^60}", "Data"); + println!("{}", "-".repeat(60)); + println!(" Number of subjects: {}", n_subjects); + println!(" Number of observations: {}", n_obs); + println!( + " Average obs/subject: {:.1}", + n_obs as f64 / n_subjects as f64 + ); + + // Algorithm info + println!("\n{}", "-".repeat(60)); + println!("{:^60}", "Algorithm"); + println!("{}", "-".repeat(60)); + println!(" Iterations completed: {}", result.iterations()); + println!(" Status: {:?}", result.status()); + + // Fixed effects (population means) + println!("\n{}", "-".repeat(60)); + println!("{:^60}", "Fixed Effects (Population Means)"); + println!("{}", "-".repeat(60)); + println!(" {:12} {:>12} {:>12}", "Parameter", "Estimate", ""); + println!(" {:12} {:>12} {:>12}", "---------", "--------", ""); + for (i, name) in param_names.iter().enumerate() { + println!(" {:12} {:>12.4}", name, mu[i]); + } + + // Variance of random effects + println!("\n{}", "-".repeat(60)); + println!("{:^60}", "Variance of Random Effects"); + println!("{}", "-".repeat(60)); + println!(" {:12} {:>12}", "Parameter", "Estimate"); + println!(" {:12} {:>12}", "---------", "--------"); + for (i, name) in param_names.iter().enumerate() { + let var = omega[(i, i)]; + println!(" omega2.{:<4} {:>12.4}", name, var); + } + + // Covariances (if any non-zero off-diagonal) + let mut has_covariances = false; + for i in 0..n_params { + for j in (i + 1)..n_params { + if omega[(i, j)].abs() > 1e-10 { + if !has_covariances { + println!("\n Covariances:"); + has_covariances = true; + } + println!( + " cov.{}.{:<6} {:>12.4}", + param_names[i], + param_names[j], + omega[(i, j)] + ); + } + } + } + + // Correlation matrix + println!("\n{}", "-".repeat(60)); + println!("{:^60}", "Correlation Matrix of Random Effects"); + println!("{}", "-".repeat(60)); + + // Header + print!(" {:12}", ""); + for name in ¶m_names { + print!(" {:>10}", name); + } + println!(); + + // Matrix rows + for i in 0..n_params { + print!(" {:12}", param_names[i]); + for j in 0..n_params { + let sd_i = omega[(i, i)].sqrt(); + let sd_j = omega[(j, j)].sqrt(); + let corr = if sd_i > 0.0 && sd_j > 0.0 { + omega[(i, j)] / (sd_i * sd_j) + } else if i == j { + 1.0 + } else { + 0.0 + }; + print!(" {:>10.4}", corr); + } + println!(); + } + + // Residual error + println!("\n{}", "-".repeat(60)); + println!("{:^60}", "Residual Error"); + println!("{}", "-".repeat(60)); + println!(" σ estimates: {:?}", result.sigma().as_vec()); + + // Statistical criteria + println!("\n{}", "-".repeat(60)); + println!("{:^60}", "Statistical Criteria"); + println!("{}", "-".repeat(60)); + + let _ll = -result.objf() / 2.0; // Convert -2LL to LL + let n_fixed = n_params; + let n_random = n_params; // diagonal omega + let n_resid = 1; // residual error + let n_total_params = n_fixed + n_random + n_resid; + + let aic = result.objf() + 2.0 * n_total_params as f64; + let bic = result.objf() + (n_total_params as f64) * (n_subjects as f64).ln(); + + println!(" -2LL = {:.4}", result.objf()); + println!(" AIC = {:.4}", aic); + println!(" BIC = {:.4}", bic); + + // Individual parameters (first 10 subjects) on the canonical ψ-space result surface. + println!("\n{}", "-".repeat(60)); + println!("{:^60}", "Individual Parameters (first 10 subjects)"); + println!("{}", "-".repeat(60)); + + // Header + print!(" {:>4}", "ID"); + for name in ¶m_names { + print!(" {:>12}", name); + } + println!(); + + // Get individual estimates + let individuals = result.individual_estimates(); + let show_count = std::cmp::min(10, individuals.nsubjects()); + + for i in 0..show_count { + if let Some(ind) = individuals.get(i) { + print!(" {:>4}", i + 1); + for j in 0..n_params { + print!(" {:>12.6}", ind.psi()[j]); + } + println!(); + } + } + + // Summary statistics for each parameter + println!("\n{}", "-".repeat(60)); + println!("{:^60}", "Summary of Individual Estimates"); + println!("{}", "-".repeat(60)); + + for (p, name) in param_names.iter().enumerate() { + let mut values: Vec = Vec::new(); + for i in 0..individuals.nsubjects() { + if let Some(ind) = individuals.get(i) { + values.push(ind.psi()[p]); + } + } + + if !values.is_empty() { + values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + let n = values.len(); + let min = values[0]; + let max = values[n - 1]; + let mean: f64 = values.iter().sum::() / n as f64; + let median = if n % 2 == 0 { + (values[n / 2 - 1] + values[n / 2]) / 2.0 + } else { + values[n / 2] + }; + let q1 = values[n / 4]; + let q3 = values[3 * n / 4]; + let variance: f64 = values.iter().map(|x| (x - mean).powi(2)).sum::() / n as f64; + let sd = variance.sqrt(); + + println!("\n --- {} ---", name); + println!(" Min: {:>10.5}", min); + println!(" 1st Qu: {:>10.5}", q1); + println!(" Median: {:>10.5}", median); + println!(" Mean: {:>10.5}", mean); + println!(" 3rd Qu: {:>10.5}", q3); + println!(" Max: {:>10.5}", max); + println!(" SD: {:>10.5}", sd); + } + } + + // Derived statistics + println!("\n{}", "-".repeat(60)); + println!("{:^60}", "Population Parameter Summary"); + println!("{}", "-".repeat(60)); + println!(" {:12} {:>10} {:>10}", "Parameter", "SD(ω)", "CV(%)"); + println!(" {:12} {:>10} {:>10}", "---------", "------", "-----"); + let cvs = result.cv_percent(); + for (i, name) in param_names.iter().enumerate() { + let omega2 = omega[(i, i)]; + let sd = omega2.sqrt(); + println!(" {:12} {:>10.4} {:>10.1}", name, sd, cvs[i]); + } + + println!("\n{}", "=".repeat(60)); +} diff --git a/examples/bimodal_ke_saem/run_saemix.R b/examples/bimodal_ke_saem/run_saemix.R new file mode 100644 index 000000000..6db68f9d8 --- /dev/null +++ b/examples/bimodal_ke_saem/run_saemix.R @@ -0,0 +1,126 @@ +# Run bimodal_ke dataset with R saemix +# This script tests the R SAEM implementation on the bimodal_ke dataset +# to compare with the Rust implementation + +library(saemix) + +# Read and prepare the data +raw_data <- read.csv("../bimodal_ke/bimodal_ke.csv") + +# Filter observation records (EVID == 0) and get necessary columns +obs_data <- raw_data[raw_data$EVID == 0, ] + +# saemix needs: ID, TIME, dose, observation +# Get doses for each subject +dose_data <- raw_data[raw_data$EVID == 1, c("ID", "DOSE")] +names(dose_data) <- c("ID", "DOSE") + +# Merge dose with observations +saemix_data <- merge(obs_data[, c("ID", "TIME", "OUT")], dose_data, by = "ID") +saemix_data <- saemix_data[order(saemix_data$ID, saemix_data$TIME), ] +names(saemix_data) <- c("id", "time", "conc", "dose") + +# Ensure numeric types +saemix_data$id <- as.integer(saemix_data$id) +saemix_data$time <- as.numeric(saemix_data$time) +saemix_data$conc <- as.numeric(saemix_data$conc) +saemix_data$dose <- as.numeric(saemix_data$dose) + +cat("Data summary:\n") +cat("Number of subjects:", length(unique(saemix_data$id)), "\n") +cat("Total observations:", nrow(saemix_data), "\n") +cat("\nData types:\n") +print(sapply(saemix_data, class)) +cat("\nFirst 20 rows:\n") +print(head(saemix_data, 20)) + +# Create saemix data object +saemix.data <- saemixData( + name.data = saemix_data, + header = TRUE, + name.group = c("id"), + name.predictors = c("dose", "time"), + name.response = c("conc") +) + +# Define the one-compartment IV bolus model +# For IV bolus: C = (Dose/V) * exp(-ke * t) +# Note: This dataset uses a 0.5h infusion, but we'll approximate as bolus +# The model: C = (Dose/V) * exp(-ke * t) +one_cpt_model <- function(psi, id, xidep) { + dose <- xidep[, 1] + time <- xidep[, 2] + ke <- psi[id, 1] + V <- psi[id, 2] + + # One compartment with IV bolus + ypred <- (dose / V) * exp(-ke * time) + return(ypred) +} + +# Create saemix model +# psi0: initial estimates [ke, V] +# transform.par: 1 = log transform (lognormal distribution) +# NPAG found: ke mean=0.191, v mean=107 +saemix.model <- saemixModel( + model = one_cpt_model, + modeltype = "structural", + description = "One-compartment IV bolus model", + psi0 = matrix(c(0.2, 110), + ncol = 2, byrow = TRUE, + dimnames = list(NULL, c("ke", "V")) + ), + transform.par = c(1, 1), # 1 = log-transform (lognormal) + covariance.model = matrix(c(1, 0, 0, 1), ncol = 2, byrow = TRUE), + omega.init = matrix(c(0.5, 0, 0, 0.5), ncol = 2, byrow = TRUE), + error.model = "proportional" # gamma * ypred +) + +# Run SAEM +cat("\n\nRunning SAEM algorithm...\n\n") +saemix.fit <- saemix( + saemix.model, + saemix.data, + list( + seed = 12345, + nbiter.saemix = c(300, 100), # K1 burn-in, K2 estimation + nb.chains = 3, + print = TRUE, + save = FALSE, + save.graphs = FALSE + ) +) + +# Print results +cat("\n\n========== SAEM Results ==========\n") +print(saemix.fit) + +# Get population parameters +cat("\n--- Population Parameters ---\n") +cat("Fixed effects (mu):\n") +print(saemix.fit@results@fixed.effects) +cat("\nOmega (variance of random effects):\n") +print(saemix.fit@results@omega) +cat("\nResidual error:\n") +print(saemix.fit@results@respar) + +# Get individual parameters +cat("\n--- Individual Parameters (first 10 subjects) ---\n") +indiv_params <- psi(saemix.fit) +print(head(indiv_params, 10)) + +# Summary statistics of individual ke +cat("\n--- Summary of individual ke estimates ---\n") +print(summary(indiv_params$ke)) + +# Check for bimodality +cat("\n--- Distribution of ke ---\n") +cat("Mean:", mean(indiv_params$ke), "\n") +cat("Median:", median(indiv_params$ke), "\n") +cat("SD:", sd(indiv_params$ke), "\n") +cat("Min:", min(indiv_params$ke), "\n") +cat("Max:", max(indiv_params$ke), "\n") + +# Objective function +cat("\n--- Objective Function ---\n") +cat("-2LL:", -2 * saemix.fit@results@ll.is, "\n") diff --git a/src/algorithms/mod.rs b/src/algorithms/mod.rs index 74aff54af..512ef36d5 100644 --- a/src/algorithms/mod.rs +++ b/src/algorithms/mod.rs @@ -18,6 +18,7 @@ use serde::{Deserialize, Serialize}; // Module organization for algorithm types pub mod nonparametric; +pub mod parametric; #[derive(Debug, Clone)] pub(crate) struct NonparametricAlgorithmInput { @@ -105,26 +106,38 @@ impl NonparametricAlgorithmInput { /// Algorithm type enumeration /// -/// This enum represents the algorithms available in the structure branch. +/// This enum represents all available algorithms in PMcore, both non-parametric and parametric. #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)] pub enum Algorithm { + // === Non-parametric algorithms === /// Non-Parametric Adaptive Grid NPAG, /// Non-Parametric Optimal Design NPOD, /// Posterior Probability calculation POSTPROB, + + // === Parametric algorithms === + /// Stochastic Approximation Expectation-Maximization + SAEM, + /// First-Order Conditional Estimation with Interaction + FOCEI, + /// Iterative Two-Stage Bayesian + IT2B, } impl Algorithm { /// Check if this is a non-parametric algorithm pub fn is_nonparametric(&self) -> bool { - matches!(self, Algorithm::NPAG | Algorithm::NPOD | Algorithm::POSTPROB) + matches!( + self, + Algorithm::NPAG | Algorithm::NPOD | Algorithm::POSTPROB + ) } /// Check if this is a parametric algorithm pub fn is_parametric(&self) -> bool { - false + matches!(self, Algorithm::SAEM | Algorithm::FOCEI | Algorithm::IT2B) } } diff --git a/src/algorithms/parametric/algorithm.rs b/src/algorithms/parametric/algorithm.rs new file mode 100644 index 000000000..cb089a65f --- /dev/null +++ b/src/algorithms/parametric/algorithm.rs @@ -0,0 +1,455 @@ +//! Parametric algorithm trait definition +//! +//! This module defines the [`ParametricAlgorithm`] trait that all parametric +//! population estimation algorithms must implement. + +use std::fs; +use std::path::Path; + +use anyhow::{Context, Result}; +use pharmsol::{Data, Equation, ResidualErrorModels}; + +use crate::api::{ + EstimationMethod, EstimationProblem, OutputPlan, ParametricMethod, RuntimeOptions, SaemConfig, +}; +use crate::compile::CompiledProblem; +use crate::estimation::parametric::{ + build_parametric_covariate_context, IndividualEstimates, + ParameterTransform as AlgorithmParameterTransform, ParametricCovariateContext, + ParametricWorkspace, Population, SufficientStats, +}; +use crate::model::{ParameterDomain, ParameterSpace}; +use crate::output::shared::RunConfiguration; + +use super::super::Status; + +#[derive(Debug, Clone)] +pub(crate) struct ParametricAlgorithmInput { + pub method: ParametricMethod, + pub equation: E, + pub data: Data, + pub parameter_space: ParameterSpace, + pub covariate_context: ParametricCovariateContext, + pub residual_error_models: ResidualErrorModels, + pub output: OutputPlan, + pub runtime: RuntimeOptions, +} + +impl ParametricAlgorithmInput { + pub(crate) fn from_compiled_problem(problem: CompiledProblem) -> Result { + let method = match problem.method() { + EstimationMethod::Parametric(method) => method, + other => anyhow::bail!( + "parametric dispatcher received non-parametric method: {:?}", + other + ), + }; + + let output = problem.output_plan().clone(); + let runtime = problem.runtime_options().clone(); + let covariate_context = build_parametric_covariate_context( + &problem.model.covariates, + &problem.design.structured_covariates, + ); + let (model, data) = problem.into_parts(); + + let residual_error_models = model + .observations + .residual_error_models + .clone() + .ok_or_else(|| { + anyhow::anyhow!("parametric algorithms require residual_error_models") + })?; + + Ok(Self { + method, + equation: model.equation, + data, + parameter_space: model.parameters, + covariate_context, + residual_error_models, + output, + runtime, + }) + } + + pub(crate) fn algorithm(&self) -> crate::algorithms::Algorithm { + self.method.algorithm() + } + + pub(crate) fn run_configuration(&self) -> RunConfiguration { + RunConfiguration::new( + self.algorithm(), + &self.output, + &self.runtime, + self.parameter_space + .iter() + .map(|parameter| parameter.name.clone()) + .collect(), + ) + } + + pub(crate) fn saem_config(&self) -> &SaemConfig { + &self.runtime.tuning.saem + } + + pub(crate) fn initial_population(&self) -> Result { + Population::from_parameter_space(self.parameter_space.clone()) + } + + pub(crate) fn parameter_transforms(&self) -> Vec { + self.parameter_space + .iter() + .map(to_parameter_transform) + .collect() + } +} + +/// Configuration specific to parametric algorithms +#[derive(Debug, Clone)] +pub struct ParametricConfig { + /// Maximum number of iterations + pub max_iterations: usize, + /// Number of burn-in iterations (for SAEM) + pub burn_in: usize, + /// Number of MCMC chains per subject + pub n_chains: usize, + /// Number of samples per chain per iteration + pub n_samples: usize, + /// Convergence tolerance for parameters + pub parameter_tolerance: f64, + /// Convergence tolerance for objective function + pub objective_tolerance: f64, + /// Whether to use simulated annealing in SAEM + pub use_annealing: bool, + /// Initial temperature for simulated annealing + pub initial_temperature: f64, +} + +impl Default for ParametricConfig { + fn default() -> Self { + Self { + max_iterations: 500, + burn_in: 200, + n_chains: 1, + n_samples: 1, + parameter_tolerance: 1e-4, + objective_tolerance: 1e-4, + use_annealing: false, + initial_temperature: 1.0, + } + } +} + +/// Trait defining the interface for parametric population algorithms +/// +/// This trait provides the common structure for algorithms that estimate +/// population parameters assuming a continuous (typically multivariate normal) +/// distribution for the random effects. +/// +/// # Algorithm Workflow +/// +/// 1. **Initialize**: Set up initial population parameters +/// 2. **E-step**: Compute or sample from conditional distribution p(η|y,θ) +/// 3. **M-step**: Update population parameters from E-step results +/// 4. **Evaluate**: Check convergence criteria +/// 5. **Repeat** until convergence or max iterations +/// +/// # Type Parameters +/// +/// * `E` - The equation type implementing pharmacokinetic/pharmacodynamic model +pub trait ParametricAlgorithm: Sync + Send { + /// Get the equation/model + fn equation(&self) -> &E; + + /// Get the data + fn data(&self) -> &Data; + + /// Get the current population parameters + fn population(&self) -> &Population; + + /// Get a mutable reference to population parameters + fn population_mut(&mut self) -> &mut Population; + + /// Get the current individual estimates + fn individual_estimates(&self) -> &IndividualEstimates; + + /// Get the current iteration number + fn iteration(&self) -> usize; + + /// Increment the iteration counter and return new value + fn increment_iteration(&mut self) -> usize; + + /// Get the current objective function value (-2LL) + fn objective_function(&self) -> f64; + + /// Get the current algorithm status + fn status(&self) -> &Status; + + /// Set the algorithm status + fn set_status(&mut self, status: Status); + + // ========== Algorithm Steps ========== + + /// Initialize the algorithm + /// + /// Sets up initial population parameters, prepares data structures, + /// and performs any pre-processing required before the main loop. + fn initialize(&mut self) -> Result<()> { + // Remove stop file if it exists + if Path::new("stop").exists() { + tracing::info!("Removing existing stop file prior to run"); + fs::remove_file("stop").context("Unable to remove previous stop file")?; + } + self.set_status(Status::Continue); + Ok(()) + } + + /// Perform the E-step (Expectation step) + /// + /// This step computes or samples from the conditional distribution of + /// individual parameters given the observations and current population parameters. + /// + /// - **FOCEI**: Finds the MAP estimate (mode) with a local curvature approximation + /// - **SAEM**: Samples from p(η|y,θ) using MCMC + fn e_step(&mut self) -> Result<()>; + + /// Perform the M-step (Maximization step) + /// + /// Updates the population parameters (μ, Ω) based on the E-step results. + /// + /// - **FOCEI**: Uses the subject modes and local curvature information + /// - **SAEM**: Uses sufficient statistics from MCMC samples + fn m_step(&mut self) -> Result<()>; + + /// Evaluate convergence and update status + /// + /// Checks various convergence criteria and determines whether to continue + /// or stop the algorithm. + fn evaluate(&mut self) -> Result; + + /// Log the current iteration state + fn log_iteration(&mut self); + + /// Perform a single iteration of the algorithm + /// + /// Default implementation calls E-step, M-step, logging, and evaluation. + fn next_iteration(&mut self) -> Result { + let iter = self.increment_iteration(); + + let span = tracing::info_span!("", "{}", format!("Iteration {}", iter)); + let _enter = span.enter(); + + self.e_step()?; + self.m_step()?; + self.log_iteration(); + self.evaluate() + } + + /// Run the full estimation procedure + /// + /// Initializes the algorithm and iterates until convergence or stopping criteria. + fn fit(&mut self) -> Result> { + self.initialize()?; + + loop { + match self.next_iteration()? { + Status::Continue => continue, + Status::Stop(_) => break, + } + } + + self.into_result() + } + + /// Convert the algorithm state into a result object + fn into_result(&self) -> Result>; + + // ========== Optional Methods ========== + + /// Get sufficient statistics (for SAEM-like algorithms) + fn sufficient_stats(&self) -> Option<&SufficientStats> { + None + } + + /// Perform optimization of error model parameters + /// + /// Some algorithms may optimize error model parameters alongside population parameters. + fn optimize_error_model(&mut self) -> Result<()> { + // Default: no optimization + Ok(()) + } + + /// Apply constraints to population parameters + /// + /// Ensures parameters stay within bounds and covariance matrix remains positive definite. + fn apply_constraints(&mut self) -> Result<()> { + // Default: no additional constraints + Ok(()) + } +} + +/// Dispatch function for parametric algorithms +/// +/// Creates the appropriate algorithm instance based on settings. +pub fn dispatch_parametric_algorithm( + problem: EstimationProblem, +) -> Result>> { + let compiled = problem.compile()?; + let input = ParametricAlgorithmInput::from_compiled_problem(compiled)?; + + dispatch_parametric_algorithm_input(input) +} + +pub(crate) fn dispatch_parametric_algorithm_input( + input: ParametricAlgorithmInput, +) -> Result>> { + use super::focei::FoceiAlgorithm; + use super::saem::FSAEM; + + match input.method { + ParametricMethod::Saem(_) => { + let saem = FSAEM::create(input)?; + Ok(saem as Box>) + } + ParametricMethod::Focei(_) => { + let focei = FoceiAlgorithm::create(input)?; + Ok(focei as Box>) + } + ParametricMethod::It2b(_) => { + // TODO: Implement IT2B + anyhow::bail!("IT2B algorithm not yet implemented") + } + } +} + +pub(crate) fn run_parametric_algorithm( + input: ParametricAlgorithmInput, +) -> Result> { + let mut algorithm = dispatch_parametric_algorithm_input(input)?; + algorithm.fit() +} + +fn to_parameter_transform(parameter: &crate::model::ParameterSpec) -> AlgorithmParameterTransform { + match parameter.transform { + crate::model::ParameterTransform::Identity => AlgorithmParameterTransform::None, + crate::model::ParameterTransform::LogNormal => AlgorithmParameterTransform::LogNormal, + crate::model::ParameterTransform::Logit => { + let (lower, upper) = bounded_domain(parameter); + AlgorithmParameterTransform::Logit { lower, upper } + } + crate::model::ParameterTransform::Probit => { + let (lower, upper) = bounded_domain(parameter); + AlgorithmParameterTransform::Probit { lower, upper } + } + } +} + +fn bounded_domain(parameter: &crate::model::ParameterSpec) -> (f64, f64) { + match parameter.domain { + ParameterDomain::Bounded { lower, upper } => (lower, upper), + ParameterDomain::Positive { lower, upper } => (lower.unwrap_or(0.0), upper.unwrap_or(1.0)), + ParameterDomain::Unbounded { lower, upper } => (lower.unwrap_or(0.0), upper.unwrap_or(1.0)), + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use super::ParametricAlgorithmInput; + use anyhow::Result; + use pharmsol::{AssayErrorModel, ErrorPoly, ResidualErrorModel, ResidualErrorModels, Subject}; + + use crate::api::{EstimationMethod, EstimationProblem, ParametricMethod, SaemOptions}; + use crate::model::{ + CovariateEffectsSpec, CovariateSpec, ModelDefinition, ObservationChannel, ObservationSpec, + ParameterSpace, ParameterSpec, + }; + use crate::prelude::*; + + fn equation() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, b, _rateiv, _cov| { + fetch_params!(p, ke); + dx[0] = -ke * x[0] + b[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, v); + y[0] = x[0] / v; + }, + ) + } + + #[test] + fn compiled_parametric_input_preserves_structured_covariates() -> Result<()> { + let data = pharmsol::Data::new(vec![Subject::builder("1") + .covariate("wt", 0.0, 70.0) + .covariate("study_day", 0.0, 1.0) + .bolus(0.0, 100.0, 0) + .observation(1.0, 10.0, 0) + .reset() + .covariate("wt", 0.0, 70.0) + .covariate("study_day", 0.0, 2.0) + .bolus(0.0, 100.0, 0) + .observation(1.0, 8.0, 0) + .build()]); + let assay_error = AssayErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0); + let residual_error = + ResidualErrorModels::new().add(0, ResidualErrorModel::combined(0.5, 0.1)); + let observations = ObservationSpec::new() + .add_channel(ObservationChannel::continuous(0, "cp")) + .with_assay_error_models(AssayErrorModels::new().add(0, assay_error)?) + .with_residual_error_models(residual_error); + + let model = ModelDefinition::builder(equation()) + .parameters( + ParameterSpace::new() + .add(ParameterSpec::positive("ke")) + .add(ParameterSpec::positive("v")), + ) + .observations(observations) + .covariates(CovariateSpec::Structured(CovariateEffectsSpec { + subject_effects: Some(CovariateModel::new( + vec!["ke", "v"], + vec!["wt"], + vec![vec![true], vec![false]], + )?), + occasion_effects: Some(CovariateModel::new( + vec!["ke", "v"], + vec!["study_day"], + vec![vec![true], vec![false]], + )?), + })) + .build()?; + + let compiled = EstimationProblem::builder(model, data) + .method(EstimationMethod::Parametric(ParametricMethod::Saem( + SaemOptions, + ))) + .build()? + .compile()?; + + let input = ParametricAlgorithmInput::from_compiled_problem(compiled)?; + + assert!(input.covariate_context.subject_model.is_some()); + assert!(input.covariate_context.occasion_model.is_some()); + assert_eq!(input.covariate_context.subject_covariates.len(), 1); + assert_eq!(input.covariate_context.occasion_covariates.len(), 2); + + let expected = HashMap::from([(String::from("wt"), 70.0)]); + assert_eq!(input.covariate_context.subject_covariates[0], expected); + assert_eq!( + input.covariate_context.occasion_covariates[0], + HashMap::from([(String::from("study_day"), 1.0)]) + ); + assert_eq!( + input.covariate_context.occasion_covariates[1], + HashMap::from([(String::from("study_day"), 2.0)]) + ); + Ok(()) + } +} diff --git a/src/algorithms/parametric/focei.rs b/src/algorithms/parametric/focei.rs new file mode 100644 index 000000000..60254fd96 --- /dev/null +++ b/src/algorithms/parametric/focei.rs @@ -0,0 +1,477 @@ +//! FOCEI (First-Order Conditional Estimation with Interaction) algorithm. + +use std::collections::HashMap; + +use anyhow::Result; +use faer::linalg::solvers::DenseSolveCore; +use faer::{Col, Mat}; +use ndarray::Array2; +use pharmsol::{Data, Equation, ResidualErrorModels}; + +use crate::algorithms::{Status, StopReason}; +use crate::estimation::parametric::{ + assemble_parametric_result, batch_log_likelihood_from_eta, covariance_from_individual_etas, + covariance_from_subject_means, covariate_state, ensure_positive_definite_covariance, + estimate_beta, focei_linearization_uncertainty, recenter_individual_estimates, + residual_error_estimates_from_models, subject_mean_phi, subject_objective_from_eta, Individual, + IndividualEstimates, LikelihoodEstimates, ParameterTransform, ParametricCovariateContext, + ParametricIterationLog, ParametricResultInput, ParametricWorkspace, Population, +}; +use crate::model::CovariateModel; +use crate::output::shared::RunConfiguration; + +use super::algorithm::{ParametricAlgorithm, ParametricAlgorithmInput, ParametricConfig}; + +pub struct FoceiAlgorithm { + equation: E, + data: Data, + run_configuration: RunConfiguration, + population: Population, + individual_estimates: IndividualEstimates, + iteration: usize, + objf: f64, + prev_objf: f64, + status: Status, + config: ParametricConfig, + transforms: Vec, + residual_error_models: ResidualErrorModels, + iteration_log: ParametricIterationLog, + subject_covariate_model: Option, + subject_covariates: Vec>, + occasion_covariate_model: Option, + occasion_covariates: Vec>, +} + +impl FoceiAlgorithm { + pub(crate) fn create(input: ParametricAlgorithmInput) -> Result> { + let run_configuration = input.run_configuration(); + let population = input.initial_population()?; + let transforms = input.parameter_transforms(); + let residual_error_models = input.residual_error_models.clone(); + let config = ParametricConfig { + max_iterations: input.runtime.cycles.max(1), + objective_tolerance: input.runtime.convergence.likelihood, + ..ParametricConfig::default() + }; + let ParametricAlgorithmInput { + equation, + data, + covariate_context, + .. + } = input; + let ParametricCovariateContext { + subject_model: mut subject_covariate_model, + occasion_model: occasion_covariate_model, + subject_covariates, + occasion_covariates, + } = covariate_context; + + if let Some(model) = subject_covariate_model.as_mut() { + let initialize_intercepts = + (0..model.beta().nrows()).all(|index| model.beta()[index].abs() < 1e-12); + if initialize_intercepts { + let intercepts = (0..population.npar()) + .map(|index| population.mu()[index]) + .collect::>(); + model.set_intercepts(&intercepts)?; + } + } + + Ok(Box::new(Self { + equation, + data, + run_configuration, + population, + individual_estimates: IndividualEstimates::new(), + iteration: 0, + objf: f64::INFINITY, + prev_objf: f64::INFINITY, + status: Status::Continue, + config, + transforms, + residual_error_models, + iteration_log: ParametricIterationLog::new(), + subject_covariate_model, + subject_covariates, + occasion_covariate_model, + occasion_covariates, + })) + } + + fn current_eta_matrix(&self) -> Array2 { + let n_subjects = self.data.subjects().len(); + let n_params = self.population.npar(); + let mut eta_matrix = Array2::zeros((n_subjects, n_params)); + + for subject_index in 0..n_subjects { + if let Some(individual) = self.individual_estimates.get(subject_index) { + for param_index in 0..n_params { + eta_matrix[[subject_index, param_index]] = individual.eta()[param_index]; + } + } + } + + eta_matrix + } + + fn invert_omega(&self) -> Mat { + match self.population.omega().clone().llt(faer::Side::Lower) { + Ok(cholesky) => cholesky.inverse(), + Err(_) => { + let n_params = self.population.npar(); + Mat::from_fn(n_params, n_params, |row, col| { + if row == col { + 1.0 / self.population.omega()[(row, row)].max(1e-8) + } else { + 0.0 + } + }) + } + } + } + + fn find_map_estimate( + &self, + subject_index: usize, + eta_matrix: &mut Array2, + omega_inv: &Mat, + subject_means: &[Col], + ) -> Result { + let n_params = self.population.npar(); + let mut current_ll = batch_log_likelihood_from_eta( + &self.equation, + &self.data, + &self.residual_error_models, + &self.transforms, + eta_matrix, + subject_means, + )?[subject_index]; + let mut current_objf = + subject_objective_from_eta(subject_index, eta_matrix, current_ll, omega_inv); + let mut step_sizes = (0..n_params) + .map(|param_index| { + self.population.omega()[(param_index, param_index)] + .sqrt() + .max(1e-3) + * 0.5 + }) + .collect::>(); + + for _ in 0..4 { + let mut improved = false; + + for param_index in 0..n_params { + let baseline = eta_matrix[[subject_index, param_index]]; + let mut best_value = baseline; + let mut best_ll = current_ll; + let mut best_objf = current_objf; + + for direction in [-1.0, 1.0] { + let mut candidate = eta_matrix.clone(); + candidate[[subject_index, param_index]] = + baseline + direction * step_sizes[param_index]; + let proposed_ll = batch_log_likelihood_from_eta( + &self.equation, + &self.data, + &self.residual_error_models, + &self.transforms, + &candidate, + subject_means, + )?[subject_index]; + let proposed_objf = subject_objective_from_eta( + subject_index, + &candidate, + proposed_ll, + omega_inv, + ); + + if proposed_objf < best_objf { + best_value = candidate[[subject_index, param_index]]; + best_ll = proposed_ll; + best_objf = proposed_objf; + } + } + + if best_objf < current_objf { + eta_matrix[[subject_index, param_index]] = best_value; + current_ll = best_ll; + current_objf = best_objf; + improved = true; + } + } + + if improved { + continue; + } + + for step_size in &mut step_sizes { + *step_size *= 0.5; + } + + if step_sizes.iter().all(|step_size| *step_size < 1e-6) { + break; + } + } + + let eta = Col::from_fn(n_params, |param_index| { + eta_matrix[[subject_index, param_index]] + }); + let phi = Col::from_fn(n_params, |param_index| { + subject_means[subject_index][param_index] + eta[param_index] + }); + let mut individual = + Individual::new(self.data.subjects()[subject_index].id().clone(), eta, phi)?; + individual.set_objective_function(current_objf); + Ok(individual) + } + + fn update_population_parameters(&mut self) -> Result<()> { + let n_subjects = self.individual_estimates.nsubjects(); + let n_params = self.population.npar(); + + if n_subjects == 0 { + return Ok(()); + } + + if let Some(model) = self.subject_covariate_model.clone() { + let target_beta = + estimate_beta(&model, &self.subject_covariates, &self.individual_estimates)?; + let mut updated_model = model; + updated_model.set_beta(target_beta)?; + + let mu = Col::from_fn(n_params, |index| { + updated_model + .intercept(index) + .unwrap_or(self.population.mu()[index]) + }); + let subject_means = subject_mean_phi( + &mu, + n_subjects, + Some(&updated_model), + &self.subject_covariates, + ); + + self.population.update_mu(mu)?; + self.individual_estimates = + recenter_individual_estimates(&self.individual_estimates, &subject_means)?; + self.subject_covariate_model = Some(updated_model); + + let omega = covariance_from_subject_means(&self.individual_estimates, &subject_means)?; + self.population + .update_omega(ensure_positive_definite_covariance(&omega))?; + } else { + let mu = Col::from_fn(n_params, |param_index| { + self.individual_estimates + .iter() + .map(|individual| individual.psi()[param_index]) + .sum::() + / n_subjects as f64 + }); + + let subject_means = subject_mean_phi(&mu, n_subjects, None, &self.subject_covariates); + self.population.update_mu(mu)?; + self.individual_estimates = + recenter_individual_estimates(&self.individual_estimates, &subject_means)?; + + let omega = covariance_from_individual_etas(&self.individual_estimates); + self.population + .update_omega(ensure_positive_definite_covariance(&omega))?; + } + + let eta_matrix = self.current_eta_matrix(); + let subject_means = subject_mean_phi( + self.population.mu(), + self.data.subjects().len(), + self.subject_covariate_model.as_ref(), + &self.subject_covariates, + ); + let omega_inv = self.invert_omega(); + let log_likelihoods = batch_log_likelihood_from_eta( + &self.equation, + &self.data, + &self.residual_error_models, + &self.transforms, + &eta_matrix, + &subject_means, + )?; + let mut updated = Vec::with_capacity(n_subjects); + let mut total_objf = 0.0; + + for subject_index in 0..n_subjects { + let individual = self.individual_estimates.get(subject_index).unwrap(); + let objf = subject_objective_from_eta( + subject_index, + &eta_matrix, + log_likelihoods[subject_index], + &omega_inv, + ); + let mut rebuilt = Individual::new( + individual.subject_id().to_string(), + individual.eta().clone(), + individual.psi().clone(), + )?; + rebuilt.set_objective_function(objf); + total_objf += objf; + updated.push(rebuilt); + } + + self.individual_estimates = IndividualEstimates::from_vec(updated); + self.prev_objf = self.objf; + self.objf = total_objf; + + Ok(()) + } +} + +impl ParametricAlgorithm for FoceiAlgorithm { + fn equation(&self) -> &E { + &self.equation + } + + fn data(&self) -> &Data { + &self.data + } + + fn population(&self) -> &Population { + &self.population + } + + fn population_mut(&mut self) -> &mut Population { + &mut self.population + } + + fn individual_estimates(&self) -> &IndividualEstimates { + &self.individual_estimates + } + + fn iteration(&self) -> usize { + self.iteration + } + + fn increment_iteration(&mut self) -> usize { + self.iteration += 1; + self.iteration + } + + fn objective_function(&self) -> f64 { + self.objf + } + + fn status(&self) -> &Status { + &self.status + } + + fn set_status(&mut self, status: Status) { + self.status = status; + } + + fn e_step(&mut self) -> Result<()> { + let subjects = self.data.subjects(); + let mut individuals = Vec::with_capacity(subjects.len()); + let mut eta_matrix = self.current_eta_matrix(); + let omega_inv = self.invert_omega(); + let subject_means = subject_mean_phi( + self.population.mu(), + self.data.subjects().len(), + self.subject_covariate_model.as_ref(), + &self.subject_covariates, + ); + + for (subject_index, _subject) in subjects.iter().enumerate() { + let individual = + self.find_map_estimate(subject_index, &mut eta_matrix, &omega_inv, &subject_means)?; + individuals.push(individual); + } + + self.individual_estimates = IndividualEstimates::from_vec(individuals); + Ok(()) + } + + fn m_step(&mut self) -> Result<()> { + self.update_population_parameters() + } + + fn evaluate(&mut self) -> Result { + if std::path::Path::new("stop").exists() { + self.status = Status::Stop(StopReason::Stopped); + return Ok(self.status.clone()); + } + + if self.iteration >= self.config.max_iterations { + self.status = Status::Stop(StopReason::MaxCycles); + return Ok(self.status.clone()); + } + + if self.prev_objf.is_finite() { + let objf_change = (self.objf - self.prev_objf).abs(); + let relative_change = objf_change / self.prev_objf.abs().max(1.0); + if relative_change < self.config.objective_tolerance { + self.status = Status::Stop(StopReason::Converged); + return Ok(self.status.clone()); + } + } + + if !self.objf.is_finite() { + self.status = Status::Stop(StopReason::Converged); + return Ok(self.status.clone()); + } + + self.status = Status::Continue; + Ok(self.status.clone()) + } + + fn log_iteration(&mut self) { + self.iteration_log + .log_iteration(self.iteration, self.objf, &self.population, &self.status); + tracing::info!( + "FOCEI iteration {}: -2LL = {:.4} (change: {:.4})", + self.iteration, + self.objf, + self.objf - self.prev_objf + ); + + tracing::debug!("Population mean (mu): {:?}", self.population.mu()); + tracing::debug!("Population SD: {:?}", self.population.standard_deviations()); + + let pop_var = faer::Col::from_fn(self.population.npar(), |index| { + self.population.omega()[(index, index)] + }); + if let Some(shrinkage) = self.individual_estimates.shrinkage(&pop_var) { + tracing::debug!("Shrinkage: {:?}", shrinkage); + } + } + + fn into_result(&self) -> Result> { + let likelihoods = LikelihoodEstimates { + ll_linearization: Some(-self.objf / 2.0), + ..LikelihoodEstimates::new() + }; + let uncertainty_estimates = focei_linearization_uncertainty( + &self.population, + self.individual_estimates.nsubjects(), + ); + let sigma = residual_error_estimates_from_models(&self.residual_error_models); + + assemble_parametric_result(ParametricResultInput { + equation: &self.equation, + data: &self.data, + population: &self.population, + individual_estimates: &self.individual_estimates, + objf: self.objf, + iterations: self.iteration, + status: &self.status, + run_configuration: self.run_configuration.clone(), + iteration_log: self.iteration_log.clone(), + likelihood_estimates: likelihoods, + uncertainty_estimates, + sigma, + transforms: &self.transforms, + covariates: Some(covariate_state( + self.subject_covariate_model.as_ref(), + &self.subject_covariates, + self.occasion_covariate_model.as_ref(), + &self.occasion_covariates, + )), + }) + } +} diff --git a/src/algorithms/parametric/mod.rs b/src/algorithms/parametric/mod.rs new file mode 100644 index 000000000..ef5a44e62 --- /dev/null +++ b/src/algorithms/parametric/mod.rs @@ -0,0 +1,16 @@ +//! Parametric algorithm implementations for the unified estimation platform. +//! +//! Supported today: +//! - SAEM +//! - FOCEI +//! +//! IT2B remains intentionally deferred. + +mod algorithm; +pub mod focei; +pub mod saem; + +pub(crate) use algorithm::run_parametric_algorithm; +pub(crate) use algorithm::ParametricAlgorithmInput; +pub use algorithm::{dispatch_parametric_algorithm, ParametricAlgorithm, ParametricConfig}; +pub use saem::{FSaemConfig, FSAEM}; diff --git a/src/algorithms/parametric/saem.rs b/src/algorithms/parametric/saem.rs new file mode 100644 index 000000000..e9bef3c69 --- /dev/null +++ b/src/algorithms/parametric/saem.rs @@ -0,0 +1,1106 @@ +//! f-SAEM (fast Stochastic Approximation Expectation-Maximization) Algorithm +//! +//! This module implements the f-SAEM algorithm for maximum likelihood estimation +//! in nonlinear mixed-effects models. +//! +//! # Algorithm Overview +//! +//! f-SAEM is an enhanced version of SAEM that uses four MCMC kernels for improved +//! mixing and faster convergence: +//! +//! 1. **Kernel 1 (Prior)**: Full multivariate proposals from N(0, Ω) +//! 2. **Kernel 2 (Component-wise)**: Single-component adaptive random walk +//! 3. **Kernel 3 (Block)**: Block random walk with varying block sizes +//! 4. **Kernel 4 (MAP-based)**: Proposals centered at MAP with Laplace covariance +//! +//! # Algorithm Phases +//! +//! ## Phase 1: Burn-in (K₁ iterations) +//! - Step size γₖ = 1 (full updates) +//! - Simulated annealing on variance (floor shrinking) +//! - All four MCMC kernels active +//! +//! ## Phase 2: Estimation (K₂ iterations) +//! - Decreasing step size γₖ = 1/(k - K₁) +//! - Sufficient statistics converge to true expectations +//! - MAP kernel may be disabled for efficiency +//! +//! # Mathematical Background +//! +//! SAEM replaces the intractable E-step with stochastic approximation: +//! +//! ```text +//! E-step: Draw φ⁽ᵏ⁾ ~ p(φ | y, θ⁽ᵏ⁻¹⁾) using MCMC +//! SA-step: sₖ = sₖ₋₁ + γₖ(S(φ⁽ᵏ⁾) - sₖ₋₁) +//! M-step: θ⁽ᵏ⁾ = argmax_θ Q(θ, sₖ) +//! ``` +//! +//! For normal random effects, sufficient statistics are: +//! - S₁ = Σᵢ φᵢ (sum of parameters) +//! - S₂ = Σᵢ φᵢφᵢᵀ (sum of outer products) +//! +//! And the M-step has closed-form solutions: +//! - μ = S₁ / n +//! - Ω = S₂ / n - μμᵀ +//! +//! # References +//! +//! - Kuhn & Lavielle (2005). "Maximum likelihood estimation in nonlinear +//! mixed effects models." Computational Statistics & Data Analysis. +//! - Comets et al. (2017). "Parameter estimation in nonlinear mixed effect +//! models using saemix." Journal of Statistical Software. +//! - Lavielle, M. (2015). "Mixed Effects Models for the Population Approach." +//! Chapman & Hall/CRC. + +use anyhow::Result; +use faer::linalg::solvers::DenseSolveCore; +use faer::{Col, Mat}; +use ndarray::Array2; +use pharmsol::{Data, Equation, Event, ResidualErrorModels}; +use rand::SeedableRng; +use rand_chacha::ChaCha8Rng; +use std::collections::HashMap; + +use crate::algorithms::{Status, StopReason}; +use crate::api::SaemConfig; +use crate::estimation::parametric::{ + advance_saem_chains, blended_subject_covariate_m_step, covariate_state, + ensure_positive_definite_covariance, estimate_initial_sigma_sq, finalize_saem_result, + initialize_population_in_phi_space, phi_to_psi, recenter_individual_estimates, + refresh_saem_objective_history, residual_error_estimates_from_observed_outeqs, + sample_eta_from_population, subject_mean_phi, transform_label, + update_residual_error_from_individuals, ChainState, Individual, IndividualEstimates, + KernelConfig, ParameterTransform, ParametricCovariateContext, ParametricIterationLog, + ParametricResultInput, PhiVector, Population, SaemFinalizeInput, SaemMcmcState, + SufficientStats, UncertaintyEstimates, +}; +use crate::model::CovariateModel; +use crate::output::shared::RunConfiguration; + +use super::algorithm::{ParametricAlgorithm, ParametricAlgorithmInput, ParametricConfig}; + +/// f-SAEM algorithm configuration +/// +/// This structure can be constructed from [`SaemConfig`] to ensure +/// consistency with user-facing configuration. +#[derive(Debug, Clone)] +pub struct FSaemConfig { + /// Base parametric algorithm configuration + pub base: ParametricConfig, + /// MCMC kernel configuration + pub kernel_config: KernelConfig, + /// Number of MCMC chains per subject + pub n_chains: usize, + /// Random seed for reproducibility + pub seed: u64, + /// Number of pure burn-in iterations with γ=0 (nbiter.burn in R saemix) + /// During this phase, MCMC runs but statistics are not accumulated + /// R saemix default: 5 + pub n_pure_burn: usize, + /// Number of SA iterations with γ=1 (exploration phase after burn-in) + /// R saemix: this is K₁ - nbiter.burn, where K₁ = nbiter.saemix[1] + /// R saemix default: 300 - 5 = 295 + pub n_sa: usize, + /// Total number of iterations (K₁ + K₂) + /// R saemix default: 300 + 100 = 400 + pub n_iterations: usize, + /// Simulated annealing decay rate for variance floor (alpha.sa in R) + /// R saemix default: 0.97 + pub sa_alpha: f64, + /// Minimum variance for simulated annealing + pub sa_min_var: f64, + /// Number of SA iterations for variance floor decay (nbiter.sa in R) + /// R saemix default: K₁/2 = 150 + pub n_sa_variance: usize, +} + +impl FSaemConfig { + /// Get total burn-in iterations (pure burn + SA phase) + /// This is equivalent to nbiter.saemix[1] in R + pub fn n_burn_in(&self) -> usize { + self.n_pure_burn + self.n_sa + } +} + +impl Default for FSaemConfig { + /// Default configuration matching R saemix defaults exactly + /// + /// R saemix defaults: + /// - nbiter.saemix = c(300, 100) → K₁=300, K₂=100, total=400 + /// - nbiter.burn = 5 + /// - nbiter.sa = K₁/2 = 150 (for variance floor decay) + /// - alpha.sa = 0.97 + /// - nb.chains = 1 + /// - nbiter.mcmc = c(2, 2, 2, 0) + fn default() -> Self { + Self { + base: ParametricConfig::default(), + kernel_config: KernelConfig::default(), + n_chains: 1, + seed: 123456, // Match R saemix default seed + // R saemix: nbiter.burn = 5 (pure burn-in with γ=0) + n_pure_burn: 5, + // R saemix: K₁ = 300 (exploration phase with γ=1) + // n_sa = K₁ - n_pure_burn = 300 - 5 = 295 + n_sa: 295, + // R saemix: K₁ + K₂ = 300 + 100 = 400 + n_iterations: 400, + sa_alpha: 0.97, + sa_min_var: 1e-6, + // R saemix: nbiter.sa = K₁/2 = 150 (variance floor decay iterations) + n_sa_variance: 150, + } + } +} + +impl FSaemConfig { + /// Create internal configuration from the user-facing SAEM config. + pub fn from_saem_config(config: &SaemConfig) -> Self { + let k1 = config.k1_iterations; + let k2 = config.k2_iterations; + + Self { + base: ParametricConfig { + max_iterations: k1 + k2, + burn_in: config.burn_in, + n_chains: config.n_chains, + n_samples: config.mcmc_iterations, + parameter_tolerance: 1e-4, + objective_tolerance: 1e-4, + use_annealing: config.sa_iterations > 0, + initial_temperature: 1.0, + }, + kernel_config: KernelConfig { + n_kernel1: 2, + n_kernel2: 2, + n_kernel3: 2, + n_kernel4: if config.n_kernels >= 4 { 0 } else { 0 }, // Kernel 4 disabled by default + map_iterations: 0, + rw_step_size: config.mcmc_step_size, + target_acceptance: 0.4, + rw_init: 0.5, + }, + n_chains: config.n_chains, + seed: config.seed, + n_pure_burn: config.burn_in, + n_sa: k1.saturating_sub(config.burn_in), + n_iterations: k1 + k2, + sa_alpha: config.sa_cooling_factor, + sa_min_var: config.omega_min_variance, + // R saemix: nbiter.sa = K₁/2 for variance floor decay + n_sa_variance: if config.sa_iterations > 0 { + config.sa_iterations + } else { + k1 / 2 + }, + } + } + + /// Create configuration with auto-scaling for small datasets + /// + /// R saemix automatically increases chains when N < 50: + /// nb.chains = ceiling(50/N) + pub fn from_saem_config_with_auto_chains(saem_config: &SaemConfig, n_subjects: usize) -> Self { + let mut config = Self::from_saem_config(saem_config); + + // R saemix: if N < 50, auto-increase chains + if n_subjects < 50 && saem_config.n_chains == 1 { + config.n_chains = ((50.0 / n_subjects as f64).ceil() as usize).max(1); + tracing::info!( + "Auto-scaled MCMC chains from 1 to {} for small dataset (N={})", + config.n_chains, + n_subjects + ); + } + + config + } +} + +/// f-SAEM algorithm state +pub struct FSAEM { + /// Pharmacokinetic/pharmacodynamic model + equation: E, + /// Population data + data: Data, + /// Run configuration derived from the unified API surface + run_configuration: RunConfiguration, + /// Current population parameters (μ, Ω) + population: Population, + /// Individual parameter estimates from last iteration + individual_estimates: IndividualEstimates, + /// Accumulated sufficient statistics + sufficient_stats: SufficientStats, + /// Current iteration number + iteration: usize, + /// Current objective function value (-2LL approximation) + objf: f64, + /// Previous objective function value (for convergence) + prev_objf: f64, + /// Algorithm status + status: Status, + /// f-SAEM specific configuration + config: FSaemConfig, + /// Chain states for each subject and chain + chain_states: Vec>, + /// Current residual error variance (σ² estimated in M-step) + sigma_sq: f64, + /// Sufficient statistic for residual error (statrese in R saemix) + /// This is Σ(weighted residuals²) and gets updated via stochastic approximation + statrese: f64, + /// Residual error models for prediction-based sigma calculation + /// Used for both M-step residual weighting and likelihood computation + /// Uses pharmsol's ResidualErrorModels which computes sigma from prediction + residual_error_models: ResidualErrorModels, + /// Variance floor for simulated annealing + variance_floor: Col, + /// Random number generator + rng: ChaCha8Rng, + /// Shared iteration log used for reporting and written outputs. + iteration_log: ParametricIterationLog, + /// Parameter transforms (φ ↔ ψ conversions) + /// Maps between unconstrained space (φ) and constrained space (ψ) + transforms: Vec, + /// Adaptive proposal scales for Kernel 2 & 3 random walks (domega2 in R saemix) + /// Initialized as sqrt(diag(Ω)) * rw_init and adapted based on acceptance rates + /// This decouples proposal width from current Ω, preventing collapse + domega2: Col, + /// Structured subject-level covariate model used to compute subject-specific mean φ. + subject_covariate_model: Option, + /// Structured subject-level covariate values keyed by covariate name. + subject_covariates: Vec>, + /// Structured occasion-level covariate model preserved in fitted state for IOV-ready paths. + occasion_covariate_model: Option, + /// Structured occasion-level covariate values keyed by covariate name. + occasion_covariates: Vec>, +} + +impl FSAEM { + /// Create a new f-SAEM algorithm instance from the unified SAEM config. + /// + /// This constructor reads configuration from [`SaemConfig`] to ensure + /// the algorithm behaves according to user-specified parameters. + /// For small datasets (N < 50), MCMC chains are automatically scaled + /// following R saemix behavior. + pub(crate) fn create(input: ParametricAlgorithmInput) -> Result> { + let n_subjects = input.data.subjects().len(); + let config = + FSaemConfig::from_saem_config_with_auto_chains(input.saem_config(), n_subjects); + Self::create_with_config(input, config) + } + + /// Create with custom configuration (advanced users) + /// + /// Use this when you need fine-grained control over algorithm parameters + /// beyond what [`SaemConfig`] provides. + pub(crate) fn create_with_config( + input: ParametricAlgorithmInput, + config: FSaemConfig, + ) -> Result> { + let run_configuration = input.run_configuration(); + // Initialize population from the unified parameter space. + let population = input.initial_population()?; + let transforms = input.parameter_transforms(); + let ParametricAlgorithmInput { + equation, + data, + covariate_context, + residual_error_models, + .. + } = input; + let ParametricCovariateContext { + subject_model: subject_covariate_model, + occasion_model: occasion_covariate_model, + subject_covariates, + occasion_covariates, + } = covariate_context; + let n_params = population.npar(); + let n_subjects = data.subjects().len(); + + // Initialize sufficient statistics + let sufficient_stats = SufficientStats::new(n_params); + + // Initialize residual error variance + let sigma_sq = estimate_initial_sigma_sq(&residual_error_models); + // Initialize statrese - sufficient statistic for residual error + // Start with initial sigma² * n_obs (will be divided back to get sigma) + let n_obs_estimate = data + .subjects() + .iter() + .map(|s| { + s.occasions() + .iter() + .flat_map(|o| o.events()) + .filter(|e| matches!(e, Event::Observation(_))) + .count() + }) + .sum::(); + let statrese = sigma_sq * n_obs_estimate as f64; + + // Initialize chain states (one set of chains per subject) + let chain_states: Vec> = (0..n_subjects) + .map(|_| { + (0..config.n_chains) + .map(|_| ChainState::new(Col::zeros(n_params))) + .collect() + }) + .collect(); + + // Initialize variance floor for simulated annealing + let variance_floor = Col::from_fn(n_params, |i| population.omega()[(i, i)]); + + // Initialize adaptive proposal scales (domega2 in R saemix) + // R saemix: domega2 = sqrt(diag(omega.eta)) * rw.init + let domega2 = Col::from_fn(n_params, |i| { + population.omega()[(i, i)].sqrt() * config.kernel_config.rw_init + }); + + // Random number generator + let rng = ChaCha8Rng::seed_from_u64(config.seed); + + Ok(Box::new(Self { + equation, + data, + run_configuration, + population, + individual_estimates: IndividualEstimates::new(), + sufficient_stats, + iteration: 0, + objf: f64::INFINITY, + prev_objf: f64::INFINITY, + status: Status::Continue, + config, + chain_states, + sigma_sq, + statrese, + residual_error_models, + variance_floor, + rng, + iteration_log: ParametricIterationLog::new(), + transforms, + domega2, + subject_covariate_model, + subject_covariates, + occasion_covariate_model, + occasion_covariates, + })) + } + + /// Check if currently in burn-in phase (Phase 1 + Phase 2) + /// - Phase 1: Pure burn-in (γ=0) + /// - Phase 2: SA phase (γ=1) + /// Check if currently in pure burn-in phase (γ=0, no stat updates) + pub fn is_pure_burn_in(&self) -> bool { + self.iteration <= self.config.n_pure_burn + } + + pub fn is_burn_in(&self) -> bool { + self.iteration <= self.config.n_burn_in() + } + + /// Check if currently in SA phase (γ=1 with simulated annealing) + pub fn is_sa_phase(&self) -> bool { + self.iteration > self.config.n_pure_burn && self.iteration <= self.config.n_burn_in() + } + + /// Check if variance floor should be applied (R saemix: kiter <= nbiter.sa) + /// + /// In R saemix, `nbiter.sa` controls how long the variance floor decays. + /// This is separate from the step size schedule. + pub fn is_variance_floor_active(&self) -> bool { + self.iteration <= self.config.n_sa_variance + } + + /// Get the current step size γₖ + /// + /// Matches R saemix reference implementation: + /// - Phase 1 (kiter <= nbiter.burn): γ = 0 (pure MCMC, no stat update) + /// - Phase 2 (kiter <= nbiter.sa): γ = 1 (SA phase with annealing) + /// - Phase 3 (kiter > nbiter.sa): γ = 1/(k - n_burn_in + 1) (stochastic approx) + pub fn current_step_size(&self) -> f64 { + if self.is_pure_burn_in() { + // Phase 1: Pure burn-in - no statistics update (γ=0) + 0.0 + } else if self.is_sa_phase() { + // Phase 2: SA phase - full updates (γ=1) with simulated annealing + 1.0 + } else { + // Phase 3: Stochastic approximation - decreasing step size + let post_burnin = self.iteration - self.config.n_burn_in(); + 1.0 / (post_burnin as f64).max(1.0) + } + } + + /// Run E-step: Sample from p(φ | y, θ) using vectorized MCMC kernels + /// + /// This implementation matches R saemix's approach: + /// 1. All proposals are generated for all subjects at once + /// 2. Likelihoods are computed in parallel batch + /// 3. Accept/reject is vectorized + fn e_step_impl(&mut self) -> Result<()> { + let n_params = self.population.npar(); + let n_subjects = self.data.len(); + + // Get Cholesky of Ω and Ω⁻¹ using faer's built-in methods + let omega = self.population.omega(); + let llt = omega + .llt(faer::Side::Lower) + .map_err(|_| anyhow::anyhow!("Omega not positive definite"))?; + let chol_omega = llt.L().to_owned(); + let omega_inv = llt.inverse(); + + // Subject-specific population means in φ space. + let mean_phi = self.current_subject_mean_phi(); + + // Current η for all subjects (N × P matrix) + // η = φ - μ, so φ = μ + η + let mut eta_matrix: Array2 = Array2::zeros((n_subjects, n_params)); + for i in 0..n_subjects { + if !self.chain_states[i].is_empty() { + for j in 0..n_params { + eta_matrix[[i, j]] = self.chain_states[i][0].eta[j]; + } + } + } + + let SaemMcmcState { + eta_matrix, + log_likelihoods: current_ll, + log_priors: current_log_prior, + } = advance_saem_chains( + &self.equation, + &self.data, + &self.residual_error_models, + &self.transforms, + &mean_phi, + &chol_omega, + &omega_inv, + &self.config.kernel_config, + self.iteration, + &mut self.domega2, + &mut self.rng, + eta_matrix, + )?; + + // Update chain states with final η + for i in 0..n_subjects { + let eta = Col::from_fn(n_params, |j| eta_matrix[[i, j]]); + if self.chain_states[i].is_empty() { + self.chain_states[i].push(ChainState { + eta: eta.clone(), + log_likelihood: current_ll[i], + log_prior: current_log_prior[i], + }); + } else { + self.chain_states[i][0].eta = eta.clone(); + self.chain_states[i][0].log_likelihood = current_ll[i]; + self.chain_states[i][0].log_prior = current_log_prior[i]; + } + } + + // Accumulate sufficient statistics + let mut new_stats = SufficientStats::new(n_params); + let mut individuals = Vec::with_capacity(n_subjects); + + for i in 0..n_subjects { + let phi = Col::from_fn(n_params, |j| mean_phi[i][j] + eta_matrix[[i, j]]); + new_stats.accumulate(&phi)?; + + let eta = Col::from_fn(n_params, |j| eta_matrix[[i, j]]); + let subject_id = self.data.subjects()[i].id().clone(); + // Note: We store phi (unconstrained) in Individual. The field is named "psi" + // but for SAEM it contains the φ values. Transform to ψ when needed. + let individual = Individual::new(subject_id, eta, phi)?; + individuals.push(individual); + } + + // Stochastic approximation update + let step_size = self.current_step_size(); + self.sufficient_stats + .stochastic_update(&new_stats, step_size)?; + + // Update individual estimates + self.individual_estimates = IndividualEstimates::from_vec(individuals); + + Ok(()) + } + + /// Run M-step: Update population parameters from sufficient statistics + /// + /// During pure burn-in (γ=0), we skip parameter updates entirely. + /// This matches R saemix behavior where nbiter.burn is pure MCMC exploration. + fn m_step_impl(&mut self) -> Result<()> { + // During pure burn-in, don't update parameters - just explore + if self.is_pure_burn_in() { + let subject_means = self.current_subject_mean_phi(); + refresh_saem_objective_history( + &mut self.objf, + &mut self.prev_objf, + false, + &self.equation, + &self.data, + &self.residual_error_models, + &self.transforms, + &self.population, + &self.individual_estimates, + &subject_means, + ); + self.iteration_log.log_iteration( + self.iteration, + self.objf, + &self.population, + &self.status, + ); + return Ok(()); + } + + let (mu, omega) = if self.subject_covariate_model.is_some() { + self.compute_covariate_aware_m_step()? + } else { + self.sufficient_stats.compute_m_step()? + }; + + // Apply simulated annealing to variance during SA variance phase + // R saemix: applies floor decay for nbiter.sa iterations (typically K₁/2) + let omega_constrained = if self.is_variance_floor_active() { + self.apply_simulated_annealing(&omega) + } else { + omega + }; + + // Update population parameters + self.population.update_mu(mu)?; + + // Ensure Omega remains positive definite before updating + // R saemix uses cutoff on diagonal: domega <- cutoff(diag(omega), .Machine$double.eps) + let omega_pd = ensure_positive_definite_covariance(&omega_constrained); + self.population.update_omega(omega_pd)?; + + if self.subject_covariate_model.is_some() { + let subject_means = self.current_subject_mean_phi(); + self.recenter_subject_effects(&subject_means)?; + } + + // Update residual error (simplified - could be more sophisticated) + self.update_residual_error()?; + + let subject_means = self.current_subject_mean_phi(); + refresh_saem_objective_history( + &mut self.objf, + &mut self.prev_objf, + true, + &self.equation, + &self.data, + &self.residual_error_models, + &self.transforms, + &self.population, + &self.individual_estimates, + &subject_means, + ); + self.iteration_log + .log_iteration(self.iteration, self.objf, &self.population, &self.status); + + Ok(()) + } + + /// Apply simulated annealing to variance (prevent premature convergence) + /// + /// Matches R saemix: diag.omega = max(diag.omega.new, floor * alpha) + /// This ensures variance doesn't collapse too quickly during burn-in + fn apply_simulated_annealing(&mut self, omega: &Mat) -> Mat { + let n = omega.nrows(); + let mut omega_sa = omega.clone(); + + // R formula: take max of new variance vs decayed floor + for i in 0..n { + let decayed_floor = self.variance_floor[i] * self.config.sa_alpha; + // New variance should be at least the decayed floor + omega_sa[(i, i)] = omega[(i, i)].max(decayed_floor); + // Update floor for next iteration (don't let it go below minimum) + self.variance_floor[i] = decayed_floor.max(self.config.sa_min_var); + } + + omega_sa + } + + /// Update residual error variance from residuals + /// + /// Uses **prediction-based** weighting to match R saemix behavior: + /// - For constant error: σ² = Σ(y - f)² / n + /// - For proportional error: σ² = Σ(y - f)² / f² / n (f = prediction) + /// - For combined error: uses current σ estimate for weighting + /// + /// Follows R saemix approach: + /// 1. Compute current residuals: statr = Σ(weighted_res²) + /// 2. Update sufficient statistic via SA: statrese = statrese + γ*(statr - statrese) + /// 3. Compute sig² = statrese / nobs + /// 4. During SA phase: pres = max(pres * alpha, sqrt(sig²)) + /// 5. After SA phase: normal stochastic approximation + fn update_residual_error(&mut self) -> Result<()> { + let step_size = self.current_step_size(); + let use_annealed_sigma_floor = self.is_variance_floor_active(); + let allow_sigma_update = !self.is_pure_burn_in(); + let update = update_residual_error_from_individuals( + &self.equation, + &self.data, + &mut self.residual_error_models, + &self.transforms, + &self.individual_estimates, + step_size, + self.sigma_sq, + self.statrese, + use_annealed_sigma_floor, + self.config.sa_alpha, + allow_sigma_update, + )?; + + self.sigma_sq = update.sigma_sq; + self.statrese = update.statrese; + Ok(()) + } + + /// Get current residual error standard deviation + pub fn sigma(&self) -> f64 { + self.sigma_sq.sqrt() + } +} + +impl ParametricAlgorithm for FSAEM { + fn equation(&self) -> &E { + &self.equation + } + + fn data(&self) -> &Data { + &self.data + } + + fn population(&self) -> &Population { + &self.population + } + + fn population_mut(&mut self) -> &mut Population { + &mut self.population + } + + fn individual_estimates(&self) -> &IndividualEstimates { + &self.individual_estimates + } + + fn iteration(&self) -> usize { + self.iteration + } + + fn increment_iteration(&mut self) -> usize { + self.iteration += 1; + self.iteration + } + + fn objective_function(&self) -> f64 { + self.objf + } + + fn status(&self) -> &Status { + &self.status + } + + fn set_status(&mut self, status: Status) { + self.status = status; + } + + fn initialize(&mut self) -> Result<()> { + tracing::info!("Initializing f-SAEM algorithm"); + tracing::info!( + "Configuration: {} pure burn-in, {} SA, {} total iterations, {} chains", + self.config.n_pure_burn, + self.config.n_sa, + self.config.n_iterations, + self.config.n_chains + ); + + let initialized_population = + initialize_population_in_phi_space(&mut self.population, &self.transforms)?; + let mu_psi = initialized_population.mu_psi; + let mu_phi = initialized_population.mu_phi; + let omega_phi = initialized_population.omega_phi; + + if let Some(model) = self.subject_covariate_model.as_mut() { + let initialize_intercepts = + (0..model.beta().nrows()).all(|index| model.beta()[index].abs() < 1e-12); + if initialize_intercepts { + let intercepts = (0..self.population.npar()) + .map(|index| self.population.mu()[index]) + .collect::>(); + model.set_intercepts(&intercepts)?; + } + } + + let n_params = self.population.npar(); + + // Show transforms being used + let transform_names: Vec<&str> = self.transforms.iter().map(transform_label).collect(); + + // Print to stderr for real-time feedback + eprintln!( + "\n╔══════════════════════════════════════════════════════════════════════════════╗" + ); + eprintln!( + "║ f-SAEM Algorithm ║" + ); + eprintln!( + "╠══════════════════════════════════════════════════════════════════════════════╣" + ); + eprintln!( + "║ Phases: {} burn-in → {} SA → {} estimation │ {} chains │ {} subjects ", + self.config.n_pure_burn, + self.config.n_sa, + self.config.n_iterations - self.config.n_burn_in(), + self.config.n_chains, + self.data.subjects().len() + ); + eprintln!("║ Transforms: {:?}", transform_names); + eprintln!( + "║ Initial μ(ψ): {:?}", + mu_psi + .as_slice() + .iter() + .map(|value| format!("{:.4}", value)) + .collect::>() + ); + eprintln!( + "║ Initial μ(φ): {:?}", + mu_phi + .as_slice() + .iter() + .map(|value| format!("{:.4}", value)) + .collect::>() + ); + eprintln!( + "║ Initial ω²(φ): {:?}", + (0..n_params) + .map(|i| format!("{:.4}", omega_phi[(i, i)])) + .collect::>() + ); + eprintln!( + "╚══════════════════════════════════════════════════════════════════════════════╝\n" + ); + + // Initialize chain states from population distribution + let subjects = self.data.subjects(); + + for i in 0..subjects.len() { + for chain_idx in 0..self.config.n_chains { + // Sample initial η from N(0, Ω) + let eta = sample_eta_from_population(&self.population, &mut self.rng); + self.chain_states[i][chain_idx] = ChainState::new(eta); + } + } + + // Initialize sufficient statistics + self.sufficient_stats = SufficientStats::new(n_params); + + // Initialize variance floor + for i in 0..n_params { + self.variance_floor[i] = self.population.omega()[(i, i)]; + } + + Ok(()) + } + + fn e_step(&mut self) -> Result<()> { + self.e_step_impl() + } + + fn m_step(&mut self) -> Result<()> { + self.m_step_impl() + } + + fn evaluate(&mut self) -> Result { + // Check for stop file + if std::path::Path::new("stop").exists() { + self.status = Status::Stop(StopReason::Stopped); + return Ok(self.status.clone()); + } + + // Check max iterations + if self.iteration >= self.config.n_iterations { + self.status = Status::Stop(StopReason::MaxCycles); + return Ok(self.status.clone()); + } + + // Only check convergence after burn-in + if !self.is_burn_in() && self.iteration > self.config.n_burn_in() + 10 { + // Check objective function convergence + let objf_change = (self.objf - self.prev_objf).abs() / (1.0 + self.prev_objf.abs()); + if objf_change < self.config.base.objective_tolerance { + self.status = Status::Stop(StopReason::Converged); + return Ok(self.status.clone()); + } + } + + self.status = Status::Continue; + Ok(self.status.clone()) + } + + fn log_iteration(&mut self) { + use std::io::Write; + + let phase = if self.is_pure_burn_in() { + "burn-in" + } else if self.is_sa_phase() { + "SA" + } else { + "est" + }; + + // Log via tracing (for file output) + tracing::info!( + "f-SAEM iter {} ({}): -2LL ≈ {:.4}, γ = {:.4}, σ = {:.4}", + self.iteration, + phase, + self.objf, + self.current_step_size(), + self.sigma() + ); + + // Also print progress to stdout for real-time feedback (every 10 iterations or key moments) + let should_print = self.iteration == 1 + || self.iteration == self.config.n_pure_burn + || self.iteration == self.config.n_burn_in() + || self.iteration % 10 == 0 + || self.iteration == self.config.n_iterations; + + if should_print { + // Transform μ from φ space to ψ space for display + // μ in storage is in φ space, but users want to see ψ (natural scale) + let mu_phi = PhiVector::from(self.population.mu()); + let mu_psi = phi_to_psi(&self.transforms, &mu_phi); + + // Format parameter values (showing ψ = natural scale) + let mu_str: Vec = mu_psi + .as_slice() + .iter() + .map(|value| format!("{:.4}", value)) + .collect(); + let omega_diag_str: Vec = self + .population + .variances_as_vec() + .iter() + .map(|v| format!("{:.4}", v)) + .collect(); + + eprintln!( + "[SAEM {:>4}/{}] {:>7} | -2LL: {:>12.4} | γ: {:.3} | σ: {:.4} | μ(ψ): [{}] | ω²: [{}]", + self.iteration, + self.config.n_iterations, + phase, + self.objf, + self.current_step_size(), + self.sigma(), + mu_str.join(", "), + omega_diag_str.join(", ") + ); + let _ = std::io::stderr().flush(); + } + + if self.iteration % 10 == 0 || self.iteration <= 5 { + // Show both φ and ψ spaces in debug + let mu_phi = PhiVector::from(self.population.mu()); + let mu_psi = phi_to_psi(&self.transforms, &mu_phi); + tracing::debug!(" μ(φ): {:?}", self.population.mu_as_vec()); + tracing::debug!(" μ(ψ): {:?}", mu_psi.as_slice()); + tracing::debug!(" diag(Ω): {:?}", self.population.variances_as_vec()); + } + } + + fn into_result(&self) -> Result> { + let observed_outeqs = self + .data + .subjects() + .iter() + .flat_map(|subject| subject.occasions().iter()) + .flat_map(|occasion| occasion.iter()) + .filter_map(|event| match event { + Event::Observation(observation) => Some(observation.outeq()), + _ => None, + }) + .collect::>(); + + finalize_saem_result( + ParametricResultInput { + equation: &self.equation, + data: &self.data, + population: &self.population, + individual_estimates: &self.individual_estimates, + objf: self.objf, + iterations: self.iteration, + status: &self.status, + run_configuration: self.run_configuration.clone(), + iteration_log: self.iteration_log.clone(), + likelihood_estimates: Default::default(), + uncertainty_estimates: UncertaintyEstimates::new(), + sigma: residual_error_estimates_from_observed_outeqs( + &self.residual_error_models, + &observed_outeqs, + ), + transforms: &self.transforms, + covariates: Some(covariate_state( + self.subject_covariate_model.as_ref(), + &self.subject_covariates, + self.occasion_covariate_model.as_ref(), + &self.occasion_covariates, + )), + }, + SaemFinalizeInput { + chain_states: &self.chain_states, + residual_error_models: &self.residual_error_models, + seed: self.config.seed, + }, + ) + } + + fn sufficient_stats(&self) -> Option<&SufficientStats> { + Some(&self.sufficient_stats) + } +} + +impl FSAEM { + fn current_subject_mean_phi(&self) -> Vec> { + subject_mean_phi( + self.population.mu(), + self.data.subjects().len(), + self.subject_covariate_model.as_ref(), + &self.subject_covariates, + ) + } + + fn compute_covariate_aware_m_step(&mut self) -> Result<(Col, Mat)> { + let step_size = self.current_step_size(); + let Some(model) = self.subject_covariate_model.clone() else { + return self.sufficient_stats.compute_m_step(); + }; + + let (updated_model, _subject_means, mu, omega) = blended_subject_covariate_m_step( + &model, + &self.subject_covariates, + &self.individual_estimates, + &self.population, + step_size, + )?; + self.subject_covariate_model = Some(updated_model); + + Ok((mu, omega)) + } + + fn recenter_subject_effects(&mut self, subject_means: &[Col]) -> Result<()> { + self.individual_estimates = + recenter_individual_estimates(&self.individual_estimates, subject_means)?; + + for (subject_index, individual) in self.individual_estimates.iter().enumerate() { + if let Some(chain_state) = self + .chain_states + .get_mut(subject_index) + .and_then(|states| states.get_mut(0)) + { + chain_state.eta = individual.eta().clone(); + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::estimation::parametric::estimate_beta; + use crate::estimation::parametric::StepSizeSchedule; + + #[test] + fn test_cholesky_faer() { + // Test that faer's built-in Cholesky works correctly + let mat = Mat::from_fn(2, 2, |i, j| if i == j { 2.0 } else { 0.5 }); + + let l = mat.llt(faer::Side::Lower).unwrap().L().to_owned(); + + // Verify L * Lᵀ = mat + for i in 0..2 { + for j in 0..2 { + let mut sum: f64 = 0.0; + for k in 0..2 { + sum += l[(i, k)] * l[(j, k)]; + } + assert!((sum - mat[(i, j)]).abs() < 1e-10); + } + } + } + + #[test] + fn test_invert_faer() { + // Test that faer's built-in inversion works correctly + let mat = Mat::from_fn(2, 2, |i, j| if i == j { 2.0 } else { 0.5 }); + + let inv = mat.llt(faer::Side::Lower).unwrap().inverse(); + + // Verify mat * inv = I + for i in 0..2 { + for j in 0..2 { + let mut sum: f64 = 0.0; + for k in 0..2 { + sum += mat[(i, k)] * inv[(k, j)]; + } + let expected: f64 = if i == j { 1.0 } else { 0.0 }; + assert!((sum - expected).abs() < 1e-10); + } + } + } + + #[test] + fn test_step_size_schedule() { + let schedule = StepSizeSchedule::new_saem(100, 200); + + // During burn-in: step_size should be 1.0 + assert_eq!(schedule.step_size(1), 1.0); + assert_eq!(schedule.step_size(50), 1.0); + assert_eq!(schedule.step_size(99), 1.0); + + // At burn-in boundary: step_size(100) = 1/(100-100+1) = 1.0 + assert_eq!(schedule.step_size(100), 1.0); + + // After burn-in: step_size should decrease + // step_size(110) = 1/(110-100+1) = 1/11 + // step_size(150) = 1/(150-100+1) = 1/51 + assert!(schedule.step_size(150) < schedule.step_size(110)); + assert!((schedule.step_size(110) - 1.0 / 11.0).abs() < 1e-10); + } + + #[test] + fn test_estimate_beta_recovers_subject_covariate_effects() { + let model = CovariateModel::new(vec!["CL", "V"], vec!["WT"], vec![vec![true], vec![false]]) + .unwrap(); + + let subject_covariates = vec![ + HashMap::from([(String::from("WT"), 60.0)]), + HashMap::from([(String::from("WT"), 80.0)]), + ]; + let individuals = IndividualEstimates::from_vec(vec![ + Individual::new( + "1", + Col::from_fn(2, |_| 0.0), + Col::from_fn(2, |index| if index == 0 { 11.0 } else { 50.0 }), + ) + .unwrap(), + Individual::new( + "2", + Col::from_fn(2, |_| 0.0), + Col::from_fn(2, |index| if index == 0 { 13.0 } else { 50.0 }), + ) + .unwrap(), + ]); + + let beta = estimate_beta(&model, &subject_covariates, &individuals).unwrap(); + + assert!((beta[0] - 5.0).abs() < 1e-5); + assert!((beta[1] - 0.1).abs() < 1e-6); + assert!((beta[2] - 50.0).abs() < 1e-5); + } +} diff --git a/src/api/estimation_problem.rs b/src/api/estimation_problem.rs index ea0de580a..b7879aad7 100644 --- a/src/api/estimation_problem.rs +++ b/src/api/estimation_problem.rs @@ -3,18 +3,21 @@ use pharmsol::{Data, Equation}; use serde::Serialize; use crate::algorithms::Algorithm; +use crate::api::SaemConfig; use crate::estimation::nonparametric::Prior; use crate::model::ModelDefinition; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum EstimationMethod { Nonparametric(NonparametricMethod), + Parametric(ParametricMethod), } impl EstimationMethod { pub fn algorithm(self) -> Algorithm { match self { EstimationMethod::Nonparametric(method) => method.algorithm(), + EstimationMethod::Parametric(method) => method.algorithm(), } } } @@ -36,6 +39,23 @@ impl NonparametricMethod { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ParametricMethod { + Saem(SaemOptions), + Focei(FoceiOptions), + It2b(It2bOptions), +} + +impl ParametricMethod { + pub fn algorithm(self) -> Algorithm { + match self { + ParametricMethod::Saem(_) => Algorithm::SAEM, + ParametricMethod::Focei(_) => Algorithm::FOCEI, + ParametricMethod::It2b(_) => Algorithm::IT2B, + } + } +} + #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] pub struct NpagOptions; @@ -45,6 +65,15 @@ pub struct NpodOptions; #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] pub struct PostProbOptions; +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub struct SaemOptions; + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub struct FoceiOptions; + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub struct It2bOptions; + #[derive(Debug, Clone, PartialEq, Serialize)] pub struct OutputPlan { pub write: bool, @@ -125,6 +154,7 @@ pub struct AlgorithmTuning { pub min_distance: f64, pub nm_steps: usize, pub tolerance: f64, + pub saem: SaemConfig, } impl Default for AlgorithmTuning { @@ -133,6 +163,7 @@ impl Default for AlgorithmTuning { min_distance: 1e-4, nm_steps: 100, tolerance: 1e-6, + saem: SaemConfig::default(), } } } diff --git a/src/api/fit.rs b/src/api/fit.rs index 0787bf645..832f131c4 100644 --- a/src/api/fit.rs +++ b/src/api/fit.rs @@ -1,8 +1,8 @@ use anyhow::Result; use pharmsol::equation::Equation; -use crate::api::estimation_problem::EstimationProblem; -use crate::estimation::nonparametric; +use crate::api::estimation_problem::{EstimationMethod, EstimationProblem}; +use crate::estimation::{nonparametric, parametric}; use crate::results::FitResult; pub fn fit( @@ -12,6 +12,11 @@ pub fn fit( problem.initialize_logs()?; } + let method = problem.method; let compiled = problem.compile()?; - nonparametric::fit(compiled) + + match method { + EstimationMethod::Nonparametric(_) => nonparametric::fit(compiled), + EstimationMethod::Parametric(_) => parametric::fit(compiled), + } } diff --git a/src/api/mod.rs b/src/api/mod.rs index cb757ba1a..fe7806c5b 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,11 +1,14 @@ pub mod estimation_problem; pub mod fit; pub mod model_definition; +pub mod saem_config; pub use estimation_problem::{ AlgorithmTuning, ConvergenceOptions, EstimationMethod, EstimationProblem, - EstimationProblemBuilder, LoggingLevel, LoggingOptions, NonparametricMethod, NpagOptions, - NpodOptions, OutputPlan, PostProbOptions, RuntimeOptions, + EstimationProblemBuilder, FoceiOptions, It2bOptions, LoggingLevel, LoggingOptions, + NonparametricMethod, NpagOptions, NpodOptions, OutputPlan, + ParametricMethod, PostProbOptions, RuntimeOptions, SaemOptions, }; pub use fit::fit; pub use model_definition::{ModelDefinition, ModelDefinitionBuilder}; +pub use saem_config::SaemConfig; diff --git a/src/api/saem_config.rs b/src/api/saem_config.rs new file mode 100644 index 000000000..7f73f373c --- /dev/null +++ b/src/api/saem_config.rs @@ -0,0 +1,124 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Deserialize, Clone, Serialize)] +#[serde(deny_unknown_fields, default)] +pub struct SaemConfig { + pub k1_iterations: usize, + pub k2_iterations: usize, + pub burn_in: usize, + pub sa_iterations: usize, + pub sa_cooling_factor: f64, + pub mcmc_step_size: f64, + pub rw_init: f64, + pub n_chains: usize, + pub mcmc_iterations: usize, + pub omega_min_variance: f64, + pub use_gibbs: bool, + pub n_kernels: usize, + pub transform_par: Vec, + pub compute_map: bool, + pub compute_fim: bool, + pub compute_ll_is: bool, + pub compute_ll_gq: bool, + pub n_mc_is: usize, + pub nu_is: usize, + pub n_nodes_gq: usize, + pub n_sd_gq: f64, + pub display_progress: usize, + pub seed: u64, + pub fix_seed: bool, +} + +impl Default for SaemConfig { + fn default() -> Self { + Self { + k1_iterations: 300, + k2_iterations: 100, + burn_in: 5, + sa_iterations: 0, + sa_cooling_factor: 0.97, + mcmc_step_size: 0.4, + rw_init: 0.5, + n_chains: 1, + mcmc_iterations: 1, + omega_min_variance: 1e-6, + use_gibbs: false, + n_kernels: 4, + transform_par: vec![], + compute_map: true, + compute_fim: true, + compute_ll_is: true, + compute_ll_gq: false, + n_mc_is: 5000, + nu_is: 4, + n_nodes_gq: 12, + n_sd_gq: 4.0, + display_progress: 10, + seed: 123456, + fix_seed: true, + } + } +} + +impl SaemConfig { + pub fn total_iterations(&self) -> usize { + self.k1_iterations + self.k2_iterations + } + + pub fn is_exploration_phase(&self, iteration: usize) -> bool { + iteration <= self.k1_iterations + } + + pub fn is_smoothing_phase(&self, iteration: usize) -> bool { + iteration > self.k1_iterations + } + + pub fn is_sa_active(&self, iteration: usize) -> bool { + self.sa_iterations > 0 && iteration <= self.sa_iterations + } + + pub fn sa_temperature(&self, iteration: usize) -> f64 { + if self.is_sa_active(iteration) { + self.sa_cooling_factor.powi(iteration as i32) + } else { + 1.0 + } + } + + pub fn step_size(&self, iteration: usize) -> f64 { + if iteration <= self.k1_iterations { + 1.0 + } else { + let k_smooth = iteration - self.k1_iterations; + 1.0 / (k_smooth as f64 + 1.0) + } + } + + pub fn get_transform(&self, param_idx: usize) -> u8 { + self.transform_par.get(param_idx).copied().unwrap_or(1) + } + + pub fn get_transforms(&self, n_params: usize) -> Vec { + let mut transforms = self.transform_par.clone(); + while transforms.len() < n_params { + transforms.push(1); + } + transforms.truncate(n_params); + transforms + } + + pub fn infer_transforms_from_ranges(&mut self, ranges: &[(f64, f64)]) { + self.transform_par = ranges + .iter() + .map(|(lower, upper)| { + if *lower >= 0.0 && *upper > 0.0 && lower.is_finite() && upper.is_finite() { + 1 + } else if (*lower - 0.0).abs() < 1e-10 && (*upper - 1.0).abs() < 1e-10 { + 3 + } else { + 0 + } + }) + .collect(); + } +} diff --git a/src/compile/validation.rs b/src/compile/validation.rs index 9ded6f62c..74a2efbcd 100644 --- a/src/compile/validation.rs +++ b/src/compile/validation.rs @@ -1,7 +1,7 @@ use anyhow::{bail, Result}; use pharmsol::equation::Equation; -use crate::api::EstimationProblem; +use crate::api::{EstimationMethod, EstimationProblem}; pub fn validate_problem(problem: &EstimationProblem) -> Result<()> { if problem.model.parameters.is_empty() { @@ -16,7 +16,15 @@ pub fn validate_problem(problem: &EstimationProblem) -> Result<( bail!("at least one observation channel is required"); } - problem.model.parameters.finite_ranges()?; + if let EstimationMethod::Parametric(_) = problem.method { + if problem.model.observations.residual_error_models.is_none() { + bail!("parametric methods require residual error models in ObservationSpec"); + } + } + + if let EstimationMethod::Nonparametric(_) = problem.method { + problem.model.parameters.finite_ranges()?; + } Ok(()) } diff --git a/src/estimation/mod.rs b/src/estimation/mod.rs index d97ede1fc..27321739a 100644 --- a/src/estimation/mod.rs +++ b/src/estimation/mod.rs @@ -1 +1,2 @@ pub mod nonparametric; +pub mod parametric; diff --git a/src/estimation/nonparametric/engine.rs b/src/estimation/nonparametric/engine.rs index a81b3e563..c04864469 100644 --- a/src/estimation/nonparametric/engine.rs +++ b/src/estimation/nonparametric/engine.rs @@ -14,7 +14,13 @@ impl NonparametricEngine { pub fn fit( problem: CompiledProblem, ) -> Result> { - let EstimationMethod::Nonparametric(method) = problem.method(); + let method = match problem.method() { + EstimationMethod::Nonparametric(method) => method, + other => anyhow::bail!( + "nonparametric engine received parametric method: {:?}", + other + ), + }; let output = problem.output_plan().clone(); let runtime = problem.runtime_options().clone(); let (model, data) = problem.into_parts(); diff --git a/src/estimation/parametric/assembler.rs b/src/estimation/parametric/assembler.rs new file mode 100644 index 000000000..6140b75a8 --- /dev/null +++ b/src/estimation/parametric/assembler.rs @@ -0,0 +1,295 @@ +use anyhow::Result; +use pharmsol::{Data, Equation}; + +use crate::algorithms::Status; +use crate::estimation::parametric::{ + phi_to_psi_vec, ChainState, CovariateState, Individual, IndividualEffectsState, + IndividualEstimates, LikelihoodEstimates, ParameterTransform, ParametricIterationLog, + ParametricModelState, ParametricWorkspace, Population, ResidualErrorEstimates, + UncertaintyEstimates, +}; +use crate::output::shared::RunConfiguration; + +use super::posthoc::{eta_samples_by_subject, saem_posthoc_likelihood}; + +pub(crate) struct ParametricResultInput<'a, E: Equation> { + pub equation: &'a E, + pub data: &'a Data, + pub population: &'a Population, + pub individual_estimates: &'a IndividualEstimates, + pub objf: f64, + pub iterations: usize, + pub status: &'a Status, + pub run_configuration: RunConfiguration, + pub iteration_log: ParametricIterationLog, + pub likelihood_estimates: LikelihoodEstimates, + pub uncertainty_estimates: UncertaintyEstimates, + pub sigma: ResidualErrorEstimates, + pub transforms: &'a [ParameterTransform], + pub covariates: Option, +} + +pub(crate) struct SaemFinalizeInput<'a> { + pub chain_states: &'a [Vec], + pub residual_error_models: &'a pharmsol::ResidualErrorModels, + pub seed: u64, +} + +pub(crate) fn assemble_parametric_result( + input: ParametricResultInput<'_, E>, +) -> Result> { + assemble_parametric_workspace(ParametricWorkspaceInput { + equation: input.equation, + data: input.data, + population: input.population, + individual_estimates: input.individual_estimates, + objf: input.objf, + iterations: input.iterations, + status: input.status, + run_configuration: input.run_configuration, + iteration_log: input.iteration_log, + likelihood_estimates: input.likelihood_estimates, + uncertainty_estimates: input.uncertainty_estimates, + sigma: input.sigma, + transforms: input.transforms, + covariates: input.covariates, + }) +} + +pub(crate) fn finalize_saem_result( + input: ParametricResultInput<'_, E>, + finalize: SaemFinalizeInput<'_>, +) -> Result> { + let eta_samples = eta_samples_by_subject(finalize.chain_states); + let (likelihood_estimates, minus2ll) = saem_posthoc_likelihood( + input.equation, + input.data, + finalize.residual_error_models, + input.transforms, + input.population, + input.individual_estimates, + &eta_samples, + finalize.seed, + )?; + tracing::info!("-2LL computed by importance sampling: {:.4}", minus2ll); + + assemble_parametric_result(ParametricResultInput { + equation: input.equation, + data: input.data, + population: input.population, + individual_estimates: input.individual_estimates, + objf: minus2ll, + iterations: input.iterations, + status: input.status, + run_configuration: input.run_configuration, + iteration_log: input.iteration_log, + likelihood_estimates, + uncertainty_estimates: input.uncertainty_estimates, + sigma: input.sigma, + transforms: input.transforms, + covariates: input.covariates, + }) +} + +struct ParametricWorkspaceInput<'a, E: Equation> { + equation: &'a E, + data: &'a Data, + population: &'a Population, + individual_estimates: &'a IndividualEstimates, + objf: f64, + iterations: usize, + status: &'a Status, + run_configuration: RunConfiguration, + iteration_log: ParametricIterationLog, + likelihood_estimates: LikelihoodEstimates, + uncertainty_estimates: UncertaintyEstimates, + sigma: ResidualErrorEstimates, + transforms: &'a [ParameterTransform], + covariates: Option, +} + +fn assemble_parametric_workspace( + input: ParametricWorkspaceInput<'_, E>, +) -> Result> { + let population_psi = build_population_in_psi_space(input.population, input.transforms)?; + let individual_estimates_psi = + build_individual_estimates_in_psi_space(input.individual_estimates, input.transforms)?; + let mut state = ParametricModelState::from_population_and_sigma(&population_psi, &input.sigma); + if let Some(covariates) = input.covariates { + state.covariates = covariates; + } + let individuals = IndividualEffectsState::from_individual_estimates(&individual_estimates_psi); + + Ok(ParametricWorkspace::new( + state, + individuals, + input.equation.clone(), + input.data.clone(), + population_psi, + individual_estimates_psi, + input.objf, + input.iterations, + input.status.clone(), + input.run_configuration, + input.iteration_log, + input.likelihood_estimates, + input.uncertainty_estimates, + input.sigma, + None, + )) +} + +fn build_population_in_psi_space( + population: &Population, + transforms: &[ParameterTransform], +) -> Result { + let mean_psi = phi_to_psi_vec(transforms, population.mu()); + Population::new( + mean_psi, + population.omega().clone(), + population.parameters().clone(), + ) +} + +fn build_individual_estimates_in_psi_space( + individual_estimates: &IndividualEstimates, + transforms: &[ParameterTransform], +) -> Result { + let individuals = individual_estimates + .iter() + .map(|individual| { + let psi = phi_to_psi_vec(transforms, individual.psi()); + let mut rebuilt = Individual::new( + individual.subject_id().to_string(), + individual.eta().clone(), + psi, + )?; + if let Some(objf) = individual.objective_function() { + rebuilt.set_objective_function(objf); + } + Ok(rebuilt) + }) + .collect::>>()?; + + Ok(IndividualEstimates::from_vec(individuals)) +} + +#[cfg(test)] +mod tests { + use super::{assemble_parametric_result, ParametricResultInput}; + use anyhow::Result; + use faer::{Col, Mat}; + use pharmsol::{Data, Subject}; + + use crate::algorithms::{Status, StopReason}; + use crate::api::{ + AlgorithmTuning, ConvergenceOptions, LoggingLevel, LoggingOptions, OutputPlan, + RuntimeOptions, + }; + use crate::estimation::parametric::{ + Individual, IndividualEstimates, LikelihoodEstimates, ParameterTransform, Population, + ResidualErrorEstimates, UncertaintyEstimates, + }; + use crate::model::{ParameterSpace, ParameterSpec}; + use crate::output::shared::RunConfiguration; + use crate::prelude::*; + + fn equation() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, b, _rateiv, _cov| { + fetch_params!(p, ke); + dx[0] = -ke * x[0] + b[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, v); + y[0] = x[0] / v; + }, + ) + } + + fn data() -> Data { + let subject = Subject::builder("1") + .bolus(0.0, 100.0, 0) + .observation(1.0, 10.0, 0) + .observation(2.0, 8.0, 0) + .build(); + + Data::new(vec![subject]) + } + + #[test] + fn saem_assembler_converts_phi_space_outputs_to_psi_space() -> Result<()> { + let parameters = ParameterSpace::new() + .add(ParameterSpec::bounded("ke", 0.1, 1.0)) + .add(ParameterSpec::bounded("v", 1.0, 20.0)); + let population = Population::new( + Col::from_fn(2, |index| if index == 0 { 0.0 } else { 2.0 }), + Mat::from_fn(2, 2, |row, col| if row == col { 0.25 } else { 0.0 }), + parameters, + )?; + let individual = Individual::new( + "1", + Col::from_fn(2, |_| 0.0), + Col::from_fn(2, |index| if index == 0 { 0.1 } else { 2.1 }), + )?; + let run_configuration = RunConfiguration::new( + Algorithm::SAEM, + &OutputPlan::disabled(), + &RuntimeOptions { + cycles: 100, + cache: true, + progress: false, + idelta: 0.12, + tad: 0.0, + prior: None, + logging: LoggingOptions { + initialize: false, + level: LoggingLevel::Info, + write: false, + stdout: false, + }, + convergence: ConvergenceOptions::default(), + tuning: AlgorithmTuning::default(), + }, + vec!["ke".to_string(), "v".to_string()], + ); + let equation = equation(); + let data = data(); + let individual_estimates = IndividualEstimates::from_vec(vec![individual]); + + let result = assemble_parametric_result(ParametricResultInput { + equation: &equation, + data: &data, + population: &population, + individual_estimates: &individual_estimates, + objf: 120.0, + iterations: 12, + status: &Status::Stop(StopReason::Converged), + run_configuration, + iteration_log: { + let mut log = ParametricIterationLog::new(); + log.log_iteration(1, 130.0, &population, &Status::Continue); + log + }, + likelihood_estimates: LikelihoodEstimates { + ll_importance_sampling: Some(-60.0), + is_n_samples: Some(1000), + ..LikelihoodEstimates::new() + }, + uncertainty_estimates: UncertaintyEstimates::new(), + sigma: ResidualErrorEstimates::additive(0.5), + transforms: &[ParameterTransform::LogNormal, ParameterTransform::LogNormal], + covariates: None, + })?; + + assert!((result.mu()[0] - 1.0).abs() < 1e-12); + assert!((result.mu()[1] - 7.38905609893065).abs() < 1e-12); + assert_eq!(result.iteration_log().len(), 1); + assert_eq!(result.likelihoods().is_n_samples, Some(1000)); + assert_eq!(result.sigma().additive, Some(0.5)); + Ok(()) + } +} diff --git a/src/estimation/parametric/compiler.rs b/src/estimation/parametric/compiler.rs new file mode 100644 index 000000000..99a683ca5 --- /dev/null +++ b/src/estimation/parametric/compiler.rs @@ -0,0 +1,170 @@ +use pharmsol::Equation; + +use crate::compile::CompiledProblem; +use crate::estimation::parametric::state::{ + CovariateEffectsSnapshot, CovariateState, FixedEffects, ParametricModelState, + ParametricTransformKind, PsiVector, RandomEffects, ResidualState, TransformSet, +}; +use crate::model::{ + CovariateSpec, ParameterDomain, ParameterSpace, ParameterVariability, RandomEffectsSpec, + VariabilityModel, +}; + +pub fn compile_model_state(problem: &CompiledProblem) -> ParametricModelState { + let parameter_names = problem + .model + .parameters + .iter() + .map(|parameter| parameter.name.clone()) + .collect::>(); + + let initial_values = problem + .model + .parameters + .iter() + .map(initial_value) + .collect::>(); + + let transforms = problem + .model + .parameters + .iter() + .map(|parameter| ParametricTransformKind::from(¶meter.transform)) + .collect::>(); + + let n_parameters = parameter_names.len(); + let covariance = identity_matrix(n_parameters); + let standard_deviations = vec![1.0; n_parameters]; + let variability = resolve_variability_model(problem); + + let covariates = match &problem.model.covariates { + CovariateSpec::InEquation => CovariateState { + subject_effects: None, + occasion_effects: None, + }, + CovariateSpec::Structured(spec) => CovariateState { + subject_effects: spec.subject_effects.as_ref().map(|model| { + CovariateEffectsSnapshot::from_model( + model, + problem + .design + .structured_covariates + .subject_rows + .iter() + .map(|row| row.values.clone()) + .collect(), + ) + }), + occasion_effects: spec.occasion_effects.as_ref().map(|model| { + CovariateEffectsSnapshot::from_model( + model, + problem + .design + .structured_covariates + .occasion_rows + .iter() + .map(|row| row.values.clone()) + .collect(), + ) + }), + }, + }; + + ParametricModelState { + fixed_effects: FixedEffects { + parameter_names, + population_mean: PsiVector(initial_values), + }, + random_effects: RandomEffects { + covariance: covariance.clone(), + standard_deviations, + correlation: covariance, + }, + residual: ResidualState { values: Vec::new() }, + transforms: TransformSet { transforms }, + covariates, + variability, + } +} + +fn resolve_variability_model(problem: &CompiledProblem) -> VariabilityModel { + let n_parameters = problem.model.parameters.len(); + let derived_subject = derived_subject_mask(&problem.model.parameters); + let derived_occasion = derived_occasion_mask(&problem.model.parameters); + + let mut subject = problem.model.variability.subject.clone(); + if subject.enabled_for.len() != n_parameters { + subject.enabled_for = derived_subject; + } + + let occasion = match &problem.model.variability.occasion { + Some(spec) => { + let mut spec = spec.clone(); + if spec.enabled_for.len() != n_parameters { + spec.enabled_for = derived_occasion.clone(); + } + Some(spec) + } + None if derived_occasion.iter().any(|enabled| *enabled) => Some(RandomEffectsSpec { + enabled_for: derived_occasion, + covariance: subject.covariance.clone(), + }), + None => None, + }; + + VariabilityModel { subject, occasion } +} + +fn derived_subject_mask(parameter_space: &ParameterSpace) -> Vec { + parameter_space + .iter() + .map(|parameter| { + matches!( + parameter.variability, + ParameterVariability::Subject | ParameterVariability::SubjectAndOccasion + ) + }) + .collect() +} + +fn derived_occasion_mask(parameter_space: &ParameterSpace) -> Vec { + parameter_space + .iter() + .map(|parameter| { + matches!( + parameter.variability, + ParameterVariability::Occasion | ParameterVariability::SubjectAndOccasion + ) + }) + .collect() +} + +fn initial_value(parameter: &crate::model::ParameterSpec) -> f64 { + if let Some(initial) = parameter.initial { + return initial; + } + + match parameter.domain { + ParameterDomain::Bounded { lower, upper } => (lower + upper) / 2.0, + ParameterDomain::Positive { lower, upper } => match (lower, upper) { + (Some(lower), Some(upper)) => (lower + upper) / 2.0, + (Some(lower), None) => lower.max(1.0), + (None, Some(upper)) => upper / 2.0, + (None, None) => 1.0, + }, + ParameterDomain::Unbounded { lower, upper } => match (lower, upper) { + (Some(lower), Some(upper)) => (lower + upper) / 2.0, + _ => 0.0, + }, + } +} + +fn identity_matrix(size: usize) -> Vec> { + (0..size) + .map(|row| { + (0..size) + .map(|col| if row == col { 1.0 } else { 0.0 }) + .collect() + }) + .collect() +} diff --git a/src/estimation/parametric/effects.rs b/src/estimation/parametric/effects.rs new file mode 100644 index 000000000..3136ca172 --- /dev/null +++ b/src/estimation/parametric/effects.rs @@ -0,0 +1,300 @@ +use std::collections::HashMap; + +use anyhow::Result; +use faer::linalg::solvers::DenseSolveCore; +use faer::{Col, Mat}; + +use crate::compile::StructuredCovariateDesign; +use crate::estimation::parametric::{IndividualEstimates, Population}; +use crate::model::{CovariateModel, CovariateSpec}; + +use super::state::{CovariateEffectsSnapshot, CovariateState}; + +#[derive(Debug, Clone, Default)] +pub(crate) struct ParametricCovariateContext { + pub subject_model: Option, + pub subject_covariates: Vec>, + pub occasion_model: Option, + pub occasion_covariates: Vec>, +} + +pub(crate) fn build_parametric_covariate_context( + covariates: &CovariateSpec, + structured_covariates: &StructuredCovariateDesign, +) -> ParametricCovariateContext { + match covariates { + CovariateSpec::InEquation => ParametricCovariateContext::default(), + CovariateSpec::Structured(spec) => ParametricCovariateContext { + subject_model: spec.subject_effects.clone(), + subject_covariates: subject_covariate_maps(structured_covariates), + occasion_model: spec.occasion_effects.clone(), + occasion_covariates: occasion_covariate_maps(structured_covariates), + }, + } +} + +pub(crate) fn recenter_individual_estimates( + individual_estimates: &IndividualEstimates, + subject_means: &[Col], +) -> Result { + let n_params = individual_estimates + .get(0) + .map(|individual| individual.npar()) + .unwrap_or(0); + let mut recentered = Vec::with_capacity(individual_estimates.nsubjects()); + + for (subject_index, individual) in individual_estimates.iter().enumerate() { + let eta = Col::from_fn(n_params, |param_index| { + individual.psi()[param_index] - subject_means[subject_index][param_index] + }); + let mut rebuilt = crate::estimation::parametric::Individual::new( + individual.subject_id().to_string(), + eta, + individual.psi().clone(), + )?; + if let Some(objf) = individual.objective_function() { + rebuilt.set_objective_function(objf); + } + recentered.push(rebuilt); + } + + Ok(IndividualEstimates::from_vec(recentered)) +} + +pub(crate) fn covariance_from_individual_etas( + individual_estimates: &IndividualEstimates, +) -> Mat { + let n_subjects = individual_estimates.nsubjects(); + let n_params = individual_estimates + .get(0) + .map(|individual| individual.npar()) + .unwrap_or(0); + + if n_subjects == 0 || n_params == 0 { + return Mat::zeros(n_params, n_params); + } + + let denom = n_subjects.max(1) as f64; + Mat::from_fn(n_params, n_params, |row, col| { + individual_estimates + .iter() + .map(|individual| individual.eta()[row] * individual.eta()[col]) + .sum::() + / denom + }) +} + +pub(crate) fn subject_covariate_maps( + structured_covariates: &StructuredCovariateDesign, +) -> Vec> { + covariate_maps( + &structured_covariates.subject_columns, + structured_covariates + .subject_rows + .iter() + .map(|row| row.values.as_slice()), + ) +} + +pub(crate) fn occasion_covariate_maps( + structured_covariates: &StructuredCovariateDesign, +) -> Vec> { + covariate_maps( + &structured_covariates.occasion_columns, + structured_covariates + .occasion_rows + .iter() + .map(|row| row.values.as_slice()), + ) +} + +pub(crate) fn subject_mean_phi( + population_mean: &Col, + n_subjects: usize, + model: Option<&CovariateModel>, + subject_covariates: &[HashMap], +) -> Vec> { + match model { + Some(model) => (0..n_subjects) + .map(|subject_index| { + let empty = HashMap::new(); + let covariates = subject_covariates.get(subject_index).unwrap_or(&empty); + model.compute_mu(covariates) + }) + .collect(), + None => (0..n_subjects).map(|_| population_mean.clone()).collect(), + } +} + +pub(crate) fn covariate_state( + subject_model: Option<&CovariateModel>, + subject_covariates: &[HashMap], + occasion_model: Option<&CovariateModel>, + occasion_covariates: &[HashMap], +) -> CovariateState { + CovariateState { + subject_effects: covariate_snapshot(subject_model, subject_covariates), + occasion_effects: covariate_snapshot(occasion_model, occasion_covariates), + } +} + +fn covariate_snapshot( + model: Option<&CovariateModel>, + covariates: &[HashMap], +) -> Option { + model.map(|model| { + CovariateEffectsSnapshot::from_model( + model, + covariates + .iter() + .map(|row| { + model + .covariate_names() + .iter() + .map(|name| row.get(name).copied()) + .collect() + }) + .collect(), + ) + }) +} + +fn covariate_maps<'a>( + columns: &[String], + rows: impl Iterator]>, +) -> Vec> { + rows.map(|row| { + columns + .iter() + .cloned() + .zip(row.iter().copied()) + .filter_map(|(name, value)| value.map(|value| (name, value))) + .collect() + }) + .collect() +} + +pub(crate) fn estimate_beta( + model: &CovariateModel, + subject_covariates: &[HashMap], + individual_estimates: &IndividualEstimates, +) -> Result> { + let n_subjects = individual_estimates.nsubjects(); + let n_params = model.n_params(); + let design = model.build_design_matrix(subject_covariates); + let responses = Col::from_fn(n_subjects * n_params, |row_index| { + let subject_index = row_index / n_params; + let param_index = row_index % n_params; + individual_estimates.get(subject_index).unwrap().psi()[param_index] + }); + + let estimated_indices = model.estimated_beta_indices(); + if estimated_indices.is_empty() { + return Ok(model.beta().clone()); + } + + let fixed_indices = (0..model.beta().nrows()) + .filter(|index| !estimated_indices.contains(index)) + .collect::>(); + let mut normal_matrix = Mat::::zeros(estimated_indices.len(), estimated_indices.len()); + let mut normal_rhs = Col::::zeros(estimated_indices.len()); + + for row_index in 0..design.nrows() { + let offset = fixed_indices.iter().fold(0.0, |acc, fixed_index| { + acc + design[(row_index, *fixed_index)] * model.beta()[*fixed_index] + }); + let adjusted_response = responses[row_index] - offset; + + for (lhs_position, lhs_index) in estimated_indices.iter().enumerate() { + let lhs_value = design[(row_index, *lhs_index)]; + normal_rhs[lhs_position] += lhs_value * adjusted_response; + + for (rhs_position, rhs_index) in estimated_indices.iter().enumerate() { + normal_matrix[(lhs_position, rhs_position)] += + lhs_value * design[(row_index, *rhs_index)]; + } + } + } + + for diagonal in 0..estimated_indices.len() { + normal_matrix[(diagonal, diagonal)] += 1e-8; + } + + let solver = normal_matrix + .llt(faer::Side::Lower) + .map_err(|_| anyhow::anyhow!("covariate normal equations are singular"))?; + let inverse = solver.inverse(); + let estimated_beta = Col::from_fn(estimated_indices.len(), |row_index| { + (0..estimated_indices.len()) + .map(|col_index| inverse[(row_index, col_index)] * normal_rhs[col_index]) + .sum() + }); + + let mut beta = model.beta().clone(); + for (position, beta_index) in estimated_indices.iter().enumerate() { + beta[*beta_index] = estimated_beta[position]; + } + + Ok(beta) +} + +pub(crate) fn blended_subject_covariate_m_step( + model: &CovariateModel, + subject_covariates: &[HashMap], + individual_estimates: &IndividualEstimates, + population: &Population, + step_size: f64, +) -> Result<(CovariateModel, Vec>, Col, Mat)> { + let target_beta = estimate_beta(model, subject_covariates, individual_estimates)?; + let current_beta = model.beta().clone(); + let updated_beta = Col::from_fn(target_beta.nrows(), |index| { + current_beta[index] + step_size * (target_beta[index] - current_beta[index]) + }); + let mut updated_model = model.clone(); + updated_model.set_beta(updated_beta)?; + + let subject_means = subject_covariates + .iter() + .map(|covariates| updated_model.compute_mu(covariates)) + .collect::>(); + let mu = Col::from_fn(population.npar(), |index| { + updated_model + .intercept(index) + .unwrap_or(population.mu()[index]) + }); + let omega = covariance_from_subject_means(individual_estimates, &subject_means)?; + + Ok((updated_model, subject_means, mu, omega)) +} + +pub(crate) fn covariance_from_subject_means( + individual_estimates: &IndividualEstimates, + subject_means: &[Col], +) -> Result> { + let n_subjects = individual_estimates.nsubjects(); + let n_params = individual_estimates + .get(0) + .map(|individual| individual.npar()) + .unwrap_or(0); + + if n_subjects == 0 || n_params == 0 { + return Ok(Mat::zeros(n_params, n_params)); + } + + let mut omega = Mat::::zeros(n_params, n_params); + for subject_index in 0..n_subjects { + let phi = individual_estimates.get(subject_index).unwrap().psi(); + for row in 0..n_params { + let eta_row = phi[row] - subject_means[subject_index][row]; + for col in 0..n_params { + let eta_col = phi[col] - subject_means[subject_index][col]; + omega[(row, col)] += eta_row * eta_col; + } + } + } + + let denom = n_subjects as f64; + Ok(Mat::from_fn(n_params, n_params, |row, col| { + omega[(row, col)] / denom + })) +} diff --git a/src/estimation/parametric/engine.rs b/src/estimation/parametric/engine.rs new file mode 100644 index 000000000..c414ea2a3 --- /dev/null +++ b/src/estimation/parametric/engine.rs @@ -0,0 +1,37 @@ +use anyhow::Result; +use pharmsol::Equation; + +use crate::algorithms::parametric::{run_parametric_algorithm, ParametricAlgorithmInput}; +use crate::api::EstimationMethod; +use crate::compile::CompiledProblem; +use crate::estimation::parametric::compiler::compile_model_state; +use crate::estimation::parametric::workspace::ParametricWorkspace; +use crate::results::FitResult; + +#[derive(Debug, Default, Clone, Copy)] +pub struct ParametricEngine; + +impl ParametricEngine { + pub fn fit( + problem: CompiledProblem, + ) -> Result> { + let compiled_state = compile_model_state(&problem); + let occasion_design = problem.design.occasions.clone(); + if !matches!(problem.method(), EstimationMethod::Parametric(_)) { + anyhow::bail!( + "parametric engine received non-parametric method: {:?}", + problem.method() + ); + } + let input = ParametricAlgorithmInput::from_compiled_problem(problem)?; + let workspace = run_parametric_algorithm(input)?; + Ok(workspace.with_compiled_state(compiled_state, &occasion_design)) + } +} + +pub fn fit( + problem: CompiledProblem, +) -> Result> { + let workspace = ParametricEngine::fit(problem)?; + Ok(workspace.into_fit_result()) +} diff --git a/src/estimation/parametric/individual.rs b/src/estimation/parametric/individual.rs new file mode 100644 index 000000000..a43b1e477 --- /dev/null +++ b/src/estimation/parametric/individual.rs @@ -0,0 +1,294 @@ +//! Individual parameter estimates. + +use anyhow::{bail, Result}; +use faer::{Col, Mat}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone)] +pub struct Individual { + subject_id: String, + eta: Col, + psi: Col, + conditional_variance: Option>, + objective_function: Option, +} + +impl Individual { + pub fn new(subject_id: impl Into, eta: Col, psi: Col) -> Result { + if eta.nrows() != psi.nrows() { + bail!( + "Random effects length ({}) must match parameter length ({})", + eta.nrows(), + psi.nrows() + ); + } + + Ok(Self { + subject_id: subject_id.into(), + eta, + psi, + conditional_variance: None, + objective_function: None, + }) + } + + pub fn with_variance( + subject_id: impl Into, + eta: Col, + psi: Col, + variance: Mat, + ) -> Result { + let n = eta.nrows(); + if variance.nrows() != n || variance.ncols() != n { + bail!( + "Variance matrix dimensions ({}x{}) must match parameter count ({})", + variance.nrows(), + variance.ncols(), + n + ); + } + + let mut individual = Self::new(subject_id, eta, psi)?; + individual.conditional_variance = Some(variance); + Ok(individual) + } + + pub fn subject_id(&self) -> &str { + &self.subject_id + } + + pub fn eta(&self) -> &Col { + &self.eta + } + + pub fn psi(&self) -> &Col { + &self.psi + } + + pub fn conditional_variance(&self) -> Option<&Mat> { + self.conditional_variance.as_ref() + } + + pub fn objective_function(&self) -> Option { + self.objective_function + } + + pub fn npar(&self) -> usize { + self.eta.nrows() + } + + pub fn standard_errors(&self) -> Option> { + self.conditional_variance + .as_ref() + .map(|var| Col::from_fn(self.npar(), |i| var[(i, i)].sqrt())) + } + + pub fn set_conditional_variance(&mut self, variance: Mat) -> Result<()> { + let n = self.npar(); + if variance.nrows() != n || variance.ncols() != n { + bail!( + "Variance matrix dimensions ({}x{}) must match parameter count ({})", + variance.nrows(), + variance.ncols(), + n + ); + } + self.conditional_variance = Some(variance); + Ok(()) + } + + pub fn set_objective_function(&mut self, objf: f64) { + self.objective_function = Some(objf); + } +} + +#[derive(Debug, Clone, Default)] +pub struct IndividualEstimates { + estimates: Vec, +} + +impl IndividualEstimates { + pub fn new() -> Self { + Self { + estimates: Vec::new(), + } + } + + pub fn from_vec(estimates: Vec) -> Self { + Self { estimates } + } + + pub fn add(&mut self, individual: Individual) { + self.estimates.push(individual); + } + + pub fn nsubjects(&self) -> usize { + self.estimates.len() + } + + pub fn get(&self, index: usize) -> Option<&Individual> { + self.estimates.get(index) + } + + pub fn get_by_id(&self, id: &str) -> Option<&Individual> { + self.estimates + .iter() + .find(|estimate| estimate.subject_id() == id) + } + + pub fn iter(&self) -> impl Iterator { + self.estimates.iter() + } + + pub fn eta_matrix(&self) -> Option> { + if self.estimates.is_empty() { + return None; + } + + let n_subjects = self.estimates.len(); + let n_params = self.estimates[0].npar(); + + Some(Mat::from_fn(n_subjects, n_params, |i, j| { + self.estimates[i].eta()[j] + })) + } + + pub fn psi_matrix(&self) -> Option> { + if self.estimates.is_empty() { + return None; + } + + let n_subjects = self.estimates.len(); + let n_params = self.estimates[0].npar(); + + Some(Mat::from_fn(n_subjects, n_params, |i, j| { + self.estimates[i].psi()[j] + })) + } + + pub fn eta_mean(&self) -> Option> { + if self.estimates.is_empty() { + return None; + } + + let n_subjects = self.estimates.len() as f64; + let n_params = self.estimates[0].npar(); + + Some(Col::from_fn(n_params, |j| { + self.estimates + .iter() + .map(|estimate| estimate.eta()[j]) + .sum::() + / n_subjects + })) + } + + pub fn eta_covariance(&self) -> Option> { + let mean = self.eta_mean()?; + let n_subjects = self.estimates.len() as f64; + let n_params = self.estimates[0].npar(); + + Some(Mat::from_fn(n_params, n_params, |i, j| { + self.estimates + .iter() + .map(|estimate| (estimate.eta()[i] - mean[i]) * (estimate.eta()[j] - mean[j])) + .sum::() + / (n_subjects - 1.0) + })) + } + + pub fn shrinkage(&self, population_variance: &Col) -> Option> { + let eta_cov = self.eta_covariance()?; + let n_params = self.estimates[0].npar(); + + Some(Col::from_fn(n_params, |i| { + let eta_var = eta_cov[(i, i)]; + let pop_var = population_variance[i]; + if pop_var > 0.0 { + 1.0 - (eta_var / pop_var) + } else { + 0.0 + } + })) + } +} + +impl Serialize for Individual { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + + let mut state = serializer.serialize_struct("Individual", 4)?; + state.serialize_field("subject_id", &self.subject_id)?; + + let eta_vec: Vec = (0..self.eta.nrows()).map(|i| self.eta[i]).collect(); + state.serialize_field("eta", &eta_vec)?; + + let psi_vec: Vec = (0..self.psi.nrows()).map(|i| self.psi[i]).collect(); + state.serialize_field("psi", &psi_vec)?; + + state.serialize_field("objective_function", &self.objective_function)?; + state.end() + } +} + +impl<'de> Deserialize<'de> for Individual { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + struct IndividualData { + subject_id: String, + eta: Vec, + psi: Vec, + objective_function: Option, + } + + let data = IndividualData::deserialize(deserializer)?; + + let eta = Col::from_fn(data.eta.len(), |i| data.eta[i]); + let psi = Col::from_fn(data.psi.len(), |i| data.psi[i]); + + let mut individual = + Individual::new(data.subject_id, eta, psi).map_err(serde::de::Error::custom)?; + individual.objective_function = data.objective_function; + Ok(individual) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_individual_creation() { + let eta = Col::from_fn(2, |i| if i == 0 { 0.1 } else { -0.2 }); + let psi = Col::from_fn(2, |i| if i == 0 { 5.5 } else { 45.0 }); + + let ind = Individual::new("SUBJ001", eta, psi).unwrap(); + + assert_eq!(ind.subject_id(), "SUBJ001"); + assert_eq!(ind.npar(), 2); + assert_eq!(ind.eta()[0], 0.1); + assert_eq!(ind.psi()[1], 45.0); + } + + #[test] + fn test_individual_estimates_collection() { + let mut estimates = IndividualEstimates::new(); + + for i in 0..3 { + let eta = Col::from_fn(2, |j| (i as f64) * 0.1 + (j as f64) * 0.05); + let psi = Col::from_fn(2, |j| 5.0 + (i as f64) + (j as f64) * 10.0); + let ind = Individual::new(format!("SUBJ{:03}", i), eta, psi).unwrap(); + estimates.add(ind); + } + + assert_eq!(estimates.nsubjects(), 3); + assert!(estimates.get_by_id("SUBJ001").is_some()); + assert!(estimates.eta_matrix().is_some()); + } +} diff --git a/src/estimation/parametric/integration.rs b/src/estimation/parametric/integration.rs new file mode 100644 index 000000000..7db18cba7 --- /dev/null +++ b/src/estimation/parametric/integration.rs @@ -0,0 +1,10 @@ +//! Numerical integration methods for marginal likelihood estimation. +//! +//! This module provides building blocks for computing marginal likelihoods in +//! mixed-effects models by integrating over random effects. + +mod importance_sampling; + +pub(crate) use importance_sampling::{ + ImportanceSamplingConfig, ImportanceSamplingEstimator, SubjectConditionalPosterior, +}; diff --git a/src/estimation/parametric/integration/importance_sampling.rs b/src/estimation/parametric/integration/importance_sampling.rs new file mode 100644 index 000000000..435b49110 --- /dev/null +++ b/src/estimation/parametric/integration/importance_sampling.rs @@ -0,0 +1,284 @@ +//! Importance Sampling for Marginal Likelihood Estimation +//! +//! This module implements importance sampling (IS) for estimating the marginal +//! log-likelihood in mixed-effects models, matching R saemix's `llis.saemix` function. + +use anyhow::Result; +use faer::linalg::solvers::DenseSolveCore; +use faer::{Col, Mat}; +use rand::SeedableRng; +use rand_chacha::ChaCha8Rng; +use rand_distr::{Distribution, StudentT}; +use statrs::function::gamma::ln_gamma; + +use crate::estimation::parametric::ParameterTransform; +use pharmsol::{Equation, Predictions, ResidualErrorModels, Subject}; + +#[derive(Debug, Clone)] +pub struct ImportanceSamplingConfig { + pub n_samples: usize, + pub nu: f64, + pub seed: u64, +} + +impl Default for ImportanceSamplingConfig { + fn default() -> Self { + Self { + n_samples: 5000, + nu: 4.0, + seed: 123456, + } + } +} + +impl ImportanceSamplingConfig { + pub fn saemix_defaults() -> Self { + Self::default() + } + + pub fn with_n_samples(mut self, n_samples: usize) -> Self { + self.n_samples = n_samples; + self + } + + pub fn with_nu(mut self, nu: f64) -> Self { + self.nu = nu; + self + } + + pub fn with_seed(mut self, seed: u64) -> Self { + self.seed = seed; + self + } +} + +#[derive(Debug, Clone)] +pub struct SubjectConditionalPosterior { + pub mean: Col, + pub variance: Vec, +} + +impl SubjectConditionalPosterior { + pub fn from_mcmc_samples( + eta_samples: &[Col], + mu_phi: &Col, + omega: &Mat, + ) -> Self { + let n_params = mu_phi.nrows(); + let n_samples = eta_samples.len(); + + if n_samples == 0 { + return Self { + mean: mu_phi.clone(), + variance: (0..n_params).map(|j| omega[(j, j)].max(1e-6)).collect(), + }; + } + + let mut mean_eta = vec![0.0; n_params]; + for eta in eta_samples { + for j in 0..n_params { + mean_eta[j] += eta[j]; + } + } + for value in &mut mean_eta { + *value /= n_samples as f64; + } + + let mean = Col::from_fn(n_params, |j| mu_phi[j] + mean_eta[j]); + + let variance = if n_samples > 1 { + let mut var = vec![0.0; n_params]; + for eta in eta_samples { + for j in 0..n_params { + let diff = eta[j] - mean_eta[j]; + var[j] += diff * diff; + } + } + for value in &mut var { + *value = (*value / (n_samples - 1) as f64).max(1e-6); + } + var + } else { + (0..n_params).map(|j| omega[(j, j)].max(1e-6)).collect() + }; + + Self { mean, variance } + } + + pub fn new(mean: Col, variance: Vec) -> Self { + Self { mean, variance } + } +} + +pub struct ImportanceSamplingEstimator<'a, E: Equation> { + config: ImportanceSamplingConfig, + equation: &'a E, + error_models: &'a ResidualErrorModels, + transforms: &'a [ParameterTransform], + mu_phi: &'a Col, + #[allow(dead_code)] + omega: &'a Mat, + omega_inv: Mat, + log_prior_const: f64, +} + +impl<'a, E: Equation> ImportanceSamplingEstimator<'a, E> { + pub fn new( + config: ImportanceSamplingConfig, + equation: &'a E, + error_models: &'a ResidualErrorModels, + transforms: &'a [ParameterTransform], + mu_phi: &'a Col, + omega: &'a Mat, + ) -> Result { + let omega_inv = omega + .llt(faer::Side::Lower) + .map_err(|_| anyhow::anyhow!("Omega not positive definite"))? + .inverse(); + let n_params = mu_phi.nrows(); + let log_det_omega = omega.determinant().ln(); + let log_prior_const = log_det_omega + (n_params as f64) * (2.0 * std::f64::consts::PI).ln(); + + Ok(Self { + config, + equation, + error_models, + transforms, + mu_phi, + omega, + omega_inv, + log_prior_const, + }) + } + + pub fn estimate_subject_ll( + &self, + subject: &Subject, + conditional: &SubjectConditionalPosterior, + rng: &mut impl rand::Rng, + ) -> f64 { + let n_params = self.mu_phi.nrows(); + let n_samples = self.config.n_samples; + let nu = self.config.nu; + + let t_dist = StudentT::new(nu).expect("Invalid nu for StudentT"); + let mut log_weights = Vec::with_capacity(n_samples); + + for _ in 0..n_samples { + let r: Vec = (0..n_params).map(|_| t_dist.sample(rng)).collect(); + let phi_sample = Col::from_fn(n_params, |j| { + conditional.mean[j] + conditional.variance[j].sqrt() * r[j] + }); + let psi_sample: Vec = (0..n_params) + .map(|j| self.transforms[j].phi_to_psi(phi_sample[j])) + .collect(); + + let log_lik = match self.equation.estimate_predictions(subject, &psi_sample) { + Ok(predictions) => { + let obs_pred_pairs = + predictions + .get_predictions() + .into_iter() + .filter_map(|pred| { + pred.observation() + .map(|obs| (pred.outeq(), obs, pred.prediction())) + }); + self.error_models.total_log_likelihood(obs_pred_pairs) + } + Err(_) => continue, + }; + + let eta = Col::from_fn(n_params, |j| phi_sample[j] - self.mu_phi[j]); + let mut quad_form = 0.0; + for j in 0..n_params { + for k in 0..n_params { + quad_form += eta[j] * self.omega_inv[(j, k)] * eta[k]; + } + } + let log_prior = -0.5 * (quad_form + self.log_prior_const); + + let mut log_proposal = 0.0; + for j in 0..n_params { + let log_t_pdf = ln_gamma((nu + 1.0) / 2.0) + - ln_gamma(nu / 2.0) + - 0.5 * (nu * std::f64::consts::PI).ln() + - ((nu + 1.0) / 2.0) * (1.0 + r[j] * r[j] / nu).ln(); + log_proposal += log_t_pdf; + } + log_proposal -= 0.5 * conditional.variance.iter().map(|v| v.ln()).sum::(); + + log_weights.push(log_lik + log_prior - log_proposal); + } + + if log_weights.is_empty() { + return f64::NEG_INFINITY; + } + + let max_weight = log_weights + .iter() + .cloned() + .fold(f64::NEG_INFINITY, f64::max); + if !max_weight.is_finite() { + return f64::NEG_INFINITY; + } + + let sum_exp: f64 = log_weights.iter().map(|&w| (w - max_weight).exp()).sum(); + max_weight + sum_exp.ln() - (log_weights.len() as f64).ln() + } + + pub fn estimate_minus2ll( + &self, + subjects: Vec<&Subject>, + conditionals: &[SubjectConditionalPosterior], + ) -> f64 { + let mut rng = ChaCha8Rng::seed_from_u64(self.config.seed); + let mut total_ll = 0.0; + + for (subject, conditional) in subjects.iter().zip(conditionals.iter()) { + total_ll += self.estimate_subject_ll(subject, conditional, &mut rng); + } + + -2.0 * total_ll + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_config_defaults() { + let config = ImportanceSamplingConfig::default(); + assert_eq!(config.n_samples, 5000); + assert_eq!(config.nu, 4.0); + } + + #[test] + fn test_conditional_posterior_from_empty() { + let mu_phi = Col::from_fn(2, |i| (i + 1) as f64); + let omega = Mat::from_fn(2, 2, |i, j| if i == j { 0.5 } else { 0.0 }); + + let conditional = SubjectConditionalPosterior::from_mcmc_samples(&[], &mu_phi, &omega); + + assert_eq!(conditional.mean[0], 1.0); + assert_eq!(conditional.mean[1], 2.0); + assert!((conditional.variance[0] - 0.5).abs() < 1e-6); + } + + #[test] + fn test_conditional_posterior_from_samples() { + let mu_phi = Col::from_fn(2, |_| 0.0); + let omega = Mat::from_fn(2, 2, |i, j| if i == j { 1.0 } else { 0.0 }); + + let samples = vec![ + Col::from_fn(2, |_| 1.0), + Col::from_fn(2, |_| 2.0), + Col::from_fn(2, |_| 3.0), + ]; + + let conditional = SubjectConditionalPosterior::from_mcmc_samples(&samples, &mu_phi, &omega); + + assert!((conditional.mean[0] - 2.0).abs() < 1e-6); + assert!((conditional.variance[0] - 1.0).abs() < 1e-6); + } +} diff --git a/src/estimation/parametric/likelihood.rs b/src/estimation/parametric/likelihood.rs new file mode 100644 index 000000000..24b99481c --- /dev/null +++ b/src/estimation/parametric/likelihood.rs @@ -0,0 +1,376 @@ +use anyhow::Result; +use faer::linalg::solvers::DenseSolveCore; +use faer::{Col, Mat}; +use ndarray::Array2; +use pharmsol::{Data, Equation, Event, Predictions, ResidualErrorModels, Subject}; + +use crate::estimation::parametric::{ + phi_to_psi, ImportanceSamplingConfig, ImportanceSamplingEstimator, IndividualEstimates, + LikelihoodEstimates, ParameterTransform, PhiVector, Population, SubjectConditionalPosterior, +}; + +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct ResidualErrorUpdate { + pub sigma_sq: f64, + pub statrese: f64, + pub n_observations: usize, +} + +pub fn batch_log_likelihood_from_eta( + equation: &E, + data: &Data, + error_models: &ResidualErrorModels, + transforms: &[ParameterTransform], + eta_matrix: &Array2, + mean_phi: &[Col], +) -> Result> { + let n_subjects = eta_matrix.nrows(); + let n_params = eta_matrix.ncols(); + + let mut psi_params = Array2::::zeros((n_subjects, n_params)); + for subject_index in 0..n_subjects { + let phi = PhiVector( + (0..n_params) + .map(|param_index| { + mean_phi[subject_index][param_index] + eta_matrix[[subject_index, param_index]] + }) + .collect(), + ); + let psi = phi_to_psi(transforms, &phi); + for (param_index, value) in psi.as_slice().iter().copied().enumerate() { + psi_params[[subject_index, param_index]] = value; + } + } + + pharmsol::prelude::simulator::log_likelihood_batch(equation, data, &psi_params, error_models) + .map_err(|error| anyhow::anyhow!("Likelihood computation failed: {}", error)) +} + +pub fn approximate_objective_from_individuals( + equation: &E, + data: &Data, + error_models: &ResidualErrorModels, + transforms: &[ParameterTransform], + population: &Population, + individual_estimates: &IndividualEstimates, + mean_phi: &[Col], +) -> f64 { + let n_subjects = individual_estimates.nsubjects(); + if n_subjects == 0 { + return f64::INFINITY; + } + let n_params = population.npar(); + + let mut eta_matrix = Array2::::zeros((n_subjects, n_params)); + for subject_index in 0..n_subjects { + if let Some(individual) = individual_estimates.get(subject_index) { + for param_index in 0..n_params { + eta_matrix[[subject_index, param_index]] = individual.eta()[param_index]; + } + } + } + + let log_likelihoods = match batch_log_likelihood_from_eta( + equation, + data, + error_models, + transforms, + &eta_matrix, + mean_phi, + ) { + Ok(log_likelihoods) => log_likelihoods, + Err(_) => return f64::INFINITY, + }; + + let omega_inv = match population.omega().llt(faer::Side::Lower) { + Ok(llt) => llt.inverse(), + Err(_) => return f64::INFINITY, + }; + + let mut total_ll = 0.0; + for (subject_index, log_likelihood) in log_likelihoods.iter().enumerate() { + if !log_likelihood.is_finite() { + continue; + } + + let mut prior_term = 0.0; + if let Some(individual) = individual_estimates.get(subject_index) { + let eta = individual.eta(); + for row in 0..n_params { + for col in 0..n_params { + prior_term += eta[row] * omega_inv[(row, col)] * eta[col]; + } + } + } + + total_ll += log_likelihood - 0.5 * prior_term; + } + + -2.0 * total_ll +} + +pub(crate) fn refresh_saem_objective_history( + objf: &mut f64, + prev_objf: &mut f64, + preserve_previous: bool, + equation: &E, + data: &Data, + error_models: &ResidualErrorModels, + transforms: &[ParameterTransform], + population: &Population, + individual_estimates: &IndividualEstimates, + mean_phi: &[Col], +) { + if preserve_previous { + *prev_objf = *objf; + } + + *objf = approximate_objective_from_individuals( + equation, + data, + error_models, + transforms, + population, + individual_estimates, + mean_phi, + ); +} + +pub fn subject_objective_from_eta( + subject_index: usize, + eta_matrix: &Array2, + log_likelihood: f64, + omega_inv: &Mat, +) -> f64 { + let quad = subject_eta_quadratic_form(subject_index, eta_matrix, omega_inv); + + -2.0 * (log_likelihood - 0.5 * quad) +} + +pub fn subject_log_prior_from_eta( + subject_index: usize, + eta_matrix: &Array2, + omega_inv: &Mat, +) -> f64 { + -0.5 * subject_eta_quadratic_form(subject_index, eta_matrix, omega_inv) +} + +pub fn log_priors_from_eta_matrix(eta_matrix: &Array2, omega_inv: &Mat) -> Vec { + (0..eta_matrix.nrows()) + .map(|subject_index| subject_log_prior_from_eta(subject_index, eta_matrix, omega_inv)) + .collect() +} + +fn subject_eta_quadratic_form( + subject_index: usize, + eta_matrix: &Array2, + omega_inv: &Mat, +) -> f64 { + let n_params = eta_matrix.ncols(); + (0..n_params) + .flat_map(|row| { + (0..n_params).map(move |col| { + eta_matrix[[subject_index, row]] + * omega_inv[(row, col)] + * eta_matrix[[subject_index, col]] + }) + }) + .sum::() +} + +pub fn estimate_initial_sigma_sq(_error_models: &ResidualErrorModels) -> f64 { + 1.0 +} + +pub fn sync_error_models_with_sigma(error_models: &mut ResidualErrorModels, sigma_sq: f64) { + error_models.update_sigma(sigma_sq.sqrt()); +} + +pub fn update_residual_error_from_individuals( + equation: &E, + data: &Data, + error_models: &mut ResidualErrorModels, + transforms: &[ParameterTransform], + individual_estimates: &IndividualEstimates, + step_size: f64, + sigma_sq: f64, + statrese: f64, + use_annealed_sigma_floor: bool, + sa_alpha: f64, + allow_sigma_update: bool, +) -> Result { + let mut sum_weighted_sq_residuals = 0.0; + let mut n_observations = 0; + + for (subject_index, subject) in data.subjects().iter().enumerate() { + let Some(individual) = individual_estimates.get(subject_index) else { + continue; + }; + + let phi = PhiVector::from(individual.psi()); + let params = phi_to_psi(transforms, &phi).0; + + let Ok(predictions) = equation.estimate_predictions(subject, ¶ms) else { + continue; + }; + + let observations: Vec<_> = subject + .occasions() + .iter() + .flat_map(|occasion| occasion.events().iter()) + .filter_map(|event| { + if let Event::Observation(observation) = event { + observation + .value() + .map(|value| (value, observation.outeq())) + } else { + None + } + }) + .collect(); + + for ((observation_value, outeq), prediction) in observations + .iter() + .zip(predictions.get_predictions().iter()) + { + let predicted_value = prediction.prediction(); + if let Some(error_model) = error_models.get(*outeq) { + sum_weighted_sq_residuals += + error_model.weighted_squared_residual(*observation_value, predicted_value); + } else { + let residual = observation_value - predicted_value; + sum_weighted_sq_residuals += residual * residual; + } + n_observations += 1; + } + } + + if n_observations == 0 { + return Ok(ResidualErrorUpdate { + sigma_sq, + statrese, + n_observations, + }); + } + + let updated_statrese = if step_size > 0.0 { + statrese + step_size * (sum_weighted_sq_residuals - statrese) + } else { + statrese + }; + let sig2 = updated_statrese / n_observations as f64; + let updated_sigma_sq = if use_annealed_sigma_floor { + let decayed_sigma = sigma_sq.sqrt() * sa_alpha; + decayed_sigma.max(sig2.sqrt()).powi(2) + } else if allow_sigma_update { + sig2 + } else { + sigma_sq + }; + + sync_error_models_with_sigma(error_models, updated_sigma_sq); + + Ok(ResidualErrorUpdate { + sigma_sq: updated_sigma_sq, + statrese: updated_statrese, + n_observations, + }) +} + +pub fn importance_sampling_likelihood_estimates( + equation: &E, + subjects: Vec<&Subject>, + error_models: &ResidualErrorModels, + transforms: &[ParameterTransform], + mu_phi: &Col, + omega: &Mat, + conditionals: &[SubjectConditionalPosterior], + config: ImportanceSamplingConfig, +) -> Result { + let estimator = ImportanceSamplingEstimator::new( + config.clone(), + equation, + error_models, + transforms, + mu_phi, + omega, + )?; + let minus2ll = estimator.estimate_minus2ll(subjects, conditionals); + + let mut estimates = LikelihoodEstimates::new(); + if minus2ll.is_finite() { + estimates.ll_importance_sampling = Some(-minus2ll / 2.0); + estimates.is_n_samples = Some(config.n_samples); + } + + Ok(estimates) +} + +pub fn subject_conditionals_from_eta_samples( + eta_samples_by_subject: &[Vec>], + fallback_individuals: Option<&IndividualEstimates>, + mu_phi: &Col, + omega: &Mat, +) -> Vec { + eta_samples_by_subject + .iter() + .enumerate() + .map(|(index, eta_samples)| { + if eta_samples.is_empty() { + fallback_subject_conditional(fallback_individuals, index, mu_phi, omega) + } else { + SubjectConditionalPosterior::from_mcmc_samples(eta_samples, mu_phi, omega) + } + }) + .collect() +} + +fn fallback_subject_conditional( + fallback_individuals: Option<&IndividualEstimates>, + index: usize, + mu_phi: &Col, + omega: &Mat, +) -> SubjectConditionalPosterior { + let mean = fallback_individuals + .and_then(|individuals| individuals.get(index)) + .map(|individual| individual.psi().clone()) + .unwrap_or_else(|| mu_phi.clone()); + let variance = (0..mu_phi.nrows()) + .map(|parameter_index| omega[(parameter_index, parameter_index)].max(1e-6)) + .collect(); + + SubjectConditionalPosterior::new(mean, variance) +} + +#[cfg(test)] +mod tests { + use faer::{Col, Mat}; + + use super::subject_conditionals_from_eta_samples; + use crate::estimation::parametric::{Individual, IndividualEstimates}; + + #[test] + fn falls_back_to_individual_phi_when_no_chain_samples_exist() { + let mu_phi = Col::from_fn(2, |index| if index == 0 { 0.0 } else { 1.0 }); + let omega = Mat::from_fn(2, 2, |row, col| if row == col { 0.25 } else { 0.0 }); + let individual = Individual::new( + "1", + Col::from_fn(2, |_| 0.0), + Col::from_fn(2, |index| if index == 0 { 0.2 } else { 1.3 }), + ) + .expect("valid individual"); + let individuals = IndividualEstimates::from_vec(vec![individual]); + + let conditionals = subject_conditionals_from_eta_samples( + &[Vec::new()], + Some(&individuals), + &mu_phi, + &omega, + ); + + assert_eq!(conditionals.len(), 1); + assert!((conditionals[0].mean[0] - 0.2).abs() < 1e-12); + assert!((conditionals[0].mean[1] - 1.3).abs() < 1e-12); + assert_eq!(conditionals[0].variance, vec![0.25, 0.25]); + } +} diff --git a/src/estimation/parametric/mod.rs b/src/estimation/parametric/mod.rs new file mode 100644 index 000000000..e2e6631b1 --- /dev/null +++ b/src/estimation/parametric/mod.rs @@ -0,0 +1,73 @@ +mod assembler; +mod compiler; +mod effects; +mod engine; +mod individual; +mod integration; +mod likelihood; +mod population; +mod posthoc; +mod predictions; +mod reporting; +mod sampling; +mod state; +mod statistics; +mod sufficient_stats; +mod summaries; +mod transforms; +mod uncertainty; +mod workspace; + +pub(crate) use assembler::{ + assemble_parametric_result, finalize_saem_result, ParametricResultInput, SaemFinalizeInput, +}; +pub use compiler::compile_model_state; +pub(crate) use effects::{ + blended_subject_covariate_m_step, build_parametric_covariate_context, + covariance_from_individual_etas, covariance_from_subject_means, covariate_state, estimate_beta, + recenter_individual_estimates, subject_mean_phi, ParametricCovariateContext, +}; +pub use engine::{fit, ParametricEngine}; +pub use individual::{Individual, IndividualEstimates}; +pub(crate) use integration::{ + ImportanceSamplingConfig, ImportanceSamplingEstimator, SubjectConditionalPosterior, +}; +pub(crate) use likelihood::refresh_saem_objective_history; +pub use likelihood::{ + approximate_objective_from_individuals, batch_log_likelihood_from_eta, + estimate_initial_sigma_sq, importance_sampling_likelihood_estimates, + log_priors_from_eta_matrix, subject_conditionals_from_eta_samples, subject_log_prior_from_eta, + subject_objective_from_eta, sync_error_models_with_sigma, + update_residual_error_from_individuals, ResidualErrorUpdate, +}; +pub(crate) use population::ensure_positive_definite_covariance; +pub use population::{CovarianceStructure, Population}; +pub use posthoc::{aic, bic, cache_predictions, shrinkage, statistics, write_statistics}; +pub use predictions::{ParametricPredictionRow, ParametricPredictions, PredictionSummary}; +pub use reporting::{FimMethod, LikelihoodEstimates, ParametricIterationLog, UncertaintyEstimates}; +pub(crate) use sampling::{ + advance_saem_chains, sample_eta_from_population, ChainState, KernelConfig, SaemMcmcState, +}; +pub use state::{ + CovariateEffectsSnapshot, CovariateState, EtaTable, EtaVector, FixedEffects, + IndividualEffectsState, KappaVector, OccasionKappa, OccasionKappaTable, ParametricModelState, + ParametricTransformKind, PhiTable, PhiVector, PsiTable, PsiVector, RandomEffects, + ResidualState, TransformSet, +}; +pub(crate) use statistics::{ + residual_error_estimates_from_models, residual_error_estimates_from_observed_outeqs, +}; +pub use statistics::{ParametricStatistics, ResidualErrorEstimates}; +pub use sufficient_stats::{StepSizeSchedule, SufficientStats}; +pub use summaries::{fit_summary, individual_summaries, population_summary}; +pub(crate) use transforms::initialize_population_in_phi_space; +pub use transforms::{ + default_phi_variance, phi_to_psi, phi_to_psi_vec, psi_to_phi, psi_to_phi_vec, transform_label, + transforms_from_saemix_codes, ParameterTransform, +}; +pub(crate) use uncertainty::focei_linearization_uncertainty; +pub use uncertainty::{ + estimates as uncertainty_estimates, fim, fim_inverse, fim_method, has_fim, has_standard_errors, + rse_mu, se_mu, se_omega, +}; +pub use workspace::ParametricWorkspace; diff --git a/src/estimation/parametric/population.rs b/src/estimation/parametric/population.rs new file mode 100644 index 000000000..45d714eaf --- /dev/null +++ b/src/estimation/parametric/population.rs @@ -0,0 +1,441 @@ +//! Parametric population representation. + +use anyhow::{bail, Result}; +use faer::{Col, Mat}; +use serde::{Deserialize, Serialize}; + +use crate::model::{ParameterDomain, ParameterSpace, ParameterTransform}; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] +pub enum CovarianceStructure { + #[default] + Full, + Diagonal, + BlockDiagonal(Vec), +} + +#[derive(Debug, Clone)] +pub struct Population { + mu: Col, + omega: Mat, + parameters: ParameterSpace, + structure: CovarianceStructure, +} + +impl Population { + pub fn new( + mu: Col, + omega: Mat, + parameters: impl Into, + ) -> Result { + let parameters = parameters.into(); + let n = mu.nrows(); + + if omega.nrows() != omega.ncols() { + bail!( + "Covariance matrix must be square, got {}x{}", + omega.nrows(), + omega.ncols() + ); + } + + if omega.nrows() != n { + bail!( + "Covariance matrix dimension ({}) must match mean vector length ({})", + omega.nrows(), + n + ); + } + + if parameters.len() != n { + bail!( + "Number of parameters ({}) must match mean vector length ({})", + parameters.len(), + n + ); + } + + Ok(Self { + mu, + omega, + parameters, + structure: CovarianceStructure::Full, + }) + } + + pub fn new_diagonal( + mu: Col, + variances: Col, + parameters: impl Into, + ) -> Result { + let n = mu.nrows(); + + if variances.nrows() != n { + bail!( + "Variances length ({}) must match mean vector length ({})", + variances.nrows(), + n + ); + } + + let omega = Mat::from_fn(n, n, |i, j| if i == j { variances[i] } else { 0.0 }); + + let mut pop = Self::new(mu, omega, parameters)?; + pop.structure = CovarianceStructure::Diagonal; + Ok(pop) + } + + pub fn from_parameter_space(parameters: impl Into) -> Result { + let parameters = parameters.into(); + let n = parameters.len(); + + if n == 0 { + bail!("Cannot create population with zero parameters"); + } + + let mu = Col::from_fn(n, |i| { + let param = ¶meters.iter().nth(i).unwrap(); + let (lower, upper) = parameter_bounds(param); + (lower + upper) / 2.0 + }); + + let omega = Mat::from_fn(n, n, |i, j| { + if i == j { + let param = ¶meters.iter().nth(i).unwrap(); + let (lower, upper) = parameter_bounds(param); + let range = upper - lower; + (range / 4.0).powi(2) + } else { + 0.0 + } + }); + + Self::new(mu, omega, parameters) + } + + pub fn mu(&self) -> &Col { + &self.mu + } + + pub fn mu_mut(&mut self) -> &mut Col { + &mut self.mu + } + + pub fn omega(&self) -> &Mat { + &self.omega + } + + pub fn omega_mut(&mut self) -> &mut Mat { + &mut self.omega + } + + pub fn parameters(&self) -> &ParameterSpace { + &self.parameters + } + + pub fn npar(&self) -> usize { + self.mu.nrows() + } + + pub fn structure(&self) -> &CovarianceStructure { + &self.structure + } + + pub fn set_structure(&mut self, structure: CovarianceStructure) { + self.structure = structure; + } + + pub fn param_names(&self) -> Vec { + self.parameters.names() + } + + pub fn standard_deviations(&self) -> Col { + Col::from_fn(self.npar(), |i| self.omega[(i, i)].sqrt()) + } + + pub fn correlation_matrix(&self) -> Mat { + let n = self.npar(); + let sds = self.standard_deviations(); + + Mat::from_fn(n, n, |i, j| { + if sds[i] > 0.0 && sds[j] > 0.0 { + self.omega[(i, j)] / (sds[i] * sds[j]) + } else if i == j { + 1.0 + } else { + 0.0 + } + }) + } + + pub fn coefficient_of_variation(&self) -> Col { + Col::from_fn(self.npar(), |i| { + let omega_ii = self.omega[(i, i)]; + let sd = omega_ii.sqrt(); + match self.parameters.items[i].transform { + ParameterTransform::Identity => { + let mu = self.mu[i].abs(); + if mu > 1e-10 { + 100.0 * sd / mu + } else { + f64::NAN + } + } + ParameterTransform::LogNormal => ((omega_ii.exp() - 1.0).sqrt()) * 100.0, + ParameterTransform::Logit | ParameterTransform::Probit => f64::NAN, + } + }) + } + + pub fn update_from_sufficient_stats( + &mut self, + stats: &crate::estimation::parametric::SufficientStats, + ) { + let n = stats.count() as f64; + + for i in 0..self.npar() { + self.mu[i] = stats.s1()[i] / n; + } + + for i in 0..self.npar() { + for j in 0..self.npar() { + self.omega[(i, j)] = stats.s2()[(i, j)] / n - self.mu[i] * self.mu[j]; + } + } + + self.apply_structure_constraint(); + } + + fn apply_structure_constraint(&mut self) { + match &self.structure { + CovarianceStructure::Full => {} + CovarianceStructure::Diagonal => { + for i in 0..self.npar() { + for j in 0..self.npar() { + if i != j { + self.omega[(i, j)] = 0.0; + } + } + } + } + CovarianceStructure::BlockDiagonal(blocks) => { + let mut current_start = 0; + let mut block_ranges: Vec<(usize, usize)> = Vec::new(); + + for &block_size in blocks { + block_ranges.push((current_start, current_start + block_size)); + current_start += block_size; + } + + for i in 0..self.npar() { + for j in 0..self.npar() { + let in_same_block = block_ranges + .iter() + .any(|&(start, end)| i >= start && i < end && j >= start && j < end); + + if !in_same_block { + self.omega[(i, j)] = 0.0; + } + } + } + } + } + } + + pub fn update_mu(&mut self, mu: Col) -> Result<()> { + if mu.nrows() != self.npar() { + bail!( + "Mean vector length ({}) doesn't match population size ({})", + mu.nrows(), + self.npar() + ); + } + self.mu = mu; + Ok(()) + } + + pub fn update_omega(&mut self, omega: Mat) -> Result<()> { + let n = self.npar(); + if omega.nrows() != n || omega.ncols() != n { + bail!( + "Omega dimensions ({}x{}) don't match population size ({})", + omega.nrows(), + omega.ncols(), + n + ); + } + self.omega = omega; + self.apply_structure_constraint(); + Ok(()) + } + + pub fn mu_as_vec(&self) -> Vec { + (0..self.npar()).map(|i| self.mu[i]).collect() + } + + pub fn variances_as_vec(&self) -> Vec { + (0..self.npar()).map(|i| self.omega[(i, i)]).collect() + } +} + +impl Default for Population { + fn default() -> Self { + Self { + mu: Col::zeros(0), + omega: Mat::zeros(0, 0), + parameters: ParameterSpace::new(), + structure: CovarianceStructure::Full, + } + } +} + +pub(crate) fn ensure_positive_definite_covariance(omega: &Mat) -> Mat { + let n = omega.nrows(); + let min_var = 1e-8; + let mut result = omega.clone(); + + for index in 0..n { + if result[(index, index)] < min_var { + result[(index, index)] = min_var; + } + } + + if result.llt(faer::Side::Lower).is_err() { + let ridge = result + .diagonal() + .column_vector() + .iter() + .cloned() + .fold(0.0_f64, f64::max) + * 0.01 + + min_var; + for index in 0..n { + result[(index, index)] += ridge; + } + tracing::debug!("Added ridge {:.2e} to Omega diagonal to ensure PD", ridge); + } + + result +} + +fn parameter_bounds(parameter: &crate::model::ParameterSpec) -> (f64, f64) { + match parameter.domain { + ParameterDomain::Bounded { lower, upper } => (lower, upper), + ParameterDomain::Positive { lower, upper } => { + (lower.unwrap_or(0.0), upper.unwrap_or(1.0e6)) + } + ParameterDomain::Unbounded { lower, upper } => { + (lower.unwrap_or(-1.0e6), upper.unwrap_or(1.0e6)) + } + } +} + +impl Serialize for Population { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + + let mut state = serializer.serialize_struct("Population", 4)?; + + let mu_vec: Vec = (0..self.mu.nrows()).map(|i| self.mu[i]).collect(); + state.serialize_field("mu", &mu_vec)?; + + let omega_vec: Vec> = (0..self.omega.nrows()) + .map(|i| { + (0..self.omega.ncols()) + .map(|j| self.omega[(i, j)]) + .collect() + }) + .collect(); + state.serialize_field("omega", &omega_vec)?; + + state.serialize_field("parameters", &self.parameters)?; + state.serialize_field("structure", &self.structure)?; + state.end() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_population_creation() { + let params = ParameterSpace::new() + .add(crate::model::ParameterSpec::bounded("CL", 0.1, 10.0)) + .add(crate::model::ParameterSpec::bounded("V", 1.0, 100.0)); + + let mu = Col::from_fn(2, |i| if i == 0 { 5.0 } else { 50.0 }); + let omega = Mat::from_fn(2, 2, |i, j| if i == j { 0.1 } else { 0.0 }); + + let pop = Population::new(mu, omega, params).unwrap(); + + assert_eq!(pop.npar(), 2); + assert_eq!(pop.mu()[0], 5.0); + assert_eq!(pop.omega()[(0, 0)], 0.1); + } + + #[test] + fn test_from_parameter_space() { + let params = ParameterSpace::new() + .add(crate::model::ParameterSpec::bounded("CL", 0.0, 10.0)) + .add(crate::model::ParameterSpec::bounded("V", 0.0, 100.0)); + + let pop = Population::from_parameter_space(params).unwrap(); + + assert_eq!(pop.mu()[0], 5.0); + assert_eq!(pop.mu()[1], 50.0); + } + + #[test] + fn test_diagonal_structure() { + let params = ParameterSpace::new() + .add(crate::model::ParameterSpec::bounded("CL", 0.1, 10.0)) + .add(crate::model::ParameterSpec::bounded("V", 1.0, 100.0)); + + let mu = Col::from_fn(2, |_| 1.0); + let variances = Col::from_fn(2, |_| 0.1); + + let pop = Population::new_diagonal(mu, variances, params).unwrap(); + + assert_eq!(*pop.structure(), CovarianceStructure::Diagonal); + assert_eq!(pop.omega()[(0, 1)], 0.0); + assert_eq!(pop.omega()[(1, 0)], 0.0); + } + + #[test] + fn test_coefficient_of_variation_uses_identity_formula_for_identity_parameters() { + let params = + ParameterSpace::new().add(crate::model::ParameterSpec::bounded("V", 50.0, 180.0)); + + let population = Population::new( + Col::from_fn(1, |_| 100.0), + Mat::from_fn(1, 1, |_, _| 25.0), + params, + ) + .unwrap(); + + let cv = population.coefficient_of_variation(); + + assert!((cv[0] - 5.0).abs() < 1e-10); + } + + #[test] + fn test_coefficient_of_variation_uses_lognormal_formula_for_lognormal_parameters() { + let params = ParameterSpace::new().add(crate::model::ParameterSpec::positive("V")); + + let population = Population::new( + Col::from_fn(1, |_| 4.0), + Mat::from_fn(1, 1, |_, _| 0.25), + params, + ) + .unwrap(); + + let cv = population.coefficient_of_variation(); + let expected = ((0.25_f64.exp() - 1.0).sqrt()) * 100.0; + + assert!((cv[0] - expected).abs() < 1e-10); + } +} diff --git a/src/estimation/parametric/posthoc.rs b/src/estimation/parametric/posthoc.rs new file mode 100644 index 000000000..f7a567d00 --- /dev/null +++ b/src/estimation/parametric/posthoc.rs @@ -0,0 +1,157 @@ +use anyhow::Result; +use faer::Col; +use pharmsol::{Data, Equation, Event, ResidualErrorModels}; + +use crate::estimation::parametric::{ + importance_sampling_likelihood_estimates, subject_conditionals_from_eta_samples, ChainState, + ImportanceSamplingConfig, IndividualEstimates, LikelihoodEstimates, ParameterTransform, + ParametricPredictions, ParametricStatistics, ParametricWorkspace, Population, +}; + +pub fn cache_predictions( + result: &mut ParametricWorkspace, + idelta: f64, + tad: f64, +) -> Result<()> { + let sigma_val = result.sigma().additive.or(result.sigma().proportional); + let predictions = ParametricPredictions::calculate( + result.equation(), + result.data(), + result.population(), + result.individual_estimates(), + sigma_val, + idelta, + tad, + )?; + result.set_predictions(predictions); + Ok(()) +} + +pub fn statistics(result: &ParametricWorkspace) -> ParametricStatistics { + let n_observations = result + .data() + .subjects() + .iter() + .flat_map(|subject| subject.occasions()) + .flat_map(|occasion| occasion.events()) + .filter(|event| matches!(event, Event::Observation(_))) + .count(); + + ParametricStatistics::from_result( + result.population(), + result.individual_estimates(), + result.objf(), + result.iterations(), + result.converged(), + result.data().len(), + n_observations, + result.likelihoods().ll_importance_sampling, + result.likelihoods().ll_linearization, + result.likelihoods().ll_gaussian_quadrature, + result.sigma().as_vec(), + ) +} + +pub fn write_statistics(result: &ParametricWorkspace) -> Result<()> { + let stats = statistics(result); + stats.write(result.output_folder())?; + stats.write_shrinkage(result.output_folder(), &result.population().param_names())?; + Ok(()) +} + +pub(crate) fn eta_samples_by_subject(chain_states: &[Vec]) -> Vec>> { + chain_states + .iter() + .map(|states| states.iter().map(|state| state.eta.clone()).collect()) + .collect() +} + +pub(crate) fn saem_posthoc_likelihood( + equation: &E, + data: &Data, + error_models: &ResidualErrorModels, + transforms: &[ParameterTransform], + population: &Population, + individual_estimates: &IndividualEstimates, + eta_samples_by_subject: &[Vec>], + seed: u64, +) -> Result<(LikelihoodEstimates, f64)> { + let conditionals = subject_conditionals_from_eta_samples( + eta_samples_by_subject, + Some(individual_estimates), + population.mu(), + population.omega(), + ); + let likelihood_estimates = importance_sampling_likelihood_estimates( + equation, + data.subjects().iter().copied().collect(), + error_models, + transforms, + population.mu(), + population.omega(), + &conditionals, + ImportanceSamplingConfig::saemix_defaults() + .with_n_samples(10000) + .with_seed(seed + 12345), + )?; + let minus2ll = likelihood_estimates + .best_objf() + .unwrap_or(f64::NEG_INFINITY); + + Ok((likelihood_estimates, minus2ll)) +} + +pub fn shrinkage(result: &ParametricWorkspace) -> Option> { + let n = result.population().npar(); + let pop_var = Col::from_fn(n, |index| result.population().omega()[(index, index)]); + result.individual_estimates().shrinkage(&pop_var) +} + +pub fn aic(result: &ParametricWorkspace) -> f64 { + let n_params = result.population().npar(); + let n_fixed = n_params; + let n_random = n_params * (n_params + 1) / 2; + let k = n_fixed + n_random; + result.best_objf() + 2.0 * k as f64 +} + +pub fn bic(result: &ParametricWorkspace) -> f64 { + let n_subjects = result.data().subjects().len(); + let n_params = result.population().npar(); + let n_fixed = n_params; + let n_random = n_params * (n_params + 1) / 2; + let k = n_fixed + n_random; + result.best_objf() + (k as f64) * (n_subjects as f64).ln() +} + +#[cfg(test)] +mod tests { + use super::eta_samples_by_subject; + use crate::estimation::parametric::ChainState; + use faer::Col; + + #[test] + fn test_eta_samples_by_subject_preserves_chain_order() { + let chain_states = vec![ + vec![ + ChainState::new(Col::from_fn(2, |index| if index == 0 { 1.0 } else { 2.0 })), + ChainState::new(Col::from_fn(2, |index| if index == 0 { 3.0 } else { 4.0 })), + ], + vec![ChainState::new(Col::from_fn(2, |index| { + if index == 0 { + 5.0 + } else { + 6.0 + } + }))], + ]; + + let samples = eta_samples_by_subject(&chain_states); + + assert_eq!(samples.len(), 2); + assert_eq!(samples[0].len(), 2); + assert_eq!(samples[0][0][0], 1.0); + assert_eq!(samples[0][1][1], 4.0); + assert_eq!(samples[1][0][0], 5.0); + } +} diff --git a/src/estimation/parametric/predictions.rs b/src/estimation/parametric/predictions.rs new file mode 100644 index 000000000..84f410431 --- /dev/null +++ b/src/estimation/parametric/predictions.rs @@ -0,0 +1,307 @@ +//! Parametric algorithm predictions. + +use anyhow::{Context, Result}; +use pharmsol::{Censor, Data, Equation, Predictions as PredTrait}; +use serde::{Deserialize, Serialize}; + +use crate::estimation::parametric::{IndividualEstimates, Population}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ParametricPredictionRow { + id: String, + time: f64, + outeq: usize, + block: usize, + obs: Option, + cens: Censor, + ppred: f64, + ipred: f64, + ires: Option, + iwres: Option, +} + +impl ParametricPredictionRow { + pub fn id(&self) -> &str { + &self.id + } + + pub fn time(&self) -> f64 { + self.time + } + + pub fn outeq(&self) -> usize { + self.outeq + } + + pub fn block(&self) -> usize { + self.block + } + + pub fn obs(&self) -> Option { + self.obs + } + + pub fn censoring(&self) -> Censor { + self.cens + } + + pub fn ppred(&self) -> f64 { + self.ppred + } + + pub fn ipred(&self) -> f64 { + self.ipred + } + + pub fn ires(&self) -> Option { + self.ires + } + + pub fn iwres(&self) -> Option { + self.iwres + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct ParametricPredictions { + predictions: Vec, +} + +impl IntoIterator for ParametricPredictions { + type Item = ParametricPredictionRow; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.predictions.into_iter() + } +} + +impl ParametricPredictions { + pub fn new() -> Self { + Self { + predictions: Vec::new(), + } + } + + pub fn add(&mut self, row: ParametricPredictionRow) { + self.predictions.push(row); + } + + pub fn predictions(&self) -> &[ParametricPredictionRow] { + &self.predictions + } + + pub fn len(&self) -> usize { + self.predictions.len() + } + + pub fn is_empty(&self) -> bool { + self.predictions.is_empty() + } + + pub fn calculate( + equation: &E, + data: &Data, + population: &Population, + individual_estimates: &IndividualEstimates, + sigma: Option, + idelta: f64, + tad: f64, + ) -> Result { + let mut container = Self::new(); + let expanded_data = data.clone().expand(idelta, tad); + let subjects = expanded_data.subjects(); + + let mu: Vec = (0..population.npar()).map(|i| population.mu()[i]).collect(); + + for subject in subjects.iter() { + let individual = individual_estimates + .iter() + .find(|ind| ind.subject_id() == subject.id()); + + let psi: Vec = match individual { + Some(ind) => (0..ind.npar()).map(|i| ind.psi()[i]).collect(), + None => mu.clone(), + }; + + let ppred_result = equation + .simulate_subject(subject, &mu, None) + .context(format!( + "Failed to simulate subject {} with population parameters", + subject.id() + ))?; + let ppred_vec = ppred_result.0.get_predictions(); + + let ipred_result = equation + .simulate_subject(subject, &psi, None) + .context(format!( + "Failed to simulate subject {} with individual parameters", + subject.id() + ))?; + let ipred_vec = ipred_result.0.get_predictions(); + + for (ppred, ipred) in ppred_vec.iter().zip(ipred_vec.iter()) { + let obs = ppred.observation(); + let (ires, iwres) = if let Some(y) = obs { + let res = y - ipred.prediction(); + let wres = sigma.map(|s| if s > 0.0 { res / s } else { f64::NAN }); + (Some(res), wres) + } else { + (None, None) + }; + + container.add(ParametricPredictionRow { + id: subject.id().clone(), + time: ppred.time(), + outeq: ppred.outeq(), + block: ppred.occasion(), + obs, + cens: ppred.censoring(), + ppred: ppred.prediction(), + ipred: ipred.prediction(), + ires, + iwres, + }); + } + } + + Ok(container) + } +} + +#[derive(Debug, Clone, Default, Serialize)] +pub struct PredictionSummary { + pub n_obs: usize, + pub mean_ppred: f64, + pub mean_ipred: f64, + pub mean_abs_ires: f64, + pub rmse_ipred: f64, + pub corr_obs_ipred: f64, +} + +impl ParametricPredictions { + pub fn summary(&self) -> PredictionSummary { + let obs_rows: Vec<_> = self + .predictions + .iter() + .filter(|r| r.obs.is_some()) + .collect(); + let n = obs_rows.len(); + + if n == 0 { + return PredictionSummary::default(); + } + + let sum_ppred: f64 = obs_rows.iter().map(|r| r.ppred).sum(); + let sum_ipred: f64 = obs_rows.iter().map(|r| r.ipred).sum(); + let sum_abs_ires: f64 = obs_rows + .iter() + .filter_map(|r| r.ires.map(|v| v.abs())) + .sum(); + let sum_sq_ires: f64 = obs_rows.iter().filter_map(|r| r.ires.map(|v| v * v)).sum(); + + let obs_vec: Vec = obs_rows.iter().filter_map(|r| r.obs).collect(); + let ipred_vec: Vec = obs_rows.iter().map(|r| r.ipred).collect(); + + let mean_obs = obs_vec.iter().sum::() / n as f64; + let mean_ipred = ipred_vec.iter().sum::() / n as f64; + + let mut cov = 0.0; + let mut var_obs = 0.0; + let mut var_ipred = 0.0; + + for (obs, ipred) in obs_vec.iter().zip(ipred_vec.iter()) { + let d_obs = obs - mean_obs; + let d_ipred = ipred - mean_ipred; + cov += d_obs * d_ipred; + var_obs += d_obs * d_obs; + var_ipred += d_ipred * d_ipred; + } + + let corr = if var_obs > 0.0 && var_ipred > 0.0 { + cov / (var_obs.sqrt() * var_ipred.sqrt()) + } else { + 0.0 + }; + + PredictionSummary { + n_obs: n, + mean_ppred: sum_ppred / n as f64, + mean_ipred: sum_ipred / n as f64, + mean_abs_ires: sum_abs_ires / n as f64, + rmse_ipred: (sum_sq_ires / n as f64).sqrt(), + corr_obs_ipred: corr, + } + } +} + +#[cfg(test)] +mod tests { + use super::ParametricPredictions; + use crate::estimation::parametric::{Individual, IndividualEstimates, Population}; + use crate::model::{ParameterSpace, ParameterSpec}; + use crate::prelude::*; + use anyhow::Result; + use faer::{Col, Mat}; + use pharmsol::{Data, Subject}; + + fn equation() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, b, _rateiv, _cov| { + fetch_params!(p, ke); + dx[0] = -ke * x[0] + b[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, v); + y[0] = x[0] / v; + }, + ) + } + + fn data() -> Data { + let subject = Subject::builder("1") + .bolus(0.0, 100.0, 0) + .observation(1.0, 10.0, 0) + .build(); + Data::new(vec![subject]) + } + + #[test] + fn calculate_uses_canonical_psi_space_individual_parameters() -> Result<()> { + let parameters = ParameterSpace::new() + .add(ParameterSpec::bounded("ke", 0.1, 1.0)) + .add(ParameterSpec::bounded("v", 1.0, 20.0)); + let population = Population::new( + Col::from_fn(2, |index| if index == 0 { 0.5 } else { 10.0 }), + Mat::from_fn(2, 2, |row, col| if row == col { 0.1 } else { 0.0 }), + parameters, + )?; + let individuals = IndividualEstimates::from_vec(vec![Individual::new( + "1", + Col::from_fn(2, |_| 0.0), + Col::from_fn(2, |index| if index == 0 { 0.5 } else { 10.0 }), + )?]); + + let predictions = ParametricPredictions::calculate( + &equation(), + &data(), + &population, + &individuals, + None, + 1.0, + 0.0, + )?; + + let first_observation = predictions + .predictions() + .iter() + .find(|row| row.obs().is_some()) + .expect("prediction row with observation"); + + assert!((first_observation.ipred() - first_observation.ppred()).abs() < 1e-12); + Ok(()) + } +} diff --git a/src/estimation/parametric/reporting.rs b/src/estimation/parametric/reporting.rs new file mode 100644 index 000000000..dbeba710c --- /dev/null +++ b/src/estimation/parametric/reporting.rs @@ -0,0 +1,168 @@ +use anyhow::Result; +use csv::WriterBuilder; +use faer::{Col, Mat}; +use serde::ser::SerializeStruct; +use serde::{Deserialize, Serialize, Serializer}; + +use crate::algorithms::Status; +use crate::estimation::parametric::Population; +use crate::output::OutputFile; + +#[derive(Debug, Clone, Default, Serialize)] +pub struct LikelihoodEstimates { + pub ll_linearization: Option, + pub ll_importance_sampling: Option, + pub ll_gaussian_quadrature: Option, + pub is_n_samples: Option, + pub gq_n_points: Option, +} + +impl LikelihoodEstimates { + pub fn new() -> Self { + Self::default() + } + + pub fn best_estimate(&self) -> Option { + self.ll_gaussian_quadrature + .or(self.ll_importance_sampling) + .or(self.ll_linearization) + } + + pub fn best_objf(&self) -> Option { + self.best_estimate().map(|ll| -2.0 * ll) + } +} + +#[derive(Debug, Clone, Default)] +pub struct UncertaintyEstimates { + pub fim: Option>, + pub fim_inverse: Option>, + pub se_mu: Option>, + pub se_omega: Option>, + pub rse_mu: Option>, + pub fim_method: Option, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +pub enum FimMethod { + Observed, + Expected, + StochasticApproximation, + Linearization, +} + +#[derive(Debug, Clone, Default)] +pub struct ParametricIterationLog { + iterations: Vec, + objf: Vec, + mu_history: Vec>, + omega_diag_history: Vec>, + status: Vec, +} + +impl ParametricIterationLog { + pub fn new() -> Self { + Self::default() + } + + pub fn log_iteration( + &mut self, + iteration: usize, + objf: f64, + population: &Population, + status: &Status, + ) { + self.iterations.push(iteration); + self.objf.push(objf); + self.mu_history + .push((0..population.npar()).map(|i| population.mu()[i]).collect()); + self.omega_diag_history.push( + (0..population.npar()) + .map(|i| population.omega()[(i, i)]) + .collect(), + ); + self.status.push(format!("{:?}", status)); + } + + pub fn len(&self) -> usize { + self.iterations.len() + } + + pub fn is_empty(&self) -> bool { + self.iterations.is_empty() + } + + pub fn objf_history(&self) -> &[f64] { + &self.objf + } + + pub fn write(&self, folder: &str, param_names: &[String]) -> Result<()> { + if self.is_empty() { + return Ok(()); + } + + let outputfile = OutputFile::new(folder, "iterations.csv")?; + let mut writer = WriterBuilder::new() + .has_headers(true) + .from_writer(outputfile.file()); + + let n_params = self + .mu_history + .first() + .map(|values| values.len()) + .unwrap_or(0); + let mut header = vec!["iteration".to_string(), "objf".to_string()]; + for name in param_names { + header.push(format!("mu_{}", name)); + } + for name in param_names { + header.push(format!("omega_{}", name)); + } + header.push("status".to_string()); + writer.write_record(&header)?; + + for index in 0..self.iterations.len() { + let mut row = vec![ + self.iterations[index].to_string(), + format!("{:.6}", self.objf[index]), + ]; + + for parameter_index in 0..n_params { + row.push(format!( + "{:.6}", + self.mu_history[index].get(parameter_index).unwrap_or(&0.0) + )); + } + + for parameter_index in 0..n_params { + row.push(format!( + "{:.6}", + self.omega_diag_history[index] + .get(parameter_index) + .unwrap_or(&0.0) + )); + } + + row.push(self.status[index].clone()); + writer.write_record(&row)?; + } + + writer.flush()?; + Ok(()) + } +} + +impl Serialize for ParametricIterationLog { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: Serializer, + { + let mut state = serializer.serialize_struct("ParametricIterationLog", 5)?; + state.serialize_field("iterations", &self.iterations)?; + state.serialize_field("objf", &self.objf)?; + state.serialize_field("mu_history", &self.mu_history)?; + state.serialize_field("omega_diag_history", &self.omega_diag_history)?; + state.serialize_field("status", &self.status)?; + state.end() + } +} diff --git a/src/estimation/parametric/sampling.rs b/src/estimation/parametric/sampling.rs new file mode 100644 index 000000000..28bf42aef --- /dev/null +++ b/src/estimation/parametric/sampling.rs @@ -0,0 +1,7 @@ +//! MCMC sampling infrastructure for parametric algorithms. + +mod kernels; + +pub(crate) use kernels::{ + advance_saem_chains, sample_eta_from_population, ChainState, KernelConfig, SaemMcmcState, +}; diff --git a/src/estimation/parametric/sampling/kernels.rs b/src/estimation/parametric/sampling/kernels.rs new file mode 100644 index 000000000..10ef0814e --- /dev/null +++ b/src/estimation/parametric/sampling/kernels.rs @@ -0,0 +1,388 @@ +//! Core MCMC types for SAEM algorithm. + +use anyhow::Result; +use faer::Col; +use faer::Mat; +use ndarray::Array2; +use pharmsol::{Data, Equation, ResidualErrorModels}; +use rand::Rng; +use rand_distr::{Distribution, Normal}; + +use crate::estimation::parametric::{ + batch_log_likelihood_from_eta, log_priors_from_eta_matrix, ParameterTransform, Population, +}; + +#[derive(Debug, Clone)] +pub struct KernelConfig { + pub n_kernel1: usize, + pub n_kernel2: usize, + pub n_kernel3: usize, + pub n_kernel4: usize, + pub map_iterations: usize, + pub rw_step_size: f64, + pub target_acceptance: f64, + pub rw_init: f64, +} + +impl Default for KernelConfig { + fn default() -> Self { + Self { + n_kernel1: 2, + n_kernel2: 2, + n_kernel3: 2, + n_kernel4: 0, + map_iterations: 0, + rw_step_size: 0.4, + target_acceptance: 0.4, + rw_init: 0.5, + } + } +} + +#[derive(Debug, Clone)] +pub struct ChainState { + pub eta: Col, + pub log_likelihood: f64, + pub log_prior: f64, +} + +impl Default for ChainState { + fn default() -> Self { + Self { + eta: Col::zeros(0), + log_likelihood: f64::NEG_INFINITY, + log_prior: f64::NEG_INFINITY, + } + } +} + +impl ChainState { + pub fn new(eta: Col) -> Self { + Self { + eta, + log_likelihood: f64::NEG_INFINITY, + log_prior: f64::NEG_INFINITY, + } + } +} + +#[derive(Debug, Clone)] +pub struct SaemMcmcState { + pub eta_matrix: Array2, + pub log_likelihoods: Vec, + pub log_priors: Vec, +} + +pub(crate) fn advance_saem_chains( + equation: &E, + data: &Data, + error_models: &ResidualErrorModels, + transforms: &[ParameterTransform], + mean_phi: &[Col], + chol_omega: &Mat, + omega_inv: &Mat, + kernel_config: &KernelConfig, + iteration: usize, + domega2: &mut Col, + rng: &mut impl Rng, + mut eta_matrix: Array2, +) -> Result { + let n_subjects = eta_matrix.nrows(); + let n_params = eta_matrix.ncols(); + let normal = Normal::new(0.0, 1.0).unwrap(); + + let mut current_ll = batch_log_likelihood_from_eta( + equation, + data, + error_models, + transforms, + &eta_matrix, + mean_phi, + )?; + let mut current_log_prior = log_priors_from_eta_matrix(&eta_matrix, omega_inv); + + for _ in 0..kernel_config.n_kernel1 { + let proposed_eta = prior_proposals(chol_omega, n_subjects, rng, &normal); + let proposed_ll = batch_log_likelihood_from_eta( + equation, + data, + error_models, + transforms, + &proposed_eta, + mean_phi, + )?; + + for subject_index in 0..n_subjects { + let log_alpha = proposed_ll[subject_index] - current_ll[subject_index]; + let u: f64 = rng.random(); + if log_alpha.is_finite() && u.ln() < log_alpha { + for param_index in 0..n_params { + eta_matrix[[subject_index, param_index]] = + proposed_eta[[subject_index, param_index]]; + } + current_ll[subject_index] = proposed_ll[subject_index]; + current_log_prior[subject_index] = + subject_log_prior(subject_index, &eta_matrix, omega_inv); + } + } + } + + if kernel_config.n_kernel2 > 0 { + let mut accepted = vec![0usize; n_params]; + let mut total = vec![0usize; n_params]; + + for _ in 0..kernel_config.n_kernel2 { + for param_index in 0..n_params { + let mut proposed_eta = eta_matrix.clone(); + for subject_index in 0..n_subjects { + let perturbation = normal.sample(rng) * domega2[param_index]; + proposed_eta[[subject_index, param_index]] += perturbation; + } + + let proposed_ll = batch_log_likelihood_from_eta( + equation, + data, + error_models, + transforms, + &proposed_eta, + mean_phi, + )?; + let proposed_log_prior = log_priors_from_eta_matrix(&proposed_eta, omega_inv); + + for subject_index in 0..n_subjects { + let log_alpha = (proposed_ll[subject_index] + + proposed_log_prior[subject_index]) + - (current_ll[subject_index] + current_log_prior[subject_index]); + let u: f64 = rng.random(); + if log_alpha.is_finite() && u.ln() < log_alpha { + eta_matrix[[subject_index, param_index]] = + proposed_eta[[subject_index, param_index]]; + current_ll[subject_index] = proposed_ll[subject_index]; + current_log_prior[subject_index] = proposed_log_prior[subject_index]; + accepted[param_index] += 1; + } + total[param_index] += 1; + } + } + } + + adapt_proposal_scales(domega2, &accepted, &total, kernel_config); + } + + if kernel_config.n_kernel3 > 0 { + let mut accepted = vec![0usize; n_params]; + let mut total = vec![0usize; n_params]; + + for _ in 0..kernel_config.n_kernel3 { + let block_indices = block_indices_for_iteration(iteration, n_params, rng); + let mut proposed_eta = eta_matrix.clone(); + + for subject_index in 0..n_subjects { + for ¶m_index in &block_indices { + let perturbation = normal.sample(rng) * domega2[param_index]; + proposed_eta[[subject_index, param_index]] += perturbation; + } + } + + let proposed_ll = batch_log_likelihood_from_eta( + equation, + data, + error_models, + transforms, + &proposed_eta, + mean_phi, + )?; + let proposed_log_prior = log_priors_from_eta_matrix(&proposed_eta, omega_inv); + + for subject_index in 0..n_subjects { + let log_alpha = (proposed_ll[subject_index] + proposed_log_prior[subject_index]) + - (current_ll[subject_index] + current_log_prior[subject_index]); + let u: f64 = rng.random(); + if log_alpha.is_finite() && u.ln() < log_alpha { + for ¶m_index in &block_indices { + eta_matrix[[subject_index, param_index]] = + proposed_eta[[subject_index, param_index]]; + accepted[param_index] += 1; + } + current_ll[subject_index] = proposed_ll[subject_index]; + current_log_prior[subject_index] = proposed_log_prior[subject_index]; + } + for ¶m_index in &block_indices { + total[param_index] += 1; + } + } + } + + adapt_proposal_scales(domega2, &accepted, &total, kernel_config); + } + + Ok(SaemMcmcState { + eta_matrix, + log_likelihoods: current_ll, + log_priors: current_log_prior, + }) +} + +pub(crate) fn sample_eta_from_population(population: &Population, rng: &mut impl Rng) -> Col { + let n = population.npar(); + let normal = Normal::new(0.0, 1.0).unwrap(); + let z: Vec = (0..n).map(|_| normal.sample(rng)).collect(); + + let omega = population.omega(); + let chol = match omega.llt(faer::Side::Lower) { + Ok(llt) => llt.L().to_owned(), + Err(_) => { + let mut diagonal = Mat::zeros(n, n); + for i in 0..n { + diagonal[(i, i)] = omega[(i, i)].sqrt().max(1e-6); + } + diagonal + } + }; + + let mut eta = Col::zeros(n); + for i in 0..n { + for j in 0..=i { + eta[i] += chol[(i, j)] * z[j]; + } + } + + eta +} + +fn prior_proposals( + chol_omega: &Mat, + n_subjects: usize, + rng: &mut impl Rng, + normal: &Normal, +) -> Array2 { + let n_params = chol_omega.nrows(); + let mut proposed_eta = Array2::zeros((n_subjects, n_params)); + + for subject_index in 0..n_subjects { + let z: Vec = (0..n_params).map(|_| normal.sample(rng)).collect(); + for row in 0..n_params { + let mut sum = 0.0; + for col in 0..=row { + sum += chol_omega[(row, col)] * z[col]; + } + proposed_eta[[subject_index, row]] = sum; + } + } + + proposed_eta +} + +fn subject_log_prior(subject_index: usize, eta_matrix: &Array2, omega_inv: &Mat) -> f64 { + let n_params = eta_matrix.ncols(); + let quadratic = (0..n_params) + .flat_map(|row| { + (0..n_params).map(move |col| { + eta_matrix[[subject_index, row]] + * omega_inv[(row, col)] + * eta_matrix[[subject_index, col]] + }) + }) + .sum::(); + + -0.5 * quadratic +} + +fn block_indices_for_iteration( + iteration: usize, + n_params: usize, + rng: &mut impl Rng, +) -> Vec { + let block_size = ((iteration % n_params.max(2).saturating_sub(1)).max(1) + 1).min(n_params); + if block_size >= n_params { + return (0..n_params).collect(); + } + + let mut indices: Vec = (0..n_params).collect(); + for offset in 0..block_size { + let u: f64 = rng.random(); + let remaining = n_params - offset; + let swap_offset = (u * remaining as f64).floor() as usize; + let swap_index = offset + swap_offset.min(remaining - 1); + indices.swap(offset, swap_index); + } + + indices[..block_size].to_vec() +} + +fn adapt_proposal_scales( + domega2: &mut Col, + accepted: &[usize], + total: &[usize], + kernel_config: &KernelConfig, +) { + for param_index in 0..domega2.nrows() { + if total[param_index] > 0 { + let acc_rate = accepted[param_index] as f64 / total[param_index] as f64; + domega2[param_index] *= + 1.0 + kernel_config.rw_step_size * (acc_rate - kernel_config.target_acceptance); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::estimation::parametric::Population; + use crate::model::{ParameterSpace, ParameterSpec}; + use faer::Mat; + + #[test] + fn test_chain_state() { + let eta = Col::from_fn(3, |i| i as f64 * 0.1); + let state = ChainState::new(eta); + assert_eq!(state.log_likelihood, f64::NEG_INFINITY); + assert_eq!(state.log_prior, f64::NEG_INFINITY); + } + + #[test] + fn test_sample_eta_from_population_matches_dimension() { + let parameters = ParameterSpace::new() + .add(ParameterSpec::bounded("ke", 0.1, 1.0)) + .add(ParameterSpec::bounded("v", 1.0, 20.0)); + let population = Population::new( + Col::from_fn(2, |_| 0.0), + Mat::from_fn(2, 2, |row, col| if row == col { 0.5 } else { 0.0 }), + parameters, + ) + .unwrap(); + let mut rng = rand::rng(); + + let eta = sample_eta_from_population(&population, &mut rng); + + assert_eq!(eta.nrows(), 2); + } + + #[test] + fn test_adapt_proposal_scales_uses_acceptance_gap() { + let mut domega2 = Col::from_fn(2, |_| 1.0); + let config = KernelConfig { + rw_step_size: 0.5, + target_acceptance: 0.4, + ..KernelConfig::default() + }; + + adapt_proposal_scales(&mut domega2, &[8, 1], &[10, 10], &config); + + assert!(domega2[0] > 1.0); + assert!(domega2[1] < 1.0); + } + + #[test] + fn test_block_indices_respect_parameter_count() { + let mut rng = rand::rng(); + + let indices = block_indices_for_iteration(5, 4, &mut rng); + + assert!(!indices.is_empty()); + assert!(indices.len() <= 4); + for index in &indices { + assert!(*index < 4); + } + } +} diff --git a/src/estimation/parametric/state.rs b/src/estimation/parametric/state.rs new file mode 100644 index 000000000..911a3336e --- /dev/null +++ b/src/estimation/parametric/state.rs @@ -0,0 +1,356 @@ +use faer::{Col, Mat}; +use serde::{Deserialize, Serialize}; + +use crate::compile::OccasionDesign; +use crate::estimation::parametric::transforms::ParameterTransform; +use crate::estimation::parametric::ResidualErrorEstimates; +use crate::estimation::parametric::{IndividualEstimates, Population}; +use crate::model::{ + CovariateModel, ParameterTransform as ModelParameterTransform, VariabilityModel, +}; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct PhiVector(pub Vec); + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct PsiVector(pub Vec); + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct EtaVector(pub Vec); + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct KappaVector(pub Vec); + +impl PhiVector { + pub fn to_col(&self) -> Col { + Col::from_fn(self.0.len(), |index| self.0[index]) + } + + pub fn as_slice(&self) -> &[f64] { + &self.0 + } +} + +impl PsiVector { + pub fn to_col(&self) -> Col { + Col::from_fn(self.0.len(), |index| self.0[index]) + } + + pub fn as_slice(&self) -> &[f64] { + &self.0 + } +} + +impl EtaVector { + pub fn to_col(&self) -> Col { + Col::from_fn(self.0.len(), |index| self.0[index]) + } +} + +impl KappaVector { + pub fn to_col(&self) -> Col { + Col::from_fn(self.0.len(), |index| self.0[index]) + } +} + +impl From<&Col> for PhiVector { + fn from(value: &Col) -> Self { + Self(col_to_vec(value)) + } +} + +impl From<&Col> for PsiVector { + fn from(value: &Col) -> Self { + Self(col_to_vec(value)) + } +} + +impl From<&Col> for EtaVector { + fn from(value: &Col) -> Self { + Self(col_to_vec(value)) + } +} + +impl From<&Col> for KappaVector { + fn from(value: &Col) -> Self { + Self(col_to_vec(value)) + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct PhiTable(pub Vec>); + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct PsiTable(pub Vec>); + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct EtaTable(pub Vec>); + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct OccasionKappa { + pub subject_index: usize, + pub occasion_index: usize, + pub values: KappaVector, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct OccasionKappaTable(pub Vec); + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct FixedEffects { + pub parameter_names: Vec, + pub population_mean: PsiVector, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct RandomEffects { + pub covariance: Vec>, + pub standard_deviations: Vec, + pub correlation: Vec>, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ResidualState { + pub values: Vec, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct TransformSet { + pub transforms: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ParametricTransformKind { + Identity, + LogNormal, + Logit, + Probit, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct CovariateState { + pub subject_effects: Option, + pub occasion_effects: Option, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct CovariateEffectsSnapshot { + pub parameter_names: Vec, + pub column_names: Vec, + pub covariate_mask: Vec>, + pub coefficients: Vec, + pub estimate_coefficients: Vec, + pub values: Vec>>, +} + +impl CovariateEffectsSnapshot { + pub fn from_model(model: &CovariateModel, values: Vec>>) -> Self { + Self { + parameter_names: model.param_names().to_vec(), + column_names: model.covariate_names().to_vec(), + covariate_mask: model.covariate_mask().to_vec(), + coefficients: (0..model.beta().nrows()) + .map(|index| model.beta()[index]) + .collect(), + estimate_coefficients: model.estimate_beta().to_vec(), + values, + } + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct IndividualEffectsState { + pub subject_eta: EtaTable, + pub subject_psi: PsiTable, + pub occasion_kappa: Option, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ParametricModelState { + pub fixed_effects: FixedEffects, + pub random_effects: RandomEffects, + pub residual: ResidualState, + pub transforms: TransformSet, + pub covariates: CovariateState, + pub variability: VariabilityModel, +} + +impl ParametricModelState { + pub fn from_population_and_sigma( + population: &Population, + sigma: &ResidualErrorEstimates, + ) -> Self { + let parameter_names = population.param_names(); + let n_parameters = parameter_names.len(); + + Self { + fixed_effects: FixedEffects { + parameter_names, + population_mean: PsiVector(col_to_vec(population.mu())), + }, + random_effects: RandomEffects { + covariance: mat_to_nested_vec(population.omega()), + standard_deviations: col_to_vec(&population.standard_deviations()), + correlation: mat_to_nested_vec(&population.correlation_matrix()), + }, + residual: ResidualState { + values: residual_values(sigma), + }, + transforms: TransformSet { + transforms: vec![ParametricTransformKind::Identity; n_parameters], + }, + covariates: CovariateState { + subject_effects: None, + occasion_effects: None, + }, + variability: VariabilityModel::default(), + } + } + + pub fn merged(self, fitted: Self) -> Self { + let covariates = CovariateState { + subject_effects: fitted + .covariates + .subject_effects + .or(self.covariates.subject_effects), + occasion_effects: fitted + .covariates + .occasion_effects + .or(self.covariates.occasion_effects), + }; + + Self { + fixed_effects: fitted.fixed_effects, + random_effects: fitted.random_effects, + residual: fitted.residual, + transforms: self.transforms, + covariates, + variability: self.variability, + } + } +} + +impl IndividualEffectsState { + pub fn from_individual_estimates(individual_estimates: &IndividualEstimates) -> Self { + Self::from_individual_estimates_with_occasion_design( + individual_estimates, + &[], + &VariabilityModel::default(), + ) + } + + pub fn from_individual_estimates_with_occasion_design( + individual_estimates: &IndividualEstimates, + occasions: &[OccasionDesign], + variability: &VariabilityModel, + ) -> Self { + let subject_eta = individual_estimates + .iter() + .map(|individual| col_to_vec(individual.eta())) + .collect(); + let subject_psi = individual_estimates + .iter() + .map(|individual| col_to_vec(individual.psi())) + .collect(); + let n_parameters = individual_estimates + .get(0) + .map(|individual| individual.npar()) + .unwrap_or_else(|| variability.subject.enabled_for.len()); + + Self { + subject_eta: EtaTable(subject_eta), + subject_psi: PsiTable(subject_psi), + occasion_kappa: occasion_kappa_table(occasions, variability, n_parameters), + } + } + + pub fn with_occasion_design( + mut self, + occasions: &[OccasionDesign], + variability: &VariabilityModel, + n_parameters: usize, + ) -> Self { + self.occasion_kappa = occasion_kappa_table(occasions, variability, n_parameters); + self + } +} + +impl From<&ParameterTransform> for ParametricTransformKind { + fn from(transform: &ParameterTransform) -> Self { + match transform { + ParameterTransform::None => Self::Identity, + ParameterTransform::LogNormal => Self::LogNormal, + ParameterTransform::Logit { .. } => Self::Logit, + ParameterTransform::Probit { .. } => Self::Probit, + } + } +} + +impl From<&ModelParameterTransform> for ParametricTransformKind { + fn from(transform: &ModelParameterTransform) -> Self { + match transform { + ModelParameterTransform::Identity => Self::Identity, + ModelParameterTransform::LogNormal => Self::LogNormal, + ModelParameterTransform::Logit => Self::Logit, + ModelParameterTransform::Probit => Self::Probit, + } + } +} + +fn col_to_vec(col: &Col) -> Vec { + (0..col.nrows()).map(|index| col[index]).collect() +} + +fn mat_to_nested_vec(mat: &Mat) -> Vec> { + (0..mat.nrows()) + .map(|row| (0..mat.ncols()).map(|col| mat[(row, col)]).collect()) + .collect() +} + +fn residual_values(residual: &ResidualErrorEstimates) -> Vec { + residual.as_vec() +} + +fn occasion_kappa_table( + occasions: &[OccasionDesign], + variability: &VariabilityModel, + n_parameters: usize, +) -> Option { + let occasion = variability.occasion.as_ref()?; + if occasions.is_empty() || !occasion.enabled_for.iter().any(|enabled| *enabled) { + return None; + } + + Some(OccasionKappaTable( + occasions + .iter() + .map(|occasion_design| OccasionKappa { + subject_index: occasion_design.subject_index, + occasion_index: occasion_design.occasion_index, + values: KappaVector(vec![0.0; n_parameters]), + }) + .collect(), + )) +} + +#[cfg(test)] +mod tests { + use super::{PhiVector, PsiVector}; + use faer::Col; + + #[test] + fn typed_vectors_roundtrip_through_col() { + let values = Col::from_fn(3, |index| match index { + 0 => 1.0, + 1 => 2.0, + _ => 3.0, + }); + + let phi = PhiVector::from(&values); + let psi = PsiVector::from(&values); + + assert_eq!(phi.to_col(), values); + assert_eq!(psi.to_col(), values); + assert_eq!(phi.as_slice(), &[1.0, 2.0, 3.0]); + } +} diff --git a/src/estimation/parametric/statistics.rs b/src/estimation/parametric/statistics.rs new file mode 100644 index 000000000..6791862ae --- /dev/null +++ b/src/estimation/parametric/statistics.rs @@ -0,0 +1,297 @@ +//! Statistical summaries for parametric algorithm results. + +use anyhow::Result; +use csv::WriterBuilder; +use pharmsol::{ResidualErrorModel, ResidualErrorModels}; +use serde::Serialize; + +use crate::estimation::parametric::{IndividualEstimates, Population}; +use crate::output::OutputFile; + +#[derive(Debug, Clone, Default, Serialize)] +pub struct ParametricStatistics { + pub n_subjects: usize, + pub n_observations: usize, + pub n_fixed: usize, + pub n_random: usize, + pub n_total_params: usize, + pub iterations: usize, + pub converged: bool, + pub objf: f64, + pub ll_is: Option, + pub ll_lin: Option, + pub ll_gq: Option, + pub aic: f64, + pub bic: f64, + pub eta_shrinkage: Vec, + pub eta_shrinkage_overall: f64, + pub sigma: Vec, + pub mu: Vec, + pub omega_diag: Vec, + pub omega_sd: Vec, + pub cv_percent: Vec, +} + +impl ParametricStatistics { + pub fn from_result( + population: &Population, + individual_estimates: &IndividualEstimates, + objf: f64, + iterations: usize, + converged: bool, + n_subjects: usize, + n_observations: usize, + ll_is: Option, + ll_lin: Option, + ll_gq: Option, + sigma: Vec, + ) -> Self { + let n_fixed = population.npar(); + let n_random = n_fixed; + let n_total = n_fixed + n_random + sigma.len(); + + let mu: Vec = (0..n_fixed).map(|i| population.mu()[i]).collect(); + let omega_diag: Vec = (0..n_fixed).map(|i| population.omega()[(i, i)]).collect(); + let omega_sd: Vec = omega_diag.iter().map(|v| v.sqrt()).collect(); + let cv_percent: Vec = population + .coefficient_of_variation() + .iter() + .copied() + .collect(); + + let pop_var = faer::Col::from_fn(n_fixed, |i| omega_diag[i]); + let shrinkage_opt = individual_estimates.shrinkage(&pop_var); + let eta_shrinkage: Vec = shrinkage_opt + .map(|shrinkage| (0..shrinkage.nrows()).map(|i| shrinkage[i]).collect()) + .unwrap_or_else(|| vec![f64::NAN; n_fixed]); + + let eta_shrinkage_overall = if !eta_shrinkage.is_empty() { + eta_shrinkage.iter().filter(|v| !v.is_nan()).sum::() + / eta_shrinkage.iter().filter(|v| !v.is_nan()).count().max(1) as f64 + } else { + f64::NAN + }; + + let best_ll = ll_gq.or(ll_is).or(ll_lin).unwrap_or(-objf / 2.0); + let best_objf = -2.0 * best_ll; + let aic = best_objf + 2.0 * n_total as f64; + let bic = best_objf + (n_total as f64) * (n_subjects as f64).ln(); + + Self { + n_subjects, + n_observations, + n_fixed, + n_random, + n_total_params: n_total, + iterations, + converged, + objf, + ll_is, + ll_lin, + ll_gq, + aic, + bic, + eta_shrinkage, + eta_shrinkage_overall, + sigma, + mu, + omega_diag, + omega_sd, + cv_percent, + } + } + + pub fn write(&self, folder: &str) -> Result<()> { + let outputfile = OutputFile::new(folder, "statistics.csv")?; + let mut writer = WriterBuilder::new() + .has_headers(true) + .from_writer(outputfile.file()); + + writer.write_record(["metric", "value"])?; + writer.write_record(["n_subjects", &self.n_subjects.to_string()])?; + writer.write_record(["n_observations", &self.n_observations.to_string()])?; + writer.write_record(["n_fixed_params", &self.n_fixed.to_string()])?; + writer.write_record(["n_random_params", &self.n_random.to_string()])?; + writer.write_record(["n_total_params", &self.n_total_params.to_string()])?; + writer.write_record(["iterations", &self.iterations.to_string()])?; + writer.write_record(["converged", &self.converged.to_string()])?; + writer.write_record(["objf", &format!("{:.6}", self.objf)])?; + if let Some(ll) = self.ll_is { + writer.write_record(["ll_is", &format!("{:.6}", ll)])?; + } + if let Some(ll) = self.ll_lin { + writer.write_record(["ll_lin", &format!("{:.6}", ll)])?; + } + if let Some(ll) = self.ll_gq { + writer.write_record(["ll_gq", &format!("{:.6}", ll)])?; + } + writer.write_record(["aic", &format!("{:.4}", self.aic)])?; + writer.write_record(["bic", &format!("{:.4}", self.bic)])?; + writer.write_record([ + "eta_shrinkage_overall", + &format!("{:.4}", self.eta_shrinkage_overall), + ])?; + + for (i, s) in self.sigma.iter().enumerate() { + let key = if self.sigma.len() == 1 { + "sigma".to_string() + } else { + format!("sigma_{}", i + 1) + }; + writer.write_record([&key, &format!("{:.6}", s)])?; + } + + writer.flush()?; + tracing::debug!("Statistics written to {:?}", outputfile.relative_path()); + Ok(()) + } + + pub fn write_shrinkage(&self, folder: &str, param_names: &[String]) -> Result<()> { + let outputfile = OutputFile::new(folder, "shrinkage.csv")?; + let mut writer = WriterBuilder::new() + .has_headers(true) + .from_writer(outputfile.file()); + + writer.write_record(["parameter", "shrinkage"])?; + + for (name, shrink) in param_names.iter().zip(self.eta_shrinkage.iter()) { + writer.write_record([name, &format!("{:.6}", shrink)])?; + } + + writer.flush()?; + tracing::debug!("Shrinkage written to {:?}", outputfile.relative_path()); + Ok(()) + } +} + +#[derive(Debug, Clone, Default, Serialize)] +pub struct ResidualErrorEstimates { + pub additive: Option, + pub proportional: Option, + pub combined: Option<(f64, f64)>, + pub model_type: String, +} + +impl ResidualErrorEstimates { + pub fn additive(sigma: f64) -> Self { + Self { + additive: Some(sigma), + proportional: None, + combined: None, + model_type: "additive".to_string(), + } + } + + pub fn proportional(sigma: f64) -> Self { + Self { + additive: None, + proportional: Some(sigma), + combined: None, + model_type: "proportional".to_string(), + } + } + + pub fn combined(additive: f64, proportional: f64) -> Self { + Self { + additive: Some(additive), + proportional: Some(proportional), + combined: Some((additive, proportional)), + model_type: "combined".to_string(), + } + } + + pub fn as_vec(&self) -> Vec { + match (&self.additive, &self.proportional) { + (Some(a), Some(b)) => vec![*a, *b], + (Some(a), None) => vec![*a], + (None, Some(b)) => vec![*b], + (None, None) => vec![], + } + } + + pub fn write(&self, folder: &str) -> Result<()> { + let outputfile = OutputFile::new(folder, "residual_error.csv")?; + let mut writer = WriterBuilder::new() + .has_headers(true) + .from_writer(outputfile.file()); + + writer.write_record(["parameter", "value", "description"])?; + writer.write_record(["model_type", &self.model_type, ""])?; + + if let Some(additive) = self.additive { + writer.write_record([ + "sigma_add", + &format!("{:.6}", additive), + "Additive error SD", + ])?; + } + + if let Some(proportional) = self.proportional { + writer.write_record([ + "sigma_prop", + &format!("{:.6}", proportional), + "Proportional error coefficient", + ])?; + } + + writer.flush()?; + tracing::debug!("Sigma written to {:?}", outputfile.relative_path()); + Ok(()) + } +} + +pub fn residual_error_estimates_from_models( + error_models: &ResidualErrorModels, +) -> ResidualErrorEstimates { + let models = error_models + .iter() + .map(|(_, model)| *model) + .collect::>(); + + let Some(first) = models.first().copied() else { + return ResidualErrorEstimates::default(); + }; + + if !models.iter().all(|model| *model == first) { + return ResidualErrorEstimates::default(); + } + + match first { + ResidualErrorModel::Constant { a } => ResidualErrorEstimates::additive(a), + ResidualErrorModel::Proportional { b } => ResidualErrorEstimates::proportional(b), + ResidualErrorModel::Combined { a, b } => ResidualErrorEstimates::combined(a, b), + ResidualErrorModel::Exponential { .. } => ResidualErrorEstimates { + model_type: "exponential".to_string(), + ..ResidualErrorEstimates::default() + }, + } +} + +pub fn residual_error_estimates_from_observed_outeqs( + error_models: &ResidualErrorModels, + observed_outeqs: &[usize], +) -> ResidualErrorEstimates { + let models = error_models + .iter() + .filter(|(outeq, _)| observed_outeqs.contains(outeq)) + .map(|(_, model)| *model) + .collect::>(); + + let Some(first) = models.first().copied() else { + return ResidualErrorEstimates::default(); + }; + + if !models.iter().all(|model| *model == first) { + return ResidualErrorEstimates::default(); + } + + match first { + ResidualErrorModel::Constant { a } => ResidualErrorEstimates::additive(a), + ResidualErrorModel::Proportional { b } => ResidualErrorEstimates::proportional(b), + ResidualErrorModel::Combined { a, b } => ResidualErrorEstimates::combined(a, b), + ResidualErrorModel::Exponential { .. } => ResidualErrorEstimates { + model_type: "exponential".to_string(), + ..ResidualErrorEstimates::default() + }, + } +} diff --git a/src/estimation/parametric/sufficient_stats.rs b/src/estimation/parametric/sufficient_stats.rs new file mode 100644 index 000000000..d1ed37623 --- /dev/null +++ b/src/estimation/parametric/sufficient_stats.rs @@ -0,0 +1,368 @@ +use anyhow::{bail, Result}; +use faer::{Col, Mat}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone)] +pub struct SufficientStats { + s1: Col, + s2: Mat, + s3: Col, + stat_rese: f64, + count: usize, + n_obs: usize, +} + +impl SufficientStats { + pub fn new(n_params: usize) -> Self { + Self { + s1: Col::zeros(n_params), + s2: Mat::zeros(n_params, n_params), + s3: Col::zeros(n_params), + stat_rese: 0.0, + count: 0, + n_obs: 0, + } + } + + pub fn reset(&mut self) { + let n = self.s1.nrows(); + self.s1 = Col::zeros(n); + self.s2 = Mat::zeros(n, n); + self.s3 = Col::zeros(n); + self.stat_rese = 0.0; + self.count = 0; + self.n_obs = 0; + } + + pub fn npar(&self) -> usize { + self.s1.nrows() + } + + pub fn s1(&self) -> &Col { + &self.s1 + } + + pub fn s1_mut(&mut self) -> &mut Col { + &mut self.s1 + } + + pub fn s2(&self) -> &Mat { + &self.s2 + } + + pub fn s2_mut(&mut self) -> &mut Mat { + &mut self.s2 + } + + pub fn count(&self) -> usize { + self.count + } + + pub fn s3(&self) -> &Col { + &self.s3 + } + + pub fn s3_mut(&mut self) -> &mut Col { + &mut self.s3 + } + + pub fn stat_rese(&self) -> f64 { + self.stat_rese + } + + pub fn set_stat_rese(&mut self, value: f64) { + self.stat_rese = value; + } + + pub fn add_stat_rese(&mut self, value: f64) { + self.stat_rese += value; + } + + pub fn n_obs(&self) -> usize { + self.n_obs + } + + pub fn set_n_obs(&mut self, n: usize) { + self.n_obs = n; + } + + pub fn add_n_obs(&mut self, n: usize) { + self.n_obs += n; + } + + pub fn accumulate(&mut self, psi: &Col) -> Result<()> { + let n = self.npar(); + + if psi.nrows() != n { + bail!( + "Parameter vector length ({}) doesn't match statistics dimension ({})", + psi.nrows(), + n + ); + } + + for i in 0..n { + self.s1[i] += psi[i]; + } + + for i in 0..n { + for j in 0..n { + self.s2[(i, j)] += psi[i] * psi[j]; + } + } + + for i in 0..n { + self.s3[i] += psi[i] * psi[i]; + } + + self.count += 1; + + Ok(()) + } + + pub fn accumulate_batch(&mut self, samples: &[Col]) -> Result<()> { + for sample in samples { + self.accumulate(sample)?; + } + Ok(()) + } + + pub fn stochastic_update(&mut self, new_stats: &SufficientStats, step_size: f64) -> Result<()> { + if self.npar() != new_stats.npar() { + bail!( + "Statistics dimension mismatch: {} vs {}", + self.npar(), + new_stats.npar() + ); + } + + if step_size == 0.0 { + return Ok(()); + } + + let n = self.npar(); + + for i in 0..n { + self.s1[i] += step_size * (new_stats.s1[i] - self.s1[i]); + } + + for i in 0..n { + for j in 0..n { + self.s2[(i, j)] += step_size * (new_stats.s2[(i, j)] - self.s2[(i, j)]); + } + } + + for i in 0..n { + self.s3[i] += step_size * (new_stats.s3[i] - self.s3[i]); + } + + self.stat_rese += step_size * (new_stats.stat_rese - self.stat_rese); + self.count = ((1.0 - step_size) * self.count as f64 + step_size * new_stats.count as f64) + .round() as usize; + self.n_obs = ((1.0 - step_size) * self.n_obs as f64 + step_size * new_stats.n_obs as f64) + .round() as usize; + + Ok(()) + } + + pub fn compute_m_step(&self) -> Result<(Col, Mat)> { + if self.count == 0 { + bail!("Cannot compute M-step with zero samples"); + } + + let n = self.npar(); + let count_f64 = self.count as f64; + let mu = Col::from_fn(n, |i| self.s1[i] / count_f64); + let omega = Mat::from_fn(n, n, |i, j| self.s2[(i, j)] / count_f64 - mu[i] * mu[j]); + + Ok((mu, omega)) + } + + pub fn merge(&mut self, other: &SufficientStats) -> Result<()> { + if self.npar() != other.npar() { + bail!( + "Cannot merge statistics with different dimensions: {} vs {}", + self.npar(), + other.npar() + ); + } + + let n = self.npar(); + + for i in 0..n { + self.s1[i] += other.s1[i]; + } + + for i in 0..n { + for j in 0..n { + self.s2[(i, j)] += other.s2[(i, j)]; + } + } + + for i in 0..n { + self.s3[i] += other.s3[i]; + } + + self.stat_rese += other.stat_rese; + self.count += other.count; + self.n_obs += other.n_obs; + + Ok(()) + } +} + +impl Default for SufficientStats { + fn default() -> Self { + Self::new(0) + } +} + +impl Serialize for SufficientStats { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + + let mut state = serializer.serialize_struct("SufficientStats", 6)?; + let s1_vec: Vec = (0..self.s1.nrows()).map(|i| self.s1[i]).collect(); + state.serialize_field("s1", &s1_vec)?; + + let s2_vec: Vec> = (0..self.s2.nrows()) + .map(|i| (0..self.s2.ncols()).map(|j| self.s2[(i, j)]).collect()) + .collect(); + state.serialize_field("s2", &s2_vec)?; + + let s3_vec: Vec = (0..self.s3.nrows()).map(|i| self.s3[i]).collect(); + state.serialize_field("s3", &s3_vec)?; + state.serialize_field("stat_rese", &self.stat_rese)?; + state.serialize_field("count", &self.count)?; + state.serialize_field("n_obs", &self.n_obs)?; + + state.end() + } +} + +impl<'de> Deserialize<'de> for SufficientStats { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + struct SufficientStatsData { + s1: Vec, + s2: Vec>, + #[serde(default)] + s3: Option>, + #[serde(default)] + stat_rese: f64, + count: usize, + #[serde(default)] + n_obs: usize, + } + + let data = SufficientStatsData::deserialize(deserializer)?; + let n = data.s1.len(); + let s1 = Col::from_fn(n, |i| data.s1[i]); + + if data.s2.len() != n { + return Err(serde::de::Error::custom( + "S2 row count doesn't match S1 length", + )); + } + + let s2 = Mat::from_fn(n, n, |i, j| { + if j < data.s2[i].len() { + data.s2[i][j] + } else { + 0.0 + } + }); + + let s3 = match data.s3 { + Some(s3_data) if s3_data.len() == n => Col::from_fn(n, |i| s3_data[i]), + _ => Col::zeros(n), + }; + + Ok(SufficientStats { + s1, + s2, + s3, + stat_rese: data.stat_rese, + count: data.count, + n_obs: data.n_obs, + }) + } +} + +#[derive(Debug, Clone, Copy)] +pub enum StepSizeSchedule { + Constant(f64), + Harmonic, + RobbinsMonro { a: f64, b: f64 }, + PolyakRuppert { start_averaging: usize }, +} + +impl StepSizeSchedule { + pub fn new_saem(n_burn_in: usize, _n_stochastic: usize) -> Self { + StepSizeSchedule::PolyakRuppert { + start_averaging: n_burn_in, + } + } + + pub fn step_size(&self, k: usize) -> f64 { + match self { + StepSizeSchedule::Constant(gamma) => *gamma, + StepSizeSchedule::Harmonic => 1.0 / k as f64, + StepSizeSchedule::RobbinsMonro { a, b } => a / (k as f64 + b), + StepSizeSchedule::PolyakRuppert { start_averaging } => { + if k < *start_averaging { + 1.0 + } else { + 1.0 / (k - start_averaging + 1) as f64 + } + } + } + } +} + +impl Default for StepSizeSchedule { + fn default() -> Self { + StepSizeSchedule::PolyakRuppert { + start_averaging: 100, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sufficient_stats_accumulation() { + let mut stats = SufficientStats::new(2); + let sample1 = Col::from_fn(2, |i| if i == 0 { 1.0 } else { 2.0 }); + let sample2 = Col::from_fn(2, |i| if i == 0 { 3.0 } else { 4.0 }); + + stats.accumulate(&sample1).unwrap(); + stats.accumulate(&sample2).unwrap(); + + assert_eq!(stats.count(), 2); + assert_eq!(stats.s1()[0], 4.0); + assert_eq!(stats.s1()[1], 6.0); + } + + #[test] + fn test_m_step_computation() { + let mut stats = SufficientStats::new(2); + + for i in 0..3 { + let sample = Col::from_fn(2, |j| (2 * i + j + 1) as f64); + stats.accumulate(&sample).unwrap(); + } + + let (mu, _omega) = stats.compute_m_step().unwrap(); + assert!((mu[0] - 3.0).abs() < 1e-10); + assert!((mu[1] - 4.0).abs() < 1e-10); + } +} diff --git a/src/estimation/parametric/summaries.rs b/src/estimation/parametric/summaries.rs new file mode 100644 index 000000000..1277bc977 --- /dev/null +++ b/src/estimation/parametric/summaries.rs @@ -0,0 +1,64 @@ +use pharmsol::{Data, Equation, Event}; + +use crate::estimation::parametric::ParametricWorkspace; +use crate::results::{FitSummary, IndividualSummary, ParameterSummary, PopulationSummary}; + +pub fn fit_summary(result: &ParametricWorkspace) -> FitSummary { + FitSummary { + objective_function: result.objf(), + converged: result.converged(), + iterations: result.iterations(), + subject_count: result.data().subjects().len(), + observation_count: count_observations(result.data()), + parameter_count: result.population().npar(), + algorithm: format!("{:?}", result.algorithm()), + } +} + +pub fn population_summary(result: &ParametricWorkspace) -> PopulationSummary { + let names = result.population().param_names(); + let sds = result.standard_deviations(); + let cvs = result.cv_percent(); + + let parameters = names + .into_iter() + .enumerate() + .map(|(index, name)| ParameterSummary { + name, + mean: result.mu()[index], + median: result.mu()[index], + sd: sds[index], + cv_percent: cvs[index], + }) + .collect(); + + PopulationSummary { parameters } +} + +pub fn individual_summaries( + result: &ParametricWorkspace, +) -> Vec { + let parameter_names = result.population().param_names(); + + result + .individual_estimates() + .iter() + .map(|individual| IndividualSummary { + id: individual.subject_id().to_string(), + parameter_names: parameter_names.clone(), + estimates: individual.psi().iter().copied().collect(), + standard_errors: individual + .standard_errors() + .map(|errors| errors.iter().copied().collect()), + }) + .collect() +} + +fn count_observations(data: &Data) -> usize { + data.subjects() + .iter() + .flat_map(|subject| subject.occasions()) + .flat_map(|occasion| occasion.events()) + .filter(|event| matches!(event, Event::Observation(_))) + .count() +} diff --git a/src/estimation/parametric/transforms.rs b/src/estimation/parametric/transforms.rs new file mode 100644 index 000000000..9c1c8e425 --- /dev/null +++ b/src/estimation/parametric/transforms.rs @@ -0,0 +1,435 @@ +use faer::Col; +use serde::{Deserialize, Serialize}; +use statrs::distribution::{Continuous, ContinuousCDF, Normal}; + +use super::state::{PhiVector, PsiVector}; +use super::Population; +use anyhow::Result; +use faer::Mat; + +pub(crate) struct InitializedPopulationInPhiSpace { + pub mu_psi: PsiVector, + pub mu_phi: PhiVector, + pub omega_phi: Mat, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum ParameterTransform { + None, + LogNormal, + Logit { lower: f64, upper: f64 }, + Probit { lower: f64, upper: f64 }, +} + +impl Default for ParameterTransform { + fn default() -> Self { + ParameterTransform::None + } +} + +impl ParameterTransform { + pub fn logit_unit() -> Self { + ParameterTransform::Logit { + lower: 0.0, + upper: 1.0, + } + } + + pub fn logit(lower: f64, upper: f64) -> Self { + ParameterTransform::Logit { lower, upper } + } + + pub fn probit(lower: f64, upper: f64) -> Self { + ParameterTransform::Probit { lower, upper } + } + + pub fn psi_to_phi(&self, psi: f64) -> f64 { + match self { + ParameterTransform::None => psi, + ParameterTransform::LogNormal => psi.ln(), + ParameterTransform::Logit { lower, upper } => { + let normalized = (psi - lower) / (upper - psi); + normalized.ln() + } + ParameterTransform::Probit { lower, upper } => { + let normalized = (psi - lower) / (upper - lower); + probit(normalized) + } + } + } + + pub fn phi_to_psi(&self, phi: f64) -> f64 { + match self { + ParameterTransform::None => phi, + ParameterTransform::LogNormal => phi.exp(), + ParameterTransform::Logit { lower, upper } => { + let exp_phi = phi.exp(); + lower + (upper - lower) * exp_phi / (1.0 + exp_phi) + } + ParameterTransform::Probit { lower, upper } => { + let normalized = normal_cdf(phi); + lower + (upper - lower) * normalized + } + } + } + + pub fn dpsi_dphi(&self, phi: f64) -> f64 { + match self { + ParameterTransform::None => 1.0, + ParameterTransform::LogNormal => phi.exp(), + ParameterTransform::Logit { lower, upper } => { + let exp_phi = phi.exp(); + let denom = (1.0 + exp_phi).powi(2); + (upper - lower) * exp_phi / denom + } + ParameterTransform::Probit { lower, upper } => (upper - lower) * normal_pdf(phi), + } + } + + pub fn log_jacobian(&self, phi: f64) -> f64 { + match self { + ParameterTransform::None => 0.0, + ParameterTransform::LogNormal => phi, + ParameterTransform::Logit { lower, upper } => { + (upper - lower).ln() + phi - 2.0 * (1.0 + phi.exp()).ln() + } + ParameterTransform::Probit { lower, upper } => { + (upper - lower).ln() + log_normal_pdf(phi) + } + } + } + + pub fn is_valid_psi(&self, psi: f64) -> bool { + match self { + ParameterTransform::None => true, + ParameterTransform::LogNormal => psi > 0.0, + ParameterTransform::Logit { lower, upper } => psi > *lower && psi < *upper, + ParameterTransform::Probit { lower, upper } => psi > *lower && psi < *upper, + } + } + + pub fn psi_bounds(&self) -> Option<(f64, f64)> { + match self { + ParameterTransform::None => None, + ParameterTransform::LogNormal => Some((0.0, f64::INFINITY)), + ParameterTransform::Logit { lower, upper } => Some((*lower, *upper)), + ParameterTransform::Probit { lower, upper } => Some((*lower, *upper)), + } + } + + pub fn to_saemix_code(&self) -> u8 { + match self { + ParameterTransform::None => 0, + ParameterTransform::LogNormal => 1, + ParameterTransform::Probit { .. } => 2, + ParameterTransform::Logit { .. } => 3, + } + } + + pub fn from_saemix_code(code: u8) -> Self { + match code { + 0 => ParameterTransform::None, + 1 => ParameterTransform::LogNormal, + 2 => ParameterTransform::Probit { + lower: 0.0, + upper: 1.0, + }, + 3 => ParameterTransform::Logit { + lower: 0.0, + upper: 1.0, + }, + _ => ParameterTransform::None, + } + } +} + +pub fn transforms_from_saemix_codes(codes: &[u8]) -> Vec { + codes + .iter() + .map(|&code| ParameterTransform::from_saemix_code(code)) + .collect() +} + +pub fn phi_to_psi_vec(transforms: &[ParameterTransform], phi: &Col) -> Col { + Col::from_fn(phi.nrows(), |index| { + transforms[index].phi_to_psi(phi[index]) + }) +} + +pub fn psi_to_phi_vec(transforms: &[ParameterTransform], psi: &Col) -> Col { + Col::from_fn(psi.nrows(), |index| { + transforms[index].psi_to_phi(psi[index]) + }) +} + +pub fn phi_to_psi(transforms: &[ParameterTransform], phi: &PhiVector) -> PsiVector { + PsiVector::from(&phi_to_psi_vec(transforms, &phi.to_col())) +} + +pub fn psi_to_phi(transforms: &[ParameterTransform], psi: &PsiVector) -> PhiVector { + PhiVector::from(&psi_to_phi_vec(transforms, &psi.to_col())) +} + +pub fn transform_label(transform: &ParameterTransform) -> &'static str { + match transform { + ParameterTransform::None => "Normal", + ParameterTransform::LogNormal => "LogNormal", + ParameterTransform::Logit { .. } => "Logit", + ParameterTransform::Probit { .. } => "Probit", + } +} + +pub fn default_phi_variance(transform: &ParameterTransform) -> Option { + match transform { + ParameterTransform::LogNormal => { + let cv: f64 = 0.5; + Some((1.0 + cv * cv).ln()) + } + ParameterTransform::Logit { .. } | ParameterTransform::Probit { .. } => Some(1.0), + ParameterTransform::None => None, + } +} + +pub(crate) fn initialize_population_in_phi_space( + population: &mut Population, + transforms: &[ParameterTransform], +) -> Result { + let mu_psi = PsiVector::from(population.mu()); + let mu_phi = psi_to_phi(transforms, &mu_psi); + population.update_mu(mu_phi.to_col())?; + + let n_params = population.npar(); + let mut omega_phi = population.omega().clone(); + for i in 0..n_params { + if let Some(default_variance) = default_phi_variance(&transforms[i]) { + omega_phi[(i, i)] = default_variance; + } + for j in 0..n_params { + if i != j { + omega_phi[(i, j)] = 0.0; + } + } + } + population.update_omega(omega_phi.clone())?; + + Ok(InitializedPopulationInPhiSpace { + mu_psi, + mu_phi, + omega_phi, + }) +} + +#[inline] +fn standard_normal() -> Normal { + Normal::standard() +} + +#[inline] +fn normal_pdf(x: f64) -> f64 { + standard_normal().pdf(x) +} + +#[inline] +fn log_normal_pdf(x: f64) -> f64 { + standard_normal().ln_pdf(x) +} + +#[inline] +fn normal_cdf(x: f64) -> f64 { + standard_normal().cdf(x) +} + +#[inline] +fn probit(p: f64) -> f64 { + standard_normal().inverse_cdf(p) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::estimation::parametric::Population; + use crate::model::{ParameterSpace, ParameterSpec}; + + const EPSILON: f64 = 1e-10; + + #[test] + fn test_none_transform() { + let t = ParameterTransform::None; + assert!((t.psi_to_phi(5.0) - 5.0).abs() < EPSILON); + assert!((t.phi_to_psi(5.0) - 5.0).abs() < EPSILON); + assert!((t.dpsi_dphi(5.0) - 1.0).abs() < EPSILON); + } + + #[test] + fn test_lognormal_transform() { + let t = ParameterTransform::LogNormal; + assert!((t.psi_to_phi(1.0) - 0.0).abs() < EPSILON); + assert!((t.phi_to_psi(0.0) - 1.0).abs() < EPSILON); + assert!((t.psi_to_phi(std::f64::consts::E) - 1.0).abs() < EPSILON); + assert!((t.phi_to_psi(1.0) - std::f64::consts::E).abs() < EPSILON); + + let psi = 2.5; + let phi = t.psi_to_phi(psi); + assert!((t.phi_to_psi(phi) - psi).abs() < EPSILON); + } + + #[test] + fn test_logit_transform() { + let t = ParameterTransform::Logit { + lower: 0.0, + upper: 1.0, + }; + assert!((t.psi_to_phi(0.5) - 0.0).abs() < EPSILON); + assert!((t.phi_to_psi(0.0) - 0.5).abs() < EPSILON); + + let psi = 0.7; + let phi = t.psi_to_phi(psi); + assert!((t.phi_to_psi(phi) - psi).abs() < EPSILON); + } + + #[test] + fn test_logit_custom_bounds() { + let t = ParameterTransform::Logit { + lower: 10.0, + upper: 100.0, + }; + assert!((t.psi_to_phi(55.0) - 0.0).abs() < EPSILON); + assert!((t.phi_to_psi(0.0) - 55.0).abs() < EPSILON); + + let psi = 30.0; + let phi = t.psi_to_phi(psi); + assert!((t.phi_to_psi(phi) - psi).abs() < EPSILON); + } + + #[test] + fn test_probit_transform() { + let t = ParameterTransform::Probit { + lower: 0.0, + upper: 1.0, + }; + assert!(t.psi_to_phi(0.5).abs() < 1e-6); + + let probit_07 = probit(0.7); + assert!((probit_07 - 0.524).abs() < 0.01); + + let psi = 0.7; + let phi = t.psi_to_phi(psi); + let psi_back = t.phi_to_psi(phi); + assert!((psi_back - psi).abs() < 1e-4); + } + + #[test] + fn test_jacobian_lognormal() { + let t = ParameterTransform::LogNormal; + let phi: f64 = 1.0; + let expected = phi.exp(); + assert!((t.dpsi_dphi(phi) - expected).abs() < EPSILON); + assert!((t.log_jacobian(phi) - phi).abs() < EPSILON); + } + + #[test] + fn test_saemix_codes() { + assert_eq!(ParameterTransform::None.to_saemix_code(), 0); + assert_eq!(ParameterTransform::LogNormal.to_saemix_code(), 1); + assert_eq!( + ParameterTransform::Probit { + lower: 0.0, + upper: 1.0 + } + .to_saemix_code(), + 2 + ); + assert_eq!( + ParameterTransform::Logit { + lower: 0.0, + upper: 1.0 + } + .to_saemix_code(), + 3 + ); + } + + #[test] + fn test_validity_checks() { + let t = ParameterTransform::LogNormal; + assert!(!t.is_valid_psi(-1.0)); + assert!(!t.is_valid_psi(0.0)); + assert!(t.is_valid_psi(0.1)); + assert!(t.is_valid_psi(100.0)); + + let t = ParameterTransform::Logit { + lower: 0.0, + upper: 1.0, + }; + assert!(!t.is_valid_psi(-0.1)); + assert!(!t.is_valid_psi(0.0)); + assert!(t.is_valid_psi(0.5)); + assert!(!t.is_valid_psi(1.0)); + assert!(!t.is_valid_psi(1.1)); + } + + #[test] + fn test_transforms_from_saemix_codes() { + let transforms = transforms_from_saemix_codes(&[0u8, 1u8, 2u8, 3u8]); + assert!(matches!(transforms[0], ParameterTransform::None)); + assert!(matches!(transforms[1], ParameterTransform::LogNormal)); + assert!(matches!(transforms[2], ParameterTransform::Probit { .. })); + assert!(matches!(transforms[3], ParameterTransform::Logit { .. })); + } + + #[test] + fn test_phi_psi_vector_roundtrip() { + let transforms = vec![ + ParameterTransform::None, + ParameterTransform::LogNormal, + ParameterTransform::logit(0.0, 1.0), + ]; + let psi = Col::from_fn(3, |index| match index { + 0 => 2.0, + 1 => 3.0, + _ => 0.25, + }); + + let phi = psi_to_phi_vec(&transforms, &psi); + let back = phi_to_psi_vec(&transforms, &phi); + + for index in 0..psi.nrows() { + assert!((psi[index] - back[index]).abs() < 1e-10); + } + } + + #[test] + fn test_typed_phi_psi_roundtrip() { + let transforms = vec![ + ParameterTransform::None, + ParameterTransform::LogNormal, + ParameterTransform::logit(0.0, 1.0), + ]; + let psi = PsiVector(vec![2.0, 3.0, 0.25]); + + let phi = psi_to_phi(&transforms, &psi); + let back = phi_to_psi(&transforms, &phi); + + assert_eq!(psi.0.len(), back.0.len()); + for index in 0..psi.0.len() { + assert!((psi.0[index] - back.0[index]).abs() < 1e-10); + } + } + + #[test] + fn initialize_population_converts_to_phi_space_and_resets_covariance() { + let parameters = ParameterSpace::new() + .add(ParameterSpec::bounded("ke", 0.1, 1.0)) + .add(ParameterSpec::bounded("v", 1.0, 20.0)); + let mut population = Population::from_parameter_space(parameters).unwrap(); + let transforms = vec![ParameterTransform::LogNormal, ParameterTransform::None]; + + let initialized = initialize_population_in_phi_space(&mut population, &transforms).unwrap(); + + assert!((initialized.mu_psi.0[0] - 0.55).abs() < 1e-12); + assert!((initialized.mu_phi.0[0] - 0.55_f64.ln()).abs() < 1e-12); + assert!((population.mu()[0] - 0.55_f64.ln()).abs() < 1e-12); + assert_eq!(initialized.omega_phi[(0, 1)], 0.0); + assert_eq!(initialized.omega_phi[(1, 0)], 0.0); + } +} diff --git a/src/estimation/parametric/uncertainty.rs b/src/estimation/parametric/uncertainty.rs new file mode 100644 index 000000000..3b1920b9c --- /dev/null +++ b/src/estimation/parametric/uncertainty.rs @@ -0,0 +1,240 @@ +use anyhow::Result; +use faer::linalg::solvers::DenseSolveCore; +use faer::{Col, Mat}; +use pharmsol::Equation; + +use super::{FimMethod, ParametricWorkspace, Population, UncertaintyEstimates}; + +impl UncertaintyEstimates { + pub fn new() -> Self { + Self::default() + } + + pub fn has_fim(&self) -> bool { + self.fim.is_some() + } + + pub fn has_standard_errors(&self) -> bool { + self.se_mu.is_some() + } + + pub fn fim_inverse(&self) -> Option<&Mat> { + self.fim_inverse.as_ref() + } + + pub fn se_mu(&self) -> Option<&Col> { + self.se_mu.as_ref() + } + + pub fn se_omega(&self) -> Option<&Mat> { + self.se_omega.as_ref() + } + + pub fn rse_mu(&self) -> Option<&Col> { + self.rse_mu.as_ref() + } + + pub fn fim_method(&self) -> Option { + self.fim_method + } + + pub fn from_fim_inverse( + population: &Population, + fim_inverse: Mat, + fim_method: FimMethod, + ) -> Self { + let n_params = population.npar(); + let se_mu = + (fim_inverse.nrows() >= n_params && fim_inverse.ncols() >= n_params).then(|| { + Col::from_fn(n_params, |index| { + fim_inverse[(index, index)].max(0.0).sqrt() + }) + }); + let rse_mu = se_mu.as_ref().map(|se| { + Col::from_fn(n_params, |index| { + let mu = population.mu()[index].abs(); + if mu > f64::EPSILON { + 100.0 * se[index] / mu + } else { + 0.0 + } + }) + }); + let se_omega = if fim_inverse.nrows() >= 2 * n_params && fim_inverse.ncols() >= 2 * n_params + { + Some(Mat::from_fn(n_params, n_params, |row, col| { + if row == col { + fim_inverse[(n_params + row, n_params + col)] + .max(0.0) + .sqrt() + } else { + 0.0 + } + })) + } else { + None + }; + + Self { + fim: None, + fim_inverse: Some(fim_inverse), + se_mu, + se_omega, + rse_mu, + fim_method: Some(fim_method), + } + } + + pub fn from_fim(population: &Population, fim: Mat, fim_method: FimMethod) -> Result { + let fim_inverse = fim + .clone() + .llt(faer::Side::Lower) + .map_err(|_| anyhow::anyhow!("FIM is not positive definite"))? + .inverse(); + let mut estimates = Self::from_fim_inverse(population, fim_inverse, fim_method); + estimates.fim = Some(fim); + Ok(estimates) + } +} + +pub fn focei_linearization_uncertainty( + population: &Population, + n_subjects: usize, +) -> UncertaintyEstimates { + let n_params = population.npar(); + if n_subjects == 0 || n_params == 0 { + return UncertaintyEstimates::new(); + } + + let n_subjects = n_subjects as f64; + let omega_inv = inverse_or_diagonal(population.omega()); + let fim = Mat::from_fn(2 * n_params, 2 * n_params, |row, col| { + if row < n_params && col < n_params { + n_subjects * omega_inv[(row, col)] + } else if row == col && row >= n_params { + let variance = population.omega()[(row - n_params, row - n_params)].max(1e-8); + n_subjects / (2.0 * variance.powi(2)) + } else { + 0.0 + } + }); + + UncertaintyEstimates::from_fim(population, fim, FimMethod::Linearization).unwrap_or_else(|_| { + let fim_inverse = Mat::from_fn(2 * n_params, 2 * n_params, |row, col| { + if row < n_params && col < n_params { + population.omega()[(row, col)] / n_subjects + } else if row == col && row >= n_params { + let variance = population.omega()[(row - n_params, row - n_params)].max(1e-8); + 2.0 * variance.powi(2) / n_subjects + } else { + 0.0 + } + }); + UncertaintyEstimates::from_fim_inverse(population, fim_inverse, FimMethod::Linearization) + }) +} + +fn inverse_or_diagonal(matrix: &Mat) -> Mat { + match matrix.clone().llt(faer::Side::Lower) { + Ok(cholesky) => cholesky.inverse(), + Err(_) => Mat::from_fn(matrix.nrows(), matrix.ncols(), |row, col| { + if row == col { + 1.0 / matrix[(row, row)].max(1e-8) + } else { + 0.0 + } + }), + } +} + +pub fn estimates(workspace: &ParametricWorkspace) -> &UncertaintyEstimates { + workspace.uncertainty() +} + +pub fn has_fim(workspace: &ParametricWorkspace) -> bool { + workspace.uncertainty().has_fim() +} + +pub fn has_standard_errors(workspace: &ParametricWorkspace) -> bool { + workspace.uncertainty().has_standard_errors() +} + +pub fn se_mu(workspace: &ParametricWorkspace) -> Option<&Col> { + workspace.uncertainty().se_mu() +} + +pub fn fim(workspace: &ParametricWorkspace) -> Option<&Mat> { + workspace.uncertainty().fim.as_ref() +} + +pub fn fim_inverse(workspace: &ParametricWorkspace) -> Option<&Mat> { + workspace.uncertainty().fim_inverse() +} + +pub fn se_omega(workspace: &ParametricWorkspace) -> Option<&Mat> { + workspace.uncertainty().se_omega() +} + +pub fn rse_mu(workspace: &ParametricWorkspace) -> Option<&Col> { + workspace.uncertainty().rse_mu() +} + +pub fn fim_method(workspace: &ParametricWorkspace) -> Option { + workspace.uncertainty().fim_method() +} + +#[cfg(test)] +mod tests { + use super::{focei_linearization_uncertainty, UncertaintyEstimates}; + use crate::estimation::parametric::{FimMethod, Population}; + use crate::model::{ParameterSpace, ParameterSpec}; + use faer::{Col, Mat}; + + #[test] + fn derives_standard_errors_from_inverse_fim() { + let parameters = ParameterSpace::new() + .add(ParameterSpec::bounded("ke", 0.1, 1.0)) + .add(ParameterSpec::bounded("v", 1.0, 20.0)); + let population = Population::new( + Col::from_fn(2, |index| if index == 0 { 0.5 } else { 10.0 }), + Mat::from_fn(2, 2, |row, col| if row == col { 0.2 } else { 0.0 }), + parameters, + ) + .unwrap(); + let fim_inverse = Mat::from_fn( + 4, + 4, + |row, col| if row == col { (row + 1) as f64 } else { 0.0 }, + ); + + let estimates = + UncertaintyEstimates::from_fim_inverse(&population, fim_inverse, FimMethod::Observed); + + assert!(estimates.has_standard_errors()); + assert_eq!(estimates.fim_method(), Some(FimMethod::Observed)); + assert!((estimates.se_mu().unwrap()[0] - 1.0).abs() < 1e-12); + assert!((estimates.se_mu().unwrap()[1] - 2.0_f64.sqrt()).abs() < 1e-12); + assert!((estimates.rse_mu().unwrap()[0] - 200.0).abs() < 1e-12); + assert!(estimates.se_omega().is_some()); + } + + #[test] + fn focei_linearization_uncertainty_exposes_fim_and_standard_errors() { + let parameters = ParameterSpace::new() + .add(ParameterSpec::bounded("ke", 0.1, 1.0)) + .add(ParameterSpec::bounded("v", 1.0, 20.0)); + let population = Population::new( + Col::from_fn(2, |index| if index == 0 { 0.5 } else { 10.0 }), + Mat::from_fn(2, 2, |row, col| if row == col { 0.2 } else { 0.0 }), + parameters, + ) + .unwrap(); + + let estimates = focei_linearization_uncertainty(&population, 4); + + assert!(estimates.has_fim()); + assert!(estimates.has_standard_errors()); + assert_eq!(estimates.fim_method(), Some(FimMethod::Linearization)); + assert!(estimates.fim_inverse().is_some()); + } +} diff --git a/src/estimation/parametric/workspace.rs b/src/estimation/parametric/workspace.rs new file mode 100644 index 000000000..c0dec888a --- /dev/null +++ b/src/estimation/parametric/workspace.rs @@ -0,0 +1,426 @@ +use pharmsol::Equation; + +use crate::algorithms::{Status, StopReason}; +use crate::compile::OccasionDesign; +use crate::estimation::parametric::{ + IndividualEstimates, LikelihoodEstimates, ParametricIterationLog, ParametricPredictions, + Population, ResidualErrorEstimates, UncertaintyEstimates, +}; +use crate::output::shared::RunConfiguration; +use crate::results::FitResult; +use pharmsol::Data; + +use super::state::{IndividualEffectsState, ParametricModelState}; + +#[derive(Debug)] +pub struct ParametricWorkspace { + state: ParametricModelState, + individuals: IndividualEffectsState, + equation: E, + data: Data, + population: Population, + individual_estimates: IndividualEstimates, + objf: f64, + iterations: usize, + status: Status, + run_configuration: RunConfiguration, + iteration_log: ParametricIterationLog, + likelihoods: LikelihoodEstimates, + uncertainty: UncertaintyEstimates, + sigma: ResidualErrorEstimates, + predictions: Option, +} + +impl ParametricWorkspace { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + state: ParametricModelState, + individuals: IndividualEffectsState, + equation: E, + data: Data, + population: Population, + individual_estimates: IndividualEstimates, + objf: f64, + iterations: usize, + status: Status, + run_configuration: RunConfiguration, + iteration_log: ParametricIterationLog, + likelihoods: LikelihoodEstimates, + uncertainty: UncertaintyEstimates, + sigma: ResidualErrorEstimates, + predictions: Option, + ) -> Self { + Self { + state, + individuals, + equation, + data, + population, + individual_estimates, + objf, + iterations, + status, + run_configuration, + iteration_log, + likelihoods, + uncertainty, + sigma, + predictions, + } + } + + pub fn state(&self) -> &ParametricModelState { + &self.state + } + + pub fn population(&self) -> &Population { + &self.population + } + + pub fn mu(&self) -> &faer::Col { + self.population.mu() + } + + pub fn omega(&self) -> &faer::Mat { + self.population.omega() + } + + pub fn individual_estimates(&self) -> &IndividualEstimates { + &self.individual_estimates + } + + pub fn objf(&self) -> f64 { + self.objf + } + + pub fn best_objf(&self) -> f64 { + self.likelihoods.best_objf().unwrap_or(self.objf) + } + + pub fn iterations(&self) -> usize { + self.iterations + } + + pub fn converged(&self) -> bool { + self.status == Status::Stop(StopReason::Converged) + } + + pub fn status(&self) -> &Status { + &self.status + } + + pub(crate) fn run_configuration(&self) -> &RunConfiguration { + &self.run_configuration + } + + pub(crate) fn algorithm(&self) -> crate::algorithms::Algorithm { + self.run_configuration.algorithm + } + + pub(crate) fn output_folder(&self) -> &str { + self.run_configuration.output_path() + } + + pub(crate) fn should_write_outputs(&self) -> bool { + self.run_configuration.should_write_outputs() + } + + pub(crate) fn prediction_interval(&self) -> (f64, f64) { + ( + self.run_configuration.runtime.idelta, + self.run_configuration.runtime.tad, + ) + } + + pub fn data(&self) -> &Data { + &self.data + } + + pub fn sigma(&self) -> &ResidualErrorEstimates { + &self.sigma + } + + pub fn predictions(&self) -> Option<&ParametricPredictions> { + self.predictions.as_ref() + } + + pub fn set_predictions(&mut self, predictions: ParametricPredictions) { + self.predictions = Some(predictions); + } + + pub fn individuals(&self) -> &IndividualEffectsState { + &self.individuals + } + + pub fn likelihoods(&self) -> &LikelihoodEstimates { + &self.likelihoods + } + + pub fn uncertainty(&self) -> &UncertaintyEstimates { + &self.uncertainty + } + + pub fn equation(&self) -> &E { + &self.equation + } + + pub fn standard_deviations(&self) -> faer::Col { + self.population.standard_deviations() + } + + pub fn cv_percent(&self) -> faer::Col { + self.population.coefficient_of_variation() + } + + pub fn correlation_matrix(&self) -> faer::Mat { + self.population.correlation_matrix() + } + + pub fn iteration_log(&self) -> &ParametricIterationLog { + &self.iteration_log + } + + pub fn write_covariates(&self) -> anyhow::Result<()> { + use csv::WriterBuilder; + use pharmsol::Event; + + tracing::debug!("Writing covariates..."); + + let outputfile = crate::output::OutputFile::new(self.output_folder(), "covariates.csv")?; + let mut writer = WriterBuilder::new() + .has_headers(true) + .from_writer(outputfile.file()); + + let mut covariate_names = std::collections::HashSet::new(); + for subject in self.data.subjects() { + for occasion in subject.occasions() { + let covmap = occasion.covariates().covariates(); + for cov_name in covmap.keys() { + covariate_names.insert(cov_name.clone()); + } + } + } + let mut covariate_names: Vec = covariate_names.into_iter().collect(); + covariate_names.sort(); + + if covariate_names.is_empty() { + return Ok(()); + } + + let mut headers = vec!["id", "time", "block"]; + headers.extend(covariate_names.iter().map(|s| s.as_str())); + writer.write_record(&headers)?; + + for subject in self.data.subjects() { + for occasion in subject.occasions() { + let covmap = occasion.covariates().covariates(); + + for event in occasion.iter() { + let time = match event { + Event::Bolus(bolus) => bolus.time(), + Event::Infusion(infusion) => infusion.time(), + Event::Observation(observation) => observation.time(), + }; + + let mut row: Vec = vec![ + subject.id().clone(), + time.to_string(), + occasion.index().to_string(), + ]; + + for cov_name in &covariate_names { + if let Some(cov) = covmap.get(cov_name) { + if let Ok(value) = cov.interpolate(time) { + row.push(value.to_string()); + } else { + row.push(String::new()); + } + } else { + row.push(String::new()); + } + } + + writer.write_record(&row)?; + } + } + } + + writer.flush()?; + Ok(()) + } + + pub fn write_population(&self) -> anyhow::Result<()> { + use csv::WriterBuilder; + + let outputfile = crate::output::OutputFile::new(self.output_folder(), "population.csv")?; + let mut writer = WriterBuilder::new() + .has_headers(true) + .from_writer(outputfile.file()); + + let header = ["parameter", "mu", "omega_diag", "sd", "cv_percent"]; + writer.write_record(header)?; + + let names = self.population.param_names(); + let sds = self.standard_deviations(); + let cvs = self.cv_percent(); + + for (i, name) in names.iter().enumerate() { + let row = vec![ + name.clone(), + self.population.mu()[i].to_string(), + self.population.omega()[(i, i)].to_string(), + sds[i].to_string(), + cvs[i].to_string(), + ]; + writer.write_record(&row)?; + } + writer.flush()?; + + let outputfile = crate::output::OutputFile::new(self.output_folder(), "correlation.csv")?; + let mut writer = WriterBuilder::new().from_writer(outputfile.file()); + let corr = self.correlation_matrix(); + let names = self.population.param_names(); + let mut header = vec!["".to_string()]; + header.extend(names.clone()); + writer.write_record(&header)?; + for (i, name) in names.iter().enumerate() { + let mut row = vec![name.clone()]; + for j in 0..corr.ncols() { + row.push(format!("{:.4}", corr[(i, j)])); + } + writer.write_record(&row)?; + } + writer.flush()?; + Ok(()) + } + + pub fn write_uncertainty(&self) -> anyhow::Result<()> { + use csv::WriterBuilder; + + if !self.uncertainty.has_fim() && !self.uncertainty.has_standard_errors() { + return Ok(()); + } + + let outputfile = crate::output::OutputFile::new(self.output_folder(), "uncertainty.csv")?; + let mut writer = WriterBuilder::new() + .has_headers(true) + .from_writer(outputfile.file()); + writer.write_record(["kind", "parameter", "value"])?; + + if let Some(method) = self.uncertainty.fim_method() { + writer.write_record(["fim_method", "", &format!("{:?}", method)])?; + } + + for (index, name) in self.population.param_names().iter().enumerate() { + writer.write_record(["mu", name, &format!("{:.6}", self.population.mu()[index])])?; + + if let Some(se_mu) = self.uncertainty.se_mu() { + writer.write_record(["se_mu", name, &format!("{:.6}", se_mu[index])])?; + } + + if let Some(rse_mu) = self.uncertainty.rse_mu() { + writer.write_record(["rse_mu", name, &format!("{:.6}", rse_mu[index])])?; + } + + if let Some(se_omega) = self.uncertainty.se_omega() { + writer.write_record([ + "se_omega_diag", + name, + &format!("{:.6}", se_omega[(index, index)]), + ])?; + } + } + + writer.flush()?; + Ok(()) + } + + pub fn write_individual_parameters(&self) -> anyhow::Result<()> { + use csv::WriterBuilder; + + let outputfile = + crate::output::OutputFile::new(self.output_folder(), "individual_parameters.csv")?; + let mut writer = WriterBuilder::new() + .has_headers(true) + .from_writer(outputfile.file()); + + let names = self.population.param_names(); + let mut header = vec!["id".to_string()]; + for name in &names { + header.push(format!("psi_{}", name)); + } + writer.write_record(&header)?; + + for ind in self.individual_estimates.iter() { + let mut row = vec![ind.subject_id().to_string()]; + for i in 0..ind.npar() { + row.push(ind.psi()[i].to_string()); + } + writer.write_record(&row)?; + } + writer.flush()?; + Ok(()) + } + + pub fn write_individual_effects(&self) -> anyhow::Result<()> { + use csv::WriterBuilder; + + let outputfile = + crate::output::OutputFile::new(self.output_folder(), "individual_effects.csv")?; + let mut writer = WriterBuilder::new() + .has_headers(true) + .from_writer(outputfile.file()); + + let names = self.population.param_names(); + let mut header = vec!["id".to_string()]; + for name in &names { + header.push(format!("eta_{}", name)); + } + if self + .individual_estimates + .iter() + .any(|i| i.objective_function().is_some()) + { + header.push("objf".to_string()); + } + writer.write_record(&header)?; + + for ind in self.individual_estimates.iter() { + let mut row = vec![ind.subject_id().to_string()]; + for i in 0..ind.npar() { + row.push(ind.eta()[i].to_string()); + } + if let Some(objf) = ind.objective_function() { + row.push(objf.to_string()); + } + writer.write_record(&row)?; + } + writer.flush()?; + Ok(()) + } + + pub fn write_iteration_log(&self) -> anyhow::Result<()> { + self.iteration_log.write( + self.output_folder(), + &self.run_configuration.parameter_names, + ) + } + + pub fn with_compiled_state( + mut self, + compiled_state: ParametricModelState, + occasions: &[OccasionDesign], + ) -> Self { + self.individuals = self.individuals.with_occasion_design( + occasions, + &compiled_state.variability, + self.population.npar(), + ); + self.state = compiled_state.merged(self.state); + self + } + + pub fn into_fit_result(self) -> FitResult { + FitResult::Parametric(self) + } +} diff --git a/src/lib.rs b/src/lib.rs index b11062943..f2b5ae582 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,11 +1,22 @@ //! PMcore is a framework for developing and running population pharmacokinetic algorithms //! -//! The framework exposes a unified modeling platform for PMcore algorithms. +//! The framework supports both **non-parametric** and **parametric** approaches to population modeling, +//! allowing for flexible estimation of population parameter distributions. //! //! # Algorithm Types //! -//! The structure branch keeps the unified modeling/compilation layer and the baseline -//! non-parametric algorithms already present in PMcore. +//! ## Non-Parametric Algorithms +//! Represent the population distribution as a discrete set of support points with associated weights. +//! - NPAG (Non-Parametric Adaptive Grid) +//! - NPOD (Non-Parametric Optimal Design) +//! - And others... +//! +//! ## Parametric Algorithms +//! Represent the population distribution as a continuous distribution (typically multivariate normal). +//! - SAEM (Stochastic Approximation Expectation-Maximization) +//! - FOCEI (First-Order Conditional Estimation with Interaction) +//! - IT2B (Iterative Two-Stage Bayesian) +//! - And others... //! //! # Public API //! @@ -57,15 +68,26 @@ pub mod prelude { pub use crate::algorithms::Algorithm; pub use crate::api::fit; pub use crate::api::{ - AlgorithmTuning, ConvergenceOptions, EstimationMethod, EstimationProblem, LoggingLevel, - LoggingOptions, ModelDefinition, NonparametricMethod, NpagOptions, NpodOptions, - OutputPlan, PostProbOptions, RuntimeOptions, + AlgorithmTuning, ConvergenceOptions, EstimationMethod, EstimationProblem, FoceiOptions, + It2bOptions, LoggingLevel, LoggingOptions, ModelDefinition, NonparametricMethod, + NpagOptions, NpodOptions, OutputPlan, + ParametricMethod, PostProbOptions, RuntimeOptions, SaemOptions, }; pub use crate::compile::{CompiledProblem, DesignContext, ObservationIndex}; pub use crate::estimation::nonparametric::{ CycleLog, NPCycle, NPPredictions, NonparametricEngine, NonparametricWorkspace, Posterior, Psi, Theta, Weights, }; + pub use crate::estimation::parametric::{ + aic, bic, cache_predictions, compile_model_state, fim, fim_inverse, fim_method, has_fim, + has_standard_errors, importance_sampling_likelihood_estimates, phi_to_psi, psi_to_phi, + rse_mu, se_mu, se_omega, shrinkage, statistics, subject_conditionals_from_eta_samples, + uncertainty_estimates, write_statistics, CovarianceStructure, EtaTable, EtaVector, + FimMethod, FixedEffects, Individual, IndividualEffectsState, IndividualEstimates, + KappaVector, OccasionKappa, OccasionKappaTable, ParameterTransform, ParametricEngine, + ParametricModelState, ParametricTransformKind, ParametricWorkspace, PhiTable, PhiVector, + Population, PsiTable, PsiVector, RandomEffects, ResidualState, TransformSet, + }; pub use crate::model::{ ContinuousObservationSpec, CovariateEffectsSpec, CovariateModel, CovariateSpec, ModelMetadata, ObservationChannel, ObservationLikelihood, ObservationSpec, ParameterDomain, @@ -82,6 +104,15 @@ pub mod prelude { pub use crate::estimation::nonparametric::{read_prior, Prior}; + // Parametric specific + pub use crate::estimation::parametric::{StepSizeSchedule, SufficientStats}; + + // Output types + pub use crate::estimation::parametric::ParametricIterationLog; + + // Internal tuning types still used by the new runtime surface. + pub use crate::api::SaemConfig; + pub mod simulator { pub use pharmsol::prelude::simulator::*; } diff --git a/src/output/mod.rs b/src/output/mod.rs index a80a2878d..f1c957690 100644 --- a/src/output/mod.rs +++ b/src/output/mod.rs @@ -1,6 +1,7 @@ mod file; pub(crate) mod logging; pub mod nonparametric; +pub mod parametric; pub mod shared; pub mod writer; diff --git a/src/output/parametric.rs b/src/output/parametric.rs new file mode 100644 index 000000000..b636dc152 --- /dev/null +++ b/src/output/parametric.rs @@ -0,0 +1,58 @@ +use anyhow::Result; + +use crate::estimation::parametric::{self as posthoc, ParametricWorkspace}; +use crate::output::shared::shared_output_file_names; + +pub(crate) fn output_file_names( + result: &ParametricWorkspace, +) -> Vec { + let mut files = shared_output_file_names(); + files.extend( + [ + "population.csv", + "correlation.csv", + "individual_parameters.csv", + "individual_effects.csv", + "iterations.csv", + "statistics.csv", + "shrinkage.csv", + "residual_error.csv", + ] + .into_iter() + .map(str::to_string), + ); + + if result.uncertainty().has_fim() || result.uncertainty().has_standard_errors() { + files.push("uncertainty.csv".to_string()); + } + + let has_covariates = result.data().subjects().iter().any(|subject| { + subject + .occasions() + .iter() + .any(|occasion| !occasion.covariates().covariates().is_empty()) + }); + if has_covariates { + files.push("covariates.csv".to_string()); + } + + files.sort(); + files.dedup(); + files +} + +pub fn write_parametric_workspace_outputs( + result: &mut ParametricWorkspace, +) -> Result<()> { + let (idelta, tad) = result.prediction_interval(); + result.write_population()?; + result.write_individual_parameters()?; + result.write_individual_effects()?; + result.write_iteration_log()?; + posthoc::cache_predictions(result, idelta, tad)?; + posthoc::write_statistics(result)?; + result.write_uncertainty()?; + result.sigma().write(result.output_folder())?; + result.write_covariates()?; + Ok(()) +} diff --git a/src/output/writer.rs b/src/output/writer.rs index 60ef89a9f..f6980b889 100644 --- a/src/output/writer.rs +++ b/src/output/writer.rs @@ -4,9 +4,11 @@ use serde::Serialize; use crate::estimation::nonparametric as np_estimation; use crate::estimation::nonparametric::NonparametricWorkspace; -use crate::output::{nonparametric as np_output, shared}; +use crate::estimation::parametric as param_estimation; +use crate::estimation::parametric::ParametricWorkspace; +use crate::output::{nonparametric as np_output, parametric, shared}; use crate::results::FitResult; -use crate::results::{nonparametric_diagnostics, FitSummary}; +use crate::results::{nonparametric_diagnostics, parametric_diagnostics, FitSummary}; #[derive(Debug, Clone, Serialize)] struct SharedPredictionRow { @@ -26,6 +28,7 @@ struct SharedPredictionRow { pub fn write_result(result: &mut FitResult) -> Result<()> { match result { FitResult::Nonparametric(inner) => write_nonparametric_result(inner)?, + FitResult::Parametric(inner) => write_parametric_workspace_result(inner)?, } Ok(()) @@ -67,6 +70,46 @@ pub fn write_nonparametric_result( Ok(()) } +pub fn write_parametric_workspace_result( + result: &mut ParametricWorkspace, +) -> Result<()> { + if !result.should_write_outputs() { + return Ok(()); + } + + let folder = result.output_folder().to_string(); + shared::write_settings(&folder, result.run_configuration())?; + shared::write_summary(&folder, ¶metric_summary(result))?; + shared::write_diagnostics(&folder, ¶metric_diagnostics(result))?; + parametric::write_parametric_workspace_outputs(result)?; + + if let Some(predictions) = result.predictions() { + let rows = predictions + .predictions() + .iter() + .map(|row| SharedPredictionRow { + id: row.id().to_string(), + time: row.time(), + outeq: row.outeq(), + block: row.block(), + obs: row.obs(), + cens: row.censoring(), + pred_population: row.ppred(), + pred_individual: row.ipred(), + residual_population: row.obs().map(|obs| obs - row.ppred()), + residual_individual: row.ires(), + source_method: "parametric".to_string(), + }); + shared::write_csv_rows(&folder, "predictions.csv", rows)?; + } + + Ok(()) +} + fn nonparametric_summary(result: &NonparametricWorkspace) -> FitSummary { np_estimation::fit_summary(result) } + +fn parametric_summary(result: &ParametricWorkspace) -> FitSummary { + param_estimation::fit_summary(result) +} diff --git a/src/results/artifacts.rs b/src/results/artifacts.rs index abd817c3e..541a2e246 100644 --- a/src/results/artifacts.rs +++ b/src/results/artifacts.rs @@ -5,6 +5,7 @@ use pharmsol::Equation; use serde::{Deserialize, Serialize}; use crate::estimation::nonparametric::NonparametricWorkspace; +use crate::estimation::parametric::ParametricWorkspace; use crate::output::shared::shared_output_file_names; #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)] @@ -26,6 +27,14 @@ pub(crate) fn nonparametric_artifacts( ) } +pub(crate) fn parametric_artifacts(result: &ParametricWorkspace) -> ArtifactIndex { + artifact_index( + result.output_folder(), + result.should_write_outputs(), + crate::output::parametric::output_file_names(result), + ) +} + fn artifact_index( folder: &str, should_write_outputs: bool, diff --git a/src/results/diagnostics.rs b/src/results/diagnostics.rs index cfeaf2b5a..1f11b1919 100644 --- a/src/results/diagnostics.rs +++ b/src/results/diagnostics.rs @@ -4,6 +4,7 @@ use pharmsol::Equation; use serde::{Deserialize, Serialize}; use crate::estimation::nonparametric::NonparametricWorkspace; +use crate::estimation::parametric::ParametricWorkspace; #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)] pub struct DiagnosticsBundle { @@ -57,3 +58,110 @@ pub(crate) fn nonparametric_diagnostics( estimator_metadata, } } + +pub(crate) fn parametric_diagnostics( + result: &ParametricWorkspace, +) -> DiagnosticsBundle { + let mut warnings = Vec::new(); + let mut deferred_features = Vec::new(); + let mut convergence_notes = Vec::new(); + let mut estimator_metadata = BTreeMap::new(); + + if result.converged() { + convergence_notes.push("Estimator reported convergence.".to_string()); + } else { + warnings.push("Estimator stopped without convergence.".to_string()); + convergence_notes.push(format!("Final status: {:?}", result.status())); + } + + if result.state().variability.occasion.is_some() { + deferred_features.push("occasion_inference".to_string()); + warnings.push( + "Occasion variability is represented in the compiled and fitted state, but occasion-level inference remains deferred; occasion_kappa entries are structural placeholders.".to_string(), + ); + estimator_metadata.insert("occasion_inference".to_string(), "deferred".to_string()); + } else { + estimator_metadata.insert( + "occasion_inference".to_string(), + "not_requested".to_string(), + ); + } + + estimator_metadata.insert("algorithm".to_string(), format!("{:?}", result.algorithm())); + estimator_metadata.insert("status".to_string(), format!("{:?}", result.status())); + estimator_metadata.insert( + "outputs_requested".to_string(), + result.should_write_outputs().to_string(), + ); + estimator_metadata.insert( + "iteration_log_entries".to_string(), + result.iteration_log().len().to_string(), + ); + estimator_metadata.insert( + "prediction_cache".to_string(), + if result.predictions().is_some() { + "available".to_string() + } else { + "not_materialized".to_string() + }, + ); + estimator_metadata.insert( + "residual_error_model".to_string(), + if result.sigma().model_type.is_empty() { + "none".to_string() + } else { + result.sigma().model_type.clone() + }, + ); + estimator_metadata.insert( + "residual_error_output".to_string(), + if !result.should_write_outputs() { + "disabled".to_string() + } else if result.sigma().model_type.is_empty() { + "not_available".to_string() + } else { + "expected".to_string() + }, + ); + estimator_metadata.insert( + "uncertainty_method".to_string(), + result + .uncertainty() + .fim_method() + .map(|method| format!("{:?}", method)) + .unwrap_or_else(|| "none".to_string()), + ); + estimator_metadata.insert( + "uncertainty_output".to_string(), + if !result.should_write_outputs() { + "disabled".to_string() + } else if result.uncertainty().has_fim() || result.uncertainty().has_standard_errors() { + "expected".to_string() + } else { + "not_available".to_string() + }, + ); + estimator_metadata.insert( + "likelihood_best_objf".to_string(), + result.best_objf().to_string(), + ); + estimator_metadata.insert( + "objective_source".to_string(), + if result.likelihoods().ll_gaussian_quadrature.is_some() { + "gaussian_quadrature".to_string() + } else if result.likelihoods().ll_importance_sampling.is_some() { + "importance_sampling".to_string() + } else if result.likelihoods().ll_linearization.is_some() { + "linearization".to_string() + } else { + "algorithm_state".to_string() + }, + ); + + DiagnosticsBundle { + warnings, + deferred_features, + convergence_notes, + estimator_metadata, + } +} diff --git a/src/results/fit_result.rs b/src/results/fit_result.rs index e7d76fb1c..3bde6f9c1 100644 --- a/src/results/fit_result.rs +++ b/src/results/fit_result.rs @@ -2,27 +2,32 @@ use anyhow::Result; use pharmsol::Equation; use crate::estimation::nonparametric::NonparametricWorkspace; -use crate::estimation::nonparametric; +use crate::estimation::parametric::ParametricWorkspace; +use crate::estimation::{nonparametric, parametric}; use crate::results::{ - nonparametric_artifacts, nonparametric_diagnostics, nonparametric_predictions, ArtifactIndex, + nonparametric_artifacts, nonparametric_diagnostics, nonparametric_predictions, + parametric_artifacts, parametric_diagnostics, parametric_predictions, ArtifactIndex, DiagnosticsBundle, FitSummary, IndividualSummary, PopulationSummary, PredictionsBundle, }; #[derive(Debug)] pub enum FitResult { Nonparametric(NonparametricWorkspace), + Parametric(ParametricWorkspace), } impl FitResult { pub fn objf(&self) -> f64 { match self { Self::Nonparametric(result) => result.objf(), + Self::Parametric(result) => result.objf(), } } pub fn converged(&self) -> bool { match self { Self::Nonparametric(result) => result.converged(), + Self::Parametric(result) => result.converged(), } } @@ -33,42 +38,56 @@ impl FitResult { pub fn summary(&self) -> FitSummary { match self { Self::Nonparametric(result) => nonparametric::fit_summary(result), + Self::Parametric(result) => parametric::fit_summary(result), } } pub fn population_summary(&self) -> PopulationSummary { match self { Self::Nonparametric(result) => nonparametric::population_summary(result), + Self::Parametric(result) => parametric::population_summary(result), } } pub fn individual_summaries(&self) -> Vec { match self { Self::Nonparametric(result) => nonparametric::individual_summaries(result), + Self::Parametric(result) => parametric::individual_summaries(result), } } pub fn diagnostics(&self) -> DiagnosticsBundle { match self { Self::Nonparametric(result) => nonparametric_diagnostics(result), + Self::Parametric(result) => parametric_diagnostics(result), } } pub fn predictions(&self) -> PredictionsBundle { match self { Self::Nonparametric(result) => nonparametric_predictions(result), + Self::Parametric(result) => parametric_predictions(result), } } pub fn artifacts(&self) -> ArtifactIndex { match self { Self::Nonparametric(result) => nonparametric_artifacts(result), + Self::Parametric(result) => parametric_artifacts(result), } } pub fn as_nonparametric(&self) -> Option<&NonparametricWorkspace> { match self { Self::Nonparametric(result) => Some(result), + Self::Parametric(_) => None, + } + } + + pub fn as_parametric(&self) -> Option<&ParametricWorkspace> { + match self { + Self::Nonparametric(_) => None, + Self::Parametric(result) => Some(result), } } } diff --git a/src/results/mod.rs b/src/results/mod.rs index 3ff684906..5ee1aadc1 100644 --- a/src/results/mod.rs +++ b/src/results/mod.rs @@ -10,6 +10,6 @@ pub use fit_result::FitResult; pub use predictions::PredictionsBundle; pub use summary::{FitSummary, IndividualSummary, ParameterSummary, PopulationSummary}; -pub(crate) use artifacts::nonparametric_artifacts; -pub(crate) use diagnostics::nonparametric_diagnostics; -pub(crate) use predictions::nonparametric_predictions; +pub(crate) use artifacts::{nonparametric_artifacts, parametric_artifacts}; +pub(crate) use diagnostics::{nonparametric_diagnostics, parametric_diagnostics}; +pub(crate) use predictions::{nonparametric_predictions, parametric_predictions}; diff --git a/src/results/predictions.rs b/src/results/predictions.rs index 8f3118999..7a3640a82 100644 --- a/src/results/predictions.rs +++ b/src/results/predictions.rs @@ -2,7 +2,8 @@ use pharmsol::Equation; use serde::{Deserialize, Serialize}; use crate::estimation::nonparametric::NonparametricWorkspace; -use crate::results::nonparametric_artifacts; +use crate::estimation::parametric::ParametricWorkspace; +use crate::results::{nonparametric_artifacts, parametric_artifacts}; #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)] pub struct PredictionsBundle { @@ -36,3 +37,28 @@ pub(crate) fn nonparametric_predictions( artifact, } } + +pub(crate) fn parametric_predictions( + result: &ParametricWorkspace, +) -> PredictionsBundle { + let artifact = parametric_artifacts(result) + .files + .into_iter() + .find(|file| file == "predictions.csv"); + + if let Some(predictions) = result.predictions() { + return PredictionsBundle { + available: true, + row_count: Some(predictions.predictions().len()), + source: Some("in_memory".to_string()), + artifact, + }; + } + + PredictionsBundle { + available: artifact.is_some(), + row_count: None, + source: artifact.as_ref().map(|_| "artifact".to_string()), + artifact, + } +} diff --git a/tests/acceptance_baseline_tests.rs b/tests/acceptance_baseline_tests.rs index 91462faad..a22ffb3c4 100644 --- a/tests/acceptance_baseline_tests.rs +++ b/tests/acceptance_baseline_tests.rs @@ -1,7 +1,11 @@ use anyhow::Result; -use pharmsol::{AssayErrorModel, ErrorPoly}; +use pharmsol::{ResidualErrorModel, ResidualErrorModels}; use pmcore::prelude::*; +#[allow(dead_code)] +#[path = "saem_validation/reference.rs"] +mod saem_reference; + fn bimodal_ode_equation() -> equation::ODE { ode! { diffeq: |x, p, _t, dx, b, rateiv, _cov| { @@ -16,6 +20,41 @@ fn bimodal_ode_equation() -> equation::ODE { .with_solver(OdeSolver::ExplicitRk(ExplicitRkTableau::Tsit45)) } +fn simple_focei_equation() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, b, _rateiv, _cov| { + fetch_params!(p, ke); + dx[0] = -ke * x[0] + b[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, v); + y[0] = x[0] / v; + }, + ) +} + +fn bimodal_analytical_equation() -> equation::Analytical { + equation::Analytical::new( + |x, p, t, rateiv, _cov| { + let mut xout = x.clone(); + fetch_params!(p, ke, _v); + xout[0] = x[0] * (-ke * t).exp() + rateiv[1] / ke * (1.0 - (-ke * t).exp()); + xout + }, + |_p, _t, _cov| {}, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[1] = x[0] / v; + }, + ) +} + fn bimodal_data() -> Result { Ok(data::read_pmetrics("examples/bimodal_ke/bimodal_ke.csv")?) } @@ -38,12 +77,59 @@ fn bimodal_npag_model() -> Result> { .build() } -fn assert_close(actual: f64, expected: f64, tolerance: f64, label: &str) { - let delta = (actual - expected).abs(); - assert!( - delta <= tolerance, - "{label}: expected {expected}, got {actual}, delta {delta} > {tolerance}" - ); +fn bimodal_saem_problem() -> Result> { + let observations = ObservationSpec::new() + .add_channel(ObservationChannel::continuous(1, "cp")) + .with_residual_error_models( + ResidualErrorModels::new().add(1, ResidualErrorModel::proportional(0.1)), + ); + + let model = ModelDefinition::builder(bimodal_analytical_equation()) + .parameters( + ParameterSpace::new() + .add(ParameterSpec::bounded("ke", 0.01, 0.5)) + .add(ParameterSpec::bounded("v", 50.0, 180.0)), + ) + .observations(observations) + .build()?; + + EstimationProblem::builder(model, bimodal_data()?) + .method(EstimationMethod::Parametric(ParametricMethod::Saem( + SaemOptions::default(), + ))) + .output(OutputPlan::disabled()) + .runtime(RuntimeOptions { + progress: false, + ..RuntimeOptions::default() + }) + .build() +} + +fn canonical_focei_data() -> Data { + Subject::builder("1") + .bolus(0.0, 100.0, 0) + .observation(1.0, 10.0, 0) + .observation(2.0, 8.0, 0) + .build() + .into() +} + +fn canonical_focei_model() -> Result> { + let assay_error = AssayErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0); + let residual_error = ResidualErrorModels::new().add(0, ResidualErrorModel::combined(0.5, 0.1)); + let observations = ObservationSpec::new() + .add_channel(ObservationChannel::continuous(0, "cp")) + .with_assay_error_models(AssayErrorModels::new().add(0, assay_error)?) + .with_residual_error_models(residual_error); + + ModelDefinition::builder(simple_focei_equation()) + .parameters( + ParameterSpace::new() + .add(ParameterSpec::bounded("ke", 0.1, 1.0)) + .add(ParameterSpec::bounded("v", 1.0, 20.0)), + ) + .observations(observations) + .build() } #[test] @@ -65,7 +151,8 @@ fn test_acceptance_baseline_npag_bimodal_ke() -> Result<()> { .as_nonparametric() .expect("NPAG acceptance baseline should yield a nonparametric result"); - assert_close( + // This is the canonical rewrite-blocking nonparametric baseline for the bimodal_ke example. + saem_reference::assert_close( summary.objective_function, -425.60904902364695, 1e-6, @@ -74,13 +161,13 @@ fn test_acceptance_baseline_npag_bimodal_ke() -> Result<()> { assert!(summary.converged); assert_eq!(summary.iterations, 288); assert_eq!(result.get_theta().nspp(), 46); - assert_close( + saem_reference::assert_close( population.parameters[0].mean, 0.187047284678325, 1e-6, "npag.ke.mean", ); - assert_close( + saem_reference::assert_close( population.parameters[1].mean, 107.94241284196241, 1e-6, @@ -88,3 +175,81 @@ fn test_acceptance_baseline_npag_bimodal_ke() -> Result<()> { ); Ok(()) } + +#[test] +fn test_acceptance_baseline_saem_bimodal_ke() -> Result<()> { + let result = bimodal_saem_problem()?.run()?; + let summary = result.summary(); + let result = result + .as_parametric() + .expect("SAEM acceptance baseline should yield a parametric result"); + + let mu_psi: Vec = (0..result.population().npar()) + .map(|index| result.population().mu()[index]) + .collect(); + let omega_diag: Vec = (0..result.population().npar()) + .map(|index| result.population().omega()[(index, index)]) + .collect(); + + // This is the canonical rewrite-blocking parametric baseline for the bimodal_ke example. + saem_reference::assert_close( + summary.objective_function, + -144.18431437030802, + 1e-6, + "saem.objf", + ); + assert!(!summary.converged); + assert_eq!(summary.iterations, 400); + saem_reference::assert_vec_close( + &mu_psi, + &[0.18709059357497426, 105.26324936442889], + 1e-6, + "saem.mu_psi", + ); + saem_reference::assert_vec_close( + &omega_diag, + &[0.026795214431165025, 489.84880731024896], + 1e-6, + "saem.omega_diag", + ); + saem_reference::assert_close(result.sigma().as_vec()[0], 0.102540, 1e-4, "saem.sigma"); + Ok(()) +} + +#[test] +fn test_acceptance_baseline_focei_onecomp() -> Result<()> { + let result = EstimationProblem::builder(canonical_focei_model()?, canonical_focei_data()) + .method(EstimationMethod::Parametric(ParametricMethod::Focei( + FoceiOptions, + ))) + .output(OutputPlan::disabled()) + .runtime(RuntimeOptions { + cycles: 3, + progress: false, + ..RuntimeOptions::default() + }) + .run()?; + let result = result + .as_parametric() + .expect("FOCEI acceptance baseline should yield a parametric result"); + + let mu: Vec = (0..result.population().npar()) + .map(|index| result.population().mu()[index]) + .collect(); + let omega_diag: Vec = (0..result.population().npar()) + .map(|index| result.population().omega()[(index, index)]) + .collect(); + + // FOCEI is deterministic on this simple canonical path, so the baseline is exact. + saem_reference::assert_close(result.objf(), 73.802216624458, 1e-9, "focei.objf"); + saem_reference::assert_vec_close(&mu, &[1.0, 10.5], 1e-12, "focei.mu"); + saem_reference::assert_vec_close(&omega_diag, &[1e-8, 1e-8], 1e-12, "focei.omega_diag"); + assert_eq!(result.sigma().combined, Some((0.5, 0.1))); + assert!(has_fim(result)); + assert_eq!(fim_method(result), Some(FimMethod::Linearization)); + let se_mu = se_mu(result).expect("FOCEI should expose standard errors on the shared surface"); + assert_eq!(se_mu.nrows(), 2); + assert!(se_mu[0].is_finite()); + assert!(se_mu[1].is_finite()); + Ok(()) +} diff --git a/tests/api_smoke_tests.rs b/tests/api_smoke_tests.rs index 5c0679414..6c73157d8 100644 --- a/tests/api_smoke_tests.rs +++ b/tests/api_smoke_tests.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use pharmsol::{AssayErrorModel, ErrorPoly}; +use pharmsol::{AssayErrorModel, ErrorPoly, ResidualErrorModel, ResidualErrorModels}; use pmcore::prelude::*; fn simple_equation() -> equation::ODE { @@ -28,6 +28,57 @@ fn simple_data() -> Data { Data::new(vec![subject]) } +fn structured_parametric_data() -> Data { + let first = Subject::builder("1") + .covariate("wt", 0.0, 60.0) + .bolus(0.0, 100.0, 0) + .observation(1.0, 12.0, 0) + .observation(2.0, 8.5, 0) + .observation(4.0, 4.8, 0) + .build(); + + let second = Subject::builder("2") + .covariate("wt", 0.0, 90.0) + .bolus(0.0, 100.0, 0) + .observation(1.0, 9.5, 0) + .observation(2.0, 6.4, 0) + .observation(4.0, 3.1, 0) + .build(); + + Data::new(vec![first, second]) +} + +fn structured_multi_occasion_parametric_data() -> Data { + let subject = Subject::builder("1") + .covariate("wt", 0.0, 70.0) + .covariate("study_day", 0.0, 1.0) + .bolus(0.0, 100.0, 0) + .observation(1.0, 10.0, 0) + .reset() + .covariate("wt", 0.0, 70.0) + .covariate("study_day", 0.0, 2.0) + .bolus(0.0, 100.0, 0) + .observation(1.0, 8.0, 0) + .build(); + + Data::new(vec![subject]) +} + +fn assert_subject_covariate_snapshot(result: &ParametricWorkspace) { + let covariates = result + .state() + .covariates + .subject_effects + .as_ref() + .expect("structured subject covariates should be preserved in the fitted state"); + + assert!(result.objf().is_finite()); + assert_eq!(covariates.parameter_names, vec!["ke", "v"]); + assert_eq!(covariates.column_names, vec!["wt"]); + assert_eq!(covariates.covariate_mask, vec![vec![true], vec![false]]); + assert_eq!(covariates.values, vec![vec![Some(60.0)], vec![Some(90.0)]]); +} + #[test] fn test_model_definition_builder() -> Result<()> { let assay_error = AssayErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0); @@ -87,3 +138,498 @@ fn test_unified_fit_nonparametric_smoke() -> Result<()> { assert_eq!(result.individual_summaries().len(), 1); Ok(()) } + +#[test] +fn test_parametric_problem_requires_residual_error_models() -> Result<()> { + let assay_error = AssayErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0); + let observations = ObservationSpec::new() + .add_channel(ObservationChannel::continuous(0, "cp")) + .with_assay_error_models(AssayErrorModels::new().add(0, assay_error)?); + + let model = ModelDefinition::builder(simple_equation()) + .parameters( + ParameterSpace::new() + .add(ParameterSpec::bounded("ke", 0.1, 1.0)) + .add(ParameterSpec::bounded("v", 1.0, 20.0)), + ) + .observations(observations) + .build()?; + + let problem = EstimationProblem::builder(model, simple_data()) + .method(EstimationMethod::Parametric(ParametricMethod::Saem( + SaemOptions, + ))) + .output(OutputPlan::disabled()) + .build()?; + + assert!(problem.run().is_err()); + Ok(()) +} + +#[test] +fn test_parametric_problem_accepts_residual_error_models() -> Result<()> { + let assay_error = AssayErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0); + let residual_error = ResidualErrorModels::new().add(0, ResidualErrorModel::combined(0.5, 0.1)); + let observations = ObservationSpec::new() + .add_channel(ObservationChannel::continuous(0, "cp")) + .with_assay_error_models(AssayErrorModels::new().add(0, assay_error)?) + .with_residual_error_models(residual_error); + + let model = ModelDefinition::builder(simple_equation()) + .parameters( + ParameterSpace::new() + .add(ParameterSpec::bounded("ke", 0.1, 1.0)) + .add(ParameterSpec::bounded("v", 1.0, 20.0)), + ) + .observations(observations) + .build()?; + + let problem = EstimationProblem::builder(model, simple_data()) + .method(EstimationMethod::Parametric(ParametricMethod::Saem( + SaemOptions, + ))) + .output(OutputPlan::disabled()) + .build()?; + + let compiled = problem.compile()?; + assert!(compiled.model.observations.residual_error_models.is_some()); + Ok(()) +} + +#[test] +fn test_unified_fit_parametric_structured_covariates_smoke() -> Result<()> { + let assay_error = AssayErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0); + let residual_error = ResidualErrorModels::new().add(0, ResidualErrorModel::combined(0.5, 0.1)); + let observations = ObservationSpec::new() + .add_channel(ObservationChannel::continuous(0, "cp")) + .with_assay_error_models(AssayErrorModels::new().add(0, assay_error)?) + .with_residual_error_models(residual_error); + + let model = ModelDefinition::builder(simple_equation()) + .parameters( + ParameterSpace::new() + .add(ParameterSpec::bounded("ke", 0.1, 1.0)) + .add(ParameterSpec::bounded("v", 5.0, 20.0)), + ) + .observations(observations) + .covariates(CovariateSpec::Structured(CovariateEffectsSpec { + subject_effects: Some(CovariateModel::new( + vec!["ke", "v"], + vec!["wt"], + vec![vec![true], vec![false]], + )?), + occasion_effects: None, + })) + .build()?; + + let result = EstimationProblem::builder(model, structured_parametric_data()) + .method(EstimationMethod::Parametric(ParametricMethod::Saem( + SaemOptions, + ))) + .output(OutputPlan::disabled()) + .runtime(RuntimeOptions { + progress: false, + tuning: AlgorithmTuning { + saem: SaemConfig { + k1_iterations: 2, + k2_iterations: 1, + burn_in: 1, + mcmc_iterations: 1, + n_kernels: 1, + compute_map: false, + compute_fim: false, + compute_ll_is: false, + compute_ll_gq: false, + n_mc_is: 32, + ..SaemConfig::default() + }, + ..AlgorithmTuning::default() + }, + ..RuntimeOptions::default() + }) + .run()?; + + let result = result + .as_parametric() + .expect("SAEM should yield a parametric result"); + assert_subject_covariate_snapshot(result); + assert_eq!( + result + .state() + .covariates + .subject_effects + .as_ref() + .expect("structured subject covariates should be preserved in the fitted state") + .coefficients + .len(), + 3 + ); + Ok(()) +} + +#[test] +fn test_unified_fit_parametric_focei_smoke() -> Result<()> { + let assay_error = AssayErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0); + let residual_error = ResidualErrorModels::new().add(0, ResidualErrorModel::combined(0.5, 0.1)); + let observations = ObservationSpec::new() + .add_channel(ObservationChannel::continuous(0, "cp")) + .with_assay_error_models(AssayErrorModels::new().add(0, assay_error)?) + .with_residual_error_models(residual_error); + + let model = ModelDefinition::builder(simple_equation()) + .parameters( + ParameterSpace::new() + .add(ParameterSpec::bounded("ke", 0.1, 1.0)) + .add(ParameterSpec::bounded("v", 1.0, 20.0)), + ) + .observations(observations) + .build()?; + + let result = EstimationProblem::builder(model, simple_data()) + .method(EstimationMethod::Parametric(ParametricMethod::Focei( + FoceiOptions, + ))) + .output(OutputPlan::disabled()) + .runtime(RuntimeOptions { + cycles: 3, + progress: false, + ..RuntimeOptions::default() + }) + .run()?; + + let result = result + .as_parametric() + .expect("FOCEI should yield a parametric result"); + + assert!(result.objf().is_finite()); + assert_eq!(result.population().param_names(), vec!["ke", "v"]); + assert_eq!(result.individual_estimates().nsubjects(), 1); + assert_eq!(result.sigma().combined, Some((0.5, 0.1))); + assert!(has_fim(result)); + assert_eq!(fim_method(result), Some(FimMethod::Linearization)); + Ok(()) +} + +#[test] +fn test_unified_fit_parametric_focei_structured_covariates_smoke() -> Result<()> { + let assay_error = AssayErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0); + let residual_error = ResidualErrorModels::new().add(0, ResidualErrorModel::combined(0.5, 0.1)); + let observations = ObservationSpec::new() + .add_channel(ObservationChannel::continuous(0, "cp")) + .with_assay_error_models(AssayErrorModels::new().add(0, assay_error)?) + .with_residual_error_models(residual_error); + + let model = ModelDefinition::builder(simple_equation()) + .parameters( + ParameterSpace::new() + .add(ParameterSpec::bounded("ke", 0.1, 1.0)) + .add(ParameterSpec::bounded("v", 5.0, 20.0)), + ) + .observations(observations) + .covariates(CovariateSpec::Structured(CovariateEffectsSpec { + subject_effects: Some(CovariateModel::new( + vec!["ke", "v"], + vec!["wt"], + vec![vec![true], vec![false]], + )?), + occasion_effects: None, + })) + .build()?; + + let result = EstimationProblem::builder(model, structured_parametric_data()) + .method(EstimationMethod::Parametric(ParametricMethod::Focei( + FoceiOptions, + ))) + .output(OutputPlan::disabled()) + .runtime(RuntimeOptions { + cycles: 3, + progress: false, + ..RuntimeOptions::default() + }) + .run()?; + + let result = result + .as_parametric() + .expect("FOCEI should yield a parametric result"); + assert_subject_covariate_snapshot(result); + assert_eq!(result.individual_estimates().nsubjects(), 2); + Ok(()) +} + +#[test] +fn test_unified_fit_parametric_focei_preserves_occasion_covariates() -> Result<()> { + let assay_error = AssayErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0); + let residual_error = ResidualErrorModels::new().add(0, ResidualErrorModel::combined(0.5, 0.1)); + let observations = ObservationSpec::new() + .add_channel(ObservationChannel::continuous(0, "cp")) + .with_assay_error_models(AssayErrorModels::new().add(0, assay_error)?) + .with_residual_error_models(residual_error); + + let model = ModelDefinition::builder(simple_equation()) + .parameters( + ParameterSpace::new() + .add(ParameterSpec::bounded("ke", 0.1, 1.0)) + .add(ParameterSpec::bounded("v", 5.0, 20.0)), + ) + .observations(observations) + .covariates(CovariateSpec::Structured(CovariateEffectsSpec { + subject_effects: Some(CovariateModel::new( + vec!["ke", "v"], + vec!["wt"], + vec![vec![true], vec![false]], + )?), + occasion_effects: Some(CovariateModel::new( + vec!["ke", "v"], + vec!["study_day"], + vec![vec![true], vec![false]], + )?), + })) + .build()?; + + let result = EstimationProblem::builder(model, structured_multi_occasion_parametric_data()) + .method(EstimationMethod::Parametric(ParametricMethod::Focei( + FoceiOptions, + ))) + .output(OutputPlan::disabled()) + .runtime(RuntimeOptions { + cycles: 2, + progress: false, + ..RuntimeOptions::default() + }) + .run()?; + + let result = result + .as_parametric() + .expect("FOCEI should yield a parametric result"); + let occasion = result + .state() + .covariates + .occasion_effects + .as_ref() + .expect("occasion covariates should be preserved in the fitted state"); + + assert!(result.objf().is_finite()); + assert_eq!(occasion.column_names, vec!["study_day"]); + assert_eq!(occasion.parameter_names, vec!["ke", "v"]); + assert_eq!(occasion.values, vec![vec![Some(1.0)], vec![Some(2.0)]]); + Ok(()) +} + +#[test] +fn test_unified_fit_parametric_saem_preserves_occasion_covariates() -> Result<()> { + let assay_error = AssayErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0); + let residual_error = ResidualErrorModels::new().add(0, ResidualErrorModel::combined(0.5, 0.1)); + let observations = ObservationSpec::new() + .add_channel(ObservationChannel::continuous(0, "cp")) + .with_assay_error_models(AssayErrorModels::new().add(0, assay_error)?) + .with_residual_error_models(residual_error); + + let model = ModelDefinition::builder(simple_equation()) + .parameters( + ParameterSpace::new() + .add(ParameterSpec { + name: "ke".to_string(), + domain: ParameterDomain::Bounded { + lower: 0.1, + upper: 1.0, + }, + transform: ModelParameterTransform::Identity, + initial: Some(0.4), + estimate: true, + variability: ParameterVariability::SubjectAndOccasion, + }) + .add(ParameterSpec::bounded("v", 5.0, 20.0)), + ) + .observations(observations) + .covariates(CovariateSpec::Structured(CovariateEffectsSpec { + subject_effects: Some(CovariateModel::new( + vec!["ke", "v"], + vec!["wt"], + vec![vec![true], vec![false]], + )?), + occasion_effects: Some(CovariateModel::new( + vec!["ke", "v"], + vec!["study_day"], + vec![vec![true], vec![false]], + )?), + })) + .build()?; + + let result = EstimationProblem::builder(model, structured_multi_occasion_parametric_data()) + .method(EstimationMethod::Parametric(ParametricMethod::Saem( + SaemOptions, + ))) + .output(OutputPlan::disabled()) + .runtime(RuntimeOptions { + progress: false, + tuning: AlgorithmTuning { + saem: SaemConfig { + k1_iterations: 2, + k2_iterations: 1, + burn_in: 1, + mcmc_iterations: 1, + n_kernels: 1, + compute_map: false, + compute_fim: false, + compute_ll_is: false, + compute_ll_gq: false, + n_mc_is: 32, + ..SaemConfig::default() + }, + ..AlgorithmTuning::default() + }, + ..RuntimeOptions::default() + }) + .run()?; + + let result = result + .as_parametric() + .expect("SAEM should yield a parametric result"); + let occasion = result + .state() + .covariates + .occasion_effects + .as_ref() + .expect("occasion covariates should be preserved in the fitted state"); + let occasion_kappa = result + .individuals() + .occasion_kappa + .as_ref() + .expect("occasion effect slots should exist for occasion-enabled SAEM models"); + + assert!(result.objf().is_finite()); + assert_eq!(occasion.column_names, vec!["study_day"]); + assert_eq!(occasion.parameter_names, vec!["ke", "v"]); + assert_eq!(occasion.values, vec![vec![Some(1.0)], vec![Some(2.0)]]); + assert_eq!(occasion_kappa.0.len(), 2); + assert_eq!(occasion_kappa.0[0].subject_index, 0); + assert_eq!(occasion_kappa.0[0].occasion_index, 0); + assert_eq!(occasion_kappa.0[1].occasion_index, 1); + assert_eq!(occasion_kappa.0[0].values.0, vec![0.0, 0.0]); + assert_eq!( + result + .state() + .variability + .occasion + .as_ref() + .expect("occasion variability should be present") + .enabled_for, + vec![true, false] + ); + Ok(()) +} + +#[test] +fn test_problem_compile_preserves_runtime_configuration() -> Result<()> { + let assay_error = AssayErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0); + let observations = ObservationSpec::new() + .add_channel(ObservationChannel::continuous(0, "cp")) + .with_assay_error_models(AssayErrorModels::new().add(0, assay_error)?); + + let model = ModelDefinition::builder(simple_equation()) + .parameters( + ParameterSpace::new() + .add(ParameterSpec::bounded("ke", 0.1, 1.0)) + .add(ParameterSpec::bounded("v", 1.0, 20.0)), + ) + .observations(observations) + .build()?; + + let runtime = RuntimeOptions { + cycles: 7, + cache: false, + progress: false, + idelta: 0.5, + tad: 24.0, + prior: None, + logging: LoggingOptions { + initialize: false, + level: LoggingLevel::Debug, + write: true, + stdout: false, + }, + convergence: ConvergenceOptions { + likelihood: 1e-5, + pyl: 5e-3, + eps: 2e-3, + }, + tuning: AlgorithmTuning { + min_distance: 2e-4, + nm_steps: 222, + tolerance: 3e-6, + saem: SaemConfig { + k1_iterations: 111, + k2_iterations: 22, + ..SaemConfig::default() + }, + }, + }; + + let compiled = EstimationProblem::builder(model, simple_data()) + .method(EstimationMethod::Nonparametric(NonparametricMethod::Npag( + NpagOptions, + ))) + .output(OutputPlan::disabled()) + .runtime(runtime) + .build()? + .compile()?; + + assert_eq!(compiled.method().algorithm(), Algorithm::NPAG); + assert!(!compiled.output_plan().write); + assert_eq!(compiled.runtime_options().cycles, 7); + assert!(!compiled.runtime_options().cache); + assert!(!compiled.runtime_options().progress); + assert_eq!(compiled.runtime_options().idelta, 0.5); + assert_eq!(compiled.runtime_options().tad, 24.0); + assert_eq!( + compiled.runtime_options().logging.level, + LoggingLevel::Debug + ); + assert!(compiled.runtime_options().logging.write); + assert!(!compiled.runtime_options().logging.stdout); + assert_eq!(compiled.runtime_options().convergence.likelihood, 1e-5); + assert_eq!(compiled.runtime_options().convergence.pyl, 5e-3); + assert_eq!(compiled.runtime_options().convergence.eps, 2e-3); + assert_eq!(compiled.runtime_options().tuning.min_distance, 2e-4); + assert_eq!(compiled.runtime_options().tuning.nm_steps, 222); + assert_eq!(compiled.runtime_options().tuning.tolerance, 3e-6); + assert_eq!(compiled.runtime_options().tuning.saem.k1_iterations, 111); + assert_eq!(compiled.runtime_options().tuning.saem.k2_iterations, 22); + Ok(()) +} + +#[test] +fn test_problem_can_initialize_logs_without_old_settings_api() -> Result<()> { + let assay_error = AssayErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0); + let observations = ObservationSpec::new() + .add_channel(ObservationChannel::continuous(0, "cp")) + .with_assay_error_models(AssayErrorModels::new().add(0, assay_error)?); + + let model = ModelDefinition::builder(simple_equation()) + .parameters( + ParameterSpace::new() + .add(ParameterSpec::bounded("ke", 0.1, 1.0)) + .add(ParameterSpec::bounded("v", 1.0, 20.0)), + ) + .observations(observations) + .build()?; + + let problem = EstimationProblem::builder(model, simple_data()) + .method(EstimationMethod::Nonparametric(NonparametricMethod::Npag( + NpagOptions, + ))) + .output(OutputPlan::disabled()) + .runtime(RuntimeOptions { + logging: LoggingOptions { + initialize: true, + level: LoggingLevel::Info, + write: false, + stdout: false, + }, + ..RuntimeOptions::default() + }) + .build()?; + + problem.initialize_logs()?; + Ok(()) +} diff --git a/tests/output_writer_tests.rs b/tests/output_writer_tests.rs index bdb1514c5..32e49cbdd 100644 --- a/tests/output_writer_tests.rs +++ b/tests/output_writer_tests.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use pharmsol::{AssayErrorModel, ErrorPoly}; +use pharmsol::{AssayErrorModel, ErrorPoly, ResidualErrorModel, ResidualErrorModels}; use pmcore::prelude::*; use std::path::PathBuf; use std::time::{SystemTime, UNIX_EPOCH}; @@ -114,3 +114,93 @@ fn test_fit_result_writes_shared_output_files() -> Result<()> { let _ = std::fs::remove_dir_all(output_dir); Ok(()) } + +#[test] +fn test_parametric_outputs_use_split_individual_files() -> Result<()> { + let output_dir = temp_output_dir(); + let assay_error = AssayErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0); + let residual_error = ResidualErrorModels::new().add(0, ResidualErrorModel::combined(0.5, 0.1)); + let observations = ObservationSpec::new() + .add_channel(ObservationChannel::continuous(0, "cp")) + .with_assay_error_models(AssayErrorModels::new().add(0, assay_error)?) + .with_residual_error_models(residual_error); + + let model = ModelDefinition::builder(simple_equation()) + .parameters( + ParameterSpace::new() + .add(ParameterSpec::bounded("ke", 0.1, 1.0)) + .add(ParameterSpec::bounded("v", 1.0, 20.0)), + ) + .observations(observations) + .build()?; + + let mut result = EstimationProblem::builder(model, simple_data()) + .method(EstimationMethod::Parametric(ParametricMethod::Focei( + FoceiOptions, + ))) + .output(OutputPlan { + write: true, + path: Some(output_dir.to_string_lossy().to_string()), + }) + .runtime(RuntimeOptions { + cycles: 3, + progress: false, + ..RuntimeOptions::default() + }) + .run()?; + + result.write_outputs()?; + + assert!(output_dir.join("individual_parameters.csv").exists()); + assert!(output_dir.join("individual_effects.csv").exists()); + assert!(output_dir.join("predictions.csv").exists()); + assert!(output_dir.join("residual_error.csv").exists()); + assert!(!output_dir.join("individual.csv").exists()); + assert!(!output_dir.join("pred.csv").exists()); + assert!(!output_dir.join("covs.csv").exists()); + assert!(!output_dir.join("sigma.csv").exists()); + + let artifacts = result.artifacts(); + assert!(artifacts + .files + .iter() + .any(|file| file == "individual_parameters.csv")); + assert!(artifacts + .files + .iter() + .any(|file| file == "individual_effects.csv")); + assert!(artifacts.files.iter().any(|file| file == "predictions.csv")); + assert!(artifacts + .files + .iter() + .any(|file| file == "residual_error.csv")); + assert!(artifacts + .expected_files + .iter() + .any(|file| file == "predictions.csv")); + assert!(artifacts + .shared_expected_files + .iter() + .any(|file| file == "predictions.csv")); + assert!(artifacts + .method_specific_expected_files + .iter() + .any(|file| file == "residual_error.csv")); + assert!(artifacts + .method_specific_expected_files + .iter() + .any(|file| file == "individual_parameters.csv")); + assert!(artifacts.missing_files.is_empty()); + assert!(!artifacts.files.iter().any(|file| file == "individual.csv")); + assert!(!artifacts.files.iter().any(|file| file == "pred.csv")); + assert!(!artifacts.files.iter().any(|file| file == "covs.csv")); + assert!(!artifacts.files.iter().any(|file| file == "sigma.csv")); + + let predictions = result.predictions(); + assert!(predictions.available); + assert_eq!(predictions.artifact.as_deref(), Some("predictions.csv")); + assert_eq!(predictions.source.as_deref(), Some("in_memory")); + + let _ = std::fs::remove_dir_all(output_dir); + Ok(()) +} diff --git a/tests/parametric_compiler_tests.rs b/tests/parametric_compiler_tests.rs new file mode 100644 index 000000000..e3e56d711 --- /dev/null +++ b/tests/parametric_compiler_tests.rs @@ -0,0 +1,210 @@ +use anyhow::Result; +use pharmsol::{AssayErrorModel, ErrorPoly, ResidualErrorModel, ResidualErrorModels}; +use pmcore::prelude::*; + +fn simple_equation() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, b, _rateiv, _cov| { + fetch_params!(p, ke); + dx[0] = -ke * x[0] + b[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, v); + y[0] = x[0] / v; + }, + ) +} + +fn simple_data() -> Data { + let subject = Subject::builder("1") + .covariate("wt", 0.0, 70.0) + .bolus(0.0, 100.0, 0) + .observation(1.0, 10.0, 0) + .observation(2.0, 8.0, 0) + .build(); + + Data::new(vec![subject]) +} + +fn multi_occasion_covariate_data() -> Data { + let subject = Subject::builder("1") + .covariate("wt", 0.0, 70.0) + .covariate("study_day", 0.0, 1.0) + .bolus(0.0, 100.0, 0) + .observation(1.0, 10.0, 0) + .reset() + .covariate("wt", 0.0, 70.0) + .covariate("study_day", 0.0, 2.0) + .bolus(0.0, 100.0, 0) + .observation(1.0, 8.0, 0) + .build(); + + Data::new(vec![subject]) +} + +#[test] +fn test_parametric_compiler_extracts_model_intent() -> Result<()> { + let assay_error = AssayErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0); + let residual_error = ResidualErrorModels::new().add(0, ResidualErrorModel::combined(0.5, 0.1)); + let observations = ObservationSpec::new() + .add_channel(ObservationChannel::continuous(0, "cp")) + .with_assay_error_models(AssayErrorModels::new().add(0, assay_error)?) + .with_residual_error_models(residual_error); + + let model = ModelDefinition::builder(simple_equation()) + .parameters( + ParameterSpace::new() + .add(ParameterSpec { + name: "ke".to_string(), + domain: ParameterDomain::Positive { + lower: Some(0.0), + upper: Some(1.0), + }, + transform: ModelParameterTransform::LogNormal, + initial: Some(0.4), + estimate: true, + variability: ParameterVariability::SubjectAndOccasion, + }) + .add(ParameterSpec { + name: "v".to_string(), + domain: ParameterDomain::Bounded { + lower: 1.0, + upper: 20.0, + }, + transform: ModelParameterTransform::Identity, + initial: None, + estimate: true, + variability: ParameterVariability::FixedOnly, + }), + ) + .observations(observations) + .covariates(CovariateSpec::Structured(CovariateEffectsSpec { + subject_effects: Some(CovariateModel::new( + vec!["ke", "v"], + vec!["wt"], + vec![vec![true], vec![false]], + )?), + occasion_effects: None, + })) + .build()?; + + let compiled = EstimationProblem::builder(model, simple_data()) + .method(EstimationMethod::Parametric(ParametricMethod::Saem( + SaemOptions, + ))) + .output(OutputPlan::disabled()) + .build()? + .compile()?; + + let state = compile_model_state(&compiled); + + assert_eq!(state.fixed_effects.parameter_names, vec!["ke", "v"]); + assert_eq!(state.fixed_effects.population_mean.0, vec![0.4, 10.5]); + assert_eq!( + state.transforms.transforms[0], + ParametricTransformKind::LogNormal + ); + assert_eq!( + state.transforms.transforms[1], + ParametricTransformKind::Identity + ); + assert_eq!(state.variability.subject.enabled_for, vec![true, false]); + assert_eq!( + state + .variability + .occasion + .as_ref() + .expect("occasion variability should be derived from parameter roles") + .enabled_for, + vec![true, false] + ); + assert!(state.covariates.subject_effects.is_some()); + assert!(state.covariates.occasion_effects.is_none()); + assert_eq!( + state + .covariates + .subject_effects + .as_ref() + .unwrap() + .column_names, + vec!["wt"] + ); + assert_eq!( + state + .covariates + .subject_effects + .as_ref() + .unwrap() + .parameter_names, + vec!["ke", "v"] + ); + assert_eq!( + state + .covariates + .subject_effects + .as_ref() + .unwrap() + .covariate_mask, + vec![vec![true], vec![false]] + ); + assert_eq!( + state.covariates.subject_effects.as_ref().unwrap().values, + vec![vec![Some(70.0)]] + ); + Ok(()) +} + +#[test] +fn test_parametric_compiler_extracts_occasion_covariates() -> Result<()> { + let assay_error = AssayErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0); + let residual_error = ResidualErrorModels::new().add(0, ResidualErrorModel::combined(0.5, 0.1)); + let observations = ObservationSpec::new() + .add_channel(ObservationChannel::continuous(0, "cp")) + .with_assay_error_models(AssayErrorModels::new().add(0, assay_error)?) + .with_residual_error_models(residual_error); + + let model = ModelDefinition::builder(simple_equation()) + .parameters( + ParameterSpace::new() + .add(ParameterSpec::bounded("ke", 0.1, 1.0)) + .add(ParameterSpec::bounded("v", 1.0, 20.0)), + ) + .observations(observations) + .covariates(CovariateSpec::Structured(CovariateEffectsSpec { + subject_effects: Some(CovariateModel::new( + vec!["ke", "v"], + vec!["wt"], + vec![vec![true], vec![false]], + )?), + occasion_effects: Some(CovariateModel::new( + vec!["ke", "v"], + vec!["study_day"], + vec![vec![true], vec![false]], + )?), + })) + .build()?; + + let compiled = EstimationProblem::builder(model, multi_occasion_covariate_data()) + .method(EstimationMethod::Parametric(ParametricMethod::Saem( + SaemOptions, + ))) + .output(OutputPlan::disabled()) + .build()? + .compile()?; + + let state = compile_model_state(&compiled); + let occasion = state + .covariates + .occasion_effects + .as_ref() + .expect("occasion covariates should be compiled into parametric state"); + + assert_eq!(occasion.column_names, vec!["study_day"]); + assert_eq!(occasion.parameter_names, vec!["ke", "v"]); + assert_eq!(occasion.covariate_mask, vec![vec![true], vec![false]]); + assert_eq!(occasion.values, vec![vec![Some(1.0)], vec![Some(2.0)]]); + Ok(()) +} diff --git a/tests/parametric_workspace_tests.rs b/tests/parametric_workspace_tests.rs new file mode 100644 index 000000000..15338a2f5 --- /dev/null +++ b/tests/parametric_workspace_tests.rs @@ -0,0 +1,105 @@ +use anyhow::Result; +use pharmsol::{AssayErrorModel, ErrorPoly, ResidualErrorModel, ResidualErrorModels}; +use pmcore::prelude::*; + +fn simple_equation() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, b, _rateiv, _cov| { + fetch_params!(p, ke); + dx[0] = -ke * x[0] + b[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, v); + y[0] = x[0] / v; + }, + ) +} + +fn multi_occasion_data() -> Data { + let subject = Subject::builder("1") + .bolus(0.0, 100.0, 0) + .observation(1.0, 10.0, 0) + .observation(2.0, 8.0, 0) + .reset() + .bolus(0.0, 100.0, 0) + .observation(1.0, 9.0, 0) + .observation(2.0, 7.5, 0) + .build(); + + Data::new(vec![subject]) +} + +#[test] +fn test_parametric_workspace_preserves_occasion_effect_slots() -> Result<()> { + let assay_error = AssayErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0); + let residual_error = ResidualErrorModels::new().add(0, ResidualErrorModel::combined(0.5, 0.1)); + let observations = ObservationSpec::new() + .add_channel(ObservationChannel::continuous(0, "cp")) + .with_assay_error_models(AssayErrorModels::new().add(0, assay_error)?) + .with_residual_error_models(residual_error); + + let model = ModelDefinition::builder(simple_equation()) + .parameters( + ParameterSpace::new() + .add(ParameterSpec { + name: "ke".to_string(), + domain: ParameterDomain::Bounded { + lower: 0.1, + upper: 1.0, + }, + transform: ModelParameterTransform::Identity, + initial: Some(0.4), + estimate: true, + variability: ParameterVariability::SubjectAndOccasion, + }) + .add(ParameterSpec::bounded("v", 1.0, 20.0)), + ) + .observations(observations) + .build()?; + + let compiled = EstimationProblem::builder(model, multi_occasion_data()) + .method(EstimationMethod::Parametric(ParametricMethod::Focei( + FoceiOptions, + ))) + .output(OutputPlan::disabled()) + .runtime(RuntimeOptions { + cycles: 1, + progress: false, + ..RuntimeOptions::default() + }) + .build()? + .compile()?; + + let workspace = ParametricEngine::fit(compiled)?; + let occasion_kappa = workspace + .individuals() + .occasion_kappa + .as_ref() + .expect("occasion effect slots should exist for occasion-enabled models"); + + assert_eq!(occasion_kappa.0.len(), 2); + assert_eq!(occasion_kappa.0[0].subject_index, 0); + assert_eq!(occasion_kappa.0[0].occasion_index, 0); + assert_eq!(occasion_kappa.0[1].occasion_index, 1); + assert_eq!(occasion_kappa.0[0].values.0, vec![0.0, 0.0]); + assert_eq!( + workspace.state().variability.subject.enabled_for, + vec![true, true] + ); + assert_eq!( + workspace + .state() + .variability + .occasion + .as_ref() + .expect("occasion variability should be present") + .enabled_for, + vec![true, false] + ); + assert_eq!(workspace.sigma().combined, Some((0.5, 0.1))); + assert!(workspace.uncertainty().has_standard_errors()); + Ok(()) +} diff --git a/tests/results_summary_tests.rs b/tests/results_summary_tests.rs index 283b244d9..61e188e84 100644 --- a/tests/results_summary_tests.rs +++ b/tests/results_summary_tests.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use pharmsol::{AssayErrorModel, ErrorPoly}; +use pharmsol::{AssayErrorModel, ErrorPoly, ResidualErrorModel, ResidualErrorModels}; use pmcore::prelude::*; fn simple_equation() -> equation::ODE { @@ -28,6 +28,20 @@ fn simple_data() -> Data { Data::new(vec![subject]) } +fn multi_occasion_data() -> Data { + let subject = Subject::builder("1") + .bolus(0.0, 100.0, 0) + .observation(1.0, 10.0, 0) + .observation(2.0, 8.0, 0) + .reset() + .bolus(0.0, 100.0, 0) + .observation(1.0, 9.0, 0) + .observation(2.0, 7.5, 0) + .build(); + + Data::new(vec![subject]) +} + #[test] fn test_nonparametric_fit_result_summary_surface() -> Result<()> { let assay_error = AssayErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0); @@ -84,3 +98,131 @@ fn test_nonparametric_fit_result_summary_surface() -> Result<()> { assert!(result.artifacts().expected_files.is_empty()); Ok(()) } + +#[test] +fn test_parametric_fit_result_diagnostics_expose_iov_boundary() -> Result<()> { + let assay_error = AssayErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0); + let residual_error = ResidualErrorModels::new().add(0, ResidualErrorModel::combined(0.5, 0.1)); + let observations = ObservationSpec::new() + .add_channel(ObservationChannel::continuous(0, "cp")) + .with_assay_error_models(AssayErrorModels::new().add(0, assay_error)?) + .with_residual_error_models(residual_error); + + let model = ModelDefinition::builder(simple_equation()) + .parameters( + ParameterSpace::new() + .add(ParameterSpec { + name: "ke".to_string(), + domain: ParameterDomain::Bounded { + lower: 0.1, + upper: 1.0, + }, + transform: ModelParameterTransform::Identity, + initial: Some(0.4), + estimate: true, + variability: ParameterVariability::SubjectAndOccasion, + }) + .add(ParameterSpec::bounded("v", 1.0, 20.0)), + ) + .observations(observations) + .build()?; + + let result = EstimationProblem::builder(model, multi_occasion_data()) + .method(EstimationMethod::Parametric(ParametricMethod::Focei( + FoceiOptions, + ))) + .output(OutputPlan::disabled()) + .runtime(RuntimeOptions { + cycles: 1, + progress: false, + ..RuntimeOptions::default() + }) + .run()?; + + let diagnostics = result.diagnostics(); + assert!(diagnostics + .warnings + .iter() + .any(|warning| warning.contains("occasion-level inference remains deferred"))); + assert_eq!( + diagnostics.estimator_metadata.get("algorithm"), + Some(&"FOCEI".to_string()) + ); + assert_eq!( + diagnostics.estimator_metadata.get("occasion_inference"), + Some(&"deferred".to_string()) + ); + assert_eq!( + diagnostics.estimator_metadata.get("outputs_requested"), + Some(&"false".to_string()) + ); + assert_eq!( + diagnostics.estimator_metadata.get("residual_error_output"), + Some(&"disabled".to_string()) + ); + assert_eq!( + diagnostics.estimator_metadata.get("uncertainty_output"), + Some(&"disabled".to_string()) + ); + assert!(diagnostics + .deferred_features + .iter() + .any(|feature| feature == "occasion_inference")); + Ok(()) +} + +#[test] +fn test_parametric_population_summary_uses_transform_aware_cv() -> Result<()> { + let assay_error = AssayErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0); + let residual_error = ResidualErrorModels::new().add(0, ResidualErrorModel::combined(0.5, 0.1)); + let observations = ObservationSpec::new() + .add_channel(ObservationChannel::continuous(0, "cp")) + .with_assay_error_models(AssayErrorModels::new().add(0, assay_error)?) + .with_residual_error_models(residual_error); + + let model = ModelDefinition::builder(simple_equation()) + .parameters( + ParameterSpace::new() + .add(ParameterSpec::bounded("ke", 0.1, 1.0)) + .add(ParameterSpec::bounded("v", 1.0, 20.0)), + ) + .observations(observations) + .build()?; + + let result = EstimationProblem::builder(model, simple_data()) + .method(EstimationMethod::Parametric(ParametricMethod::Focei( + FoceiOptions, + ))) + .output(OutputPlan::disabled()) + .runtime(RuntimeOptions { + cycles: 3, + progress: false, + ..RuntimeOptions::default() + }) + .run()?; + + let summary = result.population_summary(); + let diagnostics = result.diagnostics(); + let v_summary = summary + .parameters + .iter() + .find(|parameter| parameter.name == "v") + .expect("v should be present in the population summary"); + let expected_cv = 100.0 * v_summary.sd / v_summary.mean.abs(); + + assert!((v_summary.cv_percent - expected_cv).abs() < 1e-10); + assert!(v_summary.cv_percent.is_finite()); + assert_eq!( + diagnostics.estimator_metadata.get("residual_error_model"), + Some(&"combined".to_string()) + ); + assert_eq!( + diagnostics.estimator_metadata.get("residual_error_output"), + Some(&"disabled".to_string()) + ); + assert_eq!( + diagnostics.estimator_metadata.get("uncertainty_output"), + Some(&"disabled".to_string()) + ); + Ok(()) +} diff --git a/tests/saem_tests.rs b/tests/saem_tests.rs new file mode 100644 index 000000000..eb901bf8b --- /dev/null +++ b/tests/saem_tests.rs @@ -0,0 +1,711 @@ +//! SAEM Algorithm Validation Tests +//! +//! These tests validate the f-SAEM implementation against known results +//! from the saemix R package. + +use anyhow::Result; +use pharmsol::{ + AssayErrorModel, AssayErrorModels, Equation, ResidualErrorModel, ResidualErrorModels, +}; +use pmcore::algorithms::parametric::dispatch_parametric_algorithm; +use pmcore::model::ParameterTransform as ModelParameterTransform; +use pmcore::prelude::*; + +#[derive(Clone)] +struct SaemTestProblemConfig { + parameter_space: ParameterSpace, + residual_error: ResidualErrorModels, + output: OutputPlan, + runtime: RuntimeOptions, +} + +impl SaemTestProblemConfig { + fn new(parameter_space: ParameterSpace, residual_error: ResidualErrorModels) -> Self { + Self { + parameter_space, + residual_error, + output: OutputPlan { + write: false, + path: None, + }, + runtime: RuntimeOptions::default(), + } + } +} + +fn bounded_parameter_space(bounds: &[(&str, f64, f64)]) -> ParameterSpace { + bounds + .iter() + .fold(ParameterSpace::new(), |space, (name, lower, upper)| { + space.add(ParameterSpec::bounded(*name, *lower, *upper)) + }) +} + +fn apply_saem_transforms(parameter_space: &ParameterSpace, saem: &SaemConfig) -> ParameterSpace { + parameter_space + .iter() + .enumerate() + .fold(ParameterSpace::new(), |space, (index, parameter)| { + space.add(ParameterSpec { + transform: match saem.get_transform(index) { + 1 => ModelParameterTransform::LogNormal, + 2 => ModelParameterTransform::Probit, + 3 => ModelParameterTransform::Logit, + _ => ModelParameterTransform::Identity, + }, + ..parameter.clone() + }) + }) +} + +fn run_saem_problem( + config: SaemTestProblemConfig, + equation: E, + data: Data, +) -> Result> { + build_saem_problem(config, equation, data)?.run() +} + +fn build_saem_problem( + config: SaemTestProblemConfig, + equation: E, + data: Data, +) -> Result> { + let parameters = apply_saem_transforms(&config.parameter_space, &config.runtime.tuning.saem); + + let observations = ObservationSpec::new() + .add_channel(ObservationChannel::continuous(0, "obs")) + .with_residual_error_models(config.residual_error); + + let model = ModelDefinition::builder(equation) + .parameters(parameters) + .observations(observations) + .build()?; + + EstimationProblem::builder(model, data) + .method(EstimationMethod::Parametric(ParametricMethod::Saem( + SaemOptions, + ))) + .output(config.output) + .runtime(config.runtime) + .build() +} + +/// Test data: Theophylline pharmacokinetics +/// 12 subjects with oral theophylline dosing +/// This is the classic example from the saemix package +fn create_theo_data() -> Data { + // Theophylline data (subset matching saemix theo.saemix dataset) + // Format: (id, dose, times, concentrations) + let subjects_data = vec![ + // Subject 1 + ( + 1, + 319.992, + vec![0.25, 0.57, 1.12, 2.02, 3.82, 5.10, 7.03, 9.05, 12.12, 24.37], + vec![2.84, 6.57, 10.50, 9.66, 8.58, 8.36, 7.47, 6.89, 5.94, 3.28], + ), + // Subject 2 + ( + 2, + 318.560, + vec![0.27, 0.52, 1.00, 1.92, 3.50, 5.02, 7.03, 9.00, 12.00, 24.30], + vec![1.72, 7.91, 8.31, 8.33, 6.85, 6.08, 5.40, 4.55, 3.01, 0.90], + ), + // Subject 3 + ( + 3, + 319.365, + vec![0.27, 0.58, 1.02, 2.02, 3.62, 5.08, 7.07, 9.00, 12.15, 24.17], + vec![4.40, 6.90, 8.20, 7.80, 7.50, 6.20, 5.30, 4.90, 3.70, 1.05], + ), + // Subject 4 + ( + 4, + 319.992, + vec![0.35, 0.60, 1.07, 2.13, 3.50, 5.02, 7.02, 9.02, 11.98, 24.65], + vec![1.89, 4.60, 8.60, 8.38, 7.54, 6.88, 5.78, 5.33, 4.19, 1.15], + ), + // Subject 5 + ( + 5, + 320.619, + vec![0.30, 0.52, 1.00, 2.02, 3.50, 5.02, 7.02, 9.10, 12.00, 24.35], + vec![2.02, 5.63, 11.40, 9.33, 8.74, 7.56, 7.09, 5.90, 4.37, 1.57], + ), + // Subject 6 + ( + 6, + 320.619, + vec![0.27, 0.58, 1.15, 2.03, 3.57, 5.00, 7.00, 9.22, 12.10, 23.85], + vec![1.29, 3.08, 6.44, 6.32, 5.53, 4.94, 4.02, 3.46, 2.78, 0.92], + ), + // Subject 7 + ( + 7, + 277.767, + vec![0.25, 0.50, 1.02, 2.02, 3.48, 5.00, 6.98, 9.00, 12.05, 24.22], + vec![3.59, 6.11, 7.56, 6.54, 5.37, 4.84, 4.02, 3.83, 2.81, 0.85], + ), + // Subject 8 + ( + 8, + 276.514, + vec![0.25, 0.52, 0.98, 2.02, 3.53, 5.05, 7.15, 9.07, 12.10, 24.12], + vec![0.73, 4.00, 6.81, 8.00, 7.09, 5.89, 5.22, 4.75, 3.41, 0.96], + ), + // Subject 9 + ( + 9, + 299.550, + vec![0.30, 0.63, 1.05, 2.02, 3.53, 5.02, 7.17, 8.80, 11.60, 24.43], + vec![3.15, 6.96, 9.70, 9.52, 7.17, 6.28, 5.28, 4.66, 3.82, 1.15], + ), + // Subject 10 + ( + 10, + 298.297, + vec![0.37, 0.77, 1.02, 2.05, 3.55, 5.05, 7.08, 9.00, 12.12, 24.08], + vec![7.37, 9.03, 10.21, 9.18, 8.02, 7.14, 6.08, 5.54, 4.57, 1.17], + ), + // Subject 11 + ( + 11, + 300.176, + vec![0.25, 0.50, 0.98, 1.98, 3.60, 5.02, 7.03, 9.03, 12.12, 24.28], + vec![0.92, 2.63, 6.85, 9.05, 7.90, 7.44, 6.13, 5.31, 4.10, 1.44], + ), + // Subject 12 + ( + 12, + 298.297, + vec![0.25, 0.52, 1.00, 2.07, 3.50, 4.95, 7.00, 9.02, 12.00, 24.15], + vec![1.11, 6.33, 9.99, 9.37, 8.50, 6.89, 5.94, 5.26, 4.35, 1.25], + ), + ]; + + let subjects: Vec = subjects_data + .into_iter() + .map(|(id, dose, times, concs)| { + let mut builder = Subject::builder(id.to_string()).bolus(0.0, dose, 0); // Oral dose at time 0 + + for (t, c) in times.into_iter().zip(concs.into_iter()) { + builder = builder.observation(t, c, 0); + } + + builder.build() + }) + .collect(); + + Data::new(subjects) +} + +/// One-compartment model with first-order absorption +/// dA/dt = -ka*A + dose (absorption compartment) +/// dC/dt = ka*A - ke*C (central compartment) +/// +/// Parameters: ka (absorption rate), V (volume), CL (clearance) +/// ke = CL/V (elimination rate constant) +fn create_one_compartment_absorption_model() -> equation::ODE { + equation::ODE::new( + // ODE system: dx/dt + |x, p, _t, dx, b, _rateiv, _cov| { + // Parameters: ka, V, CL + // x[0] = drug amount in absorption compartment + // x[1] = drug amount in central compartment + fetch_params!(p, ka, v, cl); + let ke = cl / v; + + // Absorption compartment (b[0] is the bolus input) + dx[0] = -ka * x[0] + b[0]; + // Central compartment + dx[1] = ka * x[0] - ke * x[1]; + }, + // Lag time function + |_p, _t, _cov| lag! {}, + // Bioavailability function + |_p, _t, _cov| fa! {}, + // Secondary equations + |_p, _t, _cov, _x| {}, + // Output equation: observed concentration + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, v, _cl); + y[0] = x[1] / v; // Concentration = amount / volume + }, + ) +} + +/// Test that SAEM converges for a simple one-compartment model +/// +/// This test validates: +/// 1. The algorithm runs without errors +/// 2. Population parameters are in reasonable range +/// 3. Results are qualitatively similar to saemix reference +#[test] +#[ignore = "SAEM integration test - run with --ignored"] +fn test_saem_theophylline_convergence() -> Result<()> { + // Create model + let eq = create_one_compartment_absorption_model(); + + // Create data + let data = create_theo_data(); + + // Parameter ranges based on typical theophylline PK + // ka: absorption rate (0.5 - 3 /hr typical) + // V: volume of distribution (20-50 L typical for adult) + // CL: clearance (1-5 L/hr typical) + let params = bounded_parameter_space(&[("ka", 0.5, 3.0), ("v", 10.0, 50.0), ("cl", 0.5, 5.0)]); + + // Residual error model (parametric algorithms use ResidualErrorModels) + let residual_error = ResidualErrorModels::new().add(0, ResidualErrorModel::constant(0.5)); + + // Create SAEM settings + let config = SaemTestProblemConfig::new(params, residual_error); + + // Run through the unified platform entry path. + let result = run_saem_problem(config, eq, data)?; + let result = result + .as_parametric() + .expect("SAEM should yield a parametric result"); + + // Basic convergence checks + println!("SAEM completed in {} iterations", result.iterations()); + println!("Objective function: {:.2}", result.objf()); + + // The algorithm should complete + assert!( + result.iterations() > 0, + "Algorithm should complete at least one cycle" + ); + + // Objective function should be finite + assert!( + result.objf().is_finite(), + "Objective function should be finite" + ); + + // Expected approximate values from saemix (log scale for params): + // ka ≈ 1.5-2.0, V ≈ 30-35, CL ≈ 2.5-3.5 + // These are approximate - MCMC methods have variance + + Ok(()) +} + +/// Test SAEM on simple synthetic data with known parameters +/// +/// This test uses synthetic data generated from known parameter values +/// to verify the algorithm can recover the true population parameters. +#[test] +#[ignore = "NPAG used as placeholder - test SAEM when fully wired"] +fn test_saem_parameter_recovery_simple() -> Result<()> { + // Create a simple one-compartment elimination model (no absorption) + let eq = equation::ODE::new( + |x, p, _t, dx, _b, _rateiv, _cov| { + fetch_params!(p, ke); + dx[0] = -ke * x[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, v); + y[0] = x[0] / v; + }, + ); + + // True population parameters + let true_ke: f64 = 0.5; // elimination rate constant + let true_v: f64 = 10.0; // volume + let true_omega_ke: f64 = 0.04; // ~20% CV for ke + let true_omega_v: f64 = 0.09; // ~30% CV for V + + // Generate synthetic subjects with random effects + use rand::SeedableRng; + use rand_chacha::ChaCha8Rng; + use rand_distr::{Distribution, Normal}; + + let mut rng = ChaCha8Rng::seed_from_u64(42); + let normal = Normal::new(0.0, 1.0).unwrap(); + + let n_subjects = 20; + let dose = 100.0; + let times = vec![0.5, 1.0, 2.0, 4.0, 8.0, 12.0]; + + let subjects: Vec = (0..n_subjects) + .map(|id| { + // Individual parameters with log-normal distribution + let eta_ke = normal.sample(&mut rng) * true_omega_ke.sqrt(); + let eta_v = normal.sample(&mut rng) * true_omega_v.sqrt(); + let ind_ke = true_ke * f64::exp(eta_ke); + let ind_v = true_v * f64::exp(eta_v); + + // Generate observations + let mut builder = Subject::builder(id.to_string()).bolus(0.0, dose, 0); + + for &t in × { + // C(t) = (Dose/V) * exp(-ke * t) + let conc = (dose / ind_v) * f64::exp(-ind_ke * t); + // Add some measurement noise (~5%) + let noise = 1.0 + normal.sample(&mut rng) * 0.05; + builder = builder.observation(t, conc * noise, 0); + } + + builder.build() + }) + .collect(); + + let data = Data::new(subjects); + + // Set up parameters with reasonable ranges + let params = bounded_parameter_space(&[("ke", 0.1, 1.0), ("v", 5.0, 20.0)]); + + // Error model for non-parametric + let em = AssayErrorModel::additive(ErrorPoly::new(0.5, 0.0, 0.0, 0.0), 1.0); + let ems = AssayErrorModels::new().add(0, em).unwrap(); + + // Create settings - use NPAG for now as SAEM isn't fully wired up + let observations = ObservationSpec::new() + .add_channel(ObservationChannel::continuous(0, "obs")) + .with_assay_error_models(ems); + let model = ModelDefinition::builder(eq) + .parameters(params) + .observations(observations) + .build()?; + let mut runtime = RuntimeOptions::default(); + runtime.cycles = 50; + + let result = EstimationProblem::builder(model, data) + .method(EstimationMethod::Nonparametric(NonparametricMethod::Npag( + NpagOptions, + ))) + .output(OutputPlan { + write: false, + path: None, + }) + .runtime(runtime) + .build()? + .run()?; + let result = result + .as_nonparametric() + .expect("NPAG should yield a nonparametric result"); + + println!("Test completed with objf: {:.2}", result.objf()); + + // Basic sanity checks + assert!(result.objf().is_finite()); + + Ok(()) +} + +/// Unit test for SAEM sufficient statistics accumulation +#[test] +fn test_sufficient_statistics() { + use faer::Col; + use pmcore::prelude::SufficientStats; + + let mut stats = SufficientStats::new(2); + + // Add some samples + let sample1 = Col::from_fn(2, |i| if i == 0 { 1.0 } else { 2.0 }); + let sample2 = Col::from_fn(2, |i| if i == 0 { 3.0 } else { 4.0 }); + let sample3 = Col::from_fn(2, |i| if i == 0 { 5.0 } else { 6.0 }); + + stats.accumulate(&sample1).unwrap(); + stats.accumulate(&sample2).unwrap(); + stats.accumulate(&sample3).unwrap(); + + assert_eq!(stats.count(), 3); + + // Check sufficient statistics + // S1 = sum of samples + assert!((stats.s1()[0] - 9.0).abs() < 1e-10); // 1 + 3 + 5 + assert!((stats.s1()[1] - 12.0).abs() < 1e-10); // 2 + 4 + 6 + + // Compute M-step + let (mu, omega) = stats.compute_m_step().unwrap(); + + // Mean should be [3, 4] + assert!((mu[0] - 3.0).abs() < 1e-10); + assert!((mu[1] - 4.0).abs() < 1e-10); + + // Variance should be sample variance + // Var = E[X²] - E[X]² + // For column 0: E[X²] = (1+9+25)/3 = 35/3, E[X]² = 9, Var = 35/3 - 9 = 8/3 + let expected_var_0 = 8.0 / 3.0; + assert!((omega[(0, 0)] - expected_var_0).abs() < 1e-10); +} + +/// Unit test for stochastic approximation step size schedule +#[test] +fn test_step_size_schedule() { + use pmcore::prelude::StepSizeSchedule; + + // Test SAEM-style schedule + let schedule = StepSizeSchedule::new_saem(100, 200); + + // During burn-in (iterations 1-100), step size should be 1.0 + assert!((schedule.step_size(50) - 1.0).abs() < 1e-10); + assert!((schedule.step_size(100) - 1.0).abs() < 1e-10); + + // After burn-in, step size should decrease + assert!(schedule.step_size(101) < 1.0); + assert!(schedule.step_size(200) < schedule.step_size(101)); +} + +/// Test SAEM algorithm initialization and basic structure +#[test] +fn test_saem_initialization() -> Result<()> { + // Simple one-compartment model + let eq = equation::ODE::new( + |x, p, _t, dx, _b, _rateiv, _cov| { + fetch_params!(p, ke); + dx[0] = -ke * x[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, v); + y[0] = x[0] / v; + }, + ); + + // Create minimal test data (3 subjects) + let subjects: Vec = (0..3) + .map(|id| { + Subject::builder(id.to_string()) + .bolus(0.0, 100.0, 0) + .observation(1.0, 5.0, 0) + .observation(4.0, 2.0, 0) + .build() + }) + .collect(); + let data = Data::new(subjects); + + // Parameters + let params = bounded_parameter_space(&[("ke", 0.1, 1.0), ("v", 5.0, 20.0)]); + + // Residual error model for SAEM (prediction-based sigma) + let residual_error = ResidualErrorModels::new().add(0, ResidualErrorModel::constant(1.0)); + + // Create SAEM settings + let config = SaemTestProblemConfig::new(params, residual_error); + + // Create algorithm via dispatch + let algorithm = dispatch_parametric_algorithm(build_saem_problem( + config.clone(), + eq.clone(), + data.clone(), + )?)?; + + // Check basic initialization + assert_eq!(algorithm.iteration(), 0); + assert_eq!(algorithm.population().npar(), 2); // ke and v + assert!(algorithm.objective_function().is_infinite()); // Not computed yet + + println!("SAEM initialized successfully!"); + println!( + " Population mean: {:?}", + (0..algorithm.population().npar()) + .map(|i| algorithm.population().mu()[i]) + .collect::>() + ); + println!( + " Population omega diagonal: {:?}", + (0..algorithm.population().npar()) + .map(|i| algorithm.population().omega()[(i, i)]) + .collect::>() + ); + + Ok(()) +} + +/// Test that SAEM can run a few iterations without crashing +#[test] +fn test_saem_runs_iterations() -> Result<()> { + // Simple one-compartment model + let eq = equation::ODE::new( + |x, p, _t, dx, _b, _rateiv, _cov| { + fetch_params!(p, ke); + dx[0] = -ke * x[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, v); + y[0] = x[0] / v; + }, + ); + + // Create test data with multiple subjects + let subjects: Vec = (0..5) + .map(|id| { + // Generate data consistent with ke~0.5, v~10 + let true_ke = 0.3 + 0.2 * (id as f64 / 5.0); // Vary by subject + let true_v = 8.0 + 4.0 * (id as f64 / 5.0); + let dose = 100.0; + + Subject::builder(id.to_string()) + .bolus(0.0, dose, 0) + .observation(1.0, (dose / true_v) * f64::exp(-true_ke * 1.0), 0) + .observation(2.0, (dose / true_v) * f64::exp(-true_ke * 2.0), 0) + .observation(4.0, (dose / true_v) * f64::exp(-true_ke * 4.0), 0) + .observation(8.0, (dose / true_v) * f64::exp(-true_ke * 8.0), 0) + .build() + }) + .collect(); + let data = Data::new(subjects); + + // Parameters + let params = bounded_parameter_space(&[("ke", 0.1, 1.0), ("v", 5.0, 20.0)]); + + // Residual error model (parametric algorithms use ResidualErrorModels) + let residual_error = ResidualErrorModels::new().add(0, ResidualErrorModel::constant(1.0)); + + // Create SAEM settings + let config = SaemTestProblemConfig::new(params, residual_error); + + // Create algorithm + let mut algorithm = dispatch_parametric_algorithm(build_saem_problem(config, eq, data)?)?; + + // Initialize + algorithm.initialize()?; + println!("After initialization:"); + println!(" Iteration: {}", algorithm.iteration()); + + // Run a few iterations + for i in 1..=5 { + let status = algorithm.next_iteration()?; + println!("After iteration {}:", i); + println!(" Objective: {:.4}", algorithm.objective_function()); + println!(" Status: {:?}", status); + + if let pmcore::algorithms::Status::Stop(_) = status { + break; + } + } + + // Should have run some iterations + assert!(algorithm.iteration() > 0); + + // Objective function should be computed (finite) + println!("\nFinal state:"); + println!(" Total iterations: {}", algorithm.iteration()); + println!( + " Objective function: {:.4}", + algorithm.objective_function() + ); + + // Print final population parameters + let pop = algorithm.population(); + println!(" Population mean (mu):"); + for i in 0..pop.npar() { + println!(" param[{}] = {:.4}", i, pop.mu()[i]); + } + + Ok(()) +} + +/// Test SAEM convergence on a simple IV bolus model with known parameters +/// +/// Uses the full algorithm run (400 iterations by default) to verify that +/// SAEM recovers population parameters within 20% of truth. +/// This test takes ~2-5 minutes due to full SAEM convergence. +#[test] +#[ignore = "Full SAEM convergence test (~2-5 min) - run with --ignored"] +fn test_saem_convergence() -> Result<()> { + // Simple one-compartment model + let eq = equation::ODE::new( + |x, p, _t, dx, _b, _rateiv, _cov| { + fetch_params!(p, ke); + dx[0] = -ke * x[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, v); + y[0] = x[0] / v; + }, + ); + + // True values for simulation + let true_ke = 0.4; + let true_v = 10.0; + let dose = 100.0; + + // Create test data with 20 subjects with some variability + let subjects: Vec = (0..20) + .map(|id| { + // Add some deterministic variability (based on id) + let eta_ke = 0.15 * ((id as f64 * 0.5).sin()); + let eta_v = 0.15 * ((id as f64 * 0.7).cos()); + let subj_ke = true_ke * f64::exp(eta_ke); + let subj_v = true_v * f64::exp(eta_v); + + Subject::builder(id.to_string()) + .bolus(0.0, dose, 0) + .observation(0.5, (dose / subj_v) * f64::exp(-subj_ke * 0.5), 0) + .observation(1.0, (dose / subj_v) * f64::exp(-subj_ke * 1.0), 0) + .observation(2.0, (dose / subj_v) * f64::exp(-subj_ke * 2.0), 0) + .observation(4.0, (dose / subj_v) * f64::exp(-subj_ke * 4.0), 0) + .observation(8.0, (dose / subj_v) * f64::exp(-subj_ke * 8.0), 0) + .observation(12.0, (dose / subj_v) * f64::exp(-subj_ke * 12.0), 0) + .build() + }) + .collect(); + let data = Data::new(subjects); + + // Use tighter bounds centered around true values + let params = bounded_parameter_space(&[("ke", 0.1, 0.8), ("v", 5.0, 15.0)]); + + // Residual error model (parametric algorithms use ResidualErrorModels) + let residual_error = ResidualErrorModels::new().add(0, ResidualErrorModel::constant(0.1)); + + // Create SAEM settings (default: 400 iterations = 5 burn-in + 295 SA + 100 estimation) + let config = SaemTestProblemConfig::new(params, residual_error); + + // Run through the unified platform entry path. + let result = run_saem_problem(config, eq, data)?; + let result = result + .as_parametric() + .expect("SAEM should yield a parametric result"); + + // Final results + let ke_est = result.population().mu()[0]; + let v_est = result.population().mu()[1]; + + println!("=== SAEM Convergence Test Results ==="); + println!("True values: ke={}, v={}", true_ke, true_v); + println!( + "Estimated: ke={:.4} (err: {:.1}%), v={:.4} (err: {:.1}%)", + ke_est, + 100.0 * (ke_est - true_ke).abs() / true_ke, + v_est, + 100.0 * (v_est - true_v).abs() / true_v, + ); + println!("Objective: {:.4}", result.objf()); + println!("Iterations: {}", result.iterations()); + + // With 400 iterations and noiseless deterministic data, estimates should be close + let ke_rel_err = (ke_est - true_ke).abs() / true_ke; + let v_rel_err = (v_est - true_v).abs() / true_v; + + assert!( + ke_rel_err < 0.20, + "ke estimate {:.4} too far from truth {} (rel err: {:.1}%)", + ke_est, + true_ke, + ke_rel_err * 100.0 + ); + assert!( + v_rel_err < 0.20, + "v estimate {:.4} too far from truth {} (rel err: {:.1}%)", + v_est, + true_v, + v_rel_err * 100.0 + ); + + Ok(()) +} diff --git a/tests/saem_validation/component_reference.json b/tests/saem_validation/component_reference.json new file mode 100644 index 000000000..b4ca50e1d --- /dev/null +++ b/tests/saem_validation/component_reference.json @@ -0,0 +1,30 @@ +{ + "transforms": { + "log_normal": { + "psi": [0.1, 0.5, 1, 2, 5], + "phi": [-2.3026, -0.6931, 0, 0.6931, 1.6094] + } + }, + "sufficient_stats": { + "phi_samples": [ + [1, 2], + [3, 4], + [5, 6] + ], + "s1": [9, 12], + "s2": [ + [35, 44], + [44, 56] + ], + "mu": [3, 4], + "omega": [ + [2.6667, 2.6667], + [2.6667, 2.6667] + ] + }, + "step_size": { + "n_burn": 100, + "n_smooth": 200, + "schedule": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0.5, 0.3333, 0.25, 0.2, 0.1667, 0.1429, 0.125, 0.1111, 0.1, 0.0909, 0.0833, 0.0769, 0.0714, 0.0667, 0.0625, 0.0588, 0.0556, 0.0526, 0.05, 0.0476, 0.0455, 0.0435, 0.0417, 0.04, 0.0385, 0.037, 0.0357, 0.0345, 0.0333, 0.0323, 0.0312, 0.0303, 0.0294, 0.0286, 0.0278, 0.027, 0.0263, 0.0256, 0.025, 0.0244, 0.0238, 0.0233, 0.0227, 0.0222, 0.0217, 0.0213, 0.0208, 0.0204, 0.02, 0.0196, 0.0192, 0.0189, 0.0185, 0.0182, 0.0179, 0.0175, 0.0172, 0.0169, 0.0167, 0.0164, 0.0161, 0.0159, 0.0156, 0.0154, 0.0152, 0.0149, 0.0147, 0.0145, 0.0143, 0.0141, 0.0139, 0.0137, 0.0135, 0.0133, 0.0132, 0.013, 0.0128, 0.0127, 0.0125, 0.0123, 0.0122, 0.012, 0.0119, 0.0118, 0.0116, 0.0115, 0.0114, 0.0112, 0.0111, 0.011, 0.0109, 0.0108, 0.0106, 0.0105, 0.0104, 0.0103, 0.0102, 0.0101, 0.01, 0.0099, 0.0098, 0.0097, 0.0096, 0.0095, 0.0094, 0.0093, 0.0093, 0.0092, 0.0091, 0.009, 0.0089, 0.0088, 0.0088, 0.0087, 0.0086, 0.0085, 0.0085, 0.0084, 0.0083, 0.0083, 0.0082, 0.0081, 0.0081, 0.008, 0.0079, 0.0079, 0.0078, 0.0078, 0.0077, 0.0076, 0.0076, 0.0075, 0.0075, 0.0074, 0.0074, 0.0073, 0.0072, 0.0072, 0.0071, 0.0071, 0.007, 0.007, 0.0069, 0.0069, 0.0068, 0.0068, 0.0068, 0.0067, 0.0067, 0.0066, 0.0066, 0.0065, 0.0065, 0.0065, 0.0064, 0.0064, 0.0063, 0.0063, 0.0062, 0.0062, 0.0062, 0.0061, 0.0061, 0.0061, 0.006, 0.006, 0.006, 0.0059, 0.0059, 0.0058, 0.0058, 0.0058, 0.0057, 0.0057, 0.0057, 0.0056, 0.0056, 0.0056, 0.0056, 0.0055, 0.0055, 0.0055, 0.0054, 0.0054, 0.0054, 0.0053, 0.0053, 0.0053, 0.0053, 0.0052, 0.0052, 0.0052, 0.0052, 0.0051, 0.0051, 0.0051, 0.0051, 0.005, 0.005] + } +} diff --git a/tests/saem_validation/generate_reference.R b/tests/saem_validation/generate_reference.R new file mode 100644 index 000000000..9f6a0db06 --- /dev/null +++ b/tests/saem_validation/generate_reference.R @@ -0,0 +1,509 @@ +#!/usr/bin/env Rscript +# SAEM Validation Reference Generator +# This script generates reference values from saemix for comparison with PMcore Rust implementation +# +# Usage: Rscript generate_reference.R +# Output: JSON files with reference values for each test case + +library(saemix) +library(jsonlite) + +# Get the script directory (works when run via Rscript) +get_script_dir <- function() { + args <- commandArgs(trailingOnly = FALSE) + file_arg <- "--file=" + match <- grep(file_arg, args) + if (length(match) > 0) { + return(dirname(normalizePath(sub(file_arg, "", args[match])))) + } + return(getwd()) +} + +output_dir <- get_script_dir() +setwd(output_dir) + +cat("=== SAEM Validation Reference Generator ===\n") +cat("Output directory:", getwd(), "\n\n") + +# ============================================================================= +# TEST CASE 1: One-Compartment IV Bolus (Simple) +# This matches the PMcore test_saem_convergence test +# ============================================================================= + +generate_one_compartment_iv_reference <- function() { + cat("--- Test Case 1: One-Compartment IV Bolus ---\n") + + # True parameters (matching PMcore test) + true_ke <- 0.4 + true_v <- 10.0 + dose <- 100.0 + + # Generate synthetic data (20 subjects, same as PMcore) + set.seed(42) + n_subjects <- 20 + times <- c(0.5, 1.0, 2.0, 4.0, 8.0, 12.0) + + # Population variability (CV = 30%) + omega_ke <- 0.09 # CV = sqrt(exp(0.09)-1) ≈ 30% + omega_v <- 0.09 + + # Residual error (additive, SD = 0.5) + sigma <- 0.5 + + # Create data frame + data_rows <- list() + for (id in 1:n_subjects) { + # Random effects from normal distribution + eta_ke <- rnorm(1, 0, sqrt(omega_ke)) + eta_v <- rnorm(1, 0, sqrt(omega_v)) + subj_ke <- true_ke * exp(eta_ke) + subj_v <- true_v * exp(eta_v) + + for (t in times) { + conc <- (dose / subj_v) * exp(-subj_ke * t) + # Add residual error + conc <- conc + rnorm(1, 0, sigma) + # Ensure non-negative (censoring) + conc <- max(conc, 0.01) + data_rows[[length(data_rows) + 1]] <- data.frame( + Id = id, + Time = t, + Dose = ifelse(t == times[1], dose, 0), + Concentration = conc + ) + } + } + + sim_data <- do.call(rbind, data_rows) + + # Save data for Rust + write.csv(sim_data, "onecomp_iv_data.csv", row.names = FALSE) + cat(" Saved data to onecomp_iv_data.csv\n") + + # Create saemix data object + saemix_data <- saemixData( + name.data = sim_data, + header = TRUE, + name.group = c("Id"), + name.predictors = c("Time"), + name.response = c("Concentration"), + name.X = "Time" + ) + + # One-compartment IV bolus model + # C(t) = (Dose/V) * exp(-ke * t) + model_1cpt_iv <- function(psi, id, xidep) { + tim <- xidep[, 1] + ke <- psi[id, 1] + V <- psi[id, 2] + # Dose is 100 for all subjects at t=0 + ypred <- (dose / V) * exp(-ke * tim) + return(ypred) + } + + # Model specification - log-normal for both parameters + saemix_model <- saemixModel( + model = model_1cpt_iv, + description = "One-compartment IV bolus", + psi0 = matrix(c(0.45, 10.0), # Initial values (geometric mean of bounds) + ncol = 2, byrow = TRUE, + dimnames = list(NULL, c("ke", "V")) + ), + transform.par = c(1, 1), # Both log-normal + covariance.model = matrix(c(1, 0, 0, 1), ncol = 2, byrow = TRUE), + omega.init = matrix(c(0.1, 0, 0, 0.1), ncol = 2, byrow = TRUE), + error.model = "constant" + ) + + # SAEM options - use more iterations for better convergence + saemix_options <- list( + seed = 12345, + nbiter.burn = 50, # More pure burn-in + nbiter.saemix = c(200, 100), # 200 SA + 100 smoothing + nb.chains = 3, + nbiter.mcmc = c(3, 3, 3, 0), # More MCMC iterations per kernel + proba.mcmc = 0.4, + stepsize.rw = 0.4, + alpha.sa = 0.97, + rw.ini = 0.5, + save = FALSE, + save.graphs = FALSE, + print = FALSE + ) + + # Run SAEM + cat(" Running saemix...\n") + saemix_fit <- saemix(saemix_model, saemix_data, saemix_options) + saemix_fit <- map.saemix(saemix_fit) + + # Extract results + results <- list( + test_case = "one_compartment_iv", + description = "One-compartment IV bolus with log-normal parameters", + true_values = list( + ke = true_ke, + v = true_v + ), + + # NOTE: saemix @results@fixed.effects are in PSI (natural) space, + # NOT in PHI (log) space. transphi() was already applied in main.R line 224. + # Population parameters in psi space (natural) - directly from fixed.effects + mu_psi = as.numeric(saemix_fit@results@fixed.effects), + # Population parameters in phi space (log-transformed) + mu_phi = as.numeric(log(saemix_fit@results@fixed.effects)), + + # Covariance matrix + omega = as.matrix(saemix_fit@results@omega), + omega_diag = as.numeric(diag(saemix_fit@results@omega)), + + # Residual error + sigma = as.numeric(saemix_fit@results@respar[1]), + + # Likelihood (linearization approximation) + ll_lin = saemix_fit@results@ll.lin, + objf = -2 * saemix_fit@results@ll.lin, + + # Individual estimates (MAP) + map_psi = as.matrix(saemix_fit@results@map.psi), + map_eta = as.matrix(saemix_fit@results@map.eta), + cond_mean_phi = as.matrix(saemix_fit@results@cond.mean.phi), + + # Settings for reproducibility + settings = list( + seed = 12345, + n_burn = 50, + n_sa = 200, + n_smooth = 100, + n_chains = 3, + transform_par = c(1, 1), # log-normal + error_model = "constant", + initial_psi = c(0.45, 10.0), + initial_omega_diag = c(0.1, 0.1) + ), + + # Data info + n_subjects = n_subjects, + n_observations = nrow(sim_data) + ) + + # Save results + write_json(results, "onecomp_iv_reference.json", pretty = TRUE, auto_unbox = TRUE) + cat(" Saved results to onecomp_iv_reference.json\n") + + # Print summary + cat("\n === Results Summary ===\n") + cat(" True ke:", true_ke, " Estimated (psi):", results$mu_psi[1], "\n") + cat(" True V:", true_v, " Estimated (psi):", results$mu_psi[2], "\n") + cat(" Omega diagonal:", results$omega_diag, "\n") + cat(" Sigma:", results$sigma, "\n") + cat(" -2LL:", results$objf, "\n\n") + + return(results) +} + +# ============================================================================= +# TEST CASE 2: Theophylline (Standard Reference) +# This is the classic NLME example from saemix +# ============================================================================= + +generate_theophylline_reference <- function() { + cat("--- Test Case 2: Theophylline (Standard) ---\n") + + # Load built-in theophylline data + data(theo.saemix) + + # Save for Rust (need to restructure for PMcore format) + # PMcore needs: ID, TIME, AMT/DOSE, DV, EVID, CMT + theo_export <- data.frame( + ID = theo.saemix$Id, + TIME = theo.saemix$Time, + DOSE = theo.saemix$Dose, + DV = theo.saemix$Concentration, + Weight = theo.saemix$Weight, + Sex = theo.saemix$Sex + ) + write.csv(theo_export, "theo_data.csv", row.names = FALSE) + cat(" Saved data to theo_data.csv\n") + + # Create saemix data object + theo_saemix_data <- saemixData( + name.data = theo.saemix, + header = TRUE, + sep = " ", + na = NA, + name.group = c("Id"), + name.predictors = c("Dose", "Time"), + name.response = c("Concentration"), + name.covariates = c("Weight", "Sex"), + units = list(x = "hr", y = "mg/L", covariates = c("kg", "-")), + name.X = "Time" + ) + + # One-compartment model with first-order absorption + model_1cpt_oral <- function(psi, id, xidep) { + dose <- xidep[, 1] + tim <- xidep[, 2] + ka <- psi[id, 1] + V <- psi[id, 2] + CL <- psi[id, 3] + k <- CL / V + ypred <- dose * ka / (V * (ka - k)) * (exp(-k * tim) - exp(-ka * tim)) + return(ypred) + } + + # Model specification + theo_model <- saemixModel( + model = model_1cpt_oral, + description = "One-compartment model with first-order absorption", + psi0 = matrix(c(1.5, 32, 3.0), + ncol = 3, byrow = TRUE, + dimnames = list(NULL, c("ka", "V", "CL")) + ), + transform.par = c(1, 1, 1), # All log-normal + covariance.model = matrix(c(1, 0, 0, 0, 1, 0, 0, 0, 1), ncol = 3, byrow = TRUE), + omega.init = matrix(c(0.5, 0, 0, 0, 0.05, 0, 0, 0, 0.1), ncol = 3, byrow = TRUE), + error.model = "constant" + ) + + # SAEM options - more burn-in for stability + theo_options <- list( + seed = 12345, + nbiter.burn = 50, + nbiter.saemix = c(300, 100), # 300 SA + 100 smoothing + nb.chains = 3, + nbiter.mcmc = c(3, 3, 3, 0), + proba.mcmc = 0.4, + stepsize.rw = 0.4, + alpha.sa = 0.97, + rw.ini = 0.5, + save = FALSE, + save.graphs = FALSE, + print = FALSE + ) + + # Run SAEM + cat(" Running saemix...\n") + theo_fit <- saemix(theo_model, theo_saemix_data, theo_options) + theo_fit <- map.saemix(theo_fit) + + # Extract results + results <- list( + test_case = "theophylline", + description = "One-compartment oral absorption (ka, V, CL)", + + # NOTE: saemix fixed.effects are in PSI (natural) space + # Population parameters in psi space (natural) + mu_psi = as.numeric(theo_fit@results@fixed.effects), + # Population parameters in phi space (log-transformed) + mu_phi = as.numeric(log(theo_fit@results@fixed.effects)), + + # Covariance + omega = as.matrix(theo_fit@results@omega), + omega_diag = as.numeric(diag(theo_fit@results@omega)), + + # Residual error + sigma = as.numeric(theo_fit@results@respar[1]), + + # Likelihood + ll_lin = theo_fit@results@ll.lin, + objf = -2 * theo_fit@results@ll.lin, + + # Individual estimates + map_psi = as.matrix(theo_fit@results@map.psi), + map_eta = as.matrix(theo_fit@results@map.eta), + cond_mean_phi = as.matrix(theo_fit@results@cond.mean.phi), + + # Settings + settings = list( + seed = 12345, + n_burn = 50, + n_sa = 300, + n_smooth = 100, + n_chains = 3, + transform_par = c(1, 1, 1), + error_model = "constant", + initial_psi = c(1.5, 32.0, 3.0), + initial_omega_diag = c(0.5, 0.05, 0.1) + ), + + # Data info + n_subjects = theo_saemix_data@N, + n_observations = theo_saemix_data@ntot.obs + ) + + write_json(results, "theo_reference.json", pretty = TRUE, auto_unbox = TRUE) + cat(" Saved results to theo_reference.json\n") + + cat("\n === Results Summary ===\n") + cat(" ka (psi):", results$mu_psi[1], " V (psi):", results$mu_psi[2], " CL (psi):", results$mu_psi[3], "\n") + cat(" Omega diagonal:", results$omega_diag, "\n") + cat(" Sigma:", results$sigma, "\n") + cat(" -2LL:", results$objf, "\n\n") + + return(results) +} + +# ============================================================================= +# TEST CASE 3: Bimodal Ke (PMcore Internal Dataset) +# Tests algorithm behavior with multimodal distributions +# ============================================================================= + +generate_bimodal_ke_reference <- function() { + cat("--- Test Case 3: Bimodal Ke ---\n") + + # Check if data file exists + data_file <- "../../../examples/bimodal_ke/bimodal_ke.csv" + if (!file.exists(data_file)) { + cat(" WARNING: bimodal_ke.csv not found, skipping this test case\n\n") + return(NULL) + } + + # Load PMcore bimodal_ke data + bimodal_data <- read.csv(data_file) + cat(" Loaded", nrow(bimodal_data), "rows from bimodal_ke.csv\n") + + # Restructure for saemix (needs specific column names) + # Assume PMcore format has ID, TIME, DV, AMT, EVID, etc. + # We need to adapt based on actual format + + cat(" Column names:", paste(names(bimodal_data), collapse = ", "), "\n") + + # This will depend on the actual format of bimodal_ke.csv + # For now, skip if format is not compatible + cat(" TODO: Implement bimodal_ke reference generation\n\n") + return(NULL) +} + +# ============================================================================= +# Component-Level Tests +# These test individual components for exact matching +# ============================================================================= + +generate_component_tests <- function() { + cat("--- Component-Level Reference Values ---\n") + + # Test 1: Parameter transformation verification + cat(" 1. Parameter transforms:\n") + + # Log-normal transform + psi_values <- c(0.1, 0.5, 1.0, 2.0, 5.0) + phi_values <- log(psi_values) + cat( + " Log-normal: psi=", paste(psi_values, collapse = ","), + " -> phi=", paste(round(phi_values, 6), collapse = ","), "\n" + ) + + # Test 2: Sufficient statistics computation + cat(" 2. Sufficient statistics:\n") + + # Simple 2-parameter example + phi_samples <- matrix(c( + 1.0, 2.0, + 3.0, 4.0, + 5.0, 6.0 + ), nrow = 3, byrow = TRUE) + + s1 <- colSums(phi_samples) # [9, 12] + s2 <- t(phi_samples) %*% phi_samples # [[35, 44], [44, 56]] + mu <- s1 / 3 # [3, 4] + omega <- s2 / 3 - mu %*% t(mu) # Sample variance + + cat(" S1:", paste(s1, collapse = ","), "\n") + cat(" S2 diag:", paste(diag(s2), collapse = ","), "\n") + cat(" mu:", paste(mu, collapse = ","), "\n") + cat(" omega diag:", paste(round(diag(omega), 6), collapse = ","), "\n") + + # Test 3: Step size schedule + cat(" 3. Step size schedule (n_burn=100, n_smooth=200):\n") + + n_burn <- 100 + n_smooth <- 200 + n_total <- n_burn + n_smooth + + # R saemix step sizes + stepsize <- rep(1, n_total) + stepsize[(n_burn + 1):n_total] <- 1 / (1:n_smooth) + + test_iters <- c(1, 50, 100, 101, 150, 200, 300) + for (k in test_iters) { + if (k <= n_total) { + cat(" iter", k, ": gamma =", round(stepsize[k], 6), "\n") + } + } + + # Save component tests + component_results <- list( + transforms = list( + log_normal = list( + psi = psi_values, + phi = phi_values + ) + ), + sufficient_stats = list( + phi_samples = phi_samples, + s1 = as.numeric(s1), + s2 = as.matrix(s2), + mu = as.numeric(mu), + omega = as.matrix(omega) + ), + step_size = list( + n_burn = n_burn, + n_smooth = n_smooth, + schedule = as.numeric(stepsize) + ) + ) + + write_json(component_results, "component_reference.json", pretty = TRUE, auto_unbox = TRUE) + cat(" Saved to component_reference.json\n\n") + + return(component_results) +} + +# ============================================================================= +# Run all test case generators +# ============================================================================= + +main <- function() { + results <- list() + + # Component tests (always run) + results$components <- generate_component_tests() + + # Full algorithm tests + results$onecomp_iv <- tryCatch( + generate_one_compartment_iv_reference(), + error = function(e) { + cat(" ERROR:", e$message, "\n") + return(NULL) + } + ) + + results$theophylline <- tryCatch( + generate_theophylline_reference(), + error = function(e) { + cat(" ERROR:", e$message, "\n") + return(NULL) + } + ) + + results$bimodal_ke <- tryCatch( + generate_bimodal_ke_reference(), + error = function(e) { + cat(" ERROR:", e$message, "\n") + return(NULL) + } + ) + + cat("=== Reference Generation Complete ===\n") + cat("Generated files:\n") + cat(" - component_reference.json\n") + if (!is.null(results$onecomp_iv)) cat(" - onecomp_iv_reference.json\n") + if (!is.null(results$theophylline)) cat(" - theo_reference.json\n") + + return(results) +} + +# Run if executed directly +if (!interactive()) { + main() +} diff --git a/tests/saem_validation/mod.rs b/tests/saem_validation/mod.rs new file mode 100644 index 000000000..2a0306805 --- /dev/null +++ b/tests/saem_validation/mod.rs @@ -0,0 +1,32 @@ +//! SAEM Validation Test Module +//! +//! This module contains comprehensive tests to validate the PMcore SAEM implementation +//! against the R saemix reference implementation. +//! +//! # Test Structure +//! +//! 1. **Component Tests**: Verify individual components match R exactly +//! - Parameter transformations (φ ↔ ψ) +//! - Sufficient statistics computation +//! - Step size schedule +//! +//! 2. **Integration Tests**: Verify algorithm phases +//! - E-step MCMC sampling +//! - M-step parameter updates +//! +//! 3. **End-to-End Tests**: Compare full algorithm results +//! - One-compartment IV (synthetic) +//! - Theophylline (standard reference) +//! +//! # Running Tests +//! +//! ```bash +//! # Run all validation tests +//! cargo test saem_validation -- --nocapture +//! +//! # Run specific test +//! cargo test test_component_transforms -- --nocapture +//! ``` + +pub mod reference; +pub mod tests; diff --git a/tests/saem_validation/onecomp_iv_data.csv b/tests/saem_validation/onecomp_iv_data.csv new file mode 100644 index 000000000..1dd4527bd --- /dev/null +++ b/tests/saem_validation/onecomp_iv_data.csv @@ -0,0 +1,121 @@ +"Id","Time","Dose","Concentration" +1,0.5,100,8.94197432932875 +1,1,0,6.79494359054814 +1,2,0,3.74518173504493 +1,4,0,1.00663150339752 +1,8,0,0.850556377282524 +1,12,0,0.01 +2,0.5,100,7.71604145058557 +2,1,0,6.03978137792969 +2,2,0,1.65841440154282 +2,4,0,0.403875497110581 +2,8,0,0.01 +2,12,0,0.319519404380605 +3,0.5,100,17.2446351811133 +3,1,0,16.0269920825166 +3,2,0,10.4898103507497 +3,4,0,4.21478988680277 +3,8,0,1.08883573458709 +3,12,0,0.877664843826841 +4,0.5,100,7.86449200425688 +4,1,0,4.73339935531452 +4,2,0,3.0008887234611 +4,4,0,0.354744650282417 +4,8,0,0.267737082915282 +4,12,0,0.354791370534054 +5,0.5,100,9.39040071149986 +5,1,0,6.09748636862728 +5,2,0,3.63848443245282 +5,4,0,0.927951726274944 +5,8,0,0.01 +5,12,0,0.0352646055035824 +6,0.5,100,9.38743641208784 +6,1,0,6.91862633816261 +6,2,0,4.07421904971097 +6,4,0,2.248171587455 +6,8,0,0.01 +6,12,0,0.78958669594749 +7,0.5,100,7.05165927274128 +7,1,0,5.38836934556378 +7,2,0,4.85531499542277 +7,4,0,2.33548916831754 +7,8,0,0.538689608347473 +7,12,0,0.259349269231009 +8,0.5,100,6.12080022025654 +8,1,0,6.10333630814859 +8,2,0,3.46667039752461 +8,4,0,1.46147320107431 +8,8,0,0.483407515257422 +8,12,0,0.726938115600454 +9,0.5,100,5.92844197147128 +9,1,0,5.42414207668721 +9,2,0,4.01638674284061 +9,4,0,2.22955016846248 +9,8,0,0.01 +9,12,0,0.097571854855143 +10,0.5,100,10.1880083655839 +10,1,0,8.50880704570658 +10,2,0,5.45787286265875 +10,4,0,2.16576985349746 +10,8,0,0.01 +10,12,0,0.01 +11,0.5,100,6.79966094496247 +11,1,0,4.87028541174783 +11,2,0,2.02963701305263 +11,4,0,1.0515173445629 +11,8,0,0.01 +11,12,0,0.01 +12,0.5,100,6.69405870008875 +12,1,0,4.36533150122211 +12,2,0,3.03679140887892 +12,4,0,1.63641347362555 +12,8,0,0.01 +12,12,0,0.01 +13,0.5,100,13.4758493753567 +13,1,0,11.9789178139699 +13,2,0,9.36455403947386 +13,4,0,5.4802238303314 +13,8,0,1.08499728251102 +13,12,0,1.4319823157086 +14,0.5,100,8.01396866627334 +14,1,0,6.92161899647598 +14,2,0,5.12685037430179 +14,4,0,2.67389960805448 +14,8,0,0.692900773885856 +14,12,0,0.244393973286055 +15,0.5,100,8.95547040271268 +15,1,0,8.0410836463311 +15,2,0,5.56931348285513 +15,4,0,4.26832861392107 +15,8,0,0.0505749325026397 +15,12,0,0.252109981970554 +16,0.5,100,13.7425037170214 +16,1,0,11.540980540947 +16,2,0,9.32350901629286 +16,4,0,5.37910529414248 +16,8,0,1.70570851851514 +16,12,0,0.01 +17,0.5,100,8.53392433318334 +17,1,0,6.93660989364689 +17,2,0,5.44513954166353 +17,4,0,3.69041851644507 +17,8,0,1.75314768077686 +17,12,0,0.01 +18,0.5,100,5.51428690124503 +18,1,0,4.71344254850447 +18,2,0,3.17830280506375 +18,4,0,1.04421487948215 +18,8,0,0.095183672329188 +18,12,0,0.0530321767674149 +19,0.5,100,5.76047703889415 +19,1,0,4.81268180414357 +19,2,0,3.8808362790532 +19,4,0,1.21443381571917 +19,8,0,0.403748176451192 +19,12,0,0.01 +20,0.5,100,7.93896646728667 +20,1,0,5.53403612247915 +20,2,0,3.48385730825527 +20,4,0,0.720418462562687 +20,8,0,0.01 +20,12,0,0.655777480045852 diff --git a/tests/saem_validation/onecomp_iv_reference.json b/tests/saem_validation/onecomp_iv_reference.json new file mode 100644 index 000000000..7ba2f5fd0 --- /dev/null +++ b/tests/saem_validation/onecomp_iv_reference.json @@ -0,0 +1,97 @@ +{ + "test_case": "one_compartment_iv", + "description": "One-compartment IV bolus with log-normal parameters", + "true_values": { + "ke": 0.4, + "v": 10 + }, + "mu_psi": [0.4234, 9.5222], + "mu_phi": [-0.8595, 2.2536], + "omega": [ + [0.121, 0], + [0, 0.0818] + ], + "omega_diag": [0.121, 0.0818], + "sigma": 0.4532, + "ll_lin": -137.7624, + "objf": 275.5248, + "map_psi": [ + [0.554, 8.5378], + [0.7718, 8.6251], + [0.3689, 4.631], + [0.6793, 9.4882], + [0.6285, 8.0932], + [0.4729, 8.7009], + [0.3193, 12.1953], + [0.3975, 12.3618], + [0.3381, 13.368], + [0.4346, 7.8005], + [0.6511, 10.7588], + [0.4818, 12.3623], + [0.2683, 6.4178], + [0.3247, 10.4447], + [0.2948, 9.4399], + [0.2765, 6.3863], + [0.2565, 10.6035], + [0.4476, 13.7764], + [0.3911, 13.6646], + [0.5777, 9.6357] + ], + "map_eta": [ + [0.2689, -0.1091], + [0.6004, -0.099], + [-0.1378, -0.7209], + [0.4728, -0.0036], + [0.395, -0.1626], + [0.1106, -0.0902], + [-0.282, 0.2474], + [-0.0631, 0.261], + [-0.225, 0.3392], + [0.0262, -0.1994], + [0.4304, 0.1221], + [0.1292, 0.261], + [-0.4563, -0.3946], + [-0.2655, 0.0925], + [-0.3621, -0.0087], + [-0.4261, -0.3995], + [-0.5011, 0.1076], + [0.0556, 0.3693], + [-0.0794, 0.3612], + [0.3107, 0.0118] + ], + "cond_mean_phi": [ + [-0.5708, 2.1344], + [-0.2672, 2.1601], + [-1.0006, 1.5359], + [-0.3937, 2.2556], + [-0.4664, 2.0898], + [-0.7479, 2.1638], + [-1.1549, 2.5096], + [-0.9405, 2.524], + [-1.072, 2.5892], + [-0.8383, 2.0565], + [-0.4058, 2.3669], + [-0.7185, 2.5097], + [-1.3139, 1.8587], + [-1.138, 2.352], + [-1.2073, 2.2396], + [-1.2859, 1.8539], + [-1.3528, 2.3606], + [-0.8129, 2.6267], + [-0.9457, 2.6151], + [-0.5563, 2.2705] + ], + "settings": { + "seed": 12345, + "n_burn": 50, + "n_sa": 200, + "n_smooth": 100, + "n_chains": 3, + "transform_par": [1, 1], + "error_model": "constant", + "initial_psi": [0.45, 10], + "initial_omega_diag": [0.1, 0.1] + }, + "n_subjects": 20, + "n_observations": 120 +} diff --git a/tests/saem_validation/reference.rs b/tests/saem_validation/reference.rs new file mode 100644 index 000000000..ea7762c03 --- /dev/null +++ b/tests/saem_validation/reference.rs @@ -0,0 +1,192 @@ +//! Reference data structures for loading R saemix results +//! +//! These structures match the JSON output from generate_reference.R + +use serde::Deserialize; +use std::path::Path; + +/// Reference results from R saemix for a single test case +#[derive(Debug, Deserialize)] +pub struct SaemixReference { + pub test_case: String, + pub description: String, + + /// True parameter values (if synthetic data) + #[serde(default)] + pub true_values: Option, + + /// Population mean in φ (transformed/unconstrained) space + pub mu_phi: Vec, + + /// Population mean in ψ (original/constrained) space + pub mu_psi: Vec, + + /// Full covariance matrix Ω + pub omega: Vec>, + + /// Diagonal of Ω (variances) + pub omega_diag: Vec, + + /// Residual error standard deviation + pub sigma: f64, + + /// Log-likelihood (linearization approximation) + pub ll_lin: f64, + + /// Objective function (-2LL) + pub objf: f64, + + /// MAP individual parameter estimates (n_subjects × n_params) + pub map_psi: Vec>, + + /// MAP random effects (n_subjects × n_params) + pub map_eta: Vec>, + + /// Conditional mean in φ space + pub cond_mean_phi: Vec>, + + /// Settings used for this run + pub settings: SaemixSettings, + + /// Number of subjects + pub n_subjects: usize, + + /// Total number of observations + pub n_observations: usize, +} + +#[derive(Debug, Deserialize, Default)] +pub struct TrueValues { + #[serde(default)] + pub ke: Option, + #[serde(default)] + pub v: Option, + #[serde(default)] + pub ka: Option, + #[serde(default)] + pub cl: Option, +} + +#[derive(Debug, Deserialize)] +pub struct SaemixSettings { + pub seed: u64, + pub n_burn: usize, + pub n_sa: usize, + pub n_smooth: usize, + pub n_chains: usize, + pub transform_par: Vec, + pub error_model: String, + pub initial_psi: Vec, + pub initial_omega_diag: Vec, +} + +/// Component-level reference values for exact matching +#[derive(Debug, Deserialize)] +pub struct ComponentReference { + pub transforms: TransformReference, + pub sufficient_stats: SufficientStatsReference, + pub step_size: StepSizeReference, +} + +#[derive(Debug, Deserialize)] +pub struct TransformReference { + pub log_normal: LogNormalTransform, +} + +#[derive(Debug, Deserialize)] +pub struct LogNormalTransform { + pub psi: Vec, + pub phi: Vec, +} + +#[derive(Debug, Deserialize)] +pub struct SufficientStatsReference { + pub phi_samples: Vec>, + pub s1: Vec, + pub s2: Vec>, + pub mu: Vec, + pub omega: Vec>, +} + +#[derive(Debug, Deserialize)] +pub struct StepSizeReference { + pub n_burn: usize, + pub n_smooth: usize, + pub schedule: Vec, +} + +/// Load reference from JSON file +pub fn load_reference>(path: P) -> Result { + let content = std::fs::read_to_string(path.as_ref()) + .map_err(|e| format!("Failed to read file: {}", e))?; + serde_json::from_str(&content).map_err(|e| format!("Failed to parse JSON: {}", e)) +} + +/// Load component reference from JSON file +pub fn load_component_reference>(path: P) -> Result { + let content = std::fs::read_to_string(path.as_ref()) + .map_err(|e| format!("Failed to read file: {}", e))?; + serde_json::from_str(&content).map_err(|e| format!("Failed to parse JSON: {}", e)) +} + +/// Assertion helper: check values are close within relative tolerance +pub fn assert_close(actual: f64, expected: f64, rtol: f64, name: &str) { + let abs_expected = expected.abs().max(1e-10); + let rel_error = (actual - expected).abs() / abs_expected; + assert!( + rel_error < rtol, + "{}: actual={:.6}, expected={:.6}, rel_error={:.4} (tolerance={:.4})", + name, + actual, + expected, + rel_error, + rtol + ); +} + +/// Assertion helper: check vectors are close +pub fn assert_vec_close(actual: &[f64], expected: &[f64], rtol: f64, name: &str) { + assert_eq!(actual.len(), expected.len(), "{}: length mismatch", name); + for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() { + assert_close(*a, *e, rtol, &format!("{}[{}]", name, i)); + } +} + +/// Assertion helper: check matrices are close (flattened comparison) +pub fn assert_matrix_close(actual: &[Vec], expected: &[Vec], rtol: f64, name: &str) { + assert_eq!(actual.len(), expected.len(), "{}: row count mismatch", name); + for (i, (row_a, row_e)) in actual.iter().zip(expected.iter()).enumerate() { + assert_eq!( + row_a.len(), + row_e.len(), + "{}: row {} length mismatch", + name, + i + ); + for (j, (a, e)) in row_a.iter().zip(row_e.iter()).enumerate() { + assert_close(*a, *e, rtol, &format!("{}[{},{}]", name, i, j)); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_assert_close() { + // Should pass + assert_close(1.0, 1.0, 0.01, "exact"); + assert_close(1.01, 1.0, 0.02, "within 2%"); + + // Should fail - commented out to not break tests + // assert_close(1.1, 1.0, 0.05, "beyond 5%"); + } + + #[test] + fn test_assert_vec_close() { + let a = vec![1.0, 2.0, 3.0]; + let b = vec![1.01, 2.02, 3.03]; + assert_vec_close(&a, &b, 0.02, "test_vec"); + } +} diff --git a/tests/saem_validation/reference_saemix.R b/tests/saem_validation/reference_saemix.R new file mode 100644 index 000000000..6407b0313 --- /dev/null +++ b/tests/saem_validation/reference_saemix.R @@ -0,0 +1,111 @@ +# Reference SAEM implementation using saemix R package +# This script runs the theophylline example and saves results for comparison + +library(saemix) + +# Load theophylline data +data(theo.saemix) + +# Create saemix data object +saemix.data <- saemixData( + name.data = theo.saemix, + header = TRUE, + sep = " ", + na = NA, + name.group = c("Id"), + name.predictors = c("Dose", "Time"), + name.response = c("Concentration"), + name.covariates = c("Weight", "Sex"), + units = list(x = "hr", y = "mg/L", covariates = c("kg", "-")), + name.X = "Time" +) + +# One-compartment model with first-order absorption +model1cpt <- function(psi, id, xidep) { + dose <- xidep[, 1] + tim <- xidep[, 2] + ka <- psi[id, 1] + V <- psi[id, 2] + CL <- psi[id, 3] + k <- CL / V + ypred <- dose * ka / (V * (ka - k)) * (exp(-k * tim) - exp(-ka * tim)) + return(ypred) +} + +# Model specification +# Initial values: ka=1, V=20, CL=0.5 +# All parameters log-transformed (transform.par=c(1,1,1)) +# Diagonal covariance matrix for random effects +saemix.model <- saemixModel( + model = model1cpt, + description = "One-compartment model with first-order absorption", + psi0 = matrix(c(1., 20, 0.5), + ncol = 3, byrow = TRUE, + dimnames = list(NULL, c("ka", "V", "CL")) + ), + transform.par = c(1, 1, 1), # Log transform all parameters + covariance.model = matrix(c(1, 0, 0, 0, 1, 0, 0, 0, 1), ncol = 3, byrow = TRUE), + omega.init = matrix(c(1, 0, 0, 0, 1, 0, 0, 0, 1), ncol = 3, byrow = TRUE), + error.model = "constant" +) + +# SAEM options +# algorithm: c(nburning, nexploration, nsmoothing) +# Using f-SAEM style with burn-in then SA +saemix.options <- list( + algorithm = c(1, 1, 1), # Run SAEM + seed = 12345, + nbiter.saemix = c(300, 100), # 300 burn-in, 100 SA iterations + nb.chains = 3, + save = FALSE, + save.graphs = FALSE, + print = FALSE +) + +# Run SAEM +cat("Running saemix...\n") +saemix.fit <- saemix(saemix.model, saemix.data, saemix.options) + +# Extract results +cat("\n=== SAEMIX Results ===\n") +cat("\nPopulation parameters (fixed effects on log scale):\n") +print(saemix.fit@results@fixed.effects) + +cat("\nPopulation parameters (original scale):\n") +psi_pop <- exp(saemix.fit@results@fixed.effects) +names(psi_pop) <- c("ka", "V", "CL") +print(psi_pop) + +cat("\nRandom effect variances (omega^2):\n") +omega <- saemix.fit@results@omega +print(diag(omega)) + +cat("\nResidual error (sigma):\n") +print(saemix.fit@results@respar) + +cat("\nObjective function (-2LL):\n") +print(saemix.fit@results@ll.lin * -2) + +cat("\nIndividual parameters (first 5 subjects):\n") +# Get MAP estimates +saemix.fit <- map.saemix(saemix.fit) +head(saemix.fit@results@map.psi, 5) + +# Save results for comparison +results <- list( + mu_log = as.numeric(saemix.fit@results@fixed.effects), + mu = as.numeric(exp(saemix.fit@results@fixed.effects)), + omega_diag = as.numeric(diag(saemix.fit@results@omega)), + sigma = as.numeric(saemix.fit@results@respar[1]), + objf = as.numeric(saemix.fit@results@ll.lin * -2) +) + +# Save to JSON for Rust test to read +library(jsonlite) +write_json(results, "saemix_results.json", pretty = TRUE, auto_unbox = TRUE) +cat("\nResults saved to saemix_results.json\n") + +# Also save the data in a format the Rust test can read +theo_data <- theo.saemix[, c("Id", "Dose", "Time", "Concentration")] +write.csv(theo_data, "theo_data.csv", row.names = FALSE) +cat("Data saved to theo_data.csv\n") diff --git a/tests/saem_validation/tests.rs b/tests/saem_validation/tests.rs new file mode 100644 index 000000000..4ae2ad7e7 --- /dev/null +++ b/tests/saem_validation/tests.rs @@ -0,0 +1,845 @@ +//! SAEM Validation Tests +//! +//! These tests compare PMcore SAEM against R saemix reference values. + +use super::reference::*; +use anyhow::Result; +use pharmsol::Equation; +use pmcore::model::ParameterTransform as ModelParameterTransform; +use pmcore::prelude::*; + +#[derive(Clone)] +struct SaemValidationConfig { + parameter_space: ParameterSpace, + residual_error: ResidualErrorModels, + output: OutputPlan, + runtime: RuntimeOptions, +} + +impl SaemValidationConfig { + fn new(parameter_space: ParameterSpace, residual_error: ResidualErrorModels) -> Self { + Self { + parameter_space, + residual_error, + output: OutputPlan { + write: false, + path: None, + }, + runtime: RuntimeOptions::default(), + } + } +} + +fn bounded_parameter_space(bounds: &[(&str, f64, f64)]) -> ParameterSpace { + bounds + .iter() + .fold(ParameterSpace::new(), |space, (name, lower, upper)| { + space.add(ParameterSpec::bounded(*name, *lower, *upper)) + }) +} + +fn apply_saem_transforms(parameter_space: &ParameterSpace, saem: &SaemConfig) -> ParameterSpace { + parameter_space + .iter() + .enumerate() + .fold(ParameterSpace::new(), |space, (index, parameter)| { + space.add(ParameterSpec { + transform: match saem.get_transform(index) { + 1 => ModelParameterTransform::LogNormal, + 2 => ModelParameterTransform::Probit, + 3 => ModelParameterTransform::Logit, + _ => ModelParameterTransform::Identity, + }, + ..parameter.clone() + }) + }) +} + +fn run_saem_problem( + config: SaemValidationConfig, + equation: E, + data: Data, +) -> Result> { + build_saem_problem(config, equation, data)?.run() +} + +fn build_saem_problem( + config: SaemValidationConfig, + equation: E, + data: Data, +) -> Result> { + let parameters = apply_saem_transforms(&config.parameter_space, &config.runtime.tuning.saem); + + let observations = ObservationSpec::new() + .add_channel(ObservationChannel::continuous(0, "obs")) + .with_residual_error_models(config.residual_error); + + let model = ModelDefinition::builder(equation) + .parameters(parameters) + .observations(observations) + .build()?; + + EstimationProblem::builder(model, data) + .method(EstimationMethod::Parametric(ParametricMethod::Saem( + SaemOptions, + ))) + .output(config.output) + .runtime(config.runtime) + .build() +} + +// Path to reference data files (relative to tests/ directory) +const VALIDATION_DIR: &str = "tests/saem_validation"; + +// ============================================================================= +// Basic Model Tests +// Verify the model produces correct predictions +// ============================================================================= + +/// Test that the ODE model produces correct predictions +#[test] +fn test_ode_predictions() { + use pmcore::prelude::*; + + println!("=== Testing ODE Predictions ==="); + + // Simple one-compartment IV bolus model + // CRITICAL: b[0] is the bolus input term - must be included! + let eq = equation::ODE::new( + |x, p, _t, dx, b, _rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + b[0]; // b[0] is the bolus input + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + ); + + // Create a simple subject with one bolus and observations + let subject = Subject::builder("test") + .bolus(0.0, 100.0, 0) // 100 units at t=0 into compartment 0 + .observation(1.0, 0.0, 0) // placeholder observation at t=1 + .observation(2.0, 0.0, 0) // placeholder observation at t=2 + .build(); + + // Parameters: ke=0.4, V=10 + let params = vec![0.4, 10.0]; + + // Get predictions + let predictions = eq.estimate_predictions(&subject, ¶ms).unwrap(); + let preds: Vec = predictions + .get_predictions() + .iter() + .map(|p| p.prediction()) + .collect(); + + println!(" Predictions: {:?}", preds); + + // Expected: C(t) = (Dose/V) * exp(-ke*t) = (100/10) * exp(-0.4*t) + // C(1) = 10 * exp(-0.4) = 10 * 0.6703 = 6.703 + // C(2) = 10 * exp(-0.8) = 10 * 0.4493 = 4.493 + let expected_c1 = 10.0 * (-0.4_f64).exp(); + let expected_c2 = 10.0 * (-0.8_f64).exp(); + + println!( + " Expected: C(1)={:.4}, C(2)={:.4}", + expected_c1, expected_c2 + ); + + assert!(preds.len() >= 2, "Should have at least 2 predictions"); + assert!( + (preds[0] - expected_c1).abs() < 0.1, + "C(1) mismatch: got {}, expected {}", + preds[0], + expected_c1 + ); + assert!( + (preds[1] - expected_c2).abs() < 0.1, + "C(2) mismatch: got {}, expected {}", + preds[1], + expected_c2 + ); + + println!(" ✓ ODE predictions correct\n"); +} + +// ============================================================================= +// Component-Level Tests +// These verify exact matching of individual components +// ============================================================================= + +/// Test parameter transformations match R exactly +#[test] +fn test_component_transforms() { + use pmcore::prelude::ParameterTransform; + + println!("=== Testing Parameter Transforms ==="); + + // Test log-normal transform (code 1 in saemix) + let transform = ParameterTransform::LogNormal; + + // Test values from R reference + let psi_values: Vec = vec![0.1, 0.5, 1.0, 2.0, 5.0]; + let expected_phi: Vec = psi_values.iter().map(|&x| x.ln()).collect(); + + for (i, &psi) in psi_values.iter().enumerate() { + let phi = transform.psi_to_phi(psi); + let psi_back = transform.phi_to_psi(phi); + + println!( + " psi={:.4} -> phi={:.6} (expected: {:.6}) -> psi_back={:.6}", + psi, phi, expected_phi[i], psi_back + ); + + // Check forward transform + assert_close(phi, expected_phi[i], 1e-10, &format!("phi[{}]", i)); + + // Check round-trip + assert_close(psi_back, psi, 1e-10, &format!("psi_roundtrip[{}]", i)); + } + + println!(" ✓ Log-normal transforms match R\n"); + + // Test identity transform (code 0) + let identity = ParameterTransform::None; + for &val in &[-1.0, 0.0, 1.0, 5.0] { + assert_close(identity.psi_to_phi(val), val, 1e-15, "identity_forward"); + assert_close(identity.phi_to_psi(val), val, 1e-15, "identity_inverse"); + } + println!(" ✓ Identity transforms match\n"); +} + +/// Test sufficient statistics computation matches R +#[test] +fn test_component_sufficient_stats() { + use faer::Col; + use pmcore::prelude::SufficientStats; + + println!("=== Testing Sufficient Statistics ==="); + + // Test data from R reference (generate_reference.R) + let samples = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]]; + + // Expected values from R + let expected_s1 = vec![9.0, 12.0]; + let expected_mu = vec![3.0, 4.0]; + // Omega = S2/n - mu*mu' + // E[X^2] for col 0: (1+9+25)/3 = 35/3 + // Var = 35/3 - 9 = 8/3 ≈ 2.6667 + let expected_omega_00 = 8.0 / 3.0; + + let mut stats = SufficientStats::new(2); + + for sample in &samples { + let phi = Col::from_fn(2, |i| sample[i]); + stats.accumulate(&phi).unwrap(); + } + + // Check S1 + println!( + " S1: [{:.4}, {:.4}] (expected: {:?})", + stats.s1()[0], + stats.s1()[1], + expected_s1 + ); + assert_close(stats.s1()[0], expected_s1[0], 1e-10, "S1[0]"); + assert_close(stats.s1()[1], expected_s1[1], 1e-10, "S1[1]"); + + // Compute M-step + let (mu, omega) = stats.compute_m_step().unwrap(); + + // Check mu + println!( + " mu: [{:.4}, {:.4}] (expected: {:?})", + mu[0], mu[1], expected_mu + ); + assert_close(mu[0], expected_mu[0], 1e-10, "mu[0]"); + assert_close(mu[1], expected_mu[1], 1e-10, "mu[1]"); + + // Check omega diagonal + println!( + " omega[0,0]: {:.6} (expected: {:.6})", + omega[(0, 0)], + expected_omega_00 + ); + assert_close(omega[(0, 0)], expected_omega_00, 1e-10, "omega[0,0]"); + + println!(" ✓ Sufficient statistics match R\n"); +} + +/// Test step size schedule matches R saemix +#[test] +fn test_component_step_size() { + use pmcore::prelude::StepSizeSchedule; + + println!("=== Testing Step Size Schedule ==="); + + // PMcore PolyakRuppert schedule: + // - k < start_averaging: gamma = 1.0 (burn-in/exploration) + // - k >= start_averaging: gamma = 1/(k - start_averaging + 1) (smoothing) + + let n_burn = 100; // Total burn-in iterations + let n_smooth = 200; + let schedule = StepSizeSchedule::new_saem(n_burn, n_smooth); + + // Test values based on PMcore's actual implementation + let test_cases = vec![ + (1, 1.0, "burn-in start"), + (50, 1.0, "burn-in middle"), + (99, 1.0, "last burn-in"), // 99 < 100, so gamma = 1.0 + (100, 1.0, "first smoothing"), // 100 >= 100, gamma = 1/(100-100+1) = 1.0 + (101, 0.5, "second smoothing"), // gamma = 1/(101-100+1) = 0.5 + (102, 1.0 / 3.0, "third smoothing"), // gamma = 1/3 + (200, 0.01, "late smoothing"), // gamma = 1/(200-100+1) = 1/101 ≈ 0.0099 + ]; + + for (iter, expected, desc) in test_cases { + let actual = schedule.step_size(iter); + println!( + " iter {}: gamma={:.6} (expected: {:.6}) - {}", + iter, actual, expected, desc + ); + + // Allow some tolerance for numerical differences + assert_close(actual, expected, 0.01, &format!("step_size({})", iter)); + } + + println!(" ✓ Step size schedule matches PMcore behavior\n"); +} + +// ============================================================================= +// Integration Tests +// These test algorithm phases against R +// ============================================================================= + +/// Test SAEM algorithm initialization +#[test] +fn test_saem_initialization() -> Result<()> { + use pmcore::algorithms::parametric::dispatch_parametric_algorithm; + use pmcore::prelude::*; + + println!("=== Testing SAEM Initialization ==="); + + // Create simple model - b[0] is the bolus input term! + let eq = equation::ODE::new( + |x, p, _t, dx, b, _rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + b[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + ); + + // Create test data + let subjects: Vec = (0..3) + .map(|id| { + Subject::builder(id.to_string()) + .bolus(0.0, 100.0, 0) + .observation(1.0, 5.0, 0) + .observation(4.0, 2.0, 0) + .build() + }) + .collect(); + let data = Data::new(subjects); + + // Parameters - log-normal by default + let parameter_space = bounded_parameter_space(&[("ke", 0.1, 1.0), ("v", 5.0, 20.0)]); + + // Residual error model (parametric algorithms use ResidualErrorModels) + use pharmsol::{ResidualErrorModel, ResidualErrorModels}; + let residual_error = ResidualErrorModels::new().add(0, ResidualErrorModel::constant(0.5)); + + let config = SaemValidationConfig::new(parameter_space, residual_error); + + let mut algorithm = dispatch_parametric_algorithm(build_saem_problem(config, eq, data)?)?; + + // Before initialization, mu is in ψ (natural) space + // Population is initialized with midpoints: ke = (0.1+1.0)/2 = 0.55, v = (5+20)/2 = 12.5 + assert_eq!(algorithm.iteration(), 0, "Initial iteration should be 0"); + assert_eq!(algorithm.population().npar(), 2, "Should have 2 parameters"); + + let mu_ke_pre = algorithm.population().mu()[0]; + let mu_v_pre = algorithm.population().mu()[1]; + println!(" Before init - mu[ke] (ψ space): {:.4}", mu_ke_pre); + println!(" Before init - mu[v] (ψ space): {:.4}", mu_v_pre); + + // These are arithmetic midpoints in ψ space + assert_close(mu_ke_pre, 0.55, 0.01, "mu_ke pre-init"); + assert_close(mu_v_pre, 12.5, 0.01, "mu_v pre-init"); + + // Initialize - this transforms mu from ψ to φ space + algorithm.initialize()?; + + // After initialization, mu is in φ (transformed) space + let mu_ke = algorithm.population().mu()[0]; + let mu_v = algorithm.population().mu()[1]; + + println!(" After init - mu[ke] (φ space): {:.4}", mu_ke); + println!(" After init - mu[v] (φ space): {:.4}", mu_v); + + // For LogNormal: φ = ln(ψ) + // ln(0.55) ≈ -0.598, ln(12.5) ≈ 2.526 + assert!(mu_ke < 0.0, "Log of ke (0.55) should be negative"); + assert!(mu_v > 0.0, "Log of v (12.5) should be positive"); + + println!(" ✓ SAEM initialization correct\n"); + Ok(()) +} + +/// Test SAEM runs multiple iterations without errors +#[test] +fn test_saem_iterations() -> Result<()> { + use pmcore::algorithms::parametric::dispatch_parametric_algorithm; + use pmcore::prelude::*; + + println!("=== Testing SAEM Iterations ==="); + + // Simple model - b[0] is the bolus input term! + let eq = equation::ODE::new( + |x, p, _t, dx, b, _rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + b[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + ); + + // Create synthetic data with known parameters + let true_ke = 0.4; + let true_v = 10.0; + let dose = 100.0; + + let subjects: Vec = (0..10) + .map(|id| { + let eta_ke = 0.1 * ((id as f64 * 0.5).sin()); + let eta_v = 0.1 * ((id as f64 * 0.7).cos()); + let subj_ke = true_ke * f64::exp(eta_ke); + let subj_v = true_v * f64::exp(eta_v); + + Subject::builder(id.to_string()) + .bolus(0.0, dose, 0) + .observation(1.0, (dose / subj_v) * f64::exp(-subj_ke * 1.0), 0) + .observation(2.0, (dose / subj_v) * f64::exp(-subj_ke * 2.0), 0) + .observation(4.0, (dose / subj_v) * f64::exp(-subj_ke * 4.0), 0) + .build() + }) + .collect(); + let data = Data::new(subjects); + + let parameter_space = bounded_parameter_space(&[("ke", 0.1, 0.8), ("v", 5.0, 15.0)]); + + // Residual error model (parametric algorithms use ResidualErrorModels) + use pharmsol::{ResidualErrorModel, ResidualErrorModels}; + let residual_error = ResidualErrorModels::new().add(0, ResidualErrorModel::constant(0.1)); + + let config = SaemValidationConfig::new(parameter_space, residual_error); + + let mut algorithm = dispatch_parametric_algorithm(build_saem_problem(config, eq, data)?)?; + algorithm.initialize()?; + + // Run enough iterations to get past burn-in (default: 20 pure burn + 80 SA = 100) + // We'll run 120 iterations to enter the stochastic approximation phase + let mut prev_objf = f64::INFINITY; + let n_iter = 120; + + for i in 1..=n_iter { + let _status = algorithm.next_iteration()?; + let objf = algorithm.objective_function(); + + if i <= 5 || i == 21 || i == 100 || i == n_iter { + let pop = algorithm.population(); + let ke_psi = f64::exp(pop.mu()[0]); // Convert back to ψ space + let v_psi = f64::exp(pop.mu()[1]); + + println!( + " Iter {}: objf={:.2}, ke={:.4} (true: {}), v={:.4} (true: {})", + i, objf, ke_psi, true_ke, v_psi, true_v + ); + } + + prev_objf = objf; + } + + // After some iterations, objective should be finite + assert!( + prev_objf.is_finite(), + "Objective should be finite after iterations" + ); + + // Parameters should be in reasonable range + let final_pop = algorithm.population(); + let final_ke = f64::exp(final_pop.mu()[0]); + let final_v = f64::exp(final_pop.mu()[1]); + + println!( + "\n Final: ke={:.4} (true: {}), v={:.4} (true: {})", + final_ke, true_ke, final_v, true_v + ); + + // Allow generous tolerance during burn-in + assert!( + final_ke > 0.05 && final_ke < 2.0, + "ke estimate out of reasonable range" + ); + assert!( + final_v > 1.0 && final_v < 50.0, + "v estimate out of reasonable range" + ); + + println!(" ✓ SAEM iterations completed successfully\n"); + Ok(()) +} + +// ============================================================================= +// End-to-End Validation Tests (against R reference) +// ============================================================================= + +/// Test one-compartment IV bolus against R reference +/// Requires: onecomp_iv_reference.json from generate_reference.R +#[test] +fn test_validate_onecomp_iv() -> Result<()> { + use pmcore::prelude::*; + + println!("=== Validating One-Compartment IV vs R Reference ==="); + + // Load R reference + let ref_path = format!("{}/onecomp_iv_reference.json", VALIDATION_DIR); + let reference = load_reference(&ref_path) + .map_err(|e| anyhow::anyhow!("Failed to load reference: {}", e))?; + + println!( + " Reference loaded: {} subjects, {} observations", + reference.n_subjects, reference.n_observations + ); + + // Create model (same as R) - b[0] is the bolus input term! + let dose = 100.0; + let eq = equation::ODE::new( + |x, p, _t, dx, b, _rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + b[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + ); + + // Generate same data as R (from reference true values) + let true_ke = reference + .true_values + .as_ref() + .and_then(|t| t.ke) + .unwrap_or(0.4); + let true_v = reference + .true_values + .as_ref() + .and_then(|t| t.v) + .unwrap_or(10.0); + let times = vec![0.5, 1.0, 2.0, 4.0, 8.0, 12.0]; + + // Match R's deterministic random effects + let subjects: Vec = (0..reference.n_subjects) + .map(|id| { + let eta_ke = 0.15 * ((id as f64) * 0.5).sin(); + let eta_v = 0.15 * ((id as f64) * 0.7).cos(); + let subj_ke = true_ke * f64::exp(eta_ke); + let subj_v = true_v * f64::exp(eta_v); + + let mut builder = Subject::builder(id.to_string()).bolus(0.0, dose, 0); + for &t in × { + let conc = (dose / subj_v) * f64::exp(-subj_ke * t); + builder = builder.observation(t, conc, 0); + } + builder.build() + }) + .collect(); + let data = Data::new(subjects); + + // Match R settings + let parameter_space = bounded_parameter_space(&[("ke", 0.1, 0.8), ("v", 5.0, 15.0)]); + + // Residual error model (parametric algorithms use ResidualErrorModels) + use pharmsol::{ResidualErrorModel, ResidualErrorModels}; + let residual_error = ResidualErrorModels::new().add(0, ResidualErrorModel::constant(0.1)); + + let config = SaemValidationConfig::new(parameter_space, residual_error); + + // Run algorithm + let fit_result = run_saem_problem(config, eq, data)?; + let result = fit_result + .as_parametric() + .expect("SAEM validation should produce a parametric result"); + + // Compare results + println!("\n === Comparison with R Reference ==="); + + // Population mean (μ is returned in ψ space by into_result()) + let rust_mu_psi: Vec = (0..result.population().npar()) + .map(|i| result.population().mu()[i]) + .collect(); + + println!(" mu (ψ space, Rust): {:?}", rust_mu_psi); + println!(" mu (ψ space, R): {:?}", reference.mu_psi); + + // Compare with tolerance (comparing ψ values) + let mu_rtol = 0.10; // 10% relative tolerance + for (i, (r, ref_val)) in rust_mu_psi.iter().zip(reference.mu_psi.iter()).enumerate() { + let rel_err = (*r - *ref_val).abs() / ref_val.abs().max(1e-10); + println!( + " mu[{}]: Rust={:.4}, R={:.4}, rel_err={:.2}%", + i, + r, + ref_val, + rel_err * 100.0 + ); + + if rel_err > mu_rtol { + println!(" WARNING: Exceeds {}% tolerance!", mu_rtol * 100.0); + } + } + + // Omega diagonal + let rust_omega_diag: Vec = (0..result.population().npar()) + .map(|i| result.population().omega()[(i, i)]) + .collect(); + + println!(" omega_diag (Rust): {:?}", rust_omega_diag); + println!(" omega_diag (R): {:?}", reference.omega_diag); + + // Objective function + println!(" objf (Rust): {:.2}", result.objf()); + println!(" objf (R): {:.2}", reference.objf); + + let objf_rel_err = (result.objf() - reference.objf).abs() / reference.objf.abs(); + println!(" objf rel_err: {:.2}%", objf_rel_err * 100.0); + + // Assertions (with generous tolerance for stochastic algorithm) + assert_vec_close(&rust_mu_psi, &reference.mu_psi, 0.20, "mu_psi"); + + println!("\n ✓ One-compartment IV validation complete\n"); + Ok(()) +} + +/// Test theophylline against R reference +/// Requires: theo_reference.json from generate_reference.R +#[test] +#[ignore = "Full theophylline validation (~5 min) - run with --ignored"] +fn test_validate_theophylline() -> Result<()> { + use pmcore::prelude::*; + + println!("=== Validating Theophylline vs R Reference ==="); + + // Load R reference + let ref_path = format!("{}/theo_reference.json", VALIDATION_DIR); + let reference = load_reference(&ref_path) + .map_err(|e| anyhow::anyhow!("Failed to load reference: {}", e))?; + + println!( + " Reference: {} subjects, {} observations", + reference.n_subjects, reference.n_observations + ); + println!(" R results - mu_psi: {:?}", reference.mu_psi); + println!(" R results - omega_diag: {:?}", reference.omega_diag); + println!(" R results - sigma: {:.4}", reference.sigma); + println!(" R results - objf: {:.2}", reference.objf); + + // One-compartment model with first-order absorption + // Matches R saemix model: dose*ka/(V*(ka-k))*(exp(-k*t) - exp(-ka*t)) + let eq = equation::ODE::new( + |x, p, _t, dx, b, _rateiv, _cov| { + fetch_params!(p, ka, v, cl); + let ke = cl / v; + dx[0] = -ka * x[0] + b[0]; // absorption compartment + bolus + dx[1] = ka * x[0] - ke * x[1]; // central compartment + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, v, _cl); + y[0] = x[1] / v; + }, + ); + + // Theophylline data (12 subjects, same as saemix theo.saemix dataset) + let subjects_data: Vec<(u32, f64, Vec, Vec)> = vec![ + ( + 1, + 319.992, + vec![0.25, 0.57, 1.12, 2.02, 3.82, 5.10, 7.03, 9.05, 12.12, 24.37], + vec![2.84, 6.57, 10.50, 9.66, 8.58, 8.36, 7.47, 6.89, 5.94, 3.28], + ), + ( + 2, + 318.560, + vec![0.27, 0.52, 1.00, 1.92, 3.50, 5.02, 7.03, 9.00, 12.00, 24.30], + vec![1.72, 7.91, 8.31, 8.33, 6.85, 6.08, 5.40, 4.55, 3.01, 0.90], + ), + ( + 3, + 319.365, + vec![0.27, 0.58, 1.02, 2.02, 3.62, 5.08, 7.07, 9.00, 12.15, 24.17], + vec![4.40, 6.90, 8.20, 7.80, 7.50, 6.20, 5.30, 4.90, 3.70, 1.05], + ), + ( + 4, + 319.992, + vec![0.35, 0.60, 1.07, 2.13, 3.50, 5.02, 7.02, 9.02, 11.98, 24.65], + vec![1.89, 4.60, 8.60, 8.38, 7.54, 6.88, 5.78, 5.33, 4.19, 1.15], + ), + ( + 5, + 320.619, + vec![0.30, 0.52, 1.00, 2.02, 3.50, 5.02, 7.02, 9.10, 12.00, 24.35], + vec![2.02, 5.63, 11.40, 9.33, 8.74, 7.56, 7.09, 5.90, 4.37, 1.57], + ), + ( + 6, + 320.619, + vec![0.27, 0.58, 1.15, 2.03, 3.57, 5.00, 7.00, 9.22, 12.10, 23.85], + vec![1.29, 3.08, 6.44, 6.32, 5.53, 4.94, 4.02, 3.46, 2.78, 0.92], + ), + ( + 7, + 277.767, + vec![0.25, 0.50, 1.02, 2.02, 3.48, 5.00, 6.98, 9.00, 12.05, 24.22], + vec![3.59, 6.11, 7.56, 6.54, 5.37, 4.84, 4.02, 3.83, 2.81, 0.85], + ), + ( + 8, + 276.514, + vec![0.25, 0.52, 0.98, 2.02, 3.53, 5.05, 7.15, 9.07, 12.10, 24.12], + vec![0.73, 4.00, 6.81, 8.00, 7.09, 5.89, 5.22, 4.75, 3.41, 0.96], + ), + ( + 9, + 299.550, + vec![0.30, 0.63, 1.05, 2.02, 3.53, 5.02, 7.17, 8.80, 11.60, 24.43], + vec![3.15, 6.96, 9.70, 9.52, 7.17, 6.28, 5.28, 4.66, 3.82, 1.15], + ), + ( + 10, + 298.297, + vec![0.37, 0.77, 1.02, 2.05, 3.55, 5.05, 7.08, 9.00, 12.12, 24.08], + vec![7.37, 9.03, 10.21, 9.18, 8.02, 7.14, 6.08, 5.54, 4.57, 1.17], + ), + ( + 11, + 300.176, + vec![0.25, 0.50, 0.98, 1.98, 3.60, 5.02, 7.03, 9.03, 12.12, 24.28], + vec![0.92, 2.63, 6.85, 9.05, 7.90, 7.44, 6.13, 5.31, 4.10, 1.44], + ), + ( + 12, + 298.297, + vec![0.25, 0.52, 1.00, 2.07, 3.50, 4.95, 7.00, 9.02, 12.00, 24.15], + vec![1.11, 6.33, 9.99, 9.37, 8.50, 6.89, 5.94, 5.26, 4.35, 1.25], + ), + ]; + + let subjects: Vec = subjects_data + .into_iter() + .map(|(id, dose, times, concs)| { + let mut builder = Subject::builder(id.to_string()).bolus(0.0, dose, 0); + for (t, c) in times.into_iter().zip(concs.into_iter()) { + builder = builder.observation(t, c, 0); + } + builder.build() + }) + .collect(); + let data = Data::new(subjects); + + // Parameter ranges (matching R saemix initial values region) + let parameter_space = + bounded_parameter_space(&[("ka", 0.5, 3.0), ("v", 15.0, 50.0), ("cl", 0.5, 5.0)]); + + // Residual error model (constant, matching R) + use pharmsol::{ResidualErrorModel, ResidualErrorModels}; + let residual_error = ResidualErrorModels::new().add(0, ResidualErrorModel::constant(1.0)); + + let config = SaemValidationConfig::new(parameter_space, residual_error); + + // Run SAEM + let fit_result = run_saem_problem(config, eq, data)?; + let result = fit_result + .as_parametric() + .expect("SAEM validation should produce a parametric result"); + + // Compare results + println!("\n === Comparison with R Reference ==="); + + let rust_mu_psi: Vec = (0..result.population().npar()) + .map(|i| result.population().mu()[i]) + .collect(); + + println!(" mu (ψ space, Rust): {:?}", rust_mu_psi); + println!(" mu (ψ space, R): {:?}", reference.mu_psi); + + // Compare each parameter + let param_names = ["ka", "V", "CL"]; + for (i, (r, ref_val)) in rust_mu_psi.iter().zip(reference.mu_psi.iter()).enumerate() { + let rel_err = (*r - *ref_val).abs() / ref_val.abs().max(1e-10); + println!( + " {}: Rust={:.4}, R={:.4}, rel_err={:.2}%", + param_names[i], + r, + ref_val, + rel_err * 100.0 + ); + } + + // Omega diagonal + let rust_omega_diag: Vec = (0..result.population().npar()) + .map(|i| result.population().omega()[(i, i)]) + .collect(); + println!(" omega_diag (Rust): {:?}", rust_omega_diag); + println!(" omega_diag (R): {:?}", reference.omega_diag); + + // Objective function + println!(" objf (Rust): {:.2}", result.objf()); + println!(" objf (R): {:.2}", reference.objf); + + // Assertions: population means should be within 25% for a 3-parameter stochastic algorithm + let tolerance = 0.25; + assert_vec_close( + &rust_mu_psi, + &reference.mu_psi, + tolerance, + "mu_psi (theophylline)", + ); + + println!("\n ✓ Theophylline validation complete\n"); + Ok(()) +} + +// ============================================================================= +// Helper Tests +// ============================================================================= + +/// Verify test data files exist +#[test] +fn test_validation_directory_exists() { + let path = std::path::Path::new(VALIDATION_DIR); + if !path.exists() { + println!("Validation directory does not exist: {}", VALIDATION_DIR); + println!("Run from PMcore root directory"); + } +} diff --git a/tests/saem_validation/theo_data.csv b/tests/saem_validation/theo_data.csv new file mode 100644 index 000000000..1d9ca612a --- /dev/null +++ b/tests/saem_validation/theo_data.csv @@ -0,0 +1,121 @@ +"ID","TIME","DOSE","DV","Weight","Sex" +1,0.25,319.992,2.84,79.6,1 +1,0.57,319.992,6.57,79.6,1 +1,1.12,319.992,10.5,79.6,1 +1,2.02,319.992,9.66,79.6,1 +1,3.82,319.992,8.58,79.6,1 +1,5.1,319.992,8.36,79.6,1 +1,7.03,319.992,7.47,79.6,1 +1,9.05,319.992,6.89,79.6,1 +1,12.12,319.992,5.94,79.6,1 +1,24.37,319.992,3.28,79.6,1 +2,0.27,318.56,1.72,72.4,1 +2,0.52,318.56,7.91,72.4,1 +2,1,318.56,8.31,72.4,1 +2,1.92,318.56,8.33,72.4,1 +2,3.5,318.56,6.85,72.4,1 +2,5.02,318.56,6.08,72.4,1 +2,7.03,318.56,5.4,72.4,1 +2,9,318.56,4.55,72.4,1 +2,12,318.56,3.01,72.4,1 +2,24.3,318.56,0.9,72.4,1 +3,0.27,319.365,4.4,70.5,1 +3,0.58,319.365,6.9,70.5,1 +3,1.02,319.365,8.2,70.5,1 +3,2.02,319.365,7.8,70.5,1 +3,3.62,319.365,7.5,70.5,1 +3,5.08,319.365,6.2,70.5,1 +3,7.07,319.365,5.3,70.5,1 +3,9,319.365,4.9,70.5,1 +3,12.15,319.365,3.7,70.5,1 +3,24.17,319.365,1.05,70.5,1 +4,0.35,319.88,1.89,72.7,1 +4,0.6,319.88,4.6,72.7,1 +4,1.07,319.88,8.6,72.7,1 +4,2.13,319.88,8.38,72.7,1 +4,3.5,319.88,7.54,72.7,1 +4,5.02,319.88,6.88,72.7,1 +4,7.02,319.88,5.78,72.7,1 +4,9.02,319.88,5.33,72.7,1 +4,11.98,319.88,4.19,72.7,1 +4,24.65,319.88,1.15,72.7,1 +5,0.3,319.956,2.02,54.6,0 +5,0.52,319.956,5.63,54.6,0 +5,1,319.956,11.4,54.6,0 +5,2.02,319.956,9.33,54.6,0 +5,3.5,319.956,8.74,54.6,0 +5,5.02,319.956,7.56,54.6,0 +5,7.02,319.956,7.09,54.6,0 +5,9.1,319.956,5.9,54.6,0 +5,12,319.956,4.37,54.6,0 +5,24.35,319.956,1.57,54.6,0 +6,0.27,320,1.29,80,1 +6,0.58,320,3.08,80,1 +6,1.15,320,6.44,80,1 +6,2.03,320,6.32,80,1 +6,3.57,320,5.53,80,1 +6,5,320,4.94,80,1 +6,7,320,4.02,80,1 +6,9.22,320,3.46,80,1 +6,12.1,320,2.78,80,1 +6,23.85,320,0.92,80,1 +7,0.25,319.77,0.85,64.6,0 +7,0.5,319.77,2.35,64.6,0 +7,1.02,319.77,5.02,64.6,0 +7,2.02,319.77,6.58,64.6,0 +7,3.48,319.77,7.09,64.6,0 +7,5,319.77,6.66,64.6,0 +7,6.98,319.77,5.25,64.6,0 +7,9,319.77,4.39,64.6,0 +7,12.05,319.77,3.53,64.6,0 +7,24.22,319.77,1.15,64.6,0 +8,0.25,319.365,3.05,70.5,1 +8,0.52,319.365,3.05,70.5,1 +8,0.98,319.365,7.31,70.5,1 +8,2.02,319.365,7.56,70.5,1 +8,3.53,319.365,6.59,70.5,1 +8,5.05,319.365,5.88,70.5,1 +8,7.15,319.365,4.73,70.5,1 +8,9.07,319.365,4.57,70.5,1 +8,12.1,319.365,3,70.5,1 +8,24.12,319.365,1.25,70.5,1 +9,0.3,267.84,7.37,86.4,1 +9,0.63,267.84,9.03,86.4,1 +9,1.05,267.84,7.14,86.4,1 +9,2.02,267.84,6.33,86.4,1 +9,3.53,267.84,5.66,86.4,1 +9,5.02,267.84,5.67,86.4,1 +9,7.17,267.84,4.24,86.4,1 +9,8.8,267.84,4.11,86.4,1 +9,11.6,267.84,3.16,86.4,1 +9,24.43,267.84,1.12,86.4,1 +10,0.37,320.1,2.89,58.2,0 +10,0.77,320.1,5.22,58.2,0 +10,1.02,320.1,6.41,58.2,0 +10,2.05,320.1,7.83,58.2,0 +10,3.55,320.1,10.21,58.2,0 +10,5.05,320.1,9.18,58.2,0 +10,7.08,320.1,8.02,58.2,0 +10,9.38,320.1,7.14,58.2,0 +10,12.1,320.1,5.68,58.2,0 +10,23.7,320.1,2.42,58.2,0 +11,0.25,319.8,4.86,65,0 +11,0.5,319.8,7.24,65,0 +11,0.98,319.8,8,65,0 +11,1.98,319.8,6.81,65,0 +11,3.6,319.8,5.87,65,0 +11,5.02,319.8,5.22,65,0 +11,7.03,319.8,4.45,65,0 +11,9.03,319.8,3.62,65,0 +11,12.12,319.8,2.69,65,0 +11,24.08,319.8,0.86,65,0 +12,0.25,320.65,1.25,60.5,0 +12,0.5,320.65,3.96,60.5,0 +12,1,320.65,7.82,60.5,0 +12,2,320.65,9.72,60.5,0 +12,3.52,320.65,9.75,60.5,0 +12,5.07,320.65,8.57,60.5,0 +12,7.07,320.65,6.59,60.5,0 +12,9.03,320.65,6.11,60.5,0 +12,12.05,320.65,4.57,60.5,0 +12,24.15,320.65,1.17,60.5,0 diff --git a/tests/saem_validation/theo_reference.json b/tests/saem_validation/theo_reference.json new file mode 100644 index 000000000..7eb210d3c --- /dev/null +++ b/tests/saem_validation/theo_reference.json @@ -0,0 +1,70 @@ +{ + "test_case": "theophylline", + "description": "One-compartment oral absorption (ka, V, CL)", + "mu_psi": [1.5752, 31.5399, 2.7479], + "mu_phi": [0.4544, 3.4513, 1.0108], + "omega": [ + [0.4026, 0, 0], + [0, 0.0188, 0], + [0, 0, 0.0684] + ], + "omega_diag": [0.4026, 0.0188, 0.0684], + "sigma": 0.7338, + "ll_lin": -172.0081, + "objf": 344.0162, + "map_psi": [ + [1.7268, 28.9774, 1.7266], + [1.9312, 31.9102, 3.1752], + [2.2839, 33.4085, 2.8363], + [1.198, 31.3477, 2.7048], + [1.5255, 27.5808, 2.3695], + [1.068, 38.2357, 3.9864], + [0.7339, 33.503, 3.1823], + [1.3354, 34.7554, 3.2453], + [6.0641, 31.7074, 2.8831], + [0.768, 26.838, 1.8832], + [3.3267, 36.2483, 3.665], + [0.9526, 26.1025, 2.423] + ], + "map_eta": [ + [0.0918, -0.0847, -0.4647], + [0.2037, 0.0117, 0.1445], + [0.3715, 0.0576, 0.0317], + [-0.2738, -0.0061, -0.0158], + [-0.0321, -0.1341, -0.1481], + [-0.3886, 0.1925, 0.372], + [-0.7637, 0.0604, 0.1468], + [-0.1652, 0.0971, 0.1664], + [1.348, 0.0053, 0.048], + [-0.7184, -0.1614, -0.3779], + [0.7476, 0.1391, 0.288], + [-0.5029, -0.1892, -0.1258] + ], + "cond_mean_phi": [ + [0.5525, 3.3671, 0.5421], + [0.6603, 3.4682, 1.1381], + [0.832, 3.514, 1.0354], + [0.1592, 3.4388, 0.9982], + [0.4227, 3.3147, 0.8644], + [0.0692, 3.6485, 1.3715], + [-0.3144, 3.5112, 1.1502], + [0.3001, 3.5551, 1.1645], + [1.8878, 3.4633, 1.0583], + [-0.2652, 3.2872, 0.6233], + [1.2058, 3.5915, 1.2913], + [-0.0572, 3.2556, 0.8927] + ], + "settings": { + "seed": 12345, + "n_burn": 50, + "n_sa": 300, + "n_smooth": 100, + "n_chains": 3, + "transform_par": [1, 1, 1], + "error_model": "constant", + "initial_psi": [1.5, 32, 3], + "initial_omega_diag": [0.5, 0.05, 0.1] + }, + "n_subjects": 12, + "n_observations": 120 +} diff --git a/tests/saem_validation_tests.rs b/tests/saem_validation_tests.rs new file mode 100644 index 000000000..a7c80cf30 --- /dev/null +++ b/tests/saem_validation_tests.rs @@ -0,0 +1,7 @@ +//! SAEM Validation Tests +//! +//! Tests that compare PMcore SAEM against R saemix reference values. +//! Run with: cargo test --test saem_validation_tests + +mod saem_validation; +pub use saem_validation::*; From 886881fb84cd2acfe32c8ed6038b6ca4cc8c6d8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Thu, 2 Apr 2026 10:26:50 +0100 Subject: [PATCH 2/2] style: run cargo fmt --- src/api/mod.rs | 4 ++-- src/lib.rs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/api/mod.rs b/src/api/mod.rs index fe7806c5b..af7787464 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -6,8 +6,8 @@ pub mod saem_config; pub use estimation_problem::{ AlgorithmTuning, ConvergenceOptions, EstimationMethod, EstimationProblem, EstimationProblemBuilder, FoceiOptions, It2bOptions, LoggingLevel, LoggingOptions, - NonparametricMethod, NpagOptions, NpodOptions, OutputPlan, - ParametricMethod, PostProbOptions, RuntimeOptions, SaemOptions, + NonparametricMethod, NpagOptions, NpodOptions, OutputPlan, ParametricMethod, PostProbOptions, + RuntimeOptions, SaemOptions, }; pub use fit::fit; pub use model_definition::{ModelDefinition, ModelDefinitionBuilder}; diff --git a/src/lib.rs b/src/lib.rs index f2b5ae582..a93d0953d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -70,8 +70,8 @@ pub mod prelude { pub use crate::api::{ AlgorithmTuning, ConvergenceOptions, EstimationMethod, EstimationProblem, FoceiOptions, It2bOptions, LoggingLevel, LoggingOptions, ModelDefinition, NonparametricMethod, - NpagOptions, NpodOptions, OutputPlan, - ParametricMethod, PostProbOptions, RuntimeOptions, SaemOptions, + NpagOptions, NpodOptions, OutputPlan, ParametricMethod, PostProbOptions, RuntimeOptions, + SaemOptions, }; pub use crate::compile::{CompiledProblem, DesignContext, ObservationIndex}; pub use crate::estimation::nonparametric::{