1use scirs2_core::ndarray::{Array1, Array2};
8use std::collections::{HashMap, HashSet};
9use std::hash::Hash;
10use std::marker::PhantomData;
11
12use sklears_core::{
13 error::{Result, SklearsError},
14 traits::{Estimator, Fit, Trained, Transform, Untrained},
15 types::Float,
16};
17
18#[derive(Debug, Clone)]
20pub struct LabelBinarizerConfig {
21 pub neg_label: i32,
23 pub pos_label: i32,
25 pub sparse_output: bool,
27}
28
29impl Default for LabelBinarizerConfig {
30 fn default() -> Self {
31 Self {
32 neg_label: 0,
33 pos_label: 1,
34 sparse_output: false,
35 }
36 }
37}
38
39pub struct LabelBinarizer<T: Eq + Hash + Clone = i32, State = Untrained> {
41 config: LabelBinarizerConfig,
42 state: PhantomData<State>,
43 classes_: Option<Vec<T>>,
44 class_to_index_: Option<HashMap<T, usize>>,
45}
46
47impl<T: Eq + Hash + Clone> LabelBinarizer<T, Untrained> {
48 pub fn new() -> Self {
50 Self {
51 config: LabelBinarizerConfig::default(),
52 state: PhantomData,
53 classes_: None,
54 class_to_index_: None,
55 }
56 }
57
58 pub fn neg_label(mut self, neg_label: i32) -> Self {
60 self.config.neg_label = neg_label;
61 self
62 }
63
64 pub fn pos_label(mut self, pos_label: i32) -> Self {
66 self.config.pos_label = pos_label;
67 self
68 }
69}
70
71impl<T: Eq + Hash + Clone> Default for LabelBinarizer<T, Untrained> {
72 fn default() -> Self {
73 Self::new()
74 }
75}
76
77impl<T: Eq + Hash + Clone> Estimator for LabelBinarizer<T, Untrained> {
78 type Config = LabelBinarizerConfig;
79 type Error = SklearsError;
80 type Float = Float;
81
82 fn config(&self) -> &Self::Config {
83 &self.config
84 }
85}
86
87impl<T: Eq + Hash + Clone> Estimator for LabelBinarizer<T, Trained> {
88 type Config = LabelBinarizerConfig;
89 type Error = SklearsError;
90 type Float = Float;
91
92 fn config(&self) -> &Self::Config {
93 &self.config
94 }
95}
96
97impl<T: Eq + Hash + Clone + Ord + Send + Sync> Fit<Array1<T>, ()> for LabelBinarizer<T, Untrained> {
98 type Fitted = LabelBinarizer<T, Trained>;
99
100 fn fit(self, y: &Array1<T>, _x: &()) -> Result<Self::Fitted> {
101 let mut classes = HashSet::new();
103 for label in y.iter() {
104 classes.insert(label.clone());
105 }
106
107 let mut sorted_classes: Vec<T> = classes.into_iter().collect();
109 sorted_classes.sort();
110
111 let class_to_index: HashMap<T, usize> = sorted_classes
113 .iter()
114 .enumerate()
115 .map(|(i, c)| (c.clone(), i))
116 .collect();
117
118 Ok(LabelBinarizer {
119 config: self.config,
120 state: PhantomData,
121 classes_: Some(sorted_classes),
122 class_to_index_: Some(class_to_index),
123 })
124 }
125}
126
127impl<T: Eq + Hash + Clone> Transform<Array1<T>, Array2<Float>> for LabelBinarizer<T, Trained> {
128 fn transform(&self, y: &Array1<T>) -> Result<Array2<Float>> {
129 let classes = self.classes_.as_ref().expect("operation should succeed");
130 let class_to_index = self
131 .class_to_index_
132 .as_ref()
133 .expect("operation should succeed");
134 let n_samples = y.len();
135 let n_classes = classes.len();
136
137 if n_classes == 0 {
138 return Err(SklearsError::InvalidInput(
139 "No classes found during fit".to_string(),
140 ));
141 }
142
143 if n_classes == 2 {
145 let mut result = Array2::zeros((n_samples, 1));
146 for (i, label) in y.iter().enumerate() {
147 if let Some(&class_idx) = class_to_index.get(label) {
148 result[[i, 0]] = if class_idx == 1 {
149 self.config.pos_label as Float
150 } else {
151 self.config.neg_label as Float
152 };
153 } else {
154 return Err(SklearsError::InvalidInput(
155 "Unknown label encountered during transform".to_string(),
156 ));
157 }
158 }
159 Ok(result)
160 } else {
161 let mut result =
163 Array2::from_elem((n_samples, n_classes), self.config.neg_label as Float);
164 for (i, label) in y.iter().enumerate() {
165 if let Some(&class_idx) = class_to_index.get(label) {
166 result[[i, class_idx]] = self.config.pos_label as Float;
167 } else {
168 return Err(SklearsError::InvalidInput(
169 "Unknown label encountered during transform".to_string(),
170 ));
171 }
172 }
173 Ok(result)
174 }
175 }
176}
177
178impl<T: Eq + Hash + Clone> LabelBinarizer<T, Trained> {
179 pub fn classes(&self) -> &Vec<T> {
181 self.classes_.as_ref().expect("operation should succeed")
182 }
183
184 pub fn inverse_transform(&self, y: &Array2<Float>) -> Result<Array1<T>> {
186 let classes = self.classes_.as_ref().expect("operation should succeed");
187 let n_samples = y.nrows();
188 let n_classes = classes.len();
189
190 if n_classes == 2 && y.ncols() == 1 {
191 let mut result = Vec::with_capacity(n_samples);
193 let threshold = (self.config.neg_label + self.config.pos_label) as Float / 2.0;
194
195 for i in 0..n_samples {
196 let class_idx = if y[[i, 0]] > threshold { 1 } else { 0 };
197 result.push(classes[class_idx].clone());
198 }
199 Ok(Array1::from_vec(result))
200 } else if y.ncols() == n_classes {
201 let mut result = Vec::with_capacity(n_samples);
203
204 for i in 0..n_samples {
205 let row = y.row(i);
207 let mut max_idx = 0;
208 let mut max_val = row[0];
209
210 for j in 1..n_classes {
211 if row[j] > max_val {
212 max_val = row[j];
213 max_idx = j;
214 }
215 }
216
217 result.push(classes[max_idx].clone());
218 }
219 Ok(Array1::from_vec(result))
220 } else {
221 Err(SklearsError::InvalidInput(format!(
222 "Shape mismatch: y has {} columns but {} classes were expected",
223 y.ncols(),
224 n_classes
225 )))
226 }
227 }
228}
229
230#[derive(Debug, Clone, Default)]
232pub struct MultiLabelBinarizerConfig {
233 pub classes: Option<Vec<String>>,
235 pub sparse_output: bool,
237}
238
239pub struct MultiLabelBinarizer<State = Untrained> {
241 config: MultiLabelBinarizerConfig,
242 state: PhantomData<State>,
243 classes_: Option<Vec<String>>,
244 class_to_index_: Option<HashMap<String, usize>>,
245}
246
247impl MultiLabelBinarizer<Untrained> {
248 pub fn new() -> Self {
250 Self {
251 config: MultiLabelBinarizerConfig::default(),
252 state: PhantomData,
253 classes_: None,
254 class_to_index_: None,
255 }
256 }
257
258 pub fn classes(mut self, classes: Vec<String>) -> Self {
260 self.config.classes = Some(classes);
261 self
262 }
263}
264
265impl Default for MultiLabelBinarizer<Untrained> {
266 fn default() -> Self {
267 Self::new()
268 }
269}
270
271impl Estimator for MultiLabelBinarizer<Untrained> {
272 type Config = MultiLabelBinarizerConfig;
273 type Error = SklearsError;
274 type Float = Float;
275
276 fn config(&self) -> &Self::Config {
277 &self.config
278 }
279}
280
281impl Estimator for MultiLabelBinarizer<Trained> {
282 type Config = MultiLabelBinarizerConfig;
283 type Error = SklearsError;
284 type Float = Float;
285
286 fn config(&self) -> &Self::Config {
287 &self.config
288 }
289}
290
291impl Fit<Vec<Vec<String>>, ()> for MultiLabelBinarizer<Untrained> {
292 type Fitted = MultiLabelBinarizer<Trained>;
293
294 fn fit(self, y: &Vec<Vec<String>>, _x: &()) -> Result<Self::Fitted> {
295 let classes = if let Some(ref classes) = self.config.classes {
296 classes.clone()
297 } else {
298 let mut unique_classes = HashSet::new();
300 for labels in y.iter() {
301 for label in labels.iter() {
302 unique_classes.insert(label.clone());
303 }
304 }
305
306 let mut sorted_classes: Vec<String> = unique_classes.into_iter().collect();
307 sorted_classes.sort();
308 sorted_classes
309 };
310
311 let class_to_index: HashMap<String, usize> = classes
313 .iter()
314 .enumerate()
315 .map(|(i, c)| (c.clone(), i))
316 .collect();
317
318 Ok(MultiLabelBinarizer {
319 config: self.config,
320 state: PhantomData,
321 classes_: Some(classes),
322 class_to_index_: Some(class_to_index),
323 })
324 }
325}
326
327impl Transform<Vec<Vec<String>>, Array2<Float>> for MultiLabelBinarizer<Trained> {
328 fn transform(&self, y: &Vec<Vec<String>>) -> Result<Array2<Float>> {
329 let classes = self.classes_.as_ref().expect("operation should succeed");
330 let class_to_index = self
331 .class_to_index_
332 .as_ref()
333 .expect("operation should succeed");
334 let n_samples = y.len();
335 let n_classes = classes.len();
336
337 let mut result = Array2::zeros((n_samples, n_classes));
338
339 for (i, labels) in y.iter().enumerate() {
340 for label in labels.iter() {
341 if let Some(&class_idx) = class_to_index.get(label) {
342 result[[i, class_idx]] = 1.0;
343 }
344 }
346 }
347
348 Ok(result)
349 }
350}
351
352impl MultiLabelBinarizer<Trained> {
353 pub fn classes(&self) -> &Vec<String> {
355 self.classes_.as_ref().expect("operation should succeed")
356 }
357
358 pub fn inverse_transform(&self, y: &Array2<Float>) -> Result<Vec<Vec<String>>> {
360 let classes = self.classes_.as_ref().expect("operation should succeed");
361 let n_samples = y.nrows();
362 let n_classes = classes.len();
363
364 if y.ncols() != n_classes {
365 return Err(SklearsError::InvalidInput(format!(
366 "Shape mismatch: y has {} columns but {} classes were expected",
367 y.ncols(),
368 n_classes
369 )));
370 }
371
372 let mut result = Vec::with_capacity(n_samples);
373
374 for i in 0..n_samples {
375 let mut labels = Vec::new();
376 for j in 0..n_classes {
377 if y[[i, j]] > 0.5 {
378 labels.push(classes[j].clone());
379 }
380 }
381 result.push(labels);
382 }
383
384 Ok(result)
385 }
386}
387
388#[allow(non_snake_case)]
389#[cfg(test)]
390mod tests {
391 use super::*;
392 use scirs2_core::ndarray::array;
393
394 #[test]
395 fn test_label_binarizer_binary() {
396 let y = array![1, 0, 1, 0, 1];
397
398 let binarizer = LabelBinarizer::new()
399 .fit(&y, &())
400 .expect("model fitting should succeed");
401
402 let y_bin = binarizer
403 .transform(&y)
404 .expect("transformation should succeed");
405
406 assert_eq!(y_bin.shape(), &[5, 1]);
408 assert_eq!(y_bin[[0, 0]], 1.0);
409 assert_eq!(y_bin[[1, 0]], 0.0);
410 assert_eq!(y_bin[[2, 0]], 1.0);
411 }
412
413 #[test]
414 fn test_label_binarizer_multiclass() {
415 let y = array![0, 1, 2, 1, 0];
416
417 let binarizer = LabelBinarizer::new()
418 .fit(&y, &())
419 .expect("model fitting should succeed");
420
421 let y_bin = binarizer
422 .transform(&y)
423 .expect("transformation should succeed");
424
425 assert_eq!(y_bin.shape(), &[5, 3]);
427 assert_eq!(y_bin.row(0).to_vec(), vec![1.0, 0.0, 0.0]);
429 assert_eq!(y_bin.row(1).to_vec(), vec![0.0, 1.0, 0.0]);
431 assert_eq!(y_bin.row(2).to_vec(), vec![0.0, 0.0, 1.0]);
433 }
434
435 #[test]
436 fn test_label_binarizer_inverse_transform() {
437 let y = array!["cat", "dog", "cat", "bird", "dog"];
438
439 let binarizer = LabelBinarizer::new()
440 .fit(&y, &())
441 .expect("model fitting should succeed");
442
443 let y_bin = binarizer
444 .transform(&y)
445 .expect("transformation should succeed");
446 let y_inv = binarizer
447 .inverse_transform(&y_bin)
448 .expect("operation should succeed");
449
450 assert_eq!(y, y_inv);
451 }
452
453 #[test]
454 fn test_label_binarizer_custom_labels() {
455 let y = array![1, 0, 1, 0];
456
457 let binarizer = LabelBinarizer::new()
458 .neg_label(-1)
459 .pos_label(1)
460 .fit(&y, &())
461 .expect("operation should succeed");
462
463 let y_bin = binarizer
464 .transform(&y)
465 .expect("transformation should succeed");
466
467 assert_eq!(y_bin[[0, 0]], 1.0); assert_eq!(y_bin[[1, 0]], -1.0); }
470
471 #[test]
472 fn test_multilabel_binarizer() {
473 let y = vec![
474 vec!["sci-fi".to_string(), "thriller".to_string()],
475 vec!["comedy".to_string()],
476 vec!["sci-fi".to_string(), "comedy".to_string()],
477 ];
478
479 let binarizer = MultiLabelBinarizer::new()
480 .fit(&y, &())
481 .expect("model fitting should succeed");
482
483 let y_bin = binarizer
484 .transform(&y)
485 .expect("transformation should succeed");
486
487 assert_eq!(y_bin.shape(), &[3, 3]);
489 let classes = binarizer.classes();
490 assert_eq!(classes.len(), 3);
491
492 let row0_sum: Float = y_bin.row(0).sum();
494 assert_eq!(row0_sum, 2.0);
495
496 let row1_sum: Float = y_bin.row(1).sum();
498 assert_eq!(row1_sum, 1.0);
499 }
500
501 #[test]
502 fn test_multilabel_binarizer_inverse() {
503 let y = vec![
504 vec!["red".to_string(), "blue".to_string()],
505 vec!["green".to_string()],
506 vec!["red".to_string(), "green".to_string()],
507 ];
508
509 let binarizer = MultiLabelBinarizer::new()
510 .fit(&y, &())
511 .expect("model fitting should succeed");
512
513 let y_bin = binarizer
514 .transform(&y)
515 .expect("transformation should succeed");
516 let y_inv = binarizer
517 .inverse_transform(&y_bin)
518 .expect("operation should succeed");
519
520 for (original, reconstructed) in y.iter().zip(y_inv.iter()) {
522 let orig_set: HashSet<_> = original.iter().collect();
523 let recon_set: HashSet<_> = reconstructed.iter().collect();
524 assert_eq!(orig_set, recon_set);
525 }
526 }
527
528 #[test]
529 fn test_multilabel_binarizer_with_classes() {
530 let y = vec![
531 vec!["a".to_string(), "b".to_string()],
532 vec!["c".to_string()],
533 ];
534
535 let classes = vec![
536 "a".to_string(),
537 "b".to_string(),
538 "c".to_string(),
539 "d".to_string(),
540 ];
541
542 let binarizer = MultiLabelBinarizer::new()
543 .classes(classes.clone())
544 .fit(&y, &())
545 .expect("operation should succeed");
546
547 let y_bin = binarizer
548 .transform(&y)
549 .expect("transformation should succeed");
550
551 assert_eq!(y_bin.shape(), &[2, 4]);
553 assert_eq!(binarizer.classes(), &classes);
554 }
555}