rv/dist/
unit_powerlaw.rs

1//! `UnitPowerLaw` distribution over x in (0, 1)
2#[cfg(feature = "serde1")]
3use serde::{Deserialize, Serialize};
4
5use crate::data::UnitPowerLawSuffStat;
6use crate::impl_display;
7use crate::prelude::Beta;
8use crate::traits::{
9    Cdf, ContinuousDistr, Entropy, HasDensity, HasSuffStat, InverseCdf,
10    Kurtosis, Mean, Mode, Parameterized, Sampleable, Scalable, Shiftable,
11    Skewness, Support, Variance,
12};
13use rand::Rng;
14use special::Gamma as _;
15use std::f64;
16use std::fmt;
17use std::sync::OnceLock;
18
19pub mod bernoulli_prior;
20
21/// Parameters for the `UnitPowerLaw` distribution
22#[derive(Debug, Clone, PartialEq)]
23#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
24#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
25pub struct UnitPowerLawParameters {
26    /// Shape parameter
27    pub alpha: f64,
28}
29
30/// UnitPowerLaw(α) over x in (0, 1).
31///
32/// # Examples
33///
34/// `UnitPowerLaw` as a conjugate prior for Bernoulli
35///
36/// ```
37/// use rv::prelude::*;
38///
39/// // A prior that encodes our strong belief that coins are fair:
40/// let powlaw = UnitPowerLaw::new(5.0).unwrap();
41/// ```
42#[derive(Debug, Clone)]
43#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
44#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
45pub struct UnitPowerLaw {
46    alpha: f64,
47
48    // Cached alpha..recip()
49    #[cfg_attr(feature = "serde1", serde(skip))]
50    alpha_inv: OnceLock<f64>,
51
52    // Cached alpha.ln()
53    #[cfg_attr(feature = "serde1", serde(skip))]
54    alpha_ln: OnceLock<f64>,
55}
56
57impl Parameterized for UnitPowerLaw {
58    type Parameters = UnitPowerLawParameters;
59
60    fn emit_params(&self) -> Self::Parameters {
61        Self::Parameters {
62            alpha: self.alpha(),
63        }
64    }
65
66    fn from_params(params: Self::Parameters) -> Self {
67        Self::new_unchecked(params.alpha)
68    }
69}
70
71crate::impl_shiftable!(UnitPowerLaw);
72crate::impl_scalable!(UnitPowerLaw);
73
74impl PartialEq for UnitPowerLaw {
75    fn eq(&self, other: &UnitPowerLaw) -> bool {
76        self.alpha == other.alpha
77    }
78}
79
80#[derive(Debug, Clone, PartialEq, PartialOrd)]
81#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
82#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
83pub enum UnitPowerLawError {
84    /// The alpha parameter is less than or equal to zero
85    AlphaTooLow { alpha: f64 },
86    /// The alpha parameter is infinite or NaN
87    AlphaNotFinite { alpha: f64 },
88}
89
90impl UnitPowerLaw {
91    /// Create a `UnitPowerLaw` distribution with even density over (0, 1).
92    ///
93    /// # Example
94    ///
95    /// ```rust
96    /// # use rv::dist::UnitPowerLaw;
97    /// // Uniform
98    /// let powlaw_unif = UnitPowerLaw::new(1.0);
99    /// assert!(powlaw_unif.is_ok());
100    ///
101    /// // Invalid negative parameter
102    /// let powlaw_nope  = UnitPowerLaw::new(-5.0);
103    /// assert!(powlaw_nope.is_err());
104    /// ```
105    pub fn new(alpha: f64) -> Result<Self, UnitPowerLawError> {
106        if alpha <= 0.0 {
107            Err(UnitPowerLawError::AlphaTooLow { alpha })
108        } else if !alpha.is_finite() {
109            Err(UnitPowerLawError::AlphaNotFinite { alpha })
110        } else {
111            Ok(UnitPowerLaw {
112                alpha,
113                alpha_inv: OnceLock::new(),
114                alpha_ln: OnceLock::new(),
115            })
116        }
117    }
118
119    /// Creates a new `UnitPowerLaw` without checking whether the parameters are valid.
120    #[inline]
121    #[must_use]
122    pub fn new_unchecked(alpha: f64) -> Self {
123        UnitPowerLaw {
124            alpha,
125            alpha_inv: OnceLock::new(),
126            alpha_ln: OnceLock::new(),
127        }
128    }
129
130    /// Create a `UnitPowerLaw` distribution with even density over (0, 1).
131    ///
132    /// # Example
133    ///
134    /// ```rust
135    /// # use rv::dist::UnitPowerLaw;
136    /// let powlaw = UnitPowerLaw::uniform();
137    /// assert_eq!(powlaw, UnitPowerLaw::new(1.0).unwrap());
138    /// ```
139    #[inline]
140    #[must_use]
141    pub fn uniform() -> Self {
142        UnitPowerLaw::new_unchecked(1.0)
143    }
144
145    /// Get the alpha parameter
146    ///
147    /// # Example
148    ///
149    /// ```rust
150    /// # use rv::dist::UnitPowerLaw;
151    /// let powlaw = UnitPowerLaw::new(5.0).unwrap();
152    /// assert_eq!(powlaw.alpha(), 5.0);
153    /// ```
154    #[inline]
155    pub fn alpha(&self) -> f64 {
156        self.alpha
157    }
158
159    /// Set the alpha parameter
160    ///
161    /// # Example
162    ///
163    /// ```rust
164    /// # use rv::dist::UnitPowerLaw;
165    /// let mut powlaw = UnitPowerLaw::new(5.0).unwrap();
166    ///
167    /// powlaw.set_alpha(2.0).unwrap();
168    /// assert_eq!(powlaw.alpha(), 2.0);
169    /// ```
170    ///
171    /// Will error for invalid values
172    ///
173    /// ```rust
174    /// # use rv::dist::UnitPowerLaw;
175    /// # let mut powlaw = UnitPowerLaw::new(5.0).unwrap();
176    /// assert!(powlaw.set_alpha(0.1).is_ok());
177    /// assert!(powlaw.set_alpha(0.0).is_err());
178    /// assert!(powlaw.set_alpha(-1.0).is_err());
179    /// assert!(powlaw.set_alpha(f64::INFINITY).is_err());
180    /// assert!(powlaw.set_alpha(f64::NAN).is_err());
181    /// ```
182    #[inline]
183    pub fn set_alpha(&mut self, alpha: f64) -> Result<(), UnitPowerLawError> {
184        if alpha <= 0.0 {
185            Err(UnitPowerLawError::AlphaTooLow { alpha })
186        } else if !alpha.is_finite() {
187            Err(UnitPowerLawError::AlphaNotFinite { alpha })
188        } else {
189            self.set_alpha_unchecked(alpha);
190            Ok(())
191        }
192    }
193
194    /// Set alpha without input validation
195    #[inline]
196    pub fn set_alpha_unchecked(&mut self, alpha: f64) {
197        self.alpha = alpha;
198        self.alpha_inv = OnceLock::new();
199        self.alpha_ln = OnceLock::new();
200    }
201
202    /// Evaluate or fetch cached ln(a*b)
203    #[inline]
204    pub fn alpha_inv(&self) -> f64 {
205        *self.alpha_inv.get_or_init(|| self.alpha.recip())
206    }
207
208    /// Evaluate or fetch cached ln(a*b)
209    #[inline]
210    pub fn alpha_ln(&self) -> f64 {
211        *self.alpha_ln.get_or_init(|| self.alpha.ln())
212    }
213}
214
215impl From<&UnitPowerLaw> for Beta {
216    fn from(powlaw: &UnitPowerLaw) -> Beta {
217        Beta::new(powlaw.alpha, 1.0).unwrap()
218    }
219}
220
221impl Default for UnitPowerLaw {
222    fn default() -> Self {
223        todo!()
224    }
225}
226
227impl From<&UnitPowerLaw> for String {
228    fn from(powlaw: &UnitPowerLaw) -> String {
229        format!("UnitPowerLaw(α: {})", powlaw.alpha)
230    }
231}
232
233impl_display!(UnitPowerLaw);
234
235macro_rules! impl_traits {
236    ($kind:ty) => {
237        impl HasDensity<$kind> for UnitPowerLaw {
238            fn ln_f(&self, x: &$kind) -> f64 {
239                (*x as f64).ln().mul_add(self.alpha - 1.0, self.alpha_ln())
240            }
241        }
242
243        impl Sampleable<$kind> for UnitPowerLaw {
244            fn draw<R: Rng>(&self, rng: &mut R) -> $kind {
245                self.invcdf(rng.random::<f64>())
246            }
247
248            fn sample<R: Rng>(&self, n: usize, rng: &mut R) -> Vec<$kind> {
249                let alpha_inv = self.alpha_inv() as $kind;
250                (0..n)
251                    .map(|_| rng.random::<$kind>().powf(alpha_inv))
252                    .collect()
253            }
254        }
255
256        impl Support<$kind> for UnitPowerLaw {
257            fn supports(&self, x: &$kind) -> bool {
258                let xf = f64::from(*x);
259                0.0 < xf && xf < 1.0
260            }
261        }
262
263        impl ContinuousDistr<$kind> for UnitPowerLaw {}
264
265        impl Cdf<$kind> for UnitPowerLaw {
266            fn cdf(&self, x: &$kind) -> f64 {
267                (*x as f64).powf(self.alpha)
268            }
269        }
270
271        impl InverseCdf<$kind> for UnitPowerLaw {
272            fn invcdf(&self, p: f64) -> $kind {
273                p.powf(self.alpha_inv()) as $kind
274            }
275        }
276
277        impl Mean<$kind> for UnitPowerLaw {
278            fn mean(&self) -> Option<$kind> {
279                Some((self.alpha / (self.alpha + 1.0)) as $kind)
280            }
281        }
282
283        impl Mode<$kind> for UnitPowerLaw {
284            fn mode(&self) -> Option<$kind> {
285                if self.alpha > 1.0 { Some(1.0) } else { None }
286            }
287        }
288
289        impl HasSuffStat<$kind> for UnitPowerLaw {
290            type Stat = UnitPowerLawSuffStat;
291
292            fn empty_suffstat(&self) -> Self::Stat {
293                Self::Stat::new()
294            }
295
296            fn ln_f_stat(&self, stat: &Self::Stat) -> f64 {
297                let n = stat.n() as f64;
298                let t1 = n * self.alpha_ln();
299                let t2 = (self.alpha - 1.0) * stat.sum_ln_x();
300                t2 + t1
301            }
302        }
303    };
304}
305
306impl Variance<f64> for UnitPowerLaw {
307    fn variance(&self) -> Option<f64> {
308        let apb = self.alpha + 1.0;
309        Some(self.alpha / (apb * apb * (apb + 1.0)))
310    }
311}
312
313impl Entropy for UnitPowerLaw {
314    fn entropy(&self) -> f64 {
315        let apb = self.alpha + 1.0;
316        (apb - 2.0).mul_add(
317            apb.digamma(),
318            (self.alpha - 1.0).mul_add(-self.alpha.digamma(), -self.alpha_ln()),
319        )
320    }
321}
322
323impl Skewness for UnitPowerLaw {
324    fn skewness(&self) -> Option<f64> {
325        let apb = self.alpha + 1.0;
326        let numer = 2.0 * (1.0 - self.alpha) * (apb + 1.0).sqrt();
327        let denom = (apb + 2.0) * (self.alpha * 1.0).sqrt();
328        Some(numer / denom)
329    }
330}
331
332impl Kurtosis for UnitPowerLaw {
333    fn kurtosis(&self) -> Option<f64> {
334        let apb = self.alpha + 1.0;
335        let amb = self.alpha - 1.0;
336        let atb = self.alpha * 1.0;
337        let numer = 6.0 * (amb * amb).mul_add(apb + 1.0, -atb * (apb + 2.0));
338        let denom = atb * (apb + 2.0) * (apb + 3.0);
339        Some(numer / denom)
340    }
341}
342
343impl_traits!(f32);
344impl_traits!(f64);
345
346impl std::error::Error for UnitPowerLawError {}
347
348#[cfg_attr(coverage_nightly, coverage(off))]
349impl fmt::Display for UnitPowerLawError {
350    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
351        match self {
352            Self::AlphaTooLow { alpha } => {
353                write!(f, "alpha ({alpha}) must be greater than zero")
354            }
355            Self::AlphaNotFinite { alpha } => {
356                write!(f, "alpha ({alpha}) was non finite")
357            }
358        }
359    }
360}
361
362#[cfg(test)]
363mod tests {
364
365    use super::*;
366    use crate::misc::ks_test;
367    use crate::test_basic_impls;
368
369    const TOL: f64 = 1E-12;
370    const KS_PVAL: f64 = 0.2;
371    const N_TRIES: usize = 5;
372
373    test_basic_impls!(f64, UnitPowerLaw, UnitPowerLaw::new(1.5).unwrap());
374
375    #[test]
376    fn new() {
377        let powlaw = UnitPowerLaw::new(2.0).unwrap();
378        assert::close(powlaw.alpha, 2.0, TOL);
379    }
380
381    #[test]
382    fn uniform() {
383        let powlaw = UnitPowerLaw::uniform();
384        assert::close(powlaw.alpha, 1.0, TOL);
385    }
386
387    #[test]
388    fn ln_pdf_center_value() {
389        let powlaw = UnitPowerLaw::new(1.5).unwrap();
390        let beta: Beta = (&powlaw).into();
391        assert::close(powlaw.ln_pdf(&0.5), beta.ln_pdf(&0.5), TOL);
392    }
393
394    #[test]
395    fn ln_pdf_low_value() {
396        let powlaw = UnitPowerLaw::new(1.5).unwrap();
397        let beta: Beta = (&powlaw).into();
398        assert::close(powlaw.ln_pdf(&0.01), beta.ln_pdf(&0.01), TOL);
399    }
400
401    #[test]
402    fn ln_pdf_high_value() {
403        let powlaw = UnitPowerLaw::new(1.5).unwrap();
404        let beta: Beta = (&powlaw).into();
405        assert::close(powlaw.ln_pdf(&0.99), beta.ln_pdf(&0.99), TOL);
406    }
407
408    #[test]
409    fn pdf_preserved_after_set_reset_alpha() {
410        let x: f64 = 0.6;
411        let alpha = 1.5;
412
413        let mut powlaw = UnitPowerLaw::new(alpha).unwrap();
414
415        let f_1 = powlaw.f(&x);
416        let ln_f_1 = powlaw.ln_f(&x);
417
418        powlaw.set_alpha(3.4).unwrap();
419
420        assert_ne!(f_1, powlaw.f(&x));
421        assert_ne!(ln_f_1, powlaw.ln_f(&x));
422
423        powlaw.set_alpha(alpha).unwrap();
424
425        assert_eq!(f_1, powlaw.f(&x));
426        assert_eq!(ln_f_1, powlaw.ln_f(&x));
427    }
428
429    #[test]
430    fn cdf_hump_shaped() {
431        let powlaw = UnitPowerLaw::new(1.5).unwrap();
432        let beta: Beta = (&powlaw).into();
433        let xs: Vec<f64> = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9];
434        for x in &xs {
435            assert::close(powlaw.cdf(x), beta.cdf(x), TOL);
436        }
437    }
438
439    #[test]
440    fn cdf_bowl_shaped() {
441        let powlaw = UnitPowerLaw::new(0.5).unwrap();
442        let beta: Beta = (&powlaw).into();
443        let xs: Vec<f64> = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9];
444        for x in &xs {
445            assert::close(powlaw.cdf(x), beta.cdf(x), TOL);
446        }
447    }
448
449    #[test]
450    fn draw_should_return_values_within_0_to_1() {
451        let mut rng = rand::rng();
452        let powlaw = UnitPowerLaw::new(2.0).unwrap();
453        for _ in 0..100 {
454            let x = powlaw.draw(&mut rng);
455            assert!(0.0 < x && x < 1.0);
456        }
457    }
458
459    #[test]
460    fn sample_returns_the_correct_number_draws() {
461        let mut rng = rand::rng();
462        let powlaw = UnitPowerLaw::new(2.0).unwrap();
463        let xs: Vec<f32> = powlaw.sample(103, &mut rng);
464        assert_eq!(xs.len(), 103);
465    }
466
467    #[test]
468    fn uniform_mean() {
469        let mean: f64 = UnitPowerLaw::uniform().mean().unwrap();
470        assert::close(mean, 0.5, TOL);
471    }
472
473    #[test]
474    fn mean() {
475        let mean: f64 = UnitPowerLaw::new(5.0).unwrap().mean().unwrap();
476        assert::close(mean, 5.0 / 6.0, TOL);
477    }
478
479    #[test]
480    fn variance() {
481        let powlaw = UnitPowerLaw::new(1.5).unwrap();
482        let beta: Beta = (&powlaw).into();
483        assert::close(
484            powlaw.variance().unwrap(),
485            beta.variance().unwrap(),
486            TOL,
487        );
488    }
489
490    #[test]
491    fn mode_for_large_alpha_and_powlaw_one() {
492        let mode: f64 = UnitPowerLaw::new(2.0).unwrap().mode().unwrap();
493        assert::close(mode, 1.0, TOL);
494    }
495
496    #[test]
497    fn mode_for_alpha_less_than_one_is_none() {
498        let mode_opt: Option<f64> = UnitPowerLaw::new(0.99).unwrap().mode();
499        assert!(mode_opt.is_none());
500    }
501
502    #[test]
503    fn entropy() {
504        let powlaw = UnitPowerLaw::new(1.5).unwrap();
505        let beta: Beta = (&powlaw).into();
506        assert::close(powlaw.entropy(), beta.entropy(), TOL);
507    }
508
509    #[test]
510    fn uniform_skewness_should_be_zero() {
511        assert::close(UnitPowerLaw::uniform().skewness().unwrap(), 0.0, TOL);
512    }
513
514    #[test]
515    fn skewness() {
516        let powlaw = UnitPowerLaw::new(1.5).unwrap();
517        let beta: Beta = (&powlaw).into();
518        assert::close(
519            powlaw.skewness().unwrap(),
520            beta.skewness().unwrap(),
521            TOL,
522        );
523    }
524
525    #[test]
526    fn kurtosis() {
527        let powlaw = UnitPowerLaw::new(1.5).unwrap();
528        let beta: Beta = (&powlaw).into();
529        assert::close(
530            powlaw.kurtosis().unwrap(),
531            beta.kurtosis().unwrap(),
532            TOL,
533        );
534    }
535
536    #[test]
537    fn draw_test_alpha_powlaw_gt_one() {
538        let mut rng = rand::rng();
539        let powlaw = UnitPowerLaw::new(1.2).unwrap();
540        let cdf = |x: f64| powlaw.cdf(&x);
541
542        // test is flaky, try a few times
543        let passes = (0..N_TRIES).fold(0, |acc, _| {
544            let xs: Vec<f64> = powlaw.sample(1000, &mut rng);
545            let (_, p) = ks_test(&xs, cdf);
546            if p > KS_PVAL { acc + 1 } else { acc }
547        });
548
549        assert!(passes > 0);
550    }
551
552    #[test]
553    fn draw_test_alpha_powlaw_lt_one() {
554        let mut rng = rand::rng();
555        let powlaw = UnitPowerLaw::new(0.2).unwrap();
556        let cdf = |x: f64| powlaw.cdf(&x);
557
558        // test is flaky, try a few times
559        let passes = (0..N_TRIES).fold(0, |acc, _| {
560            let xs: Vec<f64> = powlaw.sample(1000, &mut rng);
561            let (_, p) = ks_test(&xs, cdf);
562            if p > KS_PVAL { acc + 1 } else { acc }
563        });
564
565        assert!(passes > 0);
566    }
567
568    #[test]
569    fn ln_f_stat() {
570        use crate::traits::SuffStat;
571
572        let data: Vec<f64> = vec![0.1, 0.23, 0.4, 0.65, 0.22, 0.31];
573        let mut stat = UnitPowerLawSuffStat::new();
574        stat.observe_many(&data);
575
576        let powlaw = UnitPowerLaw::new(0.3).unwrap();
577
578        let ln_f_base: f64 = data.iter().map(|x| powlaw.ln_f(x)).sum();
579        let ln_f_stat: f64 =
580            <UnitPowerLaw as HasSuffStat<f64>>::ln_f_stat(&powlaw, &stat);
581
582        assert::close(ln_f_base, ln_f_stat, 1e-12);
583    }
584
585    #[test]
586    fn set_alpha() {
587        let mut rng = rand::rng();
588
589        for _ in 0..100 {
590            let a1 = rng.random::<f64>();
591            let mut powlaw1 = UnitPowerLaw::new(a1).unwrap();
592
593            // Any value in the unit interval
594            let x: f64 = rng.random();
595
596            // Evaluate the pdf to force computation of `ln_powlaw_ab`
597            let _ = powlaw1.pdf(&x);
598
599            // Next we'll `set_alpha` to a2, and compare with a fresh UnitPowerLaw
600            let a2 = rng.random::<f64>();
601
602            // Setting the new values
603            powlaw1.set_alpha(a2).unwrap();
604
605            // ... and here's the fresh version
606            let powlaw2 = UnitPowerLaw::new(a2).unwrap();
607
608            let pdf_1 = powlaw1.ln_f(&x);
609            let pdf_2 = powlaw2.ln_f(&x);
610
611            assert::close(pdf_1, pdf_2, 1e-14);
612        }
613    }
614
615    #[test]
616    fn emit_and_from_params_are_identity() {
617        let vm = UnitPowerLaw::new(0.5).unwrap();
618        let vm_b = UnitPowerLaw::from_params(vm.emit_params());
619        assert_eq!(vm, vm_b);
620    }
621}