Skip to main content

shape_runtime/intrinsics/
distributions.rs

1//! Statistical distribution intrinsics — full migration to typed marshal layer.
2//!
3//! Per the intrinsics-typed-CC migration's per-file table, all 5 distribution
4//! intrinsics (`dist_uniform`, `dist_lognormal`, `dist_exponential`,
5//! `dist_poisson`, `dist_sample_n`) migrate to `register_typed_fn_N` typed
6//! entries via [`create_distributions_intrinsics_module`].
7//!
8//! `dist_sample_n` previously called sibling `intrinsic_dist_*` legacy bodies
9//! via `&[ValueWord]`; bodies are now refactored to delegate to pure-helper
10//! functions (`sample_uniform`, `sample_lognormal`, etc.) so both the typed
11//! entries and `dist_sample_n` share the same sampling math without going
12//! through the marshal layer twice.
13//!
14//! Sampling uses the thread-local RNG from the `random` module via
15//! `random::with_rng`.
16
17use super::random;
18use crate::marshal::{register_typed_fn_1, register_typed_fn_2, register_typed_fn_3};
19use crate::module_exports::ModuleExports;
20use crate::typed_module_exports::{ConcreteReturn, ConcreteType, TypedReturn};
21use rand::Rng;
22use rand_chacha::ChaCha8Rng;
23use std::sync::Arc;
24
25// ───────────────────── Module factory (5 typed entries) ─────────────────────
26
27/// Create the distributions intrinsics module with all 5 typed-marshal entry points.
28pub fn create_distributions_intrinsics_module() -> ModuleExports {
29    let mut module = ModuleExports::new("std::core::intrinsics::distributions");
30    module.description =
31        "Statistical distribution sampling intrinsics (uniform, lognormal, exponential, poisson, sample_n)"
32            .to_string();
33
34    register_typed_fn_2::<_, f64, f64>(
35        &mut module,
36        "__intrinsic_dist_uniform",
37        "Sample from a uniform distribution [lo, hi)",
38        [("lo", "number"), ("hi", "number")],
39        ConcreteType::Number,
40        |lo, hi, _ctx| {
41            if lo >= hi {
42                return Err(format!(
43                    "__intrinsic_dist_uniform: lo ({}) must be < hi ({})",
44                    lo, hi
45                ));
46            }
47            let value = random::with_rng(|rng| sample_uniform(rng, lo, hi));
48            Ok(TypedReturn::Concrete(ConcreteReturn::F64(value)))
49        },
50    );
51
52    register_typed_fn_2::<_, f64, f64>(
53        &mut module,
54        "__intrinsic_dist_lognormal",
55        "Sample from a lognormal distribution",
56        [("mean", "number"), ("std", "number")],
57        ConcreteType::Number,
58        |mean, std, _ctx| {
59            if std < 0.0 {
60                return Err("__intrinsic_dist_lognormal: std must be non-negative".to_string());
61            }
62            let value = random::with_rng(|rng| sample_lognormal(rng, mean, std));
63            Ok(TypedReturn::Concrete(ConcreteReturn::F64(value)))
64        },
65    );
66
67    register_typed_fn_1::<_, f64>(
68        &mut module,
69        "__intrinsic_dist_exponential",
70        "Sample from an exponential distribution",
71        "lambda",
72        "number",
73        ConcreteType::Number,
74        |lambda, _ctx| {
75            if lambda <= 0.0 {
76                return Err("__intrinsic_dist_exponential: lambda must be positive".to_string());
77            }
78            let value = random::with_rng(|rng| sample_exponential(rng, lambda));
79            Ok(TypedReturn::Concrete(ConcreteReturn::F64(value)))
80        },
81    );
82
83    register_typed_fn_1::<_, f64>(
84        &mut module,
85        "__intrinsic_dist_poisson",
86        "Sample from a Poisson distribution",
87        "lambda",
88        "number",
89        ConcreteType::Number,
90        |lambda, _ctx| {
91            if lambda < 0.0 {
92                return Err("__intrinsic_dist_poisson: lambda must be non-negative".to_string());
93            }
94            let value = random::with_rng(|rng| sample_poisson(rng, lambda));
95            Ok(TypedReturn::Concrete(ConcreteReturn::F64(value)))
96        },
97    );
98
99    register_typed_fn_3::<_, Arc<String>, Arc<Vec<f64>>, i64>(
100        &mut module,
101        "__intrinsic_dist_sample_n",
102        "Sample n values from a named distribution (uniform / lognormal / exponential / poisson)",
103        [
104            ("dist_name", "string"),
105            ("params", "Array<number>"),
106            ("n", "int"),
107        ],
108        ConcreteType::ArrayNumber,
109        |dist_name, params, n, _ctx| {
110            if n < 0 {
111                return Err("__intrinsic_dist_sample_n: n must be non-negative".to_string());
112            }
113            let n = n as usize;
114            let p = params.as_slice();
115            let dist = dist_name.as_str();
116
117            // Validate parameter shape for the named distribution.
118            match dist {
119                "uniform" | "lognormal" => {
120                    if p.len() != 2 {
121                        return Err(format!(
122                            "__intrinsic_dist_sample_n: '{}' requires 2 params, got {}",
123                            dist,
124                            p.len()
125                        ));
126                    }
127                }
128                "exponential" | "poisson" => {
129                    if p.len() != 1 {
130                        return Err(format!(
131                            "__intrinsic_dist_sample_n: '{}' requires 1 param, got {}",
132                            dist,
133                            p.len()
134                        ));
135                    }
136                }
137                _ => return Err(format!("Unknown distribution: {}", dist)),
138            }
139
140            // Per-distribution validity preconditions (mirror per-call typed entries).
141            match dist {
142                "uniform" if p[0] >= p[1] => {
143                    return Err(format!(
144                        "__intrinsic_dist_sample_n: uniform lo ({}) must be < hi ({})",
145                        p[0], p[1]
146                    ));
147                }
148                "lognormal" if p[1] < 0.0 => {
149                    return Err(
150                        "__intrinsic_dist_sample_n: lognormal std must be non-negative".to_string(),
151                    );
152                }
153                "exponential" if p[0] <= 0.0 => {
154                    return Err(
155                        "__intrinsic_dist_sample_n: exponential lambda must be positive"
156                            .to_string(),
157                    );
158                }
159                "poisson" if p[0] < 0.0 => {
160                    return Err(
161                        "__intrinsic_dist_sample_n: poisson lambda must be non-negative"
162                            .to_string(),
163                    );
164                }
165                _ => {}
166            }
167
168            let samples: Vec<f64> = random::with_rng(|rng| {
169                (0..n)
170                    .map(|_| match dist {
171                        "uniform" => sample_uniform(rng, p[0], p[1]),
172                        "lognormal" => sample_lognormal(rng, p[0], p[1]),
173                        "exponential" => sample_exponential(rng, p[0]),
174                        "poisson" => sample_poisson(rng, p[0]),
175                        _ => unreachable!(),
176                    })
177                    .collect()
178            });
179
180            Ok(TypedReturn::Concrete(ConcreteReturn::ArrayF64(samples)))
181        },
182    );
183
184    module
185}
186
187// ───────────────────── Sampling helpers (used by typed entries) ─────────────────────
188
189/// Sample from a uniform distribution [lo, hi).
190fn sample_uniform(rng: &mut ChaCha8Rng, lo: f64, hi: f64) -> f64 {
191    let u: f64 = rng.r#gen();
192    lo + (hi - lo) * u
193}
194
195/// Sample from a lognormal distribution via Box-Muller normal sampling.
196fn sample_lognormal(rng: &mut ChaCha8Rng, mean: f64, std: f64) -> f64 {
197    let u1: f64 = rng.r#gen();
198    let u2: f64 = rng.r#gen();
199    let z = (-2.0_f64 * u1.ln()).sqrt() * (2.0_f64 * std::f64::consts::PI * u2).cos();
200    (mean + std * z).exp()
201}
202
203/// Sample from an exponential distribution (inverse-CDF).
204fn sample_exponential(rng: &mut ChaCha8Rng, lambda: f64) -> f64 {
205    let u: f64 = rng.r#gen();
206    -u.ln() / lambda
207}
208
209/// Sample from a Poisson distribution.
210///
211/// For lambda < 30 uses Knuth's multiplicative method; for lambda >= 30 uses
212/// a normal approximation rounded to the nearest non-negative integer.
213fn sample_poisson(rng: &mut ChaCha8Rng, lambda: f64) -> f64 {
214    if lambda < 30.0 {
215        let l = (-lambda).exp();
216        let mut k = 0;
217        let mut p = 1.0;
218        loop {
219            k += 1;
220            let u: f64 = rng.r#gen();
221            p *= u;
222            if p <= l {
223                break;
224            }
225        }
226        (k - 1) as f64
227    } else {
228        let u1: f64 = rng.r#gen();
229        let u2: f64 = rng.r#gen();
230        let z = (-2.0_f64 * u1.ln()).sqrt() * (2.0_f64 * std::f64::consts::PI * u2).cos();
231        let value = lambda + lambda.sqrt() * z;
232        value.max(0.0).round()
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239    use rand::SeedableRng;
240    use crate::intrinsics::random as random_intrinsics;
241
242    fn mean_variance(samples: &[f64]) -> (f64, f64) {
243        let mean = samples.iter().sum::<f64>() / samples.len() as f64;
244        let var = samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / samples.len() as f64;
245        (mean, var)
246    }
247
248    #[test]
249    fn test_uniform_mean_variance() {
250        random_intrinsics::with_rng(|rng| {
251            *rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
252        });
253
254        let samples: Vec<f64> =
255            random_intrinsics::with_rng(|rng| (0..20000).map(|_| sample_uniform(rng, 0.0, 2.0)).collect());
256
257        let (mean, var) = mean_variance(&samples);
258        assert!((mean - 1.0).abs() < 0.05);
259        assert!((var - 1.0 / 3.0).abs() < 0.05);
260    }
261
262    #[test]
263    fn test_exponential_mean_variance() {
264        random_intrinsics::with_rng(|rng| {
265            *rng = rand_chacha::ChaCha8Rng::seed_from_u64(7);
266        });
267
268        let lambda = 2.0;
269        let samples: Vec<f64> = random_intrinsics::with_rng(|rng| {
270            (0..20000).map(|_| sample_exponential(rng, lambda)).collect()
271        });
272        let (mean, var) = mean_variance(&samples);
273        assert!((mean - 1.0 / lambda).abs() < 0.05);
274        assert!((var - 1.0 / (lambda * lambda)).abs() < 0.1);
275    }
276
277    #[test]
278    fn test_poisson_mean_variance() {
279        random_intrinsics::with_rng(|rng| {
280            *rng = rand_chacha::ChaCha8Rng::seed_from_u64(123);
281        });
282
283        let lambda = 12.0;
284        let samples: Vec<f64> = random_intrinsics::with_rng(|rng| {
285            (0..20000).map(|_| sample_poisson(rng, lambda)).collect()
286        });
287        let (mean, var) = mean_variance(&samples);
288        assert!((mean - lambda).abs() < 0.3);
289        assert!((var - lambda).abs() < 0.5);
290    }
291}