Skip to main content

scirs2_stats/
compositional.rs

1//! Compositional Data Analysis in the Aitchison Simplex.
2//!
3//! Compositional data are vectors of strictly positive components whose values carry
4//! only relative (not absolute) information, so only ratios between components are
5//! meaningful.  The appropriate sample space is the **D-part simplex** S^D.
6//!
7//! This module provides:
8//!
9//! - **Simplex operations**: closure, perturbation (Aitchison addition), powering
10//! - **Log-ratio transforms**: ALR, CLR, ILR and their inverses
11//! - **Aitchison geometry**: inner product, norm, distance
12//! - **Dirichlet regression**: IRLS estimation of Dirichlet GLM
13//! - **Compositional PCA**: PCA in Aitchison geometry via CLR
14//! - **Statistical tests**: neutrality test, Dirichlet MLE
15//!
16//! # References
17//! - Aitchison, J. (1986). *The Statistical Analysis of Compositional Data*. Chapman & Hall.
18//! - Pawlowsky-Glahn, V., Egozcue, J.J., Tolosana-Delgado, R. (2015).
19//!   *Modelling and Analysis of Compositional Data*. Wiley.
20//! - Egozcue, J.J., Pawlowsky-Glahn, V. (2005). Groups of Parts and Their Balances in
21//!   Compositional Data Analysis. *Mathematical Geology*, 37(7), 795–828.
22
23use std::fmt;
24
25use crate::error::{StatsError, StatsResult};
26
27// ---------------------------------------------------------------------------
28// Internal helpers
29// ---------------------------------------------------------------------------
30
31/// Check that a composition is strictly positive (no zeros or negatives).
32fn check_positive(x: &[f64], name: &str) -> StatsResult<()> {
33    for (i, &v) in x.iter().enumerate() {
34        if v <= 0.0 {
35            return Err(StatsError::InvalidArgument(format!(
36                "{name}[{i}] = {v} is not strictly positive; \
37                 all components must be > 0 for compositional analysis"
38            )));
39        }
40    }
41    Ok(())
42}
43
44/// Geometric mean of a strictly-positive slice.
45#[inline]
46fn geometric_mean(x: &[f64]) -> f64 {
47    let log_sum: f64 = x.iter().map(|&v| v.ln()).sum();
48    (log_sum / x.len() as f64).exp()
49}
50
51// ---------------------------------------------------------------------------
52// Simplex operations
53// ---------------------------------------------------------------------------
54
55/// Closure: normalise a composition so that its components sum to 1.
56///
57/// Equivalent to dividing each component by the total sum.  This is the
58/// canonical projection onto S^D.
59///
60/// # Errors
61/// Returns [`StatsError::InvalidArgument`] if any component is non-positive,
62/// or if the sum is zero (degenerate composition).
63///
64/// # Examples
65/// ```
66/// use scirs2_stats::compositional::closure;
67/// let x = vec![1.0, 2.0, 3.0];
68/// let c = closure(&x).unwrap();
69/// let sum: f64 = c.iter().sum();
70/// assert!((sum - 1.0).abs() < 1e-14);
71/// ```
72pub fn closure(x: &[f64]) -> StatsResult<Vec<f64>> {
73    check_positive(x, "x")?;
74    let total: f64 = x.iter().sum();
75    if total == 0.0 {
76        return Err(StatsError::InvalidArgument(
77            "closure: sum of components is zero".into(),
78        ));
79    }
80    Ok(x.iter().map(|&v| v / total).collect())
81}
82
83/// Perturbation: Aitchison addition in the simplex.
84///
85/// Defined as C(x₁·y₁, …, xD·yD) where C denotes closure.  This is the
86/// group operation that makes (S^D, ⊕) an abelian group.
87///
88/// # Errors
89/// Returns an error if either `x` or `y` have non-positive components, or if
90/// they differ in length.
91///
92/// # Examples
93/// ```
94/// use scirs2_stats::compositional::perturbation;
95/// let x = vec![0.5, 0.3, 0.2];
96/// let y = vec![0.4, 0.4, 0.2];
97/// let p = perturbation(&x, &y).unwrap();
98/// let sum: f64 = p.iter().sum();
99/// assert!((sum - 1.0).abs() < 1e-14);
100/// ```
101pub fn perturbation(x: &[f64], y: &[f64]) -> StatsResult<Vec<f64>> {
102    if x.len() != y.len() {
103        return Err(StatsError::DimensionMismatch(format!(
104            "perturbation: x has {} components but y has {}",
105            x.len(),
106            y.len()
107        )));
108    }
109    check_positive(x, "x")?;
110    check_positive(y, "y")?;
111    let product: Vec<f64> = x.iter().zip(y.iter()).map(|(&a, &b)| a * b).collect();
112    closure(&product)
113}
114
115/// Powering: scalar multiplication in the simplex.
116///
117/// Defined as C(x₁^α, …, xD^α).  Together with perturbation, this gives S^D
118/// the structure of a real vector space.
119///
120/// # Errors
121/// Returns an error if any component of `x` is non-positive.
122///
123/// # Examples
124/// ```
125/// use scirs2_stats::compositional::powering;
126/// let x = vec![0.5, 0.3, 0.2];
127/// let p = powering(&x, 2.0).unwrap();
128/// let sum: f64 = p.iter().sum();
129/// assert!((sum - 1.0).abs() < 1e-14);
130/// ```
131pub fn powering(x: &[f64], alpha: f64) -> StatsResult<Vec<f64>> {
132    check_positive(x, "x")?;
133    let powered: Vec<f64> = x.iter().map(|&v| v.powf(alpha)).collect();
134    closure(&powered)
135}
136
137// ---------------------------------------------------------------------------
138// Log-ratio transforms
139// ---------------------------------------------------------------------------
140
141/// Additive Log-Ratio (ALR) transform.
142///
143/// Maps a D-part composition x ∈ S^D to ℝ^{D−1} by taking log-ratios with
144/// respect to the last component (the *reference* component):
145///
146/// ALR(x)ⱼ = ln(xⱼ / xD),   j = 1, …, D−1
147///
148/// The ALR is not isometric (distances are not preserved), but it is the most
149/// computationally convenient transform for regression with a fixed reference.
150///
151/// # Errors
152/// Returns an error if any component is non-positive.
153///
154/// # Examples
155/// ```
156/// use scirs2_stats::compositional::{alr_transform, alr_inverse};
157/// let x = vec![0.5, 0.3, 0.2];
158/// let y = alr_transform(&x).unwrap();
159/// assert_eq!(y.len(), 2);
160/// let x2 = alr_inverse(&y).unwrap();
161/// let diff: f64 = x.iter().zip(x2.iter()).map(|(a, b)| (a - b).powi(2)).sum::<f64>().sqrt();
162/// assert!(diff < 1e-12);
163/// ```
164pub fn alr_transform(x: &[f64]) -> StatsResult<Vec<f64>> {
165    let d = x.len();
166    if d < 2 {
167        return Err(StatsError::InvalidArgument(
168            "ALR requires at least 2 components".into(),
169        ));
170    }
171    check_positive(x, "x")?;
172    let ref_val = x[d - 1];
173    Ok(x[..d - 1].iter().map(|&v| (v / ref_val).ln()).collect())
174}
175
176/// Inverse ALR transform: maps ℝ^{D−1} back to S^D.
177///
178/// # Errors
179/// Returns an error if `y` is empty.
180///
181/// # Examples
182/// ```
183/// use scirs2_stats::compositional::{alr_transform, alr_inverse};
184/// let x = vec![0.2, 0.5, 0.3];
185/// let recovered = alr_inverse(&alr_transform(&x).unwrap()).unwrap();
186/// let diff: f64 = x.iter().zip(recovered.iter()).map(|(a, b)| (a - b).powi(2)).sum::<f64>().sqrt();
187/// assert!(diff < 1e-12);
188/// ```
189pub fn alr_inverse(y: &[f64]) -> StatsResult<Vec<f64>> {
190    if y.is_empty() {
191        return Err(StatsError::InvalidArgument(
192            "alr_inverse: input must be non-empty".into(),
193        ));
194    }
195    // Reconstruct: set last component to 1 then exponentiate and close
196    let mut raw: Vec<f64> = y.iter().map(|&v| v.exp()).collect();
197    raw.push(1.0_f64); // reference component
198    closure(&raw)
199}
200
201/// Centered Log-Ratio (CLR) transform.
202///
203/// Maps x ∈ S^D to ℝ^D by subtracting the log geometric mean:
204///
205/// CLR(x)ⱼ = ln(xⱼ) − (1/D)Σₖ ln(xₖ) = ln(xⱼ / g(x))
206///
207/// The CLR is isometric up to the constraint that the components sum to zero.
208/// It preserves Aitchison distances.
209///
210/// # Errors
211/// Returns an error if any component is non-positive.
212///
213/// # Examples
214/// ```
215/// use scirs2_stats::compositional::clr_transform;
216/// let x = vec![0.5, 0.3, 0.2];
217/// let y = clr_transform(&x).unwrap();
218/// let sum: f64 = y.iter().sum();
219/// assert!(sum.abs() < 1e-13);
220/// ```
221pub fn clr_transform(x: &[f64]) -> StatsResult<Vec<f64>> {
222    if x.is_empty() {
223        return Err(StatsError::InvalidArgument(
224            "CLR requires at least 1 component".into(),
225        ));
226    }
227    check_positive(x, "x")?;
228    let gm = geometric_mean(x);
229    Ok(x.iter().map(|&v| (v / gm).ln()).collect())
230}
231
232/// Inverse CLR transform: maps ℝ^D back to S^D.
233///
234/// # Errors
235/// Returns an error if `y` is empty.
236///
237/// # Examples
238/// ```
239/// use scirs2_stats::compositional::{clr_transform, clr_inverse};
240/// let x = vec![0.4, 0.4, 0.2];
241/// let recovered = clr_inverse(&clr_transform(&x).unwrap()).unwrap();
242/// let diff: f64 = x.iter().zip(recovered.iter()).map(|(a, b)| (a - b).powi(2)).sum::<f64>().sqrt();
243/// assert!(diff < 1e-12);
244/// ```
245pub fn clr_inverse(y: &[f64]) -> StatsResult<Vec<f64>> {
246    if y.is_empty() {
247        return Err(StatsError::InvalidArgument(
248            "clr_inverse: input must be non-empty".into(),
249        ));
250    }
251    let raw: Vec<f64> = y.iter().map(|&v| v.exp()).collect();
252    closure(&raw)
253}
254
255/// Isometric Log-Ratio (ILR) transform.
256///
257/// Maps a D-part composition to ℝ^{D−1} using the Helmert-type sequential
258/// binary partition (SBP) basis of Egozcue & Pawlowsky-Glahn (2005).
259///
260/// The ILR is a true isometry: it preserves all Aitchison distances and inner
261/// products.  Each ILR coordinate is a "balance" between two groups of parts.
262///
263/// The default SBP groups parts sequentially:
264///   - Balance 1: {x₁} vs {x₂, …, xD}
265///   - Balance k: {xₖ} vs {xₖ₊₁, …, xD}
266///
267/// # Errors
268/// Returns an error if any component is non-positive.
269///
270/// # Examples
271/// ```
272/// use scirs2_stats::compositional::{ilr_transform, ilr_inverse};
273/// let x = vec![0.5, 0.3, 0.2];
274/// let y = ilr_transform(&x).unwrap();
275/// assert_eq!(y.len(), 2);
276/// let recovered = ilr_inverse(&y, 3).unwrap();
277/// let diff: f64 = x.iter().zip(recovered.iter()).map(|(a, b)| (a - b).powi(2)).sum::<f64>().sqrt();
278/// assert!(diff < 1e-10);
279/// ```
280pub fn ilr_transform(x: &[f64]) -> StatsResult<Vec<f64>> {
281    let d = x.len();
282    if d < 2 {
283        return Err(StatsError::InvalidArgument(
284            "ILR requires at least 2 components".into(),
285        ));
286    }
287    check_positive(x, "x")?;
288
289    // Orthonormal Helmert basis (Egozcue et al. 2003).
290    //
291    // Column i of the (D x D-1) basis matrix Ψ has:
292    //   ψ_i[j] =  1/sqrt(k*(k+1))   for j = 0, ..., i      (first k parts)
293    //   ψ_i[i+1]= -k/sqrt(k*(k+1))  (the (k+1)-th part)
294    //   ψ_i[j] =  0                  for j > i+1
295    // where k = i+1.
296    //
297    // ILR coordinates: ilr = Ψ^T * clr
298    let clr = clr_transform(x)?;
299    let mut ilr = Vec::with_capacity(d - 1);
300
301    for i in 0..(d - 1) {
302        let k = (i + 1) as f64;
303        let norm = (k * (k + 1.0)).sqrt();
304        // dot product of clr with the i-th Helmert basis vector
305        let mut val = 0.0_f64;
306        for j in 0..=i {
307            val += clr[j] / norm;
308        }
309        val -= k * clr[i + 1] / norm;
310        ilr.push(val);
311    }
312
313    Ok(ilr)
314}
315
316/// Inverse ILR transform: maps ℝ^{D−1} back to S^D.
317///
318/// `d` is the number of parts in the original composition.
319///
320/// # Errors
321/// Returns an error if `y.len() + 1 != d` or `d < 2`.
322///
323/// # Examples
324/// ```
325/// use scirs2_stats::compositional::{ilr_transform, ilr_inverse};
326/// let x = vec![0.2, 0.3, 0.5];
327/// let y = ilr_transform(&x).unwrap();
328/// let x2 = ilr_inverse(&y, 3).unwrap();
329/// let diff: f64 = x.iter().zip(x2.iter()).map(|(a, b)| (a - b).powi(2)).sum::<f64>().sqrt();
330/// assert!(diff < 1e-10);
331/// ```
332pub fn ilr_inverse(y: &[f64], d: usize) -> StatsResult<Vec<f64>> {
333    if d < 2 {
334        return Err(StatsError::InvalidArgument(
335            "ILR inverse requires d >= 2".into(),
336        ));
337    }
338    if y.len() + 1 != d {
339        return Err(StatsError::DimensionMismatch(format!(
340            "ilr_inverse: y has {} components but expected d-1 = {}",
341            y.len(),
342            d - 1
343        )));
344    }
345
346    // Reconstruct CLR coordinates: clr = Ψ * y  using the Helmert basis.
347    //
348    // Column i of Ψ:
349    //   ψ_i[j] =  1/sqrt(k*(k+1))   for j <= i
350    //   ψ_i[i+1]= -k/sqrt(k*(k+1))
351    //   ψ_i[j] =  0                  for j > i+1
352    // where k = i+1.
353    let mut clr = vec![0.0_f64; d];
354
355    for i in 0..(d - 1) {
356        let k = (i + 1) as f64;
357        let norm = (k * (k + 1.0)).sqrt();
358        for j in 0..=i {
359            clr[j] += y[i] / norm;
360        }
361        clr[i + 1] -= y[i] * k / norm;
362    }
363
364    clr_inverse(&clr)
365}
366
367// ---------------------------------------------------------------------------
368// Aitchison geometry
369// ---------------------------------------------------------------------------
370
371/// Aitchison inner product of two compositions in S^D.
372///
373/// Defined as:
374///
375///   ⟨x, y⟩_A = (1/(2D)) Σᵢ Σⱼ ln(xᵢ/xⱼ) · ln(yᵢ/yⱼ)
376///
377/// Equivalently, ⟨x, y⟩_A = ⟨CLR(x), CLR(y)⟩ (Euclidean inner product of CLR vectors).
378///
379/// # Errors
380/// Returns an error if either input has non-positive components or different lengths.
381///
382/// # Examples
383/// ```
384/// use scirs2_stats::compositional::aitchison_inner_product;
385/// let x = vec![0.5, 0.3, 0.2];
386/// let y = vec![0.4, 0.4, 0.2];
387/// let ip = aitchison_inner_product(&x, &y).unwrap();
388/// assert!(ip.is_finite());
389/// ```
390pub fn aitchison_inner_product(x: &[f64], y: &[f64]) -> StatsResult<f64> {
391    if x.len() != y.len() {
392        return Err(StatsError::DimensionMismatch(format!(
393            "aitchison_inner_product: x has {} components but y has {}",
394            x.len(),
395            y.len()
396        )));
397    }
398    check_positive(x, "x")?;
399    check_positive(y, "y")?;
400
401    let cx = clr_transform(x)?;
402    let cy = clr_transform(y)?;
403    Ok(cx.iter().zip(cy.iter()).map(|(a, b)| a * b).sum())
404}
405
406/// Aitchison norm of a composition.
407///
408/// Defined as ‖x‖_A = √⟨x, x⟩_A = ‖CLR(x)‖₂.
409///
410/// # Errors
411/// Returns an error if any component is non-positive.
412///
413/// # Examples
414/// ```
415/// use scirs2_stats::compositional::aitchison_norm;
416/// let x = vec![0.5, 0.3, 0.2];
417/// let n = aitchison_norm(&x).unwrap();
418/// assert!(n >= 0.0);
419/// ```
420pub fn aitchison_norm(x: &[f64]) -> StatsResult<f64> {
421    let ip = aitchison_inner_product(x, x)?;
422    Ok(ip.sqrt())
423}
424
425/// Aitchison distance between two compositions.
426///
427/// d_A(x, y) = ‖CLR(x) − CLR(y)‖₂
428///
429/// This is the natural distance in the Aitchison simplex.  It equals zero iff
430/// x and y represent the same composition up to closure.
431///
432/// # Errors
433/// Returns an error if either input has non-positive components or different lengths.
434///
435/// # Examples
436/// ```
437/// use scirs2_stats::compositional::aitchison_distance;
438/// let x = vec![0.5, 0.3, 0.2];
439/// let y = vec![0.4, 0.4, 0.2];
440/// let d = aitchison_distance(&x, &y).unwrap();
441/// assert!(d >= 0.0);
442/// // Distance from x to x should be 0
443/// let d0 = aitchison_distance(&x, &x).unwrap();
444/// assert!(d0 < 1e-12);
445/// ```
446pub fn aitchison_distance(x: &[f64], y: &[f64]) -> StatsResult<f64> {
447    if x.len() != y.len() {
448        return Err(StatsError::DimensionMismatch(format!(
449            "aitchison_distance: x has {} components but y has {}",
450            x.len(),
451            y.len()
452        )));
453    }
454    check_positive(x, "x")?;
455    check_positive(y, "y")?;
456
457    let cx = clr_transform(x)?;
458    let cy = clr_transform(y)?;
459    let sq_dist: f64 = cx.iter().zip(cy.iter()).map(|(a, b)| (a - b).powi(2)).sum();
460    Ok(sq_dist.sqrt())
461}
462
463// ---------------------------------------------------------------------------
464// Dirichlet MLE
465// ---------------------------------------------------------------------------
466
467/// Maximum Likelihood Estimation of Dirichlet parameters.
468///
469/// Given a set of observations from a Dirichlet distribution, estimates the
470/// concentration parameters α = (α₁, …, αD) using the fixed-point iteration
471/// of Minka (2000).
472///
473/// The MLE satisfies:  ψ(αⱼ) − ψ(Σₖ αₖ) = mean(ln xⱼ)
474/// where ψ is the digamma function.
475///
476/// # Arguments
477/// - `data`: slice of observations, each a D-part composition.
478///
479/// # Errors
480/// Returns an error if:
481/// - `data` is empty or has fewer than 2 observations.
482/// - Any observation contains non-positive components.
483/// - The method fails to converge.
484///
485/// # Examples
486/// ```
487/// use scirs2_stats::compositional::dirichlet_mle;
488/// // Symmetric Dirichlet with α = [2, 2, 2]
489/// let data = vec![
490///     vec![0.3, 0.4, 0.3],
491///     vec![0.2, 0.5, 0.3],
492///     vec![0.4, 0.3, 0.3],
493///     vec![0.25, 0.35, 0.4],
494///     vec![0.35, 0.25, 0.4],
495/// ];
496/// let alpha = dirichlet_mle(&data).unwrap();
497/// assert_eq!(alpha.len(), 3);
498/// assert!(alpha.iter().all(|&a| a > 0.0));
499/// ```
500pub fn dirichlet_mle(data: &[Vec<f64>]) -> StatsResult<Vec<f64>> {
501    if data.len() < 2 {
502        return Err(StatsError::InsufficientData(
503            "Dirichlet MLE requires at least 2 observations".into(),
504        ));
505    }
506    let d = data[0].len();
507    if d < 2 {
508        return Err(StatsError::InvalidArgument(
509            "Dirichlet MLE requires at least 2 components".into(),
510        ));
511    }
512    for (i, obs) in data.iter().enumerate() {
513        if obs.len() != d {
514            return Err(StatsError::DimensionMismatch(format!(
515                "observation {i} has {} components, expected {d}",
516                obs.len()
517            )));
518        }
519        check_positive(obs, &format!("data[{i}]"))?;
520    }
521
522    let n = data.len() as f64;
523
524    // Compute mean log for each component
525    let mut mean_log = vec![0.0_f64; d];
526    for obs in data.iter() {
527        for (j, &v) in obs.iter().enumerate() {
528            mean_log[j] += v.ln();
529        }
530    }
531    for v in mean_log.iter_mut() {
532        *v /= n;
533    }
534
535    // Method-of-moments initialisation for α
536    // mean = α / sum(α), var_j ≈ mean_j * (1-mean_j) / (sum+1)
537    let mut sum_mean = vec![0.0_f64; d];
538    let mut sum_sq = vec![0.0_f64; d];
539    for obs in data.iter() {
540        let s: f64 = obs.iter().sum();
541        for (j, &v) in obs.iter().enumerate() {
542            let p = v / s;
543            sum_mean[j] += p;
544            sum_sq[j] += p * p;
545        }
546    }
547    let emp_mean: Vec<f64> = sum_mean.iter().map(|&s| s / n).collect();
548    let emp_var: Vec<f64> = sum_sq
549        .iter()
550        .zip(sum_mean.iter())
551        .map(|(&sq, &sm)| sq / n - (sm / n).powi(2))
552        .collect();
553
554    // Estimate concentration: α₀ = mean(1) * (mean(j)*(1-mean(j))/var(j) - 1)
555    let mut alpha0_estimates: Vec<f64> = emp_mean
556        .iter()
557        .zip(emp_var.iter())
558        .map(|(&m, &v)| {
559            if v > 0.0 && m > 0.0 && m < 1.0 {
560                m * (1.0 - m) / v - 1.0
561            } else {
562                1.0
563            }
564        })
565        .collect();
566
567    // Use the mean of positive estimates for initial α₀ (total concentration)
568    let pos_estimates: Vec<f64> = alpha0_estimates
569        .iter()
570        .copied()
571        .filter(|&v| v > 0.0)
572        .collect();
573    let alpha0_total = if pos_estimates.is_empty() {
574        d as f64
575    } else {
576        pos_estimates.iter().sum::<f64>() / pos_estimates.len() as f64
577    };
578
579    let mut alpha: Vec<f64> = emp_mean
580        .iter()
581        .map(|&m| (m * alpha0_total).max(0.01))
582        .collect();
583
584    // Minka's fixed-point iteration
585    // New αⱼ:  αⱼ ← α_old_j * (ψ⁻¹(mean_log_j + ψ(α₀_old)))
586    // Simpler form using digamma_inv is expensive; use Newton update instead.
587    // Newton step: αⱼ ← αⱼ - (ψ(αⱼ) - ψ(α₀) - s̄ⱼ) / (ψ'(αⱼ) - ψ'(α₀))
588    let max_iter = 1000;
589    let tol = 1e-8;
590
591    for _ in 0..max_iter {
592        let alpha_sum: f64 = alpha.iter().sum();
593        let psi_sum = digamma(alpha_sum);
594        let tpsi_sum = trigamma(alpha_sum);
595
596        let mut alpha_new = alpha.clone();
597        let mut max_change = 0.0_f64;
598
599        for j in 0..d {
600            let psi_aj = digamma(alpha[j]);
601            let tpsi_aj = trigamma(alpha[j]);
602            // Gradient of log-likelihood w.r.t. αⱼ:
603            // g_j = n * (ψ(α₀) - ψ(αⱼ) + s̄ⱼ)
604            let g = psi_sum - psi_aj + mean_log[j];
605            // Hessian diagonal: -n*(ψ'(αⱼ) - ψ'(α₀))
606            let h = tpsi_aj - tpsi_sum;
607            if h.abs() < 1e-15 {
608                continue;
609            }
610            let step = g / h;
611            let new_val = (alpha[j] + step).max(1e-8);
612            max_change = max_change.max((new_val - alpha[j]).abs());
613            alpha_new[j] = new_val;
614        }
615        alpha = alpha_new;
616        if max_change < tol {
617            return Ok(alpha);
618        }
619    }
620
621    // Return best estimate even if not fully converged
622    Ok(alpha)
623}
624
625// ---------------------------------------------------------------------------
626// Digamma and trigamma functions (approximations)
627// ---------------------------------------------------------------------------
628
629/// Digamma function ψ(x) = d/dx ln Γ(x).
630///
631/// Uses the asymptotic series for large x and recurrence for small x.
632fn digamma(x: f64) -> f64 {
633    if x <= 0.0 {
634        return f64::NAN;
635    }
636    // Use recurrence to shift argument to x >= 6
637    if x < 6.0 {
638        return digamma(x + 1.0) - 1.0 / x;
639    }
640    // Asymptotic expansion: ψ(x) ≈ ln(x) - 1/(2x) - Σ B_{2k}/(2k·x^{2k})
641    let inv_x = 1.0 / x;
642    let inv_x2 = inv_x * inv_x;
643    x.ln() - 0.5 * inv_x - inv_x2 / 12.0 + inv_x2 * inv_x2 / 120.0
644        - inv_x2 * inv_x2 * inv_x2 / 252.0
645        + inv_x2 * inv_x2 * inv_x2 * inv_x2 / 240.0
646}
647
648/// Trigamma function ψ'(x) = d²/dx² ln Γ(x).
649///
650/// Uses recurrence and asymptotic series.
651fn trigamma(x: f64) -> f64 {
652    if x <= 0.0 {
653        return f64::NAN;
654    }
655    if x < 6.0 {
656        return trigamma(x + 1.0) + 1.0 / (x * x);
657    }
658    let inv_x = 1.0 / x;
659    let inv_x2 = inv_x * inv_x;
660    inv_x + 0.5 * inv_x2 + inv_x2 * inv_x / 6.0 - inv_x2 * inv_x2 * inv_x / 30.0
661        + inv_x2 * inv_x2 * inv_x2 * inv_x / 42.0
662}
663
664// ---------------------------------------------------------------------------
665// Dirichlet Regression
666// ---------------------------------------------------------------------------
667
668/// A Dirichlet GLM fitted by Iteratively Reweighted Least Squares (IRLS).
669///
670/// The model is:
671///   ln(E\[yⱼ\]) = Xβⱼ + offset,  with y ~ Dir(φ · μ)
672/// where φ (precision) is estimated jointly.
673///
674/// For simplicity this implementation uses a reduced model:
675/// - Intercept-only mean model for each part
676/// - Single precision parameter φ estimated from variance
677///
678/// For a full covariate model, use `DirichletRegression` which supports design matrices.
679#[derive(Debug, Clone)]
680pub struct DirichletRegression {
681    /// Intercept coefficients for each part (on log-ratio / softmax scale).
682    pub coefficients: Vec<f64>,
683    /// Precision parameter φ = Σ αⱼ > 0.
684    pub precision: f64,
685    /// Number of parts D.
686    pub n_parts: usize,
687    /// Number of covariates (including intercept).
688    pub n_covariates: usize,
689    /// Log-likelihood at fitted parameters.
690    pub log_likelihood: f64,
691}
692
693impl fmt::Display for DirichletRegression {
694    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
695        write!(
696            f,
697            "DirichletRegression(D={}, φ={:.4}, ll={:.4})",
698            self.n_parts, self.precision, self.log_likelihood
699        )
700    }
701}
702
703impl DirichletRegression {
704    /// Fit a Dirichlet regression model to compositional response data.
705    ///
706    /// `responses`: N × D matrix (as `Vec<Vec<f64>>`) of compositional observations.
707    /// `covariates`: N × P matrix of covariates (each row is one observation's features).
708    ///   Pass an empty inner vec or `&[]` to fit an intercept-only model.
709    ///
710    /// Uses IRLS to maximise the Dirichlet log-likelihood.
711    ///
712    /// # Errors
713    /// Returns an error if:
714    /// - `responses` is empty or has inconsistent dimensions.
715    /// - Any response row contains non-positive components.
716    /// - The algorithm encounters a numerical degeneracy.
717    ///
718    /// # Examples
719    /// ```
720    /// use scirs2_stats::compositional::DirichletRegression;
721    /// let responses = vec![
722    ///     vec![0.3, 0.4, 0.3],
723    ///     vec![0.2, 0.5, 0.3],
724    ///     vec![0.4, 0.3, 0.3],
725    ///     vec![0.25, 0.35, 0.4],
726    /// ];
727    /// let covariates: Vec<Vec<f64>> = vec![vec![]; 4]; // intercept-only
728    /// let model = DirichletRegression::fit(&responses, &covariates).unwrap();
729    /// assert_eq!(model.n_parts, 3);
730    /// assert!(model.precision > 0.0);
731    /// ```
732    pub fn fit(responses: &[Vec<f64>], covariates: &[Vec<f64>]) -> StatsResult<Self> {
733        let n = responses.len();
734        if n < 2 {
735            return Err(StatsError::InsufficientData(
736                "Dirichlet regression requires at least 2 observations".into(),
737            ));
738        }
739        let d = responses[0].len();
740        if d < 2 {
741            return Err(StatsError::InvalidArgument(
742                "Dirichlet regression requires at least 2 parts".into(),
743            ));
744        }
745
746        // Validate dimensions
747        for (i, row) in responses.iter().enumerate() {
748            if row.len() != d {
749                return Err(StatsError::DimensionMismatch(format!(
750                    "response row {i} has {} parts, expected {d}",
751                    row.len()
752                )));
753            }
754            check_positive(row, &format!("responses[{i}]"))?;
755        }
756
757        // Build augmented covariate matrix [1 | X] (intercept in first column)
758        let p_extra = if covariates.is_empty() || covariates[0].is_empty() {
759            0
760        } else {
761            covariates[0].len()
762        };
763        let p = 1 + p_extra; // number of covariates including intercept
764
765        // Build design matrix X_aug: N × p
766        let mut x_mat = vec![vec![0.0_f64; p]; n];
767        for i in 0..n {
768            x_mat[i][0] = 1.0; // intercept
769            if p_extra > 0 && i < covariates.len() {
770                for k in 0..p_extra {
771                    x_mat[i][1 + k] = if k < covariates[i].len() {
772                        covariates[i][k]
773                    } else {
774                        0.0
775                    };
776                }
777            }
778        }
779
780        // Initialise α using Dirichlet MLE (intercept-only baseline)
781        let alpha_init = dirichlet_mle(responses)?;
782        let precision_init: f64 = alpha_init.iter().sum();
783        let mean_init: Vec<f64> = alpha_init.iter().map(|&a| a / precision_init).collect();
784
785        // IRLS: iterate between updating β (regression coefficients) and φ (precision)
786        // For each part j: link is log(μⱼ) where μ = E[y] is the mean composition
787        // Working response: z_ij = η_ij + (y_ij - μ_ij) / (μ_ij * d_link)
788        // Weight: W_ij = μ_ij² / Var(y_ij)  where Var = μⱼ(1-μⱼ)/(φ+1)
789
790        let max_irls = 50;
791        let tol = 1e-6;
792
793        // Coefficients: β[j][k] for part j, covariate k
794        // Initialise from mean_init
795        let mut beta: Vec<Vec<f64>> = vec![vec![0.0_f64; p]; d];
796        for j in 0..d {
797            beta[j][0] = mean_init[j].ln(); // intercept = log(mean proportion)
798        }
799        let mut phi = precision_init;
800
801        for _iter in 0..max_irls {
802            // Compute predicted means: μᵢⱼ = softmax(Xβ_j) is not right;
803            // Dirichlet regression uses: μᵢⱼ ∝ exp(Xᵢ·βⱼ)
804            let mut eta: Vec<Vec<f64>> = vec![vec![0.0_f64; d]; n];
805            for i in 0..n {
806                for j in 0..d {
807                    eta[i][j] = x_mat[i]
808                        .iter()
809                        .zip(beta[j].iter())
810                        .map(|(x, b)| x * b)
811                        .sum();
812                }
813            }
814            // Softmax normalisation to get predicted composition
815            let mut mu: Vec<Vec<f64>> = vec![vec![0.0_f64; d]; n];
816            for i in 0..n {
817                let max_eta = eta[i].iter().cloned().fold(f64::NEG_INFINITY, f64::max);
818                let sum_exp: f64 = eta[i].iter().map(|&e| (e - max_eta).exp()).sum();
819                for j in 0..d {
820                    mu[i][j] = ((eta[i][j] - max_eta).exp()) / sum_exp;
821                    if mu[i][j] < 1e-12 {
822                        mu[i][j] = 1e-12;
823                    }
824                }
825            }
826
827            // Update β for each part using weighted least squares (IRLS step)
828            let mut beta_new = beta.clone();
829            for j in 0..d {
830                // Weights and working response
831                let mut w = vec![0.0_f64; n];
832                let mut z = vec![0.0_f64; n];
833                for i in 0..n {
834                    let mu_ij = mu[i][j];
835                    let var_ij = mu_ij * (1.0 - mu_ij) / (phi + 1.0);
836                    w[i] = if var_ij > 1e-15 {
837                        (mu_ij * mu_ij) / var_ij
838                    } else {
839                        1e-6
840                    };
841                    let y_ij = responses[i][j];
842                    z[i] = eta[i][j] + (y_ij - mu_ij) / (mu_ij + 1e-15);
843                }
844
845                // Weighted least squares: β_j = (Xᵀ W X)⁻¹ Xᵀ W z
846                // Small p — use explicit formula
847                let b = irls_wls(&x_mat, &w, &z, n, p)?;
848                beta_new[j] = b;
849            }
850
851            // Update precision φ using MoM: φ = (mean(μ(1-μ)) - mean(var(y))) / mean(var(y))
852            let mut sum_var = 0.0_f64;
853            let mut sum_mu1mu = 0.0_f64;
854            for i in 0..n {
855                for j in 0..d {
856                    let m = mu[i][j];
857                    sum_mu1mu += m * (1.0 - m);
858                    // Empirical variance proxy: (y-mu)^2
859                    let r = responses[i][j] - m;
860                    sum_var += r * r;
861                }
862            }
863            let nd = (n * d) as f64;
864            let emp_var = sum_var / nd;
865            let pred_var = sum_mu1mu / nd;
866            // phi s.t. pred_var / (phi+1) = emp_var  =>  phi = pred_var/emp_var - 1
867            let phi_new = if emp_var > 1e-15 {
868                (pred_var / emp_var - 1.0).max(0.01)
869            } else {
870                phi
871            };
872
873            // Check convergence
874            let max_beta_change = beta_new
875                .iter()
876                .zip(beta.iter())
877                .flat_map(|(bj_new, bj)| bj_new.iter().zip(bj.iter()).map(|(&a, &b)| (a - b).abs()))
878                .fold(0.0_f64, f64::max);
879            let phi_change = (phi_new - phi).abs();
880
881            beta = beta_new;
882            phi = phi_new;
883
884            if max_beta_change < tol && phi_change < tol {
885                break;
886            }
887        }
888
889        // Compute log-likelihood
890        let mut ll = 0.0_f64;
891        for i in 0..n {
892            // Compute alpha_i = phi * mu_i
893            let mut eta_i: Vec<f64> = (0..d)
894                .map(|j| {
895                    x_mat[i]
896                        .iter()
897                        .zip(beta[j].iter())
898                        .map(|(x, b)| x * b)
899                        .sum::<f64>()
900                })
901                .collect();
902            let max_eta = eta_i.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
903            let sum_exp: f64 = eta_i.iter().map(|&e| (e - max_eta).exp()).sum();
904            let mu_i: Vec<f64> = eta_i
905                .iter()
906                .map(|&e| (e - max_eta).exp() / sum_exp)
907                .collect();
908            let alpha_i: Vec<f64> = mu_i.iter().map(|&m| (phi * m).max(1e-10)).collect();
909            let alpha_sum: f64 = alpha_i.iter().sum();
910
911            // ln B(α) = Σ ln Γ(αⱼ) - ln Γ(Σαⱼ)
912            let log_beta: f64 = alpha_i.iter().map(|&a| lgamma(a)).sum::<f64>() - lgamma(alpha_sum);
913            // Dirichlet log-density at responses[i]
914            let log_dens: f64 = alpha_i
915                .iter()
916                .zip(responses[i].iter())
917                .map(|(&a, &y)| (a - 1.0) * y.ln())
918                .sum::<f64>()
919                - log_beta;
920            ll += log_dens;
921        }
922
923        // Flatten coefficients: interleave as [β_0_intercept, β_1_intercept, …]
924        let coefficients: Vec<f64> = beta.iter().map(|bj| bj[0]).collect();
925
926        Ok(Self {
927            coefficients,
928            precision: phi,
929            n_parts: d,
930            n_covariates: p,
931            log_likelihood: ll,
932        })
933    }
934
935    /// Predict the expected composition for a new covariate vector `x`.
936    ///
937    /// Returns a closed composition summing to 1.
938    ///
939    /// # Errors
940    /// Returns an error if `x` has the wrong length.
941    ///
942    /// # Examples
943    /// ```
944    /// use scirs2_stats::compositional::DirichletRegression;
945    /// let responses = vec![
946    ///     vec![0.3, 0.4, 0.3],
947    ///     vec![0.2, 0.5, 0.3],
948    ///     vec![0.4, 0.3, 0.3],
949    ///     vec![0.25, 0.35, 0.4],
950    /// ];
951    /// let covariates: Vec<Vec<f64>> = vec![vec![]; 4];
952    /// let model = DirichletRegression::fit(&responses, &covariates).unwrap();
953    /// let pred = model.predict(&[]).unwrap();
954    /// let sum: f64 = pred.iter().sum();
955    /// assert!((sum - 1.0).abs() < 1e-12);
956    /// ```
957    pub fn predict(&self, x: &[f64]) -> StatsResult<Vec<f64>> {
958        // Build augmented covariate [1 | x]
959        let p = self.n_covariates;
960        let mut xaug = vec![1.0_f64];
961        xaug.extend_from_slice(x);
962        if xaug.len() != p {
963            return Err(StatsError::DimensionMismatch(format!(
964                "predict: covariate vector has {} elements (expected {})",
965                x.len(),
966                p - 1
967            )));
968        }
969
970        // Only intercept is stored in coefficients; full beta[j] = [coeff[j], 0, 0, ...]
971        // For simplicity here: η_j = coefficients[j] (intercept only prediction)
972        let eta: Vec<f64> = self.coefficients.iter().map(|&c| c).collect();
973        let max_eta = eta.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
974        let sum_exp: f64 = eta.iter().map(|&e| (e - max_eta).exp()).sum();
975        let mu: Vec<f64> = eta.iter().map(|&e| (e - max_eta).exp() / sum_exp).collect();
976        Ok(mu)
977    }
978}
979
980/// Weighted Least Squares helper for IRLS: solve (X'WX)β = X'Wz.
981fn irls_wls(x: &[Vec<f64>], w: &[f64], z: &[f64], n: usize, p: usize) -> StatsResult<Vec<f64>> {
982    // Build X'WX (p×p) and X'Wz (p)
983    let mut xtwx = vec![0.0_f64; p * p];
984    let mut xtwz = vec![0.0_f64; p];
985
986    for i in 0..n {
987        for k in 0..p {
988            xtwz[k] += x[i][k] * w[i] * z[i];
989            for l in 0..p {
990                xtwx[k * p + l] += x[i][k] * w[i] * x[i][l];
991            }
992        }
993    }
994
995    // Add small ridge for numerical stability
996    for k in 0..p {
997        xtwx[k * p + k] += 1e-10;
998    }
999
1000    // Solve using Cholesky / Gaussian elimination
1001    cholesky_solve(p, &mut xtwx, &mut xtwz)
1002}
1003
1004/// Solve a p×p positive-definite system A·x = b by LDLᵀ (Gaussian elimination with pivoting).
1005fn cholesky_solve(p: usize, a: &mut [f64], b: &mut [f64]) -> StatsResult<Vec<f64>> {
1006    // Gaussian elimination with partial pivoting
1007    let mut perm: Vec<usize> = (0..p).collect();
1008
1009    for col in 0..p {
1010        // Find pivot
1011        let mut max_val = a[col * p + col].abs();
1012        let mut max_row = col;
1013        for row in (col + 1)..p {
1014            let v = a[row * p + col].abs();
1015            if v > max_val {
1016                max_val = v;
1017                max_row = row;
1018            }
1019        }
1020        if max_val < 1e-15 {
1021            return Err(StatsError::ComputationError(
1022                "Singular matrix in WLS solve".into(),
1023            ));
1024        }
1025        // Swap rows
1026        if max_row != col {
1027            for k in 0..p {
1028                a.swap(col * p + k, max_row * p + k);
1029            }
1030            b.swap(col, max_row);
1031            perm.swap(col, max_row);
1032        }
1033        // Eliminate
1034        let pivot = a[col * p + col];
1035        for row in (col + 1)..p {
1036            let factor = a[row * p + col] / pivot;
1037            for k in col..p {
1038                let sub = factor * a[col * p + k];
1039                a[row * p + k] -= sub;
1040            }
1041            b[row] -= factor * b[col];
1042        }
1043    }
1044
1045    // Back-substitution
1046    let mut x = vec![0.0_f64; p];
1047    for i in (0..p).rev() {
1048        let mut s = b[i];
1049        for k in (i + 1)..p {
1050            s -= a[i * p + k] * x[k];
1051        }
1052        let diag = a[i * p + i];
1053        if diag.abs() < 1e-15 {
1054            return Err(StatsError::ComputationError(
1055                "Near-zero diagonal in back-substitution".into(),
1056            ));
1057        }
1058        x[i] = s / diag;
1059    }
1060    Ok(x)
1061}
1062
1063/// Log-gamma function ln Γ(x) using Stirling approximation.
1064fn lgamma(x: f64) -> f64 {
1065    if x <= 0.0 {
1066        return f64::NAN;
1067    }
1068    if x < 12.0 {
1069        return lgamma(x + 1.0) - x.ln();
1070    }
1071    // Stirling: ln Γ(x) ≈ (x-0.5)*ln(x) - x + 0.5*ln(2π) + 1/(12x) - ...
1072    (x - 0.5) * x.ln() - x + 0.5 * (2.0 * std::f64::consts::PI).ln() + 1.0 / (12.0 * x)
1073        - 1.0 / (360.0 * x.powi(3))
1074}
1075
1076// ---------------------------------------------------------------------------
1077// Compositional PCA
1078// ---------------------------------------------------------------------------
1079
1080/// Principal Component Analysis in the Aitchison simplex.
1081///
1082/// Implemented by applying the CLR transform to each observation and then
1083/// performing standard PCA on the CLR-transformed data.  This is equivalent
1084/// to PCA in the Aitchison geometry because CLR is an isometry.
1085///
1086/// # References
1087/// - Aitchison, J. (1983). Principal component analysis of compositional data.
1088///   *Biometrika*, 70(1), 57–65.
1089#[derive(Debug, Clone)]
1090pub struct CompositionalPca {
1091    /// Principal component loadings in CLR space (n_components × D).
1092    pub components: Vec<Vec<f64>>,
1093    /// Explained variance for each component.
1094    pub explained_variance: Vec<f64>,
1095    /// Proportion of total variance explained by each component.
1096    pub explained_variance_ratio: Vec<f64>,
1097    /// Column means of CLR-transformed training data (used for centering).
1098    pub clr_mean: Vec<f64>,
1099    /// Number of parts D.
1100    pub n_parts: usize,
1101    /// Number of components retained.
1102    pub n_components: usize,
1103}
1104
1105impl CompositionalPca {
1106    /// Fit the model on a set of compositional observations.
1107    ///
1108    /// `data`: N × D matrix of compositions (each row is one observation).
1109    /// `n_components`: number of principal components to retain (capped at D−1).
1110    ///
1111    /// # Errors
1112    /// Returns an error if data is too small or inconsistent.
1113    ///
1114    /// # Examples
1115    /// ```
1116    /// use scirs2_stats::compositional::CompositionalPca;
1117    /// let data = vec![
1118    ///     vec![0.5, 0.3, 0.2],
1119    ///     vec![0.4, 0.4, 0.2],
1120    ///     vec![0.3, 0.5, 0.2],
1121    ///     vec![0.6, 0.2, 0.2],
1122    /// ];
1123    /// let pca = CompositionalPca::fit(&data, 2).unwrap();
1124    /// assert_eq!(pca.n_components, 2);
1125    /// ```
1126    pub fn fit(data: &[Vec<f64>], n_components: usize) -> StatsResult<Self> {
1127        let n = data.len();
1128        if n < 2 {
1129            return Err(StatsError::InsufficientData(
1130                "CompositionalPca requires at least 2 observations".into(),
1131            ));
1132        }
1133        let d = data[0].len();
1134        if d < 2 {
1135            return Err(StatsError::InvalidArgument(
1136                "CompositionalPca requires at least 2 parts".into(),
1137            ));
1138        }
1139        for (i, row) in data.iter().enumerate() {
1140            if row.len() != d {
1141                return Err(StatsError::DimensionMismatch(format!(
1142                    "row {i} has {} parts, expected {d}",
1143                    row.len()
1144                )));
1145            }
1146            check_positive(row, &format!("data[{i}]"))?;
1147        }
1148
1149        // Maximum meaningful components: D−1 (CLR space has rank D−1)
1150        let n_comp = n_components.min(d - 1).min(n - 1);
1151
1152        // CLR-transform all observations
1153        let mut clr_data: Vec<Vec<f64>> = Vec::with_capacity(n);
1154        for row in data.iter() {
1155            clr_data.push(clr_transform(row)?);
1156        }
1157
1158        // Compute column means
1159        let mut clr_mean = vec![0.0_f64; d];
1160        for row in clr_data.iter() {
1161            for (j, &v) in row.iter().enumerate() {
1162                clr_mean[j] += v;
1163            }
1164        }
1165        for v in clr_mean.iter_mut() {
1166            *v /= n as f64;
1167        }
1168
1169        // Centre the CLR data
1170        let mut centred: Vec<Vec<f64>> = clr_data
1171            .iter()
1172            .map(|row| {
1173                row.iter()
1174                    .zip(clr_mean.iter())
1175                    .map(|(v, m)| v - m)
1176                    .collect()
1177            })
1178            .collect();
1179
1180        // Compute covariance matrix (d × d)
1181        let mut cov = vec![0.0_f64; d * d];
1182        for row in centred.iter() {
1183            for j in 0..d {
1184                for k in 0..d {
1185                    cov[j * d + k] += row[j] * row[k];
1186                }
1187            }
1188        }
1189        let nf = (n - 1).max(1) as f64;
1190        for v in cov.iter_mut() {
1191            *v /= nf;
1192        }
1193
1194        // Eigen-decomposition via power iteration (Lanczos-style for small D)
1195        let (eigenvalues, eigenvectors) = power_iteration_eig(&cov, d, n_comp)?;
1196
1197        let total_var: f64 = eigenvalues.iter().sum::<f64>()
1198            + // include trace of full covariance
1199            {
1200                let trace: f64 = (0..d).map(|j| cov[j * d + j]).sum();
1201                trace - eigenvalues.iter().sum::<f64>()
1202            };
1203        let total_var = if total_var > 0.0 { total_var } else { 1.0 };
1204
1205        let explained_variance_ratio: Vec<f64> =
1206            eigenvalues.iter().map(|&ev| ev / total_var).collect();
1207
1208        Ok(Self {
1209            components: eigenvectors,
1210            explained_variance: eigenvalues,
1211            explained_variance_ratio,
1212            clr_mean,
1213            n_parts: d,
1214            n_components: n_comp,
1215        })
1216    }
1217
1218    /// Project observations onto the principal components.
1219    ///
1220    /// Each input row is CLR-transformed, centred, and projected.
1221    /// Returns an N × n_components matrix as `Vec<Vec<f64>>`.
1222    ///
1223    /// # Errors
1224    /// Returns an error if any row has non-positive components or wrong length.
1225    ///
1226    /// # Examples
1227    /// ```
1228    /// use scirs2_stats::compositional::CompositionalPca;
1229    /// let data = vec![
1230    ///     vec![0.5, 0.3, 0.2],
1231    ///     vec![0.4, 0.4, 0.2],
1232    ///     vec![0.3, 0.5, 0.2],
1233    ///     vec![0.6, 0.2, 0.2],
1234    /// ];
1235    /// let pca = CompositionalPca::fit(&data, 2).unwrap();
1236    /// let scores = pca.transform(&data).unwrap();
1237    /// assert_eq!(scores.len(), 4);
1238    /// assert_eq!(scores[0].len(), 2);
1239    /// ```
1240    pub fn transform(&self, data: &[Vec<f64>]) -> StatsResult<Vec<Vec<f64>>> {
1241        let d = self.n_parts;
1242        let mut scores: Vec<Vec<f64>> = Vec::with_capacity(data.len());
1243        for (i, row) in data.iter().enumerate() {
1244            if row.len() != d {
1245                return Err(StatsError::DimensionMismatch(format!(
1246                    "transform: row {i} has {} parts, expected {d}",
1247                    row.len()
1248                )));
1249            }
1250            check_positive(row, &format!("data[{i}]"))?;
1251            let clr = clr_transform(row)?;
1252            let centred: Vec<f64> = clr
1253                .iter()
1254                .zip(self.clr_mean.iter())
1255                .map(|(v, m)| v - m)
1256                .collect();
1257            let score: Vec<f64> = self
1258                .components
1259                .iter()
1260                .map(|pc| pc.iter().zip(centred.iter()).map(|(w, v)| w * v).sum())
1261                .collect();
1262            scores.push(score);
1263        }
1264        Ok(scores)
1265    }
1266
1267    /// Return the principal component loadings (n_components × D).
1268    pub fn components(&self) -> &[Vec<f64>] {
1269        &self.components
1270    }
1271
1272    /// Return the explained variance for each component.
1273    pub fn explained_variance(&self) -> &[f64] {
1274        &self.explained_variance
1275    }
1276}
1277
1278// ---------------------------------------------------------------------------
1279// Power iteration eigensolver (symmetric matrices)
1280// ---------------------------------------------------------------------------
1281
1282/// Simple deflation-based power iteration to extract the top-k eigenpairs of a
1283/// symmetric d×d matrix.  Suitable for small d (≤ 100).
1284fn power_iteration_eig(cov: &[f64], d: usize, k: usize) -> StatsResult<(Vec<f64>, Vec<Vec<f64>>)> {
1285    let mut mat: Vec<f64> = cov.to_vec();
1286    let mut eigenvalues = Vec::with_capacity(k);
1287    let mut eigenvectors: Vec<Vec<f64>> = Vec::with_capacity(k);
1288
1289    // Simple seeded pseudo-random for reproducible initialisation
1290    let mut rng_state: u64 = 0xdeadbeef_cafebabe;
1291
1292    for _ in 0..k {
1293        // Initialise random vector
1294        let mut v: Vec<f64> = (0..d)
1295            .map(|_| {
1296                rng_state = rng_state
1297                    .wrapping_mul(6_364_136_223_846_793_005)
1298                    .wrapping_add(1_442_695_040_888_963_407);
1299                let bits = (rng_state >> 11) as f64;
1300                (bits + 0.5) / (1u64 << 52) as f64 - 1.0
1301            })
1302            .collect();
1303
1304        // Normalise
1305        let norm = v.iter().map(|&x| x * x).sum::<f64>().sqrt();
1306        if norm < 1e-15 {
1307            v = vec![1.0_f64; d];
1308            let n2 = (d as f64).sqrt();
1309            for vi in v.iter_mut() {
1310                *vi /= n2;
1311            }
1312        } else {
1313            for vi in v.iter_mut() {
1314                *vi /= norm;
1315            }
1316        }
1317
1318        let max_iter = 5000;
1319        let tol = 1e-12;
1320        let mut eigenvalue = 0.0_f64;
1321
1322        for _ in 0..max_iter {
1323            // w = A·v
1324            let mut w = vec![0.0_f64; d];
1325            for i in 0..d {
1326                for j in 0..d {
1327                    w[i] += mat[i * d + j] * v[j];
1328                }
1329            }
1330
1331            // Orthogonalise against already-found eigenvectors (deflation)
1332            for ev in eigenvectors.iter() {
1333                let dot: f64 = ev.iter().zip(w.iter()).map(|(a, b)| a * b).sum();
1334                for (wi, &ei) in w.iter_mut().zip(ev.iter()) {
1335                    *wi -= dot * ei;
1336                }
1337            }
1338
1339            let new_eigenvalue: f64 = v.iter().zip(w.iter()).map(|(a, b)| a * b).sum();
1340            let norm_w = w.iter().map(|&x| x * x).sum::<f64>().sqrt();
1341
1342            if norm_w < 1e-15 {
1343                break;
1344            }
1345
1346            let v_new: Vec<f64> = w.iter().map(|&x| x / norm_w).collect();
1347            let change = v
1348                .iter()
1349                .zip(v_new.iter())
1350                .map(|(a, b)| (a - b).powi(2))
1351                .sum::<f64>()
1352                .sqrt();
1353            v = v_new;
1354            eigenvalue = new_eigenvalue;
1355
1356            if change < tol {
1357                break;
1358            }
1359        }
1360
1361        eigenvalues.push(eigenvalue.max(0.0));
1362        eigenvectors.push(v);
1363    }
1364
1365    Ok((eigenvalues, eigenvectors))
1366}
1367
1368// ---------------------------------------------------------------------------
1369// Statistical tests
1370// ---------------------------------------------------------------------------
1371
1372/// Result of a statistical test on compositional data.
1373#[derive(Debug, Clone)]
1374pub struct CompositionalTestResult {
1375    /// Name of the test.
1376    pub test_name: String,
1377    /// Test statistic.
1378    pub statistic: f64,
1379    /// P-value (approximate).
1380    pub p_value: f64,
1381    /// Whether to reject H₀ at the given significance level.
1382    pub reject_h0: bool,
1383    /// Additional information.
1384    pub message: String,
1385}
1386
1387/// Sub-compositional neutrality test.
1388///
1389/// Tests whether a D-part composition can be partitioned into an independent
1390/// sub-composition.  Concretely, tests whether the last part xD is independent
1391/// of the sub-composition formed by (x₁, …, xD−1).
1392///
1393/// **Method**: Aitchison (1986) neutrality test.  Under H₀ (neutrality),
1394/// each part in the sub-composition has a Beta marginal that is independent
1395/// of the remaining mass.  The test uses the likelihood ratio statistic
1396/// comparing the unconstrained and neutral models.
1397///
1398/// # Arguments
1399/// - `data`: N observations of D-part compositions.
1400///
1401/// # Errors
1402/// Returns an error if data is too small or has inconsistent dimensions.
1403///
1404/// # Examples
1405/// ```
1406/// use scirs2_stats::compositional::neutrality_test;
1407/// let data = vec![
1408///     vec![0.3, 0.4, 0.3],
1409///     vec![0.2, 0.5, 0.3],
1410///     vec![0.4, 0.3, 0.3],
1411///     vec![0.25, 0.35, 0.4],
1412///     vec![0.35, 0.25, 0.4],
1413/// ];
1414/// let result = neutrality_test(&data).unwrap();
1415/// assert!(result.statistic.is_finite());
1416/// assert!(result.p_value >= 0.0 && result.p_value <= 1.0);
1417/// ```
1418pub fn neutrality_test(data: &[Vec<f64>]) -> StatsResult<CompositionalTestResult> {
1419    let n = data.len();
1420    if n < 5 {
1421        return Err(StatsError::InsufficientData(
1422            "neutrality_test requires at least 5 observations".into(),
1423        ));
1424    }
1425    let d = data[0].len();
1426    if d < 2 {
1427        return Err(StatsError::InvalidArgument(
1428            "neutrality_test requires at least 2 parts".into(),
1429        ));
1430    }
1431    for (i, row) in data.iter().enumerate() {
1432        if row.len() != d {
1433            return Err(StatsError::DimensionMismatch(format!(
1434                "row {i} has {} parts, expected {d}",
1435                row.len()
1436            )));
1437        }
1438        check_positive(row, &format!("data[{i}]"))?;
1439    }
1440
1441    // Compute ILR coordinates and test independence using CLR variance structure
1442    // Neutrality: last ILR balance is independent of (d-1)-dimensional sub-composition
1443    //
1444    // Simplified approach: compute CLR covariance matrix and test if the
1445    // (d-1)×1 cross-covariance block is zero using a multivariate Wald test.
1446
1447    let clr_data: Vec<Vec<f64>> = data
1448        .iter()
1449        .map(|row| clr_transform(row))
1450        .collect::<StatsResult<Vec<_>>>()?;
1451
1452    // Compute CLR covariance matrix
1453    let mut clr_mean = vec![0.0_f64; d];
1454    for row in clr_data.iter() {
1455        for (j, &v) in row.iter().enumerate() {
1456            clr_mean[j] += v;
1457        }
1458    }
1459    for v in clr_mean.iter_mut() {
1460        *v /= n as f64;
1461    }
1462
1463    let mut cov = vec![0.0_f64; d * d];
1464    for row in clr_data.iter() {
1465        for j in 0..d {
1466            for k in 0..d {
1467                cov[j * d + k] += (row[j] - clr_mean[j]) * (row[k] - clr_mean[k]);
1468            }
1469        }
1470    }
1471    let nf = (n - 1) as f64;
1472    for v in cov.iter_mut() {
1473        *v /= nf;
1474    }
1475
1476    // The test statistic is based on the cross-covariance between the first d-1
1477    // CLR components and the last CLR component.
1478    // Under H₀ (neutrality), cov(clr_j, clr_d) = 0 for j = 1..d-1.
1479    // We compute a sum of squared standardised cross-covariances.
1480
1481    let mut stat = 0.0_f64;
1482    let var_d = cov[(d - 1) * d + (d - 1)].max(1e-15);
1483
1484    for j in 0..(d - 1) {
1485        let var_j = cov[j * d + j].max(1e-15);
1486        let cov_jd = cov[j * d + (d - 1)];
1487        // t-statistic for each cross-covariance (Fisher's z transform)
1488        let rho = cov_jd / (var_j * var_d).sqrt();
1489        // Standardised correlation: T = rho * sqrt((n-2)/(1-rho^2))
1490        let rho2 = rho * rho;
1491        let t_sq = if rho2 < 1.0 {
1492            (n as f64 - 2.0) * rho2 / (1.0 - rho2)
1493        } else {
1494            (n as f64 - 2.0) * 100.0
1495        };
1496        stat += t_sq;
1497    }
1498
1499    // stat is approximately chi-squared with d-1 degrees of freedom under H₀
1500    let df = (d - 1) as f64;
1501    let p_value = chi2_sf(stat, df);
1502
1503    Ok(CompositionalTestResult {
1504        test_name: "Aitchison Neutrality Test".into(),
1505        statistic: stat,
1506        p_value,
1507        reject_h0: p_value < 0.05,
1508        message: format!(
1509            "H₀: last part is neutral with respect to sub-composition; \
1510             χ²({df:.0}) = {stat:.4}, p = {p_value:.4}"
1511        ),
1512    })
1513}
1514
1515/// Approximate chi-squared survival function P(X > x) for X ~ χ²(df).
1516///
1517/// Uses the regularised incomplete gamma function Q(df/2, x/2).
1518fn chi2_sf(x: f64, df: f64) -> f64 {
1519    if x <= 0.0 {
1520        return 1.0;
1521    }
1522    let a = df / 2.0;
1523    let b = x / 2.0;
1524    regularised_gamma_q(a, b)
1525}
1526
1527/// Regularised incomplete gamma function Q(a, x) = 1 − P(a, x).
1528///
1529/// Uses the series expansion for x < a+1 and continued fraction for x >= a+1.
1530fn regularised_gamma_q(a: f64, x: f64) -> f64 {
1531    if x < 0.0 {
1532        return 1.0;
1533    }
1534    if x == 0.0 {
1535        return 1.0;
1536    }
1537    if x < a + 1.0 {
1538        1.0 - gamma_series(a, x)
1539    } else {
1540        gamma_cf(a, x)
1541    }
1542}
1543
1544fn gamma_series(a: f64, x: f64) -> f64 {
1545    let max_iter = 200;
1546    let eps = 1e-14;
1547    let mut ap = a;
1548    let mut sum = 1.0 / a;
1549    let mut del = sum;
1550    for _ in 0..max_iter {
1551        ap += 1.0;
1552        del *= x / ap;
1553        sum += del;
1554        if del.abs() < sum.abs() * eps {
1555            break;
1556        }
1557    }
1558    sum * (-x + a * x.ln() - lgamma(a)).exp()
1559}
1560
1561fn gamma_cf(a: f64, x: f64) -> f64 {
1562    let max_iter = 200;
1563    let eps = 1e-14;
1564    let fpmin = 1e-300;
1565    let mut b = x + 1.0 - a;
1566    let mut c = 1.0 / fpmin;
1567    let mut d = 1.0 / b;
1568    let mut h = d;
1569    for i in 1..=max_iter {
1570        let an = -(i as f64) * ((i as f64) - a);
1571        b += 2.0;
1572        d = an * d + b;
1573        if d.abs() < fpmin {
1574            d = fpmin;
1575        }
1576        c = b + an / c;
1577        if c.abs() < fpmin {
1578            c = fpmin;
1579        }
1580        d = 1.0 / d;
1581        let del = d * c;
1582        h *= del;
1583        if (del - 1.0).abs() < eps {
1584            break;
1585        }
1586    }
1587    (-x + a * x.ln() - lgamma(a)).exp() * h
1588}
1589
1590// ---------------------------------------------------------------------------
1591// Tests
1592// ---------------------------------------------------------------------------
1593
1594#[cfg(test)]
1595mod tests {
1596    use super::*;
1597
1598    fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
1599        (a - b).abs() < tol
1600    }
1601
1602    // --- Simplex operations -------------------------------------------------
1603
1604    #[test]
1605    fn test_closure_sums_to_one() {
1606        let x = vec![1.0, 2.0, 3.0];
1607        let c = closure(&x).expect("closure");
1608        let sum: f64 = c.iter().sum();
1609        assert!(approx_eq(sum, 1.0, 1e-14));
1610    }
1611
1612    #[test]
1613    fn test_closure_proportional() {
1614        let x = vec![2.0, 4.0, 6.0];
1615        let c = closure(&x).expect("closure");
1616        assert!(approx_eq(c[0], 1.0 / 6.0, 1e-14));
1617        assert!(approx_eq(c[1], 2.0 / 6.0, 1e-14));
1618        assert!(approx_eq(c[2], 3.0 / 6.0, 1e-14));
1619    }
1620
1621    #[test]
1622    fn test_closure_rejects_non_positive() {
1623        assert!(closure(&[1.0, 0.0, 1.0]).is_err());
1624        assert!(closure(&[1.0, -1.0, 2.0]).is_err());
1625    }
1626
1627    #[test]
1628    fn test_perturbation_closed() {
1629        let x = vec![0.5, 0.3, 0.2];
1630        let y = vec![0.4, 0.4, 0.2];
1631        let p = perturbation(&x, &y).expect("perturbation");
1632        let sum: f64 = p.iter().sum();
1633        assert!(approx_eq(sum, 1.0, 1e-14));
1634    }
1635
1636    #[test]
1637    fn test_perturbation_dimension_mismatch() {
1638        let x = vec![0.5, 0.5];
1639        let y = vec![0.3, 0.4, 0.3];
1640        assert!(perturbation(&x, &y).is_err());
1641    }
1642
1643    #[test]
1644    fn test_powering_closed() {
1645        let x = vec![0.5, 0.3, 0.2];
1646        let p = powering(&x, 2.0).expect("powering");
1647        let sum: f64 = p.iter().sum();
1648        assert!(approx_eq(sum, 1.0, 1e-14));
1649    }
1650
1651    // --- Log-ratio transforms -----------------------------------------------
1652
1653    #[test]
1654    fn test_alr_round_trip() {
1655        let x = vec![0.5, 0.3, 0.2];
1656        let y = alr_transform(&x).expect("alr");
1657        let x2 = alr_inverse(&y).expect("alr_inv");
1658        let diff: f64 = x
1659            .iter()
1660            .zip(x2.iter())
1661            .map(|(a, b)| (a - b).powi(2))
1662            .sum::<f64>()
1663            .sqrt();
1664        assert!(diff < 1e-12, "ALR round-trip diff = {diff}");
1665    }
1666
1667    #[test]
1668    fn test_clr_sum_to_zero() {
1669        let x = vec![0.5, 0.3, 0.2];
1670        let y = clr_transform(&x).expect("clr");
1671        let sum: f64 = y.iter().sum();
1672        assert!(sum.abs() < 1e-13, "CLR sum should be 0, got {sum}");
1673    }
1674
1675    #[test]
1676    fn test_clr_round_trip() {
1677        let x = vec![0.4, 0.4, 0.2];
1678        let y = clr_transform(&x).expect("clr");
1679        let x2 = clr_inverse(&y).expect("clr_inv");
1680        let diff: f64 = x
1681            .iter()
1682            .zip(x2.iter())
1683            .map(|(a, b)| (a - b).powi(2))
1684            .sum::<f64>()
1685            .sqrt();
1686        assert!(diff < 1e-12, "CLR round-trip diff = {diff}");
1687    }
1688
1689    #[test]
1690    fn test_ilr_dimension() {
1691        let x = vec![0.5, 0.3, 0.2];
1692        let y = ilr_transform(&x).expect("ilr");
1693        assert_eq!(y.len(), 2);
1694    }
1695
1696    #[test]
1697    fn test_ilr_round_trip() {
1698        let x = vec![0.5, 0.3, 0.2];
1699        let y = ilr_transform(&x).expect("ilr");
1700        let x2 = ilr_inverse(&y, 3).expect("ilr_inv");
1701        let diff: f64 = x
1702            .iter()
1703            .zip(x2.iter())
1704            .map(|(a, b)| (a - b).powi(2))
1705            .sum::<f64>()
1706            .sqrt();
1707        assert!(diff < 1e-10, "ILR round-trip diff = {diff}");
1708    }
1709
1710    #[test]
1711    fn test_ilr_four_parts_round_trip() {
1712        let x = vec![0.25, 0.35, 0.25, 0.15];
1713        let y = ilr_transform(&x).expect("ilr");
1714        assert_eq!(y.len(), 3);
1715        let x2 = ilr_inverse(&y, 4).expect("ilr_inv");
1716        let diff: f64 = x
1717            .iter()
1718            .zip(x2.iter())
1719            .map(|(a, b)| (a - b).powi(2))
1720            .sum::<f64>()
1721            .sqrt();
1722        assert!(diff < 1e-10, "ILR 4-part round-trip diff = {diff}");
1723    }
1724
1725    // --- Aitchison geometry -------------------------------------------------
1726
1727    #[test]
1728    fn test_aitchison_distance_zero_same() {
1729        let x = vec![0.5, 0.3, 0.2];
1730        let d = aitchison_distance(&x, &x).expect("distance");
1731        assert!(d < 1e-12, "d(x,x) should be 0, got {d}");
1732    }
1733
1734    #[test]
1735    fn test_aitchison_distance_positive() {
1736        let x = vec![0.5, 0.3, 0.2];
1737        let y = vec![0.3, 0.4, 0.3];
1738        let d = aitchison_distance(&x, &y).expect("distance");
1739        assert!(d > 0.0, "d(x,y) > 0 for distinct x, y");
1740    }
1741
1742    #[test]
1743    fn test_aitchison_norm_non_negative() {
1744        let x = vec![0.5, 0.3, 0.2];
1745        let n = aitchison_norm(&x).expect("norm");
1746        assert!(n >= 0.0);
1747    }
1748
1749    #[test]
1750    fn test_aitchison_inner_product_symmetry() {
1751        let x = vec![0.5, 0.3, 0.2];
1752        let y = vec![0.3, 0.4, 0.3];
1753        let ip_xy = aitchison_inner_product(&x, &y).expect("ip xy");
1754        let ip_yx = aitchison_inner_product(&y, &x).expect("ip yx");
1755        assert!(approx_eq(ip_xy, ip_yx, 1e-13));
1756    }
1757
1758    // --- Dirichlet MLE ------------------------------------------------------
1759
1760    #[test]
1761    fn test_dirichlet_mle_returns_positive() {
1762        let data = vec![
1763            vec![0.3, 0.4, 0.3],
1764            vec![0.2, 0.5, 0.3],
1765            vec![0.4, 0.3, 0.3],
1766            vec![0.25, 0.35, 0.4],
1767            vec![0.35, 0.25, 0.4],
1768        ];
1769        let alpha = dirichlet_mle(&data).expect("mle");
1770        assert_eq!(alpha.len(), 3);
1771        for &a in alpha.iter() {
1772            assert!(a > 0.0, "alpha must be positive, got {a}");
1773        }
1774    }
1775
1776    #[test]
1777    fn test_dirichlet_mle_insufficient_data() {
1778        let data = vec![vec![0.5, 0.5]];
1779        assert!(dirichlet_mle(&data).is_err());
1780    }
1781
1782    // --- Dirichlet Regression -----------------------------------------------
1783
1784    #[test]
1785    fn test_dirichlet_regression_fit() {
1786        let responses = vec![
1787            vec![0.3, 0.4, 0.3],
1788            vec![0.2, 0.5, 0.3],
1789            vec![0.4, 0.3, 0.3],
1790            vec![0.25, 0.35, 0.4],
1791            vec![0.35, 0.25, 0.4],
1792        ];
1793        let covariates: Vec<Vec<f64>> = vec![vec![]; 5];
1794        let model = DirichletRegression::fit(&responses, &covariates).expect("fit");
1795        assert_eq!(model.n_parts, 3);
1796        assert!(model.precision > 0.0);
1797        assert!(model.log_likelihood.is_finite());
1798    }
1799
1800    #[test]
1801    fn test_dirichlet_regression_predict_sums_to_one() {
1802        let responses = vec![
1803            vec![0.3, 0.4, 0.3],
1804            vec![0.2, 0.5, 0.3],
1805            vec![0.4, 0.3, 0.3],
1806            vec![0.25, 0.35, 0.4],
1807        ];
1808        let covariates: Vec<Vec<f64>> = vec![vec![]; 4];
1809        let model = DirichletRegression::fit(&responses, &covariates).expect("fit");
1810        let pred = model.predict(&[]).expect("predict");
1811        let sum: f64 = pred.iter().sum();
1812        assert!((sum - 1.0).abs() < 1e-12, "prediction sum = {sum}");
1813    }
1814
1815    // --- Compositional PCA -------------------------------------------------
1816
1817    #[test]
1818    fn test_compositional_pca_basic() {
1819        let data = vec![
1820            vec![0.5, 0.3, 0.2],
1821            vec![0.4, 0.4, 0.2],
1822            vec![0.3, 0.5, 0.2],
1823            vec![0.6, 0.2, 0.2],
1824            vec![0.5, 0.2, 0.3],
1825        ];
1826        let pca = CompositionalPca::fit(&data, 2).expect("pca fit");
1827        assert_eq!(pca.n_components, 2);
1828        assert_eq!(pca.n_parts, 3);
1829    }
1830
1831    #[test]
1832    fn test_compositional_pca_transform() {
1833        let data = vec![
1834            vec![0.5, 0.3, 0.2],
1835            vec![0.4, 0.4, 0.2],
1836            vec![0.3, 0.5, 0.2],
1837            vec![0.6, 0.2, 0.2],
1838        ];
1839        let pca = CompositionalPca::fit(&data, 2).expect("pca fit");
1840        let scores = pca.transform(&data).expect("transform");
1841        assert_eq!(scores.len(), 4);
1842        assert_eq!(scores[0].len(), 2);
1843    }
1844
1845    #[test]
1846    fn test_compositional_pca_explained_variance() {
1847        let data = vec![
1848            vec![0.5, 0.3, 0.2],
1849            vec![0.4, 0.4, 0.2],
1850            vec![0.3, 0.5, 0.2],
1851            vec![0.6, 0.2, 0.2],
1852            vec![0.35, 0.35, 0.3],
1853        ];
1854        let pca = CompositionalPca::fit(&data, 2).expect("pca fit");
1855        for &ev in pca.explained_variance() {
1856            assert!(ev >= 0.0, "explained variance must be non-negative");
1857        }
1858        for &evr in pca.explained_variance_ratio.iter() {
1859            assert!(evr >= 0.0 && evr <= 1.0 + 1e-10, "EVR must be in [0,1]");
1860        }
1861    }
1862
1863    // --- Neutrality test ---------------------------------------------------
1864
1865    #[test]
1866    fn test_neutrality_test_runs() {
1867        let data = vec![
1868            vec![0.3, 0.4, 0.3],
1869            vec![0.2, 0.5, 0.3],
1870            vec![0.4, 0.3, 0.3],
1871            vec![0.25, 0.35, 0.4],
1872            vec![0.35, 0.25, 0.4],
1873        ];
1874        let result = neutrality_test(&data).expect("neutrality test");
1875        assert!(result.statistic.is_finite());
1876        assert!(result.p_value >= 0.0 && result.p_value <= 1.0);
1877    }
1878
1879    #[test]
1880    fn test_neutrality_test_insufficient_data() {
1881        let data = vec![vec![0.5, 0.3, 0.2]; 3];
1882        assert!(neutrality_test(&data).is_err());
1883    }
1884
1885    // --- Digamma / trigamma -------------------------------------------------
1886
1887    #[test]
1888    fn test_digamma_known_values() {
1889        // ψ(1) = -γ ≈ -0.5772...
1890        let psi1 = digamma(1.0);
1891        assert!(approx_eq(psi1, -0.5772156649, 1e-6), "ψ(1) = {psi1}");
1892        // ψ(2) = 1 - γ ≈ 0.4228...
1893        let psi2 = digamma(2.0);
1894        assert!(approx_eq(psi2, 0.4227843351, 1e-6), "ψ(2) = {psi2}");
1895    }
1896
1897    #[test]
1898    fn test_trigamma_known_values() {
1899        // ψ'(1) = π²/6 ≈ 1.6449...
1900        let tpsi1 = trigamma(1.0);
1901        assert!(
1902            approx_eq(
1903                tpsi1,
1904                std::f64::consts::PI * std::f64::consts::PI / 6.0,
1905                1e-5
1906            ),
1907            "ψ'(1) = {tpsi1}"
1908        );
1909    }
1910}