Skip to main content

sklears_simd/
distributions.rs

1//! SIMD-optimized probability distributions and sampling algorithms
2//!
3//! This module provides high-performance implementations of common probability
4//! distributions and sampling algorithms using SIMD instructions for maximum
5//! performance in machine learning applications.
6
7use scirs2_autograd::ndarray::{Array1, Array2};
8
9#[cfg(feature = "no-std")]
10use core::f32::consts::{SQRT_2, TAU};
11#[cfg(not(feature = "no-std"))]
12use std::f32::consts::{SQRT_2, TAU};
13
14#[cfg(feature = "no-std")]
15use alloc::vec;
16
17/// SIMD-optimized random number generator using LCG (Linear Congruential Generator)
18pub struct SimdRng {
19    state: u64,
20    multiplier: u64,
21    increment: u64,
22}
23
24impl SimdRng {
25    /// Create a new SIMD random number generator
26    pub fn new(seed: u64) -> Self {
27        Self {
28            state: seed,
29            multiplier: 1103515245,
30            increment: 12345,
31        }
32    }
33
34    /// Generate a single random u32
35    pub fn next_u32(&mut self) -> u32 {
36        self.state = self
37            .state
38            .wrapping_mul(self.multiplier)
39            .wrapping_add(self.increment);
40        (self.state >> 16) as u32
41    }
42
43    /// Generate multiple random u32 values using SIMD
44    pub fn fill_u32(&mut self, output: &mut [u32]) {
45        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
46        {
47            if crate::simd_feature_detected!("avx2") {
48                unsafe { self.fill_u32_avx2(output) };
49                return;
50            } else if crate::simd_feature_detected!("sse2") {
51                unsafe { self.fill_u32_sse2(output) };
52                return;
53            }
54        }
55
56        // Scalar fallback
57        for val in output.iter_mut() {
58            *val = self.next_u32();
59        }
60    }
61
62    /// Generate uniform random floats in [0, 1)
63    pub fn uniform_f32(&mut self, output: &mut [f32]) {
64        let mut u32_buffer = vec![0u32; output.len()];
65        self.fill_u32(&mut u32_buffer);
66
67        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
68        {
69            if crate::simd_feature_detected!("avx2") {
70                unsafe { convert_u32_to_f32_avx2(&u32_buffer, output) };
71                return;
72            } else if crate::simd_feature_detected!("sse2") {
73                unsafe { convert_u32_to_f32_sse2(&u32_buffer, output) };
74                return;
75            }
76        }
77
78        // Scalar fallback
79        for (i, &val) in u32_buffer.iter().enumerate() {
80            output[i] = (val as f32) / (u32::MAX as f32);
81        }
82    }
83}
84
85/// SIMD-optimized normal (Gaussian) distribution
86pub struct Normal {
87    mean: f32,
88    std_dev: f32,
89}
90
91impl Normal {
92    /// Create a new normal distribution
93    pub fn new(mean: f32, std_dev: f32) -> Self {
94        assert!(std_dev > 0.0, "Standard deviation must be positive");
95        Self { mean, std_dev }
96    }
97
98    /// Generate samples using Box-Muller transform
99    pub fn sample(&self, rng: &mut SimdRng, output: &mut [f32]) {
100        let mut uniform_samples = vec![0.0f32; output.len() * 2];
101        rng.uniform_f32(&mut uniform_samples);
102
103        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
104        {
105            if crate::simd_feature_detected!("avx2") {
106                unsafe { self.box_muller_avx2(&uniform_samples, output) };
107                return;
108            } else if crate::simd_feature_detected!("sse2") {
109                unsafe { self.box_muller_sse2(&uniform_samples, output) };
110                return;
111            }
112        }
113
114        // Scalar fallback
115        self.box_muller_scalar(&uniform_samples, output);
116    }
117
118    /// Compute probability density function
119    pub fn pdf(&self, values: &[f32], output: &mut [f32]) {
120        assert_eq!(values.len(), output.len());
121
122        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
123        {
124            if crate::simd_feature_detected!("avx2") {
125                unsafe { self.pdf_avx2(values, output) };
126                return;
127            } else if crate::simd_feature_detected!("sse2") {
128                unsafe { self.pdf_sse2(values, output) };
129                return;
130            }
131        }
132
133        // Scalar fallback
134        self.pdf_scalar(values, output);
135    }
136
137    /// Compute cumulative distribution function using error function approximation
138    pub fn cdf(&self, values: &[f32], output: &mut [f32]) {
139        assert_eq!(values.len(), output.len());
140
141        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
142        {
143            if crate::simd_feature_detected!("avx2") {
144                unsafe { self.cdf_avx2(values, output) };
145                return;
146            } else if crate::simd_feature_detected!("sse2") {
147                unsafe { self.cdf_sse2(values, output) };
148                return;
149            }
150        }
151
152        // Scalar fallback
153        self.cdf_scalar(values, output);
154    }
155
156    fn box_muller_scalar(&self, uniform: &[f32], output: &mut [f32]) {
157        let mut i = 0;
158        let mut out_idx = 0;
159
160        while out_idx < output.len() && i + 1 < uniform.len() {
161            let u1 = uniform[i].max(1e-10); // Avoid log(0)
162            let u2 = uniform[i + 1];
163
164            let magnitude = (-2.0 * u1.ln()).sqrt() * self.std_dev;
165            let angle = TAU * u2;
166
167            let z0 = magnitude * angle.cos() + self.mean;
168            let z1 = magnitude * angle.sin() + self.mean;
169
170            output[out_idx] = z0;
171            if out_idx + 1 < output.len() {
172                output[out_idx + 1] = z1;
173            }
174
175            i += 2;
176            out_idx += 2;
177        }
178    }
179
180    fn pdf_scalar(&self, values: &[f32], output: &mut [f32]) {
181        let inv_sqrt_2pi = 1.0 / (TAU).sqrt();
182        let inv_std = 1.0 / self.std_dev;
183        let inv_var_2 = 1.0 / (2.0 * self.std_dev * self.std_dev);
184
185        for (i, &x) in values.iter().enumerate() {
186            let z = (x - self.mean) * inv_std;
187            output[i] = inv_sqrt_2pi * inv_std * (-z * z * inv_var_2).exp();
188        }
189    }
190
191    fn cdf_scalar(&self, values: &[f32], output: &mut [f32]) {
192        for (i, &x) in values.iter().enumerate() {
193            let z = (x - self.mean) / (self.std_dev * SQRT_2);
194            output[i] = 0.5 * (1.0 + erf_approximation(z));
195        }
196    }
197}
198
199/// SIMD-optimized exponential distribution
200pub struct Exponential {
201    rate: f32,
202}
203
204impl Exponential {
205    /// Create a new exponential distribution
206    pub fn new(rate: f32) -> Self {
207        assert!(rate > 0.0, "Rate parameter must be positive");
208        Self { rate }
209    }
210
211    /// Generate samples using inverse transform sampling
212    pub fn sample(&self, rng: &mut SimdRng, output: &mut [f32]) {
213        let mut uniform_samples = vec![0.0f32; output.len()];
214        rng.uniform_f32(&mut uniform_samples);
215
216        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
217        {
218            if crate::simd_feature_detected!("avx2") {
219                unsafe { self.inverse_transform_avx2(&uniform_samples, output) };
220                return;
221            } else if crate::simd_feature_detected!("sse2") {
222                unsafe { self.inverse_transform_sse2(&uniform_samples, output) };
223                return;
224            }
225        }
226
227        // Scalar fallback
228        for (i, &u) in uniform_samples.iter().enumerate() {
229            output[i] = -(1.0 - u).ln() / self.rate;
230        }
231    }
232
233    /// Compute probability density function
234    pub fn pdf(&self, values: &[f32], output: &mut [f32]) {
235        for (i, &x) in values.iter().enumerate() {
236            if x >= 0.0 {
237                output[i] = self.rate * (-self.rate * x).exp();
238            } else {
239                output[i] = 0.0;
240            }
241        }
242    }
243}
244
245/// SIMD-optimized beta distribution (simplified using rejection sampling)
246pub struct Beta {
247    alpha: f32,
248    beta: f32,
249}
250
251impl Beta {
252    /// Create a new beta distribution
253    pub fn new(alpha: f32, beta: f32) -> Self {
254        assert!(alpha > 0.0 && beta > 0.0, "Alpha and beta must be positive");
255        Self { alpha, beta }
256    }
257
258    /// Generate samples using rejection sampling (simplified)
259    pub fn sample(&self, rng: &mut SimdRng, output: &mut [f32]) {
260        // For simplicity, we'll use a basic approach
261        // A more sophisticated implementation would use more efficient algorithms
262        let mut uniform_samples = vec![0.0f32; output.len() * 2];
263        rng.uniform_f32(&mut uniform_samples);
264
265        for i in 0..output.len() {
266            let u1 = uniform_samples[i * 2];
267            let u2 = uniform_samples[i * 2 + 1];
268
269            // Simple transformation (not the most efficient for all parameter values)
270            let x = u1.powf(1.0 / self.alpha);
271            let y = u2.powf(1.0 / self.beta);
272
273            output[i] = x / (x + y);
274        }
275    }
276}
277
278/// Error function approximation using polynomial approximation
279fn erf_approximation(x: f32) -> f32 {
280    // Abramowitz and Stegun approximation
281    let a1 = 0.254_829_6;
282    let a2 = -0.284_496_72;
283    let a3 = 1.421_413_8;
284    let a4 = -1.453_152_1;
285    let a5 = 1.061_405_4;
286    let p = 0.3275911;
287
288    let sign = if x < 0.0 { -1.0 } else { 1.0 };
289    let x_abs = x.abs();
290
291    let t = 1.0 / (1.0 + p * x_abs);
292    let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x_abs * x_abs).exp();
293
294    sign * y
295}
296
297// SIMD implementations for x86/x86_64
298
299#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
300impl SimdRng {
301    #[target_feature(enable = "sse2")]
302    unsafe fn fill_u32_sse2(&mut self, output: &mut [u32]) {
303        // SSE2 doesn't have 32-bit multiply (_mm_mullo_epi32 is SSE4.1)
304        // Fall back to scalar implementation
305        for val in output.iter_mut() {
306            *val = self.next_u32();
307        }
308    }
309
310    #[target_feature(enable = "avx2")]
311    unsafe fn fill_u32_avx2(&mut self, output: &mut [u32]) {
312        // The previous SIMD implementation had a fundamental flaw:
313        // using _mm256_set1_epi32 broadcasts the same state to all lanes,
314        // causing all lanes to generate identical values.
315        // A proper SIMD RNG requires different states per lane or
316        // a counter-based approach. For correctness, use scalar code.
317        for val in output.iter_mut() {
318            *val = self.next_u32();
319        }
320    }
321}
322
323#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
324#[target_feature(enable = "sse2")]
325unsafe fn convert_u32_to_f32_sse2(input: &[u32], output: &mut [f32]) {
326    // _mm_cvtepi32_ps converts signed i32, not unsigned u32
327    // This causes values > 2^31 to be interpreted as negative
328    // Use scalar conversion for correctness
329    for (i, &val) in input.iter().enumerate() {
330        output[i] = (val as f32) / (u32::MAX as f32);
331    }
332}
333
334#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
335#[target_feature(enable = "avx2")]
336unsafe fn convert_u32_to_f32_avx2(input: &[u32], output: &mut [f32]) {
337    // _mm256_cvtepi32_ps converts signed i32, not unsigned u32
338    // This causes values > 2^31 to be interpreted as negative
339    // Use scalar conversion for correctness
340    for (i, &val) in input.iter().enumerate() {
341        output[i] = (val as f32) / (u32::MAX as f32);
342    }
343}
344
345#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
346impl Normal {
347    #[target_feature(enable = "sse2")]
348    unsafe fn box_muller_sse2(&self, uniform: &[f32], output: &mut [f32]) {
349        #[cfg(feature = "no-std")]
350        use core::arch::x86_64::*;
351        #[cfg(not(feature = "no-std"))]
352        use core::arch::x86_64::*;
353
354        let mut i = 0;
355        let mut out_idx = 0;
356
357        while out_idx + 4 <= output.len() && i + 8 <= uniform.len() {
358            let u1 = _mm_loadu_ps(&uniform[i]);
359            let u2 = _mm_loadu_ps(&uniform[i + 4]);
360
361            // Simplified - use scalar math for trigonometric functions
362            // magnitude = sqrt(-2 * ln(u1)) * std_dev
363            let mut u1_vals = [0.0f32; 4];
364            let mut u2_vals = [0.0f32; 4];
365            _mm_storeu_ps(u1_vals.as_mut_ptr(), u1);
366            _mm_storeu_ps(u2_vals.as_mut_ptr(), u2);
367
368            let mut z0_vals = [0.0f32; 4];
369            for k in 0..4 {
370                let magnitude = (-2.0 * u1_vals[k].ln()).sqrt() * self.std_dev;
371                let angle = TAU * u2_vals[k];
372                z0_vals[k] = magnitude * angle.cos() + self.mean;
373            }
374
375            let z0 = _mm_loadu_ps(z0_vals.as_ptr());
376
377            _mm_storeu_ps(&mut output[out_idx], z0);
378
379            i += 8;
380            out_idx += 4;
381        }
382
383        // Handle remaining elements with scalar code
384        while out_idx < output.len() && i + 1 < uniform.len() {
385            let u1 = uniform[i].max(1e-10);
386            let u2 = uniform[i + 1];
387
388            let magnitude = (-2.0 * u1.ln()).sqrt() * self.std_dev;
389            let angle = TAU * u2;
390
391            output[out_idx] = magnitude * angle.cos() + self.mean;
392
393            i += 2;
394            out_idx += 1;
395        }
396    }
397
398    #[target_feature(enable = "avx2")]
399    unsafe fn box_muller_avx2(&self, uniform: &[f32], output: &mut [f32]) {
400        #[cfg(feature = "no-std")]
401        use core::arch::x86_64::*;
402        #[cfg(not(feature = "no-std"))]
403        use core::arch::x86_64::*;
404
405        let mut i = 0;
406        let mut out_idx = 0;
407
408        while out_idx + 8 <= output.len() && i + 16 <= uniform.len() {
409            let u1 = _mm256_loadu_ps(&uniform[i]);
410            let u2 = _mm256_loadu_ps(&uniform[i + 8]);
411
412            // Simplified - use scalar math for trigonometric functions
413            // magnitude = sqrt(-2 * ln(u1)) * std_dev
414            let mut u1_vals = [0.0f32; 8];
415            let mut u2_vals = [0.0f32; 8];
416            _mm256_storeu_ps(u1_vals.as_mut_ptr(), u1);
417            _mm256_storeu_ps(u2_vals.as_mut_ptr(), u2);
418
419            let mut z0_vals = [0.0f32; 8];
420            for k in 0..8 {
421                let magnitude = (-2.0 * u1_vals[k].ln()).sqrt() * self.std_dev;
422                let angle = TAU * u2_vals[k];
423                z0_vals[k] = magnitude * angle.cos() + self.mean;
424            }
425
426            let z0 = _mm256_loadu_ps(z0_vals.as_ptr());
427
428            _mm256_storeu_ps(&mut output[out_idx], z0);
429
430            i += 16;
431            out_idx += 8;
432        }
433
434        // Handle remaining elements with scalar code
435        while out_idx < output.len() && i + 1 < uniform.len() {
436            let u1 = uniform[i].max(1e-10);
437            let u2 = uniform[i + 1];
438
439            let magnitude = (-2.0 * u1.ln()).sqrt() * self.std_dev;
440            let angle = TAU * u2;
441
442            output[out_idx] = magnitude * angle.cos() + self.mean;
443
444            i += 2;
445            out_idx += 1;
446        }
447    }
448
449    #[target_feature(enable = "sse2")]
450    unsafe fn pdf_sse2(&self, values: &[f32], output: &mut [f32]) {
451        #[cfg(feature = "no-std")]
452        use core::arch::x86_64::*;
453        #[cfg(not(feature = "no-std"))]
454        use core::arch::x86_64::*;
455
456        let inv_sqrt_2pi = _mm_set1_ps(1.0 / (TAU).sqrt());
457        let mean_vec = _mm_set1_ps(self.mean);
458        let inv_std = _mm_set1_ps(1.0 / self.std_dev);
459        let inv_var_2 = _mm_set1_ps(1.0 / (2.0 * self.std_dev * self.std_dev));
460
461        let mut i = 0;
462        while i + 4 <= values.len() {
463            let x = _mm_loadu_ps(&values[i]);
464            let z = _mm_mul_ps(_mm_sub_ps(x, mean_vec), inv_std);
465            let exp_arg = _mm_mul_ps(_mm_mul_ps(z, z), inv_var_2);
466            let mut exp_arg_vals = [0.0f32; 4];
467            _mm_storeu_ps(exp_arg_vals.as_mut_ptr(), exp_arg);
468            let mut exp_vals = [0.0f32; 4];
469            for k in 0..4 {
470                exp_vals[k] = (-exp_arg_vals[k]).exp();
471            }
472            let exp_result = _mm_loadu_ps(exp_vals.as_ptr());
473            let result = _mm_mul_ps(_mm_mul_ps(inv_sqrt_2pi, inv_std), exp_result);
474            _mm_storeu_ps(&mut output[i], result);
475            i += 4;
476        }
477
478        // Handle remaining elements
479        while i < values.len() {
480            let z = (values[i] - self.mean) / self.std_dev;
481            output[i] = (1.0 / (TAU).sqrt()) / self.std_dev * (-z * z / 2.0).exp();
482            i += 1;
483        }
484    }
485
486    #[target_feature(enable = "avx2")]
487    unsafe fn pdf_avx2(&self, values: &[f32], output: &mut [f32]) {
488        #[cfg(feature = "no-std")]
489        use core::arch::x86_64::*;
490        #[cfg(not(feature = "no-std"))]
491        use core::arch::x86_64::*;
492
493        let inv_sqrt_2pi = _mm256_set1_ps(1.0 / (TAU).sqrt());
494        let mean_vec = _mm256_set1_ps(self.mean);
495        let inv_std = _mm256_set1_ps(1.0 / self.std_dev);
496        let inv_var_2 = _mm256_set1_ps(1.0 / (2.0 * self.std_dev * self.std_dev));
497
498        let mut i = 0;
499        while i + 8 <= values.len() {
500            let x = _mm256_loadu_ps(&values[i]);
501            let z = _mm256_mul_ps(_mm256_sub_ps(x, mean_vec), inv_std);
502            let exp_arg = _mm256_mul_ps(_mm256_mul_ps(z, z), inv_var_2);
503            let mut exp_arg_vals = [0.0f32; 8];
504            _mm256_storeu_ps(exp_arg_vals.as_mut_ptr(), exp_arg);
505            let mut exp_vals = [0.0f32; 8];
506            for k in 0..8 {
507                exp_vals[k] = (-exp_arg_vals[k]).exp();
508            }
509            let exp_result = _mm256_loadu_ps(exp_vals.as_ptr());
510            let result = _mm256_mul_ps(_mm256_mul_ps(inv_sqrt_2pi, inv_std), exp_result);
511            _mm256_storeu_ps(&mut output[i], result);
512            i += 8;
513        }
514
515        // Handle remaining elements
516        while i < values.len() {
517            let z = (values[i] - self.mean) / self.std_dev;
518            output[i] = (1.0 / (TAU).sqrt()) / self.std_dev * (-z * z / 2.0).exp();
519            i += 1;
520        }
521    }
522
523    #[target_feature(enable = "sse2")]
524    unsafe fn cdf_sse2(&self, values: &[f32], output: &mut [f32]) {
525        // Implementation would use SIMD error function approximation
526        self.cdf_scalar(values, output);
527    }
528
529    #[target_feature(enable = "avx2")]
530    unsafe fn cdf_avx2(&self, values: &[f32], output: &mut [f32]) {
531        // Implementation would use SIMD error function approximation
532        self.cdf_scalar(values, output);
533    }
534}
535
536#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
537impl Exponential {
538    #[target_feature(enable = "sse2")]
539    unsafe fn inverse_transform_sse2(&self, uniform: &[f32], output: &mut [f32]) {
540        #[cfg(feature = "no-std")]
541        use core::arch::x86_64::*;
542        #[cfg(not(feature = "no-std"))]
543        use core::arch::x86_64::*;
544
545        let one = _mm_set1_ps(1.0);
546        let rate_vec = _mm_set1_ps(self.rate);
547
548        let mut i = 0;
549        while i + 4 <= uniform.len() {
550            let u = _mm_loadu_ps(&uniform[i]);
551            let one_minus_u = _mm_sub_ps(one, u);
552            let mut one_minus_u_vals = [0.0f32; 4];
553            _mm_storeu_ps(one_minus_u_vals.as_mut_ptr(), one_minus_u);
554            let mut ln_vals = [0.0f32; 4];
555            for k in 0..4 {
556                ln_vals[k] = one_minus_u_vals[k].ln();
557            }
558            let ln_result = _mm_loadu_ps(ln_vals.as_ptr());
559            let neg_ln = _mm_sub_ps(_mm_setzero_ps(), ln_result);
560            let result = _mm_div_ps(neg_ln, rate_vec);
561            _mm_storeu_ps(&mut output[i], result);
562            i += 4;
563        }
564
565        // Handle remaining elements
566        while i < uniform.len() {
567            output[i] = -(1.0 - uniform[i]).ln() / self.rate;
568            i += 1;
569        }
570    }
571
572    #[target_feature(enable = "avx2")]
573    unsafe fn inverse_transform_avx2(&self, uniform: &[f32], output: &mut [f32]) {
574        #[cfg(feature = "no-std")]
575        use core::arch::x86_64::*;
576        #[cfg(not(feature = "no-std"))]
577        use core::arch::x86_64::*;
578
579        let one = _mm256_set1_ps(1.0);
580        let rate_vec = _mm256_set1_ps(self.rate);
581
582        let mut i = 0;
583        while i + 8 <= uniform.len() {
584            let u = _mm256_loadu_ps(&uniform[i]);
585            let one_minus_u = _mm256_sub_ps(one, u);
586            let mut one_minus_u_vals = [0.0f32; 8];
587            _mm256_storeu_ps(one_minus_u_vals.as_mut_ptr(), one_minus_u);
588            let mut ln_vals = [0.0f32; 8];
589            for k in 0..8 {
590                ln_vals[k] = one_minus_u_vals[k].ln();
591            }
592            let ln_result = _mm256_loadu_ps(ln_vals.as_ptr());
593            let neg_ln = _mm256_sub_ps(_mm256_setzero_ps(), ln_result);
594            let result = _mm256_div_ps(neg_ln, rate_vec);
595            _mm256_storeu_ps(&mut output[i], result);
596            i += 8;
597        }
598
599        // Handle remaining elements
600        while i < uniform.len() {
601            output[i] = -(1.0 - uniform[i]).ln() / self.rate;
602            i += 1;
603        }
604    }
605}
606
607/// Multivariate normal distribution sampling
608pub fn multivariate_normal_sample(
609    mean: &Array1<f32>,
610    covariance: &Array2<f32>,
611    rng: &mut SimdRng,
612    num_samples: usize,
613) -> Array2<f32> {
614    let dim = mean.len();
615    assert_eq!(covariance.shape(), &[dim, dim]);
616
617    // Cholesky decomposition of covariance matrix (simplified)
618    let chol = cholesky_decomposition(covariance);
619
620    let mut samples = Array2::zeros((num_samples, dim));
621    let normal = Normal::new(0.0, 1.0);
622
623    for i in 0..num_samples {
624        let mut standard_normal = vec![0.0f32; dim];
625        normal.sample(rng, &mut standard_normal);
626
627        // Transform standard normal to desired distribution
628        let z = Array1::from_vec(standard_normal);
629        let transformed = crate::matrix::matrix_vector_multiply_f32(&chol, &z);
630
631        for j in 0..dim {
632            samples[[i, j]] = transformed[j] + mean[j];
633        }
634    }
635
636    samples
637}
638
639/// Simplified Cholesky decomposition
640fn cholesky_decomposition(matrix: &Array2<f32>) -> Array2<f32> {
641    let n = matrix.nrows();
642    let mut chol = Array2::zeros((n, n));
643
644    for i in 0..n {
645        for j in 0..=i {
646            if i == j {
647                let mut sum = 0.0;
648                for k in 0..j {
649                    sum += chol[[j, k]] * chol[[j, k]];
650                }
651                chol[[j, j]] = (matrix[[j, j]] - sum).sqrt();
652            } else {
653                let mut sum = 0.0;
654                for k in 0..j {
655                    sum += chol[[i, k]] * chol[[j, k]];
656                }
657                chol[[i, j]] = (matrix[[i, j]] - sum) / chol[[j, j]];
658            }
659        }
660    }
661
662    chol
663}
664
665#[allow(non_snake_case)]
666#[cfg(all(test, not(feature = "no-std")))]
667mod tests {
668    use super::*;
669    use approx::assert_relative_eq;
670
671    #[cfg(feature = "no-std")]
672    use alloc::{vec, vec::Vec};
673
674    #[test]
675    fn test_simd_rng() {
676        let mut rng = SimdRng::new(12345);
677        let mut output = vec![0u32; 16];
678        rng.fill_u32(&mut output);
679
680        // Check that we get different values
681        assert!(output.iter().any(|&x| x != output[0]));
682    }
683
684    #[test]
685    fn test_uniform_f32() {
686        let mut rng = SimdRng::new(12345);
687        let mut output = vec![0.0f32; 100];
688        rng.uniform_f32(&mut output);
689
690        // Check range [0, 1)
691        for &val in &output {
692            assert!((0.0..1.0).contains(&val));
693        }
694
695        // Check some variability
696        let mean: f32 = output.iter().sum::<f32>() / output.len() as f32;
697        assert!(mean > 0.4 && mean < 0.6); // Should be around 0.5
698    }
699
700    #[test]
701    fn test_normal_distribution() {
702        let mut rng = SimdRng::new(42);
703        let normal = Normal::new(5.0, 2.0);
704        let mut samples = vec![0.0f32; 1000];
705        normal.sample(&mut rng, &mut samples);
706
707        let mean: f32 = samples.iter().sum::<f32>() / samples.len() as f32;
708        assert_relative_eq!(mean, 5.0, epsilon = 0.2);
709    }
710
711    #[test]
712    fn test_normal_pdf() {
713        let normal = Normal::new(0.0, 1.0);
714        let values = vec![0.0, 1.0, -1.0];
715        let mut output = vec![0.0f32; 3];
716        normal.pdf(&values, &mut output);
717
718        // At x=0, PDF should be 1/sqrt(2π) ≈ 0.3989
719        assert_relative_eq!(output[0], 0.3989, epsilon = 0.01);
720
721        // At x=1 and x=-1, should be equal (symmetric)
722        assert_relative_eq!(output[1], output[2], epsilon = 1e-6);
723    }
724
725    #[test]
726    fn test_exponential_distribution() {
727        let mut rng = SimdRng::new(123);
728        let exp_dist = Exponential::new(2.0);
729        let mut samples = vec![0.0f32; 1000];
730        exp_dist.sample(&mut rng, &mut samples);
731
732        // All samples should be non-negative
733        for &sample in &samples {
734            assert!(sample >= 0.0);
735        }
736
737        // Mean should be approximately 1/rate = 0.5
738        let mean: f32 = samples.iter().sum::<f32>() / samples.len() as f32;
739        assert_relative_eq!(mean, 0.5, epsilon = 0.1);
740    }
741
742    #[test]
743    fn test_beta_distribution() {
744        let mut rng = SimdRng::new(456);
745        let beta = Beta::new(2.0, 3.0);
746        let mut samples = vec![0.0f32; 100];
747        beta.sample(&mut rng, &mut samples);
748
749        // All samples should be in [0, 1]
750        for &sample in &samples {
751            assert!((0.0..=1.0).contains(&sample));
752        }
753    }
754
755    #[test]
756    fn test_erf_approximation() {
757        assert_relative_eq!(erf_approximation(0.0), 0.0, epsilon = 1e-4);
758        assert_relative_eq!(erf_approximation(1.0), 0.8427, epsilon = 1e-3);
759        assert_relative_eq!(erf_approximation(-1.0), -0.8427, epsilon = 1e-3);
760    }
761
762    #[test]
763    fn test_rng_uniform() {
764        let mut rng = SimdRng::new(123);
765        let mut samples = vec![0.0f32; 10];
766        rng.uniform_f32(&mut samples);
767
768        eprintln!("Uniform samples: {:?}", samples);
769        let sum: f32 = samples.iter().sum();
770        eprintln!("Sum: {}, Mean: {}", sum, sum / samples.len() as f32);
771
772        // At least some variance
773        assert!(sum > 0.1);
774    }
775
776    #[test]
777    fn test_multivariate_normal() {
778        let mut rng = SimdRng::new(789);
779        let mean = Array1::from_vec(vec![1.0, 2.0]);
780        let cov = Array2::from_shape_vec((2, 2), vec![1.0, 0.5, 0.5, 1.0])
781            .expect("shape and data length should match");
782
783        let samples = multivariate_normal_sample(&mean, &cov, &mut rng, 10);
784        assert_eq!(samples.shape(), &[10, 2]);
785    }
786}