Skip to main content

scirs2_stats/advi/
transforms.rs

1//! Bijective parameter transforms for ADVI.
2//!
3//! Maps constrained parameters to unconstrained real space so that
4//! the variational Gaussian is defined over the entire real line.
5//! Each transform comes with its inverse and log-absolute-Jacobian
6//! (change-of-variables adjustment to the ELBO).
7
8use crate::error::{StatsError, StatsResult};
9
10use super::types::ConstraintType;
11
12// ============================================================================
13// Standalone transform functions (convenient API)
14// ============================================================================
15
16/// Map a positive real θ > 0 to the unconstrained real line: η = log(θ).
17///
18/// Inverse: θ = exp(η).
19#[inline]
20pub fn log_transform(x: f64) -> f64 {
21    x.ln()
22}
23
24/// Map a bounded parameter θ ∈ (lo, hi) to the unconstrained real line
25/// via the scaled logit: η = logit((θ − lo) / (hi − lo)).
26///
27/// Inverse: θ = lo + (hi − lo) · sigmoid(η).
28#[inline]
29pub fn logit_transform(x: f64, lo: f64, hi: f64) -> f64 {
30    let s = (x - lo) / (hi - lo);
31    (s / (1.0 - s)).ln()
32}
33
34/// Map a real-valued vector to the probability simplex via softmax:
35/// p_i = exp(x_i) / Σ_j exp(x_j).
36///
37/// Numerically stable: subtract max before exponentiating.
38pub fn softmax_transform(x: &[f64]) -> Vec<f64> {
39    if x.is_empty() {
40        return Vec::new();
41    }
42    let max_val = x.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
43    let exps: Vec<f64> = x.iter().map(|&v| (v - max_val).exp()).collect();
44    let sum: f64 = exps.iter().sum();
45    exps.iter().map(|&e| e / sum).collect()
46}
47
48/// Log-absolute Jacobian for the log transform.
49///
50/// θ = exp(η) ⟹ |dθ/dη| = exp(η) = θ ⟹ log|J| = η = log(θ).
51///
52/// Equivalently, the *correction* term for the ELBO is log|dη/dθ| = -log(θ).
53#[inline]
54pub fn log_jacobian_positive(x: f64) -> f64 {
55    // Correction added to log p(θ): log|dη/dθ| where η = log(θ).
56    // dη/dθ = 1/θ  ⟹  log|dη/dθ| = -log(θ)
57    -x.ln()
58}
59
60/// Log-absolute Jacobian for the bounded (logit) transform.
61///
62/// η = logit((θ − lo)/(hi − lo))
63/// |dη/dθ| = 1 / [(θ − lo)(hi − θ)] · (hi − lo)
64/// log|dη/dθ| = -log(θ − lo) - log(hi − θ) + log(hi − lo)
65#[inline]
66pub fn log_jacobian_bounded(x: f64, lo: f64, hi: f64) -> f64 {
67    let range = hi - lo;
68    -(x - lo).ln() - (hi - x).ln() + range.ln()
69}
70
71// ============================================================================
72// TransformSpec — per-parameter specification
73// ============================================================================
74
75/// Per-parameter transform specification: constraint type + bijective map.
76///
77/// A `TransformSpec` pairs a `ConstraintType` with the forward/inverse
78/// transform functions and log-Jacobian for use inside ADVI.
79#[derive(Debug, Clone, PartialEq)]
80pub struct TransformSpec {
81    /// The constraint type for this parameter
82    pub constraint: ConstraintType,
83}
84
85impl TransformSpec {
86    /// Create a `TransformSpec` from a `ConstraintType`.
87    pub fn new(constraint: ConstraintType) -> Self {
88        Self { constraint }
89    }
90
91    /// Create an unconstrained (identity) spec.
92    pub fn unconstrained() -> Self {
93        Self::new(ConstraintType::Unconstrained)
94    }
95
96    /// Create a positive-valued spec (log transform).
97    pub fn positive() -> Self {
98        Self::new(ConstraintType::Positive)
99    }
100
101    /// Create a bounded spec for θ ∈ (lo, hi).
102    pub fn bounded(lo: f64, hi: f64) -> Self {
103        Self::new(ConstraintType::Bounded { lo, hi })
104    }
105
106    /// Forward transform: constrained θ → unconstrained η.
107    ///
108    /// Returns an error if the value violates the constraint.
109    pub fn to_unconstrained(&self, theta: f64) -> StatsResult<f64> {
110        match &self.constraint {
111            ConstraintType::Unconstrained => Ok(theta),
112            ConstraintType::Positive => {
113                if theta <= 0.0 {
114                    return Err(StatsError::invalid_argument(format!(
115                        "Positive constraint violated: θ = {} must be > 0",
116                        theta
117                    )));
118                }
119                Ok(log_transform(theta))
120            }
121            ConstraintType::Bounded { lo, hi } => {
122                if theta <= *lo || theta >= *hi {
123                    return Err(StatsError::invalid_argument(format!(
124                        "Bounded constraint violated: θ = {} must lie in ({}, {})",
125                        theta, lo, hi
126                    )));
127                }
128                Ok(logit_transform(theta, *lo, *hi))
129            }
130            ConstraintType::Simplex => {
131                // For simplex, we use the additive log-ratio (ALR) of the last
132                // element as reference; but for a single scalar we just return identity.
133                Ok(theta)
134            }
135        }
136    }
137
138    /// Inverse transform: unconstrained η → constrained θ.
139    pub fn to_constrained(&self, eta: f64) -> f64 {
140        match &self.constraint {
141            ConstraintType::Unconstrained => eta,
142            ConstraintType::Positive => eta.exp(),
143            ConstraintType::Bounded { lo, hi } => {
144                let s = sigmoid(eta);
145                lo + (hi - lo) * s
146            }
147            ConstraintType::Simplex => eta,
148        }
149    }
150
151    /// Log-absolute-Jacobian of the inverse transform (η → θ),
152    /// i.e., log|dθ/dη|.  Added to log p(θ) when computing the ELBO.
153    pub fn log_jacobian_inverse(&self, eta: f64) -> f64 {
154        match &self.constraint {
155            ConstraintType::Unconstrained => 0.0,
156            ConstraintType::Positive => {
157                // θ = exp(η) ⟹ dθ/dη = exp(η) ⟹ log|J| = η
158                eta
159            }
160            ConstraintType::Bounded { lo, hi } => {
161                // θ = lo + (hi - lo) σ(η)
162                // dθ/dη = (hi - lo) σ(η)(1 − σ(η))
163                let range = hi - lo;
164                let s = sigmoid(eta);
165                range.ln() + s.ln() + (1.0 - s).ln()
166            }
167            ConstraintType::Simplex => 0.0,
168        }
169    }
170}
171
172/// Numerically stable sigmoid: σ(x) = 1 / (1 + exp(−x))
173#[inline]
174pub(crate) fn sigmoid(x: f64) -> f64 {
175    if x >= 0.0 {
176        1.0 / (1.0 + (-x).exp())
177    } else {
178        let ex = x.exp();
179        ex / (1.0 + ex)
180    }
181}
182
183// ============================================================================
184// Tests
185// ============================================================================
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    const EPS: f64 = 1e-10;
192
193    #[test]
194    fn test_log_transform_roundtrip() {
195        for x in [0.001, 0.1, 1.0, 10.0, 1000.0] {
196            let eta = log_transform(x);
197            let recovered = eta.exp();
198            assert!(
199                (recovered - x).abs() < EPS * x.max(1.0),
200                "Roundtrip failed for x={}: got {}",
201                x,
202                recovered
203            );
204        }
205    }
206
207    #[test]
208    fn test_logit_transform_range() {
209        let lo = -2.0;
210        let hi = 5.0;
211        // For x in (lo, hi), output should be finite real number
212        for x in [-1.5, 0.0, 1.0, 3.0, 4.5] {
213            let eta = logit_transform(x, lo, hi);
214            assert!(
215                eta.is_finite(),
216                "logit_transform({}, {}, {}) = {} is not finite",
217                x,
218                lo,
219                hi,
220                eta
221            );
222        }
223        // Edge: x approaching lo should give -∞, x approaching hi should give +∞
224        let near_lo = logit_transform(lo + 1e-10, lo, hi);
225        let near_hi = logit_transform(hi - 1e-10, lo, hi);
226        assert!(near_lo < -20.0, "Near lo should give large negative value");
227        assert!(near_hi > 20.0, "Near hi should give large positive value");
228    }
229
230    #[test]
231    fn test_softmax_sums_one() {
232        let x = vec![1.0, 2.0, 3.0, -1.0, 0.5];
233        let p = softmax_transform(&x);
234        let sum: f64 = p.iter().sum();
235        assert!((sum - 1.0).abs() < 1e-12, "Softmax sum = {} ≠ 1", sum);
236        for &pi in &p {
237            assert!(pi >= 0.0 && pi <= 1.0, "Probability {} out of [0,1]", pi);
238        }
239    }
240
241    #[test]
242    fn test_softmax_empty() {
243        let p = softmax_transform(&[]);
244        assert!(p.is_empty());
245    }
246
247    #[test]
248    fn test_softmax_single() {
249        let p = softmax_transform(&[3.7]);
250        assert!((p[0] - 1.0).abs() < 1e-12);
251    }
252
253    #[test]
254    fn test_log_jacobian_positive() {
255        // log|dη/dθ| = -log(θ) at various θ > 0
256        for theta in [0.1, 1.0, 5.0] {
257            let jac = log_jacobian_positive(theta);
258            assert!((jac - (-theta.ln())).abs() < EPS);
259        }
260    }
261
262    #[test]
263    fn test_log_jacobian_bounded() {
264        let lo = 0.0;
265        let hi = 1.0;
266        let theta = 0.3;
267        let jac = log_jacobian_bounded(theta, lo, hi);
268        // Expected: -ln(θ - lo) - ln(hi - θ) + ln(hi - lo)
269        let expected = -(theta - lo).ln() - (hi - theta).ln() + (hi - lo).ln();
270        assert!((jac - expected).abs() < EPS);
271    }
272
273    #[test]
274    fn test_transform_spec_unconstrained_roundtrip() {
275        let spec = TransformSpec::unconstrained();
276        for val in [-3.0, 0.0, 7.0] {
277            let eta = spec.to_unconstrained(val).expect("unconstrained ok");
278            let theta = spec.to_constrained(eta);
279            assert!((theta - val).abs() < EPS);
280        }
281    }
282
283    #[test]
284    fn test_transform_spec_positive_roundtrip() {
285        let spec = TransformSpec::positive();
286        for val in [0.01, 1.0, 100.0] {
287            let eta = spec.to_unconstrained(val).expect("positive ok");
288            let theta = spec.to_constrained(eta);
289            assert!(
290                (theta - val).abs() < EPS * val,
291                "Roundtrip failed: {val} -> {eta} -> {theta}"
292            );
293        }
294    }
295
296    #[test]
297    fn test_transform_spec_positive_error() {
298        let spec = TransformSpec::positive();
299        assert!(spec.to_unconstrained(0.0).is_err());
300        assert!(spec.to_unconstrained(-1.0).is_err());
301    }
302
303    #[test]
304    fn test_transform_spec_bounded_roundtrip() {
305        let spec = TransformSpec::bounded(2.0, 8.0);
306        for val in [2.5, 5.0, 7.9] {
307            let eta = spec.to_unconstrained(val).expect("bounded ok");
308            let theta = spec.to_constrained(eta);
309            assert!(
310                (theta - val).abs() < 1e-8,
311                "Roundtrip failed: {val} -> {eta} -> {theta}"
312            );
313        }
314    }
315
316    #[test]
317    fn test_transform_spec_bounded_error() {
318        let spec = TransformSpec::bounded(0.0, 1.0);
319        assert!(spec.to_unconstrained(0.0).is_err()); // boundary (excluded)
320        assert!(spec.to_unconstrained(1.0).is_err()); // boundary (excluded)
321        assert!(spec.to_unconstrained(-0.5).is_err()); // outside
322    }
323
324    #[test]
325    fn test_log_jacobian_inverse_identity() {
326        let spec = TransformSpec::unconstrained();
327        assert!((spec.log_jacobian_inverse(3.14) - 0.0).abs() < EPS);
328    }
329
330    #[test]
331    fn test_log_jacobian_inverse_positive() {
332        let spec = TransformSpec::positive();
333        for eta in [-2.0, 0.0, 1.5] {
334            let jac = spec.log_jacobian_inverse(eta);
335            // log|dθ/dη| = η  (since θ = exp(η))
336            assert!(
337                (jac - eta).abs() < EPS,
338                "log_jacobian_inverse({eta}) = {jac} ≠ {eta}"
339            );
340        }
341    }
342
343    #[test]
344    fn test_log_jacobian_inverse_bounded() {
345        let spec = TransformSpec::bounded(0.0, 1.0);
346        let eta = 0.0; // sigmoid(0) = 0.5
347        let jac = spec.log_jacobian_inverse(eta);
348        // dθ/dη = (hi-lo) σ(1-σ) = 1 · 0.25
349        let expected = (1.0_f64).ln() + 0.5_f64.ln() + 0.5_f64.ln();
350        assert!((jac - expected).abs() < EPS);
351    }
352}