Skip to main content

salmon_core/
math.rs

1//! Log-space arithmetic helpers.
2//!
3//! Salmon performs nearly all of its probability accumulation in log space to
4//! avoid underflow. These helpers mirror the conventions in the C++ code where
5//! `LOG_0` is negative infinity and `LOG_1` is zero.
6
7/// log(0) = -inf
8pub const LOG_0: f64 = f64::NEG_INFINITY;
9/// log(1) = 0
10pub const LOG_1: f64 = 0.0;
11/// A very small log-space epsilon used as an effective "zero" mass in places
12/// where strict -inf would propagate NaNs.
13pub const LOG_EPSILON: f64 = -1.0e10;
14
15/// Numerically stable `log(exp(x) + exp(y))`.
16///
17/// Handles the `-inf` identities so that `log_add(LOG_0, y) == y`.
18#[inline]
19pub fn log_add(x: f64, y: f64) -> f64 {
20    if x == LOG_0 {
21        return y;
22    }
23    if y == LOG_0 {
24        return x;
25    }
26    let (hi, lo) = if x > y { (x, y) } else { (y, x) };
27    // hi + log1p(exp(lo - hi)); lo - hi <= 0 so exp is in (0, 1].
28    hi + (lo - hi).exp().ln_1p()
29}
30
31/// Numerically stable `log(exp(x) - exp(y))`, requires `x >= y`.
32///
33/// Returns `LOG_0` when `x == y`.
34#[inline]
35pub fn log_sub(x: f64, y: f64) -> f64 {
36    debug_assert!(x >= y, "log_sub requires x >= y (x={x}, y={y})");
37    if y == LOG_0 {
38        return x;
39    }
40    if x == y {
41        return LOG_0;
42    }
43    // x + log(1 - exp(y - x)); y - x < 0.
44    x + (-((y - x).exp())).ln_1p()
45}
46
47/// Log-sum-exp over an iterator of log-space values.
48pub fn log_sum<I: IntoIterator<Item = f64>>(values: I) -> f64 {
49    values.into_iter().fold(LOG_0, log_add)
50}
51
52#[cfg(test)]
53mod tests {
54    use super::*;
55
56    fn approx(a: f64, b: f64) {
57        if a == b {
58            return; // covers the -inf == -inf case
59        }
60        assert!((a - b).abs() < 1e-9, "expected {b}, got {a}");
61    }
62
63    #[test]
64    fn add_identities() {
65        approx(log_add(LOG_0, 0.5), 0.5);
66        approx(log_add(0.5, LOG_0), 0.5);
67        // log(e^0 + e^0) = log(2)
68        approx(log_add(LOG_1, LOG_1), 2.0_f64.ln());
69    }
70
71    #[test]
72    fn add_matches_naive() {
73        let (x, y) = (-3.2_f64, -1.1_f64);
74        let naive = (x.exp() + y.exp()).ln();
75        approx(log_add(x, y), naive);
76    }
77
78    #[test]
79    fn sub_works() {
80        let (x, y) = (2.0_f64, 1.0_f64);
81        let naive = (x.exp() - y.exp()).ln();
82        approx(log_sub(x, y), naive);
83        approx(log_sub(x, LOG_0), x);
84        approx(log_sub(x, x), LOG_0);
85    }
86
87    #[test]
88    fn sum_works() {
89        let vals = [LOG_1, LOG_1, LOG_1];
90        approx(log_sum(vals), 3.0_f64.ln());
91        approx(log_sum(std::iter::empty::<f64>()), LOG_0);
92    }
93}