1#![allow(non_snake_case)] use rand::Rng;
9use rand::seq::SliceRandom;
10use statrs::distribution::{
11 Bernoulli, Binomial, Geometric, Hypergeometric, NegativeBinomial, Poisson,
12};
13use statrs::distribution::{
14 Beta, Cauchy, ChiSquared, Continuous, ContinuousCDF, Discrete, DiscreteCDF, FisherSnedecor,
15 Gamma, LogNormal, Normal, StudentsT, Weibull,
16};
17use statrs::function::gamma::gamma;
18use thiserror::Error;
19
20#[derive(Error, Debug)]
22pub enum DistributionError {
23 #[error("Invalid parameter: {0}")]
24 InvalidParameter(String),
25
26 #[error("Numerical error: {0}")]
27 NumericalError(String),
28
29 #[error("Distribution not supported: {0}")]
30 NotSupported(String),
31}
32
33pub type Result<T> = std::result::Result<T, DistributionError>;
35
36#[derive(Debug, Clone)]
38pub enum ContinuousDistribution {
39 Normal { mean: f64, std_dev: f64 },
40 StudentsT { df: f64 },
41 ChiSquared { df: f64 },
42 FisherSnedecor { d1: f64, d2: f64 },
43 Exponential { rate: f64 },
44 Gamma { shape: f64, rate: f64 },
45 Beta { alpha: f64, beta: f64 },
46 LogNormal { mu: f64, sigma: f64 },
47 Cauchy { location: f64, scale: f64 },
48 Weibull { shape: f64, scale: f64 },
49 Uniform { lower: f64, upper: f64 },
50}
51
52impl ContinuousDistribution {
53 pub fn standard_normal() -> Self {
55 Self::Normal {
56 mean: 0.0,
57 std_dev: 1.0,
58 }
59 }
60
61 pub fn normal(mean: f64, std_dev: f64) -> Result<Self> {
63 if std_dev <= 0.0 {
64 return Err(DistributionError::InvalidParameter(
65 "Standard deviation must be positive".to_string(),
66 ));
67 }
68 Ok(Self::Normal { mean, std_dev })
69 }
70
71 pub fn students_t(df: f64) -> Result<Self> {
73 if df <= 0.0 {
74 return Err(DistributionError::InvalidParameter(
75 "Degrees of freedom must be positive".to_string(),
76 ));
77 }
78 Ok(Self::StudentsT { df })
79 }
80
81 pub fn chi_squared(df: f64) -> Result<Self> {
83 if df <= 0.0 {
84 return Err(DistributionError::InvalidParameter(
85 "Degrees of freedom must be positive".to_string(),
86 ));
87 }
88 Ok(Self::ChiSquared { df })
89 }
90
91 pub fn fisher_snedecor(d1: f64, d2: f64) -> Result<Self> {
93 if d1 <= 0.0 || d2 <= 0.0 {
94 return Err(DistributionError::InvalidParameter(
95 "Degrees of freedom must be positive".to_string(),
96 ));
97 }
98 Ok(Self::FisherSnedecor { d1, d2 })
99 }
100
101 pub fn pdf(&self, x: f64) -> f64 {
103 match self {
104 Self::Normal { mean, std_dev } => {
105 let dist = Normal::new(*mean, *std_dev).unwrap();
106 dist.pdf(x)
107 }
108 Self::StudentsT { df } => {
109 let dist = StudentsT::new(0.0, 1.0, *df).unwrap();
110 dist.pdf(x)
111 }
112 Self::ChiSquared { df } => {
113 let dist = ChiSquared::new(*df).unwrap();
114 dist.pdf(x)
115 }
116 Self::FisherSnedecor { d1, d2 } => {
117 let dist = FisherSnedecor::new(*d1, *d2).unwrap();
118 dist.pdf(x)
119 }
120 Self::Exponential { rate } => {
121 if x < 0.0 {
122 0.0
123 } else {
124 rate * (-rate * x).exp()
125 }
126 }
127 Self::Gamma { shape, rate } => {
128 let dist = Gamma::new(*shape, *rate).unwrap();
129 dist.pdf(x)
130 }
131 Self::Beta { alpha, beta } => {
132 let dist = Beta::new(*alpha, *beta).unwrap();
133 dist.pdf(x)
134 }
135 Self::LogNormal { mu, sigma } => {
136 let dist = LogNormal::new(*mu, *sigma).unwrap();
137 dist.pdf(x)
138 }
139 Self::Cauchy { location, scale } => {
140 let dist = Cauchy::new(*location, *scale).unwrap();
141 dist.pdf(x)
142 }
143 Self::Weibull { shape, scale } => {
144 let dist = Weibull::new(*shape, *scale).unwrap();
145 dist.pdf(x)
146 }
147 Self::Uniform { lower, upper } => {
148 if x < *lower || x > *upper {
149 0.0
150 } else {
151 1.0 / (upper - lower)
152 }
153 }
154 }
155 }
156
157 pub fn cdf(&self, x: f64) -> f64 {
159 match self {
160 Self::Normal { mean, std_dev } => {
161 let dist = Normal::new(*mean, *std_dev).unwrap();
162 dist.cdf(x)
163 }
164 Self::StudentsT { df } => {
165 let dist = StudentsT::new(0.0, 1.0, *df).unwrap();
166 dist.cdf(x)
167 }
168 Self::ChiSquared { df } => {
169 let dist = ChiSquared::new(*df).unwrap();
170 dist.cdf(x)
171 }
172 Self::FisherSnedecor { d1, d2 } => {
173 let dist = FisherSnedecor::new(*d1, *d2).unwrap();
174 dist.cdf(x)
175 }
176 Self::Exponential { rate } => {
177 if x < 0.0 {
178 0.0
179 } else {
180 1.0 - (-rate * x).exp()
181 }
182 }
183 Self::Gamma { shape, rate } => {
184 let dist = Gamma::new(*shape, *rate).unwrap();
185 dist.cdf(x)
186 }
187 Self::Beta { alpha, beta } => {
188 let dist = Beta::new(*alpha, *beta).unwrap();
189 dist.cdf(x)
190 }
191 Self::LogNormal { mu, sigma } => {
192 let dist = LogNormal::new(*mu, *sigma).unwrap();
193 dist.cdf(x)
194 }
195 Self::Cauchy { location, scale } => {
196 let dist = Cauchy::new(*location, *scale).unwrap();
197 dist.cdf(x)
198 }
199 Self::Weibull { shape, scale } => {
200 let dist = Weibull::new(*shape, *scale).unwrap();
201 dist.cdf(x)
202 }
203 Self::Uniform { lower, upper } => {
204 if x < *lower {
205 0.0
206 } else if x > *upper {
207 1.0
208 } else {
209 (x - lower) / (upper - lower)
210 }
211 }
212 }
213 }
214
215 pub fn quantile(&self, p: f64) -> Option<f64> {
217 if !(0.0..=1.0).contains(&p) {
218 return None;
219 }
220
221 match self {
222 Self::Normal { mean, std_dev } => {
223 let dist = Normal::new(*mean, *std_dev).unwrap();
224 Some(dist.inverse_cdf(p))
225 }
226 Self::StudentsT { df } => {
227 let dist = StudentsT::new(0.0, 1.0, *df).unwrap();
228 Some(dist.inverse_cdf(p))
229 }
230 Self::ChiSquared { df } => {
231 let dist = ChiSquared::new(*df).unwrap();
232 Some(dist.inverse_cdf(p))
233 }
234 Self::FisherSnedecor { d1, d2 } => {
235 let dist = FisherSnedecor::new(*d1, *d2).unwrap();
236 Some(dist.inverse_cdf(p))
237 }
238 Self::Exponential { rate } => {
239 if p <= 0.0 {
240 Some(0.0)
241 } else if p >= 1.0 {
242 Some(f64::INFINITY)
243 } else {
244 Some(-(1.0 - p).ln() / rate)
245 }
246 }
247 Self::Gamma { shape, rate } => {
248 let dist = Gamma::new(*shape, *rate).unwrap();
249 Some(dist.inverse_cdf(p))
250 }
251 Self::Beta { alpha, beta } => {
252 let dist = Beta::new(*alpha, *beta).unwrap();
253 Some(dist.inverse_cdf(p))
254 }
255 Self::LogNormal { mu, sigma } => {
256 let dist = LogNormal::new(*mu, *sigma).unwrap();
257 Some(dist.inverse_cdf(p))
258 }
259 Self::Cauchy { location, scale } => {
260 let dist = Cauchy::new(*location, *scale).unwrap();
261 Some(dist.inverse_cdf(p))
262 }
263 Self::Weibull { shape, scale } => {
264 let dist = Weibull::new(*shape, *scale).unwrap();
265 Some(dist.inverse_cdf(p))
266 }
267 Self::Uniform { lower, upper } => Some(lower + p * (upper - lower)),
268 }
269 }
270
271 pub fn sample<R: Rng>(&self, _rng: &mut R) -> f64 {
273 match self {
274 Self::Normal { mean, .. } => *mean,
275 Self::StudentsT { df } => {
276 if *df > 1.0 {
277 0.0
278 } else {
279 f64::NAN
280 }
281 }
282 Self::ChiSquared { df } => *df,
283 Self::FisherSnedecor { d1: _, d2 } => {
284 if *d2 > 2.0 {
285 *d2 / (*d2 - 2.0)
286 } else {
287 f64::NAN
288 }
289 }
290 Self::Exponential { rate } => 1.0 / rate,
291 Self::Gamma { shape, rate } => shape / rate,
292 Self::Beta { alpha, beta } => alpha / (alpha + beta),
293 Self::LogNormal { mu, sigma } => (mu + sigma.powi(2) / 2.0).exp(),
294 Self::Cauchy { location, .. } => *location,
295 Self::Weibull { shape, scale } => scale * gamma(1.0 + 1.0 / shape),
296 Self::Uniform { lower, upper } => (lower + upper) / 2.0,
297 }
298 }
299}
300
301#[derive(Debug, Clone)]
303pub enum DiscreteDistribution {
304 Bernoulli { p: f64 },
305 Binomial { n: u64, p: f64 },
306 Poisson { lambda: f64 },
307 Geometric { p: f64 },
308 NegativeBinomial { r: f64, p: f64 },
309 Hypergeometric { N: u64, K: u64, n: u64 },
310}
311
312impl DiscreteDistribution {
313 pub fn bernoulli(p: f64) -> Result<Self> {
315 if !(0.0..=1.0).contains(&p) {
316 return Err(DistributionError::InvalidParameter(
317 "Probability must be between 0 and 1".to_string(),
318 ));
319 }
320 Ok(Self::Bernoulli { p })
321 }
322
323 pub fn binomial(n: u64, p: f64) -> Result<Self> {
325 if !(0.0..=1.0).contains(&p) {
326 return Err(DistributionError::InvalidParameter(
327 "Probability must be between 0 and 1".to_string(),
328 ));
329 }
330 Ok(Self::Binomial { n, p })
331 }
332
333 pub fn poisson(lambda: f64) -> Result<Self> {
335 if lambda <= 0.0 {
336 return Err(DistributionError::InvalidParameter(
337 "Lambda must be positive".to_string(),
338 ));
339 }
340 Ok(Self::Poisson { lambda })
341 }
342
343 pub fn geometric(p: f64) -> Result<Self> {
345 if !(0.0..=1.0).contains(&p) {
346 return Err(DistributionError::InvalidParameter(
347 "Probability must be between 0 and 1".to_string(),
348 ));
349 }
350 Ok(Self::Geometric { p })
351 }
352
353 pub fn negative_binomial(r: f64, p: f64) -> Result<Self> {
355 if r <= 0.0 || !(0.0..=1.0).contains(&p) {
356 return Err(DistributionError::InvalidParameter(
357 "r must be positive and p must be between 0 and 1".to_string(),
358 ));
359 }
360 Ok(Self::NegativeBinomial { r, p })
361 }
362
363 pub fn hypergeometric(N: u64, K: u64, n: u64) -> Result<Self> {
365 if n > N || K > N {
366 return Err(DistributionError::InvalidParameter(
367 "Invalid parameters for hypergeometric distribution".to_string(),
368 ));
369 }
370 Ok(Self::Hypergeometric { N, K, n })
371 }
372
373 pub fn pmf(&self, k: u64) -> f64 {
375 match self {
376 Self::Bernoulli { p } => {
377 let dist = Bernoulli::new(*p).unwrap();
378 dist.pmf(k as u64)
379 }
380 Self::Binomial { n, p } => {
381 let dist = Binomial::new(*p, *n).unwrap();
382 dist.pmf(k as u64)
383 }
384 Self::Poisson { lambda } => {
385 let dist = Poisson::new(*lambda).unwrap();
386 dist.pmf(k as u64)
387 }
388 Self::Geometric { p } => {
389 let dist = Geometric::new(*p).unwrap();
390 dist.pmf(k as u64)
391 }
392 Self::NegativeBinomial { r, p } => {
393 let dist = NegativeBinomial::new(*r, *p).unwrap();
394 dist.pmf(k as u64)
395 }
396 Self::Hypergeometric { N, K, n } => {
397 let dist = Hypergeometric::new(*N, *K, *n).unwrap();
398 dist.pmf(k as u64)
399 }
400 }
401 }
402
403 pub fn cdf(&self, k: u64) -> f64 {
405 match self {
406 Self::Bernoulli { p } => {
407 let dist = Bernoulli::new(*p).unwrap();
408 dist.cdf(k as u64)
409 }
410 Self::Binomial { n, p } => {
411 let dist = Binomial::new(*p, *n).unwrap();
412 dist.cdf(k as u64)
413 }
414 Self::Poisson { lambda } => {
415 let dist = Poisson::new(*lambda).unwrap();
416 dist.cdf(k as u64)
417 }
418 Self::Geometric { p } => {
419 let dist = Geometric::new(*p).unwrap();
420 dist.cdf(k as u64)
421 }
422 Self::NegativeBinomial { r, p } => {
423 let dist = NegativeBinomial::new(*r, *p).unwrap();
424 dist.cdf(k as u64)
425 }
426 Self::Hypergeometric { N, K, n } => {
427 let dist = Hypergeometric::new(*N, *K, *n).unwrap();
428 dist.cdf(k as u64)
429 }
430 }
431 }
432
433 pub fn sample<R: Rng>(&self, rng: &mut R) -> u64 {
435 match self {
436 Self::Bernoulli { p } => {
437 if rng.random::<f64>() < *p { 1 } else { 0 }
439 }
440 Self::Binomial { n, p } => {
441 let mut successes = 0;
443 for _ in 0..*n {
444 if rng.random::<f64>() < *p {
445 successes += 1;
446 }
447 }
448 successes
449 }
450 Self::Poisson { lambda } => {
451 let l = (-*lambda).exp();
453 let mut k = 0;
454 let mut p = 1.0;
455 loop {
456 k += 1;
457 p *= rng.random::<f64>();
458 if p <= l {
459 break;
460 }
461 }
462 (k - 1) as u64
463 }
464 Self::Geometric { p } => {
465 ((rng.random::<f64>().ln() / (1.0 - p).ln()).floor() as u64) + 1
467 }
468 Self::NegativeBinomial { r, p } => {
469 let mut successes = 0;
471 let mut trials = 0;
472 while successes < *r as u64 {
473 trials += 1;
474 if rng.random::<f64>() < *p {
475 successes += 1;
476 }
477 }
478 trials
479 }
480 Self::Hypergeometric { N, K, n } => {
481 let mut population = vec![true; *K as usize]
484 .into_iter()
485 .chain(vec![false; (*N - *K) as usize])
486 .collect::<Vec<_>>();
487 population.shuffle(rng);
488 population[..*n as usize].iter().filter(|&&x| x).count() as u64
489 }
490 }
491 }
492}