scirs2_stats/advi/
transforms.rs1use crate::error::{StatsError, StatsResult};
9
10use super::types::ConstraintType;
11
12#[inline]
20pub fn log_transform(x: f64) -> f64 {
21 x.ln()
22}
23
24#[inline]
29pub fn logit_transform(x: f64, lo: f64, hi: f64) -> f64 {
30 let s = (x - lo) / (hi - lo);
31 (s / (1.0 - s)).ln()
32}
33
34pub fn softmax_transform(x: &[f64]) -> Vec<f64> {
39 if x.is_empty() {
40 return Vec::new();
41 }
42 let max_val = x.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
43 let exps: Vec<f64> = x.iter().map(|&v| (v - max_val).exp()).collect();
44 let sum: f64 = exps.iter().sum();
45 exps.iter().map(|&e| e / sum).collect()
46}
47
48#[inline]
54pub fn log_jacobian_positive(x: f64) -> f64 {
55 -x.ln()
58}
59
60#[inline]
66pub fn log_jacobian_bounded(x: f64, lo: f64, hi: f64) -> f64 {
67 let range = hi - lo;
68 -(x - lo).ln() - (hi - x).ln() + range.ln()
69}
70
71#[derive(Debug, Clone, PartialEq)]
80pub struct TransformSpec {
81 pub constraint: ConstraintType,
83}
84
85impl TransformSpec {
86 pub fn new(constraint: ConstraintType) -> Self {
88 Self { constraint }
89 }
90
91 pub fn unconstrained() -> Self {
93 Self::new(ConstraintType::Unconstrained)
94 }
95
96 pub fn positive() -> Self {
98 Self::new(ConstraintType::Positive)
99 }
100
101 pub fn bounded(lo: f64, hi: f64) -> Self {
103 Self::new(ConstraintType::Bounded { lo, hi })
104 }
105
106 pub fn to_unconstrained(&self, theta: f64) -> StatsResult<f64> {
110 match &self.constraint {
111 ConstraintType::Unconstrained => Ok(theta),
112 ConstraintType::Positive => {
113 if theta <= 0.0 {
114 return Err(StatsError::invalid_argument(format!(
115 "Positive constraint violated: θ = {} must be > 0",
116 theta
117 )));
118 }
119 Ok(log_transform(theta))
120 }
121 ConstraintType::Bounded { lo, hi } => {
122 if theta <= *lo || theta >= *hi {
123 return Err(StatsError::invalid_argument(format!(
124 "Bounded constraint violated: θ = {} must lie in ({}, {})",
125 theta, lo, hi
126 )));
127 }
128 Ok(logit_transform(theta, *lo, *hi))
129 }
130 ConstraintType::Simplex => {
131 Ok(theta)
134 }
135 }
136 }
137
138 pub fn to_constrained(&self, eta: f64) -> f64 {
140 match &self.constraint {
141 ConstraintType::Unconstrained => eta,
142 ConstraintType::Positive => eta.exp(),
143 ConstraintType::Bounded { lo, hi } => {
144 let s = sigmoid(eta);
145 lo + (hi - lo) * s
146 }
147 ConstraintType::Simplex => eta,
148 }
149 }
150
151 pub fn log_jacobian_inverse(&self, eta: f64) -> f64 {
154 match &self.constraint {
155 ConstraintType::Unconstrained => 0.0,
156 ConstraintType::Positive => {
157 eta
159 }
160 ConstraintType::Bounded { lo, hi } => {
161 let range = hi - lo;
164 let s = sigmoid(eta);
165 range.ln() + s.ln() + (1.0 - s).ln()
166 }
167 ConstraintType::Simplex => 0.0,
168 }
169 }
170}
171
172#[inline]
174pub(crate) fn sigmoid(x: f64) -> f64 {
175 if x >= 0.0 {
176 1.0 / (1.0 + (-x).exp())
177 } else {
178 let ex = x.exp();
179 ex / (1.0 + ex)
180 }
181}
182
183#[cfg(test)]
188mod tests {
189 use super::*;
190
191 const EPS: f64 = 1e-10;
192
193 #[test]
194 fn test_log_transform_roundtrip() {
195 for x in [0.001, 0.1, 1.0, 10.0, 1000.0] {
196 let eta = log_transform(x);
197 let recovered = eta.exp();
198 assert!(
199 (recovered - x).abs() < EPS * x.max(1.0),
200 "Roundtrip failed for x={}: got {}",
201 x,
202 recovered
203 );
204 }
205 }
206
207 #[test]
208 fn test_logit_transform_range() {
209 let lo = -2.0;
210 let hi = 5.0;
211 for x in [-1.5, 0.0, 1.0, 3.0, 4.5] {
213 let eta = logit_transform(x, lo, hi);
214 assert!(
215 eta.is_finite(),
216 "logit_transform({}, {}, {}) = {} is not finite",
217 x,
218 lo,
219 hi,
220 eta
221 );
222 }
223 let near_lo = logit_transform(lo + 1e-10, lo, hi);
225 let near_hi = logit_transform(hi - 1e-10, lo, hi);
226 assert!(near_lo < -20.0, "Near lo should give large negative value");
227 assert!(near_hi > 20.0, "Near hi should give large positive value");
228 }
229
230 #[test]
231 fn test_softmax_sums_one() {
232 let x = vec![1.0, 2.0, 3.0, -1.0, 0.5];
233 let p = softmax_transform(&x);
234 let sum: f64 = p.iter().sum();
235 assert!((sum - 1.0).abs() < 1e-12, "Softmax sum = {} ≠ 1", sum);
236 for &pi in &p {
237 assert!(pi >= 0.0 && pi <= 1.0, "Probability {} out of [0,1]", pi);
238 }
239 }
240
241 #[test]
242 fn test_softmax_empty() {
243 let p = softmax_transform(&[]);
244 assert!(p.is_empty());
245 }
246
247 #[test]
248 fn test_softmax_single() {
249 let p = softmax_transform(&[3.7]);
250 assert!((p[0] - 1.0).abs() < 1e-12);
251 }
252
253 #[test]
254 fn test_log_jacobian_positive() {
255 for theta in [0.1, 1.0, 5.0] {
257 let jac = log_jacobian_positive(theta);
258 assert!((jac - (-theta.ln())).abs() < EPS);
259 }
260 }
261
262 #[test]
263 fn test_log_jacobian_bounded() {
264 let lo = 0.0;
265 let hi = 1.0;
266 let theta = 0.3;
267 let jac = log_jacobian_bounded(theta, lo, hi);
268 let expected = -(theta - lo).ln() - (hi - theta).ln() + (hi - lo).ln();
270 assert!((jac - expected).abs() < EPS);
271 }
272
273 #[test]
274 fn test_transform_spec_unconstrained_roundtrip() {
275 let spec = TransformSpec::unconstrained();
276 for val in [-3.0, 0.0, 7.0] {
277 let eta = spec.to_unconstrained(val).expect("unconstrained ok");
278 let theta = spec.to_constrained(eta);
279 assert!((theta - val).abs() < EPS);
280 }
281 }
282
283 #[test]
284 fn test_transform_spec_positive_roundtrip() {
285 let spec = TransformSpec::positive();
286 for val in [0.01, 1.0, 100.0] {
287 let eta = spec.to_unconstrained(val).expect("positive ok");
288 let theta = spec.to_constrained(eta);
289 assert!(
290 (theta - val).abs() < EPS * val,
291 "Roundtrip failed: {val} -> {eta} -> {theta}"
292 );
293 }
294 }
295
296 #[test]
297 fn test_transform_spec_positive_error() {
298 let spec = TransformSpec::positive();
299 assert!(spec.to_unconstrained(0.0).is_err());
300 assert!(spec.to_unconstrained(-1.0).is_err());
301 }
302
303 #[test]
304 fn test_transform_spec_bounded_roundtrip() {
305 let spec = TransformSpec::bounded(2.0, 8.0);
306 for val in [2.5, 5.0, 7.9] {
307 let eta = spec.to_unconstrained(val).expect("bounded ok");
308 let theta = spec.to_constrained(eta);
309 assert!(
310 (theta - val).abs() < 1e-8,
311 "Roundtrip failed: {val} -> {eta} -> {theta}"
312 );
313 }
314 }
315
316 #[test]
317 fn test_transform_spec_bounded_error() {
318 let spec = TransformSpec::bounded(0.0, 1.0);
319 assert!(spec.to_unconstrained(0.0).is_err()); assert!(spec.to_unconstrained(1.0).is_err()); assert!(spec.to_unconstrained(-0.5).is_err()); }
323
324 #[test]
325 fn test_log_jacobian_inverse_identity() {
326 let spec = TransformSpec::unconstrained();
327 assert!((spec.log_jacobian_inverse(3.14) - 0.0).abs() < EPS);
328 }
329
330 #[test]
331 fn test_log_jacobian_inverse_positive() {
332 let spec = TransformSpec::positive();
333 for eta in [-2.0, 0.0, 1.5] {
334 let jac = spec.log_jacobian_inverse(eta);
335 assert!(
337 (jac - eta).abs() < EPS,
338 "log_jacobian_inverse({eta}) = {jac} ≠ {eta}"
339 );
340 }
341 }
342
343 #[test]
344 fn test_log_jacobian_inverse_bounded() {
345 let spec = TransformSpec::bounded(0.0, 1.0);
346 let eta = 0.0; let jac = spec.log_jacobian_inverse(eta);
348 let expected = (1.0_f64).ln() + 0.5_f64.ln() + 0.5_f64.ln();
350 assert!((jac - expected).abs() < EPS);
351 }
352}