rv/dist/
uniform.rs

1//! Continuous uniform distribution, U(a, b) on the interval x in [a, b]
2#[cfg(feature = "serde1")]
3use serde::{Deserialize, Serialize};
4
5use crate::impl_display;
6use crate::traits::{
7    Cdf, ContinuousDistr, Entropy, HasDensity, InverseCdf, Kurtosis, Mean,
8    Median, Parameterized, Sampleable, Scalable, Shiftable, Skewness, Support,
9    Variance,
10};
11use rand::Rng;
12use std::f64;
13use std::fmt;
14use std::sync::OnceLock;
15
16/// [Continuous uniform distribution](https://en.wikipedia.org/wiki/Uniform_distribution_(continuous)),
17/// U(a, b) on the interval x in [a, b]
18///
19/// # Example
20///
21/// The Uniform CDF is a line
22///
23/// ```
24/// use rv::prelude::*;
25///
26/// let u = Uniform::new(2.0, 4.0).unwrap();
27///
28/// // A line representing the CDF
29/// let y = |x: f64| { 0.5 * x - 1.0 };
30///
31/// assert!((u.cdf(&3.0_f64) - y(3.0)).abs() < 1E-12);
32/// assert!((u.cdf(&3.2_f64) - y(3.2)).abs() < 1E-12);
33/// ```
34///
35/// Parameters for the Uniform distribution
36#[derive(Debug, Clone, PartialEq)]
37#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
38#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
39pub struct UniformParameters {
40    /// Lower bound
41    pub a: f64,
42    /// Upper bound
43    pub b: f64,
44}
45
46#[derive(Debug, Clone)]
47#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
48#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
49pub struct Uniform {
50    a: f64,
51    b: f64,
52    /// Cached value of the ln(PDF)
53    #[cfg_attr(feature = "serde1", serde(skip))]
54    lnf: OnceLock<f64>,
55}
56
57impl Shiftable for Uniform {
58    type Output = Uniform;
59    type Error = UniformError;
60
61    fn shifted(self, shift: f64) -> Result<Self::Output, Self::Error>
62    where
63        Self: Sized,
64    {
65        Uniform::new(self.a() + shift, self.b() + shift)
66    }
67
68    fn shifted_unchecked(self, shift: f64) -> Self::Output
69    where
70        Self: Sized,
71    {
72        Uniform::new_unchecked(self.a() + shift, self.b() + shift)
73    }
74}
75
76impl Scalable for Uniform {
77    type Output = Uniform;
78    type Error = UniformError;
79
80    fn scaled(self, scale: f64) -> Result<Self::Output, Self::Error> {
81        Uniform::new(self.a() * scale, self.b() * scale)
82    }
83
84    fn scaled_unchecked(self, scale: f64) -> Self::Output {
85        Uniform::new_unchecked(self.a() * scale, self.b() * scale)
86    }
87}
88
89impl Parameterized for Uniform {
90    type Parameters = UniformParameters;
91
92    fn emit_params(&self) -> Self::Parameters {
93        Self::Parameters {
94            a: self.a(),
95            b: self.b(),
96        }
97    }
98
99    fn from_params(params: Self::Parameters) -> Self {
100        Self::new_unchecked(params.a, params.b)
101    }
102}
103
104impl PartialEq for Uniform {
105    fn eq(&self, other: &Uniform) -> bool {
106        self.a == other.a && self.b == other.b
107    }
108}
109
110#[derive(Debug, Clone, PartialEq)]
111#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
112#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
113pub enum UniformError {
114    /// A >= B
115    InvalidInterval { a: f64, b: f64 },
116    /// A was infinite or NaN
117    ANotFinite { a: f64 },
118    /// B was infinite or NaN
119    BNotFinite { b: f64 },
120}
121
122impl Uniform {
123    /// Create a new uniform distribution on [a, b]
124    #[inline]
125    pub fn new(a: f64, b: f64) -> Result<Self, UniformError> {
126        if a >= b {
127            Err(UniformError::InvalidInterval { a, b })
128        } else if !a.is_finite() {
129            Err(UniformError::ANotFinite { a })
130        } else if !b.is_finite() {
131            Err(UniformError::BNotFinite { b })
132        } else {
133            Ok(Uniform::new_unchecked(a, b))
134        }
135    }
136
137    /// Creates a new Uniform without checking whether the parameters are
138    /// valid.
139    #[inline]
140    #[must_use]
141    pub fn new_unchecked(a: f64, b: f64) -> Self {
142        Uniform {
143            a,
144            b,
145            lnf: OnceLock::new(),
146        }
147    }
148
149    /// Get the lower bound, a
150    ///
151    /// # Example
152    ///
153    /// ```
154    /// # use rv::dist::Uniform;
155    /// let u = Uniform::new(0.0, 1.0).unwrap();
156    /// assert_eq!(u.a(), 0.0);
157    /// ```
158    #[inline]
159    pub fn a(&self) -> f64 {
160        self.a
161    }
162
163    /// Set the value of a
164    pub fn set_a(&mut self, a: f64) -> Result<(), UniformError> {
165        if !a.is_finite() {
166            Err(UniformError::ANotFinite { a })
167        } else if a >= self.b {
168            Err(UniformError::InvalidInterval { a, b: self.b })
169        } else {
170            self.set_a_unchecked(a);
171            Ok(())
172        }
173    }
174
175    /// Set the value of a without checking if a is valid
176    pub fn set_a_unchecked(&mut self, a: f64) {
177        self.lnf = OnceLock::new();
178        self.a = a;
179    }
180
181    /// Get the upper bound, b
182    ///
183    /// # Example
184    ///
185    /// ```
186    /// # use rv::dist::Uniform;
187    /// let u = Uniform::new(0.0, 1.0).unwrap();
188    /// assert_eq!(u.b(), 1.0);
189    /// ```
190    #[inline]
191    pub fn b(&self) -> f64 {
192        self.b
193    }
194
195    /// Set the value of b
196    pub fn set_b(&mut self, b: f64) -> Result<(), UniformError> {
197        if !b.is_finite() {
198            Err(UniformError::BNotFinite { b })
199        } else if self.a >= b {
200            Err(UniformError::InvalidInterval { a: self.a, b })
201        } else {
202            self.set_b_unchecked(b);
203            Ok(())
204        }
205    }
206
207    /// Set the value of b without checking if b is valid
208    pub fn set_b_unchecked(&mut self, b: f64) {
209        self.lnf = OnceLock::new();
210        self.b = b;
211    }
212
213    #[inline]
214    fn lnf(&self) -> f64 {
215        *self.lnf.get_or_init(|| -(self.b - self.a).ln())
216    }
217}
218
219impl Default for Uniform {
220    fn default() -> Self {
221        Uniform::new_unchecked(0.0, 1.0)
222    }
223}
224
225impl From<&Uniform> for String {
226    fn from(u: &Uniform) -> String {
227        format!("U({}, {})", u.a, u.b)
228    }
229}
230
231impl_display!(Uniform);
232
233macro_rules! impl_traits {
234    ($kind:ty) => {
235        impl HasDensity<$kind> for Uniform {
236            fn ln_f(&self, x: &$kind) -> f64 {
237                let xf = f64::from(*x);
238                if self.a <= xf && xf <= self.b {
239                    // call the lnf cache field
240                    self.lnf()
241                } else {
242                    f64::NEG_INFINITY
243                }
244            }
245        }
246
247        impl Sampleable<$kind> for Uniform {
248            fn draw<R: Rng>(&self, rng: &mut R) -> $kind {
249                let u = rand_distr::Uniform::new(self.a, self.b)
250                    .expect("By construction, this should be valid.");
251                rng.sample(u) as $kind
252            }
253
254            fn sample<R: Rng>(&self, n: usize, rng: &mut R) -> Vec<$kind> {
255                let u = rand_distr::Uniform::new(self.a, self.b)
256                    .expect("By construction, this should be valid.");
257                (0..n).map(|_| rng.sample(u) as $kind).collect()
258            }
259        }
260
261        #[allow(clippy::cmp_owned)]
262        impl Support<$kind> for Uniform {
263            fn supports(&self, x: &$kind) -> bool {
264                x.is_finite()
265                    && self.a <= f64::from(*x)
266                    && f64::from(*x) <= self.b
267            }
268        }
269
270        impl ContinuousDistr<$kind> for Uniform {}
271
272        impl Mean<$kind> for Uniform {
273            fn mean(&self) -> Option<$kind> {
274                let m = (self.b + self.a) / 2.0;
275                Some(m as $kind)
276            }
277        }
278
279        impl Median<$kind> for Uniform {
280            fn median(&self) -> Option<$kind> {
281                let m = (self.b + self.a) / 2.0;
282                Some(m as $kind)
283            }
284        }
285
286        impl Variance<$kind> for Uniform {
287            fn variance(&self) -> Option<$kind> {
288                let diff = self.b - self.a;
289                let v = diff * diff / 12.0;
290                Some(v as $kind)
291            }
292        }
293
294        impl Cdf<$kind> for Uniform {
295            fn cdf(&self, x: &$kind) -> f64 {
296                let xf = f64::from(*x);
297                if xf < self.a {
298                    0.0
299                } else if xf >= self.b {
300                    1.0
301                } else {
302                    (xf - self.a) / (self.b - self.a)
303                }
304            }
305        }
306
307        impl InverseCdf<$kind> for Uniform {
308            fn invcdf(&self, p: f64) -> $kind {
309                let x = p.mul_add(self.b - self.a, self.a);
310                x as $kind
311            }
312        }
313    };
314}
315
316impl Skewness for Uniform {
317    fn skewness(&self) -> Option<f64> {
318        Some(0.0)
319    }
320}
321
322impl Kurtosis for Uniform {
323    fn kurtosis(&self) -> Option<f64> {
324        Some(-1.2)
325    }
326}
327
328impl Entropy for Uniform {
329    fn entropy(&self) -> f64 {
330        (self.b - self.a).ln()
331    }
332}
333
334impl_traits!(f64);
335impl_traits!(f32);
336
337impl std::error::Error for UniformError {}
338
339#[cfg_attr(coverage_nightly, coverage(off))]
340impl fmt::Display for UniformError {
341    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
342        match self {
343            Self::InvalidInterval { a, b } => {
344                write!(f, "invalid interval: (a, b) = ({a}, {b})")
345            }
346            Self::ANotFinite { a } => write!(f, "non-finite a: {a}"),
347            Self::BNotFinite { b } => write!(f, "non-finite b: {b}"),
348        }
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355    use crate::misc::ks_test;
356    use crate::test_basic_impls;
357
358    const TOL: f64 = 1E-12;
359    const KS_PVAL: f64 = 0.2;
360    const N_TRIES: usize = 5;
361
362    test_basic_impls!(f64, Uniform);
363
364    #[test]
365    fn new() {
366        let u = Uniform::new(0.0, 1.0).unwrap();
367        assert::close(u.a, 0.0, TOL);
368        assert::close(u.b, 1.0, TOL);
369    }
370
371    #[test]
372    fn new_rejects_a_equal_to_b() {
373        assert!(Uniform::new(1.0, 1.0).is_err());
374    }
375
376    #[test]
377    fn new_rejects_a_gt_b() {
378        assert!(Uniform::new(2.0, 1.0).is_err());
379    }
380
381    #[test]
382    fn new_rejects_non_finite_a_or_b() {
383        assert!(Uniform::new(f64::NEG_INFINITY, 1.0).is_err());
384        assert!(Uniform::new(f64::NAN, 1.0).is_err());
385        assert!(Uniform::new(0.0, f64::INFINITY).is_err());
386        assert!(Uniform::new(0.0, f64::NAN).is_err());
387    }
388
389    #[test]
390    fn mean() {
391        let m: f64 = Uniform::new(2.0, 4.0).unwrap().mean().unwrap();
392        assert::close(m, 3.0, TOL);
393    }
394
395    #[test]
396    fn median() {
397        let m: f64 = Uniform::new(2.0, 4.0).unwrap().median().unwrap();
398        assert::close(m, 3.0, TOL);
399    }
400
401    #[test]
402    fn variance() {
403        let v: f64 = Uniform::new(2.0, 4.0).unwrap().variance().unwrap();
404        assert::close(v, 2.0 / 6.0, TOL);
405    }
406
407    #[test]
408    fn entropy() {
409        let h: f64 = Uniform::new(2.0, 4.0).unwrap().entropy();
410        assert::close(h, std::f64::consts::LN_2, TOL);
411    }
412
413    #[test]
414    fn ln_pdf() {
415        let u = Uniform::new(2.0, 4.0).unwrap();
416        assert::close(u.ln_pdf(&2.0_f64), -std::f64::consts::LN_2, TOL);
417        assert::close(u.ln_pdf(&2.3_f64), -std::f64::consts::LN_2, TOL);
418        assert::close(u.ln_pdf(&3.3_f64), -std::f64::consts::LN_2, TOL);
419        assert::close(u.ln_pdf(&4.0_f64), -std::f64::consts::LN_2, TOL);
420    }
421
422    #[test]
423    fn cdf() {
424        let u = Uniform::new(2.0, 4.0).unwrap();
425        assert::close(u.cdf(&2.0_f64), 0.0, TOL);
426        assert::close(u.cdf(&2.3_f64), 0.149_999_999_999_999_9, TOL);
427        assert::close(u.cdf(&3.3_f64), 0.649_999_999_999_999_9, TOL);
428        assert::close(u.cdf(&4.0_f64), 1.0, TOL);
429    }
430
431    #[test]
432    fn cdf_inv_cdf_ident() {
433        let mut rng = rand::rng();
434        let ru = rand::distr::Uniform::new(1.2, 3.4).unwrap();
435        let u = Uniform::new(1.2, 3.4).unwrap();
436        for _ in 0..100 {
437            let x: f64 = rng.sample(ru);
438            let cdf = u.cdf(&x);
439            let y: f64 = u.invcdf(cdf);
440            assert::close(x, y, 1E-8);
441        }
442    }
443
444    #[test]
445    fn draw_test() {
446        let mut rng = rand::rng();
447        let u = Uniform::new(1.2, 3.4).unwrap();
448        let cdf = |x: f64| u.cdf(&x);
449
450        // test is flaky, try a few times
451        let passes = (0..N_TRIES).fold(0, |acc, _| {
452            let xs: Vec<f64> = u.sample(1000, &mut rng);
453            let (_, p) = ks_test(&xs, cdf);
454            if p > KS_PVAL { acc + 1 } else { acc }
455        });
456        assert!(passes > 0);
457    }
458
459    use crate::test_shiftable_cdf;
460    use crate::test_shiftable_density;
461    use crate::test_shiftable_entropy;
462    use crate::test_shiftable_invcdf;
463    use crate::test_shiftable_method;
464
465    test_shiftable_method!(Uniform::new(2.0, 4.0).unwrap(), mean);
466    test_shiftable_method!(Uniform::new(2.0, 4.0).unwrap(), median);
467    test_shiftable_method!(Uniform::new(2.0, 4.0).unwrap(), variance);
468    test_shiftable_method!(Uniform::new(2.0, 4.0).unwrap(), skewness);
469    test_shiftable_method!(Uniform::new(2.0, 4.0).unwrap(), kurtosis);
470    test_shiftable_density!(Uniform::new(2.0, 4.0).unwrap());
471    test_shiftable_entropy!(Uniform::new(2.0, 4.0).unwrap());
472    test_shiftable_cdf!(Uniform::new(2.0, 4.0).unwrap());
473    test_shiftable_invcdf!(Uniform::new(2.0, 4.0).unwrap());
474
475    use crate::test_scalable_cdf;
476    use crate::test_scalable_density;
477    use crate::test_scalable_entropy;
478    use crate::test_scalable_invcdf;
479    use crate::test_scalable_method;
480
481    test_scalable_method!(Uniform::new(2.0, 4.0).unwrap(), mean);
482    test_scalable_method!(Uniform::new(2.0, 4.0).unwrap(), median);
483    test_scalable_method!(Uniform::new(2.0, 4.0).unwrap(), variance);
484    test_scalable_method!(Uniform::new(2.0, 4.0).unwrap(), skewness);
485    test_scalable_method!(Uniform::new(2.0, 4.0).unwrap(), kurtosis);
486    test_scalable_density!(Uniform::new(2.0, 4.0).unwrap());
487    test_scalable_entropy!(Uniform::new(2.0, 4.0).unwrap());
488    test_scalable_cdf!(Uniform::new(2.0, 4.0).unwrap());
489    test_scalable_invcdf!(Uniform::new(2.0, 4.0).unwrap());
490
491    #[test]
492    fn emit_and_from_params_are_identity() {
493        let vm = Uniform::new(0.5, 10.4).unwrap();
494        let vm_b = Uniform::from_params(vm.emit_params());
495        assert_eq!(vm, vm_b);
496    }
497}