shape_runtime/intrinsics/
distributions.rs1use super::random;
18use crate::marshal::{register_typed_fn_1, register_typed_fn_2, register_typed_fn_3};
19use crate::module_exports::ModuleExports;
20use crate::typed_module_exports::{ConcreteReturn, ConcreteType, TypedReturn};
21use rand::Rng;
22use rand_chacha::ChaCha8Rng;
23use std::sync::Arc;
24
25pub fn create_distributions_intrinsics_module() -> ModuleExports {
29 let mut module = ModuleExports::new("std::core::intrinsics::distributions");
30 module.description =
31 "Statistical distribution sampling intrinsics (uniform, lognormal, exponential, poisson, sample_n)"
32 .to_string();
33
34 register_typed_fn_2::<_, f64, f64>(
35 &mut module,
36 "__intrinsic_dist_uniform",
37 "Sample from a uniform distribution [lo, hi)",
38 [("lo", "number"), ("hi", "number")],
39 ConcreteType::Number,
40 |lo, hi, _ctx| {
41 if lo >= hi {
42 return Err(format!(
43 "__intrinsic_dist_uniform: lo ({}) must be < hi ({})",
44 lo, hi
45 ));
46 }
47 let value = random::with_rng(|rng| sample_uniform(rng, lo, hi));
48 Ok(TypedReturn::Concrete(ConcreteReturn::F64(value)))
49 },
50 );
51
52 register_typed_fn_2::<_, f64, f64>(
53 &mut module,
54 "__intrinsic_dist_lognormal",
55 "Sample from a lognormal distribution",
56 [("mean", "number"), ("std", "number")],
57 ConcreteType::Number,
58 |mean, std, _ctx| {
59 if std < 0.0 {
60 return Err("__intrinsic_dist_lognormal: std must be non-negative".to_string());
61 }
62 let value = random::with_rng(|rng| sample_lognormal(rng, mean, std));
63 Ok(TypedReturn::Concrete(ConcreteReturn::F64(value)))
64 },
65 );
66
67 register_typed_fn_1::<_, f64>(
68 &mut module,
69 "__intrinsic_dist_exponential",
70 "Sample from an exponential distribution",
71 "lambda",
72 "number",
73 ConcreteType::Number,
74 |lambda, _ctx| {
75 if lambda <= 0.0 {
76 return Err("__intrinsic_dist_exponential: lambda must be positive".to_string());
77 }
78 let value = random::with_rng(|rng| sample_exponential(rng, lambda));
79 Ok(TypedReturn::Concrete(ConcreteReturn::F64(value)))
80 },
81 );
82
83 register_typed_fn_1::<_, f64>(
84 &mut module,
85 "__intrinsic_dist_poisson",
86 "Sample from a Poisson distribution",
87 "lambda",
88 "number",
89 ConcreteType::Number,
90 |lambda, _ctx| {
91 if lambda < 0.0 {
92 return Err("__intrinsic_dist_poisson: lambda must be non-negative".to_string());
93 }
94 let value = random::with_rng(|rng| sample_poisson(rng, lambda));
95 Ok(TypedReturn::Concrete(ConcreteReturn::F64(value)))
96 },
97 );
98
99 register_typed_fn_3::<_, Arc<String>, Arc<Vec<f64>>, i64>(
100 &mut module,
101 "__intrinsic_dist_sample_n",
102 "Sample n values from a named distribution (uniform / lognormal / exponential / poisson)",
103 [
104 ("dist_name", "string"),
105 ("params", "Array<number>"),
106 ("n", "int"),
107 ],
108 ConcreteType::ArrayNumber,
109 |dist_name, params, n, _ctx| {
110 if n < 0 {
111 return Err("__intrinsic_dist_sample_n: n must be non-negative".to_string());
112 }
113 let n = n as usize;
114 let p = params.as_slice();
115 let dist = dist_name.as_str();
116
117 match dist {
119 "uniform" | "lognormal" => {
120 if p.len() != 2 {
121 return Err(format!(
122 "__intrinsic_dist_sample_n: '{}' requires 2 params, got {}",
123 dist,
124 p.len()
125 ));
126 }
127 }
128 "exponential" | "poisson" => {
129 if p.len() != 1 {
130 return Err(format!(
131 "__intrinsic_dist_sample_n: '{}' requires 1 param, got {}",
132 dist,
133 p.len()
134 ));
135 }
136 }
137 _ => return Err(format!("Unknown distribution: {}", dist)),
138 }
139
140 match dist {
142 "uniform" if p[0] >= p[1] => {
143 return Err(format!(
144 "__intrinsic_dist_sample_n: uniform lo ({}) must be < hi ({})",
145 p[0], p[1]
146 ));
147 }
148 "lognormal" if p[1] < 0.0 => {
149 return Err(
150 "__intrinsic_dist_sample_n: lognormal std must be non-negative".to_string(),
151 );
152 }
153 "exponential" if p[0] <= 0.0 => {
154 return Err(
155 "__intrinsic_dist_sample_n: exponential lambda must be positive"
156 .to_string(),
157 );
158 }
159 "poisson" if p[0] < 0.0 => {
160 return Err(
161 "__intrinsic_dist_sample_n: poisson lambda must be non-negative"
162 .to_string(),
163 );
164 }
165 _ => {}
166 }
167
168 let samples: Vec<f64> = random::with_rng(|rng| {
169 (0..n)
170 .map(|_| match dist {
171 "uniform" => sample_uniform(rng, p[0], p[1]),
172 "lognormal" => sample_lognormal(rng, p[0], p[1]),
173 "exponential" => sample_exponential(rng, p[0]),
174 "poisson" => sample_poisson(rng, p[0]),
175 _ => unreachable!(),
176 })
177 .collect()
178 });
179
180 Ok(TypedReturn::Concrete(ConcreteReturn::ArrayF64(samples)))
181 },
182 );
183
184 module
185}
186
187fn sample_uniform(rng: &mut ChaCha8Rng, lo: f64, hi: f64) -> f64 {
191 let u: f64 = rng.r#gen();
192 lo + (hi - lo) * u
193}
194
195fn sample_lognormal(rng: &mut ChaCha8Rng, mean: f64, std: f64) -> f64 {
197 let u1: f64 = rng.r#gen();
198 let u2: f64 = rng.r#gen();
199 let z = (-2.0_f64 * u1.ln()).sqrt() * (2.0_f64 * std::f64::consts::PI * u2).cos();
200 (mean + std * z).exp()
201}
202
203fn sample_exponential(rng: &mut ChaCha8Rng, lambda: f64) -> f64 {
205 let u: f64 = rng.r#gen();
206 -u.ln() / lambda
207}
208
209fn sample_poisson(rng: &mut ChaCha8Rng, lambda: f64) -> f64 {
214 if lambda < 30.0 {
215 let l = (-lambda).exp();
216 let mut k = 0;
217 let mut p = 1.0;
218 loop {
219 k += 1;
220 let u: f64 = rng.r#gen();
221 p *= u;
222 if p <= l {
223 break;
224 }
225 }
226 (k - 1) as f64
227 } else {
228 let u1: f64 = rng.r#gen();
229 let u2: f64 = rng.r#gen();
230 let z = (-2.0_f64 * u1.ln()).sqrt() * (2.0_f64 * std::f64::consts::PI * u2).cos();
231 let value = lambda + lambda.sqrt() * z;
232 value.max(0.0).round()
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use rand::SeedableRng;
240 use crate::intrinsics::random as random_intrinsics;
241
242 fn mean_variance(samples: &[f64]) -> (f64, f64) {
243 let mean = samples.iter().sum::<f64>() / samples.len() as f64;
244 let var = samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / samples.len() as f64;
245 (mean, var)
246 }
247
248 #[test]
249 fn test_uniform_mean_variance() {
250 random_intrinsics::with_rng(|rng| {
251 *rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
252 });
253
254 let samples: Vec<f64> =
255 random_intrinsics::with_rng(|rng| (0..20000).map(|_| sample_uniform(rng, 0.0, 2.0)).collect());
256
257 let (mean, var) = mean_variance(&samples);
258 assert!((mean - 1.0).abs() < 0.05);
259 assert!((var - 1.0 / 3.0).abs() < 0.05);
260 }
261
262 #[test]
263 fn test_exponential_mean_variance() {
264 random_intrinsics::with_rng(|rng| {
265 *rng = rand_chacha::ChaCha8Rng::seed_from_u64(7);
266 });
267
268 let lambda = 2.0;
269 let samples: Vec<f64> = random_intrinsics::with_rng(|rng| {
270 (0..20000).map(|_| sample_exponential(rng, lambda)).collect()
271 });
272 let (mean, var) = mean_variance(&samples);
273 assert!((mean - 1.0 / lambda).abs() < 0.05);
274 assert!((var - 1.0 / (lambda * lambda)).abs() < 0.1);
275 }
276
277 #[test]
278 fn test_poisson_mean_variance() {
279 random_intrinsics::with_rng(|rng| {
280 *rng = rand_chacha::ChaCha8Rng::seed_from_u64(123);
281 });
282
283 let lambda = 12.0;
284 let samples: Vec<f64> = random_intrinsics::with_rng(|rng| {
285 (0..20000).map(|_| sample_poisson(rng, lambda)).collect()
286 });
287 let (mean, var) = mean_variance(&samples);
288 assert!((mean - lambda).abs() < 0.3);
289 assert!((var - lambda).abs() < 0.5);
290 }
291}