use crate::consts::{LN_2PI, LN_PI};
use rand::distributions::Open01;
use rand::Rng;
use special::Gamma;
use std::cmp::Ordering;
use std::cmp::PartialOrd;
use std::fmt::Debug;
use std::ops::AddAssign;
pub fn vec_to_string<T: Debug>(xs: &[T], max_entries: usize) -> String {
let mut out = String::new();
out += "[";
let n = xs.len();
xs.iter().enumerate().for_each(|(i, x)| {
let to_push = if i < max_entries - 1 {
format!("{:?}, ", x)
} else if i == (max_entries - 1) && n > max_entries {
String::from("... , ")
} else {
format!("{:?}]", x)
};
out.push_str(to_push.as_str());
});
out
}
pub fn ln_binom(n: f64, k: f64) -> f64 {
(n + 1.0).ln_gamma().0 - (k + 1.0).ln_gamma().0 - (n - k + 1.0).ln_gamma().0
}
pub fn logsumexp(xs: &[f64]) -> f64 {
if xs.is_empty() {
panic!("Empty container");
} else if xs.len() == 1 {
xs[0]
} else {
let maxval =
*xs.iter().max_by(|x, y| x.partial_cmp(y).unwrap()).unwrap();
xs.iter().fold(0.0, |acc, x| acc + (x - maxval).exp()).ln() + maxval
}
}
pub fn cumsum<T>(xs: &[T]) -> Vec<T>
where
T: AddAssign + Copy + Default,
{
xs.iter()
.scan(T::default(), |acc, &x| {
*acc += x;
Some(*acc)
})
.collect()
}
#[inline]
fn binary_search(cws: &[f64], r: f64) -> usize {
let mut left: usize = 0;
let mut right: usize = cws.len();
while left < right {
let mid = (left + right) / 2;
if cws[mid] < r {
left = mid + 1;
} else {
right = mid;
}
}
left
}
#[inline]
fn catflip_bisection(cws: &[f64], r: f64) -> Option<usize> {
let ix = binary_search(&cws, r);
if ix < cws.len() {
Some(ix)
} else {
None
}
}
#[inline]
fn catflip_standard(cws: &[f64], r: f64) -> Option<usize> {
cws.iter().position(|&w| w > r)
}
fn catflip(cws: &[f64], r: f64) -> Option<usize> {
if cws.len() > 9 {
catflip_bisection(&cws, r)
} else {
catflip_standard(&cws, r)
}
}
pub fn pflip(weights: &[f64], n: usize, rng: &mut impl Rng) -> Vec<usize> {
if weights.is_empty() {
panic!("Empty container");
}
let cws: Vec<f64> = cumsum(weights);
let scale: f64 = *cws.last().unwrap();
let u = rand::distributions::Uniform::new(0.0, 1.0);
(0..n)
.map(|_| {
let r = rng.sample(u) * scale;
match catflip(&cws, r) {
Some(ix) => ix,
None => {
let wsvec = weights.to_vec();
panic!("Could not draw from {:?}", wsvec)
}
}
})
.collect()
}
pub fn ln_pflip<R: Rng>(
ln_weights: &[f64],
n: usize,
normed: bool,
rng: &mut R,
) -> Vec<usize> {
let z = if normed { 0.0 } else { logsumexp(ln_weights) };
let mut cws: Vec<f64> = ln_weights.iter().map(|w| (w - z).exp()).collect();
for i in 1..cws.len() {
cws[i] += cws[i - 1];
}
(0..n)
.map(|_| {
let r = rng.sample(Open01);
match catflip(&cws, r) {
Some(ix) => ix,
None => {
let wsvec = ln_weights.to_vec();
panic!("Could not draw from {:?}", wsvec)
}
}
})
.collect()
}
pub fn argmax<T: PartialOrd>(xs: &[T]) -> Vec<usize> {
if xs.is_empty() {
vec![]
} else if xs.len() == 1 {
vec![0]
} else {
let mut maxval = &xs[0];
let mut max_ixs: Vec<usize> = vec![0];
for (i, x) in xs.iter().enumerate().skip(1) {
match x.partial_cmp(maxval) {
Some(Ordering::Greater) => {
maxval = x;
max_ixs = vec![i];
}
Some(Ordering::Equal) => max_ixs.push(i),
_ => (),
}
}
max_ixs
}
}
pub fn lnmv_gamma(p: usize, a: f64) -> f64 {
let pf = p as f64;
let a0 = pf * (pf - 1.0) / 4.0 * LN_PI;
(1..=p).fold(a0, |acc, j| acc + (a + (1.0 - j as f64) / 2.0).ln_gamma().0)
}
pub fn mvgamma(p: usize, a: f64) -> f64 {
lnmv_gamma(p, a).exp()
}
pub fn ln_fact(n: usize) -> f64 {
if n < 254 {
LN_FACT[n]
} else {
let y: f64 = (n as f64) + 1.0;
(y - 0.5) * y.ln() - y + 0.5 * LN_2PI + (12.0 * y).recip()
}
}
const LN_FACT: [f64; 255] = [
0.000000000000000,
0.000000000000000,
std::f64::consts::LN_2,
1.791759469228055,
3.178053830347946,
4.787491742782046,
6.579251212010101,
8.525161361065415,
10.604_602_902_745_25,
12.801827480081469,
15.104412573075516,
17.502307845873887,
19.987214495661885,
22.552_163_853_123_42,
25.191221182738683,
27.899271383840894,
30.671860106080675,
33.505_073_450_136_89,
36.395_445_208_033_05,
39.339884187199495,
42.335616460753485,
45.380_138_898_476_91,
48.471_181_351_835_23,
51.606_675_567_764_38,
54.784_729_398_112_32,
58.003_605_222_980_52,
61.261_701_761_002,
64.557_538_627_006_32,
67.889_743_137_181_53,
71.257_038_967_168,
74.658_236_348_830_16,
78.092_223_553_315_3,
81.557_959_456_115_03,
85.054_467_017_581_52,
88.580_827_542_197_68,
92.136_175_603_687_08,
95.719_694_542_143_2,
99.330_612_454_787_43,
102.968_198_614_513_81,
106.631_760_260_643_45,
110.320_639_714_757_39,
114.034_211_781_461_69,
117.771_881_399_745_06,
121.533_081_515_438_64,
125.317_271_149_356_88,
129.123_933_639_127_24,
132.952_575_035_616_3,
136.802_722_637_326_35,
140.673_923_648_234_25,
144.565_743_946_344_9,
148.477_766_951_773_02,
152.409_592_584_497_35,
156.360_836_303_078_8,
160.331_128_216_630_93,
164.320_112_263_195_17,
168.327_445_448_427_65,
172.352_797_139_162_82,
176.395_848_406_997_37,
180.456_291_417_543_78,
184.533_828_861_449_5,
188.628_173_423_671_6,
192.739_047_287_844_9,
196.866_181_672_889_98,
201.009_316_399_281_57,
205.168_199_482_641_2,
209.342_586_752_536_82,
213.532_241_494_563_27,
217.736_934_113_954_25,
221.956_441_819_130_36,
226.190_548_323_727_57,
230.439_043_565_776_93,
234.701_723_442_818_26,
238.978_389_561_834_35,
243.268_849_002_982_73,
247.572_914_096_186_9,
251.890_402_209_723_2,
256.221_135_550_009_5,
260.564_940_971_863_2,
264.921_649_798_552_8,
269.291_097_651_019_8,
273.673_124_285_693_7,
278.067_573_440_366_1,
282.474_292_687_630_4,
286.893_133_295_427,
291.323_950_094_270_3,
295.766_601_350_760_6,
300.220_948_647_014_1,
304.686_856_765_668_7,
309.164_193_580_146_9,
313.652_829_949_879,
318.152_639_620_209_3,
322.663_499_126_726_2,
327.185_287_703_775_2,
331.717_887_196_928_5,
336.261_181_979_198_45,
340.815_058_870_798_96,
345.379_407_062_266_86,
349.954_118_040_770_25,
354.539_085_519_440_8,
359.134_205_369_575_34,
363.739_375_555_563_47,
368.354_496_072_404_7,
372.979_468_885_689,
377.614_197_873_918_67,
382.258_588_773_06,
386.912_549_123_217_56,
391.575_988_217_329_6,
396.248_817_051_791_5,
400.930_948_278_915_76,
405.622_296_161_144_9,
410.322_776_526_937_3,
415.032_306_728_249_6,
419.750_805_599_544_8,
424.478_193_418_257_1,
429.214_391_866_651_57,
433.959_323_995_014_87,
438.712_914_186_121_17,
443.475_088_120_918_94,
448.245_772_745_384_6,
453.024_896_238_496_1,
457.812_387_981_278_1,
462.608_178_526_874_9,
467.412_199_571_608_1,
472.224_383_926_980_5,
477.044_665_492_585_6,
481.872_979_229_887_9,
486.709_261_136_839_36,
491.553_448_223_298,
496.405_478_487_217_6,
501.265_290_891_579_24,
506.132_825_342_034_83,
511.008_022_665_236_07,
515.890_824_587_822_5,
520.781_173_716_044_2,
525.679_013_515_995,
530.584_288_294_433_6,
535.496_943_180_169_5,
540.416_924_105_997_7,
545.344_177_791_155,
550.278_651_724_285_6,
555.220_294_146_895,
560.169_054_037_273_1,
565.124_881_094_874_4,
570.087_725_725_134_2,
575.057_539_024_710_2,
580.034_272_767_130_8,
585.017_879_388_839_2,
590.008_311_975_617_9,
595.005_524_249_382,
600.009_470_555_327_4,
605.020_105_849_423_8,
610.037_385_686_238_7,
615.061_266_207_084_9,
620.091_704_128_477_4,
625.128_656_730_891_1,
630.172_081_847_810_2,
635.221_937_855_059_8,
640.278_183_660_408_1,
645.340_778_693_435,
650.409_682_895_655_2,
655.484_856_710_889_1,
660.566_261_075_873_5,
665.653_857_411_106,
670.747_607_611_912_7,
675.847_474_039_736_9,
680.953_419_513_637_5,
686.065_407_301_994,
691.183_401_114_410_8,
696.307_365_093_814,
701.437_263_808_737_2,
706.573_062_245_787_5,
711.714_725_802_29,
716.862_220_279_103_4,
722.015_511_873_601_3,
727.174_567_172_815_8,
732.339_353_146_739_3,
737.509_837_141_777_4,
742.685_986_874_351_2,
747.867_770_424_643_4,
753.055_156_230_484_2,
758.248_113_081_374_3,
763.446_610_112_640_2,
768.650_616_799_717,
773.860_102_952_558_5,
779.075_038_710_167_4,
784.295_394_535_245_7,
789.521_141_208_959,
794.752_249_825_813_5,
799.988_691_788_643_5,
805.230_438_803_703_1,
810.477_462_875_863_6,
815.729_736_303_910_2,
820.987_231_675_937_9,
826.249_921_864_842_8,
831.517_780_023_906_3,
836.790_779_582_469_9,
842.068_894_241_700_5,
847.352_097_970_438_4,
852.640_365_001_133_1,
857.933_669_825_857_5,
863.231_987_192_405_4,
868.535_292_100_464_6,
873.843_559_797_865_7,
879.156_765_776_907_6,
884.474_885_770_751_8,
889.797_895_749_890_2,
895.125_771_918_679_9,
900.458_490_711_945_3,
905.796_028_791_646_3,
911.138_363_043_611_2,
916.485_470_574_328_8,
921.837_328_707_804_9,
927.193_914_982_476_7,
932.555_207_148_186_2,
937.921_183_163_208_1,
943.291_821_191_335_7,
948.667_099_599_019_8,
954.046_996_952_560_4,
959.431_492_015_349_5,
964.820_563_745_165_9,
970.214_191_291_518_3,
975.612_353_993_036_2,
981.015_031_374_908_4,
986.422_203_146_368_6,
991.833_849_198_223_4,
997.249_949_600_427_8,
1_002.670_484_599_700_3,
1_008.095_434_617_181_7,
1_013.524_780_246_136_2,
1_018.958_502_249_690_2,
1_024.396_581_558_613_4,
1_029.838_999_269_135_5,
1_035.285_736_640_801_6,
1_040.736_775_094_367_4,
1_046.192_096_209_725,
1_051.651_681_723_869_2,
1_057.115_513_528_895,
1_062.583_573_670_03,
1_068.055_844_343_701_4,
1_073.532_307_895_632_8,
1_079.012_946_818_975,
1_084.497_743_752_465_6,
1_089.986_681_478_622_4,
1_095.479_742_921_962_7,
1_100.976_911_147_256,
1_106.478_169_357_800_9,
1_111.983_500_893_733,
1_117.492_889_230_361,
1_123.006_317_976_526_1,
1_128.523_770_872_990_8,
1_134.045_231_790_853,
1_139.570_684_729_984_8,
1_145.100_113_817_496,
1_150.633_503_306_223_7,
1_156.170_837_573_242_4,
];
#[cfg(test)]
mod tests {
use super::*;
const TOL: f64 = 1E-12;
#[test]
fn argmax_empty_is_empty() {
let xs: Vec<f64> = vec![];
assert_eq!(argmax(&xs), Vec::<usize>::new());
}
#[test]
fn argmax_single_elem_is_0() {
let xs: Vec<f64> = vec![1.0];
assert_eq!(argmax(&xs), vec![0]);
}
#[test]
fn argmax_unique_max() {
let xs: Vec<u8> = vec![1, 2, 3, 4, 5, 4, 3];
assert_eq!(argmax(&xs), vec![4]);
}
#[test]
fn argmax_repeated_max() {
let xs: Vec<u8> = vec![1, 2, 3, 4, 5, 4, 5];
assert_eq!(argmax(&xs), vec![4, 6]);
}
#[test]
fn logsumexp_on_vector_of_zeros() {
let xs: Vec<f64> = vec![0.0; 5];
assert::close(logsumexp(&xs), 1.6094379124341003, TOL);
}
#[test]
fn logsumexp_on_random_values() {
let xs: Vec<f64> = vec![
0.30415386,
-0.07072296,
-1.04287019,
0.27855407,
-0.81896765,
];
assert::close(logsumexp(&xs), 1.4820007894263059, TOL);
}
#[test]
fn logsumexp_returns_only_value_on_one_element_container() {
let xs: Vec<f64> = vec![0.30415386];
assert::close(logsumexp(&xs), 0.30415386, TOL);
}
#[test]
#[should_panic]
fn logsumexp_should_panic_on_empty() {
let xs: Vec<f64> = Vec::new();
logsumexp(&xs);
}
#[test]
fn lnmv_gamma_values() {
assert::close(lnmv_gamma(1, 1.0), 0.0, TOL);
assert::close(lnmv_gamma(1, 12.0), 17.502307845873887, TOL);
assert::close(lnmv_gamma(3, 12.0), 50.615_815_724_290_74, TOL);
assert::close(lnmv_gamma(3, 8.23), 25.709195968438628, TOL);
}
#[test]
fn bisection_and_stanard_catflip_equivalence() {
let mut rng = rand::thread_rng();
for _ in 0..1000 {
let n: usize = rng.gen_range(10..100);
let cws: Vec<f64> = (1..=n).map(|i| i as f64).collect();
let u2 = rand::distributions::Uniform::new(0.0, n as f64);
let r = rng.sample(u2);
let ix1 = catflip_standard(&cws, r).unwrap();
let ix2 = catflip_bisection(&cws, r).unwrap();
assert_eq!(ix1, ix2);
}
}
#[test]
fn ln_fact_agrees_with_naive() {
fn ln_fact_naive(x: usize) -> f64 {
if x < 2 {
0.0
} else {
(2..=x).map(|y| (y as f64).ln()).sum()
}
}
for x in 0..300 {
let f1 = ln_fact_naive(x);
let f2 = ln_fact(x);
assert::close(f1, f2, 1e-9);
}
}
}