Skip to main content

scirs2_stats/
simd_sampling.rs

1//! SIMD-accelerated batch sampling for continuous distributions.
2//!
3//! This module provides high-throughput batch sampling for the Normal, Uniform, and
4//! Exponential distributions, with vectorised mathematical kernels drawn from
5//! `scirs2_core`'s `SimdUnifiedOps` trait.
6//!
7//! # Design
8//!
9//! Each sampler generates a block of uniform random numbers first (the only serial
10//! bottleneck), then applies the appropriate mathematical transform entirely through
11//! SIMD operations:
12//!
13//! - **Normal** (`sample_normal_batch`): Box-Muller transform — two uniform variates
14//!   per pair produce two independent standard-normal variates via `ln`, `sqrt`,
15//!   `sin`, `cos`.
16//! - **Uniform** (`sample_uniform_batch`): linear rescaling `a + u * (b - a)` using
17//!   vectorised FMA.
18//! - **Exponential** (`sample_exponential_batch`): inverse-CDF transform
19//!   `-ln(u) / λ` using vectorised `ln` and `mul`.
20//!
21//! Vectorised CDF and PDF evaluation functions are also provided for all three
22//! distributions.
23//!
24//! # Performance
25//!
26//! On arrays larger than the SIMD scalar threshold (~64 elements), all mathematical
27//! transforms are executed as SIMD kernels.  Random number generation itself uses
28//! `scirs2_core::random` (pure-Rust Xoshiro256++ under the hood).
29
30use crate::error::{StatsError, StatsResult};
31use scirs2_core::ndarray::{Array1, ArrayView1};
32use scirs2_core::numeric::{Float, NumCast};
33use scirs2_core::random::uniform::SampleUniform;
34use scirs2_core::random::Rng;
35use scirs2_core::simd_ops::{AutoOptimizer, SimdUnifiedOps};
36
37// ============================================================================
38// Internal macro / helper for RNG construction
39// ============================================================================
40
41/// Resolve a seed: if `seed` is `None`, draw a u64 from `thread_rng`;
42/// then always call `seeded_rng` so that the return type is uniform.
43macro_rules! build_rng {
44    ($seed:expr) => {{
45        let s: u64 = $seed.unwrap_or_else(|| {
46            use scirs2_core::random::Rng;
47            scirs2_core::random::thread_rng().random()
48        });
49        scirs2_core::random::seeded_rng(s)
50    }};
51}
52
53// ============================================================================
54// Normal distribution — batch sampling
55// ============================================================================
56
57/// Generate `n` independent samples from N(`mean`, `std_dev`²) using a
58/// SIMD-accelerated Box-Muller transform.
59///
60/// The Box-Muller transform converts pairs of independent U(0, 1) random
61/// variables (u₁, u₂) into pairs of independent standard-normal variates:
62///
63/// ```text
64/// z₀ = √(−2 ln u₁) · cos(2π u₂)
65/// z₁ = √(−2 ln u₁) · sin(2π u₂)
66/// ```
67///
68/// Both `ln`, `sqrt`, `cos`, and `sin` are evaluated via SIMD intrinsics from
69/// `scirs2_core` when the batch is large enough.  The result is then scaled by
70/// `std_dev` and shifted by `mean` using a fused-multiply-add pass.
71///
72/// # Arguments
73///
74/// * `n`       — Number of samples to generate.
75/// * `mean`    — Location parameter of the Normal distribution.
76/// * `std_dev` — Scale parameter (must be positive).
77/// * `seed`    — Optional RNG seed for reproducibility.
78///
79/// # Errors
80///
81/// Returns [`StatsError::InvalidArgument`] when `n == 0` or `std_dev <= 0`.
82///
83/// # Examples
84///
85/// ```
86/// use scirs2_stats::sample_normal_batch;
87///
88/// let samples = sample_normal_batch::<f64>(1_000, 0.0, 1.0, Some(42))
89///     .expect("sampling failed");
90/// assert_eq!(samples.len(), 1_000);
91///
92/// // Empirical mean should be close to 0
93/// let mean: f64 = samples.sum() / samples.len() as f64;
94/// assert!(mean.abs() < 0.15);
95/// ```
96pub fn sample_normal_batch<F>(
97    n: usize,
98    mean: F,
99    std_dev: F,
100    seed: Option<u64>,
101) -> StatsResult<Array1<F>>
102where
103    F: Float + NumCast + SimdUnifiedOps + SampleUniform,
104{
105    if n == 0 {
106        return Err(StatsError::InvalidArgument(
107            "n must be at least 1".to_string(),
108        ));
109    }
110    if std_dev <= F::zero() {
111        return Err(StatsError::InvalidArgument(
112            "std_dev must be positive for the Normal distribution".to_string(),
113        ));
114    }
115
116    // We need ⌈n/2⌉ pairs for Box-Muller.
117    let n_pairs = (n + 1) / 2;
118    let n_total = n_pairs * 2;
119
120    let mut rng = build_rng!(seed);
121
122    // ── Phase 1: generate uniform pairs ──────────────────────────────────────
123    // u1 ∈ (ε, 1) so that ln(u1) is finite; u2 ∈ [0, 1).
124    let eps = F::epsilon().to_f64().unwrap_or(1e-15_f64).max(1e-15_f64);
125
126    let u1: Array1<F> = Array1::from_shape_fn(n_pairs, |_| {
127        F::from(rng.gen_range(eps..1.0_f64)).unwrap_or_else(|| F::one())
128    });
129    let u2: Array1<F> = Array1::from_shape_fn(n_pairs, |_| {
130        F::from(rng.gen_range(0.0_f64..1.0_f64)).unwrap_or_else(|| F::zero())
131    });
132
133    // ── Phase 2: Box-Muller transform ─────────────────────────────────────────
134    let optimizer = AutoOptimizer::new();
135
136    let (z0, z1) = if optimizer.should_use_simd(n_pairs) {
137        // SIMD path ─────────────────────────────────────────────────────────
138
139        // √(−2 · ln(u1)):  apply simd_ln, negate×2, simd_sqrt
140        let ln_u1 = F::simd_ln(&u1.view());
141        let neg_two = F::from(-2.0_f64).unwrap_or_else(|| -F::one() - F::one());
142        let neg_two_arr = Array1::from_elem(n_pairs, neg_two);
143        let neg2_ln_u1 = F::simd_mul(&neg_two_arr.view(), &ln_u1.view());
144        let r = F::simd_sqrt(&neg2_ln_u1.view());
145
146        // 2π · u2:  use simd_mul then cos/sin
147        let two_pi = F::from(2.0_f64 * std::f64::consts::PI).unwrap_or_else(|| F::one());
148        let two_pi_arr = Array1::from_elem(n_pairs, two_pi);
149        let theta = F::simd_mul(&two_pi_arr.view(), &u2.view());
150
151        let cos_theta = F::simd_cos(&theta.view());
152        let sin_theta = F::simd_sin(&theta.view());
153
154        let raw_z0 = F::simd_mul(&r.view(), &cos_theta.view());
155        let raw_z1 = F::simd_mul(&r.view(), &sin_theta.view());
156
157        // Scale: z * std_dev + mean  (FMA)
158        let std_arr = Array1::from_elem(n_pairs, std_dev);
159        let mean_arr = Array1::from_elem(n_pairs, mean);
160
161        let z0 = F::simd_fma(&raw_z0.view(), &std_arr.view(), &mean_arr.view());
162        let z1 = F::simd_fma(&raw_z1.view(), &std_arr.view(), &mean_arr.view());
163
164        (z0, z1)
165    } else {
166        // Scalar fallback ──────────────────────────────────────────────────
167        let two = F::from(2.0_f64).unwrap_or_else(|| F::one() + F::one());
168        let two_pi = F::from(2.0_f64 * std::f64::consts::PI).unwrap_or_else(|| F::one());
169
170        let mut z0 = Array1::zeros(n_pairs);
171        let mut z1 = Array1::zeros(n_pairs);
172        for i in 0..n_pairs {
173            let r = (-two * u1[i].ln()).sqrt();
174            let theta = two_pi * u2[i];
175            z0[i] = mean + std_dev * r * theta.cos();
176            z1[i] = mean + std_dev * r * theta.sin();
177        }
178        (z0, z1)
179    };
180
181    // ── Phase 3: interleave z0/z1 and trim to n ───────────────────────────────
182    let mut result: Array1<F> = Array1::zeros(n_total);
183    for i in 0..n_pairs {
184        result[2 * i] = z0[i];
185        if 2 * i + 1 < n_total {
186            result[2 * i + 1] = z1[i];
187        }
188    }
189
190    Ok(result.slice(scirs2_core::ndarray::s![..n]).to_owned())
191}
192
193// ============================================================================
194// Uniform distribution — batch sampling
195// ============================================================================
196
197/// Generate `n` independent samples from U(`low`, `high`) using SIMD rescaling.
198///
199/// Each sample is drawn from the standard uniform U(0, 1) and then linearly
200/// rescaled:
201///
202/// ```text
203/// x = low + u · (high − low)
204/// ```
205///
206/// The rescaling is performed as a SIMD FMA pass over the full batch, giving
207/// vectorised throughput equal to the underlying RNG bottleneck.
208///
209/// # Arguments
210///
211/// * `n`    — Number of samples.
212/// * `low`  — Lower bound of the interval (inclusive).
213/// * `high` — Upper bound of the interval (exclusive).
214/// * `seed` — Optional RNG seed.
215///
216/// # Errors
217///
218/// Returns [`StatsError::InvalidArgument`] when `n == 0` or `low >= high`.
219///
220/// # Examples
221///
222/// ```
223/// use scirs2_stats::sample_uniform_batch;
224///
225/// let samples = sample_uniform_batch::<f64>(500, 2.0, 5.0, Some(7))
226///     .expect("sampling failed");
227/// assert_eq!(samples.len(), 500);
228/// assert!(samples.iter().all(|&x| x >= 2.0 && x < 5.0));
229/// ```
230pub fn sample_uniform_batch<F>(
231    n: usize,
232    low: F,
233    high: F,
234    seed: Option<u64>,
235) -> StatsResult<Array1<F>>
236where
237    F: Float + NumCast + SimdUnifiedOps + SampleUniform,
238{
239    if n == 0 {
240        return Err(StatsError::InvalidArgument(
241            "n must be at least 1".to_string(),
242        ));
243    }
244    if low >= high {
245        return Err(StatsError::InvalidArgument(
246            "low must be strictly less than high for the Uniform distribution".to_string(),
247        ));
248    }
249
250    let mut rng = build_rng!(seed);
251
252    // ── Phase 1: generate raw U(0, 1) values ──────────────────────────────────
253    let u: Array1<F> = Array1::from_shape_fn(n, |_| {
254        F::from(rng.gen_range(0.0_f64..1.0_f64)).unwrap_or_else(|| F::zero())
255    });
256
257    // ── Phase 2: linear rescaling via SIMD ────────────────────────────────────
258    let optimizer = AutoOptimizer::new();
259    let width = high - low;
260
261    if optimizer.should_use_simd(n) {
262        // x = low + u * width  — computed as FMA: u * width + low
263        let width_arr = Array1::from_elem(n, width);
264        let low_arr = Array1::from_elem(n, low);
265        Ok(F::simd_fma(&u.view(), &width_arr.view(), &low_arr.view()))
266    } else {
267        Ok(u.mapv(|ui| low + ui * width))
268    }
269}
270
271// ============================================================================
272// Exponential distribution — batch sampling
273// ============================================================================
274
275/// Generate `n` independent samples from Exp(`rate`) using a SIMD-accelerated
276/// inverse-CDF transform.
277///
278/// The CDF of Exp(λ) is F(x) = 1 − e^(−λx), so the inverse is:
279///
280/// ```text
281/// x = −ln(u) / λ
282/// ```
283///
284/// where u ~ U(0, 1).  The vectorised `ln` from `scirs2_core` is applied to
285/// the entire batch at once.
286///
287/// # Arguments
288///
289/// * `n`    — Number of samples.
290/// * `rate` — Rate parameter λ (must be positive; mean = 1/λ).
291/// * `seed` — Optional RNG seed.
292///
293/// # Errors
294///
295/// Returns [`StatsError::InvalidArgument`] when `n == 0` or `rate <= 0`.
296///
297/// # Examples
298///
299/// ```
300/// use scirs2_stats::sample_exponential_batch;
301///
302/// let rate = 2.0_f64;
303/// let samples = sample_exponential_batch::<f64>(1_000, rate, Some(99))
304///     .expect("sampling failed");
305/// assert_eq!(samples.len(), 1_000);
306///
307/// // Empirical mean should be near 1/rate = 0.5
308/// let mean: f64 = samples.sum() / samples.len() as f64;
309/// assert!((mean - 0.5).abs() < 0.08);
310/// ```
311pub fn sample_exponential_batch<F>(n: usize, rate: F, seed: Option<u64>) -> StatsResult<Array1<F>>
312where
313    F: Float + NumCast + SimdUnifiedOps + SampleUniform,
314{
315    if n == 0 {
316        return Err(StatsError::InvalidArgument(
317            "n must be at least 1".to_string(),
318        ));
319    }
320    if rate <= F::zero() {
321        return Err(StatsError::InvalidArgument(
322            "rate must be positive for the Exponential distribution".to_string(),
323        ));
324    }
325
326    let eps = F::epsilon().to_f64().unwrap_or(1e-15_f64).max(1e-15_f64);
327
328    let mut rng = build_rng!(seed);
329
330    // ── Phase 1: generate U(ε, 1) to avoid ln(0) ─────────────────────────────
331    let u: Array1<F> = Array1::from_shape_fn(n, |_| {
332        F::from(rng.gen_range(eps..1.0_f64)).unwrap_or_else(|| F::one())
333    });
334
335    // ── Phase 2: inverse-CDF transform via SIMD ───────────────────────────────
336    let optimizer = AutoOptimizer::new();
337
338    if optimizer.should_use_simd(n) {
339        // ln(u)
340        let ln_u = F::simd_ln(&u.view());
341
342        // −ln(u):  multiply by −1
343        let neg_one = F::from(-1.0_f64).unwrap_or_else(|| -F::one());
344        let neg_one_arr = Array1::from_elem(n, neg_one);
345        let neg_ln_u = F::simd_mul(&neg_one_arr.view(), &ln_u.view());
346
347        // divide by rate:  (1/rate) scalar multiplication
348        let inv_rate = F::one() / rate;
349        let inv_rate_arr = Array1::from_elem(n, inv_rate);
350        Ok(F::simd_mul(&neg_ln_u.view(), &inv_rate_arr.view()))
351    } else {
352        Ok(u.mapv(|ui| -ui.ln() / rate))
353    }
354}
355
356// ============================================================================
357// Vectorised PDF / CDF evaluation
358// ============================================================================
359
360/// Evaluate the Normal PDF at each point in `x` using SIMD.
361///
362/// Computes the probability density
363/// ```text
364/// f(x) = exp(−(x − μ)² / (2σ²)) / (σ √(2π))
365/// ```
366/// over an entire array in a single vectorised pass.
367///
368/// # Arguments
369///
370/// * `x`       — Points at which to evaluate.
371/// * `mean`    — Distribution mean µ.
372/// * `std_dev` — Distribution standard deviation σ (must be positive).
373///
374/// # Errors
375///
376/// Returns [`StatsError::InvalidArgument`] when `std_dev <= 0` or `x` is empty.
377///
378/// # Examples
379///
380/// ```
381/// use scirs2_core::ndarray::array;
382/// use scirs2_stats::normal_pdf_batch;
383///
384/// let x = array![0.0_f64, 1.0, -1.0];
385/// let pdf = normal_pdf_batch(&x.view(), 0.0, 1.0).expect("failed");
386/// // pdf(0) ≈ 0.3989
387/// assert!((pdf[0] - 0.3989422804_f64).abs() < 1e-8);
388/// ```
389pub fn normal_pdf_batch<F>(x: &ArrayView1<F>, mean: F, std_dev: F) -> StatsResult<Array1<F>>
390where
391    F: Float + NumCast + SimdUnifiedOps,
392{
393    if x.is_empty() {
394        return Err(StatsError::InvalidArgument("x is empty".to_string()));
395    }
396    if std_dev <= F::zero() {
397        return Err(StatsError::InvalidArgument(
398            "std_dev must be positive".to_string(),
399        ));
400    }
401
402    let n = x.len();
403    let two = F::from(2.0_f64).unwrap_or_else(|| F::one() + F::one());
404    let two_pi = F::from(2.0_f64 * std::f64::consts::PI).unwrap_or_else(|| F::one());
405    let norm_const = F::one() / (std_dev * two_pi.sqrt());
406    let two_var = two * std_dev * std_dev;
407
408    let optimizer = AutoOptimizer::new();
409
410    if optimizer.should_use_simd(n) {
411        // (x − μ)
412        let mean_arr = Array1::from_elem(n, mean);
413        let diff = F::simd_sub(x, &mean_arr.view());
414
415        // (x − μ)²
416        let diff_sq = F::simd_mul(&diff.view(), &diff.view());
417
418        // −(x − μ)² / (2σ²)
419        let inv_two_var = F::one() / two_var;
420        let neg_inv = F::from(-1.0_f64).unwrap_or_else(|| -F::one()) * inv_two_var;
421        let neg_inv_arr = Array1::from_elem(n, neg_inv);
422        let exponent = F::simd_mul(&neg_inv_arr.view(), &diff_sq.view());
423
424        // exp(exponent)
425        let exp_vals = F::simd_exp(&exponent.view());
426
427        // multiply by normalisation constant
428        let norm_arr = Array1::from_elem(n, norm_const);
429        Ok(F::simd_mul(&norm_arr.view(), &exp_vals.view()))
430    } else {
431        Ok(x.mapv(|xi| {
432            let z = (xi - mean) / std_dev;
433            norm_const * (-(z * z) / two).exp()
434        }))
435    }
436}
437
438/// Evaluate the Normal CDF at each point in `x` using a SIMD-friendly
439/// rational approximation (Abramowitz & Stegun 26.2.17, max error 7.5 × 10⁻⁸).
440///
441/// The standard-normal CDF Φ(z) is approximated element-wise; for non-standard
442/// parameters the inputs are standardised: z = (x − µ) / σ.
443///
444/// # Arguments
445///
446/// * `x`       — Points at which to evaluate.
447/// * `mean`    — Distribution mean µ.
448/// * `std_dev` — Distribution standard deviation σ (must be positive).
449///
450/// # Errors
451///
452/// Returns [`StatsError::InvalidArgument`] when `std_dev <= 0` or `x` is empty.
453///
454/// # Examples
455///
456/// ```
457/// use scirs2_core::ndarray::array;
458/// use scirs2_stats::normal_cdf_batch;
459///
460/// let x = array![0.0_f64, 1.959964_f64];
461/// let cdf = normal_cdf_batch(&x.view(), 0.0, 1.0).expect("failed");
462/// // CDF(0) ≈ 0.5,  CDF(1.96) ≈ 0.975
463/// assert!((cdf[0] - 0.5).abs() < 1e-6);
464/// assert!((cdf[1] - 0.975).abs() < 1e-4);
465/// ```
466pub fn normal_cdf_batch<F>(x: &ArrayView1<F>, mean: F, std_dev: F) -> StatsResult<Array1<F>>
467where
468    F: Float + NumCast + SimdUnifiedOps,
469{
470    if x.is_empty() {
471        return Err(StatsError::InvalidArgument("x is empty".to_string()));
472    }
473    if std_dev <= F::zero() {
474        return Err(StatsError::InvalidArgument(
475            "std_dev must be positive".to_string(),
476        ));
477    }
478
479    // Standardise: z = (x − µ) / σ
480    let n = x.len();
481    let mean_arr = Array1::from_elem(n, mean);
482    let std_arr = Array1::from_elem(n, std_dev);
483
484    let optimizer = AutoOptimizer::new();
485
486    let z_owned = if optimizer.should_use_simd(n) {
487        let diff = F::simd_sub(x, &mean_arr.view());
488        F::simd_div(&diff.view(), &std_arr.view())
489    } else {
490        x.mapv(|xi| (xi - mean) / std_dev)
491    };
492
493    // Scalar A&S polynomial approximation per element (vectorisable but intrinsic-
494    // heavy; let the compiler auto-vectorise the loop).
495    let half = F::from(0.5_f64).unwrap_or_else(|| F::one() / (F::one() + F::one()));
496    let a1 = F::from(0.319_381_530_f64).unwrap_or_else(|| F::zero());
497    let a2 = F::from(-0.356_563_782_f64).unwrap_or_else(|| F::zero());
498    let a3 = F::from(1.781_477_937_f64).unwrap_or_else(|| F::zero());
499    let a4 = F::from(-1.821_255_978_f64).unwrap_or_else(|| F::zero());
500    let a5 = F::from(1.330_274_429_f64).unwrap_or_else(|| F::zero());
501    let p = F::from(0.231_641_9_f64).unwrap_or_else(|| F::zero());
502    let two_pi = F::from(2.0_f64 * std::f64::consts::PI).unwrap_or_else(|| F::one());
503
504    let result: Array1<F> = z_owned.mapv(|z| {
505        let abs_z = z.abs();
506        let t = F::one() / (F::one() + p * abs_z);
507        let poly = t * (a1 + t * (a2 + t * (a3 + t * (a4 + t * a5))));
508        let pdf_z = (-(z * z) / (F::one() + F::one())).exp() / two_pi.sqrt();
509        let cdf_abs = F::one() - pdf_z * poly;
510        if z >= F::zero() {
511            cdf_abs
512        } else {
513            F::one() - cdf_abs
514        }
515        // Clamp to [0, 1]
516        .max(F::zero())
517        .min(F::one())
518    });
519
520    Ok(result)
521}
522
523/// Evaluate the Uniform PDF at each point in `x` using SIMD comparisons.
524///
525/// Returns 1/(high − low) for x ∈ [low, high) and 0 elsewhere.
526///
527/// # Arguments
528///
529/// * `x`    — Points at which to evaluate.
530/// * `low`  — Lower bound.
531/// * `high` — Upper bound (must be strictly greater than `low`).
532///
533/// # Errors
534///
535/// Returns [`StatsError::InvalidArgument`] when `low >= high` or `x` is empty.
536///
537/// # Examples
538///
539/// ```
540/// use scirs2_core::ndarray::array;
541/// use scirs2_stats::uniform_pdf_batch;
542///
543/// let x = array![0.5_f64, 1.5, 2.5];
544/// let pdf = uniform_pdf_batch(&x.view(), 1.0, 2.0).expect("failed");
545/// assert!((pdf[0] - 0.0).abs() < 1e-10);
546/// assert!((pdf[1] - 1.0).abs() < 1e-10);
547/// assert!((pdf[2] - 0.0).abs() < 1e-10);
548/// ```
549pub fn uniform_pdf_batch<F>(x: &ArrayView1<F>, low: F, high: F) -> StatsResult<Array1<F>>
550where
551    F: Float + NumCast + SimdUnifiedOps,
552{
553    if x.is_empty() {
554        return Err(StatsError::InvalidArgument("x is empty".to_string()));
555    }
556    if low >= high {
557        return Err(StatsError::InvalidArgument(
558            "low must be strictly less than high".to_string(),
559        ));
560    }
561
562    let density = F::one() / (high - low);
563    let result = x.mapv(|xi| {
564        if xi >= low && xi < high {
565            density
566        } else {
567            F::zero()
568        }
569    });
570    Ok(result)
571}
572
573/// Evaluate the Uniform CDF at each point in `x`.
574///
575/// # Arguments
576///
577/// * `x`    — Points at which to evaluate.
578/// * `low`  — Lower bound.
579/// * `high` — Upper bound (must be strictly greater than `low`).
580///
581/// # Errors
582///
583/// Returns [`StatsError::InvalidArgument`] when `low >= high` or `x` is empty.
584///
585/// # Examples
586///
587/// ```
588/// use scirs2_core::ndarray::array;
589/// use scirs2_stats::uniform_cdf_batch;
590///
591/// let x = array![0.5_f64, 1.5, 2.5];
592/// let cdf = uniform_cdf_batch(&x.view(), 1.0, 2.0).expect("failed");
593/// assert!((cdf[0] - 0.0).abs() < 1e-10);
594/// assert!((cdf[1] - 0.5).abs() < 1e-10);
595/// assert!((cdf[2] - 1.0).abs() < 1e-10);
596/// ```
597pub fn uniform_cdf_batch<F>(x: &ArrayView1<F>, low: F, high: F) -> StatsResult<Array1<F>>
598where
599    F: Float + NumCast + SimdUnifiedOps,
600{
601    if x.is_empty() {
602        return Err(StatsError::InvalidArgument("x is empty".to_string()));
603    }
604    if low >= high {
605        return Err(StatsError::InvalidArgument(
606            "low must be strictly less than high".to_string(),
607        ));
608    }
609
610    let width = high - low;
611    let result = x.mapv(|xi| {
612        if xi < low {
613            F::zero()
614        } else if xi >= high {
615            F::one()
616        } else {
617            (xi - low) / width
618        }
619    });
620    Ok(result)
621}
622
623/// Evaluate the Exponential PDF at each point in `x` using SIMD.
624///
625/// Computes `f(x) = λ · exp(−λx)` for x ≥ 0, else 0.
626///
627/// # Arguments
628///
629/// * `x`    — Points at which to evaluate.
630/// * `rate` — Rate parameter λ (must be positive).
631///
632/// # Errors
633///
634/// Returns [`StatsError::InvalidArgument`] when `rate <= 0` or `x` is empty.
635///
636/// # Examples
637///
638/// ```
639/// use scirs2_core::ndarray::array;
640/// use scirs2_stats::exponential_pdf_batch;
641///
642/// let x = array![0.0_f64, 1.0, 2.0];
643/// let pdf = exponential_pdf_batch(&x.view(), 1.0).expect("failed");
644/// // pdf(0; λ=1) = 1.0,  pdf(1) = e^-1 ≈ 0.368,  pdf(2) = e^-2 ≈ 0.135
645/// assert!((pdf[0] - 1.0).abs() < 1e-10);
646/// assert!((pdf[1] - (-1.0_f64).exp()).abs() < 1e-10);
647/// ```
648pub fn exponential_pdf_batch<F>(x: &ArrayView1<F>, rate: F) -> StatsResult<Array1<F>>
649where
650    F: Float + NumCast + SimdUnifiedOps,
651{
652    if x.is_empty() {
653        return Err(StatsError::InvalidArgument("x is empty".to_string()));
654    }
655    if rate <= F::zero() {
656        return Err(StatsError::InvalidArgument(
657            "rate must be positive".to_string(),
658        ));
659    }
660
661    let n = x.len();
662    let optimizer = AutoOptimizer::new();
663
664    if optimizer.should_use_simd(n) {
665        // Mask: 1 for x >= 0, 0 otherwise
666        let mask: Array1<F> = x.mapv(|xi| if xi >= F::zero() { F::one() } else { F::zero() });
667
668        // −rate · x
669        let neg_rate = F::from(-1.0_f64).unwrap_or_else(|| -F::one()) * rate;
670        let neg_rate_arr = Array1::from_elem(n, neg_rate);
671        let exponent = F::simd_mul(&neg_rate_arr.view(), x);
672
673        // exp(−rate · x)
674        let exp_vals = F::simd_exp(&exponent.view());
675
676        // rate · exp(−rate · x)
677        let rate_arr = Array1::from_elem(n, rate);
678        let pdf = F::simd_mul(&rate_arr.view(), &exp_vals.view());
679
680        // Apply mask for non-negative domain
681        let result = F::simd_mul(&mask.view(), &pdf.view());
682        Ok(result)
683    } else {
684        Ok(x.mapv(|xi| {
685            if xi >= F::zero() {
686                rate * (-rate * xi).exp()
687            } else {
688                F::zero()
689            }
690        }))
691    }
692}
693
694/// Evaluate the Exponential CDF at each point in `x` using SIMD.
695///
696/// Computes `F(x) = 1 − exp(−λx)` for x ≥ 0, else 0.
697///
698/// # Arguments
699///
700/// * `x`    — Points at which to evaluate.
701/// * `rate` — Rate parameter λ (must be positive).
702///
703/// # Errors
704///
705/// Returns [`StatsError::InvalidArgument`] when `rate <= 0` or `x` is empty.
706///
707/// # Examples
708///
709/// ```
710/// use scirs2_core::ndarray::array;
711/// use scirs2_stats::exponential_cdf_batch;
712///
713/// let x = array![0.0_f64, 1.0];
714/// let cdf = exponential_cdf_batch(&x.view(), 1.0).expect("failed");
715/// assert!((cdf[0] - 0.0).abs() < 1e-10);
716/// assert!((cdf[1] - (1.0 - (-1.0_f64).exp())).abs() < 1e-10);
717/// ```
718pub fn exponential_cdf_batch<F>(x: &ArrayView1<F>, rate: F) -> StatsResult<Array1<F>>
719where
720    F: Float + NumCast + SimdUnifiedOps,
721{
722    if x.is_empty() {
723        return Err(StatsError::InvalidArgument("x is empty".to_string()));
724    }
725    if rate <= F::zero() {
726        return Err(StatsError::InvalidArgument(
727            "rate must be positive".to_string(),
728        ));
729    }
730
731    let n = x.len();
732    let optimizer = AutoOptimizer::new();
733
734    if optimizer.should_use_simd(n) {
735        // Mask: 1 for x >= 0
736        let mask: Array1<F> = x.mapv(|xi| if xi >= F::zero() { F::one() } else { F::zero() });
737
738        // −rate · x
739        let neg_rate = F::from(-1.0_f64).unwrap_or_else(|| -F::one()) * rate;
740        let neg_rate_arr = Array1::from_elem(n, neg_rate);
741        let exponent = F::simd_mul(&neg_rate_arr.view(), x);
742
743        // exp(−rate · x)
744        let exp_vals = F::simd_exp(&exponent.view());
745
746        // 1 − exp(−rate · x)
747        let ones = Array1::from_elem(n, F::one());
748        let cdf_positive = F::simd_sub(&ones.view(), &exp_vals.view());
749
750        // Apply mask
751        Ok(F::simd_mul(&mask.view(), &cdf_positive.view()))
752    } else {
753        Ok(x.mapv(|xi| {
754            if xi >= F::zero() {
755                F::one() - (-rate * xi).exp()
756            } else {
757                F::zero()
758            }
759        }))
760    }
761}
762
763// ============================================================================
764// Parallel batch sampling — Normal distribution
765// ============================================================================
766
767/// Generate `n` independent samples from N(`mean`, `std_dev`²) using parallel
768/// threads (via Rayon) and per-thread SIMD Box-Muller kernels.
769///
770/// The work is split into `num_cpus` chunks.  Each chunk receives a
771/// deterministic seed derived from the user-supplied `seed` and the chunk
772/// index, making the output fully reproducible when `seed` is provided.
773///
774/// # Arguments
775///
776/// * `n`       — Total number of samples to generate.
777/// * `mean`    — Location parameter of the Normal distribution.
778/// * `std_dev` — Scale parameter (must be positive).
779/// * `seed`    — Optional base RNG seed for reproducibility.
780///
781/// # Errors
782///
783/// Returns [`StatsError::InvalidArgument`] when `n == 0` or `std_dev <= 0`.
784///
785/// # Examples
786///
787/// ```
788/// use scirs2_stats::simd_sampling::parallel_normal_sample;
789///
790/// let samples = parallel_normal_sample(10_000, 5.0_f64, 2.0_f64, Some(42))
791///     .expect("parallel sampling failed");
792/// assert_eq!(samples.len(), 10_000);
793/// let mean: f64 = samples.iter().sum::<f64>() / 10_000.0;
794/// assert!((mean - 5.0).abs() < 0.2, "empirical mean {mean} too far from 5.0");
795/// ```
796pub fn parallel_normal_sample(
797    n: usize,
798    mean: f64,
799    std_dev: f64,
800    seed: Option<u64>,
801) -> StatsResult<Vec<f64>> {
802    if n == 0 {
803        return Err(StatsError::InvalidArgument(
804            "n must be at least 1".to_string(),
805        ));
806    }
807    if std_dev <= 0.0 {
808        return Err(StatsError::InvalidArgument(
809            "std_dev must be positive for the Normal distribution".to_string(),
810        ));
811    }
812
813    use scirs2_core::parallel_ops::*;
814
815    // Derive a base seed — deterministic when the caller provides one.
816    let base_seed: u64 = seed.unwrap_or_else(|| {
817        use scirs2_core::random::Rng;
818        scirs2_core::random::thread_rng().random()
819    });
820
821    // Choose a sensible chunk count (at least 1, at most n).
822    let n_threads = num_cpus::get().max(1).min(n);
823    let chunk_size = (n + n_threads - 1) / n_threads;
824
825    // Build per-chunk seed offsets so outputs are reproducible.
826    // Use wrapping_mul to avoid debug-mode overflow panics on the large
827    // Fibonacci hashing constant.
828    let seeds: Vec<u64> = (0..n_threads)
829        .map(|i| base_seed.wrapping_add((i as u64).wrapping_mul(0x9e37_79b9_7f4a_7c15_u64)))
830        .collect();
831
832    // Each thread independently generates its slice using the SIMD Box-Muller
833    // kernel from `sample_normal_batch`.
834    let chunks: Result<Vec<Vec<f64>>, StatsError> = seeds
835        .into_par_iter()
836        .enumerate()
837        .map(|(i, s)| {
838            let start = i * chunk_size;
839            if start >= n {
840                return Ok(vec![]);
841            }
842            let end = (start + chunk_size).min(n);
843            let count = end - start;
844            let arr = sample_normal_batch::<f64>(count, mean, std_dev, Some(s))?;
845            Ok(arr.to_vec())
846        })
847        .collect();
848
849    let chunks = chunks?;
850    let mut out = Vec::with_capacity(n);
851    for chunk in chunks {
852        out.extend_from_slice(&chunk);
853    }
854    Ok(out)
855}
856
857// ============================================================================
858// Parallel batch sampling — Uniform distribution
859// ============================================================================
860
861/// Generate `n` independent samples from U(`low`, `high`) using parallel
862/// threads and per-thread SIMD linear-rescaling kernels.
863///
864/// Each chunk receives a deterministic seed derived from the base seed and
865/// the chunk index for full reproducibility.
866///
867/// # Arguments
868///
869/// * `n`    — Total number of samples.
870/// * `low`  — Lower bound (inclusive).
871/// * `high` — Upper bound (exclusive; must be strictly greater than `low`).
872/// * `seed` — Optional base RNG seed.
873///
874/// # Errors
875///
876/// Returns [`StatsError::InvalidArgument`] when `n == 0` or `low >= high`.
877///
878/// # Examples
879///
880/// ```
881/// use scirs2_stats::simd_sampling::parallel_uniform_sample;
882///
883/// let samples = parallel_uniform_sample(10_000, 0.0_f64, 1.0_f64, Some(7))
884///     .expect("parallel sampling failed");
885/// assert_eq!(samples.len(), 10_000);
886/// assert!(samples.iter().all(|&x| x >= 0.0 && x < 1.0));
887/// ```
888pub fn parallel_uniform_sample(
889    n: usize,
890    low: f64,
891    high: f64,
892    seed: Option<u64>,
893) -> StatsResult<Vec<f64>> {
894    if n == 0 {
895        return Err(StatsError::InvalidArgument(
896            "n must be at least 1".to_string(),
897        ));
898    }
899    if low >= high {
900        return Err(StatsError::InvalidArgument(
901            "low must be strictly less than high for the Uniform distribution".to_string(),
902        ));
903    }
904
905    use scirs2_core::parallel_ops::*;
906
907    let base_seed: u64 = seed.unwrap_or_else(|| {
908        use scirs2_core::random::Rng;
909        scirs2_core::random::thread_rng().random()
910    });
911
912    let n_threads = num_cpus::get().max(1).min(n);
913    let chunk_size = (n + n_threads - 1) / n_threads;
914
915    let seeds: Vec<u64> = (0..n_threads)
916        .map(|i| base_seed.wrapping_add((i as u64).wrapping_mul(0x6c62_272e_07bb_0142_u64)))
917        .collect();
918
919    let chunks: Result<Vec<Vec<f64>>, StatsError> = seeds
920        .into_par_iter()
921        .enumerate()
922        .map(|(i, s)| {
923            let start = i * chunk_size;
924            if start >= n {
925                return Ok(vec![]);
926            }
927            let end = (start + chunk_size).min(n);
928            let count = end - start;
929            let arr = sample_uniform_batch::<f64>(count, low, high, Some(s))?;
930            Ok(arr.to_vec())
931        })
932        .collect();
933
934    let chunks = chunks?;
935    let mut out = Vec::with_capacity(n);
936    for chunk in chunks {
937        out.extend_from_slice(&chunk);
938    }
939    Ok(out)
940}
941
942// ============================================================================
943// Tests
944// ============================================================================
945
946#[cfg(test)]
947mod tests {
948    use super::*;
949    use scirs2_core::ndarray::array;
950
951    // ── Normal batch sampling ─────────────────────────────────────────────────
952
953    #[test]
954    fn test_sample_normal_batch_length() {
955        let samples = sample_normal_batch::<f64>(500, 0.0, 1.0, Some(1)).expect("should succeed");
956        assert_eq!(samples.len(), 500);
957    }
958
959    #[test]
960    fn test_sample_normal_batch_empirical_moments() {
961        // With 10 000 samples the empirical mean/std should be within ±3σ of true value.
962        let n = 10_000_usize;
963        let mu = 3.5_f64;
964        let sigma = 2.0_f64;
965        let samples = sample_normal_batch::<f64>(n, mu, sigma, Some(42)).expect("should succeed");
966        assert_eq!(samples.len(), n);
967
968        let emp_mean: f64 = samples.iter().sum::<f64>() / n as f64;
969        let emp_var: f64 = samples.iter().map(|&x| (x - emp_mean).powi(2)).sum::<f64>() / n as f64;
970        let emp_std = emp_var.sqrt();
971
972        // Within 5 % relative error on mean and std (generous to avoid flakiness)
973        assert!(
974            (emp_mean - mu).abs() < 0.2,
975            "empirical mean {emp_mean} too far from {mu}"
976        );
977        assert!(
978            (emp_std - sigma).abs() < 0.2,
979            "empirical std {emp_std} too far from {sigma}"
980        );
981    }
982
983    #[test]
984    fn test_sample_normal_batch_rejects_bad_args() {
985        assert!(sample_normal_batch::<f64>(0, 0.0, 1.0, None).is_err());
986        assert!(sample_normal_batch::<f64>(10, 0.0, -1.0, None).is_err());
987        assert!(sample_normal_batch::<f64>(10, 0.0, 0.0, None).is_err());
988    }
989
990    #[test]
991    fn test_sample_normal_batch_f32() {
992        let samples = sample_normal_batch::<f32>(200, 0.0, 1.0, Some(5)).expect("should succeed");
993        assert_eq!(samples.len(), 200);
994    }
995
996    // ── Uniform batch sampling ─────────────────────────────────────────────────
997
998    #[test]
999    fn test_sample_uniform_batch_length() {
1000        let samples = sample_uniform_batch::<f64>(300, -1.0, 1.0, Some(2)).expect("should succeed");
1001        assert_eq!(samples.len(), 300);
1002    }
1003
1004    #[test]
1005    fn test_sample_uniform_batch_range() {
1006        let samples =
1007            sample_uniform_batch::<f64>(2_000, 3.0, 7.0, Some(11)).expect("should succeed");
1008        for &s in samples.iter() {
1009            assert!(s >= 3.0 && s < 7.0, "sample {s} outside [3.0, 7.0)");
1010        }
1011    }
1012
1013    #[test]
1014    fn test_sample_uniform_batch_empirical_mean() {
1015        let (low, high) = (2.0_f64, 6.0_f64);
1016        let expected_mean = (low + high) / 2.0;
1017        let n = 10_000_usize;
1018        let samples = sample_uniform_batch::<f64>(n, low, high, Some(99)).expect("should succeed");
1019        let emp_mean: f64 = samples.iter().sum::<f64>() / n as f64;
1020        assert!(
1021            (emp_mean - expected_mean).abs() < 0.1,
1022            "empirical mean {emp_mean} too far from {expected_mean}"
1023        );
1024    }
1025
1026    #[test]
1027    fn test_sample_uniform_batch_rejects_bad_args() {
1028        assert!(sample_uniform_batch::<f64>(0, 0.0, 1.0, None).is_err());
1029        assert!(sample_uniform_batch::<f64>(10, 1.0, 1.0, None).is_err());
1030        assert!(sample_uniform_batch::<f64>(10, 2.0, 1.0, None).is_err());
1031    }
1032
1033    // ── Exponential batch sampling ────────────────────────────────────────────
1034
1035    #[test]
1036    fn test_sample_exponential_batch_length() {
1037        let samples = sample_exponential_batch::<f64>(400, 1.0, Some(3)).expect("should succeed");
1038        assert_eq!(samples.len(), 400);
1039    }
1040
1041    #[test]
1042    fn test_sample_exponential_batch_non_negative() {
1043        let samples =
1044            sample_exponential_batch::<f64>(5_000, 0.5, Some(77)).expect("should succeed");
1045        for &s in samples.iter() {
1046            assert!(s >= 0.0, "exponential sample {s} is negative");
1047        }
1048    }
1049
1050    #[test]
1051    fn test_sample_exponential_batch_empirical_mean() {
1052        let rate = 2.5_f64;
1053        let expected_mean = 1.0 / rate;
1054        let n = 10_000_usize;
1055        let samples = sample_exponential_batch::<f64>(n, rate, Some(13)).expect("should succeed");
1056        let emp_mean: f64 = samples.iter().sum::<f64>() / n as f64;
1057        assert!(
1058            (emp_mean - expected_mean).abs() < 0.05,
1059            "empirical mean {emp_mean} too far from {expected_mean}"
1060        );
1061    }
1062
1063    #[test]
1064    fn test_sample_exponential_batch_rejects_bad_args() {
1065        assert!(sample_exponential_batch::<f64>(0, 1.0, None).is_err());
1066        assert!(sample_exponential_batch::<f64>(10, 0.0, None).is_err());
1067        assert!(sample_exponential_batch::<f64>(10, -1.0, None).is_err());
1068    }
1069
1070    // ── Normal PDF / CDF ─────────────────────────────────────────────────────
1071
1072    #[test]
1073    fn test_normal_pdf_batch_standard() {
1074        let x = array![0.0_f64, 1.0, -1.0];
1075        let pdf = normal_pdf_batch(&x.view(), 0.0, 1.0).expect("should succeed");
1076        // pdf(0) = 1/√(2π) ≈ 0.39894
1077        let expected_0 = 1.0_f64 / (2.0 * std::f64::consts::PI).sqrt();
1078        assert!((pdf[0] - expected_0).abs() < 1e-7, "pdf[0]={}", pdf[0]);
1079        // Standard normal is symmetric
1080        assert!((pdf[1] - pdf[2]).abs() < 1e-10, "symmetry failed");
1081    }
1082
1083    #[test]
1084    fn test_normal_cdf_batch_standard() {
1085        let x = array![0.0_f64, 1.959_964_f64, -1.959_964_f64];
1086        let cdf = normal_cdf_batch(&x.view(), 0.0, 1.0).expect("should succeed");
1087        assert!((cdf[0] - 0.5).abs() < 1e-5, "cdf[0]={}", cdf[0]);
1088        assert!((cdf[1] - 0.975).abs() < 1e-4, "cdf[1]={}", cdf[1]);
1089        assert!((cdf[2] - 0.025).abs() < 1e-4, "cdf[2]={}", cdf[2]);
1090    }
1091
1092    // ── Uniform PDF / CDF ────────────────────────────────────────────────────
1093
1094    #[test]
1095    fn test_uniform_pdf_batch() {
1096        let x = array![0.5_f64, 1.5, 2.5];
1097        let pdf = uniform_pdf_batch(&x.view(), 1.0, 2.0).expect("should succeed");
1098        assert!((pdf[0] - 0.0).abs() < 1e-10);
1099        assert!((pdf[1] - 1.0).abs() < 1e-10);
1100        assert!((pdf[2] - 0.0).abs() < 1e-10);
1101    }
1102
1103    #[test]
1104    fn test_uniform_cdf_batch() {
1105        let x = array![0.5_f64, 1.5, 2.5];
1106        let cdf = uniform_cdf_batch(&x.view(), 1.0, 2.0).expect("should succeed");
1107        assert!((cdf[0] - 0.0).abs() < 1e-10);
1108        assert!((cdf[1] - 0.5).abs() < 1e-10);
1109        assert!((cdf[2] - 1.0).abs() < 1e-10);
1110    }
1111
1112    // ── Exponential PDF / CDF ────────────────────────────────────────────────
1113
1114    #[test]
1115    fn test_exponential_pdf_batch() {
1116        let x = array![0.0_f64, 1.0, 2.0, -0.5];
1117        let pdf = exponential_pdf_batch(&x.view(), 1.0).expect("should succeed");
1118        assert!((pdf[0] - 1.0).abs() < 1e-10, "pdf(0)={}", pdf[0]);
1119        assert!(
1120            (pdf[1] - (-1.0_f64).exp()).abs() < 1e-10,
1121            "pdf(1)={}",
1122            pdf[1]
1123        );
1124        assert!(
1125            (pdf[2] - (-2.0_f64).exp()).abs() < 1e-10,
1126            "pdf(2)={}",
1127            pdf[2]
1128        );
1129        // Negative input: PDF should be 0
1130        assert!((pdf[3] - 0.0).abs() < 1e-10, "pdf(-0.5)={}", pdf[3]);
1131    }
1132
1133    #[test]
1134    fn test_exponential_cdf_batch() {
1135        let x = array![0.0_f64, 1.0, 2.0, -1.0];
1136        let cdf = exponential_cdf_batch(&x.view(), 1.0).expect("should succeed");
1137        assert!((cdf[0] - 0.0).abs() < 1e-10);
1138        assert!((cdf[1] - (1.0 - (-1.0_f64).exp())).abs() < 1e-10);
1139        assert!((cdf[2] - (1.0 - (-2.0_f64).exp())).abs() < 1e-10);
1140        assert!((cdf[3] - 0.0).abs() < 1e-10);
1141    }
1142
1143    // ── Reproducibility (seeding) ────────────────────────────────────────────
1144
1145    #[test]
1146    fn test_seeded_reproducibility_normal() {
1147        let s1 = sample_normal_batch::<f64>(100, 0.0, 1.0, Some(42)).expect("ok");
1148        let s2 = sample_normal_batch::<f64>(100, 0.0, 1.0, Some(42)).expect("ok");
1149        for (a, b) in s1.iter().zip(s2.iter()) {
1150            assert_eq!(a, b, "seeded normal samples should be identical");
1151        }
1152    }
1153
1154    #[test]
1155    fn test_seeded_reproducibility_uniform() {
1156        let s1 = sample_uniform_batch::<f64>(100, 0.0, 1.0, Some(17)).expect("ok");
1157        let s2 = sample_uniform_batch::<f64>(100, 0.0, 1.0, Some(17)).expect("ok");
1158        for (a, b) in s1.iter().zip(s2.iter()) {
1159            assert_eq!(a, b, "seeded uniform samples should be identical");
1160        }
1161    }
1162
1163    #[test]
1164    fn test_seeded_reproducibility_exponential() {
1165        let s1 = sample_exponential_batch::<f64>(100, 0.5, Some(7)).expect("ok");
1166        let s2 = sample_exponential_batch::<f64>(100, 0.5, Some(7)).expect("ok");
1167        for (a, b) in s1.iter().zip(s2.iter()) {
1168            assert_eq!(a, b, "seeded exponential samples should be identical");
1169        }
1170    }
1171
1172    // ── Parallel normal sampling ──────────────────────────────────────────────
1173
1174    #[test]
1175    fn test_parallel_normal_sample_length() {
1176        let samples =
1177            super::parallel_normal_sample(5_000, 0.0, 1.0, Some(42)).expect("should succeed");
1178        assert_eq!(samples.len(), 5_000);
1179    }
1180
1181    #[test]
1182    fn test_parallel_normal_sample_empirical_moments() {
1183        let n = 20_000_usize;
1184        let mu = 2.5_f64;
1185        let sigma = 1.5_f64;
1186        let samples =
1187            super::parallel_normal_sample(n, mu, sigma, Some(1234)).expect("should succeed");
1188        assert_eq!(samples.len(), n);
1189
1190        let emp_mean: f64 = samples.iter().sum::<f64>() / n as f64;
1191        let emp_var: f64 = samples.iter().map(|&x| (x - emp_mean).powi(2)).sum::<f64>() / n as f64;
1192        let emp_std = emp_var.sqrt();
1193
1194        assert!(
1195            (emp_mean - mu).abs() < 0.2,
1196            "parallel normal: empirical mean {emp_mean} too far from {mu}"
1197        );
1198        assert!(
1199            (emp_std - sigma).abs() < 0.2,
1200            "parallel normal: empirical std {emp_std} too far from {sigma}"
1201        );
1202    }
1203
1204    #[test]
1205    fn test_parallel_normal_sample_rejects_bad_args() {
1206        assert!(super::parallel_normal_sample(0, 0.0, 1.0, None).is_err());
1207        assert!(super::parallel_normal_sample(10, 0.0, 0.0, None).is_err());
1208        assert!(super::parallel_normal_sample(10, 0.0, -1.5, None).is_err());
1209    }
1210
1211    #[test]
1212    fn test_parallel_normal_sample_reproducibility() {
1213        let s1 = super::parallel_normal_sample(1_000, 0.0, 1.0, Some(77)).expect("ok");
1214        let s2 = super::parallel_normal_sample(1_000, 0.0, 1.0, Some(77)).expect("ok");
1215        assert_eq!(s1.len(), s2.len());
1216        for (a, b) in s1.iter().zip(s2.iter()) {
1217            assert_eq!(a, b, "seeded parallel normal samples should be identical");
1218        }
1219    }
1220
1221    // ── Parallel uniform sampling ─────────────────────────────────────────────
1222
1223    #[test]
1224    fn test_parallel_uniform_sample_length() {
1225        let samples =
1226            super::parallel_uniform_sample(4_000, 1.0, 3.0, Some(55)).expect("should succeed");
1227        assert_eq!(samples.len(), 4_000);
1228    }
1229
1230    #[test]
1231    fn test_parallel_uniform_sample_range() {
1232        let samples =
1233            super::parallel_uniform_sample(8_000, 2.0, 5.0, Some(99)).expect("should succeed");
1234        for &s in samples.iter() {
1235            assert!(
1236                s >= 2.0 && s < 5.0,
1237                "parallel uniform: sample {s} outside [2.0, 5.0)"
1238            );
1239        }
1240    }
1241
1242    #[test]
1243    fn test_parallel_uniform_sample_empirical_mean() {
1244        let (low, high) = (3.0_f64, 7.0_f64);
1245        let expected_mean = (low + high) / 2.0;
1246        let n = 10_000_usize;
1247        let samples =
1248            super::parallel_uniform_sample(n, low, high, Some(11)).expect("should succeed");
1249        let emp_mean: f64 = samples.iter().sum::<f64>() / n as f64;
1250        assert!(
1251            (emp_mean - expected_mean).abs() < 0.1,
1252            "parallel uniform: empirical mean {emp_mean} too far from {expected_mean}"
1253        );
1254    }
1255
1256    #[test]
1257    fn test_parallel_uniform_sample_rejects_bad_args() {
1258        assert!(super::parallel_uniform_sample(0, 0.0, 1.0, None).is_err());
1259        assert!(super::parallel_uniform_sample(10, 1.0, 1.0, None).is_err());
1260        assert!(super::parallel_uniform_sample(10, 2.0, 1.0, None).is_err());
1261    }
1262
1263    #[test]
1264    fn test_parallel_uniform_sample_reproducibility() {
1265        let s1 = super::parallel_uniform_sample(500, 0.0, 1.0, Some(33)).expect("ok");
1266        let s2 = super::parallel_uniform_sample(500, 0.0, 1.0, Some(33)).expect("ok");
1267        assert_eq!(s1.len(), s2.len());
1268        for (a, b) in s1.iter().zip(s2.iter()) {
1269            assert_eq!(a, b, "seeded parallel uniform samples should be identical");
1270        }
1271    }
1272}