treeboost/preprocessing/
scaler.rs

1//! Scaling transformations for numerical features
2//!
3//! Scalers normalize feature distributions to improve model performance:
4//! - **StandardScaler**: Zero mean, unit variance (most common)
5//! - **MinMaxScaler**: Scale to fixed range [min, max]
6//! - **RobustScaler**: Use median/IQR (robust to outliers)
7//!
8//! # Incremental Support
9//!
10//! StandardScaler and MinMaxScaler support incremental fitting via the
11//! `IncrementalScaler` trait, allowing updates with new data batches:
12//!
13//! ```ignore
14//! use treeboost::preprocessing::{StandardScaler, Scaler};
15//! use treeboost::preprocessing::incremental::IncrementalScaler;
16//!
17//! let mut scaler = StandardScaler::new();
18//! scaler.partial_fit(&batch1, num_features)?;
19//! scaler.partial_fit(&batch2, num_features)?;
20//! // scaler now has statistics from both batches
21//! ```
22//!
23//! # Example
24//!
25//! ```rust
26//! use treeboost::preprocessing::{StandardScaler, Scaler};
27//!
28//! let mut data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2 rows × 3 features
29//! let num_features = 3;
30//!
31//! let mut scaler = StandardScaler::new();
32//! scaler.fit(&data, num_features);
33//! scaler.transform(&mut data, num_features);
34//!
35//! // data is now standardized: (x - mean) / std
36//! ```
37
38use crate::preprocessing::incremental::{IncrementalScaler, WelfordState};
39use crate::{Result, TreeBoostError};
40
41// =============================================================================
42// Scaler Trait
43// =============================================================================
44
45/// Trait for all scalers (fit-transform pattern)
46///
47/// All scalers must implement:
48/// - `fit()`: Learn parameters from training data
49/// - `transform()`: Apply learned parameters to data
50/// - Serialization for train/test consistency
51pub trait Scaler {
52    /// Fit scaler on training data (row-major: num_rows × num_features)
53    ///
54    /// # Arguments
55    /// - `data`: Row-major flat array (row0_feat0, row0_feat1, ..., row1_feat0, ...)
56    /// - `num_features`: Number of features per row
57    fn fit(&mut self, data: &[f32], num_features: usize) -> Result<()>;
58
59    /// Transform data in-place using fitted parameters
60    ///
61    /// # Arguments
62    /// - `data`: Row-major flat array to transform in-place
63    /// - `num_features`: Number of features per row (must match fit)
64    fn transform(&self, data: &mut [f32], num_features: usize) -> Result<()>;
65
66    /// Fit and transform in one step (convenience)
67    fn fit_transform(&mut self, data: &mut [f32], num_features: usize) -> Result<()> {
68        self.fit(data, num_features)?;
69        self.transform(data, num_features)?;
70        Ok(())
71    }
72
73    /// Check if scaler has been fitted
74    fn is_fitted(&self) -> bool;
75}
76
77// =============================================================================
78// StandardScaler
79// =============================================================================
80
81/// StandardScaler: (x - μ) / σ
82///
83/// Transforms features to have zero mean and unit variance.
84///
85/// # Why it helps GBDTs
86/// Even though trees are scale-invariant, scaling improves:
87/// - **Regularization fairness**: L1/L2 penalties applied uniformly
88/// - **Binning uniformity**: Quantiles distributed evenly
89/// - **Numerical stability**: Gradient/Hessian calculations
90/// - **Mixed ensembles**: Combining linear + tree models
91///
92/// # Example
93///
94/// ```rust
95/// use treeboost::preprocessing::{StandardScaler, Scaler};
96///
97/// let mut train = vec![1.0, 10.0, 2.0, 20.0, 3.0, 30.0]; // 3 rows × 2 features
98/// let mut test = vec![1.5, 15.0, 2.5, 25.0]; // 2 rows × 2 features
99///
100/// let mut scaler = StandardScaler::new();
101/// scaler.fit(&train, 2)?;
102/// scaler.transform(&mut train, 2)?;
103/// scaler.transform(&mut test, 2)?; // Use same mean/std from training
104/// # Ok::<(), treeboost::TreeBoostError>(())
105/// ```
106#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
107pub struct StandardScaler {
108    /// Mean of each feature (learned during fit)
109    pub means: Vec<f32>,
110    /// Standard deviation of each feature (learned during fit)
111    pub stds: Vec<f32>,
112    /// Whether fit() has been called
113    fitted: bool,
114    /// Welford state for incremental updates (one per feature)
115    #[serde(default)]
116    welford_states: Vec<WelfordState>,
117    /// Optional forget factor (alpha) for EMA-based rolling window updates (0.0 to 1.0)
118    ///
119    /// When set, statistics are updated using exponential moving average:
120    /// `new_stat = (1 - alpha) * old_stat + alpha * batch_stat`
121    ///
122    /// - `alpha=0.0`: Ignores new batches entirely (not useful)
123    /// - `alpha=0.1`: 10% blend from new batch per update (slow adaptation)
124    /// - `alpha=0.5`: Equal blend of old and new statistics each update
125    /// - `alpha=1.0`: Completely replace with new batch statistics
126    ///
127    /// **Decay behavior**: After N batches, the first batch's influence is `(1-alpha)^N`.
128    /// Example with alpha=0.1: after 10 batches, first batch retains ~35% influence.
129    ///
130    /// Use small values (0.05-0.2) for gradual adaptation to distribution drift.
131    #[serde(default)]
132    forget_factor: Option<f32>,
133}
134
135impl StandardScaler {
136    /// Create a new unfitted StandardScaler
137    pub fn new() -> Self {
138        Self {
139            means: Vec::new(),
140            stds: Vec::new(),
141            fitted: false,
142            welford_states: Vec::new(),
143            forget_factor: None,
144        }
145    }
146
147    /// Create a StandardScaler with EMA-based rolling window updates
148    ///
149    /// # Arguments
150    /// * `forget_factor` - Alpha value between 0.0 and 1.0 (clamped if out of range)
151    ///
152    /// # Example
153    /// ```ignore
154    /// // Create scaler with alpha=0.1 (10% blend from each new batch)
155    /// let mut scaler = StandardScaler::with_forget_factor(0.1);
156    /// scaler.partial_fit(&batch1, num_features)?;  // 100% batch1
157    /// scaler.partial_fit(&batch2, num_features)?;  // 90% batch1, 10% batch2
158    /// scaler.partial_fit(&batch3, num_features)?;  // 81% batch1, 9% batch2, 10% batch3
159    /// ```
160    pub fn with_forget_factor(forget_factor: f32) -> Self {
161        Self {
162            means: Vec::new(),
163            stds: Vec::new(),
164            fitted: false,
165            welford_states: Vec::new(),
166            forget_factor: Some(forget_factor.clamp(0.0, 1.0)),
167        }
168    }
169
170    /// Set the forget factor for EMA-based updates
171    ///
172    /// # Arguments
173    /// * `factor` - Value between 0.0 and 1.0, or None to disable EMA mode
174    pub fn set_forget_factor(&mut self, factor: Option<f32>) {
175        self.forget_factor = factor.map(|f| f.clamp(0.0, 1.0));
176    }
177
178    /// Get the current forget factor
179    pub fn forget_factor(&self) -> Option<f32> {
180        self.forget_factor
181    }
182
183    /// Get the means (only valid after fit)
184    pub fn means(&self) -> &[f32] {
185        &self.means
186    }
187
188    /// Get the standard deviations (only valid after fit)
189    pub fn stds(&self) -> &[f32] {
190        &self.stds
191    }
192
193    /// Sync means/stds from Welford states (internal helper)
194    fn sync_from_welford(&mut self) {
195        let num_features = self.welford_states.len();
196        self.means.resize(num_features, 0.0);
197        self.stds.resize(num_features, 1.0);
198
199        for (i, state) in self.welford_states.iter().enumerate() {
200            self.means[i] = state.mean as f32;
201            let std = state.std() as f32;
202            // Handle zero-variance features (constant column)
203            self.stds[i] = if std < 1e-8 { 1.0 } else { std };
204        }
205    }
206
207    /// Compute mean and variance for a batch (helper for EMA updates)
208    fn compute_batch_stats(data: &[f32], num_features: usize) -> Vec<(f64, f64)> {
209        let num_rows = data.len() / num_features;
210        let mut stats = vec![(0.0f64, 0.0f64); num_features];
211
212        if num_rows == 0 {
213            return stats;
214        }
215
216        // Compute means
217        for feat in 0..num_features {
218            let mut sum = 0.0f64;
219            for row in 0..num_rows {
220                sum += data[row * num_features + feat] as f64;
221            }
222            stats[feat].0 = sum / num_rows as f64;
223        }
224
225        // Compute variances
226        for feat in 0..num_features {
227            let mean = stats[feat].0;
228            let mut variance = 0.0f64;
229            for row in 0..num_rows {
230                let x = data[row * num_features + feat] as f64;
231                variance += (x - mean).powi(2);
232            }
233            stats[feat].1 = variance / num_rows as f64;
234        }
235
236        stats
237    }
238
239    /// EMA-based partial fit for rolling window updates
240    ///
241    /// Uses exponential moving average to decay old statistics:
242    /// new_mean = (1 - alpha) * old_mean + alpha * batch_mean
243    /// new_var = (1 - alpha) * old_var + alpha * batch_var
244    ///
245    /// This allows the scaler to adapt to distribution drift over time.
246    fn partial_fit_ema(&mut self, data: &[f32], num_features: usize, alpha: f32) -> Result<()> {
247        let num_rows = data.len() / num_features;
248        if num_rows == 0 {
249            return Ok(());
250        }
251
252        // Compute batch statistics
253        let batch_stats = Self::compute_batch_stats(data, num_features);
254
255        // First batch: just use batch stats directly
256        if self.means.is_empty() || !self.fitted {
257            self.means = vec![0.0; num_features];
258            self.stds = vec![1.0; num_features];
259            self.welford_states = vec![WelfordState::new(); num_features];
260
261            for feat in 0..num_features {
262                let (mean, var) = batch_stats[feat];
263                self.means[feat] = mean as f32;
264                let std = var.sqrt() as f32;
265                self.stds[feat] = if std < 1e-8 { 1.0 } else { std };
266
267                // Also initialize Welford state for sample counting
268                self.welford_states[feat].n = num_rows as u64;
269                self.welford_states[feat].mean = mean;
270                self.welford_states[feat].m2 = var * num_rows as f64;
271            }
272            self.fitted = true;
273            return Ok(());
274        }
275
276        // Check feature count consistency
277        if self.means.len() != num_features {
278            return Err(TreeBoostError::Data(format!(
279                "num_features mismatch: initialized with {}, partial_fit with {}",
280                self.means.len(),
281                num_features
282            )));
283        }
284
285        // EMA update: new = (1 - alpha) * old + alpha * batch
286        let alpha_64 = alpha as f64;
287        let decay = 1.0 - alpha_64;
288
289        for feat in 0..num_features {
290            let (batch_mean, batch_var) = batch_stats[feat];
291
292            // Update mean via EMA
293            let old_mean = self.means[feat] as f64;
294            let new_mean = decay * old_mean + alpha_64 * batch_mean;
295            self.means[feat] = new_mean as f32;
296
297            // Update variance via EMA
298            // Note: This is approximate for variance, but works well in practice
299            let old_var = (self.stds[feat] as f64).powi(2);
300            let new_var = decay * old_var + alpha_64 * batch_var;
301            let new_std = new_var.sqrt() as f32;
302            self.stds[feat] = if new_std < 1e-8 { 1.0 } else { new_std };
303
304            // Update sample count (approximate effective samples)
305            self.welford_states[feat].n += num_rows as u64;
306        }
307
308        Ok(())
309    }
310}
311
312impl Default for StandardScaler {
313    fn default() -> Self {
314        Self::new()
315    }
316}
317
318impl Scaler for StandardScaler {
319    fn fit(&mut self, data: &[f32], num_features: usize) -> Result<()> {
320        if num_features == 0 {
321            return Err(TreeBoostError::Data("num_features must be > 0".into()));
322        }
323
324        if !data.len().is_multiple_of(num_features) {
325            return Err(TreeBoostError::Data(format!(
326                "Data length {} not divisible by num_features {}",
327                data.len(),
328                num_features
329            )));
330        }
331
332        let num_rows = data.len() / num_features;
333
334        if num_rows == 0 {
335            return Err(TreeBoostError::Data("No rows to fit".into()));
336        }
337
338        self.means = vec![0.0; num_features];
339        self.stds = vec![0.0; num_features];
340
341        // Compute means
342        for feat in 0..num_features {
343            let mut sum = 0.0;
344            for row in 0..num_rows {
345                sum += data[row * num_features + feat];
346            }
347            self.means[feat] = sum / num_rows as f32;
348        }
349
350        // Compute standard deviations
351        for feat in 0..num_features {
352            let mean = self.means[feat];
353            let mut variance = 0.0;
354            for row in 0..num_rows {
355                let x = data[row * num_features + feat];
356                variance += (x - mean).powi(2);
357            }
358            let std = (variance / num_rows as f32).sqrt();
359
360            // Handle zero-variance features (constant column)
361            self.stds[feat] = if std < 1e-8 { 1.0 } else { std };
362        }
363
364        self.fitted = true;
365        Ok(())
366    }
367
368    fn transform(&self, data: &mut [f32], num_features: usize) -> Result<()> {
369        if !self.fitted {
370            return Err(TreeBoostError::Data(
371                "StandardScaler not fitted. Call fit() first.".into(),
372            ));
373        }
374
375        if num_features != self.means.len() {
376            return Err(TreeBoostError::Data(format!(
377                "num_features mismatch: fit with {}, transform with {}",
378                self.means.len(),
379                num_features
380            )));
381        }
382
383        if !data.len().is_multiple_of(num_features) {
384            return Err(TreeBoostError::Data(format!(
385                "Data length {} not divisible by num_features {}",
386                data.len(),
387                num_features
388            )));
389        }
390
391        let num_rows = data.len() / num_features;
392
393        // Apply standardization: (x - mean) / std
394        for feat in 0..num_features {
395            let mean = self.means[feat];
396            let std = self.stds[feat];
397            for row in 0..num_rows {
398                let idx = row * num_features + feat;
399                data[idx] = (data[idx] - mean) / std;
400            }
401        }
402
403        Ok(())
404    }
405
406    fn is_fitted(&self) -> bool {
407        self.fitted
408    }
409}
410
411impl IncrementalScaler for StandardScaler {
412    fn partial_fit(&mut self, data: &[f32], num_features: usize) -> Result<()> {
413        if num_features == 0 {
414            return Err(TreeBoostError::Data("num_features must be > 0".into()));
415        }
416
417        if !data.len().is_multiple_of(num_features) {
418            return Err(TreeBoostError::Data(format!(
419                "Data length {} not divisible by num_features {}",
420                data.len(),
421                num_features
422            )));
423        }
424
425        let num_rows = data.len() / num_features;
426        if num_rows == 0 {
427            return Ok(()); // Nothing to do
428        }
429
430        // If forget_factor is set, use EMA-based updates
431        if let Some(alpha) = self.forget_factor {
432            return self.partial_fit_ema(data, num_features, alpha);
433        }
434
435        // Standard Welford-based incremental update (cumulative, no decay)
436
437        // Initialize Welford states if this is the first call
438        if self.welford_states.is_empty() {
439            self.welford_states = vec![WelfordState::new(); num_features];
440        } else if self.welford_states.len() != num_features {
441            return Err(TreeBoostError::Data(format!(
442                "num_features mismatch: initialized with {}, partial_fit with {}",
443                self.welford_states.len(),
444                num_features
445            )));
446        }
447
448        // Update Welford states with new data
449        for row in 0..num_rows {
450            for feat in 0..num_features {
451                let x = data[row * num_features + feat] as f64;
452                if x.is_finite() {
453                    self.welford_states[feat].update(x);
454                }
455            }
456        }
457
458        // Sync mean/std from Welford states
459        self.sync_from_welford();
460        self.fitted = true;
461
462        Ok(())
463    }
464
465    fn n_samples(&self) -> u64 {
466        self.welford_states.first().map(|s| s.n).unwrap_or(0)
467    }
468
469    fn merge(&mut self, other: &Self) -> Result<()> {
470        if self.welford_states.is_empty() {
471            // Copy from other
472            self.welford_states = other.welford_states.clone();
473            self.sync_from_welford();
474            self.fitted = other.fitted;
475            return Ok(());
476        }
477
478        if other.welford_states.is_empty() {
479            return Ok(()); // Nothing to merge
480        }
481
482        if self.welford_states.len() != other.welford_states.len() {
483            return Err(TreeBoostError::Data(format!(
484                "Cannot merge scalers with different num_features: {} vs {}",
485                self.welford_states.len(),
486                other.welford_states.len()
487            )));
488        }
489
490        // Merge Welford states using Chan's parallel algorithm
491        for (self_state, other_state) in self.welford_states.iter_mut().zip(&other.welford_states) {
492            self_state.merge(other_state);
493        }
494
495        self.sync_from_welford();
496        Ok(())
497    }
498}
499
500// =============================================================================
501// MinMaxScaler
502// =============================================================================
503
504/// MinMaxScaler: (x - min) / (max - min) * (b - a) + a
505///
506/// Scales features to a fixed range [a, b] (default [0, 1]).
507///
508/// # Use cases
509/// - When you need features in a specific range (e.g., [0, 1] for neural nets)
510/// - When you know the expected min/max bounds
511///
512/// # Warning
513/// - Sensitive to outliers (one extreme value affects entire scale)
514/// - Consider RobustScaler if outliers are present
515///
516/// # Example
517///
518/// ```rust
519/// use treeboost::preprocessing::{MinMaxScaler, Scaler};
520///
521/// let mut data = vec![1.0, 5.0, 2.0, 10.0, 3.0, 15.0]; // 3 rows × 2 features
522///
523/// let mut scaler = MinMaxScaler::new().with_range(0.0, 1.0);
524/// scaler.fit(&data, 2)?;
525/// scaler.transform(&mut data, 2)?;
526///
527/// // data is now in [0, 1] range
528/// # Ok::<(), treeboost::TreeBoostError>(())
529/// ```
530#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
531pub struct MinMaxScaler {
532    /// Minimum of each feature (learned during fit)
533    pub mins: Vec<f32>,
534    /// Maximum of each feature (learned during fit)
535    pub maxs: Vec<f32>,
536    /// Output range (a, b)
537    pub feature_range: (f32, f32),
538    /// Whether fit() has been called
539    fitted: bool,
540    /// Number of samples seen (for incremental fitting)
541    #[serde(default)]
542    n_samples: u64,
543}
544
545impl MinMaxScaler {
546    /// Create a new unfitted MinMaxScaler with default range [0, 1]
547    pub fn new() -> Self {
548        Self {
549            mins: Vec::new(),
550            maxs: Vec::new(),
551            feature_range: (0.0, 1.0),
552            fitted: false,
553            n_samples: 0,
554        }
555    }
556
557    /// Set the output range (default is [0, 1])
558    pub fn with_range(mut self, min: f32, max: f32) -> Self {
559        self.feature_range = (min, max);
560        self
561    }
562}
563
564impl Default for MinMaxScaler {
565    fn default() -> Self {
566        Self::new()
567    }
568}
569
570impl Scaler for MinMaxScaler {
571    fn fit(&mut self, data: &[f32], num_features: usize) -> Result<()> {
572        if num_features == 0 {
573            return Err(TreeBoostError::Data("num_features must be > 0".into()));
574        }
575
576        if !data.len().is_multiple_of(num_features) {
577            return Err(TreeBoostError::Data(format!(
578                "Data length {} not divisible by num_features {}",
579                data.len(),
580                num_features
581            )));
582        }
583
584        let num_rows = data.len() / num_features;
585
586        if num_rows == 0 {
587            return Err(TreeBoostError::Data("No rows to fit".into()));
588        }
589
590        self.mins = vec![f32::INFINITY; num_features];
591        self.maxs = vec![f32::NEG_INFINITY; num_features];
592
593        // Find min and max for each feature
594        for feat in 0..num_features {
595            for row in 0..num_rows {
596                let val = data[row * num_features + feat];
597                self.mins[feat] = self.mins[feat].min(val);
598                self.maxs[feat] = self.maxs[feat].max(val);
599            }
600
601            // Handle constant features (min == max)
602            if (self.maxs[feat] - self.mins[feat]).abs() < 1e-8 {
603                self.maxs[feat] = self.mins[feat] + 1.0;
604            }
605        }
606
607        self.fitted = true;
608        Ok(())
609    }
610
611    fn transform(&self, data: &mut [f32], num_features: usize) -> Result<()> {
612        if !self.fitted {
613            return Err(TreeBoostError::Data(
614                "MinMaxScaler not fitted. Call fit() first.".into(),
615            ));
616        }
617
618        if num_features != self.mins.len() {
619            return Err(TreeBoostError::Data(format!(
620                "num_features mismatch: fit with {}, transform with {}",
621                self.mins.len(),
622                num_features
623            )));
624        }
625
626        if !data.len().is_multiple_of(num_features) {
627            return Err(TreeBoostError::Data(format!(
628                "Data length {} not divisible by num_features {}",
629                data.len(),
630                num_features
631            )));
632        }
633
634        let num_rows = data.len() / num_features;
635        let (a, b) = self.feature_range;
636
637        // Apply scaling: (x - min) / (max - min) * (b - a) + a
638        for feat in 0..num_features {
639            let min = self.mins[feat];
640            let max = self.maxs[feat];
641            let scale = b - a;
642
643            for row in 0..num_rows {
644                let idx = row * num_features + feat;
645                data[idx] = (data[idx] - min) / (max - min) * scale + a;
646
647                // Clip to range (handles out-of-bound values in test set)
648                data[idx] = data[idx].clamp(a, b);
649            }
650        }
651
652        Ok(())
653    }
654
655    fn is_fitted(&self) -> bool {
656        self.fitted
657    }
658}
659
660impl IncrementalScaler for MinMaxScaler {
661    fn partial_fit(&mut self, data: &[f32], num_features: usize) -> Result<()> {
662        if num_features == 0 {
663            return Err(TreeBoostError::Data("num_features must be > 0".into()));
664        }
665
666        if !data.len().is_multiple_of(num_features) {
667            return Err(TreeBoostError::Data(format!(
668                "Data length {} not divisible by num_features {}",
669                data.len(),
670                num_features
671            )));
672        }
673
674        let num_rows = data.len() / num_features;
675        if num_rows == 0 {
676            return Ok(()); // Nothing to do
677        }
678
679        // Initialize mins/maxs if this is the first call
680        if self.mins.is_empty() {
681            self.mins = vec![f32::INFINITY; num_features];
682            self.maxs = vec![f32::NEG_INFINITY; num_features];
683        } else if self.mins.len() != num_features {
684            return Err(TreeBoostError::Data(format!(
685                "num_features mismatch: initialized with {}, partial_fit with {}",
686                self.mins.len(),
687                num_features
688            )));
689        }
690
691        // Update min/max with new data (monotonic expansion)
692        for row in 0..num_rows {
693            for feat in 0..num_features {
694                let val = data[row * num_features + feat];
695                if val.is_finite() {
696                    self.mins[feat] = self.mins[feat].min(val);
697                    self.maxs[feat] = self.maxs[feat].max(val);
698                }
699            }
700        }
701
702        // Handle constant features
703        for feat in 0..num_features {
704            if (self.maxs[feat] - self.mins[feat]).abs() < 1e-8 {
705                self.maxs[feat] = self.mins[feat] + 1.0;
706            }
707        }
708
709        self.n_samples += num_rows as u64;
710        self.fitted = true;
711
712        Ok(())
713    }
714
715    fn n_samples(&self) -> u64 {
716        self.n_samples
717    }
718
719    fn merge(&mut self, other: &Self) -> Result<()> {
720        if self.mins.is_empty() {
721            // Copy from other
722            self.mins = other.mins.clone();
723            self.maxs = other.maxs.clone();
724            self.n_samples = other.n_samples;
725            self.fitted = other.fitted;
726            return Ok(());
727        }
728
729        if other.mins.is_empty() {
730            return Ok(()); // Nothing to merge
731        }
732
733        if self.mins.len() != other.mins.len() {
734            return Err(TreeBoostError::Data(format!(
735                "Cannot merge scalers with different num_features: {} vs {}",
736                self.mins.len(),
737                other.mins.len()
738            )));
739        }
740
741        // Merge min/max (take min of mins, max of maxs)
742        for i in 0..self.mins.len() {
743            self.mins[i] = self.mins[i].min(other.mins[i]);
744            self.maxs[i] = self.maxs[i].max(other.maxs[i]);
745        }
746
747        self.n_samples += other.n_samples;
748        Ok(())
749    }
750}
751
752// =============================================================================
753// RobustScaler
754// =============================================================================
755
756/// RobustScaler: (x - median) / IQR
757///
758/// Scales features using statistics robust to outliers:
759/// - Center: median (instead of mean)
760/// - Scale: IQR = Q3 - Q1 (instead of std)
761///
762/// # Use cases
763/// - Data with outliers or heavy-tailed distributions
764/// - When mean/std are unreliable due to extreme values
765///
766/// # Example
767///
768/// ```rust
769/// use treeboost::preprocessing::{RobustScaler, Scaler};
770///
771/// let mut data = vec![1.0, 2.0, 3.0, 100.0, 5.0, 6.0]; // 3 rows × 2 features (outlier: 100)
772///
773/// let mut scaler = RobustScaler::new();
774/// scaler.fit(&data, 2)?;
775/// scaler.transform(&mut data, 2)?;
776///
777/// // Outlier (100) has less impact than with StandardScaler
778/// # Ok::<(), treeboost::TreeBoostError>(())
779/// ```
780#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
781pub struct RobustScaler {
782    /// Median of each feature (learned during fit)
783    pub medians: Vec<f32>,
784    /// IQR (Q3 - Q1) of each feature (learned during fit)
785    pub iqrs: Vec<f32>,
786    /// Whether fit() has been called
787    fitted: bool,
788}
789
790impl RobustScaler {
791    /// Create a new unfitted RobustScaler
792    pub fn new() -> Self {
793        Self {
794            medians: Vec::new(),
795            iqrs: Vec::new(),
796            fitted: false,
797        }
798    }
799}
800
801impl Default for RobustScaler {
802    fn default() -> Self {
803        Self::new()
804    }
805}
806
807impl Scaler for RobustScaler {
808    fn fit(&mut self, data: &[f32], num_features: usize) -> Result<()> {
809        if num_features == 0 {
810            return Err(TreeBoostError::Data("num_features must be > 0".into()));
811        }
812
813        if !data.len().is_multiple_of(num_features) {
814            return Err(TreeBoostError::Data(format!(
815                "Data length {} not divisible by num_features {}",
816                data.len(),
817                num_features
818            )));
819        }
820
821        let num_rows = data.len() / num_features;
822
823        if num_rows == 0 {
824            return Err(TreeBoostError::Data("No rows to fit".into()));
825        }
826
827        self.medians = vec![0.0; num_features];
828        self.iqrs = vec![0.0; num_features];
829
830        // Use T-Digest for O(n) quantile estimation instead of O(n log n) sorting
831        // This is critical for large datasets (100M+ rows)
832        use tdigest::TDigest;
833
834        // Compute median and IQR for each feature using approximate quantiles
835        for feat in 0..num_features {
836            // Build T-Digest for this feature column
837            let mut digest = TDigest::new_with_size(100); // 100 centroids is accurate enough
838
839            for row in 0..num_rows {
840                let value = data[row * num_features + feat] as f64;
841                if value.is_finite() {
842                    digest = digest.merge_unsorted(vec![value]);
843                }
844            }
845
846            // Get approximate quantiles from T-Digest
847            let q1 = digest.estimate_quantile(0.25) as f32;
848            let median = digest.estimate_quantile(0.50) as f32;
849            let q3 = digest.estimate_quantile(0.75) as f32;
850
851            self.medians[feat] = median;
852
853            // IQR = Q3 - Q1
854            let iqr = q3 - q1;
855
856            // Handle zero IQR (all values in Q1-Q3 range are same)
857            self.iqrs[feat] = if iqr < 1e-8 { 1.0 } else { iqr };
858        }
859
860        self.fitted = true;
861        Ok(())
862    }
863
864    fn transform(&self, data: &mut [f32], num_features: usize) -> Result<()> {
865        if !self.fitted {
866            return Err(TreeBoostError::Data(
867                "RobustScaler not fitted. Call fit() first.".into(),
868            ));
869        }
870
871        if num_features != self.medians.len() {
872            return Err(TreeBoostError::Data(format!(
873                "num_features mismatch: fit with {}, transform with {}",
874                self.medians.len(),
875                num_features
876            )));
877        }
878
879        if !data.len().is_multiple_of(num_features) {
880            return Err(TreeBoostError::Data(format!(
881                "Data length {} not divisible by num_features {}",
882                data.len(),
883                num_features
884            )));
885        }
886
887        let num_rows = data.len() / num_features;
888
889        // Apply robust scaling: (x - median) / IQR
890        for feat in 0..num_features {
891            let median = self.medians[feat];
892            let iqr = self.iqrs[feat];
893            for row in 0..num_rows {
894                let idx = row * num_features + feat;
895                data[idx] = (data[idx] - median) / iqr;
896            }
897        }
898
899        Ok(())
900    }
901
902    fn is_fitted(&self) -> bool {
903        self.fitted
904    }
905}
906
907// =============================================================================
908// Tests
909// =============================================================================
910
911#[cfg(test)]
912mod tests {
913    use super::*;
914
915    #[test]
916    fn test_standard_scaler_basic() {
917        let mut data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
918        // 2 rows × 3 features
919        // Row 0: [1.0, 2.0, 3.0]
920        // Row 1: [4.0, 5.0, 6.0]
921        let num_features = 3;
922
923        let mut scaler = StandardScaler::new();
924        assert!(!scaler.is_fitted());
925
926        scaler.fit(&data, num_features).unwrap();
927        assert!(scaler.is_fitted());
928
929        // Check means (column averages)
930        // Feature 0: (1.0 + 4.0) / 2 = 2.5
931        // Feature 1: (2.0 + 5.0) / 2 = 3.5
932        // Feature 2: (3.0 + 6.0) / 2 = 4.5
933        assert_eq!(scaler.means(), &[2.5, 3.5, 4.5]);
934
935        scaler.transform(&mut data, num_features).unwrap();
936
937        // After standardization, mean should be ~0, std should be ~1
938    }
939
940    #[test]
941    fn test_standard_scaler_zero_variance() {
942        let mut data = vec![5.0, 1.0, 2.0, 5.0, 3.0, 4.0];
943        // 2 rows × 3 features
944        // Row 0: [5.0, 1.0, 2.0]
945        // Row 1: [5.0, 3.0, 4.0]
946        // Feature 0 is constant: [5.0, 5.0]
947        let num_features = 3;
948
949        let mut scaler = StandardScaler::new();
950        scaler.fit(&data, num_features).unwrap();
951
952        // Zero-variance feature should have std = 1.0 (fallback)
953        assert_eq!(scaler.stds[0], 1.0);
954        assert_eq!(scaler.means[0], 5.0);
955
956        // Transform should not panic
957        scaler.transform(&mut data, num_features).unwrap();
958    }
959
960    #[test]
961    fn test_minmax_scaler_basic() {
962        let mut data = vec![1.0, 10.0, 2.0, 20.0, 3.0, 30.0]; // 3 rows × 2 features
963        let num_features = 2;
964
965        let mut scaler = MinMaxScaler::new();
966        scaler.fit(&data, num_features).unwrap();
967
968        assert_eq!(scaler.mins, vec![1.0, 10.0]);
969        assert_eq!(scaler.maxs, vec![3.0, 30.0]);
970
971        scaler.transform(&mut data, num_features).unwrap();
972
973        // First feature: [1, 2, 3] → [0.0, 0.5, 1.0]
974        assert!((data[0] - 0.0).abs() < 1e-6);
975        assert!((data[2] - 0.5).abs() < 1e-6);
976        assert!((data[4] - 1.0).abs() < 1e-6);
977
978        // Second feature: [10, 20, 30] → [0.0, 0.5, 1.0]
979        assert!((data[1] - 0.0).abs() < 1e-6);
980        assert!((data[3] - 0.5).abs() < 1e-6);
981        assert!((data[5] - 1.0).abs() < 1e-6);
982    }
983
984    #[test]
985    fn test_minmax_scaler_custom_range() {
986        let mut data = vec![1.0, 2.0, 3.0]; // 3 rows × 1 feature
987        let num_features = 1;
988
989        let mut scaler = MinMaxScaler::new().with_range(-1.0, 1.0);
990        scaler.fit(&data, num_features).unwrap();
991        scaler.transform(&mut data, num_features).unwrap();
992
993        // [1, 2, 3] → [-1.0, 0.0, 1.0]
994        assert!((data[0] - (-1.0)).abs() < 1e-6);
995        assert!((data[1] - 0.0).abs() < 1e-6);
996        assert!((data[2] - 1.0).abs() < 1e-6);
997    }
998
999    #[test]
1000    fn test_robust_scaler_basic() {
1001        let mut data = vec![1.0, 2.0, 3.0, 100.0]; // 2 rows × 2 features (outlier: 100)
1002        let num_features = 2;
1003
1004        let mut scaler = RobustScaler::new();
1005        scaler.fit(&data, num_features).unwrap();
1006
1007        // Medians: [2.0, 51.0] (avg of middle two values)
1008        assert!((scaler.medians[0] - 2.0).abs() < 1e-6);
1009
1010        scaler.transform(&mut data, num_features).unwrap();
1011
1012        // Check that outlier doesn't dominate (median-based scaling)
1013    }
1014
1015    #[test]
1016    fn test_scaler_not_fitted_error() {
1017        let mut data = vec![1.0, 2.0, 3.0];
1018        let scaler = StandardScaler::new();
1019
1020        let result = scaler.transform(&mut data, 1);
1021        assert!(result.is_err());
1022        assert!(result.unwrap_err().to_string().contains("not fitted"));
1023    }
1024
1025    #[test]
1026    fn test_scaler_feature_mismatch_error() {
1027        let data = vec![1.0, 2.0, 3.0, 4.0];
1028        let mut scaler = StandardScaler::new();
1029
1030        scaler.fit(&data, 2).unwrap(); // Fit with 2 features
1031
1032        let mut test_data = vec![5.0, 6.0, 7.0];
1033        let result = scaler.transform(&mut test_data, 3); // Try to transform with 3 features
1034
1035        assert!(result.is_err());
1036        assert!(result.unwrap_err().to_string().contains("mismatch"));
1037    }
1038
1039    // =========================================================================
1040    // Incremental Scaler Tests
1041    // =========================================================================
1042
1043    #[test]
1044    fn test_standard_scaler_incremental_equivalence() {
1045        // Test that partial_fit on chunks equals fit on all data
1046        let all_data: Vec<f32> = (0..1000).map(|i| i as f32).collect();
1047        let num_features = 1;
1048
1049        // Scaler A: fit() on all data
1050        let mut scaler_a = StandardScaler::new();
1051        scaler_a.fit(&all_data, num_features).unwrap();
1052
1053        // Scaler B: partial_fit() on 10 chunks of 100
1054        let mut scaler_b = StandardScaler::new();
1055        for chunk in all_data.chunks(100) {
1056            scaler_b.partial_fit(chunk, num_features).unwrap();
1057        }
1058
1059        // Verify equivalence (within f32 epsilon)
1060        assert!(
1061            (scaler_a.means[0] - scaler_b.means[0]).abs() < 1e-3,
1062            "Means differ: {} vs {}",
1063            scaler_a.means[0],
1064            scaler_b.means[0]
1065        );
1066        assert!(
1067            (scaler_a.stds[0] - scaler_b.stds[0]).abs() < 1e-3,
1068            "Stds differ: {} vs {}",
1069            scaler_a.stds[0],
1070            scaler_b.stds[0]
1071        );
1072
1073        // Verify n_samples tracked correctly
1074        assert_eq!(scaler_b.n_samples(), 1000);
1075    }
1076
1077    #[test]
1078    fn test_standard_scaler_welford_stability() {
1079        // Test numerical stability with large offset
1080        let offset = 1e8_f32;
1081        let data: Vec<f32> = (0..100).map(|i| offset + i as f32).collect();
1082        let num_features = 1;
1083
1084        let mut scaler = StandardScaler::new();
1085        scaler.partial_fit(&data, num_features).unwrap();
1086
1087        // Mean should be offset + 49.5
1088        let expected_mean = offset + 49.5;
1089        assert!(
1090            (scaler.means[0] - expected_mean).abs() < 1.0,
1091            "Mean with large offset: got {}, expected {}",
1092            scaler.means[0],
1093            expected_mean
1094        );
1095    }
1096
1097    #[test]
1098    fn test_standard_scaler_merge() {
1099        let num_features = 2;
1100
1101        // Scaler A: [1, 2, 3, 4] for 2 features
1102        let mut scaler_a = StandardScaler::new();
1103        scaler_a
1104            .partial_fit(&[1.0, 10.0, 2.0, 20.0], num_features)
1105            .unwrap();
1106
1107        // Scaler B: [3, 4, 5, 6] for 2 features
1108        let mut scaler_b = StandardScaler::new();
1109        scaler_b
1110            .partial_fit(&[3.0, 30.0, 4.0, 40.0], num_features)
1111            .unwrap();
1112
1113        // Merge B into A
1114        scaler_a.merge(&scaler_b).unwrap();
1115
1116        // Should be equivalent to fitting on all 4 rows
1117        assert_eq!(scaler_a.n_samples(), 4);
1118
1119        // Feature 0: [1, 2, 3, 4] → mean = 2.5
1120        assert!((scaler_a.means[0] - 2.5).abs() < 1e-5);
1121
1122        // Feature 1: [10, 20, 30, 40] → mean = 25.0
1123        assert!((scaler_a.means[1] - 25.0).abs() < 1e-4);
1124    }
1125
1126    #[test]
1127    fn test_minmax_scaler_incremental() {
1128        let num_features = 2;
1129
1130        let mut scaler = MinMaxScaler::new();
1131
1132        // Batch 1: data in [0, 50] for feat 0, [0, 100] for feat 1
1133        scaler
1134            .partial_fit(&[0.0, 0.0, 50.0, 100.0], num_features)
1135            .unwrap();
1136        assert_eq!(scaler.mins, vec![0.0, 0.0]);
1137        assert_eq!(scaler.maxs, vec![50.0, 100.0]);
1138
1139        // Batch 2: data in [25, 100] for feat 0, [50, 200] for feat 1
1140        scaler
1141            .partial_fit(&[25.0, 50.0, 100.0, 200.0], num_features)
1142            .unwrap();
1143
1144        // Min/max should expand monotonically
1145        assert_eq!(scaler.mins, vec![0.0, 0.0]); // min stays at 0
1146        assert_eq!(scaler.maxs, vec![100.0, 200.0]); // max expands
1147
1148        assert_eq!(scaler.n_samples(), 4);
1149    }
1150
1151    #[test]
1152    fn test_minmax_scaler_merge() {
1153        let num_features = 1;
1154
1155        let mut scaler_a = MinMaxScaler::new();
1156        scaler_a.partial_fit(&[10.0, 20.0], num_features).unwrap();
1157
1158        let mut scaler_b = MinMaxScaler::new();
1159        scaler_b.partial_fit(&[5.0, 30.0], num_features).unwrap();
1160
1161        scaler_a.merge(&scaler_b).unwrap();
1162
1163        // Should have min=5, max=30 (union of ranges)
1164        assert_eq!(scaler_a.mins, vec![5.0]);
1165        assert_eq!(scaler_a.maxs, vec![30.0]);
1166        assert_eq!(scaler_a.n_samples(), 4);
1167    }
1168
1169    // =========================================================================
1170    // Rolling Window / EMA Tests
1171    // =========================================================================
1172
1173    #[test]
1174    fn test_standard_scaler_forget_factor_creation() {
1175        let scaler = StandardScaler::with_forget_factor(0.1);
1176        assert_eq!(scaler.forget_factor(), Some(0.1));
1177
1178        let mut scaler2 = StandardScaler::new();
1179        assert_eq!(scaler2.forget_factor(), None);
1180
1181        scaler2.set_forget_factor(Some(0.5));
1182        assert_eq!(scaler2.forget_factor(), Some(0.5));
1183
1184        scaler2.set_forget_factor(None);
1185        assert_eq!(scaler2.forget_factor(), None);
1186    }
1187
1188    #[test]
1189    fn test_standard_scaler_forget_factor_clamping() {
1190        let scaler = StandardScaler::with_forget_factor(-0.5);
1191        assert_eq!(scaler.forget_factor(), Some(0.0));
1192
1193        let scaler2 = StandardScaler::with_forget_factor(1.5);
1194        assert_eq!(scaler2.forget_factor(), Some(1.0));
1195    }
1196
1197    #[test]
1198    fn test_standard_scaler_ema_single_batch() {
1199        // First batch should set stats directly
1200        let num_features = 1;
1201        let data = vec![10.0, 20.0, 30.0, 40.0];
1202
1203        let mut scaler = StandardScaler::with_forget_factor(0.1);
1204        scaler.partial_fit(&data, num_features).unwrap();
1205
1206        assert!(scaler.is_fitted());
1207        // Mean should be 25.0
1208        assert!((scaler.means()[0] - 25.0).abs() < 0.01);
1209    }
1210
1211    #[test]
1212    fn test_standard_scaler_ema_decay() {
1213        let num_features = 1;
1214
1215        // Batch 1: mean=10
1216        let batch1 = vec![8.0, 10.0, 12.0];
1217        // Batch 2: mean=100 (shifted distribution)
1218        let batch2 = vec![98.0, 100.0, 102.0];
1219
1220        let mut scaler = StandardScaler::with_forget_factor(0.3);
1221        scaler.partial_fit(&batch1, num_features).unwrap();
1222
1223        let mean_after_batch1 = scaler.means()[0];
1224        assert!((mean_after_batch1 - 10.0).abs() < 0.01);
1225
1226        // After batch 2 with alpha=0.3:
1227        // new_mean = 0.7 * 10 + 0.3 * 100 = 7 + 30 = 37
1228        scaler.partial_fit(&batch2, num_features).unwrap();
1229
1230        let mean_after_batch2 = scaler.means()[0];
1231        assert!(
1232            (mean_after_batch2 - 37.0).abs() < 0.5,
1233            "Expected ~37, got {}",
1234            mean_after_batch2
1235        );
1236    }
1237
1238    #[test]
1239    fn test_standard_scaler_ema_vs_cumulative() {
1240        let num_features = 1;
1241
1242        // Batch 1: mean=10
1243        let batch1 = vec![8.0, 10.0, 12.0];
1244        // Batch 2: mean=100 (shifted distribution)
1245        let batch2 = vec![98.0, 100.0, 102.0];
1246
1247        // Cumulative (no forget factor)
1248        let mut cumulative = StandardScaler::new();
1249        cumulative.partial_fit(&batch1, num_features).unwrap();
1250        cumulative.partial_fit(&batch2, num_features).unwrap();
1251
1252        // EMA with high forget factor
1253        let mut ema = StandardScaler::with_forget_factor(0.5);
1254        ema.partial_fit(&batch1, num_features).unwrap();
1255        ema.partial_fit(&batch2, num_features).unwrap();
1256
1257        // Cumulative mean: (10 + 100) / 2 = 55 (equal weight to all samples)
1258        let cumulative_mean = cumulative.means()[0];
1259
1260        // EMA mean: 0.5 * 10 + 0.5 * 100 = 55 (with alpha=0.5)
1261        let ema_mean = ema.means()[0];
1262
1263        // With alpha=0.5, both should be similar but EMA weights batch means, not sample means
1264        assert!((cumulative_mean - 55.0).abs() < 1.0);
1265        assert!((ema_mean - 55.0).abs() < 1.0);
1266    }
1267
1268    #[test]
1269    fn test_standard_scaler_ema_adapts_to_drift() {
1270        let num_features = 1;
1271
1272        // Start with mean=10
1273        let batch1 = vec![8.0, 10.0, 12.0];
1274
1275        // Series of batches with drifting mean
1276        let batch2 = vec![28.0, 30.0, 32.0]; // mean=30
1277        let batch3 = vec![48.0, 50.0, 52.0]; // mean=50
1278        let batch4 = vec![68.0, 70.0, 72.0]; // mean=70
1279        let batch5 = vec![88.0, 90.0, 92.0]; // mean=90
1280
1281        let mut scaler = StandardScaler::with_forget_factor(0.5); // 50% weight to new batch
1282
1283        scaler.partial_fit(&batch1, num_features).unwrap();
1284        assert!((scaler.means()[0] - 10.0).abs() < 1.0);
1285
1286        scaler.partial_fit(&batch2, num_features).unwrap();
1287        // 0.5 * 10 + 0.5 * 30 = 20
1288        assert!(
1289            (scaler.means()[0] - 20.0).abs() < 1.0,
1290            "Expected ~20, got {}",
1291            scaler.means()[0]
1292        );
1293
1294        scaler.partial_fit(&batch3, num_features).unwrap();
1295        // 0.5 * 20 + 0.5 * 50 = 35
1296        assert!(
1297            (scaler.means()[0] - 35.0).abs() < 1.0,
1298            "Expected ~35, got {}",
1299            scaler.means()[0]
1300        );
1301
1302        scaler.partial_fit(&batch4, num_features).unwrap();
1303        // 0.5 * 35 + 0.5 * 70 = 52.5
1304        assert!(
1305            (scaler.means()[0] - 52.5).abs() < 1.0,
1306            "Expected ~52.5, got {}",
1307            scaler.means()[0]
1308        );
1309
1310        scaler.partial_fit(&batch5, num_features).unwrap();
1311        // 0.5 * 52.5 + 0.5 * 90 = 71.25
1312        assert!(
1313            (scaler.means()[0] - 71.25).abs() < 1.5,
1314            "Expected ~71.25, got {}",
1315            scaler.means()[0]
1316        );
1317    }
1318
1319    #[test]
1320    fn test_standard_scaler_ema_variance_decay() {
1321        let num_features = 1;
1322
1323        // Low variance batch
1324        let batch1 = vec![9.9, 10.0, 10.1]; // std ≈ 0.08
1325
1326        // High variance batch
1327        let batch2 = vec![0.0, 10.0, 20.0]; // std ≈ 8.16
1328
1329        let mut scaler = StandardScaler::with_forget_factor(0.3);
1330
1331        scaler.partial_fit(&batch1, num_features).unwrap();
1332        let std_after_batch1 = scaler.stds()[0];
1333        assert!(
1334            std_after_batch1 < 1.0,
1335            "Std should be small after low-variance batch"
1336        );
1337
1338        scaler.partial_fit(&batch2, num_features).unwrap();
1339        let std_after_batch2 = scaler.stds()[0];
1340
1341        // Std should increase (EMA blend of low and high variance)
1342        assert!(
1343            std_after_batch2 > std_after_batch1,
1344            "Std should increase after high-variance batch"
1345        );
1346    }
1347}