1pub const LOG_0: f64 = f64::NEG_INFINITY;
9pub const LOG_1: f64 = 0.0;
11pub const LOG_EPSILON: f64 = -1.0e10;
14
15#[inline]
19pub fn log_add(x: f64, y: f64) -> f64 {
20 if x == LOG_0 {
21 return y;
22 }
23 if y == LOG_0 {
24 return x;
25 }
26 let (hi, lo) = if x > y { (x, y) } else { (y, x) };
27 hi + (lo - hi).exp().ln_1p()
29}
30
31#[inline]
35pub fn log_sub(x: f64, y: f64) -> f64 {
36 debug_assert!(x >= y, "log_sub requires x >= y (x={x}, y={y})");
37 if y == LOG_0 {
38 return x;
39 }
40 if x == y {
41 return LOG_0;
42 }
43 x + (-((y - x).exp())).ln_1p()
45}
46
47pub fn log_sum<I: IntoIterator<Item = f64>>(values: I) -> f64 {
49 values.into_iter().fold(LOG_0, log_add)
50}
51
52#[cfg(test)]
53mod tests {
54 use super::*;
55
56 fn approx(a: f64, b: f64) {
57 if a == b {
58 return; }
60 assert!((a - b).abs() < 1e-9, "expected {b}, got {a}");
61 }
62
63 #[test]
64 fn add_identities() {
65 approx(log_add(LOG_0, 0.5), 0.5);
66 approx(log_add(0.5, LOG_0), 0.5);
67 approx(log_add(LOG_1, LOG_1), 2.0_f64.ln());
69 }
70
71 #[test]
72 fn add_matches_naive() {
73 let (x, y) = (-3.2_f64, -1.1_f64);
74 let naive = (x.exp() + y.exp()).ln();
75 approx(log_add(x, y), naive);
76 }
77
78 #[test]
79 fn sub_works() {
80 let (x, y) = (2.0_f64, 1.0_f64);
81 let naive = (x.exp() - y.exp()).ln();
82 approx(log_sub(x, y), naive);
83 approx(log_sub(x, LOG_0), x);
84 approx(log_sub(x, x), LOG_0);
85 }
86
87 #[test]
88 fn sum_works() {
89 let vals = [LOG_1, LOG_1, LOG_1];
90 approx(log_sum(vals), 3.0_f64.ln());
91 approx(log_sum(std::iter::empty::<f64>()), LOG_0);
92 }
93}