diff --git a/examples/bimodal_ke/main.rs b/examples/bimodal_ke/main.rs index 0ea9bb68c..bbb0a95fc 100644 --- a/examples/bimodal_ke/main.rs +++ b/examples/bimodal_ke/main.rs @@ -46,5 +46,9 @@ fn main() -> Result<()> { let mut result = algorithm.fit()?; result.write_outputs()?; + if let Some(m) = result.metrics() { + println!("{}", m); + } + Ok(()) } diff --git a/src/routines/output/mod.rs b/src/routines/output/mod.rs index 3835904d2..3c5d068fc 100644 --- a/src/routines/output/mod.rs +++ b/src/routines/output/mod.rs @@ -483,6 +483,24 @@ impl NPResult { tracing::debug!("Covariates written to {:?}", &outputfile.relative_path()); Ok(()) } + + /// Get a reference to the predictions, if they have been calculated + pub fn predictions(&self) -> Option<&NPPredictions> { + self.predictions.as_ref() + } + + /// Compute prediction metrics, calculating predictions first if needed + /// + /// Uses the `idelta` and `tad` values from the current [Settings]. + /// Returns `None` if there are no valid observation-prediction pairs. + pub fn metrics(&mut self) -> Option { + if self.predictions.is_none() { + let idelta = self.settings.predictions().idelta; + let tad = self.settings.predictions().tad; + self.calculate_predictions(idelta, tad).ok()?; + } + self.predictions.as_ref().and_then(|p| p.metrics()) + } } pub(crate) fn median(data: &[f64]) -> f64 { diff --git a/src/routines/output/predictions.rs b/src/routines/output/predictions.rs index 64f78e41b..f9c969688 100644 --- a/src/routines/output/predictions.rs +++ b/src/routines/output/predictions.rs @@ -237,4 +237,407 @@ impl NPPredictions { Ok(container) } + + /// Compute prediction performance metrics for all prediction types + /// + /// Only uncensored observations (`Censor::None`) with a non-`None` observed value are included. + /// Returns `None` if there are no valid observation-prediction pairs. + pub fn metrics(&self) -> Option { + let cap = self.predictions.len(); + let mut obs_vals = Vec::with_capacity(cap); + let mut pop_mean_vals = Vec::with_capacity(cap); + let mut pop_median_vals = Vec::with_capacity(cap); + let mut post_mean_vals = Vec::with_capacity(cap); + let mut post_median_vals = Vec::with_capacity(cap); + let mut subject_ids = std::collections::HashSet::new(); + + for row in &self.predictions { + if row.cens != Censor::None { + continue; + } + if let Some(o) = row.obs { + obs_vals.push(o); + pop_mean_vals.push(row.pop_mean); + pop_median_vals.push(row.pop_median); + post_mean_vals.push(row.post_mean); + post_median_vals.push(row.post_median); + subject_ids.insert(row.id.clone()); + } + } + + if obs_vals.is_empty() { + return None; + } + + Some(PredictionMetrics { + n_subjects: subject_ids.len(), + pop_mean: ErrorMetrics::compute(&obs_vals, &pop_mean_vals), + pop_median: ErrorMetrics::compute(&obs_vals, &pop_median_vals), + post_mean: ErrorMetrics::compute(&obs_vals, &post_mean_vals), + post_median: ErrorMetrics::compute(&obs_vals, &post_median_vals), + }) + } +} + +/// Metrics for a single prediction type (e.g. population mean, posterior median) +/// +/// Percentage metrics (`bias_pct`, `imprecision_pct`, `rmse_pct`) are computed only +/// for observation-prediction pairs where obs > 0. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ErrorMetrics { + /// Number of observations used + pub n: usize, + /// Number of observations excluded from percentage metrics (obs <= 0) + pub n_excluded: usize, + /// Bias: mean(pred - obs) + pub bias: f64, + /// Imprecision: standard deviation of (pred - obs) + pub imprecision: f64, + /// Root mean squared error: sqrt(mean((pred - obs)²)) + pub rmse: f64, + /// Coefficient of determination (R²) + pub r_squared: f64, + /// Relative bias (%): mean((pred - obs) / obs) * 100, for obs > 0 + pub bias_pct: f64, + /// Relative imprecision (%): SD of ((pred - obs) / obs) * 100, for obs > 0 + pub imprecision_pct: f64, + /// Relative RMSE (%): sqrt(mean(((pred - obs) / obs)²)) * 100, for obs > 0 + pub rmse_pct: f64, +} + +impl ErrorMetrics { + /// Compute error metrics from paired observations and predictions. + /// + /// Percentage metrics only include pairs where obs > 0. + fn compute(obs: &[f64], pred: &[f64]) -> Self { + let n = obs.len(); + assert_eq!(n, pred.len()); + + if n == 0 { + return ErrorMetrics { + n: 0, + n_excluded: 0, + bias: f64::NAN, + imprecision: f64::NAN, + rmse: f64::NAN, + r_squared: f64::NAN, + bias_pct: f64::NAN, + imprecision_pct: f64::NAN, + rmse_pct: f64::NAN, + }; + } + + let nf = n as f64; + let mut sum_err = 0.0; + let mut sum_sq_err = 0.0; + let mut rel_errors: Vec = Vec::new(); + + for (&o, &p) in obs.iter().zip(pred.iter()) { + let err = p - o; + sum_err += err; + sum_sq_err += err * err; + if o > 0.0 { + rel_errors.push(err / o); + } + } + + let bias = sum_err / nf; + let imprecision = (sum_sq_err / nf - bias * bias).max(0.0).sqrt(); + let rmse = (sum_sq_err / nf).sqrt(); + + // R²: 1 - SS_res / SS_tot + let obs_mean = obs.iter().sum::() / nf; + let ss_tot: f64 = obs.iter().map(|&o| (o - obs_mean).powi(2)).sum(); + let r_squared = if ss_tot > 0.0 { + 1.0 - sum_sq_err / ss_tot + } else { + f64::NAN + }; + + let (bias_pct, imprecision_pct, rmse_pct) = if !rel_errors.is_empty() { + let n_rel = rel_errors.len() as f64; + let mean_rel: f64 = rel_errors.iter().sum::() / n_rel; + let mean_sq_rel: f64 = rel_errors.iter().map(|e| e * e).sum::() / n_rel; + let var_rel = (mean_sq_rel - mean_rel * mean_rel).max(0.0); + ( + mean_rel * 100.0, + var_rel.sqrt() * 100.0, + mean_sq_rel.sqrt() * 100.0, + ) + } else { + (f64::NAN, f64::NAN, f64::NAN) + }; + + let n_excluded = n - rel_errors.len(); + + ErrorMetrics { + n, + n_excluded, + bias, + imprecision, + rmse, + r_squared, + bias_pct, + imprecision_pct, + rmse_pct, + } + } +} + +/// Prediction performance metrics for all prediction types +/// +/// Contains [ErrorMetrics] for each of the four prediction types computed by `NPPredictions`: +/// population mean, population median, posterior mean, and posterior median. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PredictionMetrics { + /// Number of unique subjects included + pub n_subjects: usize, + /// Metrics for population mean predictions + pub pop_mean: ErrorMetrics, + /// Metrics for population median predictions + pub pop_median: ErrorMetrics, + /// Metrics for posterior mean predictions + pub post_mean: ErrorMetrics, + /// Metrics for posterior median predictions + pub post_median: ErrorMetrics, +} + +impl std::fmt::Display for PredictionMetrics { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let m = [ + &self.pop_mean, + &self.pop_median, + &self.post_mean, + &self.post_median, + ]; + let w = 14; // column width + + write!( + f, + "Prediction Metrics ({} subjects, {} observations", + self.n_subjects, m[0].n + )?; + if m[0].n_excluded > 0 { + write!( + f, + ", {} with obs <= 0 excluded from relative metrics", + m[0].n_excluded + )?; + } + writeln!(f, ")")?; + writeln!( + f, + "{:<16}{:>w$}{:>w$}{:>w$}{:>w$}", + "", "Pop. Mean", "Pop. Median", "Post. Mean", "Post. Median", + )?; + + writeln!( + f, + "{:<16}{:>w$.4}{:>w$.4}{:>w$.4}{:>w$.4}", + "Bias", m[0].bias, m[1].bias, m[2].bias, m[3].bias, + )?; + writeln!( + f, + "{:<16}{:>w$.4}{:>w$.4}{:>w$.4}{:>w$.4}", + "Imprecision", m[0].imprecision, m[1].imprecision, m[2].imprecision, m[3].imprecision, + )?; + writeln!( + f, + "{:<16}{:>w$.4}{:>w$.4}{:>w$.4}{:>w$.4}", + "RMSE", m[0].rmse, m[1].rmse, m[2].rmse, m[3].rmse, + )?; + writeln!( + f, + "{:<16}{:>w$.4}{:>w$.4}{:>w$.4}{:>w$.4}", + "R²", m[0].r_squared, m[1].r_squared, m[2].r_squared, m[3].r_squared, + )?; + + // Percentage metrics + writeln!( + f, + "{:<16}{:>w$.2}{:>w$.2}{:>w$.2}{:>w$.2}", + "Bias%", m[0].bias_pct, m[1].bias_pct, m[2].bias_pct, m[3].bias_pct, + )?; + writeln!( + f, + "{:<16}{:>w$.2}{:>w$.2}{:>w$.2}{:>w$.2}", + "Imprecision%", + m[0].imprecision_pct, + m[1].imprecision_pct, + m[2].imprecision_pct, + m[3].imprecision_pct, + )?; + write!( + f, + "{:<16}{:>w$.2}{:>w$.2}{:>w$.2}{:>w$.2}", + "RMSE%", m[0].rmse_pct, m[1].rmse_pct, m[2].rmse_pct, m[3].rmse_pct, + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Helper: build an NPPredictions from raw tuples. + /// Each tuple is (id, obs, cens, pop_mean, pop_median, post_mean, post_median). + fn make_predictions(rows: &[(&str, Option, Censor, f64, f64, f64, f64)]) -> NPPredictions { + let mut preds = NPPredictions::new(); + for (id, obs, cens, pm, pmed, ptm, ptmed) in rows { + preds.add(NPPredictionRow { + id: id.to_string(), + time: 0.0, + outeq: 1, + block: 1, + obs: *obs, + cens: *cens, + pop_mean: *pm, + pop_median: *pmed, + post_mean: *ptm, + post_median: *ptmed, + }); + } + preds + } + + // ── ErrorMetrics::compute ─────────────────────────────────────────── + + #[test] + fn error_metrics_empty_input() { + let m = ErrorMetrics::compute(&[], &[]); + assert_eq!(m.n, 0); + assert_eq!(m.n_excluded, 0); + assert!(m.bias.is_nan()); + assert!(m.rmse.is_nan()); + } + + #[test] + fn error_metrics_perfect_predictions() { + let obs = vec![1.0, 2.0, 3.0]; + let pred = vec![1.0, 2.0, 3.0]; + let m = ErrorMetrics::compute(&obs, &pred); + + assert_eq!(m.n, 3); + assert_eq!(m.n_excluded, 0); + assert!((m.bias).abs() < 1e-12); + assert!((m.imprecision).abs() < 1e-12); + assert!((m.rmse).abs() < 1e-12); + assert!((m.r_squared - 1.0).abs() < 1e-12); + assert!((m.bias_pct).abs() < 1e-12); + assert!((m.imprecision_pct).abs() < 1e-12); + assert!((m.rmse_pct).abs() < 1e-12); + } + + #[test] + fn error_metrics_constant_offset() { + // pred = obs + 1 for all points + let obs = vec![2.0, 4.0, 6.0]; + let pred = vec![3.0, 5.0, 7.0]; + let m = ErrorMetrics::compute(&obs, &pred); + + assert_eq!(m.n, 3); + assert_eq!(m.n_excluded, 0); + assert!((m.bias - 1.0).abs() < 1e-12); + assert!((m.rmse - 1.0).abs() < 1e-12); + // Imprecision (SD of errors) should be 0 since all errors are identical + assert!(m.imprecision.abs() < 1e-12); + } + + #[test] + fn error_metrics_excludes_non_positive_obs_from_pct() { + let obs = vec![0.0, -1.0, 2.0, 4.0]; + let pred = vec![0.5, -0.5, 2.5, 4.5]; + let m = ErrorMetrics::compute(&obs, &pred); + + assert_eq!(m.n, 4); + assert_eq!(m.n_excluded, 2); // obs=0.0 and obs=-1.0 + + // Absolute metrics use all 4 pairs + assert!((m.bias - 0.5).abs() < 1e-12); + + // Percentage metrics use only obs=2.0 and obs=4.0 + // rel errors: 0.5/2.0 = 0.25, 0.5/4.0 = 0.125 + let expected_bias_pct = (0.25 + 0.125) / 2.0 * 100.0; + assert!((m.bias_pct - expected_bias_pct).abs() < 1e-10); + } + + #[test] + fn error_metrics_all_non_positive_obs() { + let obs = vec![0.0, -1.0]; + let pred = vec![0.5, -0.5]; + let m = ErrorMetrics::compute(&obs, &pred); + + assert_eq!(m.n, 2); + assert_eq!(m.n_excluded, 2); + assert!(m.bias_pct.is_nan()); + assert!(m.imprecision_pct.is_nan()); + assert!(m.rmse_pct.is_nan()); + } + + #[test] + fn error_metrics_r_squared_constant_obs() { + // All obs the same → SS_tot = 0 → R² = NaN + let obs = vec![5.0, 5.0, 5.0]; + let pred = vec![5.0, 6.0, 4.0]; + let m = ErrorMetrics::compute(&obs, &pred); + assert!(m.r_squared.is_nan()); + } + + #[test] + fn metrics_returns_none_when_no_observations() { + let preds = make_predictions(&[("S1", None, Censor::None, 1.0, 1.0, 1.0, 1.0)]); + assert!(preds.metrics().is_none()); + } + + #[test] + fn metrics_skips_censored_rows() { + let preds = make_predictions(&[("S1", Some(10.0), Censor::BLOQ, 10.0, 10.0, 10.0, 10.0)]); + assert!(preds.metrics().is_none()); + } + + #[test] + fn metrics_counts_subjects() { + let preds = make_predictions(&[ + ("S1", Some(1.0), Censor::None, 1.0, 1.0, 1.0, 1.0), + ("S1", Some(2.0), Censor::None, 2.0, 2.0, 2.0, 2.0), + ("S2", Some(3.0), Censor::None, 3.0, 3.0, 3.0, 3.0), + ]); + let m = preds.metrics().unwrap(); + assert_eq!(m.n_subjects, 2); + assert_eq!(m.pop_mean.n, 3); + } + + #[test] + fn metrics_routes_predictions_correctly() { + // Use distinct prediction values per type so we can verify routing + let preds = make_predictions(&[("S1", Some(10.0), Censor::None, 11.0, 12.0, 13.0, 14.0)]); + let m = preds.metrics().unwrap(); + + assert!((m.pop_mean.bias - 1.0).abs() < 1e-12); + assert!((m.pop_median.bias - 2.0).abs() < 1e-12); + assert!((m.post_mean.bias - 3.0).abs() < 1e-12); + assert!((m.post_median.bias - 4.0).abs() < 1e-12); + } + + #[test] + fn display_header_without_exclusions() { + let preds = make_predictions(&[("S1", Some(1.0), Censor::None, 1.0, 1.0, 1.0, 1.0)]); + let m = preds.metrics().unwrap(); + let output = format!("{}", m); + assert!(output.contains("1 subjects")); + assert!(output.contains("1 observations")); + assert!(!output.contains("excluded")); + } + + #[test] + fn display_header_with_exclusions() { + let preds = make_predictions(&[ + ("S1", Some(0.0), Censor::None, 0.5, 0.5, 0.5, 0.5), + ("S1", Some(2.0), Censor::None, 2.5, 2.5, 2.5, 2.5), + ]); + let m = preds.metrics().unwrap(); + let output = format!("{}", m); + assert!(output.contains("2 observations")); + assert!(output.contains("1 with obs <= 0 excluded from relative metrics")); + } }