1use scirs2_core::ndarray::{Array1, Array2};
8use std::marker::PhantomData;
9
10use sklears_core::{
11 error::{Result, SklearsError},
12 traits::{Estimator, Fit, Trained, Transform, Untrained},
13 types::Float,
14};
15
16#[derive(Debug, Clone)]
18pub struct BinarizerConfig {
19 pub threshold: Float,
21 pub copy: bool,
23}
24
25impl Default for BinarizerConfig {
26 fn default() -> Self {
27 Self {
28 threshold: 0.0,
29 copy: true,
30 }
31 }
32}
33
34pub struct Binarizer<State = Untrained> {
36 config: BinarizerConfig,
37 state: PhantomData<State>,
38}
39
40impl Binarizer<Untrained> {
41 pub fn new() -> Self {
43 Self {
44 config: BinarizerConfig::default(),
45 state: PhantomData,
46 }
47 }
48
49 pub fn with_threshold(threshold: Float) -> Self {
51 Self {
52 config: BinarizerConfig {
53 threshold,
54 copy: true,
55 },
56 state: PhantomData,
57 }
58 }
59
60 pub fn threshold(mut self, threshold: Float) -> Self {
62 self.config.threshold = threshold;
63 self
64 }
65
66 pub fn copy(mut self, copy: bool) -> Self {
68 self.config.copy = copy;
69 self
70 }
71}
72
73impl Default for Binarizer<Untrained> {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79impl Estimator for Binarizer<Untrained> {
80 type Config = BinarizerConfig;
81 type Error = SklearsError;
82 type Float = Float;
83
84 fn config(&self) -> &Self::Config {
85 &self.config
86 }
87}
88
89impl Estimator for Binarizer<Trained> {
90 type Config = BinarizerConfig;
91 type Error = SklearsError;
92 type Float = Float;
93
94 fn config(&self) -> &Self::Config {
95 &self.config
96 }
97}
98
99impl Fit<Array2<Float>, ()> for Binarizer<Untrained> {
100 type Fitted = Binarizer<Trained>;
101
102 fn fit(self, _x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
103 Ok(Binarizer {
105 config: self.config,
106 state: PhantomData,
107 })
108 }
109}
110
111impl Transform<Array2<Float>, Array2<Float>> for Binarizer<Trained> {
112 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
113 let result = if self.config.copy {
114 x.clone()
115 } else {
116 x.to_owned()
117 };
118
119 Ok(result.mapv(|v| if v > self.config.threshold { 1.0 } else { 0.0 }))
120 }
121}
122
123impl Transform<Array1<Float>, Array1<Float>> for Binarizer<Trained> {
124 fn transform(&self, x: &Array1<Float>) -> Result<Array1<Float>> {
125 let result = if self.config.copy {
126 x.clone()
127 } else {
128 x.to_owned()
129 };
130
131 Ok(result.mapv(|v| if v > self.config.threshold { 1.0 } else { 0.0 }))
132 }
133}
134
135#[derive(Debug, Clone, Copy, PartialEq)]
137pub enum DiscretizationStrategy {
138 Uniform,
140 Quantile,
142 KMeans,
144}
145
146#[derive(Debug, Clone, Copy, PartialEq)]
148pub enum DiscretizerEncoding {
149 OneHot,
151 Ordinal,
153}
154
155#[derive(Debug, Clone)]
157pub struct KBinsDiscretizerConfig {
158 pub n_bins: usize,
160 pub encode: DiscretizerEncoding,
162 pub strategy: DiscretizationStrategy,
164 pub subsample: Option<usize>,
166 pub random_state: Option<u64>,
168}
169
170impl Default for KBinsDiscretizerConfig {
171 fn default() -> Self {
172 Self {
173 n_bins: 5,
174 encode: DiscretizerEncoding::OneHot,
175 strategy: DiscretizationStrategy::Quantile,
176 subsample: Some(200_000),
177 random_state: None,
178 }
179 }
180}
181
182pub struct KBinsDiscretizer<State = Untrained> {
184 config: KBinsDiscretizerConfig,
185 state: PhantomData<State>,
186 bin_edges_: Option<Vec<Array1<Float>>>,
188 n_bins_: Option<Vec<usize>>,
190}
191
192impl KBinsDiscretizer<Untrained> {
193 pub fn new() -> Self {
195 Self {
196 config: KBinsDiscretizerConfig::default(),
197 state: PhantomData,
198 bin_edges_: None,
199 n_bins_: None,
200 }
201 }
202
203 pub fn n_bins(mut self, n_bins: usize) -> Result<Self> {
205 if n_bins < 2 {
206 return Err(SklearsError::InvalidParameter {
207 name: "n_bins".to_string(),
208 reason: "must be at least 2".to_string(),
209 });
210 }
211 self.config.n_bins = n_bins;
212 Ok(self)
213 }
214
215 pub fn encode(mut self, encode: DiscretizerEncoding) -> Self {
217 self.config.encode = encode;
218 self
219 }
220
221 pub fn strategy(mut self, strategy: DiscretizationStrategy) -> Self {
223 self.config.strategy = strategy;
224 self
225 }
226}
227
228impl Default for KBinsDiscretizer<Untrained> {
229 fn default() -> Self {
230 Self::new()
231 }
232}
233
234impl Estimator for KBinsDiscretizer<Untrained> {
235 type Config = KBinsDiscretizerConfig;
236 type Error = SklearsError;
237 type Float = Float;
238
239 fn config(&self) -> &Self::Config {
240 &self.config
241 }
242}
243
244impl Estimator for KBinsDiscretizer<Trained> {
245 type Config = KBinsDiscretizerConfig;
246 type Error = SklearsError;
247 type Float = Float;
248
249 fn config(&self) -> &Self::Config {
250 &self.config
251 }
252}
253
254fn compute_uniform_bins(data: &Array1<Float>, n_bins: usize) -> Array1<Float> {
256 let min_val = data.iter().cloned().fold(Float::INFINITY, Float::min);
257 let max_val = data.iter().cloned().fold(Float::NEG_INFINITY, Float::max);
258
259 if (max_val - min_val).abs() < Float::EPSILON {
260 return Array1::from_vec(vec![min_val - 0.5, max_val + 0.5]);
262 }
263
264 let width = (max_val - min_val) / n_bins as Float;
265 let mut edges = Vec::with_capacity(n_bins + 1);
266
267 for i in 0..=n_bins {
268 edges.push(min_val + i as Float * width);
269 }
270
271 edges[n_bins] = max_val + Float::EPSILON;
273
274 Array1::from_vec(edges)
275}
276
277fn compute_quantile_bins(data: &Array1<Float>, n_bins: usize) -> Array1<Float> {
279 let mut sorted_data = data.to_vec();
280 sorted_data.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
281
282 let n_samples = sorted_data.len();
283 let mut edges = Vec::with_capacity(n_bins + 1);
284
285 edges.push(sorted_data[0]);
287
288 for i in 1..n_bins {
290 let idx = (i * n_samples) / n_bins;
291 let value = sorted_data[idx.min(n_samples - 1)];
292
293 if value > edges.last().expect("collection should not be empty") + Float::EPSILON {
295 edges.push(value);
296 }
297 }
298
299 edges.push(sorted_data[n_samples - 1] + Float::EPSILON);
301
302 if edges.len() < 3 {
304 edges.clear();
305 edges.push(sorted_data[0]);
306 edges.push(sorted_data[n_samples - 1] + Float::EPSILON);
307 }
308
309 Array1::from_vec(edges)
310}
311
312impl Fit<Array2<Float>, ()> for KBinsDiscretizer<Untrained> {
313 type Fitted = KBinsDiscretizer<Trained>;
314
315 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
316 let n_features = x.ncols();
317 let mut bin_edges = Vec::with_capacity(n_features);
318 let mut n_bins = Vec::with_capacity(n_features);
319
320 for j in 0..n_features {
322 let feature_data = x.column(j).to_owned();
323
324 let edges = match self.config.strategy {
325 DiscretizationStrategy::Uniform => {
326 compute_uniform_bins(&feature_data, self.config.n_bins)
327 }
328 DiscretizationStrategy::Quantile => {
329 compute_quantile_bins(&feature_data, self.config.n_bins)
330 }
331 DiscretizationStrategy::KMeans => {
332 compute_quantile_bins(&feature_data, self.config.n_bins)
334 }
335 };
336
337 n_bins.push(edges.len() - 1);
338 bin_edges.push(edges);
339 }
340
341 Ok(KBinsDiscretizer {
342 config: self.config,
343 state: PhantomData,
344 bin_edges_: Some(bin_edges),
345 n_bins_: Some(n_bins),
346 })
347 }
348}
349
350fn find_bin(value: Float, edges: &Array1<Float>) -> usize {
352 let n_edges = edges.len();
354
355 if value <= edges[0] {
356 return 0;
357 }
358 if value >= edges[n_edges - 1] {
359 return n_edges - 2;
360 }
361
362 let mut left = 0;
363 let mut right = n_edges - 1;
364
365 while left < right - 1 {
366 let mid = (left + right) / 2;
367 if value < edges[mid] {
368 right = mid;
369 } else {
370 left = mid;
371 }
372 }
373
374 left
375}
376
377impl Transform<Array2<Float>, Array2<Float>> for KBinsDiscretizer<Trained> {
378 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
379 let n_samples = x.nrows();
380 let n_features = x.ncols();
381 let bin_edges = self.bin_edges_.as_ref().expect("operation should succeed");
382 let n_bins = self.n_bins_.as_ref().expect("operation should succeed");
383
384 match self.config.encode {
385 DiscretizerEncoding::Ordinal => {
386 let mut result = Array2::zeros((n_samples, n_features));
387
388 for i in 0..n_samples {
389 for j in 0..n_features {
390 let bin_idx = find_bin(x[[i, j]], &bin_edges[j]);
391 result[[i, j]] = bin_idx as Float;
392 }
393 }
394
395 Ok(result)
396 }
397 DiscretizerEncoding::OneHot => {
398 let total_bins: usize = n_bins.iter().sum();
400 let mut result = Array2::zeros((n_samples, total_bins));
401
402 for i in 0..n_samples {
403 let mut col_offset = 0;
404 for j in 0..n_features {
405 let bin_idx = find_bin(x[[i, j]], &bin_edges[j]);
406 result[[i, col_offset + bin_idx]] = 1.0;
407 col_offset += n_bins[j];
408 }
409 }
410
411 Ok(result)
412 }
413 }
414 }
415}
416
417impl KBinsDiscretizer<Trained> {
418 pub fn bin_edges(&self) -> &Vec<Array1<Float>> {
420 self.bin_edges_.as_ref().expect("operation should succeed")
421 }
422
423 pub fn n_bins(&self) -> &Vec<usize> {
425 self.n_bins_.as_ref().expect("operation should succeed")
426 }
427
428 pub fn inverse_transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
430 let bin_edges = self.bin_edges_.as_ref().expect("operation should succeed");
431 let n_features = bin_edges.len();
432
433 match self.config.encode {
434 DiscretizerEncoding::Ordinal => {
435 if x.ncols() != n_features {
436 return Err(SklearsError::InvalidInput(
437 "Input must have the same number of features as during fit".to_string(),
438 ));
439 }
440
441 let mut result = Array2::zeros(x.dim());
442
443 for i in 0..x.nrows() {
444 for j in 0..n_features {
445 let bin_idx = x[[i, j]] as usize;
446 let edges = &bin_edges[j];
447
448 if bin_idx >= edges.len() - 1 {
449 return Err(SklearsError::InvalidInput(format!(
450 "Invalid bin index {bin_idx} for feature {j}"
451 )));
452 }
453
454 result[[i, j]] = (edges[bin_idx] + edges[bin_idx + 1]) / 2.0;
456 }
457 }
458
459 Ok(result)
460 }
461 DiscretizerEncoding::OneHot => Err(SklearsError::InvalidInput(
462 "Inverse transform not supported for one-hot encoding".to_string(),
463 )),
464 }
465 }
466}
467
468#[allow(non_snake_case)]
469#[cfg(test)]
470mod tests {
471 use super::*;
472 use scirs2_core::ndarray::array;
473
474 #[test]
475 fn test_binarizer() {
476 let x = array![[1.0, -1.0, 2.0], [2.0, 0.0, 0.0], [0.0, 1.0, -1.0],];
477
478 let binarizer = Binarizer::with_threshold(0.0)
479 .fit(&x, &())
480 .expect("model fitting should succeed");
481
482 let x_bin = binarizer
483 .transform(&x)
484 .expect("transformation should succeed");
485
486 let expected = array![[1.0, 0.0, 1.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0],];
487
488 assert_eq!(x_bin, expected);
489 }
490
491 #[test]
492 fn test_binarizer_custom_threshold() {
493 let x = array![[1.0, 2.0, 3.0, 4.0]];
494
495 let binarizer = Binarizer::new()
496 .threshold(2.5)
497 .fit(&x, &())
498 .expect("model fitting should succeed");
499
500 let x_bin = binarizer
501 .transform(&x)
502 .expect("transformation should succeed");
503 let expected = array![[0.0, 0.0, 1.0, 1.0]];
504
505 assert_eq!(x_bin, expected);
506 }
507
508 #[test]
509 fn test_binarizer_1d() {
510 let x = array![1.0, -1.0, 2.0, 0.0];
511
512 let binarizer = Binarizer::new()
513 .fit(&array![[0.0]], &())
514 .expect("model fitting should succeed");
515
516 let x_bin = binarizer
517 .transform(&x)
518 .expect("transformation should succeed");
519 let expected = array![1.0, 0.0, 1.0, 0.0];
520
521 assert_eq!(x_bin, expected);
522 }
523
524 #[test]
525 fn test_kbins_discretizer_uniform() {
526 let x = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0],];
527
528 let discretizer = KBinsDiscretizer::new()
529 .n_bins(3)
530 .expect("valid parameter")
531 .strategy(DiscretizationStrategy::Uniform)
532 .encode(DiscretizerEncoding::Ordinal)
533 .fit(&x, &())
534 .expect("operation should succeed");
535
536 let x_disc = discretizer
537 .transform(&x)
538 .expect("transformation should succeed");
539
540 assert_eq!(
543 x_disc.column(0).to_vec(),
544 vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0]
545 );
546 }
547
548 #[test]
549 fn test_kbins_discretizer_quantile() {
550 let x = array![[0.0], [1.0], [1.0], [2.0], [3.0], [10.0],];
551
552 let discretizer = KBinsDiscretizer::new()
553 .n_bins(3)
554 .expect("valid parameter")
555 .strategy(DiscretizationStrategy::Quantile)
556 .encode(DiscretizerEncoding::Ordinal)
557 .fit(&x, &())
558 .expect("operation should succeed");
559
560 let x_disc = discretizer
561 .transform(&x)
562 .expect("transformation should succeed");
563
564 let bin_counts = vec![
566 x_disc.iter().filter(|&&v| v == 0.0).count(),
567 x_disc.iter().filter(|&&v| v == 1.0).count(),
568 x_disc.iter().filter(|&&v| v == 2.0).count(),
569 ];
570
571 for count in bin_counts {
573 assert!(count >= 1 && count <= 3);
574 }
575 }
576
577 #[test]
578 fn test_kbins_discretizer_onehot() {
579 let x = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0],];
580
581 let discretizer = KBinsDiscretizer::new()
582 .n_bins(2)
583 .expect("valid parameter")
584 .encode(DiscretizerEncoding::OneHot)
585 .fit(&x, &())
586 .expect("operation should succeed");
587
588 let x_disc = discretizer
589 .transform(&x)
590 .expect("transformation should succeed");
591
592 assert_eq!(x_disc.ncols(), 4);
594
595 for i in 0..x_disc.nrows() {
597 let row_sum: Float = x_disc.row(i).sum();
598 assert_eq!(row_sum, 2.0);
599 }
600 }
601
602 #[test]
603 fn test_kbins_discretizer_inverse_transform() {
604 let x = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0],];
605
606 let discretizer = KBinsDiscretizer::new()
607 .n_bins(3)
608 .expect("valid parameter")
609 .strategy(DiscretizationStrategy::Uniform)
610 .encode(DiscretizerEncoding::Ordinal)
611 .fit(&x, &())
612 .expect("operation should succeed");
613
614 let x_disc = discretizer
615 .transform(&x)
616 .expect("transformation should succeed");
617 let x_inv = discretizer
618 .inverse_transform(&x_disc)
619 .expect("operation should succeed");
620
621 assert!(x_inv[[0, 0]] < 2.0); assert!(x_inv[[2, 0]] > 2.0 && x_inv[[2, 0]] < 4.0); assert!(x_inv[[4, 0]] > 4.0); }
627
628 #[test]
629 fn test_find_bin() {
630 let edges = array![0.0, 2.0, 4.0, 6.0];
631
632 assert_eq!(find_bin(-1.0, &edges), 0);
633 assert_eq!(find_bin(0.0, &edges), 0);
634 assert_eq!(find_bin(1.0, &edges), 0);
635 assert_eq!(find_bin(2.0, &edges), 1);
636 assert_eq!(find_bin(3.0, &edges), 1);
637 assert_eq!(find_bin(4.0, &edges), 2);
638 assert_eq!(find_bin(5.0, &edges), 2);
639 assert_eq!(find_bin(6.0, &edges), 2);
640 assert_eq!(find_bin(7.0, &edges), 2);
641 }
642}