1use crate::dataset::Dataset;
30use crate::error::{Result, ScryLearnError};
31
32pub const MAX_BINS: usize = 256;
34
35#[derive(Clone, Debug)]
44#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
45#[non_exhaustive]
46pub struct FeatureBinner {
47 bin_edges: Vec<Vec<f64>>,
50 n_bins_per_feature: Vec<usize>,
53 max_bins: usize,
54 fitted: bool,
55 #[cfg_attr(feature = "serde", serde(default))]
56 _schema_version: u32,
57}
58
59impl FeatureBinner {
60 pub fn new() -> Self {
68 Self {
69 bin_edges: Vec::new(),
70 n_bins_per_feature: Vec::new(),
71 max_bins: MAX_BINS,
72 fitted: false,
73 _schema_version: crate::version::SCHEMA_VERSION,
74 }
75 }
76
77 pub fn max_bins(mut self, bins: usize) -> Self {
79 self.max_bins = bins.clamp(2, MAX_BINS);
80 self
81 }
82
83 pub fn fit(&mut self, data: &Dataset) -> Result<()> {
89 data.validate_no_inf()?;
90 if data.n_samples() == 0 {
91 return Err(ScryLearnError::EmptyDataset);
92 }
93
94 let n_features = data.n_features();
95 let valid_bins = self.max_bins - 1; self.bin_edges = Vec::with_capacity(n_features);
98 self.n_bins_per_feature = Vec::with_capacity(n_features);
99
100 for f in 0..n_features {
101 let col = &data.features[f];
102
103 let mut valid: Vec<f64> = col.iter().copied().filter(|v| !v.is_nan()).collect();
105 valid.sort_unstable_by(|a, b| a.total_cmp(b));
106
107 if valid.is_empty() {
108 self.bin_edges.push(Vec::new());
110 self.n_bins_per_feature.push(1);
111 continue;
112 }
113
114 valid.dedup();
116
117 let n_unique = valid.len();
118 let actual_bins = n_unique.min(valid_bins);
119
120 if actual_bins <= 1 {
121 self.bin_edges.push(Vec::new());
123 self.n_bins_per_feature.push(1);
124 continue;
125 }
126
127 let mut edges = Vec::with_capacity(actual_bins - 1);
129 for i in 1..actual_bins {
130 let q = i as f64 / actual_bins as f64;
131 let pos = q * (valid.len() - 1) as f64;
132 let lo = pos.floor() as usize;
133 let hi = (lo + 1).min(valid.len() - 1);
134 let frac = pos - lo as f64;
135 let edge = valid[lo] * (1.0 - frac) + valid[hi] * frac;
136 edges.push(edge);
137 }
138
139 edges.dedup_by(|a, b| (*a - *b).abs() < f64::EPSILON);
141
142 let n_valid_bins = edges.len() + 1;
143 self.n_bins_per_feature.push(n_valid_bins);
144 self.bin_edges.push(edges);
145 }
146
147 self.fitted = true;
148 Ok(())
149 }
150
151 pub fn transform(&self, data: &Dataset) -> Result<Vec<Vec<u8>>> {
155 if !self.fitted {
156 return Err(ScryLearnError::NotFitted);
157 }
158 let n_features = data.n_features();
159 if n_features != self.bin_edges.len() {
160 return Err(ScryLearnError::ShapeMismatch {
161 expected: self.bin_edges.len(),
162 got: n_features,
163 });
164 }
165
166 let n_samples = data.n_samples();
167 let mut result = Vec::with_capacity(n_features);
168
169 for f in 0..n_features {
170 let col = &data.features[f];
171 let edges = &self.bin_edges[f];
172 let mut binned = vec![0u8; n_samples];
173
174 for (i, &val) in col.iter().enumerate() {
175 if val.is_nan() {
176 binned[i] = 0; } else {
178 let bin = match edges.binary_search_by(|edge| {
180 edge.partial_cmp(&val).unwrap_or(std::cmp::Ordering::Equal)
181 }) {
182 Ok(pos) => pos + 1, Err(pos) => pos,
184 };
185 binned[i] = (bin + 1).min(255) as u8;
187 }
188 }
189
190 result.push(binned);
191 }
192
193 Ok(result)
194 }
195
196 pub fn fit_transform(&mut self, data: &Dataset) -> Result<Vec<Vec<u8>>> {
198 self.fit(data)?;
199 self.transform(data)
200 }
201
202 pub fn n_bins_per_feature(&self) -> &[usize] {
204 &self.n_bins_per_feature
205 }
206
207 pub fn bin_edges(&self) -> &[Vec<f64>] {
209 &self.bin_edges
210 }
211}
212
213impl Default for FeatureBinner {
214 fn default() -> Self {
215 Self::new()
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 fn simple_dataset() -> Dataset {
224 Dataset::new(
225 vec![
226 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
227 vec![
228 100.0, 200.0, 300.0, 400.0, 500.0, 600.0, 700.0, 800.0, 900.0, 1000.0,
229 ],
230 ],
231 vec![0.0; 10],
232 vec!["a".into(), "b".into()],
233 "y",
234 )
235 }
236
237 #[test]
238 fn test_fit_transform_basic() {
239 let ds = simple_dataset();
240 let mut binner = FeatureBinner::new();
241 let binned = binner.fit_transform(&ds).unwrap();
242 assert_eq!(binned.len(), 2);
243 assert_eq!(binned[0].len(), 10);
244 for &b in &binned[0] {
246 assert!(b >= 1, "valid values should map to bins >= 1");
247 }
248 for i in 1..10 {
250 assert!(binned[0][i] >= binned[0][i - 1]);
251 }
252 }
253
254 #[test]
255 fn test_nan_handling() {
256 let ds = Dataset::new(
257 vec![vec![1.0, f64::NAN, 3.0, f64::NAN, 5.0]],
258 vec![0.0; 5],
259 vec!["x".into()],
260 "y",
261 );
262 let mut binner = FeatureBinner::new();
263 let binned = binner.fit_transform(&ds).unwrap();
264 assert_eq!(binned[0][1], 0, "NaN should map to bin 0");
265 assert_eq!(binned[0][3], 0, "NaN should map to bin 0");
266 assert!(binned[0][0] >= 1, "valid value should be >= 1");
267 }
268
269 #[test]
270 fn test_constant_feature() {
271 let ds = Dataset::new(
272 vec![vec![5.0, 5.0, 5.0, 5.0]],
273 vec![0.0; 4],
274 vec!["x".into()],
275 "y",
276 );
277 let mut binner = FeatureBinner::new();
278 let binned = binner.fit_transform(&ds).unwrap();
279 let first = binned[0][0];
281 for &b in &binned[0] {
282 assert_eq!(b, first);
283 }
284 }
285
286 #[test]
287 fn test_max_bins_param() {
288 let ds = simple_dataset();
289 let mut binner = FeatureBinner::new().max_bins(4);
290 let binned = binner.fit_transform(&ds).unwrap();
291 for &b in &binned[0] {
293 assert!(b <= 3, "with max_bins=4, bin index should be <= 3, got {b}");
294 }
295 }
296
297 #[test]
298 fn test_not_fitted_error() {
299 let ds = simple_dataset();
300 let binner = FeatureBinner::new();
301 let result = binner.transform(&ds);
302 assert!(result.is_err());
303 }
304
305 #[test]
306 fn test_all_nan_feature() {
307 let ds = Dataset::new(
308 vec![vec![f64::NAN, f64::NAN, f64::NAN]],
309 vec![0.0; 3],
310 vec!["x".into()],
311 "y",
312 );
313 let mut binner = FeatureBinner::new();
314 let binned = binner.fit_transform(&ds).unwrap();
315 for &b in &binned[0] {
316 assert_eq!(b, 0, "all-NaN feature should map entirely to bin 0");
317 }
318 }
319}