simd_kernels/kernels/scientific/distributions/shared/
sampler.rs

1// Copyright Peter Bower 2025. All Rights Reserved.
2// Licensed under Mozilla Public License (MPL) 2.0.
3
4//! # Statistical Sampling Module — High-Performance Pseudorandom Distribution Sampling
5//!
6//! Pseudorandom number generation kernels providing sampling from statistical
7//! distributions with strong performance and distributional accuracy.
8
9use minarrow::Vec64;
10use rand::rngs::ThreadRng;
11use rand::{rng, Rng};
12use std::f64::consts::PI;
13
14/// Thread-local statistical distribution sampler backed by a high-quality PRNG.
15pub struct Sampler {
16    rng: ThreadRng,
17}
18
19impl Sampler {
20    /// Creates a new sampler instance with a thread-local pseudorandom number generator.
21    #[inline]
22    pub fn new() -> Self {
23        Sampler { rng: rng() }
24    }
25
26    /// Generates a single sample from the standard normal distribution N(0, 1).
27    #[inline]
28    pub fn sample_standard_normal(&mut self) -> f64 {
29        sample_standard_normal(&mut self.rng)
30    }
31
32    /// Marsaglia–Tsang for Γ(shape, scale). Preconditions: shape > 0, scale > 0.
33    #[inline]
34    fn sample_gamma(&mut self, shape: f64, scale: f64) -> f64 {
35        sample_gamma(&mut self.rng, shape, scale)
36    }
37
38    /// Gamma(shape, scale). Preconditions: shape > 0, scale > 0.
39    #[inline]
40    pub fn gamma(&mut self, shape: f64, scale: f64) -> f64 {
41        self.sample_gamma(shape, scale)
42    }
43
44    /// Chi-square(df) == Gamma(df/2, 2). Preconditions: df > 0.
45    #[inline]
46    pub fn chi2(&mut self, df: f64) -> f64 {
47        assert!(df.is_finite() && df > 0.0, "df must be finite and > 0");
48        self.gamma(df * 0.5, 2.0)
49    }
50
51    /// Vector of iid N(0,1) samples of length `dim`.
52    #[inline]
53    pub fn standard_normal_vec(&mut self, dim: usize) -> Vec64<f64> {
54        let mut v = Vec64::with_capacity(dim);
55        for _ in 0..dim {
56            v.push(self.sample_standard_normal());
57        }
58        v
59    }
60
61    /// Dirichlet-distributed probability vector via normalised gamma sampling.
62    ///
63    /// Preconditions: `alpha` non-empty; all entries finite and > 0.
64    #[inline]
65    pub fn dirichlet(&mut self, alpha: &[f64]) -> Vec64<f64> {
66        assert!(!alpha.is_empty(), "alpha must be non-empty");
67        assert!(
68            alpha.iter().all(|&a| a.is_finite() && a > 0.0),
69            "all alpha entries must be finite and > 0"
70        );
71
72        let mut draw = Vec64::with_capacity(alpha.len());
73        let mut sum = 0.0;
74        for &a in alpha {
75            let x = self.gamma(a, 1.0);
76            sum += x;
77            draw.push(x);
78        }
79        // With the preconditions above, sum > 0 with probability 1
80        draw.iter_mut().for_each(|v| *v /= sum);
81        draw
82    }
83}
84
85// Box–Muller to get one N(0,1)
86/// Generates a single sample from the standard normal distribution N(0,1).
87#[inline]
88pub fn sample_standard_normal<R: Rng + ?Sized>(rng: &mut R) -> f64 {
89    // U1 ∈ (0,1], U2 ∈ [0,1)
90    let u1: f64 = rng.random::<f64>().max(f64::MIN_POSITIVE); // avoid log(0)
91    let u2: f64 = rng.random::<f64>();
92    let r = (-2.0 * u1.ln()).sqrt();
93    r * (2.0 * PI * u2).cos()
94}
95
96/// Generates a single sample from the Gamma distribution using the Marsaglia–Tsang algorithm.
97/// Preconditions: shape > 0, scale > 0.
98#[inline]
99pub fn sample_gamma<R: Rng + ?Sized>(rng: &mut R, shape: f64, scale: f64) -> f64 {
100    assert!(shape.is_finite() && shape > 0.0, "shape must be finite and > 0");
101    assert!(scale.is_finite() && scale > 0.0, "scale must be finite and > 0");
102
103    // Handle 0 < shape < 1 by boosting to shape+1, then apply a power-law correction.
104    if shape < 1.0 {
105        let u: f64 = rng.random::<f64>();
106        return sample_gamma(rng, shape + 1.0, scale) * u.powf(1.0 / shape);
107    }
108
109    let d = shape - 1.0 / 3.0;
110    let c = 1.0 / (9.0 * d).sqrt(); // Correct: c = 1 / sqrt(9d)
111
112    loop {
113        let x = sample_standard_normal(rng);
114        let one_plus_cx = 1.0 + c * x;
115        if one_plus_cx <= 0.0 {
116            continue;
117        }
118        let v = one_plus_cx * one_plus_cx * one_plus_cx; // (1 + c x)^3
119        let u: f64 = rng.random::<f64>();
120
121        // Squeeze step
122        if u < 1.0 - 0.0331 * (x * x) * (x * x) {
123            return d * v * scale;
124        }
125        // Log acceptance step
126        if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
127            return d * v * scale;
128        }
129    }
130}