Skip to main content

scry_learn/tree/
binning.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Feature binning for histogram-based gradient boosting.
3//!
4//! [`FeatureBinner`] quantile-bins continuous `f64` features into `u8`
5//! indices (0..=255). Bin 0 is reserved for missing values (`NaN`);
6//! valid data maps to bins 1..=255.
7//!
8//! # Example
9//! ```
10//! use scry_learn::dataset::Dataset;
11//! use scry_learn::tree::FeatureBinner;
12//!
13//! let features = vec![
14//!     vec![1.0, 2.0, f64::NAN, 4.0, 5.0],
15//!     vec![10.0, 20.0, 30.0, 40.0, 50.0],
16//! ];
17//! let target = vec![0.0; 5];
18//! let data = Dataset::new(features, target, vec!["a".into(), "b".into()], "y");
19//!
20//! let mut binner = FeatureBinner::new();
21//! binner.fit(&data).unwrap();
22//! let binned = binner.transform(&data).unwrap();
23//!
24//! // NaN → bin 0, valid values → bins 1..=255
25//! assert_eq!(binned[0][2], 0);
26//! assert!(binned[1][0] >= 1);
27//! ```
28
29use crate::dataset::Dataset;
30use crate::error::{Result, ScryLearnError};
31
32/// Maximum number of bins (including the missing-value bin 0).
33pub const MAX_BINS: usize = 256;
34
35/// Quantile-based feature binner.
36///
37/// Transforms each feature column into `u8` bin indices using quantile
38/// boundaries computed during `fit()`. Missing values (`NaN`) are
39/// mapped to bin 0; valid values to bins 1–255.
40///
41/// The binning is designed for histogram-based gradient boosting where
42/// the O(256) histogram scan replaces the O(n log n) sorted-split search.
43#[derive(Clone, Debug)]
44#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
45#[non_exhaustive]
46pub struct FeatureBinner {
47    /// Bin edges per feature: `bin_edges[feature][edge_idx]`.
48    /// For K valid bins there are K−1 edges (upper-exclusive boundaries).
49    bin_edges: Vec<Vec<f64>>,
50    /// Number of actual valid bins per feature (may be < 255 for
51    /// low-cardinality features).
52    n_bins_per_feature: Vec<usize>,
53    max_bins: usize,
54    fitted: bool,
55    #[cfg_attr(feature = "serde", serde(default))]
56    _schema_version: u32,
57}
58
59impl FeatureBinner {
60    /// Create a new binner with the default 256 max bins.
61    ///
62    /// # Example
63    /// ```
64    /// use scry_learn::tree::FeatureBinner;
65    /// let binner = FeatureBinner::new();
66    /// ```
67    pub fn new() -> Self {
68        Self {
69            bin_edges: Vec::new(),
70            n_bins_per_feature: Vec::new(),
71            max_bins: MAX_BINS,
72            fitted: false,
73            _schema_version: crate::version::SCHEMA_VERSION,
74        }
75    }
76
77    /// Set the maximum number of bins (2..=256, default 256).
78    pub fn max_bins(mut self, bins: usize) -> Self {
79        self.max_bins = bins.clamp(2, MAX_BINS);
80        self
81    }
82
83    /// Compute bin edges from training data.
84    ///
85    /// For each feature column, sorts the non-NaN values and picks
86    /// equally-spaced quantile boundaries to create up to `max_bins - 1`
87    /// valid bins (bin 0 is reserved for missing).
88    pub fn fit(&mut self, data: &Dataset) -> Result<()> {
89        data.validate_no_inf()?;
90        if data.n_samples() == 0 {
91            return Err(ScryLearnError::EmptyDataset);
92        }
93
94        let n_features = data.n_features();
95        let valid_bins = self.max_bins - 1; // reserve bin 0 for NaN
96
97        self.bin_edges = Vec::with_capacity(n_features);
98        self.n_bins_per_feature = Vec::with_capacity(n_features);
99
100        for f in 0..n_features {
101            let col = &data.features[f];
102
103            // Collect and sort non-NaN values.
104            let mut valid: Vec<f64> = col.iter().copied().filter(|v| !v.is_nan()).collect();
105            valid.sort_unstable_by(|a, b| a.total_cmp(b));
106
107            if valid.is_empty() {
108                // All NaN — single bin (the missing bin).
109                self.bin_edges.push(Vec::new());
110                self.n_bins_per_feature.push(1);
111                continue;
112            }
113
114            // Deduplicate to find unique values.
115            valid.dedup();
116
117            let n_unique = valid.len();
118            let actual_bins = n_unique.min(valid_bins);
119
120            if actual_bins <= 1 {
121                // Constant feature — one valid bin.
122                self.bin_edges.push(Vec::new());
123                self.n_bins_per_feature.push(1);
124                continue;
125            }
126
127            // Compute quantile-based bin edges: `actual_bins - 1` thresholds.
128            let mut edges = Vec::with_capacity(actual_bins - 1);
129            for i in 1..actual_bins {
130                let q = i as f64 / actual_bins as f64;
131                let pos = q * (valid.len() - 1) as f64;
132                let lo = pos.floor() as usize;
133                let hi = (lo + 1).min(valid.len() - 1);
134                let frac = pos - lo as f64;
135                let edge = valid[lo] * (1.0 - frac) + valid[hi] * frac;
136                edges.push(edge);
137            }
138
139            // Remove duplicate edges (low-cardinality features).
140            edges.dedup_by(|a, b| (*a - *b).abs() < f64::EPSILON);
141
142            let n_valid_bins = edges.len() + 1;
143            self.n_bins_per_feature.push(n_valid_bins);
144            self.bin_edges.push(edges);
145        }
146
147        self.fitted = true;
148        Ok(())
149    }
150
151    /// Map features to `u8` bin indices.
152    ///
153    /// Returns `binned[feature_idx][sample_idx]`. NaN → 0, valid → 1..=255.
154    pub fn transform(&self, data: &Dataset) -> Result<Vec<Vec<u8>>> {
155        if !self.fitted {
156            return Err(ScryLearnError::NotFitted);
157        }
158        let n_features = data.n_features();
159        if n_features != self.bin_edges.len() {
160            return Err(ScryLearnError::ShapeMismatch {
161                expected: self.bin_edges.len(),
162                got: n_features,
163            });
164        }
165
166        let n_samples = data.n_samples();
167        let mut result = Vec::with_capacity(n_features);
168
169        for f in 0..n_features {
170            let col = &data.features[f];
171            let edges = &self.bin_edges[f];
172            let mut binned = vec![0u8; n_samples];
173
174            for (i, &val) in col.iter().enumerate() {
175                if val.is_nan() {
176                    binned[i] = 0; // missing-value bin
177                } else {
178                    // Binary search for the correct bin.
179                    let bin = match edges.binary_search_by(|edge| {
180                        edge.partial_cmp(&val).unwrap_or(std::cmp::Ordering::Equal)
181                    }) {
182                        Ok(pos) => pos + 1, // on edge → next bin
183                        Err(pos) => pos,
184                    };
185                    // Shift by 1 because bin 0 is reserved for NaN.
186                    binned[i] = (bin + 1).min(255) as u8;
187                }
188            }
189
190            result.push(binned);
191        }
192
193        Ok(result)
194    }
195
196    /// Combined fit + transform.
197    pub fn fit_transform(&mut self, data: &Dataset) -> Result<Vec<Vec<u8>>> {
198        self.fit(data)?;
199        self.transform(data)
200    }
201
202    /// Number of bins per feature (including the missing-value bin).
203    pub fn n_bins_per_feature(&self) -> &[usize] {
204        &self.n_bins_per_feature
205    }
206
207    /// Bin edges per feature.
208    pub fn bin_edges(&self) -> &[Vec<f64>] {
209        &self.bin_edges
210    }
211}
212
213impl Default for FeatureBinner {
214    fn default() -> Self {
215        Self::new()
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    fn simple_dataset() -> Dataset {
224        Dataset::new(
225            vec![
226                vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
227                vec![
228                    100.0, 200.0, 300.0, 400.0, 500.0, 600.0, 700.0, 800.0, 900.0, 1000.0,
229                ],
230            ],
231            vec![0.0; 10],
232            vec!["a".into(), "b".into()],
233            "y",
234        )
235    }
236
237    #[test]
238    fn test_fit_transform_basic() {
239        let ds = simple_dataset();
240        let mut binner = FeatureBinner::new();
241        let binned = binner.fit_transform(&ds).unwrap();
242        assert_eq!(binned.len(), 2);
243        assert_eq!(binned[0].len(), 10);
244        // All values should be > 0 (no NaN in data).
245        for &b in &binned[0] {
246            assert!(b >= 1, "valid values should map to bins >= 1");
247        }
248        // Monotonicity: sorted input → sorted bins.
249        for i in 1..10 {
250            assert!(binned[0][i] >= binned[0][i - 1]);
251        }
252    }
253
254    #[test]
255    fn test_nan_handling() {
256        let ds = Dataset::new(
257            vec![vec![1.0, f64::NAN, 3.0, f64::NAN, 5.0]],
258            vec![0.0; 5],
259            vec!["x".into()],
260            "y",
261        );
262        let mut binner = FeatureBinner::new();
263        let binned = binner.fit_transform(&ds).unwrap();
264        assert_eq!(binned[0][1], 0, "NaN should map to bin 0");
265        assert_eq!(binned[0][3], 0, "NaN should map to bin 0");
266        assert!(binned[0][0] >= 1, "valid value should be >= 1");
267    }
268
269    #[test]
270    fn test_constant_feature() {
271        let ds = Dataset::new(
272            vec![vec![5.0, 5.0, 5.0, 5.0]],
273            vec![0.0; 4],
274            vec!["x".into()],
275            "y",
276        );
277        let mut binner = FeatureBinner::new();
278        let binned = binner.fit_transform(&ds).unwrap();
279        // All should map to the same bin (>= 1).
280        let first = binned[0][0];
281        for &b in &binned[0] {
282            assert_eq!(b, first);
283        }
284    }
285
286    #[test]
287    fn test_max_bins_param() {
288        let ds = simple_dataset();
289        let mut binner = FeatureBinner::new().max_bins(4);
290        let binned = binner.fit_transform(&ds).unwrap();
291        // With max_bins=4, valid bins are 1..=3, so max bin index should be <= 3.
292        for &b in &binned[0] {
293            assert!(b <= 3, "with max_bins=4, bin index should be <= 3, got {b}");
294        }
295    }
296
297    #[test]
298    fn test_not_fitted_error() {
299        let ds = simple_dataset();
300        let binner = FeatureBinner::new();
301        let result = binner.transform(&ds);
302        assert!(result.is_err());
303    }
304
305    #[test]
306    fn test_all_nan_feature() {
307        let ds = Dataset::new(
308            vec![vec![f64::NAN, f64::NAN, f64::NAN]],
309            vec![0.0; 3],
310            vec!["x".into()],
311            "y",
312        );
313        let mut binner = FeatureBinner::new();
314        let binned = binner.fit_transform(&ds).unwrap();
315        for &b in &binned[0] {
316            assert_eq!(b, 0, "all-NaN feature should map entirely to bin 0");
317        }
318    }
319}