1use std::fmt;
12
13use crate::stats::population_variance;
14
15#[derive(Debug, Clone, PartialEq)]
19pub enum TransformError {
20 NonPositiveData,
22 InsufficientData,
24 InvalidInverse,
26}
27
28impl fmt::Display for TransformError {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 match self {
31 TransformError::NonPositiveData => {
32 write!(f, "Box-Cox requires all y > 0")
33 }
34 TransformError::InsufficientData => {
35 write!(f, "need at least 2 data points")
36 }
37 TransformError::InvalidInverse => {
38 write!(f, "inverse transformation produced non-finite values")
39 }
40 }
41 }
42}
43
44impl std::error::Error for TransformError {}
45
46fn validate_positive_slice(y: &[f64]) -> Result<(), TransformError> {
49 if y.len() < 2 {
50 return Err(TransformError::InsufficientData);
51 }
52 if y.iter().any(|&v| v <= 0.0) {
53 return Err(TransformError::NonPositiveData);
54 }
55 Ok(())
56}
57
58pub fn box_cox(y: &[f64], lambda: f64) -> Result<Vec<f64>, TransformError> {
83 validate_positive_slice(y)?;
84 let result = if lambda.abs() < 1e-10 {
85 y.iter().map(|&v| v.ln()).collect()
86 } else {
87 y.iter().map(|&v| (v.powf(lambda) - 1.0) / lambda).collect()
88 };
89 Ok(result)
90}
91
92pub fn inverse_box_cox(y_t: &[f64], lambda: f64) -> Result<Vec<f64>, TransformError> {
119 let result: Vec<f64> = if lambda.abs() < 1e-10 {
120 y_t.iter().map(|&v| v.exp()).collect()
121 } else {
122 y_t.iter()
123 .map(|&v| (v * lambda + 1.0).powf(1.0 / lambda))
124 .collect()
125 };
126 if result.iter().any(|v| !v.is_finite()) {
127 return Err(TransformError::InvalidInverse);
128 }
129 Ok(result)
130}
131
132pub fn estimate_lambda(y: &[f64], lambda_min: f64, lambda_max: f64) -> Result<f64, TransformError> {
157 if lambda_min >= lambda_max {
158 return Err(TransformError::InsufficientData);
159 }
160 validate_positive_slice(y)?;
161
162 let n = y.len() as f64;
163 let log_sum: f64 = y.iter().map(|&v| v.ln()).sum::<f64>();
164
165 let profile_ll = |lambda: f64| -> f64 {
167 let y_t: Vec<f64> = if lambda.abs() < 1e-10 {
168 y.iter().map(|&v| v.ln()).collect()
169 } else {
170 y.iter().map(|&v| (v.powf(lambda) - 1.0) / lambda).collect()
171 };
172 let var = population_variance(&y_t)
173 .expect("slice has >= 2 elements — variance is defined");
174 if var <= 0.0 {
175 return f64::NEG_INFINITY;
176 }
177 -(n / 2.0) * var.ln() + (lambda - 1.0) * log_sum
178 };
179
180 const PHI: f64 = 0.618_033_988_749_895; let mut a = lambda_min;
183 let mut b = lambda_max;
184
185 let mut x1 = b - PHI * (b - a);
186 let mut x2 = a + PHI * (b - a);
187 let mut f1 = profile_ll(x1);
188 let mut f2 = profile_ll(x2);
189
190 for _ in 0..100 {
191 if (b - a).abs() < 1e-6 {
192 break;
193 }
194 if f1 < f2 {
195 a = x1;
196 x1 = x2;
197 f1 = f2;
198 x2 = a + PHI * (b - a);
199 f2 = profile_ll(x2);
200 } else {
201 b = x2;
202 x2 = x1;
203 f2 = f1;
204 x1 = b - PHI * (b - a);
205 f1 = profile_ll(x1);
206 }
207 }
208
209 Ok((a + b) / 2.0)
210}
211
212#[cfg(test)]
215mod tests {
216 use super::*;
217
218 #[test]
219 fn box_cox_log_transform() {
220 let y = vec![1.0, std::f64::consts::E, std::f64::consts::E.powi(2)];
222 let y_t = box_cox(&y, 0.0).unwrap();
223 assert!((y_t[0] - 0.0).abs() < 1e-10);
224 assert!((y_t[1] - 1.0).abs() < 1e-9);
225 assert!((y_t[2] - 2.0).abs() < 1e-9);
226 }
227
228 #[test]
229 fn box_cox_identity_lambda_1() {
230 let y = vec![2.0, 5.0, 10.0];
232 let y_t = box_cox(&y, 1.0).unwrap();
233 assert!((y_t[0] - 1.0).abs() < 1e-10);
234 assert!((y_t[1] - 4.0).abs() < 1e-10);
235 }
236
237 #[test]
238 fn box_cox_sqrt_lambda_half() {
239 let y = vec![4.0, 9.0];
241 let y_t = box_cox(&y, 0.5).unwrap();
242 assert!((y_t[0] - 2.0).abs() < 1e-10); assert!((y_t[1] - 4.0).abs() < 1e-10); }
245
246 #[test]
247 fn inverse_roundtrip_multiple_lambdas() {
248 let y = vec![1.5, 2.3, 4.7, 8.1, 15.2];
249 for &lambda in &[-0.5_f64, 0.0, 0.5, 1.0, 2.0] {
250 let y_t = box_cox(&y, lambda).unwrap();
251 let y_rec = inverse_box_cox(&y_t, lambda).unwrap();
252 for (orig, rec) in y.iter().zip(y_rec.iter()) {
253 assert!(
254 (orig - rec).abs() < 1e-9,
255 "lambda={lambda} orig={orig} rec={rec}"
256 );
257 }
258 }
259 }
260
261 #[test]
262 fn estimate_lambda_near_zero_for_exponential() {
263 let y: Vec<f64> = (1..=30).map(|i| (i as f64 * 0.2).exp()).collect();
265 let lambda = estimate_lambda(&y, -2.0, 2.0).unwrap();
266 assert!(lambda.abs() < 0.3, "Expected lambda near 0, got {lambda}");
267 }
268
269 #[test]
270 fn estimate_lambda_near_half_for_quadratic() {
271 let y: Vec<f64> = (1..=20).map(|i| (i as f64).powi(2)).collect();
273 let lambda = estimate_lambda(&y, -2.0, 2.0).unwrap();
274 assert!(
275 lambda > 0.2 && lambda < 0.8,
276 "Expected lambda ~0.5, got {lambda}"
277 );
278 }
279
280 #[test]
281 fn non_positive_returns_error() {
282 assert!(box_cox(&[1.0, -1.0, 2.0], 0.5).is_err());
283 assert!(box_cox(&[0.0, 1.0, 2.0], 0.5).is_err());
284 }
285
286 #[test]
287 fn insufficient_data_returns_error() {
288 assert!(box_cox(&[1.0], 0.5).is_err());
289 assert!(estimate_lambda(&[1.0], -2.0, 2.0).is_err());
290 }
291
292 #[test]
293 fn inverse_invalid_returns_error() {
294 let y_t = vec![-1.0, -0.8];
297 assert!(inverse_box_cox(&y_t, 2.0).is_err());
298 }
299
300 #[test]
301 fn estimate_lambda_invalid_range() {
302 let y = vec![1.0, 2.0, 3.0, 4.0];
303 assert!(estimate_lambda(&y, 1.0, 0.0).is_err()); assert!(estimate_lambda(&y, 0.5, 0.5).is_err()); }
306
307 #[test]
308 fn box_cox_negative_lambda() {
309 let y = vec![2.0, 4.0];
311 let y_t = box_cox(&y, -1.0).unwrap();
312 assert!((y_t[0] - 0.5).abs() < 1e-10);
314 assert!((y_t[1] - 0.75).abs() < 1e-10);
316 }
317}