scirs2_stats/distributions/
normal.rs1use crate::error::{StatsError, StatsResult};
6use crate::error_messages::{helpers, validation};
7use crate::sampling::SampleableDistribution;
8use crate::traits::{ContinuousDistribution, Distribution};
9use scirs2_core::ndarray::Array1;
10use scirs2_core::numeric::{Float, NumCast};
11use scirs2_core::random::{Distribution as RandDistribution, Normal as RandNormal};
12
13pub struct Normal<F: Float> {
15 pub loc: F,
17 pub scale: F,
19 rand_distr: RandNormal<f64>,
21}
22
23impl<F: Float + NumCast + std::fmt::Display> Normal<F> {
24 pub fn new(loc: F, scale: F) -> StatsResult<Self> {
43 validation::ensure_positive(scale, "scale")?;
45
46 let loc_f64 = <f64 as NumCast>::from(loc)
48 .ok_or_else(|| helpers::numerical_error("failed to convert loc to f64"))?;
49 let scale_f64 = <f64 as NumCast>::from(scale)
50 .ok_or_else(|| helpers::numerical_error("failed to convert scale to f64"))?;
51
52 match RandNormal::new(loc_f64, scale_f64) {
53 Ok(rand_distr) => Ok(Normal {
54 loc,
55 scale,
56 rand_distr,
57 }),
58 Err(_) => Err(helpers::numerical_error("normal distribution creation")),
59 }
60 }
61
62 pub fn pdf(&self, x: F) -> F {
82 let pi = F::from(std::f64::consts::PI).unwrap_or_else(|| F::zero());
84 let two = F::from(2.0).unwrap_or_else(|| F::zero());
85
86 let z = (x - self.loc) / self.scale;
87 let exponent = -z * z / two;
88
89 F::from(1.0).unwrap_or_else(|| F::zero()) / (self.scale * (two * pi).sqrt())
90 * exponent.exp()
91 }
92
93 pub fn cdf(&self, x: F) -> F {
113 let z = (x - self.loc) / self.scale;
115
116 if z == F::zero() {
118 return F::from(0.5).unwrap_or_else(|| F::zero());
119 }
120
121 let two = F::from(2.0).unwrap_or_else(|| F::zero());
124 let one = F::one();
125 let half = F::from(0.5).unwrap_or_else(|| F::zero());
126
127 half * (one + erf(z / two.sqrt()))
128 }
129
130 pub fn ppf(&self, p: F) -> StatsResult<F> {
150 if p < F::zero() || p > F::one() {
151 return Err(StatsError::DomainError(
152 "Probability must be between 0 and 1".to_string(),
153 ));
154 }
155
156 if p == F::zero() {
158 return Ok(F::neg_infinity());
159 }
160 if p == F::one() {
161 return Ok(F::infinity());
162 }
163
164 let half = F::from(0.5).unwrap_or_else(|| F::zero());
167
168 let a1 = F::from(-3.969683028665376e+01).unwrap_or_else(|| F::zero());
170 let a2 = F::from(2.209460984245205e+02).unwrap_or_else(|| F::zero());
171 let a3 = F::from(-2.759285104469687e+02).unwrap_or_else(|| F::zero());
172 let a4 = F::from(1.383577518672690e+02).unwrap_or_else(|| F::zero());
173 let a5 = F::from(-3.066479806614716e+01).unwrap_or_else(|| F::zero());
174 let a6 = F::from(2.506628277459239e+00).unwrap_or_else(|| F::zero());
175
176 let b1 = F::from(-5.447609879822406e+01).unwrap_or_else(|| F::zero());
177 let b2 = F::from(1.615858368580409e+02).unwrap_or_else(|| F::zero());
178 let b3 = F::from(-1.556989798598866e+02).unwrap_or_else(|| F::zero());
179 let b4 = F::from(6.680131188771972e+01).unwrap_or_else(|| F::zero());
180 let b5 = F::from(-1.328068155288572e+01).unwrap_or_else(|| F::zero());
181
182 let c1 = F::from(-7.784894002430293e-03).unwrap_or_else(|| F::zero());
183 let c2 = F::from(-3.223964580411365e-01).unwrap_or_else(|| F::zero());
184 let c3 = F::from(-2.400758277161838e+00).unwrap_or_else(|| F::zero());
185 let c4 = F::from(-2.549732539343734e+00).unwrap_or_else(|| F::zero());
186 let c5 = F::from(4.374664141464968e+00).unwrap_or_else(|| F::zero());
187 let c6 = F::from(2.938163982698783e+00).unwrap_or_else(|| F::zero());
188
189 let d1c = F::from(7.784695709041462e-03).unwrap_or_else(|| F::zero());
190 let d2c = F::from(3.224671290700398e-01).unwrap_or_else(|| F::zero());
191 let d3c = F::from(2.445134137142996e+00).unwrap_or_else(|| F::zero());
192 let d4c = F::from(3.754408661907416e+00).unwrap_or_else(|| F::zero());
193
194 let p_low = F::from(0.02425).unwrap_or_else(|| F::zero());
195 let p_high = F::one() - p_low;
196
197 let z = if p < p_low {
198 let q = (-F::from(2.0).unwrap_or_else(|| F::zero()) * p.ln()).sqrt();
200 (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6)
201 / ((((d1c * q + d2c) * q + d3c) * q + d4c) * q + F::one())
202 } else if p <= p_high {
203 let q = p - half;
205 let r = q * q;
206 (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6) * q
207 / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + F::one())
208 } else {
209 let q = (-F::from(2.0).unwrap_or_else(|| F::zero()) * (F::one() - p).ln()).sqrt();
211 -(((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6)
212 / ((((d1c * q + d2c) * q + d3c) * q + d4c) * q + F::one())
213 };
214
215 Ok(z * self.scale + self.loc)
217 }
218
219 pub fn rvs(&self, size: usize) -> StatsResult<Array1<F>> {
239 let mut rng = scirs2_core::random::thread_rng();
240 let mut samples = Vec::with_capacity(size);
241
242 for _ in 0..size {
243 let sample = self.rand_distr.sample(&mut rng);
244 samples.push(F::from(sample).expect("Failed to convert to float"));
245 }
246
247 Ok(Array1::from(samples))
248 }
249}
250
251#[allow(dead_code)]
253fn erf<F: Float>(x: F) -> F {
254 let zero = F::zero();
256 let one = F::one();
257
258 if x < zero {
260 return -erf(-x);
261 }
262
263 let a1 = F::from(0.254829592).expect("Failed to convert constant to float");
265 let a2 = F::from(-0.284496736).expect("Failed to convert constant to float");
266 let a3 = F::from(1.421413741).expect("Failed to convert constant to float");
267 let a4 = F::from(-1.453152027).expect("Failed to convert constant to float");
268 let a5 = F::from(1.061405429).expect("Failed to convert constant to float");
269 let p = F::from(0.3275911).expect("Failed to convert constant to float");
270
271 let t = one / (one + p * x);
273 one - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp()
274}
275
276impl<F: Float + NumCast + std::fmt::Display> Distribution<F> for Normal<F> {
281 fn mean(&self) -> F {
282 self.loc
283 }
284
285 fn var(&self) -> F {
286 self.scale * self.scale
287 }
288
289 fn std(&self) -> F {
290 self.scale
291 }
292
293 fn rvs(&self, size: usize) -> StatsResult<Array1<F>> {
294 Normal::rvs(self, size)
295 }
296
297 fn entropy(&self) -> F {
298 let half = F::from(0.5).expect("Failed to convert constant to float");
299 let two = F::from(2.0).expect("Failed to convert constant to float");
300 let pi = F::from(std::f64::consts::PI).expect("Failed to convert to float");
301 let e = F::from(std::f64::consts::E).expect("Failed to convert to float");
302
303 half + half * (two * pi * e * self.scale * self.scale).ln()
304 }
305}
306
307impl<F: Float + NumCast + std::fmt::Display> ContinuousDistribution<F> for Normal<F> {
309 fn pdf(&self, x: F) -> F {
310 Normal::pdf(self, x)
311 }
312
313 fn cdf(&self, x: F) -> F {
314 Normal::cdf(self, x)
315 }
316
317 fn ppf(&self, p: F) -> StatsResult<F> {
318 Normal::ppf(self, p)
319 }
320}
321
322impl<F: Float + NumCast + std::fmt::Display> SampleableDistribution<F> for Normal<F> {
324 fn rvs(&self, size: usize) -> StatsResult<Vec<F>> {
325 let array = Normal::rvs(self, size)?;
326 Ok(array.to_vec())
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333 use approx::assert_relative_eq;
334
335 #[test]
336 fn test_normal_creation() {
337 let norm = Normal::new(0.0, 1.0).expect("Operation failed");
339 assert_eq!(norm.loc, 0.0);
340 assert_eq!(norm.scale, 1.0);
341
342 let custom = Normal::new(5.0, 2.0).expect("Operation failed");
344 assert_eq!(custom.loc, 5.0);
345 assert_eq!(custom.scale, 2.0);
346
347 assert!(Normal::<f64>::new(0.0, 0.0).is_err());
349 assert!(Normal::<f64>::new(0.0, -1.0).is_err());
350 }
351
352 #[test]
353 fn test_normal_pdf() {
354 let norm = Normal::new(0.0, 1.0).expect("Operation failed");
356
357 let pdf_at_zero = norm.pdf(0.0);
359 assert_relative_eq!(pdf_at_zero, 0.3989423, epsilon = 1e-7);
360
361 let pdf_at_one = norm.pdf(1.0);
363 assert_relative_eq!(pdf_at_one, 0.2419707, epsilon = 1e-7);
364
365 let pdf_at_neg_one = norm.pdf(-1.0);
367 assert_relative_eq!(pdf_at_neg_one, 0.2419707, epsilon = 1e-7);
368
369 let custom = Normal::new(5.0, 2.0).expect("Operation failed");
371 assert_relative_eq!(custom.pdf(5.0), 0.19947114, epsilon = 1e-7);
372 }
373
374 #[test]
375 fn test_normal_cdf() {
376 let norm = Normal::new(0.0, 1.0).expect("Operation failed");
378
379 let cdf_at_zero = norm.cdf(0.0);
381 assert_relative_eq!(cdf_at_zero, 0.5, epsilon = 1e-7);
382
383 let cdf_at_one = norm.cdf(1.0);
385 assert_relative_eq!(cdf_at_one, 0.8413447, epsilon = 1e-5);
386
387 let cdf_at_neg_one = norm.cdf(-1.0);
389 assert_relative_eq!(cdf_at_neg_one, 0.1586553, epsilon = 1e-5);
390 }
391
392 #[test]
393 fn test_normal_ppf() {
394 let norm = Normal::new(0.0, 1.0).expect("Operation failed");
396
397 let median = norm.ppf(0.5).expect("Operation failed");
399 assert_relative_eq!(median, 0.0, epsilon = 1e-5);
400
401 let p975 = norm.ppf(0.975).expect("Operation failed");
403 assert_relative_eq!(p975, 1.96, epsilon = 1e-2);
404
405 let p025 = norm.ppf(0.025).expect("Operation failed");
407 assert_relative_eq!(p025, -1.96, epsilon = 1e-2);
408
409 assert!(norm.ppf(-0.1).is_err());
411 assert!(norm.ppf(1.1).is_err());
412 }
413
414 #[test]
415 fn test_normal_rvs() {
416 let norm = Normal::new(0.0, 1.0).expect("Operation failed");
417
418 let samples = norm.rvs(1000).expect("Operation failed");
420
421 assert_eq!(samples.len(), 1000);
423
424 let sum: f64 = samples.iter().sum();
426 let mean = sum / 1000.0;
427
428 assert!(
431 mean.abs() < 0.15,
432 "Sample mean {} is outside expected range",
433 mean
434 );
435
436 let variance: f64 = samples
438 .iter()
439 .map(|&x| (x - mean) * (x - mean))
440 .sum::<f64>()
441 / 1000.0;
442 let std_dev = variance.sqrt();
443
444 assert!(
446 (std_dev - 1.0).abs() < 0.15,
447 "Sample std dev {} is outside expected range",
448 std_dev
449 );
450 }
451}