1use crate::consts::{LN_2PI, LN_PI};
2use rand::Rng;
3use rand::distr::Open01;
4use special::Gamma;
5use std::cmp::Ordering;
6use std::fmt::Debug;
7use std::ops::AddAssign;
8
9pub fn vec_to_string<T: Debug>(xs: &[T], max_entries: usize) -> String {
22 let mut out = String::new();
23 out += "[";
24 let n = xs.len();
25 xs.iter().enumerate().for_each(|(i, x)| {
26 let to_push = if i == n - 1 {
27 format!("{x:?}]")
28 } else if i < max_entries - 1 {
29 format!("{x:?}, ")
30 } else if i == (max_entries - 1) && n > max_entries {
31 String::from("... , ")
32 } else {
33 format!("{x:?}]")
34 };
35
36 out.push_str(to_push.as_str());
37 });
38
39 out
40}
41
42#[must_use]
52pub fn ln_binom(n: f64, k: f64) -> f64 {
53 ln_gammafn(n + 1.0) - ln_gammafn(k + 1.0) - ln_gammafn(n - k + 1.0)
54}
55
56#[must_use]
72pub fn gammafn(x: f64) -> f64 {
73 Gamma::gamma(x)
74}
75
76#[must_use]
93pub fn ln_gammafn(x: f64) -> f64 {
94 Gamma::ln_gamma(x).0
95}
96
97#[must_use]
98pub fn log1pexp(x: f64) -> f64 {
99 if x <= -37.0 {
100 f64::exp(x)
101 } else if x <= 18.0 {
102 f64::ln_1p(f64::exp(x))
103 } else if x <= 33.3 {
104 x + f64::exp(-x)
105 } else {
106 x
107 }
108}
109
110#[must_use]
111pub fn logaddexp(x: f64, y: f64) -> f64 {
112 if x > y {
113 x + log1pexp(y - x)
114 } else {
115 y + log1pexp(x - y)
116 }
117}
118
119pub trait LogSumExp {
121 fn logsumexp(self) -> f64;
122}
123
124use std::borrow::Borrow;
125
126impl<I> LogSumExp for I
127where
128 I: Iterator,
129 I::Item: std::borrow::Borrow<f64>,
130{
131 fn logsumexp(self) -> f64 {
132 let (alpha, r) =
133 self.fold((f64::NEG_INFINITY, 0.0), |(alpha, r), x| {
134 let x = *x.borrow();
135 if x == f64::NEG_INFINITY {
136 (alpha, r)
137 } else if x <= alpha {
138 (alpha, r + (x - alpha).exp())
139 } else {
140 (x, (alpha - x).exp().mul_add(r, 1.0))
141 }
142 });
143
144 alpha + r.ln()
145 }
146}
147
148pub fn cumsum<T>(xs: &[T]) -> Vec<T>
158where
159 T: AddAssign + Copy + Default,
160{
161 xs.iter()
162 .scan(T::default(), |acc, &x| {
163 *acc += x;
164 Some(*acc)
165 })
166 .collect()
167}
168
169#[inline]
170fn binary_search(cws: &[f64], r: f64) -> usize {
171 let mut left: usize = 0;
172 let mut right: usize = cws.len();
173 while left < right {
174 let mid = (left + right) / 2;
175 if cws[mid] < r {
176 left = mid + 1;
177 } else {
178 right = mid;
179 }
180 }
181 left
182}
183
184#[inline]
185fn catflip_bisection(cws: &[f64], r: f64) -> Option<usize> {
186 let ix = binary_search(cws, r);
187 if ix < cws.len() { Some(ix) } else { None }
188}
189
190#[inline]
191fn catflip_standard(cws: &[f64], r: f64) -> Option<usize> {
192 cws.iter().position(|&w| w > r)
193}
194
195fn catflip(cws: &[f64], r: f64) -> Option<usize> {
196 if cws.len() > 9 {
197 catflip_bisection(cws, r)
198 } else {
199 catflip_standard(cws, r)
200 }
201}
202
203pub fn gumbel_pflip(weights: &[f64], rng: &mut impl Rng) -> usize {
205 assert!(!weights.is_empty(), "Empty container");
206 weights
207 .iter()
208 .map(|w| (w, rng.random::<f64>().ln()))
209 .enumerate()
210 .max_by(|(_, (w1, l1)), (_, (w2, l2))| {
211 (*w2 * l1).partial_cmp(&(*w1 * l2)).unwrap()
212 })
213 .unwrap()
214 .0
215}
216
217pub fn pflip(weights: &[f64], sum: Option<f64>, rng: &mut impl Rng) -> usize {
218 assert!(!weights.is_empty(), "Empty container");
219
220 let sum = sum.unwrap_or_else(|| weights.iter().sum::<f64>());
221
222 let mut cwt = 0.0;
223 let r: f64 = rng.random::<f64>() * sum;
224 for (ix, w) in weights.iter().enumerate() {
225 cwt += w;
226 if cwt > r {
227 return ix;
228 }
229 }
230 panic!("Could not draw from {weights:?}")
231}
232
233pub fn pflips(weights: &[f64], n: usize, rng: &mut impl Rng) -> Vec<usize> {
235 assert!(!weights.is_empty(), "Empty container");
236
237 let cws: Vec<f64> = cumsum(weights);
238 let scale: f64 = *cws.last().unwrap();
239 let u = rand::distr::StandardUniform;
240
241 (0..n)
242 .map(|_| {
243 let r = rng.sample::<f64, _>(u) * scale;
244 if let Some(ix) = catflip(&cws, r) {
245 ix
246 } else {
247 let wsvec = weights.to_vec();
248 panic!("Could not draw from {wsvec:?}")
249 }
250 })
251 .collect()
252}
253
254pub fn ln_pflips<R: Rng>(
294 ln_weights: &[f64],
295 n: usize,
296 normed: bool,
297 rng: &mut R,
298) -> Vec<usize> {
299 let z = if normed {
300 0.0
301 } else {
302 ln_weights.iter().logsumexp()
303 };
304
305 let cws: Vec<f64> = ln_weights
307 .iter()
308 .scan(0.0, |state, w| {
309 *state += (w - z).exp();
310 Some(*state)
311 })
312 .collect();
313
314 (0..n)
315 .map(|_| {
316 let r = rng.sample(Open01);
317 if let Some(ix) = catflip(&cws, r) {
318 ix
319 } else {
320 let wsvec = ln_weights.to_vec();
321 panic!("Could not draw from {wsvec:?}")
322 }
323 })
324 .collect()
325}
326
327pub fn ln_pflip<R: Rng, I>(ln_weights: I, _normed: bool, rng: &mut R) -> usize
328where
329 I: IntoIterator,
330 I::Item: std::borrow::Borrow<f64>,
331{
332 ln_weights
333 .into_iter()
334 .map(|ln_w| (*ln_w.borrow(), rng.random::<f64>().ln()))
335 .enumerate()
336 .max_by(|(_, (ln_w1, l1)), (_, (ln_w2, l2))| {
337 l1.partial_cmp(&(l2 * (*ln_w1 - *ln_w2).exp())).unwrap()
338 })
339 .unwrap()
340 .0
341}
342
343pub fn argmax<T: PartialOrd>(xs: &[T]) -> Vec<usize> {
360 if xs.is_empty() {
361 vec![]
362 } else if xs.len() == 1 {
363 vec![0]
364 } else {
365 let mut maxval = &xs[0];
366 let mut max_ixs: Vec<usize> = vec![0];
367 for (i, x) in xs.iter().enumerate().skip(1) {
368 match x.partial_cmp(maxval) {
369 Some(Ordering::Greater) => {
370 maxval = x;
371 max_ixs = vec![i];
372 }
373 Some(Ordering::Equal) => max_ixs.push(i),
374 _ => (),
375 }
376 }
377 max_ixs
378 }
379}
380
381#[must_use]
388pub fn lnmv_gamma(p: usize, a: f64) -> f64 {
389 let pf = p as f64;
390 let a0 = pf * (pf - 1.0) / 4.0 * LN_PI;
391 (1..=p).fold(a0, |acc, j| acc + ln_gammafn(a + (1.0 - j as f64) / 2.0))
392}
393
394#[must_use]
401pub fn mvgamma(p: usize, a: f64) -> f64 {
402 lnmv_gamma(p, a).exp()
403}
404
405#[must_use]
415pub fn ln_fact(n: usize) -> f64 {
416 if n < 254 {
417 LN_FACT[n]
418 } else {
419 let y: f64 = (n as f64) + 1.0;
420 (y - 0.5).mul_add(y.ln(), -y)
421 + 0.5_f64.mul_add(LN_2PI, (12.0 * y).recip())
422 }
423}
424
425pub fn sorted_uniforms<R: Rng>(n: usize, rng: &mut R) -> Vec<f64> {
462 let mut xs: Vec<_> = (0..n)
463 .map(|_| -rng.random::<f64>().ln())
464 .scan(0.0, |state, x| {
465 *state += x;
466 Some(*state)
467 })
468 .collect();
469 let max = *xs.last().unwrap() - rng.random::<f64>().ln();
470 (0..n).for_each(|i| xs[i] /= max);
471 xs
472}
473
474#[allow(dead_code)]
475pub(crate) fn eq_or_close(a: f64, b: f64, tol: f64) -> bool {
476 a == b || a.is_nan() && b.is_nan() || (a - b).abs() < tol || 2.0 * (a - b).abs() / (a + b).abs() < tol }
481
482const LN_FACT: [f64; 255] = [
483 0.000_000_000_000_000,
484 0.000_000_000_000_000,
485 std::f64::consts::LN_2,
486 1.791_759_469_228_055,
487 3.178_053_830_347_946,
488 4.787_491_742_782_046,
489 6.579_251_212_010_101,
490 8.525_161_361_065_415,
491 10.604_602_902_745_25,
492 12.801_827_480_081_469,
493 15.104_412_573_075_516,
494 17.502_307_845_873_887,
495 19.987_214_495_661_885,
496 22.552_163_853_123_42,
497 25.191_221_182_738_683,
498 27.899_271_383_840_894,
499 30.671_860_106_080_675,
500 33.505_073_450_136_89,
501 36.395_445_208_033_05,
502 39.339_884_187_199_495,
503 42.335_616_460_753_485,
504 45.380_138_898_476_91,
505 48.471_181_351_835_23,
506 51.606_675_567_764_38,
507 54.784_729_398_112_32,
508 58.003_605_222_980_52,
509 61.261_701_761_002,
510 64.557_538_627_006_32,
511 67.889_743_137_181_53,
512 71.257_038_967_168,
513 74.658_236_348_830_16,
514 78.092_223_553_315_3,
515 81.557_959_456_115_03,
516 85.054_467_017_581_52,
517 88.580_827_542_197_68,
518 92.136_175_603_687_08,
519 95.719_694_542_143_2,
520 99.330_612_454_787_43,
521 102.968_198_614_513_81,
522 106.631_760_260_643_45,
523 110.320_639_714_757_39,
524 114.034_211_781_461_69,
525 117.771_881_399_745_06,
526 121.533_081_515_438_64,
527 125.317_271_149_356_88,
528 129.123_933_639_127_24,
529 132.952_575_035_616_3,
530 136.802_722_637_326_35,
531 140.673_923_648_234_25,
532 144.565_743_946_344_9,
533 148.477_766_951_773_02,
534 152.409_592_584_497_35,
535 156.360_836_303_078_8,
536 160.331_128_216_630_93,
537 164.320_112_263_195_17,
538 168.327_445_448_427_65,
539 172.352_797_139_162_82,
540 176.395_848_406_997_37,
541 180.456_291_417_543_78,
542 184.533_828_861_449_5,
543 188.628_173_423_671_6,
544 192.739_047_287_844_9,
545 196.866_181_672_889_98,
546 201.009_316_399_281_57,
547 205.168_199_482_641_2,
548 209.342_586_752_536_82,
549 213.532_241_494_563_27,
550 217.736_934_113_954_25,
551 221.956_441_819_130_36,
552 226.190_548_323_727_57,
553 230.439_043_565_776_93,
554 234.701_723_442_818_26,
555 238.978_389_561_834_35,
556 243.268_849_002_982_73,
557 247.572_914_096_186_9,
558 251.890_402_209_723_2,
559 256.221_135_550_009_5,
560 260.564_940_971_863_2,
561 264.921_649_798_552_8,
562 269.291_097_651_019_8,
563 273.673_124_285_693_7,
564 278.067_573_440_366_1,
565 282.474_292_687_630_4,
566 286.893_133_295_427,
567 291.323_950_094_270_3,
568 295.766_601_350_760_6,
569 300.220_948_647_014_1,
570 304.686_856_765_668_7,
571 309.164_193_580_146_9,
572 313.652_829_949_879,
573 318.152_639_620_209_3,
574 322.663_499_126_726_2,
575 327.185_287_703_775_2,
576 331.717_887_196_928_5,
577 336.261_181_979_198_45,
578 340.815_058_870_798_96,
579 345.379_407_062_266_86,
580 349.954_118_040_770_25,
581 354.539_085_519_440_8,
582 359.134_205_369_575_34,
583 363.739_375_555_563_47,
584 368.354_496_072_404_7,
585 372.979_468_885_689,
586 377.614_197_873_918_67,
587 382.258_588_773_06,
588 386.912_549_123_217_56,
589 391.575_988_217_329_6,
590 396.248_817_051_791_5,
591 400.930_948_278_915_76,
592 405.622_296_161_144_9,
593 410.322_776_526_937_3,
594 415.032_306_728_249_6,
595 419.750_805_599_544_8,
596 424.478_193_418_257_1,
597 429.214_391_866_651_57,
598 433.959_323_995_014_87,
599 438.712_914_186_121_17,
600 443.475_088_120_918_94,
601 448.245_772_745_384_6,
602 453.024_896_238_496_1,
603 457.812_387_981_278_1,
604 462.608_178_526_874_9,
605 467.412_199_571_608_1,
606 472.224_383_926_980_5,
607 477.044_665_492_585_6,
608 481.872_979_229_887_9,
609 486.709_261_136_839_36,
610 491.553_448_223_298,
611 496.405_478_487_217_6,
612 501.265_290_891_579_24,
613 506.132_825_342_034_83,
614 511.008_022_665_236_07,
615 515.890_824_587_822_5,
616 520.781_173_716_044_2,
617 525.679_013_515_995,
618 530.584_288_294_433_6,
619 535.496_943_180_169_5,
620 540.416_924_105_997_7,
621 545.344_177_791_155,
622 550.278_651_724_285_6,
623 555.220_294_146_895,
624 560.169_054_037_273_1,
625 565.124_881_094_874_4,
626 570.087_725_725_134_2,
627 575.057_539_024_710_2,
628 580.034_272_767_130_8,
629 585.017_879_388_839_2,
630 590.008_311_975_617_9,
631 595.005_524_249_382,
632 600.009_470_555_327_4,
633 605.020_105_849_423_8,
634 610.037_385_686_238_7,
635 615.061_266_207_084_9,
636 620.091_704_128_477_4,
637 625.128_656_730_891_1,
638 630.172_081_847_810_2,
639 635.221_937_855_059_8,
640 640.278_183_660_408_1,
641 645.340_778_693_435,
642 650.409_682_895_655_2,
643 655.484_856_710_889_1,
644 660.566_261_075_873_5,
645 665.653_857_411_106,
646 670.747_607_611_912_7,
647 675.847_474_039_736_9,
648 680.953_419_513_637_5,
649 686.065_407_301_994,
650 691.183_401_114_410_8,
651 696.307_365_093_814,
652 701.437_263_808_737_2,
653 706.573_062_245_787_5,
654 711.714_725_802_29,
655 716.862_220_279_103_4,
656 722.015_511_873_601_3,
657 727.174_567_172_815_8,
658 732.339_353_146_739_3,
659 737.509_837_141_777_4,
660 742.685_986_874_351_2,
661 747.867_770_424_643_4,
662 753.055_156_230_484_2,
663 758.248_113_081_374_3,
664 763.446_610_112_640_2,
665 768.650_616_799_717,
666 773.860_102_952_558_5,
667 779.075_038_710_167_4,
668 784.295_394_535_245_7,
669 789.521_141_208_959,
670 794.752_249_825_813_5,
671 799.988_691_788_643_5,
672 805.230_438_803_703_1,
673 810.477_462_875_863_6,
674 815.729_736_303_910_2,
675 820.987_231_675_937_9,
676 826.249_921_864_842_8,
677 831.517_780_023_906_3,
678 836.790_779_582_469_9,
679 842.068_894_241_700_5,
680 847.352_097_970_438_4,
681 852.640_365_001_133_1,
682 857.933_669_825_857_5,
683 863.231_987_192_405_4,
684 868.535_292_100_464_6,
685 873.843_559_797_865_7,
686 879.156_765_776_907_6,
687 884.474_885_770_751_8,
688 889.797_895_749_890_2,
689 895.125_771_918_679_9,
690 900.458_490_711_945_3,
691 905.796_028_791_646_3,
692 911.138_363_043_611_2,
693 916.485_470_574_328_8,
694 921.837_328_707_804_9,
695 927.193_914_982_476_7,
696 932.555_207_148_186_2,
697 937.921_183_163_208_1,
698 943.291_821_191_335_7,
699 948.667_099_599_019_8,
700 954.046_996_952_560_4,
701 959.431_492_015_349_5,
702 964.820_563_745_165_9,
703 970.214_191_291_518_3,
704 975.612_353_993_036_2,
705 981.015_031_374_908_4,
706 986.422_203_146_368_6,
707 991.833_849_198_223_4,
708 997.249_949_600_427_8,
709 1_002.670_484_599_700_3,
710 1_008.095_434_617_181_7,
711 1_013.524_780_246_136_2,
712 1_018.958_502_249_690_2,
713 1_024.396_581_558_613_4,
714 1_029.838_999_269_135_5,
715 1_035.285_736_640_801_6,
716 1_040.736_775_094_367_4,
717 1_046.192_096_209_725,
718 1_051.651_681_723_869_2,
719 1_057.115_513_528_895,
720 1_062.583_573_670_03,
721 1_068.055_844_343_701_4,
722 1_073.532_307_895_632_8,
723 1_079.012_946_818_975,
724 1_084.497_743_752_465_6,
725 1_089.986_681_478_622_4,
726 1_095.479_742_921_962_7,
727 1_100.976_911_147_256,
728 1_106.478_169_357_800_9,
729 1_111.983_500_893_733,
730 1_117.492_889_230_361,
731 1_123.006_317_976_526_1,
732 1_128.523_770_872_990_8,
733 1_134.045_231_790_853,
734 1_139.570_684_729_984_8,
735 1_145.100_113_817_496,
736 1_150.633_503_306_223_7,
737 1_156.170_837_573_242_4,
738];
739
740use num::Zero;
741
742pub fn log_product(data: impl Iterator<Item = f64>) -> f64 {
772 let mut result = 0.0;
773 let mut prod = 1.0;
774 for x in data {
775 let next_prod: f64 = x * prod;
776 if next_prod.is_normal() {
777 prod = next_prod;
778 } else {
779 if x.is_zero() {
780 return f64::NEG_INFINITY;
781 }
782 result += prod.ln();
783 prod = x;
784 }
785 }
786 result + prod.ln()
787}
788
789#[cfg(test)]
790mod tests {
791 use super::*;
792 use proptest::prelude::*;
793
794 proptest! {
795 #[test]
796 fn test_log1pexp_close_to_ln_1p_exp(x in -100.0..100.0_f64) {
797 let expected = x.exp().ln_1p();
798 let actual = log1pexp(x);
799 prop_assert!((expected - actual).abs() < 1e-10);
800 }
801 }
802 #[test]
803 fn test_log_product_empty() {
804 let empty: Vec<f64> = vec![];
805 assert_eq!(log_product(empty.into_iter()), 0.0);
806 }
807
808 #[test]
809 fn test_log_product_single_element() {
810 let single = vec![2.0];
811 assert_eq!(log_product(single.into_iter()), 2.0_f64.ln());
812 }
813
814 #[test]
815 fn test_log_product_multiple_elements() {
816 let multiple = vec![2.0, 3.0, 4.0];
817 assert!(
818 (log_product(multiple.into_iter())
819 - (2.0_f64 * 3.0_f64 * 4.0_f64).ln())
820 .abs()
821 < 1e-10
822 );
823 }
824
825 #[test]
826 fn test_log_product_overflow() {
827 let n = 100;
828 let large = vec![1e100; n];
829 let result = log_product(large.into_iter());
830 let correct = n as f64 * 1e100_f64.ln();
831 assert!((result - correct).abs() < 1e-10);
832 }
833
834 #[test]
835 fn test_log_product_underflow() {
836 let n = 100;
837 let large = vec![1e-100; n];
838 let result = log_product(large.into_iter());
839 let correct = n as f64 * 1e-100_f64.ln();
840 assert!((result - correct).abs() < 1e-10);
841 }
842
843 #[test]
844 fn test_log_product_with_zero() {
845 let with_zero = vec![2.0, 0.0, 3.0];
846 assert_eq!(log_product(with_zero.into_iter()), f64::NEG_INFINITY);
847 }
848
849 use crate::prelude::ChiSquared;
850 use crate::traits::Cdf;
851 use rand::{SeedableRng, rng};
852
853 const TOL: f64 = 1E-12;
854
855 #[test]
856 fn argmax_empty_is_empty() {
857 let xs: Vec<f64> = vec![];
858 assert_eq!(argmax(&xs), Vec::<usize>::new());
859 }
860
861 #[test]
862 fn argmax_single_elem_is_0() {
863 let xs: Vec<f64> = vec![1.0];
864 assert_eq!(argmax(&xs), vec![0]);
865 }
866
867 #[test]
868 fn argmax_unique_max() {
869 let xs: Vec<u8> = vec![1, 2, 3, 4, 5, 4, 3];
870 assert_eq!(argmax(&xs), vec![4]);
871 }
872
873 #[test]
874 fn argmax_repeated_max() {
875 let xs: Vec<u8> = vec![1, 2, 3, 4, 5, 4, 5];
876 assert_eq!(argmax(&xs), vec![4, 6]);
877 }
878
879 #[test]
880 fn logsumexp_nan_handling() {
881 let a: f64 = -3.0;
882 let b: f64 = -7.0;
883 let target: f64 = logaddexp(a, b);
884 let xs = [
885 -f64::INFINITY,
886 a,
887 -f64::INFINITY,
888 b,
889 -f64::INFINITY,
890 -f64::INFINITY,
891 -f64::INFINITY,
892 -f64::INFINITY,
893 -f64::INFINITY,
894 -f64::INFINITY,
895 ];
896 let result = xs.iter().logsumexp();
897 assert!((result - target).abs() < 1e-12);
898 }
899
900 proptest! {
901 #[test]
902 fn proptest_logsumexp(xs in prop::collection::vec(-1e10_f64..1e10_f64, 0..100)) {
903 let result = xs.iter().logsumexp();
904 if xs.is_empty() {
905 prop_assert!(result == f64::NEG_INFINITY);
906 } else {
907 let max_x = xs.iter().copied().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
909 let sum_exp = xs.iter().map(|&x| (x - max_x).exp()).sum::<f64>();
910 let expected = max_x + sum_exp.ln();
911
912 prop_assert!((result - expected).abs() < 1e-10);
914
915 prop_assert!(result >= *xs.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap());
917
918 let sum_exp_inputs: f64 = xs.iter().map(|&x| x.exp()).sum();
920 prop_assert!(result.exp() >= sum_exp_inputs);
921 }
922 }
923
924 }
925
926 #[test]
927 fn lnmv_gamma_values() {
928 assert::close(lnmv_gamma(1, 1.0), 0.0, TOL);
929 assert::close(lnmv_gamma(1, 12.0), 17.502_307_845_873_887, TOL);
930 assert::close(lnmv_gamma(3, 12.0), 50.615_815_724_290_74, TOL);
931 assert::close(lnmv_gamma(3, 8.23), 25.709_195_968_438_628, TOL);
932 }
933
934 #[test]
935 fn bisection_and_standard_catflip_equivalence() {
936 let mut rng = rand::rng();
937 for _ in 0..1000 {
938 let n: usize = rng.random_range(10..100);
939 let cws: Vec<f64> = (1..=n).map(|i| i as f64).collect();
940 let u2 = rand::distr::Uniform::new(0.0, n as f64).unwrap();
941 let r = rng.sample(u2);
942
943 let ix1 = catflip_standard(&cws, r).unwrap();
944 let ix2 = catflip_bisection(&cws, r).unwrap();
945
946 assert_eq!(ix1, ix2);
947 }
948 }
949
950 #[test]
951 fn ln_fact_agrees_with_naive() {
952 fn ln_fact_naive(x: usize) -> f64 {
953 if x < 2 {
954 0.0
955 } else {
956 (2..=x).map(|y| (y as f64).ln()).sum()
957 }
958 }
959
960 for x in 0..300 {
961 let f1 = ln_fact_naive(x);
962 let f2 = ln_fact(x);
963 assert::close(f1, f2, 1e-9);
964 }
965 }
966
967 #[test]
968 fn ln_pflips_works_with_zero_weights() {
969 use std::f64::consts::LN_2;
970
971 let ln_weights: Vec<f64> = vec![-LN_2, f64::NEG_INFINITY, -LN_2];
972
973 let xs = ln_pflips(&ln_weights, 100, true, &mut rand::rng());
974
975 let zero_count = xs.iter().filter(|&&x| x == 0).count();
976 let one_count = xs.iter().filter(|&&x| x == 1).count();
977 let two_count = xs.iter().filter(|&&x| x == 2).count();
978
979 assert!(zero_count > 30);
980 assert_eq!(one_count, 0);
981 assert!(two_count > 30);
982 }
983
984 #[test]
985 fn test_sorted_uniforms() {
986 let mut rng = rng();
987 let n = 1000;
988 let xs = sorted_uniforms(n, &mut rng);
989 assert_eq!(xs.len(), n);
990
991 assert!(&0.0 < xs.first().unwrap());
993 assert!(xs.last().unwrap() < &1.0);
994 assert!(xs.windows(2).all(|w| w[0] <= w[1]));
995
996 let mut t = 0.0;
998
999 {
1000 let mut next_bin = 0.01;
1003 let mut bin_pop = 0;
1004
1005 for x in &xs {
1006 bin_pop += 1;
1007 if *x > next_bin {
1008 let obs = f64::from(bin_pop);
1009 let exp = n as f64 / 100.0;
1010 t += (obs - exp).powi(2) / exp;
1011 bin_pop = 0;
1012 next_bin += 0.01;
1013 }
1014 }
1015
1016 let obs = f64::from(bin_pop);
1018 let exp = n as f64 / 100.0;
1019 t += (obs - exp).powi(2) / exp;
1020 }
1021
1022 let alpha = 0.001;
1023
1024 let chi2 = ChiSquared::new(99.0).unwrap();
1026 let p = chi2.sf(&t);
1027 assert!(p > alpha);
1028 }
1029
1030 use crate::prelude::Gaussian;
1031 use crate::traits::Sampleable;
1032 #[test]
1033 fn ln_pflip_sampling_distribution() {
1034 let n_samples = 1_000;
1035 let mut rng = rand::rngs::StdRng::seed_from_u64(123);
1036
1037 let ln_weights =
1039 Gaussian::new(0.0, 1.0).unwrap().sample(n_samples, &mut rng);
1040 let log_normalizer: f64 = ln_weights.iter().logsumexp();
1041 let expected: Vec<f64> = ln_weights
1042 .iter()
1043 .map(|w| (w - log_normalizer).exp() * n_samples as f64)
1044 .collect();
1045
1046 let mut counts = vec![0; ln_weights.len()];
1048 for _ in 0..n_samples {
1049 let sample = ln_pflip(&ln_weights, false, &mut rng);
1050 counts[sample] += 1;
1051 }
1052 let chi_squared: f64 = counts
1054 .iter()
1055 .zip(expected.iter())
1056 .map(|(obs, exp)| {
1057 let diff = f64::from(*obs) - exp;
1058 diff * diff / exp
1059 })
1060 .sum();
1061
1062 let dof = ln_weights.len() - 1;
1064 let chi2 = ChiSquared::new(dof as f64).unwrap();
1065 let p_value = chi2.sf(&chi_squared);
1066
1067 assert!(
1068 p_value > 0.01,
1069 "Chi-squared test failed: p-value = {p_value}"
1070 );
1071 }
1072}