Skip to main content

scirs2_stats/distributions/
normal.rs

1//! Normal distribution functions
2//!
3//! This module provides functionality for the Normal (Gaussian) distribution.
4
5use crate::error::{StatsError, StatsResult};
6use crate::error_messages::{helpers, validation};
7use crate::sampling::SampleableDistribution;
8use crate::traits::{ContinuousDistribution, Distribution};
9use scirs2_core::ndarray::Array1;
10use scirs2_core::numeric::{Float, NumCast};
11use scirs2_core::random::{Distribution as RandDistribution, Normal as RandNormal};
12
13/// Normal distribution structure
14pub struct Normal<F: Float> {
15    /// Mean (location) parameter
16    pub loc: F,
17    /// Standard deviation (scale) parameter
18    pub scale: F,
19    /// Random number generator for this distribution
20    rand_distr: RandNormal<f64>,
21}
22
23impl<F: Float + NumCast + std::fmt::Display> Normal<F> {
24    /// Create a new normal distribution with given mean and standard deviation
25    ///
26    /// # Arguments
27    ///
28    /// * `loc` - Mean (location) parameter
29    /// * `scale` - Standard deviation (scale) parameter
30    ///
31    /// # Returns
32    ///
33    /// * A new Normal distribution instance
34    ///
35    /// # Examples
36    ///
37    /// ```
38    /// use scirs2_stats::distributions::normal::Normal;
39    ///
40    /// let norm = Normal::new(0.0f64, 1.0).expect("Operation failed");
41    /// ```
42    pub fn new(loc: F, scale: F) -> StatsResult<Self> {
43        // Validate scale parameter
44        validation::ensure_positive(scale, "scale")?;
45
46        // Convert to f64 for rand_distr
47        let loc_f64 = <f64 as NumCast>::from(loc)
48            .ok_or_else(|| helpers::numerical_error("failed to convert loc to f64"))?;
49        let scale_f64 = <f64 as NumCast>::from(scale)
50            .ok_or_else(|| helpers::numerical_error("failed to convert scale to f64"))?;
51
52        match RandNormal::new(loc_f64, scale_f64) {
53            Ok(rand_distr) => Ok(Normal {
54                loc,
55                scale,
56                rand_distr,
57            }),
58            Err(_) => Err(helpers::numerical_error("normal distribution creation")),
59        }
60    }
61
62    /// Calculate the probability density function (PDF) at a given point
63    ///
64    /// # Arguments
65    ///
66    /// * `x` - The point at which to evaluate the PDF
67    ///
68    /// # Returns
69    ///
70    /// * The value of the PDF at the given point
71    ///
72    /// # Examples
73    ///
74    /// ```
75    /// use scirs2_stats::distributions::normal::Normal;
76    ///
77    /// let norm = Normal::new(0.0f64, 1.0).expect("Operation failed");
78    /// let pdf_at_zero = norm.pdf(0.0);
79    /// assert!((pdf_at_zero - 0.3989423).abs() < 1e-7);
80    /// ```
81    pub fn pdf(&self, x: F) -> F {
82        // PDF = (1 / (scale * sqrt(2*pi))) * exp(-0.5 * ((x-loc)/scale)^2)
83        let pi = F::from(std::f64::consts::PI).unwrap_or_else(|| F::zero());
84        let two = F::from(2.0).unwrap_or_else(|| F::zero());
85
86        let z = (x - self.loc) / self.scale;
87        let exponent = -z * z / two;
88
89        F::from(1.0).unwrap_or_else(|| F::zero()) / (self.scale * (two * pi).sqrt())
90            * exponent.exp()
91    }
92
93    /// Calculate the cumulative distribution function (CDF) at a given point
94    ///
95    /// # Arguments
96    ///
97    /// * `x` - The point at which to evaluate the CDF
98    ///
99    /// # Returns
100    ///
101    /// * The value of the CDF at the given point
102    ///
103    /// # Examples
104    ///
105    /// ```
106    /// use scirs2_stats::distributions::normal::Normal;
107    ///
108    /// let norm = Normal::new(0.0f64, 1.0).expect("Operation failed");
109    /// let cdf_at_zero = norm.cdf(0.0);
110    /// assert!((cdf_at_zero - 0.5).abs() < 1e-10);
111    /// ```
112    pub fn cdf(&self, x: F) -> F {
113        // Standardize the variable
114        let z = (x - self.loc) / self.scale;
115
116        // For standard normal CDF at 0, the result should be exactly 0.5
117        if z == F::zero() {
118            return F::from(0.5).unwrap_or_else(|| F::zero());
119        }
120
121        // Use a standard implementation of the error function
122        // CDF = 0.5 * (1 + erf(z / sqrt(2)))
123        let two = F::from(2.0).unwrap_or_else(|| F::zero());
124        let one = F::one();
125        let half = F::from(0.5).unwrap_or_else(|| F::zero());
126
127        half * (one + erf(z / two.sqrt()))
128    }
129
130    /// Inverse of the cumulative distribution function (quantile function)
131    ///
132    /// # Arguments
133    ///
134    /// * `p` - Probability value (between 0 and 1)
135    ///
136    /// # Returns
137    ///
138    /// * The value x such that CDF(x) = p
139    ///
140    /// # Examples
141    ///
142    /// ```
143    /// use scirs2_stats::distributions::normal::Normal;
144    ///
145    /// let norm = Normal::new(0.0f64, 1.0).expect("Operation failed");
146    /// let x = norm.ppf(0.975).expect("Operation failed");
147    /// assert!((x - 1.96).abs() < 1e-2);
148    /// ```
149    pub fn ppf(&self, p: F) -> StatsResult<F> {
150        if p < F::zero() || p > F::one() {
151            return Err(StatsError::DomainError(
152                "Probability must be between 0 and 1".to_string(),
153            ));
154        }
155
156        // Special cases
157        if p == F::zero() {
158            return Ok(F::neg_infinity());
159        }
160        if p == F::one() {
161            return Ok(F::infinity());
162        }
163
164        // Acklam's rational approximation for the inverse standard normal CDF.
165        // Maximum relative error ~1.15e-9 across the full range.
166        let half = F::from(0.5).unwrap_or_else(|| F::zero());
167
168        // Coefficients for the rational approximation
169        let a1 = F::from(-3.969683028665376e+01).unwrap_or_else(|| F::zero());
170        let a2 = F::from(2.209460984245205e+02).unwrap_or_else(|| F::zero());
171        let a3 = F::from(-2.759285104469687e+02).unwrap_or_else(|| F::zero());
172        let a4 = F::from(1.383577518672690e+02).unwrap_or_else(|| F::zero());
173        let a5 = F::from(-3.066479806614716e+01).unwrap_or_else(|| F::zero());
174        let a6 = F::from(2.506628277459239e+00).unwrap_or_else(|| F::zero());
175
176        let b1 = F::from(-5.447609879822406e+01).unwrap_or_else(|| F::zero());
177        let b2 = F::from(1.615858368580409e+02).unwrap_or_else(|| F::zero());
178        let b3 = F::from(-1.556989798598866e+02).unwrap_or_else(|| F::zero());
179        let b4 = F::from(6.680131188771972e+01).unwrap_or_else(|| F::zero());
180        let b5 = F::from(-1.328068155288572e+01).unwrap_or_else(|| F::zero());
181
182        let c1 = F::from(-7.784894002430293e-03).unwrap_or_else(|| F::zero());
183        let c2 = F::from(-3.223964580411365e-01).unwrap_or_else(|| F::zero());
184        let c3 = F::from(-2.400758277161838e+00).unwrap_or_else(|| F::zero());
185        let c4 = F::from(-2.549732539343734e+00).unwrap_or_else(|| F::zero());
186        let c5 = F::from(4.374664141464968e+00).unwrap_or_else(|| F::zero());
187        let c6 = F::from(2.938163982698783e+00).unwrap_or_else(|| F::zero());
188
189        let d1c = F::from(7.784695709041462e-03).unwrap_or_else(|| F::zero());
190        let d2c = F::from(3.224671290700398e-01).unwrap_or_else(|| F::zero());
191        let d3c = F::from(2.445134137142996e+00).unwrap_or_else(|| F::zero());
192        let d4c = F::from(3.754408661907416e+00).unwrap_or_else(|| F::zero());
193
194        let p_low = F::from(0.02425).unwrap_or_else(|| F::zero());
195        let p_high = F::one() - p_low;
196
197        let z = if p < p_low {
198            // Lower tail: rational approximation
199            let q = (-F::from(2.0).unwrap_or_else(|| F::zero()) * p.ln()).sqrt();
200            (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6)
201                / ((((d1c * q + d2c) * q + d3c) * q + d4c) * q + F::one())
202        } else if p <= p_high {
203            // Central region: rational approximation
204            let q = p - half;
205            let r = q * q;
206            (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6) * q
207                / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + F::one())
208        } else {
209            // Upper tail: rational approximation (by symmetry)
210            let q = (-F::from(2.0).unwrap_or_else(|| F::zero()) * (F::one() - p).ln()).sqrt();
211            -(((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6)
212                / ((((d1c * q + d2c) * q + d3c) * q + d4c) * q + F::one())
213        };
214
215        // Scale and shift to get the quantile for the given parameters
216        Ok(z * self.scale + self.loc)
217    }
218
219    /// Generate random samples from the distribution
220    ///
221    /// # Arguments
222    ///
223    /// * `size` - Number of samples to generate
224    ///
225    /// # Returns
226    ///
227    /// * Vector of random samples
228    ///
229    /// # Examples
230    ///
231    /// ```
232    /// use scirs2_stats::distributions::normal::Normal;
233    ///
234    /// let norm = Normal::new(0.0f64, 1.0).expect("Operation failed");
235    /// let samples = norm.rvs(1000).expect("Operation failed");
236    /// assert_eq!(samples.len(), 1000);
237    /// ```
238    pub fn rvs(&self, size: usize) -> StatsResult<Array1<F>> {
239        let mut rng = scirs2_core::random::thread_rng();
240        let mut samples = Vec::with_capacity(size);
241
242        for _ in 0..size {
243            let sample = self.rand_distr.sample(&mut rng);
244            samples.push(F::from(sample).expect("Failed to convert to float"));
245        }
246
247        Ok(Array1::from(samples))
248    }
249}
250
251/// Calculate the error function (erf)
252#[allow(dead_code)]
253fn erf<F: Float>(x: F) -> F {
254    // Approximation based on Abramowitz and Stegun
255    let zero = F::zero();
256    let one = F::one();
257
258    // Handle negative values using erf(-x) = -erf(x)
259    if x < zero {
260        return -erf(-x);
261    }
262
263    // Constants for the approximation
264    let a1 = F::from(0.254829592).expect("Failed to convert constant to float");
265    let a2 = F::from(-0.284496736).expect("Failed to convert constant to float");
266    let a3 = F::from(1.421413741).expect("Failed to convert constant to float");
267    let a4 = F::from(-1.453152027).expect("Failed to convert constant to float");
268    let a5 = F::from(1.061405429).expect("Failed to convert constant to float");
269    let p = F::from(0.3275911).expect("Failed to convert constant to float");
270
271    // Calculate the approximation
272    let t = one / (one + p * x);
273    one - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp()
274}
275
276// The inverse_erf function has been replaced with a more accurate
277// approximation directly in the ppf method
278
279// Implement the Distribution trait for Normal
280impl<F: Float + NumCast + std::fmt::Display> Distribution<F> for Normal<F> {
281    fn mean(&self) -> F {
282        self.loc
283    }
284
285    fn var(&self) -> F {
286        self.scale * self.scale
287    }
288
289    fn std(&self) -> F {
290        self.scale
291    }
292
293    fn rvs(&self, size: usize) -> StatsResult<Array1<F>> {
294        Normal::rvs(self, size)
295    }
296
297    fn entropy(&self) -> F {
298        let half = F::from(0.5).expect("Failed to convert constant to float");
299        let two = F::from(2.0).expect("Failed to convert constant to float");
300        let pi = F::from(std::f64::consts::PI).expect("Failed to convert to float");
301        let e = F::from(std::f64::consts::E).expect("Failed to convert to float");
302
303        half + half * (two * pi * e * self.scale * self.scale).ln()
304    }
305}
306
307// Implement the ContinuousDistribution trait for Normal
308impl<F: Float + NumCast + std::fmt::Display> ContinuousDistribution<F> for Normal<F> {
309    fn pdf(&self, x: F) -> F {
310        Normal::pdf(self, x)
311    }
312
313    fn cdf(&self, x: F) -> F {
314        Normal::cdf(self, x)
315    }
316
317    fn ppf(&self, p: F) -> StatsResult<F> {
318        Normal::ppf(self, p)
319    }
320}
321
322/// Implementation of SampleableDistribution for Normal
323impl<F: Float + NumCast + std::fmt::Display> SampleableDistribution<F> for Normal<F> {
324    fn rvs(&self, size: usize) -> StatsResult<Vec<F>> {
325        let array = Normal::rvs(self, size)?;
326        Ok(array.to_vec())
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333    use approx::assert_relative_eq;
334
335    #[test]
336    fn test_normal_creation() {
337        // Standard normal
338        let norm = Normal::new(0.0, 1.0).expect("Operation failed");
339        assert_eq!(norm.loc, 0.0);
340        assert_eq!(norm.scale, 1.0);
341
342        // Custom normal
343        let custom = Normal::new(5.0, 2.0).expect("Operation failed");
344        assert_eq!(custom.loc, 5.0);
345        assert_eq!(custom.scale, 2.0);
346
347        // Error cases
348        assert!(Normal::<f64>::new(0.0, 0.0).is_err());
349        assert!(Normal::<f64>::new(0.0, -1.0).is_err());
350    }
351
352    #[test]
353    fn test_normal_pdf() {
354        // Standard normal PDF values
355        let norm = Normal::new(0.0, 1.0).expect("Operation failed");
356
357        // PDF at x = 0
358        let pdf_at_zero = norm.pdf(0.0);
359        assert_relative_eq!(pdf_at_zero, 0.3989423, epsilon = 1e-7);
360
361        // PDF at x = 1
362        let pdf_at_one = norm.pdf(1.0);
363        assert_relative_eq!(pdf_at_one, 0.2419707, epsilon = 1e-7);
364
365        // PDF at x = -1
366        let pdf_at_neg_one = norm.pdf(-1.0);
367        assert_relative_eq!(pdf_at_neg_one, 0.2419707, epsilon = 1e-7);
368
369        // Custom normal
370        let custom = Normal::new(5.0, 2.0).expect("Operation failed");
371        assert_relative_eq!(custom.pdf(5.0), 0.19947114, epsilon = 1e-7);
372    }
373
374    #[test]
375    fn test_normal_cdf() {
376        // Standard normal CDF values
377        let norm = Normal::new(0.0, 1.0).expect("Operation failed");
378
379        // CDF at x = 0
380        let cdf_at_zero = norm.cdf(0.0);
381        assert_relative_eq!(cdf_at_zero, 0.5, epsilon = 1e-7);
382
383        // CDF at x = 1
384        let cdf_at_one = norm.cdf(1.0);
385        assert_relative_eq!(cdf_at_one, 0.8413447, epsilon = 1e-5);
386
387        // CDF at x = -1
388        let cdf_at_neg_one = norm.cdf(-1.0);
389        assert_relative_eq!(cdf_at_neg_one, 0.1586553, epsilon = 1e-5);
390    }
391
392    #[test]
393    fn test_normal_ppf() {
394        // Standard normal quantiles
395        let norm = Normal::new(0.0, 1.0).expect("Operation failed");
396
397        // Median (50th percentile)
398        let median = norm.ppf(0.5).expect("Operation failed");
399        assert_relative_eq!(median, 0.0, epsilon = 1e-5);
400
401        // 97.5th percentile (often used for confidence intervals)
402        let p975 = norm.ppf(0.975).expect("Operation failed");
403        assert_relative_eq!(p975, 1.96, epsilon = 1e-2);
404
405        // 2.5th percentile
406        let p025 = norm.ppf(0.025).expect("Operation failed");
407        assert_relative_eq!(p025, -1.96, epsilon = 1e-2);
408
409        // Error cases
410        assert!(norm.ppf(-0.1).is_err());
411        assert!(norm.ppf(1.1).is_err());
412    }
413
414    #[test]
415    fn test_normal_rvs() {
416        let norm = Normal::new(0.0, 1.0).expect("Operation failed");
417
418        // Generate samples
419        let samples = norm.rvs(1000).expect("Operation failed");
420
421        // Check the number of samples
422        assert_eq!(samples.len(), 1000);
423
424        // Basic statistical checks
425        let sum: f64 = samples.iter().sum();
426        let mean = sum / 1000.0;
427
428        // Mean should be close to 0 (with more generous tolerance for random variation)
429        // With 1000 samples, 99.7% confidence interval is roughly ±0.1
430        assert!(
431            mean.abs() < 0.15,
432            "Sample mean {} is outside expected range",
433            mean
434        );
435
436        // Standard deviation check
437        let variance: f64 = samples
438            .iter()
439            .map(|&x| (x - mean) * (x - mean))
440            .sum::<f64>()
441            / 1000.0;
442        let std_dev = variance.sqrt();
443
444        // Std dev should be close to 1 (with tolerance for random variation)
445        assert!(
446            (std_dev - 1.0).abs() < 0.15,
447            "Sample std dev {} is outside expected range",
448            std_dev
449        );
450    }
451}