scirs2_stats/distributions/
tweedie.rs1use crate::error::{StatsError, StatsResult};
37use crate::sampling::SampleableDistribution;
38use scirs2_core::numeric::{Float, NumCast};
39use scirs2_core::random::prelude::*;
40use scirs2_core::random::rand_distributions::Distribution as _;
41use scirs2_core::random::Uniform as RandUniform;
42use std::f64::consts::PI;
43
44fn ln_gamma(x: f64) -> f64 {
49 let coeffs = [
50 0.99999999999980993_f64,
51 676.5203681218851,
52 -1259.1392167224028,
53 771.32342877765313,
54 -176.61502916214059,
55 12.507343278686905,
56 -0.13857109526572012,
57 9.9843695780195716e-6,
58 1.5056327351493116e-7,
59 ];
60 if x < 0.5 {
61 return PI.ln() - (PI * x).sin().ln() - ln_gamma(1.0 - x);
62 }
63 let xm1 = x - 1.0;
64 let mut s = coeffs[0];
65 for (k, &c) in coeffs[1..].iter().enumerate() {
66 s += c / (xm1 + k as f64 + 1.0);
67 }
68 let t = xm1 + 7.5;
69 0.5 * (2.0 * PI).ln() + (xm1 + 0.5) * t.ln() - t + s.ln()
70}
71
72fn gamma_fn(x: f64) -> f64 {
73 ln_gamma(x).exp()
74}
75
76pub struct Tweedie<F: Float> {
93 pub mu: F,
95 pub phi: F,
97 pub p: F,
99 uniform_distr: RandUniform<f64>,
100}
101
102impl<F: Float + NumCast + std::fmt::Display> Tweedie<F> {
103 pub fn new(mu: F, phi: F, p: F) -> StatsResult<Self> {
111 let mu_f64: f64 = NumCast::from(mu).unwrap_or(0.0);
112 let phi_f64: f64 = NumCast::from(phi).unwrap_or(0.0);
113 let p_f64: f64 = NumCast::from(p).unwrap_or(0.0);
114
115 if mu_f64 <= 0.0 {
116 return Err(StatsError::DomainError(
117 "Mean mu must be positive".to_string(),
118 ));
119 }
120 if phi_f64 <= 0.0 {
121 return Err(StatsError::DomainError(
122 "Dispersion phi must be positive".to_string(),
123 ));
124 }
125 if p_f64 > 0.0 && p_f64 < 1.0 {
126 return Err(StatsError::DomainError(
127 "Tweedie power p cannot be in the open interval (0, 1)".to_string(),
128 ));
129 }
130
131 let uniform_distr = RandUniform::new(0.0_f64, 1.0_f64).map_err(|_| {
132 StatsError::ComputationError(
133 "Failed to create uniform distribution for Tweedie sampling".to_string(),
134 )
135 })?;
136
137 Ok(Self {
138 mu,
139 phi,
140 p,
141 uniform_distr,
142 })
143 }
144
145 pub fn variance(&self) -> F {
147 self.phi * self.mu.powf(self.p)
148 }
149
150 pub fn prob_zero(&self) -> f64 {
154 let mu_f64: f64 = NumCast::from(self.mu).unwrap_or(1.0);
155 let phi_f64: f64 = NumCast::from(self.phi).unwrap_or(1.0);
156 let p_f64: f64 = NumCast::from(self.p).unwrap_or(1.5);
157
158 if p_f64 <= 1.0 || p_f64 >= 2.0 {
159 return 0.0; }
161
162 let lambda = mu_f64.powf(2.0 - p_f64) / (phi_f64 * (2.0 - p_f64));
163 (-lambda).exp()
164 }
165
166 pub fn log_pdf(&self, x: F, max_terms: usize) -> f64 {
170 let x_f64: f64 = NumCast::from(x).unwrap_or(0.0);
171 let mu_f64: f64 = NumCast::from(self.mu).unwrap_or(1.0);
172 let phi_f64: f64 = NumCast::from(self.phi).unwrap_or(1.0);
173 let p_f64: f64 = NumCast::from(self.p).unwrap_or(1.5);
174
175 if (p_f64 - 0.0).abs() < 1e-10 {
177 return self.log_pdf_normal(x_f64, mu_f64, phi_f64);
178 }
179 if (p_f64 - 1.0).abs() < 1e-10 {
180 return self.log_pdf_poisson(x_f64, mu_f64, phi_f64);
181 }
182 if (p_f64 - 2.0).abs() < 1e-10 {
183 return self.log_pdf_gamma(x_f64, mu_f64, phi_f64);
184 }
185 if (p_f64 - 3.0).abs() < 1e-10 {
186 return self.log_pdf_inverse_gaussian(x_f64, mu_f64, phi_f64);
187 }
188
189 if p_f64 > 1.0 && p_f64 < 2.0 {
191 if x_f64 <= 0.0 {
192 let lambda = mu_f64.powf(2.0 - p_f64) / (phi_f64 * (2.0 - p_f64));
193 return -lambda;
194 }
195 return self.log_pdf_cpg(x_f64, mu_f64, phi_f64, p_f64, max_terms);
196 }
197
198 self.log_pdf_saddlepoint(x_f64, mu_f64, phi_f64, p_f64)
200 }
201
202 pub fn pdf(&self, x: F) -> f64 {
204 self.log_pdf(x, 100).exp()
205 }
206
207 fn log_pdf_normal(&self, x: f64, mu: f64, phi: f64) -> f64 {
210 let z = (x - mu) / phi.sqrt();
211 -0.5 * (z * z + (2.0 * PI * phi).ln())
212 }
213
214 fn log_pdf_poisson(&self, x: f64, mu: f64, phi: f64) -> f64 {
215 let k = x.round() as u64;
217 let lambda = mu / phi;
218 -(lambda) + k as f64 * lambda.ln() - ln_gamma(k as f64 + 1.0)
219 }
220
221 fn log_pdf_gamma(&self, x: f64, mu: f64, phi: f64) -> f64 {
222 if x <= 0.0 {
223 return f64::NEG_INFINITY;
224 }
225 let shape = 1.0 / phi;
227 let scale = mu * phi;
228 (shape - 1.0) * x.ln() - x / scale - shape * scale.ln() - ln_gamma(shape)
229 }
230
231 fn log_pdf_inverse_gaussian(&self, x: f64, mu: f64, phi: f64) -> f64 {
232 if x <= 0.0 {
233 return f64::NEG_INFINITY;
234 }
235 let lambda = 1.0 / phi;
237 0.5 * (lambda.ln() - (2.0 * PI * x * x * x).ln())
238 - lambda * (x - mu) * (x - mu) / (2.0 * mu * mu * x)
239 }
240
241 fn log_pdf_cpg(&self, x: f64, mu: f64, phi: f64, p: f64, max_terms: usize) -> f64 {
244 let alpha = (2.0 - p) / (p - 1.0); let lambda = mu.powf(2.0 - p) / (phi * (2.0 - p));
246 let theta = -mu.powf(1.0 - p) / (1.0 - p);
247
248 let log_const = x * theta - mu.powf(2.0 - p) / (phi * (2.0 - p)) - x.ln();
250
251 let mut log_w_vec: Vec<f64> = Vec::with_capacity(max_terms);
253
254 for j in 1..=max_terms {
255 let jf = j as f64;
256 let log_wj = jf * (alpha * x.ln() - (alpha * phi).ln() + lambda.ln())
257 - ln_gamma(jf * alpha + 1.0)
258 - ln_gamma(jf + 1.0);
259 log_w_vec.push(log_wj);
260
261 if j > 5 && log_wj < log_w_vec[0] - 50.0 {
263 break;
264 }
265 }
266
267 let max_lw = log_w_vec.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
269 let sum_exp: f64 = log_w_vec.iter().map(|&lw| (lw - max_lw).exp()).sum();
270 let log_series = max_lw + sum_exp.ln();
271
272 log_const - (phi * x).ln() + log_series
273 }
274
275 fn log_pdf_saddlepoint(&self, x: f64, mu: f64, phi: f64, p: f64) -> f64 {
278 if x <= 0.0 {
279 return f64::NEG_INFINITY;
280 }
281 let vx = x.powf(p);
284 let deviance = if (p - 2.0).abs() < 1e-10 {
285 2.0 * (x / mu).ln() + 2.0 * (mu - x) / mu
286 } else {
287 2.0 * (x.powf(2.0 - p) / (2.0 - p) - x * mu.powf(1.0 - p) / (1.0 - p)
288 + mu.powf(2.0 - p) / (2.0 - p))
289 };
290 -0.5 * (2.0 * PI * phi * vx).ln() - deviance / (2.0 * phi)
291 }
292
293 fn sample_one<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
297 let mu_f64: f64 = NumCast::from(self.mu).unwrap_or(1.0);
298 let phi_f64: f64 = NumCast::from(self.phi).unwrap_or(1.0);
299 let p_f64: f64 = NumCast::from(self.p).unwrap_or(1.5);
300
301 if p_f64 > 1.0 && p_f64 < 2.0 {
302 self.sample_cpg(mu_f64, phi_f64, p_f64, rng)
303 } else if (p_f64 - 2.0).abs() < 1e-10 {
304 self.sample_gamma_dist(mu_f64, phi_f64, rng)
305 } else if (p_f64 - 0.0).abs() < 1e-10 {
306 self.sample_normal(mu_f64, phi_f64, rng)
307 } else {
308 self.sample_approximate(mu_f64, phi_f64, p_f64, rng)
310 }
311 }
312
313 fn sample_normal<R: Rng + ?Sized>(&self, mu: f64, phi: f64, rng: &mut R) -> f64 {
314 let u1: f64 = self.uniform_distr.sample(rng).max(f64::EPSILON);
315 let u2: f64 = self.uniform_distr.sample(rng);
316 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
317 mu + phi.sqrt() * z
318 }
319
320 fn sample_gamma_dist<R: Rng + ?Sized>(&self, mu: f64, phi: f64, rng: &mut R) -> f64 {
321 let shape = 1.0 / phi;
322 let scale = mu * phi;
323 self.sample_gamma_raw(shape, scale, rng)
324 }
325
326 fn sample_cpg<R: Rng + ?Sized>(&self, mu: f64, phi: f64, p: f64, rng: &mut R) -> f64 {
328 let lambda = mu.powf(2.0 - p) / (phi * (2.0 - p));
329 let alpha = (2.0 - p) / (p - 1.0);
330 let beta = phi * (p - 1.0) * mu.powf(p - 1.0);
331
332 let n = self.sample_poisson(lambda, rng);
334
335 if n == 0 {
336 return 0.0;
337 }
338
339 let mut total = 0.0;
341 for _ in 0..n {
342 total += self.sample_gamma_raw(alpha, beta, rng);
343 }
344 total
345 }
346
347 fn sample_poisson<R: Rng + ?Sized>(&self, lambda: f64, rng: &mut R) -> usize {
348 if lambda < 30.0 {
350 let target = (-lambda).exp();
351 let mut k = 0_usize;
352 let mut p = 1.0_f64;
353 loop {
354 p *= self.uniform_distr.sample(rng);
355 if p <= target {
356 break;
357 }
358 k += 1;
359 if k > 10_000 {
360 break;
361 }
362 }
363 k
364 } else {
365 let u1: f64 = self.uniform_distr.sample(rng).max(f64::EPSILON);
367 let u2: f64 = self.uniform_distr.sample(rng);
368 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
369 let sample = lambda + lambda.sqrt() * z;
370 sample.round().max(0.0) as usize
371 }
372 }
373
374 fn sample_gamma_raw<R: Rng + ?Sized>(&self, shape: f64, scale: f64, rng: &mut R) -> f64 {
375 if shape < 1.0 {
377 let u: f64 = self.uniform_distr.sample(rng).max(f64::EPSILON);
378 return self.sample_gamma_raw(1.0 + shape, scale, rng) * u.powf(1.0 / shape);
379 }
380 let d = shape - 1.0 / 3.0;
381 let c = 1.0 / (9.0 * d).sqrt();
382 loop {
383 let u1: f64 = self.uniform_distr.sample(rng).max(f64::EPSILON);
384 let u2: f64 = self.uniform_distr.sample(rng);
385 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
386 let v = (1.0 + c * z).powi(3);
387 if v <= 0.0 {
388 continue;
389 }
390 let u3: f64 = self.uniform_distr.sample(rng);
391 if u3 < 1.0 - 0.0331 * z.powi(4) || u3.ln() < 0.5 * z * z + d * (1.0 - v + v.ln()) {
392 return d * v * scale;
393 }
394 }
395 }
396
397 fn sample_approximate<R: Rng + ?Sized>(&self, mu: f64, phi: f64, p: f64, rng: &mut R) -> f64 {
399 let variance = phi * mu.powf(p);
400 let shape = mu * mu / variance;
401 let scale = variance / mu;
402 self.sample_gamma_raw(shape, scale, rng)
403 }
404
405 pub fn rvs<R: Rng + ?Sized>(&self, n: usize, rng: &mut R) -> StatsResult<Vec<F>> {
407 let mut samples = Vec::with_capacity(n);
408 for _ in 0..n {
409 let s = self.sample_one(rng);
410 let f_s = F::from(s).ok_or_else(|| {
411 StatsError::ComputationError("Failed to convert sample to F".to_string())
412 })?;
413 samples.push(f_s);
414 }
415 Ok(samples)
416 }
417}
418
419impl<F: Float + NumCast + std::fmt::Display> SampleableDistribution<F> for Tweedie<F> {
420 fn rvs(&self, size: usize) -> StatsResult<Vec<F>> {
421 use scirs2_core::random::rngs::SmallRng;
422 use scirs2_core::random::SeedableRng;
423 let seed = std::time::SystemTime::now()
424 .duration_since(std::time::UNIX_EPOCH)
425 .map(|d| d.as_nanos() as u64)
426 .unwrap_or(0x9e3779b97f4a7c15);
427 let mut rng = SmallRng::seed_from_u64(seed);
428 self.rvs(size, &mut rng)
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435 use scirs2_core::random::{rngs::SmallRng, SeedableRng};
436
437 #[test]
438 fn test_normal_special_case_p0() {
439 let tw = Tweedie::new(3.0f64, 1.0, 0.0).expect("valid params");
440 let log_p = tw.log_pdf(3.0f64, 50);
442 let expected = -0.5 * (2.0 * std::f64::consts::PI).ln();
443 assert!((log_p - expected).abs() < 1e-10, "log_p={}", log_p);
444 }
445
446 #[test]
447 fn test_gamma_special_case_p2() {
448 let tw = Tweedie::new(2.0f64, 0.5, 2.0).expect("valid params");
449 let p = tw.pdf(1.0f64);
450 assert!(p > 0.0);
451 assert!(p.is_finite());
452 }
453
454 #[test]
455 fn test_prob_zero_cpg() {
456 let tw = Tweedie::new(2.0f64, 1.0, 1.5).expect("valid params");
458 let p0 = tw.prob_zero();
459 let lambda = 2.0_f64.powf(0.5) / (1.0 * 0.5);
460 let expected = (-lambda).exp();
461 assert!(
462 (p0 - expected).abs() < 1e-12,
463 "p0={} expected={}",
464 p0,
465 expected
466 );
467 }
468
469 #[test]
470 fn test_variance() {
471 let tw = Tweedie::new(3.0f64, 2.0, 1.5).expect("valid params");
472 let expected = 2.0 * 3.0_f64.powf(1.5);
474 let var: f64 = NumCast::from(tw.variance()).unwrap_or(0.0);
475 assert!((var - expected).abs() < 1e-10);
476 }
477
478 #[test]
479 fn test_cpg_sampling_mean() {
480 let mut rng = SmallRng::seed_from_u64(42);
481 let mu = 3.0_f64;
482 let tw = Tweedie::new(mu, 0.5, 1.5).expect("valid params");
483 let n = 5000_usize;
484 let samples = tw.rvs(n, &mut rng).expect("sampling should succeed");
485 assert_eq!(samples.len(), n);
486
487 let sum: f64 = samples.iter().sum();
488 let empirical_mean = sum / n as f64;
489 assert!(
491 (empirical_mean - mu).abs() < 0.5,
492 "empirical mean {} far from {}",
493 empirical_mean,
494 mu
495 );
496 }
497
498 #[test]
499 fn test_invalid_power() {
500 assert!(Tweedie::new(1.0f64, 1.0, 0.5).is_err()); }
502
503 #[test]
504 fn test_log_pdf_cpg() {
505 let tw = Tweedie::new(2.0f64, 1.0, 1.5).expect("valid params");
506 let log_p0 = tw.log_pdf(0.0f64, 100);
508 let expected = tw.prob_zero().ln();
509 assert!(
510 (log_p0 - expected).abs() < 1e-10,
511 "log_p0={} expected={}",
512 log_p0,
513 expected
514 );
515
516 let log_p1 = tw.log_pdf(1.0f64, 100);
518 assert!(log_p1.is_finite(), "log_p1={}", log_p1);
519 }
520}