1#[cfg(feature = "serde1")]
3use serde::{Deserialize, Serialize};
4
5use crate::impl_display;
6use crate::traits::{
7 Cdf, ContinuousDistr, Entropy, HasDensity, InverseCdf, Kurtosis, Mean,
8 Median, Parameterized, Sampleable, Scalable, Shiftable, Skewness, Support,
9 Variance,
10};
11use rand::Rng;
12use std::f64;
13use std::fmt;
14use std::sync::OnceLock;
15
16#[derive(Debug, Clone, PartialEq)]
37#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
38#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
39pub struct UniformParameters {
40 pub a: f64,
42 pub b: f64,
44}
45
46#[derive(Debug, Clone)]
47#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
48#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
49pub struct Uniform {
50 a: f64,
51 b: f64,
52 #[cfg_attr(feature = "serde1", serde(skip))]
54 lnf: OnceLock<f64>,
55}
56
57impl Shiftable for Uniform {
58 type Output = Uniform;
59 type Error = UniformError;
60
61 fn shifted(self, shift: f64) -> Result<Self::Output, Self::Error>
62 where
63 Self: Sized,
64 {
65 Uniform::new(self.a() + shift, self.b() + shift)
66 }
67
68 fn shifted_unchecked(self, shift: f64) -> Self::Output
69 where
70 Self: Sized,
71 {
72 Uniform::new_unchecked(self.a() + shift, self.b() + shift)
73 }
74}
75
76impl Scalable for Uniform {
77 type Output = Uniform;
78 type Error = UniformError;
79
80 fn scaled(self, scale: f64) -> Result<Self::Output, Self::Error> {
81 Uniform::new(self.a() * scale, self.b() * scale)
82 }
83
84 fn scaled_unchecked(self, scale: f64) -> Self::Output {
85 Uniform::new_unchecked(self.a() * scale, self.b() * scale)
86 }
87}
88
89impl Parameterized for Uniform {
90 type Parameters = UniformParameters;
91
92 fn emit_params(&self) -> Self::Parameters {
93 Self::Parameters {
94 a: self.a(),
95 b: self.b(),
96 }
97 }
98
99 fn from_params(params: Self::Parameters) -> Self {
100 Self::new_unchecked(params.a, params.b)
101 }
102}
103
104impl PartialEq for Uniform {
105 fn eq(&self, other: &Uniform) -> bool {
106 self.a == other.a && self.b == other.b
107 }
108}
109
110#[derive(Debug, Clone, PartialEq)]
111#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
112#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
113pub enum UniformError {
114 InvalidInterval { a: f64, b: f64 },
116 ANotFinite { a: f64 },
118 BNotFinite { b: f64 },
120}
121
122impl Uniform {
123 #[inline]
125 pub fn new(a: f64, b: f64) -> Result<Self, UniformError> {
126 if a >= b {
127 Err(UniformError::InvalidInterval { a, b })
128 } else if !a.is_finite() {
129 Err(UniformError::ANotFinite { a })
130 } else if !b.is_finite() {
131 Err(UniformError::BNotFinite { b })
132 } else {
133 Ok(Uniform::new_unchecked(a, b))
134 }
135 }
136
137 #[inline]
140 #[must_use]
141 pub fn new_unchecked(a: f64, b: f64) -> Self {
142 Uniform {
143 a,
144 b,
145 lnf: OnceLock::new(),
146 }
147 }
148
149 #[inline]
159 pub fn a(&self) -> f64 {
160 self.a
161 }
162
163 pub fn set_a(&mut self, a: f64) -> Result<(), UniformError> {
165 if !a.is_finite() {
166 Err(UniformError::ANotFinite { a })
167 } else if a >= self.b {
168 Err(UniformError::InvalidInterval { a, b: self.b })
169 } else {
170 self.set_a_unchecked(a);
171 Ok(())
172 }
173 }
174
175 pub fn set_a_unchecked(&mut self, a: f64) {
177 self.lnf = OnceLock::new();
178 self.a = a;
179 }
180
181 #[inline]
191 pub fn b(&self) -> f64 {
192 self.b
193 }
194
195 pub fn set_b(&mut self, b: f64) -> Result<(), UniformError> {
197 if !b.is_finite() {
198 Err(UniformError::BNotFinite { b })
199 } else if self.a >= b {
200 Err(UniformError::InvalidInterval { a: self.a, b })
201 } else {
202 self.set_b_unchecked(b);
203 Ok(())
204 }
205 }
206
207 pub fn set_b_unchecked(&mut self, b: f64) {
209 self.lnf = OnceLock::new();
210 self.b = b;
211 }
212
213 #[inline]
214 fn lnf(&self) -> f64 {
215 *self.lnf.get_or_init(|| -(self.b - self.a).ln())
216 }
217}
218
219impl Default for Uniform {
220 fn default() -> Self {
221 Uniform::new_unchecked(0.0, 1.0)
222 }
223}
224
225impl From<&Uniform> for String {
226 fn from(u: &Uniform) -> String {
227 format!("U({}, {})", u.a, u.b)
228 }
229}
230
231impl_display!(Uniform);
232
233macro_rules! impl_traits {
234 ($kind:ty) => {
235 impl HasDensity<$kind> for Uniform {
236 fn ln_f(&self, x: &$kind) -> f64 {
237 let xf = f64::from(*x);
238 if self.a <= xf && xf <= self.b {
239 self.lnf()
241 } else {
242 f64::NEG_INFINITY
243 }
244 }
245 }
246
247 impl Sampleable<$kind> for Uniform {
248 fn draw<R: Rng>(&self, rng: &mut R) -> $kind {
249 let u = rand_distr::Uniform::new(self.a, self.b)
250 .expect("By construction, this should be valid.");
251 rng.sample(u) as $kind
252 }
253
254 fn sample<R: Rng>(&self, n: usize, rng: &mut R) -> Vec<$kind> {
255 let u = rand_distr::Uniform::new(self.a, self.b)
256 .expect("By construction, this should be valid.");
257 (0..n).map(|_| rng.sample(u) as $kind).collect()
258 }
259 }
260
261 #[allow(clippy::cmp_owned)]
262 impl Support<$kind> for Uniform {
263 fn supports(&self, x: &$kind) -> bool {
264 x.is_finite()
265 && self.a <= f64::from(*x)
266 && f64::from(*x) <= self.b
267 }
268 }
269
270 impl ContinuousDistr<$kind> for Uniform {}
271
272 impl Mean<$kind> for Uniform {
273 fn mean(&self) -> Option<$kind> {
274 let m = (self.b + self.a) / 2.0;
275 Some(m as $kind)
276 }
277 }
278
279 impl Median<$kind> for Uniform {
280 fn median(&self) -> Option<$kind> {
281 let m = (self.b + self.a) / 2.0;
282 Some(m as $kind)
283 }
284 }
285
286 impl Variance<$kind> for Uniform {
287 fn variance(&self) -> Option<$kind> {
288 let diff = self.b - self.a;
289 let v = diff * diff / 12.0;
290 Some(v as $kind)
291 }
292 }
293
294 impl Cdf<$kind> for Uniform {
295 fn cdf(&self, x: &$kind) -> f64 {
296 let xf = f64::from(*x);
297 if xf < self.a {
298 0.0
299 } else if xf >= self.b {
300 1.0
301 } else {
302 (xf - self.a) / (self.b - self.a)
303 }
304 }
305 }
306
307 impl InverseCdf<$kind> for Uniform {
308 fn invcdf(&self, p: f64) -> $kind {
309 let x = p.mul_add(self.b - self.a, self.a);
310 x as $kind
311 }
312 }
313 };
314}
315
316impl Skewness for Uniform {
317 fn skewness(&self) -> Option<f64> {
318 Some(0.0)
319 }
320}
321
322impl Kurtosis for Uniform {
323 fn kurtosis(&self) -> Option<f64> {
324 Some(-1.2)
325 }
326}
327
328impl Entropy for Uniform {
329 fn entropy(&self) -> f64 {
330 (self.b - self.a).ln()
331 }
332}
333
334impl_traits!(f64);
335impl_traits!(f32);
336
337impl std::error::Error for UniformError {}
338
339#[cfg_attr(coverage_nightly, coverage(off))]
340impl fmt::Display for UniformError {
341 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
342 match self {
343 Self::InvalidInterval { a, b } => {
344 write!(f, "invalid interval: (a, b) = ({a}, {b})")
345 }
346 Self::ANotFinite { a } => write!(f, "non-finite a: {a}"),
347 Self::BNotFinite { b } => write!(f, "non-finite b: {b}"),
348 }
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355 use crate::misc::ks_test;
356 use crate::test_basic_impls;
357
358 const TOL: f64 = 1E-12;
359 const KS_PVAL: f64 = 0.2;
360 const N_TRIES: usize = 5;
361
362 test_basic_impls!(f64, Uniform);
363
364 #[test]
365 fn new() {
366 let u = Uniform::new(0.0, 1.0).unwrap();
367 assert::close(u.a, 0.0, TOL);
368 assert::close(u.b, 1.0, TOL);
369 }
370
371 #[test]
372 fn new_rejects_a_equal_to_b() {
373 assert!(Uniform::new(1.0, 1.0).is_err());
374 }
375
376 #[test]
377 fn new_rejects_a_gt_b() {
378 assert!(Uniform::new(2.0, 1.0).is_err());
379 }
380
381 #[test]
382 fn new_rejects_non_finite_a_or_b() {
383 assert!(Uniform::new(f64::NEG_INFINITY, 1.0).is_err());
384 assert!(Uniform::new(f64::NAN, 1.0).is_err());
385 assert!(Uniform::new(0.0, f64::INFINITY).is_err());
386 assert!(Uniform::new(0.0, f64::NAN).is_err());
387 }
388
389 #[test]
390 fn mean() {
391 let m: f64 = Uniform::new(2.0, 4.0).unwrap().mean().unwrap();
392 assert::close(m, 3.0, TOL);
393 }
394
395 #[test]
396 fn median() {
397 let m: f64 = Uniform::new(2.0, 4.0).unwrap().median().unwrap();
398 assert::close(m, 3.0, TOL);
399 }
400
401 #[test]
402 fn variance() {
403 let v: f64 = Uniform::new(2.0, 4.0).unwrap().variance().unwrap();
404 assert::close(v, 2.0 / 6.0, TOL);
405 }
406
407 #[test]
408 fn entropy() {
409 let h: f64 = Uniform::new(2.0, 4.0).unwrap().entropy();
410 assert::close(h, std::f64::consts::LN_2, TOL);
411 }
412
413 #[test]
414 fn ln_pdf() {
415 let u = Uniform::new(2.0, 4.0).unwrap();
416 assert::close(u.ln_pdf(&2.0_f64), -std::f64::consts::LN_2, TOL);
417 assert::close(u.ln_pdf(&2.3_f64), -std::f64::consts::LN_2, TOL);
418 assert::close(u.ln_pdf(&3.3_f64), -std::f64::consts::LN_2, TOL);
419 assert::close(u.ln_pdf(&4.0_f64), -std::f64::consts::LN_2, TOL);
420 }
421
422 #[test]
423 fn cdf() {
424 let u = Uniform::new(2.0, 4.0).unwrap();
425 assert::close(u.cdf(&2.0_f64), 0.0, TOL);
426 assert::close(u.cdf(&2.3_f64), 0.149_999_999_999_999_9, TOL);
427 assert::close(u.cdf(&3.3_f64), 0.649_999_999_999_999_9, TOL);
428 assert::close(u.cdf(&4.0_f64), 1.0, TOL);
429 }
430
431 #[test]
432 fn cdf_inv_cdf_ident() {
433 let mut rng = rand::rng();
434 let ru = rand::distr::Uniform::new(1.2, 3.4).unwrap();
435 let u = Uniform::new(1.2, 3.4).unwrap();
436 for _ in 0..100 {
437 let x: f64 = rng.sample(ru);
438 let cdf = u.cdf(&x);
439 let y: f64 = u.invcdf(cdf);
440 assert::close(x, y, 1E-8);
441 }
442 }
443
444 #[test]
445 fn draw_test() {
446 let mut rng = rand::rng();
447 let u = Uniform::new(1.2, 3.4).unwrap();
448 let cdf = |x: f64| u.cdf(&x);
449
450 let passes = (0..N_TRIES).fold(0, |acc, _| {
452 let xs: Vec<f64> = u.sample(1000, &mut rng);
453 let (_, p) = ks_test(&xs, cdf);
454 if p > KS_PVAL { acc + 1 } else { acc }
455 });
456 assert!(passes > 0);
457 }
458
459 use crate::test_shiftable_cdf;
460 use crate::test_shiftable_density;
461 use crate::test_shiftable_entropy;
462 use crate::test_shiftable_invcdf;
463 use crate::test_shiftable_method;
464
465 test_shiftable_method!(Uniform::new(2.0, 4.0).unwrap(), mean);
466 test_shiftable_method!(Uniform::new(2.0, 4.0).unwrap(), median);
467 test_shiftable_method!(Uniform::new(2.0, 4.0).unwrap(), variance);
468 test_shiftable_method!(Uniform::new(2.0, 4.0).unwrap(), skewness);
469 test_shiftable_method!(Uniform::new(2.0, 4.0).unwrap(), kurtosis);
470 test_shiftable_density!(Uniform::new(2.0, 4.0).unwrap());
471 test_shiftable_entropy!(Uniform::new(2.0, 4.0).unwrap());
472 test_shiftable_cdf!(Uniform::new(2.0, 4.0).unwrap());
473 test_shiftable_invcdf!(Uniform::new(2.0, 4.0).unwrap());
474
475 use crate::test_scalable_cdf;
476 use crate::test_scalable_density;
477 use crate::test_scalable_entropy;
478 use crate::test_scalable_invcdf;
479 use crate::test_scalable_method;
480
481 test_scalable_method!(Uniform::new(2.0, 4.0).unwrap(), mean);
482 test_scalable_method!(Uniform::new(2.0, 4.0).unwrap(), median);
483 test_scalable_method!(Uniform::new(2.0, 4.0).unwrap(), variance);
484 test_scalable_method!(Uniform::new(2.0, 4.0).unwrap(), skewness);
485 test_scalable_method!(Uniform::new(2.0, 4.0).unwrap(), kurtosis);
486 test_scalable_density!(Uniform::new(2.0, 4.0).unwrap());
487 test_scalable_entropy!(Uniform::new(2.0, 4.0).unwrap());
488 test_scalable_cdf!(Uniform::new(2.0, 4.0).unwrap());
489 test_scalable_invcdf!(Uniform::new(2.0, 4.0).unwrap());
490
491 #[test]
492 fn emit_and_from_params_are_identity() {
493 let vm = Uniform::new(0.5, 10.4).unwrap();
494 let vm_b = Uniform::from_params(vm.emit_params());
495 assert_eq!(vm, vm_b);
496 }
497}