rv/dist/
invgaussian.rs

1//! Inverse Gaussian distribution over x in (0, ∞)
2#[cfg(feature = "serde1")]
3use serde::{Deserialize, Serialize};
4
5use rand::Rng;
6use rand_distr::Normal;
7use std::fmt;
8use std::sync::OnceLock;
9
10use crate::consts::{HALF_LN_2PI, LN_2PI};
11use crate::data::InvGaussianSuffStat;
12use crate::impl_display;
13use crate::traits::{
14    Cdf, ContinuousDistr, HasDensity, HasSuffStat, Kurtosis, Mean, Mode,
15    Parameterized, Sampleable, Scalable, Shiftable, Skewness, Support,
16    Variance,
17};
18
19/// [Inverse Gaussian distribution](https://en.wikipedia.org/wiki/Inverse_Gaussian_distribution),
20/// N<sup>-1</sup>(μ, λ) over real values.
21#[derive(Debug, Clone)]
22#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
23#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
24pub struct InvGaussian {
25    /// Mean
26    mu: f64,
27    /// Shape
28    lambda: f64,
29    /// Cached log(lambda)
30    #[cfg_attr(feature = "serde1", serde(skip))]
31    ln_lambda: OnceLock<f64>,
32}
33
34pub struct InvGaussianParameters {
35    pub mu: f64,
36    pub lambda: f64,
37}
38
39crate::impl_shiftable!(InvGaussian);
40crate::impl_scalable!(InvGaussian);
41
42impl Parameterized for InvGaussian {
43    type Parameters = InvGaussianParameters;
44
45    fn emit_params(&self) -> Self::Parameters {
46        Self::Parameters {
47            mu: self.mu(),
48            lambda: self.lambda(),
49        }
50    }
51
52    fn from_params(params: Self::Parameters) -> Self {
53        Self::new_unchecked(params.mu, params.lambda)
54    }
55}
56
57impl PartialEq for InvGaussian {
58    fn eq(&self, other: &InvGaussian) -> bool {
59        self.mu == other.mu && self.lambda == other.lambda
60    }
61}
62
63#[derive(Debug, Clone, PartialEq)]
64#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
65#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
66pub enum InvGaussianError {
67    /// The mu parameter is infinite or NaN
68    MuNotFinite { mu: f64 },
69    /// The mu parameter is less than or equal to zero
70    MuTooLow { mu: f64 },
71    /// The lambda parameter is less than or equal to zero
72    LambdaTooLow { lambda: f64 },
73    /// The lambda parameter is infinite or NaN
74    LambdaNotFinite { lambda: f64 },
75}
76
77impl InvGaussian {
78    /// Create a new Inverse Gaussian distribution
79    ///
80    /// # Arguments
81    /// - mu: mean > 0
82    /// - lambda: shape > 0
83    ///
84    /// ```
85    /// use rv::dist::InvGaussian;
86    /// let invgauss = InvGaussian::new(1.0, 3.0).unwrap();
87    /// ```
88    ///
89    /// Mu and lambda must be finite and greater than 0.
90    /// ```
91    /// # use rv::dist::InvGaussian;
92    /// use std::f64::{NAN, INFINITY};
93    /// assert!(InvGaussian::new(0.0, 3.0).is_err());
94    /// assert!(InvGaussian::new(NAN, 3.0).is_err());
95    /// assert!(InvGaussian::new(INFINITY, 3.0).is_err());
96    ///
97    /// assert!(InvGaussian::new(1.0, 0.0).is_err());
98    /// assert!(InvGaussian::new(1.0, NAN).is_err());
99    /// assert!(InvGaussian::new(1.0, INFINITY).is_err());
100    /// ```
101    pub fn new(mu: f64, lambda: f64) -> Result<Self, InvGaussianError> {
102        if !mu.is_finite() {
103            Err(InvGaussianError::MuNotFinite { mu })
104        } else if mu <= 0.0 {
105            Err(InvGaussianError::MuTooLow { mu })
106        } else if lambda <= 0.0 {
107            Err(InvGaussianError::LambdaTooLow { lambda })
108        } else if !lambda.is_finite() {
109            Err(InvGaussianError::LambdaNotFinite { lambda })
110        } else {
111            Ok(InvGaussian {
112                mu,
113                lambda,
114                ln_lambda: OnceLock::new(),
115            })
116        }
117    }
118
119    /// Creates a new `InvGaussian` without checking whether the parameters are
120    /// valid.
121    #[inline]
122    #[must_use]
123    pub fn new_unchecked(mu: f64, lambda: f64) -> Self {
124        InvGaussian {
125            mu,
126            lambda,
127            ln_lambda: OnceLock::new(),
128        }
129    }
130
131    /// Get mu parameter
132    ///
133    /// # Example
134    ///
135    /// ```rust
136    /// # use rv::dist::InvGaussian;
137    /// let ig = InvGaussian::new(2.0, 1.5).unwrap();
138    ///
139    /// assert_eq!(ig.mu(), 2.0);
140    /// ```
141    #[inline]
142    pub fn mu(&self) -> f64 {
143        self.mu
144    }
145
146    /// Set the value of mu
147    ///
148    /// # Example
149    ///
150    /// ```rust
151    /// # use rv::dist::InvGaussian;
152    /// let mut ig = InvGaussian::new(2.0, 1.5).unwrap();
153    /// assert_eq!(ig.mu(), 2.0);
154    ///
155    /// ig.set_mu(1.3).unwrap();
156    /// assert_eq!(ig.mu(), 1.3);
157    /// ```
158    ///
159    /// Will error for invalid values
160    ///
161    /// ```rust
162    /// # use rv::dist::InvGaussian;
163    /// # let mut ig = InvGaussian::new(2.0, 1.5).unwrap();
164    /// assert!(ig.set_mu(1.3).is_ok());
165    /// assert!(ig.set_mu(0.0).is_err());
166    /// assert!(ig.set_mu(-1.0).is_err());
167    /// assert!(ig.set_mu(f64::NEG_INFINITY).is_err());
168    /// assert!(ig.set_mu(f64::INFINITY).is_err());
169    /// assert!(ig.set_mu(f64::NAN).is_err());
170    /// ```
171    #[inline]
172    pub fn set_mu(&mut self, mu: f64) -> Result<(), InvGaussianError> {
173        if !mu.is_finite() {
174            Err(InvGaussianError::MuNotFinite { mu })
175        } else if mu <= 0.0 {
176            Err(InvGaussianError::MuTooLow { mu })
177        } else {
178            self.set_mu_unchecked(mu);
179            Ok(())
180        }
181    }
182
183    /// Set the value of mu without input validation
184    #[inline]
185    pub fn set_mu_unchecked(&mut self, mu: f64) {
186        self.mu = mu;
187    }
188
189    /// Get lambda parameter
190    ///
191    /// # Example
192    ///
193    /// ```rust
194    /// # use rv::dist::InvGaussian;
195    /// let ig = InvGaussian::new(2.0, 1.5).unwrap();
196    ///
197    /// assert_eq!(ig.lambda(), 1.5);
198    /// ```
199    #[inline]
200    pub fn lambda(&self) -> f64 {
201        self.lambda
202    }
203
204    /// Set the value of lambda
205    ///
206    /// # Example
207    ///
208    /// ```rust
209    /// # use rv::dist::InvGaussian;
210    /// let mut ig = InvGaussian::new(1.0, 2.0).unwrap();
211    /// assert_eq!(ig.lambda(), 2.0);
212    ///
213    /// ig.set_lambda(2.3).unwrap();
214    /// assert_eq!(ig.lambda(), 2.3);
215    /// ```
216    ///
217    /// Will error for invalid values
218    ///
219    /// ```rust
220    /// # use rv::dist::InvGaussian;
221    /// # let mut ig = InvGaussian::new(1.0, 2.0).unwrap();
222    /// assert!(ig.set_lambda(2.3).is_ok());
223    /// assert!(ig.set_lambda(0.0).is_err());
224    /// assert!(ig.set_lambda(-1.0).is_err());
225    /// assert!(ig.set_lambda(f64::INFINITY).is_err());
226    /// assert!(ig.set_lambda(f64::NEG_INFINITY).is_err());
227    /// assert!(ig.set_lambda(f64::NAN).is_err());
228    /// ```
229    #[inline]
230    pub fn set_lambda(&mut self, lambda: f64) -> Result<(), InvGaussianError> {
231        if lambda <= 0.0 {
232            Err(InvGaussianError::LambdaTooLow { lambda })
233        } else if !lambda.is_finite() {
234            Err(InvGaussianError::LambdaNotFinite { lambda })
235        } else {
236            self.set_lambda_unchecked(lambda);
237            Ok(())
238        }
239    }
240
241    /// Set the value of lambda without input validation
242    #[inline]
243    pub fn set_lambda_unchecked(&mut self, lambda: f64) {
244        self.ln_lambda = OnceLock::new();
245        self.lambda = lambda;
246    }
247
248    #[inline]
249    fn ln_lambda(&self) -> f64 {
250        *self.ln_lambda.get_or_init(|| self.lambda.ln())
251    }
252}
253
254impl From<&InvGaussian> for String {
255    fn from(ig: &InvGaussian) -> String {
256        format!("N⁻¹(μ: {}, λ: {})", ig.mu, ig.lambda)
257    }
258}
259
260impl_display!(InvGaussian);
261
262macro_rules! impl_traits {
263    ($kind:ty) => {
264        impl HasDensity<$kind> for InvGaussian {
265            fn ln_f(&self, x: &$kind) -> f64 {
266                let InvGaussianParameters { mu, lambda } = self.emit_params();
267                let xf = f64::from(*x);
268                let z = self.ln_lambda() - xf.ln().mul_add(3.0, LN_2PI);
269                let err = xf - mu;
270                let term = lambda * err * err / (2.0 * mu * mu * xf);
271                z.mul_add(0.5, -term)
272            }
273        }
274
275        impl Sampleable<$kind> for InvGaussian {
276            // https://en.wikipedia.org/wiki/Inverse_Gaussian_distribution#Sampling_from_an_inverse-Gaussian_distribution
277            fn draw<R: Rng>(&self, rng: &mut R) -> $kind {
278                let InvGaussianParameters { mu, lambda } = self.emit_params();
279                let g = Normal::new(0.0, 1.0).unwrap();
280                let v: f64 = rng.sample(g);
281                let y = v * v;
282                let mu2 = mu * mu;
283                let x = 0.5_f64.mul_add(
284                    (mu / lambda).mul_add(
285                        -(4.0 * mu * lambda).mul_add(y, mu2 * y * y).sqrt(),
286                        mu2 * y / lambda,
287                    ),
288                    mu,
289                );
290                let z: f64 = rng.random();
291
292                if z <= mu / (mu + x) {
293                    x as $kind
294                } else {
295                    (mu2 / x) as $kind
296                }
297            }
298        }
299
300        impl ContinuousDistr<$kind> for InvGaussian {}
301
302        impl Support<$kind> for InvGaussian {
303            fn supports(&self, x: &$kind) -> bool {
304                x.is_finite()
305            }
306        }
307
308        impl Cdf<$kind> for InvGaussian {
309            fn cdf(&self, x: &$kind) -> f64 {
310                let xf = f64::from(*x);
311                let InvGaussianParameters { mu, lambda } = self.emit_params();
312                let gauss = crate::dist::Gaussian::standard();
313                let z = (lambda / xf).sqrt();
314                let a = z * (xf / mu - 1.0);
315                let b = -z * (xf / mu + 1.0);
316                (2.0 * lambda / mu)
317                    .exp()
318                    .mul_add(gauss.cdf(&b), gauss.cdf(&a))
319            }
320        }
321        impl Mean<$kind> for InvGaussian {
322            fn mean(&self) -> Option<$kind> {
323                Some(self.mu as $kind)
324            }
325        }
326
327        impl Mode<$kind> for InvGaussian {
328            fn mode(&self) -> Option<$kind> {
329                let InvGaussianParameters { mu, lambda } = self.emit_params();
330                let a = (1.0 + 0.25 * 9.0 * mu * mu / (lambda * lambda)).sqrt();
331                let b = 0.5 * 3.0 * mu / lambda;
332                let mode = mu * (a - b);
333                Some(mode as $kind)
334            }
335        }
336
337        impl HasSuffStat<$kind> for InvGaussian {
338            type Stat = InvGaussianSuffStat;
339
340            fn empty_suffstat(&self) -> Self::Stat {
341                InvGaussianSuffStat::new()
342            }
343
344            fn ln_f_stat(&self, stat: &Self::Stat) -> f64 {
345                let n = stat.n() as f64;
346                let mu2 = self.mu * self.mu;
347                let t1 = n.mul_add(
348                    0.5_f64.mul_add(self.ln_lambda(), -HALF_LN_2PI),
349                    -3.0 / 2.0 * stat.sum_ln_x(),
350                );
351                let t2 = self.lambda() / (2.0 * mu2);
352                let t3 = (2.0 * n).mul_add(-self.mu, stat.sum_x());
353                let t4 = stat.sum_inv_x().mul_add(mu2, t3);
354                t2.mul_add(-t4, t1)
355            }
356        }
357    };
358}
359
360impl Variance<f64> for InvGaussian {
361    fn variance(&self) -> Option<f64> {
362        Some(self.mu.powi(3) / self.lambda)
363    }
364}
365
366impl Skewness for InvGaussian {
367    fn skewness(&self) -> Option<f64> {
368        Some(2.0 * (self.mu / self.lambda).sqrt())
369    }
370}
371
372impl Kurtosis for InvGaussian {
373    fn kurtosis(&self) -> Option<f64> {
374        Some(15.0 * self.mu / self.lambda)
375    }
376}
377
378impl_traits!(f32);
379impl_traits!(f64);
380
381impl std::error::Error for InvGaussianError {}
382
383#[cfg_attr(coverage_nightly, coverage(off))]
384impl fmt::Display for InvGaussianError {
385    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
386        match self {
387            Self::MuNotFinite { mu } => write!(f, "non-finite mu: {mu}"),
388            Self::MuTooLow { mu } => {
389                write!(f, "mu ({mu}) must be greater than zero")
390            }
391            Self::LambdaTooLow { lambda } => {
392                write!(f, "lambda ({lambda}) must be greater than zero")
393            }
394            Self::LambdaNotFinite { lambda } => {
395                write!(f, "non-finite lambda: {lambda}")
396            }
397        }
398    }
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404    use crate::misc::ks_test;
405
406    const N_TRIES: usize = 10;
407    const KS_PVAL: f64 = 0.2;
408
409    crate::test_basic_impls!(
410        f64,
411        InvGaussian,
412        InvGaussian::new(1.0, 2.3).unwrap()
413    );
414
415    #[test]
416    fn mode_is_highest_point() {
417        let mut rng = rand::rng();
418        let mu_prior = crate::dist::InvGamma::new_unchecked(2.0, 2.0);
419        let lambda_prior = crate::dist::InvGamma::new_unchecked(2.0, 2.0);
420        for _ in 0..100 {
421            let mu: f64 = mu_prior.draw(&mut rng);
422            let lambda: f64 = lambda_prior.draw(&mut rng);
423            let ig = InvGaussian::new(mu, lambda).unwrap();
424            let mode: f64 = ig.mode().unwrap();
425            let ln_f_mode = ig.ln_f(&mode);
426            let ln_f_plus = ig.ln_f(&(mode + 1e-4));
427            let ln_f_minus = ig.ln_f(&(mode - 1e-4));
428
429            assert!(ln_f_mode > ln_f_plus);
430            assert!(ln_f_mode > ln_f_minus);
431        }
432    }
433
434    #[test]
435    fn quad_on_pdf_agrees_with_cdf_x() {
436        use peroxide::numerical::integral::{
437            Integral, gauss_kronrod_quadrature,
438        };
439        let ig = InvGaussian::new(1.1, 2.5).unwrap();
440        // use pdf to hit `supports(x)` first
441        let pdf = |x: f64| ig.pdf(&x);
442        let mut rng = rand::rng();
443        for _ in 0..100 {
444            let x: f64 = ig.draw(&mut rng);
445            let res = gauss_kronrod_quadrature(
446                pdf,
447                (1e-16, x),
448                Integral::G7K15(1e-10, 100),
449            );
450            let cdf = ig.cdf(&x);
451            assert::close(res, cdf, 1e-7);
452        }
453    }
454
455    #[test]
456    fn draw_vs_kl() {
457        let mut rng = rand::rng();
458        let ig = InvGaussian::new(1.2, 3.4).unwrap();
459        let cdf = |x: f64| ig.cdf(&x);
460
461        // test is flaky, try a few times
462        let passes = (0..N_TRIES).fold(0, |acc, _| {
463            let xs: Vec<f64> = ig.sample(1000, &mut rng);
464            let (_, p) = ks_test(&xs, cdf);
465            if p > KS_PVAL { acc + 1 } else { acc }
466        });
467
468        assert!(passes > 0);
469    }
470
471    #[test]
472    fn ln_f_stat() {
473        use crate::traits::SuffStat;
474
475        let data: Vec<f64> = vec![0.1, 0.23, 1.4, 0.65, 0.22, 3.1];
476        let mut stat = InvGaussianSuffStat::new();
477        stat.observe_many(&data);
478
479        let igauss = InvGaussian::new(0.3, 2.33).unwrap();
480
481        let ln_f_base: f64 = data.iter().map(|x| igauss.ln_f(x)).sum();
482        let ln_f_stat: f64 =
483            <InvGaussian as HasSuffStat<f64>>::ln_f_stat(&igauss, &stat);
484
485        assert::close(ln_f_base, ln_f_stat, 1e-12);
486    }
487
488    #[test]
489    fn emit_and_from_params_are_identity() {
490        let dist_a = InvGaussian::new(1.5, 3.5).unwrap();
491        let dist_b = InvGaussian::from_params(dist_a.emit_params());
492        assert_eq!(dist_a, dist_b);
493    }
494}