1use scirs2_core::ndarray::{Array1, Array2};
8use std::marker::PhantomData;
9
10use sklears_core::{
11 error::{Result, SklearsError},
12 traits::{Estimator, Fit, Trained, Transform, Untrained},
13 types::Float,
14};
15
16#[derive(Debug, Clone)]
18pub struct BinarizerConfig {
19 pub threshold: Float,
21 pub copy: bool,
23}
24
25impl Default for BinarizerConfig {
26 fn default() -> Self {
27 Self {
28 threshold: 0.0,
29 copy: true,
30 }
31 }
32}
33
34pub struct Binarizer<State = Untrained> {
36 config: BinarizerConfig,
37 state: PhantomData<State>,
38}
39
40impl Binarizer<Untrained> {
41 pub fn new() -> Self {
43 Self {
44 config: BinarizerConfig::default(),
45 state: PhantomData,
46 }
47 }
48
49 pub fn with_threshold(threshold: Float) -> Self {
51 Self {
52 config: BinarizerConfig {
53 threshold,
54 copy: true,
55 },
56 state: PhantomData,
57 }
58 }
59
60 pub fn threshold(mut self, threshold: Float) -> Self {
62 self.config.threshold = threshold;
63 self
64 }
65
66 pub fn copy(mut self, copy: bool) -> Self {
68 self.config.copy = copy;
69 self
70 }
71}
72
73impl Default for Binarizer<Untrained> {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79impl Estimator for Binarizer<Untrained> {
80 type Config = BinarizerConfig;
81 type Error = SklearsError;
82 type Float = Float;
83
84 fn config(&self) -> &Self::Config {
85 &self.config
86 }
87}
88
89impl Estimator for Binarizer<Trained> {
90 type Config = BinarizerConfig;
91 type Error = SklearsError;
92 type Float = Float;
93
94 fn config(&self) -> &Self::Config {
95 &self.config
96 }
97}
98
99impl Fit<Array2<Float>, ()> for Binarizer<Untrained> {
100 type Fitted = Binarizer<Trained>;
101
102 fn fit(self, _x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
103 Ok(Binarizer {
105 config: self.config,
106 state: PhantomData,
107 })
108 }
109}
110
111impl Transform<Array2<Float>, Array2<Float>> for Binarizer<Trained> {
112 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
113 let result = if self.config.copy {
114 x.clone()
115 } else {
116 x.to_owned()
117 };
118
119 Ok(result.mapv(|v| if v > self.config.threshold { 1.0 } else { 0.0 }))
120 }
121}
122
123impl Transform<Array1<Float>, Array1<Float>> for Binarizer<Trained> {
124 fn transform(&self, x: &Array1<Float>) -> Result<Array1<Float>> {
125 let result = if self.config.copy {
126 x.clone()
127 } else {
128 x.to_owned()
129 };
130
131 Ok(result.mapv(|v| if v > self.config.threshold { 1.0 } else { 0.0 }))
132 }
133}
134
135#[derive(Debug, Clone, Copy, PartialEq)]
137pub enum DiscretizationStrategy {
138 Uniform,
140 Quantile,
142 KMeans,
144}
145
146#[derive(Debug, Clone, Copy, PartialEq)]
148pub enum DiscretizerEncoding {
149 OneHot,
151 Ordinal,
153}
154
155#[derive(Debug, Clone)]
157pub struct KBinsDiscretizerConfig {
158 pub n_bins: usize,
160 pub encode: DiscretizerEncoding,
162 pub strategy: DiscretizationStrategy,
164 pub subsample: Option<usize>,
166 pub random_state: Option<u64>,
168}
169
170impl Default for KBinsDiscretizerConfig {
171 fn default() -> Self {
172 Self {
173 n_bins: 5,
174 encode: DiscretizerEncoding::OneHot,
175 strategy: DiscretizationStrategy::Quantile,
176 subsample: Some(200_000),
177 random_state: None,
178 }
179 }
180}
181
182pub struct KBinsDiscretizer<State = Untrained> {
184 config: KBinsDiscretizerConfig,
185 state: PhantomData<State>,
186 bin_edges_: Option<Vec<Array1<Float>>>,
188 n_bins_: Option<Vec<usize>>,
190}
191
192impl KBinsDiscretizer<Untrained> {
193 pub fn new() -> Self {
195 Self {
196 config: KBinsDiscretizerConfig::default(),
197 state: PhantomData,
198 bin_edges_: None,
199 n_bins_: None,
200 }
201 }
202
203 pub fn n_bins(mut self, n_bins: usize) -> Self {
205 if n_bins < 2 {
206 panic!("n_bins must be at least 2");
207 }
208 self.config.n_bins = n_bins;
209 self
210 }
211
212 pub fn encode(mut self, encode: DiscretizerEncoding) -> Self {
214 self.config.encode = encode;
215 self
216 }
217
218 pub fn strategy(mut self, strategy: DiscretizationStrategy) -> Self {
220 self.config.strategy = strategy;
221 self
222 }
223}
224
225impl Default for KBinsDiscretizer<Untrained> {
226 fn default() -> Self {
227 Self::new()
228 }
229}
230
231impl Estimator for KBinsDiscretizer<Untrained> {
232 type Config = KBinsDiscretizerConfig;
233 type Error = SklearsError;
234 type Float = Float;
235
236 fn config(&self) -> &Self::Config {
237 &self.config
238 }
239}
240
241impl Estimator for KBinsDiscretizer<Trained> {
242 type Config = KBinsDiscretizerConfig;
243 type Error = SklearsError;
244 type Float = Float;
245
246 fn config(&self) -> &Self::Config {
247 &self.config
248 }
249}
250
251fn compute_uniform_bins(data: &Array1<Float>, n_bins: usize) -> Array1<Float> {
253 let min_val = data.iter().cloned().fold(Float::INFINITY, Float::min);
254 let max_val = data.iter().cloned().fold(Float::NEG_INFINITY, Float::max);
255
256 if (max_val - min_val).abs() < Float::EPSILON {
257 return Array1::from_vec(vec![min_val - 0.5, max_val + 0.5]);
259 }
260
261 let width = (max_val - min_val) / n_bins as Float;
262 let mut edges = Vec::with_capacity(n_bins + 1);
263
264 for i in 0..=n_bins {
265 edges.push(min_val + i as Float * width);
266 }
267
268 edges[n_bins] = max_val + Float::EPSILON;
270
271 Array1::from_vec(edges)
272}
273
274fn compute_quantile_bins(data: &Array1<Float>, n_bins: usize) -> Array1<Float> {
276 let mut sorted_data = data.to_vec();
277 sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
278
279 let n_samples = sorted_data.len();
280 let mut edges = Vec::with_capacity(n_bins + 1);
281
282 edges.push(sorted_data[0]);
284
285 for i in 1..n_bins {
287 let idx = (i * n_samples) / n_bins;
288 let value = sorted_data[idx.min(n_samples - 1)];
289
290 if value > edges.last().unwrap() + Float::EPSILON {
292 edges.push(value);
293 }
294 }
295
296 edges.push(sorted_data[n_samples - 1] + Float::EPSILON);
298
299 if edges.len() < 3 {
301 edges.clear();
302 edges.push(sorted_data[0]);
303 edges.push(sorted_data[n_samples - 1] + Float::EPSILON);
304 }
305
306 Array1::from_vec(edges)
307}
308
309impl Fit<Array2<Float>, ()> for KBinsDiscretizer<Untrained> {
310 type Fitted = KBinsDiscretizer<Trained>;
311
312 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
313 let n_features = x.ncols();
314 let mut bin_edges = Vec::with_capacity(n_features);
315 let mut n_bins = Vec::with_capacity(n_features);
316
317 for j in 0..n_features {
319 let feature_data = x.column(j).to_owned();
320
321 let edges = match self.config.strategy {
322 DiscretizationStrategy::Uniform => {
323 compute_uniform_bins(&feature_data, self.config.n_bins)
324 }
325 DiscretizationStrategy::Quantile => {
326 compute_quantile_bins(&feature_data, self.config.n_bins)
327 }
328 DiscretizationStrategy::KMeans => {
329 compute_quantile_bins(&feature_data, self.config.n_bins)
331 }
332 };
333
334 n_bins.push(edges.len() - 1);
335 bin_edges.push(edges);
336 }
337
338 Ok(KBinsDiscretizer {
339 config: self.config,
340 state: PhantomData,
341 bin_edges_: Some(bin_edges),
342 n_bins_: Some(n_bins),
343 })
344 }
345}
346
347fn find_bin(value: Float, edges: &Array1<Float>) -> usize {
349 let n_edges = edges.len();
351
352 if value <= edges[0] {
353 return 0;
354 }
355 if value >= edges[n_edges - 1] {
356 return n_edges - 2;
357 }
358
359 let mut left = 0;
360 let mut right = n_edges - 1;
361
362 while left < right - 1 {
363 let mid = (left + right) / 2;
364 if value < edges[mid] {
365 right = mid;
366 } else {
367 left = mid;
368 }
369 }
370
371 left
372}
373
374impl Transform<Array2<Float>, Array2<Float>> for KBinsDiscretizer<Trained> {
375 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
376 let n_samples = x.nrows();
377 let n_features = x.ncols();
378 let bin_edges = self.bin_edges_.as_ref().unwrap();
379 let n_bins = self.n_bins_.as_ref().unwrap();
380
381 match self.config.encode {
382 DiscretizerEncoding::Ordinal => {
383 let mut result = Array2::zeros((n_samples, n_features));
384
385 for i in 0..n_samples {
386 for j in 0..n_features {
387 let bin_idx = find_bin(x[[i, j]], &bin_edges[j]);
388 result[[i, j]] = bin_idx as Float;
389 }
390 }
391
392 Ok(result)
393 }
394 DiscretizerEncoding::OneHot => {
395 let total_bins: usize = n_bins.iter().sum();
397 let mut result = Array2::zeros((n_samples, total_bins));
398
399 for i in 0..n_samples {
400 let mut col_offset = 0;
401 for j in 0..n_features {
402 let bin_idx = find_bin(x[[i, j]], &bin_edges[j]);
403 result[[i, col_offset + bin_idx]] = 1.0;
404 col_offset += n_bins[j];
405 }
406 }
407
408 Ok(result)
409 }
410 }
411 }
412}
413
414impl KBinsDiscretizer<Trained> {
415 pub fn bin_edges(&self) -> &Vec<Array1<Float>> {
417 self.bin_edges_.as_ref().unwrap()
418 }
419
420 pub fn n_bins(&self) -> &Vec<usize> {
422 self.n_bins_.as_ref().unwrap()
423 }
424
425 pub fn inverse_transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
427 let bin_edges = self.bin_edges_.as_ref().unwrap();
428 let n_features = bin_edges.len();
429
430 match self.config.encode {
431 DiscretizerEncoding::Ordinal => {
432 if x.ncols() != n_features {
433 return Err(SklearsError::InvalidInput(
434 "Input must have the same number of features as during fit".to_string(),
435 ));
436 }
437
438 let mut result = Array2::zeros(x.dim());
439
440 for i in 0..x.nrows() {
441 for j in 0..n_features {
442 let bin_idx = x[[i, j]] as usize;
443 let edges = &bin_edges[j];
444
445 if bin_idx >= edges.len() - 1 {
446 return Err(SklearsError::InvalidInput(format!(
447 "Invalid bin index {bin_idx} for feature {j}"
448 )));
449 }
450
451 result[[i, j]] = (edges[bin_idx] + edges[bin_idx + 1]) / 2.0;
453 }
454 }
455
456 Ok(result)
457 }
458 DiscretizerEncoding::OneHot => Err(SklearsError::InvalidInput(
459 "Inverse transform not supported for one-hot encoding".to_string(),
460 )),
461 }
462 }
463}
464
465#[allow(non_snake_case)]
466#[cfg(test)]
467mod tests {
468 use super::*;
469 use scirs2_core::ndarray::array;
470
471 #[test]
472 fn test_binarizer() {
473 let x = array![[1.0, -1.0, 2.0], [2.0, 0.0, 0.0], [0.0, 1.0, -1.0],];
474
475 let binarizer = Binarizer::with_threshold(0.0).fit(&x, &()).unwrap();
476
477 let x_bin = binarizer.transform(&x).unwrap();
478
479 let expected = array![[1.0, 0.0, 1.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0],];
480
481 assert_eq!(x_bin, expected);
482 }
483
484 #[test]
485 fn test_binarizer_custom_threshold() {
486 let x = array![[1.0, 2.0, 3.0, 4.0]];
487
488 let binarizer = Binarizer::new().threshold(2.5).fit(&x, &()).unwrap();
489
490 let x_bin = binarizer.transform(&x).unwrap();
491 let expected = array![[0.0, 0.0, 1.0, 1.0]];
492
493 assert_eq!(x_bin, expected);
494 }
495
496 #[test]
497 fn test_binarizer_1d() {
498 let x = array![1.0, -1.0, 2.0, 0.0];
499
500 let binarizer = Binarizer::new().fit(&array![[0.0]], &()).unwrap();
501
502 let x_bin = binarizer.transform(&x).unwrap();
503 let expected = array![1.0, 0.0, 1.0, 0.0];
504
505 assert_eq!(x_bin, expected);
506 }
507
508 #[test]
509 fn test_kbins_discretizer_uniform() {
510 let x = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0],];
511
512 let discretizer = KBinsDiscretizer::new()
513 .n_bins(3)
514 .strategy(DiscretizationStrategy::Uniform)
515 .encode(DiscretizerEncoding::Ordinal)
516 .fit(&x, &())
517 .unwrap();
518
519 let x_disc = discretizer.transform(&x).unwrap();
520
521 assert_eq!(
524 x_disc.column(0).to_vec(),
525 vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0]
526 );
527 }
528
529 #[test]
530 fn test_kbins_discretizer_quantile() {
531 let x = array![[0.0], [1.0], [1.0], [2.0], [3.0], [10.0],];
532
533 let discretizer = KBinsDiscretizer::new()
534 .n_bins(3)
535 .strategy(DiscretizationStrategy::Quantile)
536 .encode(DiscretizerEncoding::Ordinal)
537 .fit(&x, &())
538 .unwrap();
539
540 let x_disc = discretizer.transform(&x).unwrap();
541
542 let bin_counts = vec![
544 x_disc.iter().filter(|&&v| v == 0.0).count(),
545 x_disc.iter().filter(|&&v| v == 1.0).count(),
546 x_disc.iter().filter(|&&v| v == 2.0).count(),
547 ];
548
549 for count in bin_counts {
551 assert!(count >= 1 && count <= 3);
552 }
553 }
554
555 #[test]
556 fn test_kbins_discretizer_onehot() {
557 let x = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0],];
558
559 let discretizer = KBinsDiscretizer::new()
560 .n_bins(2)
561 .encode(DiscretizerEncoding::OneHot)
562 .fit(&x, &())
563 .unwrap();
564
565 let x_disc = discretizer.transform(&x).unwrap();
566
567 assert_eq!(x_disc.ncols(), 4);
569
570 for i in 0..x_disc.nrows() {
572 let row_sum: Float = x_disc.row(i).sum();
573 assert_eq!(row_sum, 2.0);
574 }
575 }
576
577 #[test]
578 fn test_kbins_discretizer_inverse_transform() {
579 let x = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0],];
580
581 let discretizer = KBinsDiscretizer::new()
582 .n_bins(3)
583 .strategy(DiscretizationStrategy::Uniform)
584 .encode(DiscretizerEncoding::Ordinal)
585 .fit(&x, &())
586 .unwrap();
587
588 let x_disc = discretizer.transform(&x).unwrap();
589 let x_inv = discretizer.inverse_transform(&x_disc).unwrap();
590
591 assert!(x_inv[[0, 0]] < 2.0); assert!(x_inv[[2, 0]] > 2.0 && x_inv[[2, 0]] < 4.0); assert!(x_inv[[4, 0]] > 4.0); }
597
598 #[test]
599 fn test_find_bin() {
600 let edges = array![0.0, 2.0, 4.0, 6.0];
601
602 assert_eq!(find_bin(-1.0, &edges), 0);
603 assert_eq!(find_bin(0.0, &edges), 0);
604 assert_eq!(find_bin(1.0, &edges), 0);
605 assert_eq!(find_bin(2.0, &edges), 1);
606 assert_eq!(find_bin(3.0, &edges), 1);
607 assert_eq!(find_bin(4.0, &edges), 2);
608 assert_eq!(find_bin(5.0, &edges), 2);
609 assert_eq!(find_bin(6.0, &edges), 2);
610 assert_eq!(find_bin(7.0, &edges), 2);
611 }
612}