1#[cfg(feature = "serde1")]
3use serde::{Deserialize, Serialize};
4
5use rand::Rng;
6use rand_distr::Normal;
7use std::fmt;
8use std::sync::OnceLock;
9
10use crate::consts::{HALF_LN_2PI, LN_2PI};
11use crate::data::InvGaussianSuffStat;
12use crate::impl_display;
13use crate::traits::{
14 Cdf, ContinuousDistr, HasDensity, HasSuffStat, Kurtosis, Mean, Mode,
15 Parameterized, Sampleable, Scalable, Shiftable, Skewness, Support,
16 Variance,
17};
18
19#[derive(Debug, Clone)]
22#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
23#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
24pub struct InvGaussian {
25 mu: f64,
27 lambda: f64,
29 #[cfg_attr(feature = "serde1", serde(skip))]
31 ln_lambda: OnceLock<f64>,
32}
33
34pub struct InvGaussianParameters {
35 pub mu: f64,
36 pub lambda: f64,
37}
38
39crate::impl_shiftable!(InvGaussian);
40crate::impl_scalable!(InvGaussian);
41
42impl Parameterized for InvGaussian {
43 type Parameters = InvGaussianParameters;
44
45 fn emit_params(&self) -> Self::Parameters {
46 Self::Parameters {
47 mu: self.mu(),
48 lambda: self.lambda(),
49 }
50 }
51
52 fn from_params(params: Self::Parameters) -> Self {
53 Self::new_unchecked(params.mu, params.lambda)
54 }
55}
56
57impl PartialEq for InvGaussian {
58 fn eq(&self, other: &InvGaussian) -> bool {
59 self.mu == other.mu && self.lambda == other.lambda
60 }
61}
62
63#[derive(Debug, Clone, PartialEq)]
64#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
65#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
66pub enum InvGaussianError {
67 MuNotFinite { mu: f64 },
69 MuTooLow { mu: f64 },
71 LambdaTooLow { lambda: f64 },
73 LambdaNotFinite { lambda: f64 },
75}
76
77impl InvGaussian {
78 pub fn new(mu: f64, lambda: f64) -> Result<Self, InvGaussianError> {
102 if !mu.is_finite() {
103 Err(InvGaussianError::MuNotFinite { mu })
104 } else if mu <= 0.0 {
105 Err(InvGaussianError::MuTooLow { mu })
106 } else if lambda <= 0.0 {
107 Err(InvGaussianError::LambdaTooLow { lambda })
108 } else if !lambda.is_finite() {
109 Err(InvGaussianError::LambdaNotFinite { lambda })
110 } else {
111 Ok(InvGaussian {
112 mu,
113 lambda,
114 ln_lambda: OnceLock::new(),
115 })
116 }
117 }
118
119 #[inline]
122 #[must_use]
123 pub fn new_unchecked(mu: f64, lambda: f64) -> Self {
124 InvGaussian {
125 mu,
126 lambda,
127 ln_lambda: OnceLock::new(),
128 }
129 }
130
131 #[inline]
142 pub fn mu(&self) -> f64 {
143 self.mu
144 }
145
146 #[inline]
172 pub fn set_mu(&mut self, mu: f64) -> Result<(), InvGaussianError> {
173 if !mu.is_finite() {
174 Err(InvGaussianError::MuNotFinite { mu })
175 } else if mu <= 0.0 {
176 Err(InvGaussianError::MuTooLow { mu })
177 } else {
178 self.set_mu_unchecked(mu);
179 Ok(())
180 }
181 }
182
183 #[inline]
185 pub fn set_mu_unchecked(&mut self, mu: f64) {
186 self.mu = mu;
187 }
188
189 #[inline]
200 pub fn lambda(&self) -> f64 {
201 self.lambda
202 }
203
204 #[inline]
230 pub fn set_lambda(&mut self, lambda: f64) -> Result<(), InvGaussianError> {
231 if lambda <= 0.0 {
232 Err(InvGaussianError::LambdaTooLow { lambda })
233 } else if !lambda.is_finite() {
234 Err(InvGaussianError::LambdaNotFinite { lambda })
235 } else {
236 self.set_lambda_unchecked(lambda);
237 Ok(())
238 }
239 }
240
241 #[inline]
243 pub fn set_lambda_unchecked(&mut self, lambda: f64) {
244 self.ln_lambda = OnceLock::new();
245 self.lambda = lambda;
246 }
247
248 #[inline]
249 fn ln_lambda(&self) -> f64 {
250 *self.ln_lambda.get_or_init(|| self.lambda.ln())
251 }
252}
253
254impl From<&InvGaussian> for String {
255 fn from(ig: &InvGaussian) -> String {
256 format!("N⁻¹(μ: {}, λ: {})", ig.mu, ig.lambda)
257 }
258}
259
260impl_display!(InvGaussian);
261
262macro_rules! impl_traits {
263 ($kind:ty) => {
264 impl HasDensity<$kind> for InvGaussian {
265 fn ln_f(&self, x: &$kind) -> f64 {
266 let InvGaussianParameters { mu, lambda } = self.emit_params();
267 let xf = f64::from(*x);
268 let z = self.ln_lambda() - xf.ln().mul_add(3.0, LN_2PI);
269 let err = xf - mu;
270 let term = lambda * err * err / (2.0 * mu * mu * xf);
271 z.mul_add(0.5, -term)
272 }
273 }
274
275 impl Sampleable<$kind> for InvGaussian {
276 fn draw<R: Rng>(&self, rng: &mut R) -> $kind {
278 let InvGaussianParameters { mu, lambda } = self.emit_params();
279 let g = Normal::new(0.0, 1.0).unwrap();
280 let v: f64 = rng.sample(g);
281 let y = v * v;
282 let mu2 = mu * mu;
283 let x = 0.5_f64.mul_add(
284 (mu / lambda).mul_add(
285 -(4.0 * mu * lambda).mul_add(y, mu2 * y * y).sqrt(),
286 mu2 * y / lambda,
287 ),
288 mu,
289 );
290 let z: f64 = rng.random();
291
292 if z <= mu / (mu + x) {
293 x as $kind
294 } else {
295 (mu2 / x) as $kind
296 }
297 }
298 }
299
300 impl ContinuousDistr<$kind> for InvGaussian {}
301
302 impl Support<$kind> for InvGaussian {
303 fn supports(&self, x: &$kind) -> bool {
304 x.is_finite()
305 }
306 }
307
308 impl Cdf<$kind> for InvGaussian {
309 fn cdf(&self, x: &$kind) -> f64 {
310 let xf = f64::from(*x);
311 let InvGaussianParameters { mu, lambda } = self.emit_params();
312 let gauss = crate::dist::Gaussian::standard();
313 let z = (lambda / xf).sqrt();
314 let a = z * (xf / mu - 1.0);
315 let b = -z * (xf / mu + 1.0);
316 (2.0 * lambda / mu)
317 .exp()
318 .mul_add(gauss.cdf(&b), gauss.cdf(&a))
319 }
320 }
321 impl Mean<$kind> for InvGaussian {
322 fn mean(&self) -> Option<$kind> {
323 Some(self.mu as $kind)
324 }
325 }
326
327 impl Mode<$kind> for InvGaussian {
328 fn mode(&self) -> Option<$kind> {
329 let InvGaussianParameters { mu, lambda } = self.emit_params();
330 let a = (1.0 + 0.25 * 9.0 * mu * mu / (lambda * lambda)).sqrt();
331 let b = 0.5 * 3.0 * mu / lambda;
332 let mode = mu * (a - b);
333 Some(mode as $kind)
334 }
335 }
336
337 impl HasSuffStat<$kind> for InvGaussian {
338 type Stat = InvGaussianSuffStat;
339
340 fn empty_suffstat(&self) -> Self::Stat {
341 InvGaussianSuffStat::new()
342 }
343
344 fn ln_f_stat(&self, stat: &Self::Stat) -> f64 {
345 let n = stat.n() as f64;
346 let mu2 = self.mu * self.mu;
347 let t1 = n.mul_add(
348 0.5_f64.mul_add(self.ln_lambda(), -HALF_LN_2PI),
349 -3.0 / 2.0 * stat.sum_ln_x(),
350 );
351 let t2 = self.lambda() / (2.0 * mu2);
352 let t3 = (2.0 * n).mul_add(-self.mu, stat.sum_x());
353 let t4 = stat.sum_inv_x().mul_add(mu2, t3);
354 t2.mul_add(-t4, t1)
355 }
356 }
357 };
358}
359
360impl Variance<f64> for InvGaussian {
361 fn variance(&self) -> Option<f64> {
362 Some(self.mu.powi(3) / self.lambda)
363 }
364}
365
366impl Skewness for InvGaussian {
367 fn skewness(&self) -> Option<f64> {
368 Some(2.0 * (self.mu / self.lambda).sqrt())
369 }
370}
371
372impl Kurtosis for InvGaussian {
373 fn kurtosis(&self) -> Option<f64> {
374 Some(15.0 * self.mu / self.lambda)
375 }
376}
377
378impl_traits!(f32);
379impl_traits!(f64);
380
381impl std::error::Error for InvGaussianError {}
382
383#[cfg_attr(coverage_nightly, coverage(off))]
384impl fmt::Display for InvGaussianError {
385 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
386 match self {
387 Self::MuNotFinite { mu } => write!(f, "non-finite mu: {mu}"),
388 Self::MuTooLow { mu } => {
389 write!(f, "mu ({mu}) must be greater than zero")
390 }
391 Self::LambdaTooLow { lambda } => {
392 write!(f, "lambda ({lambda}) must be greater than zero")
393 }
394 Self::LambdaNotFinite { lambda } => {
395 write!(f, "non-finite lambda: {lambda}")
396 }
397 }
398 }
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404 use crate::misc::ks_test;
405
406 const N_TRIES: usize = 10;
407 const KS_PVAL: f64 = 0.2;
408
409 crate::test_basic_impls!(
410 f64,
411 InvGaussian,
412 InvGaussian::new(1.0, 2.3).unwrap()
413 );
414
415 #[test]
416 fn mode_is_highest_point() {
417 let mut rng = rand::rng();
418 let mu_prior = crate::dist::InvGamma::new_unchecked(2.0, 2.0);
419 let lambda_prior = crate::dist::InvGamma::new_unchecked(2.0, 2.0);
420 for _ in 0..100 {
421 let mu: f64 = mu_prior.draw(&mut rng);
422 let lambda: f64 = lambda_prior.draw(&mut rng);
423 let ig = InvGaussian::new(mu, lambda).unwrap();
424 let mode: f64 = ig.mode().unwrap();
425 let ln_f_mode = ig.ln_f(&mode);
426 let ln_f_plus = ig.ln_f(&(mode + 1e-4));
427 let ln_f_minus = ig.ln_f(&(mode - 1e-4));
428
429 assert!(ln_f_mode > ln_f_plus);
430 assert!(ln_f_mode > ln_f_minus);
431 }
432 }
433
434 #[test]
435 fn quad_on_pdf_agrees_with_cdf_x() {
436 use peroxide::numerical::integral::{
437 Integral, gauss_kronrod_quadrature,
438 };
439 let ig = InvGaussian::new(1.1, 2.5).unwrap();
440 let pdf = |x: f64| ig.pdf(&x);
442 let mut rng = rand::rng();
443 for _ in 0..100 {
444 let x: f64 = ig.draw(&mut rng);
445 let res = gauss_kronrod_quadrature(
446 pdf,
447 (1e-16, x),
448 Integral::G7K15(1e-10, 100),
449 );
450 let cdf = ig.cdf(&x);
451 assert::close(res, cdf, 1e-7);
452 }
453 }
454
455 #[test]
456 fn draw_vs_kl() {
457 let mut rng = rand::rng();
458 let ig = InvGaussian::new(1.2, 3.4).unwrap();
459 let cdf = |x: f64| ig.cdf(&x);
460
461 let passes = (0..N_TRIES).fold(0, |acc, _| {
463 let xs: Vec<f64> = ig.sample(1000, &mut rng);
464 let (_, p) = ks_test(&xs, cdf);
465 if p > KS_PVAL { acc + 1 } else { acc }
466 });
467
468 assert!(passes > 0);
469 }
470
471 #[test]
472 fn ln_f_stat() {
473 use crate::traits::SuffStat;
474
475 let data: Vec<f64> = vec![0.1, 0.23, 1.4, 0.65, 0.22, 3.1];
476 let mut stat = InvGaussianSuffStat::new();
477 stat.observe_many(&data);
478
479 let igauss = InvGaussian::new(0.3, 2.33).unwrap();
480
481 let ln_f_base: f64 = data.iter().map(|x| igauss.ln_f(x)).sum();
482 let ln_f_stat: f64 =
483 <InvGaussian as HasSuffStat<f64>>::ln_f_stat(&igauss, &stat);
484
485 assert::close(ln_f_base, ln_f_stat, 1e-12);
486 }
487
488 #[test]
489 fn emit_and_from_params_are_identity() {
490 let dist_a = InvGaussian::new(1.5, 3.5).unwrap();
491 let dist_b = InvGaussian::from_params(dist_a.emit_params());
492 assert_eq!(dist_a, dist_b);
493 }
494}