rv/misc/
func.rs

1use crate::consts::{LN_2PI, LN_PI};
2use rand::Rng;
3use rand::distr::Open01;
4use special::Gamma;
5use std::cmp::Ordering;
6use std::fmt::Debug;
7use std::ops::AddAssign;
8
9/// Convert a Vector to a printable string
10///
11/// # Example
12///
13/// ```rust
14/// # use rv::misc::vec_to_string;
15/// let xs: Vec<u8> = vec![0, 1, 2, 3, 4, 5];
16///
17/// assert_eq!(vec_to_string(&xs, 6).as_str(), "[0, 1, 2, 3, 4, 5]");
18/// assert_eq!(vec_to_string(&xs, 5).as_str(), "[0, 1, 2, 3, ... , 5]");
19///
20/// ```
21pub fn vec_to_string<T: Debug>(xs: &[T], max_entries: usize) -> String {
22    let mut out = String::new();
23    out += "[";
24    let n = xs.len();
25    xs.iter().enumerate().for_each(|(i, x)| {
26        let to_push = if i == n - 1 {
27            format!("{x:?}]")
28        } else if i < max_entries - 1 {
29            format!("{x:?}, ")
30        } else if i == (max_entries - 1) && n > max_entries {
31            String::from("... , ")
32        } else {
33            format!("{x:?}]")
34        };
35
36        out.push_str(to_push.as_str());
37    });
38
39    out
40}
41
42/// Natural logarithm of binomial coefficient, ln nCk
43///
44/// # Example
45///
46/// ```rust
47/// use rv::misc::ln_binom;
48///
49/// assert!((ln_binom(4.0, 2.0) - 6.0_f64.ln()) < 1E-12);
50/// ```
51#[must_use]
52pub fn ln_binom(n: f64, k: f64) -> f64 {
53    ln_gammafn(n + 1.0) - ln_gammafn(k + 1.0) - ln_gammafn(n - k + 1.0)
54}
55
56/// Gamma function, Γ(x)
57///
58/// # Example
59///
60/// ```rust
61/// use rv::misc::gammafn;
62///
63/// assert!((gammafn(4.0) - 6.0) < 1E-12);
64/// ```
65///
66/// # Notes
67///
68/// This function is a wrapper around `special::Gamma::gamma`.. The name `gamma`
69/// is reserved for possible future use in standard libraries. This function is
70/// purely to avoid warnings resulting from this.
71#[must_use]
72pub fn gammafn(x: f64) -> f64 {
73    Gamma::gamma(x)
74}
75
76/// Logarithm of the gamma function, ln Γ(x)
77///
78/// # Example
79///
80/// ```rust
81///
82/// use rv::misc::ln_gammafn;
83///
84/// assert!((ln_gammafn(4.0) - 6.0_f64.ln()) < 1E-12);
85/// ```
86///
87/// # Notes
88///
89/// This function is a wrapper around `special::Gamma::ln_gamma`.. The name
90/// `ln_gamma` is reserved for possible future use in standard libraries. This
91/// function is purely to avoid warnings resulting from this.
92#[must_use]
93pub fn ln_gammafn(x: f64) -> f64 {
94    Gamma::ln_gamma(x).0
95}
96
97#[must_use]
98pub fn log1pexp(x: f64) -> f64 {
99    if x <= -37.0 {
100        f64::exp(x)
101    } else if x <= 18.0 {
102        f64::ln_1p(f64::exp(x))
103    } else if x <= 33.3 {
104        x + f64::exp(-x)
105    } else {
106        x
107    }
108}
109
110#[must_use]
111pub fn logaddexp(x: f64, y: f64) -> f64 {
112    if x > y {
113        x + log1pexp(y - x)
114    } else {
115        y + log1pexp(x - y)
116    }
117}
118
119/// Streaming `logexp` implementation as described in [Sebastian Nowozin's blog](https://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html)
120pub trait LogSumExp {
121    fn logsumexp(self) -> f64;
122}
123
124use std::borrow::Borrow;
125
126impl<I> LogSumExp for I
127where
128    I: Iterator,
129    I::Item: std::borrow::Borrow<f64>,
130{
131    fn logsumexp(self) -> f64 {
132        let (alpha, r) =
133            self.fold((f64::NEG_INFINITY, 0.0), |(alpha, r), x| {
134                let x = *x.borrow();
135                if x == f64::NEG_INFINITY {
136                    (alpha, r)
137                } else if x <= alpha {
138                    (alpha, r + (x - alpha).exp())
139                } else {
140                    (x, (alpha - x).exp().mul_add(r, 1.0))
141                }
142            });
143
144        alpha + r.ln()
145    }
146}
147
148/// Cumulative sum of `xs`
149///
150/// # Example
151///
152/// ```rust
153/// # use rv::misc::cumsum;
154/// let xs: Vec<i32> = vec![1, 1, 2, 1];
155/// assert_eq!(cumsum(&xs), vec![1, 2, 4, 5]);
156/// ```
157pub fn cumsum<T>(xs: &[T]) -> Vec<T>
158where
159    T: AddAssign + Copy + Default,
160{
161    xs.iter()
162        .scan(T::default(), |acc, &x| {
163            *acc += x;
164            Some(*acc)
165        })
166        .collect()
167}
168
169#[inline]
170fn binary_search(cws: &[f64], r: f64) -> usize {
171    let mut left: usize = 0;
172    let mut right: usize = cws.len();
173    while left < right {
174        let mid = (left + right) / 2;
175        if cws[mid] < r {
176            left = mid + 1;
177        } else {
178            right = mid;
179        }
180    }
181    left
182}
183
184#[inline]
185fn catflip_bisection(cws: &[f64], r: f64) -> Option<usize> {
186    let ix = binary_search(cws, r);
187    if ix < cws.len() { Some(ix) } else { None }
188}
189
190#[inline]
191fn catflip_standard(cws: &[f64], r: f64) -> Option<usize> {
192    cws.iter().position(|&w| w > r)
193}
194
195fn catflip(cws: &[f64], r: f64) -> Option<usize> {
196    if cws.len() > 9 {
197        catflip_bisection(cws, r)
198    } else {
199        catflip_standard(cws, r)
200    }
201}
202
203// Draw a categorical using Gumbel max sampling
204pub fn gumbel_pflip(weights: &[f64], rng: &mut impl Rng) -> usize {
205    assert!(!weights.is_empty(), "Empty container");
206    weights
207        .iter()
208        .map(|w| (w, rng.random::<f64>().ln()))
209        .enumerate()
210        .max_by(|(_, (w1, l1)), (_, (w2, l2))| {
211            (*w2 * l1).partial_cmp(&(*w1 * l2)).unwrap()
212        })
213        .unwrap()
214        .0
215}
216
217pub fn pflip(weights: &[f64], sum: Option<f64>, rng: &mut impl Rng) -> usize {
218    assert!(!weights.is_empty(), "Empty container");
219
220    let sum = sum.unwrap_or_else(|| weights.iter().sum::<f64>());
221
222    let mut cwt = 0.0;
223    let r: f64 = rng.random::<f64>() * sum;
224    for (ix, w) in weights.iter().enumerate() {
225        cwt += w;
226        if cwt > r {
227            return ix;
228        }
229    }
230    panic!("Could not draw from {weights:?}")
231}
232
233/// Draw `n` indices in proportion to their `weights`
234pub fn pflips(weights: &[f64], n: usize, rng: &mut impl Rng) -> Vec<usize> {
235    assert!(!weights.is_empty(), "Empty container");
236
237    let cws: Vec<f64> = cumsum(weights);
238    let scale: f64 = *cws.last().unwrap();
239    let u = rand::distr::StandardUniform;
240
241    (0..n)
242        .map(|_| {
243            let r = rng.sample::<f64, _>(u) * scale;
244            if let Some(ix) = catflip(&cws, r) {
245                ix
246            } else {
247                let wsvec = weights.to_vec();
248                panic!("Could not draw from {wsvec:?}")
249            }
250        })
251        .collect()
252}
253
254/// Draw an index according to log-domain weights
255///
256/// Draw a `usize` from the categorical distribution defined by `ln_weights`.
257/// If `normed` is `true` then exp(`ln_weights`) is assumed to sum to 1.
258///
259/// # Examples
260///
261/// ```rust
262/// use rv::misc::ln_pflips;
263///
264/// let weights: Vec<f64> = vec![0.4, 0.2, 0.3, 0.1];
265/// let ln_weights: Vec<f64> = weights.iter().map(|&w| w.ln()).collect();
266///
267/// let xs = ln_pflips(&ln_weights, 100, true, &mut rand::rng());
268///
269/// assert_eq!(xs.len(), 100);
270/// assert!(xs.iter().all(|&x| x <= 3));
271/// assert!(!xs.iter().any(|&x| x > 3));
272/// ```
273///
274/// Can handle -Inf ln weights
275///
276/// ```rust
277/// # use rv::misc::ln_pflips;
278/// use std::f64::NEG_INFINITY;
279/// use std::f64::consts::LN_2;
280///
281/// let ln_weights: Vec<f64> = vec![-LN_2, NEG_INFINITY, -LN_2];
282///
283/// let xs = ln_pflips(&ln_weights, 100, true, &mut rand::rng());
284///
285/// let zero_count = xs.iter().filter(|&&x| x == 0).count();
286/// let one_count = xs.iter().filter(|&&x| x == 1).count();
287/// let two_count = xs.iter().filter(|&&x| x == 2).count();
288///
289/// assert!(zero_count > 30);
290/// assert_eq!(one_count, 0);
291/// assert!(two_count > 30);
292/// ```
293pub fn ln_pflips<R: Rng>(
294    ln_weights: &[f64],
295    n: usize,
296    normed: bool,
297    rng: &mut R,
298) -> Vec<usize> {
299    let z = if normed {
300        0.0
301    } else {
302        ln_weights.iter().logsumexp()
303    };
304
305    // doing this instead of calling pflips shaves about 30% off the runtime.
306    let cws: Vec<f64> = ln_weights
307        .iter()
308        .scan(0.0, |state, w| {
309            *state += (w - z).exp();
310            Some(*state)
311        })
312        .collect();
313
314    (0..n)
315        .map(|_| {
316            let r = rng.sample(Open01);
317            if let Some(ix) = catflip(&cws, r) {
318                ix
319            } else {
320                let wsvec = ln_weights.to_vec();
321                panic!("Could not draw from {wsvec:?}")
322            }
323        })
324        .collect()
325}
326
327pub fn ln_pflip<R: Rng, I>(ln_weights: I, _normed: bool, rng: &mut R) -> usize
328where
329    I: IntoIterator,
330    I::Item: std::borrow::Borrow<f64>,
331{
332    ln_weights
333        .into_iter()
334        .map(|ln_w| (*ln_w.borrow(), rng.random::<f64>().ln()))
335        .enumerate()
336        .max_by(|(_, (ln_w1, l1)), (_, (ln_w2, l2))| {
337            l1.partial_cmp(&(l2 * (*ln_w1 - *ln_w2).exp())).unwrap()
338        })
339        .unwrap()
340        .0
341}
342
343/// Indices of the largest element(s) in xs.
344///
345/// If there is more than one largest element, `argmax` returns the indices of
346/// all replicates.
347///
348/// # Examples
349///
350/// ```rust
351/// use rv::misc::argmax;
352///
353/// let xs: Vec<u8> = vec![1, 2, 3, 4, 5, 4, 5];
354/// let ys: Vec<u8> = vec![1, 2, 3, 4, 5, 4, 0];
355///
356/// assert_eq!(argmax(&xs), vec![4, 6]);
357/// assert_eq!(argmax(&ys), vec![4]);
358/// ```
359pub fn argmax<T: PartialOrd>(xs: &[T]) -> Vec<usize> {
360    if xs.is_empty() {
361        vec![]
362    } else if xs.len() == 1 {
363        vec![0]
364    } else {
365        let mut maxval = &xs[0];
366        let mut max_ixs: Vec<usize> = vec![0];
367        for (i, x) in xs.iter().enumerate().skip(1) {
368            match x.partial_cmp(maxval) {
369                Some(Ordering::Greater) => {
370                    maxval = x;
371                    max_ixs = vec![i];
372                }
373                Some(Ordering::Equal) => max_ixs.push(i),
374                _ => (),
375            }
376        }
377        max_ixs
378    }
379}
380
381/// Natural logarithm of the multivariate gamma function, *ln Γ<sub>p</sub>(a)*.
382///
383/// # Arguments
384///
385/// * `p` - Positive integer degrees of freedom
386/// * `a` - The number for which to compute the multivariate gamma
387#[must_use]
388pub fn lnmv_gamma(p: usize, a: f64) -> f64 {
389    let pf = p as f64;
390    let a0 = pf * (pf - 1.0) / 4.0 * LN_PI;
391    (1..=p).fold(a0, |acc, j| acc + ln_gammafn(a + (1.0 - j as f64) / 2.0))
392}
393
394/// Multivariate gamma function, *Γ<sub>p</sub>(a)*.
395///
396/// # Arguments
397///
398/// * `p` - Positive integer degrees of freedom
399/// * `a` - The number for which to compute the multivariate gamma
400#[must_use]
401pub fn mvgamma(p: usize, a: f64) -> f64 {
402    lnmv_gamma(p, a).exp()
403}
404
405/// ln factorial
406///
407/// # Notes
408///
409/// n < 254 are computed via lookup table. n > 254 are computed via Sterling's
410/// approximation. Code based on [C code from John
411/// Cook](https://www.johndcook.com/blog/csharp_log_factorial/)
412///
413///
414#[must_use]
415pub fn ln_fact(n: usize) -> f64 {
416    if n < 254 {
417        LN_FACT[n]
418    } else {
419        let y: f64 = (n as f64) + 1.0;
420        (y - 0.5).mul_add(y.ln(), -y)
421            + 0.5_f64.mul_add(LN_2PI, (12.0 * y).recip())
422    }
423}
424
425/// Generate a vector of sorted uniform random variables.
426///
427/// # Arguments
428///     
429/// * `n` - The number of random variables to generate.
430///
431/// * `rng` - A mutable reference to the random number generator.
432///
433/// # Returns
434///
435/// A vector of sorted uniform random variables.
436///
437/// # Example
438///
439/// ```
440/// use rand::rng;
441/// use rv::misc::sorted_uniforms;
442///    
443/// let mut rng = rng();
444/// let n = 10000;
445/// let xs = sorted_uniforms(n, &mut rng);
446/// assert_eq!(xs.len(), n);
447///
448/// // Result is sorted and in the unit interval
449/// assert!(xs.first().map_or(false, |&first| first > 0.0));
450/// assert!(xs.last().map_or(false, |&last| last < 1.0));
451/// assert!(xs.windows(2).all(|w| w[0] <= w[1]));
452///
453/// // Mean is approximately 1/2
454/// let mean = xs.iter().sum::<f64>() / n as f64;
455/// assert!(mean > 0.49 && mean < 0.51);
456///
457/// // Variance is approximately 1/12
458/// let var = xs.iter().map(|x| (x - 0.5).powi(2)).sum::<f64>() / n as f64;
459/// assert!(var > 0.08 && var < 0.09);
460/// ```
461pub fn sorted_uniforms<R: Rng>(n: usize, rng: &mut R) -> Vec<f64> {
462    let mut xs: Vec<_> = (0..n)
463        .map(|_| -rng.random::<f64>().ln())
464        .scan(0.0, |state, x| {
465            *state += x;
466            Some(*state)
467        })
468        .collect();
469    let max = *xs.last().unwrap() - rng.random::<f64>().ln();
470    (0..n).for_each(|i| xs[i] /= max);
471    xs
472}
473
474#[allow(dead_code)]
475pub(crate) fn eq_or_close(a: f64, b: f64, tol: f64) -> bool {
476    a == b                                           // Really equal, or both -Inf or Inf
477        || a.is_nan() && b.is_nan()                  // Both NaN
478        || (a - b).abs() < tol                       // Small absolute difference
479        || 2.0 * (a - b).abs() / (a + b).abs() < tol // Small relative difference
480}
481
482const LN_FACT: [f64; 255] = [
483    0.000_000_000_000_000,
484    0.000_000_000_000_000,
485    std::f64::consts::LN_2,
486    1.791_759_469_228_055,
487    3.178_053_830_347_946,
488    4.787_491_742_782_046,
489    6.579_251_212_010_101,
490    8.525_161_361_065_415,
491    10.604_602_902_745_25,
492    12.801_827_480_081_469,
493    15.104_412_573_075_516,
494    17.502_307_845_873_887,
495    19.987_214_495_661_885,
496    22.552_163_853_123_42,
497    25.191_221_182_738_683,
498    27.899_271_383_840_894,
499    30.671_860_106_080_675,
500    33.505_073_450_136_89,
501    36.395_445_208_033_05,
502    39.339_884_187_199_495,
503    42.335_616_460_753_485,
504    45.380_138_898_476_91,
505    48.471_181_351_835_23,
506    51.606_675_567_764_38,
507    54.784_729_398_112_32,
508    58.003_605_222_980_52,
509    61.261_701_761_002,
510    64.557_538_627_006_32,
511    67.889_743_137_181_53,
512    71.257_038_967_168,
513    74.658_236_348_830_16,
514    78.092_223_553_315_3,
515    81.557_959_456_115_03,
516    85.054_467_017_581_52,
517    88.580_827_542_197_68,
518    92.136_175_603_687_08,
519    95.719_694_542_143_2,
520    99.330_612_454_787_43,
521    102.968_198_614_513_81,
522    106.631_760_260_643_45,
523    110.320_639_714_757_39,
524    114.034_211_781_461_69,
525    117.771_881_399_745_06,
526    121.533_081_515_438_64,
527    125.317_271_149_356_88,
528    129.123_933_639_127_24,
529    132.952_575_035_616_3,
530    136.802_722_637_326_35,
531    140.673_923_648_234_25,
532    144.565_743_946_344_9,
533    148.477_766_951_773_02,
534    152.409_592_584_497_35,
535    156.360_836_303_078_8,
536    160.331_128_216_630_93,
537    164.320_112_263_195_17,
538    168.327_445_448_427_65,
539    172.352_797_139_162_82,
540    176.395_848_406_997_37,
541    180.456_291_417_543_78,
542    184.533_828_861_449_5,
543    188.628_173_423_671_6,
544    192.739_047_287_844_9,
545    196.866_181_672_889_98,
546    201.009_316_399_281_57,
547    205.168_199_482_641_2,
548    209.342_586_752_536_82,
549    213.532_241_494_563_27,
550    217.736_934_113_954_25,
551    221.956_441_819_130_36,
552    226.190_548_323_727_57,
553    230.439_043_565_776_93,
554    234.701_723_442_818_26,
555    238.978_389_561_834_35,
556    243.268_849_002_982_73,
557    247.572_914_096_186_9,
558    251.890_402_209_723_2,
559    256.221_135_550_009_5,
560    260.564_940_971_863_2,
561    264.921_649_798_552_8,
562    269.291_097_651_019_8,
563    273.673_124_285_693_7,
564    278.067_573_440_366_1,
565    282.474_292_687_630_4,
566    286.893_133_295_427,
567    291.323_950_094_270_3,
568    295.766_601_350_760_6,
569    300.220_948_647_014_1,
570    304.686_856_765_668_7,
571    309.164_193_580_146_9,
572    313.652_829_949_879,
573    318.152_639_620_209_3,
574    322.663_499_126_726_2,
575    327.185_287_703_775_2,
576    331.717_887_196_928_5,
577    336.261_181_979_198_45,
578    340.815_058_870_798_96,
579    345.379_407_062_266_86,
580    349.954_118_040_770_25,
581    354.539_085_519_440_8,
582    359.134_205_369_575_34,
583    363.739_375_555_563_47,
584    368.354_496_072_404_7,
585    372.979_468_885_689,
586    377.614_197_873_918_67,
587    382.258_588_773_06,
588    386.912_549_123_217_56,
589    391.575_988_217_329_6,
590    396.248_817_051_791_5,
591    400.930_948_278_915_76,
592    405.622_296_161_144_9,
593    410.322_776_526_937_3,
594    415.032_306_728_249_6,
595    419.750_805_599_544_8,
596    424.478_193_418_257_1,
597    429.214_391_866_651_57,
598    433.959_323_995_014_87,
599    438.712_914_186_121_17,
600    443.475_088_120_918_94,
601    448.245_772_745_384_6,
602    453.024_896_238_496_1,
603    457.812_387_981_278_1,
604    462.608_178_526_874_9,
605    467.412_199_571_608_1,
606    472.224_383_926_980_5,
607    477.044_665_492_585_6,
608    481.872_979_229_887_9,
609    486.709_261_136_839_36,
610    491.553_448_223_298,
611    496.405_478_487_217_6,
612    501.265_290_891_579_24,
613    506.132_825_342_034_83,
614    511.008_022_665_236_07,
615    515.890_824_587_822_5,
616    520.781_173_716_044_2,
617    525.679_013_515_995,
618    530.584_288_294_433_6,
619    535.496_943_180_169_5,
620    540.416_924_105_997_7,
621    545.344_177_791_155,
622    550.278_651_724_285_6,
623    555.220_294_146_895,
624    560.169_054_037_273_1,
625    565.124_881_094_874_4,
626    570.087_725_725_134_2,
627    575.057_539_024_710_2,
628    580.034_272_767_130_8,
629    585.017_879_388_839_2,
630    590.008_311_975_617_9,
631    595.005_524_249_382,
632    600.009_470_555_327_4,
633    605.020_105_849_423_8,
634    610.037_385_686_238_7,
635    615.061_266_207_084_9,
636    620.091_704_128_477_4,
637    625.128_656_730_891_1,
638    630.172_081_847_810_2,
639    635.221_937_855_059_8,
640    640.278_183_660_408_1,
641    645.340_778_693_435,
642    650.409_682_895_655_2,
643    655.484_856_710_889_1,
644    660.566_261_075_873_5,
645    665.653_857_411_106,
646    670.747_607_611_912_7,
647    675.847_474_039_736_9,
648    680.953_419_513_637_5,
649    686.065_407_301_994,
650    691.183_401_114_410_8,
651    696.307_365_093_814,
652    701.437_263_808_737_2,
653    706.573_062_245_787_5,
654    711.714_725_802_29,
655    716.862_220_279_103_4,
656    722.015_511_873_601_3,
657    727.174_567_172_815_8,
658    732.339_353_146_739_3,
659    737.509_837_141_777_4,
660    742.685_986_874_351_2,
661    747.867_770_424_643_4,
662    753.055_156_230_484_2,
663    758.248_113_081_374_3,
664    763.446_610_112_640_2,
665    768.650_616_799_717,
666    773.860_102_952_558_5,
667    779.075_038_710_167_4,
668    784.295_394_535_245_7,
669    789.521_141_208_959,
670    794.752_249_825_813_5,
671    799.988_691_788_643_5,
672    805.230_438_803_703_1,
673    810.477_462_875_863_6,
674    815.729_736_303_910_2,
675    820.987_231_675_937_9,
676    826.249_921_864_842_8,
677    831.517_780_023_906_3,
678    836.790_779_582_469_9,
679    842.068_894_241_700_5,
680    847.352_097_970_438_4,
681    852.640_365_001_133_1,
682    857.933_669_825_857_5,
683    863.231_987_192_405_4,
684    868.535_292_100_464_6,
685    873.843_559_797_865_7,
686    879.156_765_776_907_6,
687    884.474_885_770_751_8,
688    889.797_895_749_890_2,
689    895.125_771_918_679_9,
690    900.458_490_711_945_3,
691    905.796_028_791_646_3,
692    911.138_363_043_611_2,
693    916.485_470_574_328_8,
694    921.837_328_707_804_9,
695    927.193_914_982_476_7,
696    932.555_207_148_186_2,
697    937.921_183_163_208_1,
698    943.291_821_191_335_7,
699    948.667_099_599_019_8,
700    954.046_996_952_560_4,
701    959.431_492_015_349_5,
702    964.820_563_745_165_9,
703    970.214_191_291_518_3,
704    975.612_353_993_036_2,
705    981.015_031_374_908_4,
706    986.422_203_146_368_6,
707    991.833_849_198_223_4,
708    997.249_949_600_427_8,
709    1_002.670_484_599_700_3,
710    1_008.095_434_617_181_7,
711    1_013.524_780_246_136_2,
712    1_018.958_502_249_690_2,
713    1_024.396_581_558_613_4,
714    1_029.838_999_269_135_5,
715    1_035.285_736_640_801_6,
716    1_040.736_775_094_367_4,
717    1_046.192_096_209_725,
718    1_051.651_681_723_869_2,
719    1_057.115_513_528_895,
720    1_062.583_573_670_03,
721    1_068.055_844_343_701_4,
722    1_073.532_307_895_632_8,
723    1_079.012_946_818_975,
724    1_084.497_743_752_465_6,
725    1_089.986_681_478_622_4,
726    1_095.479_742_921_962_7,
727    1_100.976_911_147_256,
728    1_106.478_169_357_800_9,
729    1_111.983_500_893_733,
730    1_117.492_889_230_361,
731    1_123.006_317_976_526_1,
732    1_128.523_770_872_990_8,
733    1_134.045_231_790_853,
734    1_139.570_684_729_984_8,
735    1_145.100_113_817_496,
736    1_150.633_503_306_223_7,
737    1_156.170_837_573_242_4,
738];
739
740use num::Zero;
741
742/// Computes the natural logarithm of the product of a sequence of floating-point numbers.
743///
744/// This function calculates ln(x1 * x2 * ... * xn) in a numerically stable way,
745/// avoiding potential overflow or underflow issues that might occur with naive multiplication.
746///
747/// # Arguments
748///
749/// * `data` - An iterator yielding f64 values whose product's logarithm is to be computed.
750///
751/// # Returns
752///
753/// * The natural logarithm of the product of all numbers in the input iterator.
754///
755/// # Examples
756///
757/// ```
758/// # use rv::misc::log_product;
759/// let numbers = vec![2.0, 3.0, 4.0];
760/// let result = log_product(numbers.into_iter());
761/// assert!((result - (2.0f64 * 3.0 * 4.0).ln()).abs() < 1e-10);
762/// ```
763///
764/// # Notes
765///
766/// - If the input iterator is empty, the function returns 0.0 (ln(1) = 0).
767/// - If any input value is 0, the function returns negative infinity.
768/// - This function is particularly useful for computing products of many numbers
769///   or products of very large or very small numbers where direct multiplication
770///   might lead to floating-point overflow or underflow.
771pub fn log_product(data: impl Iterator<Item = f64>) -> f64 {
772    let mut result = 0.0;
773    let mut prod = 1.0;
774    for x in data {
775        let next_prod: f64 = x * prod;
776        if next_prod.is_normal() {
777            prod = next_prod;
778        } else {
779            if x.is_zero() {
780                return f64::NEG_INFINITY;
781            }
782            result += prod.ln();
783            prod = x;
784        }
785    }
786    result + prod.ln()
787}
788
789#[cfg(test)]
790mod tests {
791    use super::*;
792    use proptest::prelude::*;
793
794    proptest! {
795        #[test]
796        fn test_log1pexp_close_to_ln_1p_exp(x in -100.0..100.0_f64) {
797            let expected = x.exp().ln_1p();
798            let actual = log1pexp(x);
799            prop_assert!((expected - actual).abs() < 1e-10);
800        }
801    }
802    #[test]
803    fn test_log_product_empty() {
804        let empty: Vec<f64> = vec![];
805        assert_eq!(log_product(empty.into_iter()), 0.0);
806    }
807
808    #[test]
809    fn test_log_product_single_element() {
810        let single = vec![2.0];
811        assert_eq!(log_product(single.into_iter()), 2.0_f64.ln());
812    }
813
814    #[test]
815    fn test_log_product_multiple_elements() {
816        let multiple = vec![2.0, 3.0, 4.0];
817        assert!(
818            (log_product(multiple.into_iter())
819                - (2.0_f64 * 3.0_f64 * 4.0_f64).ln())
820            .abs()
821                < 1e-10
822        );
823    }
824
825    #[test]
826    fn test_log_product_overflow() {
827        let n = 100;
828        let large = vec![1e100; n];
829        let result = log_product(large.into_iter());
830        let correct = n as f64 * 1e100_f64.ln();
831        assert!((result - correct).abs() < 1e-10);
832    }
833
834    #[test]
835    fn test_log_product_underflow() {
836        let n = 100;
837        let large = vec![1e-100; n];
838        let result = log_product(large.into_iter());
839        let correct = n as f64 * 1e-100_f64.ln();
840        assert!((result - correct).abs() < 1e-10);
841    }
842
843    #[test]
844    fn test_log_product_with_zero() {
845        let with_zero = vec![2.0, 0.0, 3.0];
846        assert_eq!(log_product(with_zero.into_iter()), f64::NEG_INFINITY);
847    }
848
849    use crate::prelude::ChiSquared;
850    use crate::traits::Cdf;
851    use rand::{SeedableRng, rng};
852
853    const TOL: f64 = 1E-12;
854
855    #[test]
856    fn argmax_empty_is_empty() {
857        let xs: Vec<f64> = vec![];
858        assert_eq!(argmax(&xs), Vec::<usize>::new());
859    }
860
861    #[test]
862    fn argmax_single_elem_is_0() {
863        let xs: Vec<f64> = vec![1.0];
864        assert_eq!(argmax(&xs), vec![0]);
865    }
866
867    #[test]
868    fn argmax_unique_max() {
869        let xs: Vec<u8> = vec![1, 2, 3, 4, 5, 4, 3];
870        assert_eq!(argmax(&xs), vec![4]);
871    }
872
873    #[test]
874    fn argmax_repeated_max() {
875        let xs: Vec<u8> = vec![1, 2, 3, 4, 5, 4, 5];
876        assert_eq!(argmax(&xs), vec![4, 6]);
877    }
878
879    #[test]
880    fn logsumexp_nan_handling() {
881        let a: f64 = -3.0;
882        let b: f64 = -7.0;
883        let target: f64 = logaddexp(a, b);
884        let xs = [
885            -f64::INFINITY,
886            a,
887            -f64::INFINITY,
888            b,
889            -f64::INFINITY,
890            -f64::INFINITY,
891            -f64::INFINITY,
892            -f64::INFINITY,
893            -f64::INFINITY,
894            -f64::INFINITY,
895        ];
896        let result = xs.iter().logsumexp();
897        assert!((result - target).abs() < 1e-12);
898    }
899
900    proptest! {
901        #[test]
902        fn proptest_logsumexp(xs in prop::collection::vec(-1e10_f64..1e10_f64, 0..100)) {
903            let result = xs.iter().logsumexp();
904            if xs.is_empty() {
905                prop_assert!(result == f64::NEG_INFINITY);
906            } else {
907                // Naive implementation for comparison
908                let max_x = xs.iter().copied().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
909                let sum_exp = xs.iter().map(|&x| (x - max_x).exp()).sum::<f64>();
910                let expected = max_x + sum_exp.ln();
911
912                // Check that the results are close
913                prop_assert!((result - expected).abs() < 1e-10);
914
915                // Check that the result is greater than or equal to the maximum input
916                prop_assert!(result >= *xs.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap());
917
918                // Check that exp(result) is greater than or equal to the sum of exp(x) for all x
919                let sum_exp_inputs: f64 = xs.iter().map(|&x| x.exp()).sum();
920                prop_assert!(result.exp() >= sum_exp_inputs);
921            }
922        }
923
924    }
925
926    #[test]
927    fn lnmv_gamma_values() {
928        assert::close(lnmv_gamma(1, 1.0), 0.0, TOL);
929        assert::close(lnmv_gamma(1, 12.0), 17.502_307_845_873_887, TOL);
930        assert::close(lnmv_gamma(3, 12.0), 50.615_815_724_290_74, TOL);
931        assert::close(lnmv_gamma(3, 8.23), 25.709_195_968_438_628, TOL);
932    }
933
934    #[test]
935    fn bisection_and_standard_catflip_equivalence() {
936        let mut rng = rand::rng();
937        for _ in 0..1000 {
938            let n: usize = rng.random_range(10..100);
939            let cws: Vec<f64> = (1..=n).map(|i| i as f64).collect();
940            let u2 = rand::distr::Uniform::new(0.0, n as f64).unwrap();
941            let r = rng.sample(u2);
942
943            let ix1 = catflip_standard(&cws, r).unwrap();
944            let ix2 = catflip_bisection(&cws, r).unwrap();
945
946            assert_eq!(ix1, ix2);
947        }
948    }
949
950    #[test]
951    fn ln_fact_agrees_with_naive() {
952        fn ln_fact_naive(x: usize) -> f64 {
953            if x < 2 {
954                0.0
955            } else {
956                (2..=x).map(|y| (y as f64).ln()).sum()
957            }
958        }
959
960        for x in 0..300 {
961            let f1 = ln_fact_naive(x);
962            let f2 = ln_fact(x);
963            assert::close(f1, f2, 1e-9);
964        }
965    }
966
967    #[test]
968    fn ln_pflips_works_with_zero_weights() {
969        use std::f64::consts::LN_2;
970
971        let ln_weights: Vec<f64> = vec![-LN_2, f64::NEG_INFINITY, -LN_2];
972
973        let xs = ln_pflips(&ln_weights, 100, true, &mut rand::rng());
974
975        let zero_count = xs.iter().filter(|&&x| x == 0).count();
976        let one_count = xs.iter().filter(|&&x| x == 1).count();
977        let two_count = xs.iter().filter(|&&x| x == 2).count();
978
979        assert!(zero_count > 30);
980        assert_eq!(one_count, 0);
981        assert!(two_count > 30);
982    }
983
984    #[test]
985    fn test_sorted_uniforms() {
986        let mut rng = rng();
987        let n = 1000;
988        let xs = sorted_uniforms(n, &mut rng);
989        assert_eq!(xs.len(), n);
990
991        // Result is sorted and in the unit interval
992        assert!(&0.0 < xs.first().unwrap());
993        assert!(xs.last().unwrap() < &1.0);
994        assert!(xs.windows(2).all(|w| w[0] <= w[1]));
995
996        // t will aggregate our chi-squared test statistic
997        let mut t = 0.0;
998
999        {
1000            // We'll build a histogram and count the bin populations, aggregating
1001            // the chi-squared statistic as we go
1002            let mut next_bin = 0.01;
1003            let mut bin_pop = 0;
1004
1005            for x in &xs {
1006                bin_pop += 1;
1007                if *x > next_bin {
1008                    let obs = f64::from(bin_pop);
1009                    let exp = n as f64 / 100.0;
1010                    t += (obs - exp).powi(2) / exp;
1011                    bin_pop = 0;
1012                    next_bin += 0.01;
1013                }
1014            }
1015
1016            // The last bin
1017            let obs = f64::from(bin_pop);
1018            let exp = n as f64 / 100.0;
1019            t += (obs - exp).powi(2) / exp;
1020        }
1021
1022        let alpha = 0.001;
1023
1024        // dof = number of bins minus one
1025        let chi2 = ChiSquared::new(99.0).unwrap();
1026        let p = chi2.sf(&t);
1027        assert!(p > alpha);
1028    }
1029
1030    use crate::prelude::Gaussian;
1031    use crate::traits::Sampleable;
1032    #[test]
1033    fn ln_pflip_sampling_distribution() {
1034        let n_samples = 1_000;
1035        let mut rng = rand::rngs::StdRng::seed_from_u64(123);
1036
1037        // Calculate expected probabilities
1038        let ln_weights =
1039            Gaussian::new(0.0, 1.0).unwrap().sample(n_samples, &mut rng);
1040        let log_normalizer: f64 = ln_weights.iter().logsumexp();
1041        let expected: Vec<f64> = ln_weights
1042            .iter()
1043            .map(|w| (w - log_normalizer).exp() * n_samples as f64)
1044            .collect();
1045
1046        // Collect samples
1047        let mut counts = vec![0; ln_weights.len()];
1048        for _ in 0..n_samples {
1049            let sample = ln_pflip(&ln_weights, false, &mut rng);
1050            counts[sample] += 1;
1051        }
1052        // Compute chi-squared statistic
1053        let chi_squared: f64 = counts
1054            .iter()
1055            .zip(expected.iter())
1056            .map(|(obs, exp)| {
1057                let diff = f64::from(*obs) - exp;
1058                diff * diff / exp
1059            })
1060            .sum();
1061
1062        // Degrees of freedom is number of categories minus 1
1063        let dof = ln_weights.len() - 1;
1064        let chi2 = ChiSquared::new(dof as f64).unwrap();
1065        let p_value = chi2.sf(&chi_squared);
1066
1067        assert!(
1068            p_value > 0.01,
1069            "Chi-squared test failed: p-value = {p_value}"
1070        );
1071    }
1072}