1use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
8use sklears_core::{
9 error::{Result as SklResult, SklearsError},
10 traits::{Estimator, Fit, Predict, Untrained},
11 types::Float,
12};
13use std::collections::HashMap;
14
15#[derive(Debug, Clone)]
33pub struct BinaryRelevance<S = Untrained> {
34 pub state: S,
36 n_jobs: Option<i32>,
37}
38
39impl BinaryRelevance<Untrained> {
40 pub fn new() -> Self {
42 Self {
43 state: Untrained,
44 n_jobs: None,
45 }
46 }
47
48 pub fn n_jobs(mut self, n_jobs: Option<i32>) -> Self {
50 self.n_jobs = n_jobs;
51 self
52 }
53}
54
55impl Default for BinaryRelevance<Untrained> {
56 fn default() -> Self {
57 Self::new()
58 }
59}
60
61impl Estimator for BinaryRelevance<Untrained> {
62 type Config = ();
63 type Error = SklearsError;
64 type Float = Float;
65
66 fn config(&self) -> &Self::Config {
67 &()
68 }
69}
70
71impl Fit<ArrayView2<'_, Float>, Array2<i32>> for BinaryRelevance<Untrained> {
72 type Fitted = BinaryRelevance<BinaryRelevanceTrained>;
73
74 #[allow(non_snake_case)]
75 fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
76 let X = X.to_owned();
77 let (n_samples, n_features) = X.dim();
78
79 if n_samples != y.nrows() {
80 return Err(SklearsError::InvalidInput(
81 "X and y must have the same number of samples".to_string(),
82 ));
83 }
84
85 let n_labels = y.ncols();
86 if n_labels == 0 {
87 return Err(SklearsError::InvalidInput(
88 "y must have at least one label".to_string(),
89 ));
90 }
91
92 let mut binary_classifiers = HashMap::new();
93 let mut classes_per_label = Vec::new();
94
95 for label_idx in 0..n_labels {
97 let y_label = y.column(label_idx);
98
99 let mut label_classes: Vec<i32> = y_label
101 .iter()
102 .cloned()
103 .collect::<std::collections::HashSet<_>>()
104 .into_iter()
105 .collect();
106 label_classes.sort();
107
108 if label_classes.len() > 2 {
110 return Err(SklearsError::InvalidInput(format!(
111 "Label {} has {} classes, but BinaryRelevance expects binary labels",
112 label_idx,
113 label_classes.len()
114 )));
115 }
116
117 let has_positive = label_classes.contains(&1);
119 let has_negative = label_classes.contains(&0);
120
121 if !has_positive && !has_negative {
122 return Err(SklearsError::InvalidInput(format!(
123 "Label {} has no training examples",
124 label_idx
125 )));
126 }
127
128 let weights = train_binary_classifier(&X, &y_label)?;
130 binary_classifiers.insert(label_idx, weights);
131 classes_per_label.push(label_classes);
132 }
133
134 Ok(BinaryRelevance {
135 state: BinaryRelevanceTrained {
136 binary_classifiers,
137 classes_per_label,
138 n_labels,
139 n_features,
140 },
141 n_jobs: self.n_jobs,
142 })
143 }
144}
145
146fn train_binary_classifier(
148 X: &Array2<Float>,
149 y: &scirs2_core::ndarray::ArrayView1<i32>,
150) -> SklResult<(Array1<f64>, f64)> {
151 let (n_samples, n_features) = X.dim();
152
153 let mut weights = Array1::<Float>::zeros(n_features);
155 let mut bias = 0.0;
156
157 let y_mean: f64 = y.iter().map(|&label| label as f64).sum::<f64>() / n_samples as f64;
159
160 bias = if y_mean > 0.0 && y_mean < 1.0 {
162 (y_mean / (1.0 - y_mean)).ln()
163 } else if y_mean >= 1.0 {
164 2.0 } else {
166 -2.0 };
168
169 for feature_idx in 0..n_features {
171 let mut x_mean = 0.0;
172 for sample_idx in 0..n_samples {
173 x_mean += X[[sample_idx, feature_idx]];
174 }
175 x_mean /= n_samples as f64;
176
177 let mut numerator: f64 = 0.0;
179 let mut x_var: f64 = 0.0;
180 let mut y_var: f64 = 0.0;
181
182 for sample_idx in 0..n_samples {
183 let x_diff = X[[sample_idx, feature_idx]] - x_mean;
184 let y_diff = y[sample_idx] as f64 - y_mean;
185 numerator += x_diff * y_diff;
186 x_var += x_diff * x_diff;
187 y_var += y_diff * y_diff;
188 }
189
190 if x_var > 1e-10 && y_var > 1e-10 {
191 let correlation = numerator / (x_var.sqrt() * y_var.sqrt());
192 weights[feature_idx] = correlation; }
194 }
195
196 Ok((weights, bias))
197}
198
199impl BinaryRelevance<BinaryRelevanceTrained> {
200 pub fn classes(&self) -> &[Vec<i32>] {
202 &self.state.classes_per_label
203 }
204
205 pub fn n_labels(&self) -> usize {
207 self.state.n_labels
208 }
209
210 pub fn n_features(&self) -> usize {
212 self.state.n_features
213 }
214}
215
216impl Predict<ArrayView2<'_, Float>, Array2<i32>> for BinaryRelevance<BinaryRelevanceTrained> {
217 #[allow(non_snake_case)]
218 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
219 let X = X.to_owned();
220 let (n_samples, n_features) = X.dim();
221
222 if n_features != self.state.n_features {
223 return Err(SklearsError::InvalidInput(
224 "Number of features doesn't match training data".to_string(),
225 ));
226 }
227
228 let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_labels));
229
230 for label_idx in 0..self.state.n_labels {
232 if let Some((weights, bias)) = self.state.binary_classifiers.get(&label_idx) {
233 for (sample_idx, sample) in X.axis_iter(Axis(0)).enumerate() {
234 let score: f64 = sample
236 .iter()
237 .zip(weights.iter())
238 .map(|(&x, &w)| x * w)
239 .sum::<f64>()
240 + bias;
241
242 let prob = 1.0 / (1.0 + (-score).exp());
244 let prediction = if prob > 0.5 { 1 } else { 0 };
245
246 predictions[[sample_idx, label_idx]] = prediction;
247 }
248 }
249 }
250
251 Ok(predictions)
252 }
253}
254
255impl BinaryRelevance<BinaryRelevanceTrained> {
257 #[allow(non_snake_case)]
259 pub fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
260 let X = X.to_owned();
261 let (n_samples, n_features) = X.dim();
262
263 if n_features != self.state.n_features {
264 return Err(SklearsError::InvalidInput(
265 "Number of features doesn't match training data".to_string(),
266 ));
267 }
268
269 let mut probabilities = Array2::<Float>::zeros((n_samples, self.state.n_labels));
270
271 for label_idx in 0..self.state.n_labels {
273 if let Some((weights, bias)) = self.state.binary_classifiers.get(&label_idx) {
274 for (sample_idx, sample) in X.axis_iter(Axis(0)).enumerate() {
275 let score: f64 = sample
277 .iter()
278 .zip(weights.iter())
279 .map(|(&x, &w)| x * w)
280 .sum::<f64>()
281 + bias;
282
283 let prob = 1.0 / (1.0 + (-score).exp());
285 probabilities[[sample_idx, label_idx]] = prob;
286 }
287 }
288 }
289
290 Ok(probabilities)
291 }
292}
293
294#[derive(Debug, Clone)]
296pub struct BinaryRelevanceTrained {
297 pub binary_classifiers: HashMap<usize, (Array1<f64>, f64)>,
299 pub classes_per_label: Vec<Vec<i32>>,
301 pub n_labels: usize,
303 pub n_features: usize,
305}
306
307#[derive(Debug, Clone)]
326pub struct LabelPowerset<S = Untrained> {
327 state: S,
328}
329
330impl LabelPowerset<Untrained> {
331 pub fn new() -> Self {
333 Self { state: Untrained }
334 }
335}
336
337impl Default for LabelPowerset<Untrained> {
338 fn default() -> Self {
339 Self::new()
340 }
341}
342
343impl Estimator for LabelPowerset<Untrained> {
344 type Config = ();
345 type Error = SklearsError;
346 type Float = Float;
347
348 fn config(&self) -> &Self::Config {
349 &()
350 }
351}
352
353impl Fit<ArrayView2<'_, Float>, Array2<i32>> for LabelPowerset<Untrained> {
354 type Fitted = LabelPowerset<LabelPowersetTrained>;
355
356 #[allow(non_snake_case)]
357 fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
358 let X = X.to_owned();
359 let (n_samples, n_features) = X.dim();
360
361 if n_samples != y.nrows() {
362 return Err(SklearsError::InvalidInput(
363 "X and y must have the same number of samples".to_string(),
364 ));
365 }
366
367 let n_labels = y.ncols();
368 if n_labels == 0 {
369 return Err(SklearsError::InvalidInput(
370 "y must have at least one label".to_string(),
371 ));
372 }
373
374 for sample_idx in 0..n_samples {
376 for label_idx in 0..n_labels {
377 let label_value = y[[sample_idx, label_idx]];
378 if label_value != 0 && label_value != 1 {
379 return Err(SklearsError::InvalidInput(format!(
380 "LabelPowerset expects binary labels, but found {} at position ({}, {})",
381 label_value, sample_idx, label_idx
382 )));
383 }
384 }
385 }
386
387 let mut class_to_combination: HashMap<usize, Vec<i32>> = HashMap::new();
389 let mut combination_to_class: HashMap<Vec<i32>, usize> = HashMap::new();
390 let mut transformed_labels = Vec::new();
391 let mut next_class_id = 0;
392
393 for sample_idx in 0..n_samples {
394 let combination: Vec<i32> = (0..n_labels)
396 .map(|label_idx| y[[sample_idx, label_idx]])
397 .collect();
398
399 let class_id = if let Some(&existing_class_id) = combination_to_class.get(&combination)
401 {
402 existing_class_id
403 } else {
404 let class_id = next_class_id;
406 combination_to_class.insert(combination.clone(), class_id);
407 class_to_combination.insert(class_id, combination);
408 next_class_id += 1;
409 class_id
410 };
411
412 transformed_labels.push(class_id);
413 }
414
415 let mut class_centroids: HashMap<usize, Array1<f64>> = HashMap::new();
418
419 for &class_id in class_to_combination.keys() {
420 let mut centroid = Array1::<Float>::zeros(n_features);
421 let mut count = 0;
422
423 for (sample_idx, &sample_class) in transformed_labels.iter().enumerate() {
425 if sample_class == class_id {
426 for feature_idx in 0..n_features {
427 centroid[feature_idx] += X[[sample_idx, feature_idx]];
428 }
429 count += 1;
430 }
431 }
432
433 if count > 0 {
434 centroid /= count as f64;
435 }
436 class_centroids.insert(class_id, centroid);
437 }
438
439 let unique_classes: Vec<usize> = class_to_combination.keys().cloned().collect();
440
441 Ok(LabelPowerset {
442 state: LabelPowersetTrained {
443 class_to_combination,
444 combination_to_class,
445 class_centroids,
446 unique_classes,
447 n_labels,
448 n_features,
449 },
450 })
451 }
452}
453
454impl LabelPowerset<LabelPowersetTrained> {
455 pub fn classes(&self) -> &HashMap<usize, Vec<i32>> {
457 &self.state.class_to_combination
458 }
459
460 pub fn n_classes(&self) -> usize {
462 self.state.unique_classes.len()
463 }
464
465 pub fn n_labels(&self) -> usize {
467 self.state.n_labels
468 }
469}
470
471impl Predict<ArrayView2<'_, Float>, Array2<i32>> for LabelPowerset<LabelPowersetTrained> {
472 #[allow(non_snake_case)]
473 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
474 let X = X.to_owned();
475 let (n_samples, n_features) = X.dim();
476
477 if n_features != self.state.n_features {
478 return Err(SklearsError::InvalidInput(
479 "Number of features doesn't match training data".to_string(),
480 ));
481 }
482
483 let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_labels));
484
485 for (sample_idx, sample) in X.axis_iter(Axis(0)).enumerate() {
487 let mut min_distance = f64::INFINITY;
488 let mut best_class_id = 0;
489
490 for (&class_id, centroid) in &self.state.class_centroids {
492 let mut distance = 0.0;
493 for feature_idx in 0..n_features {
494 let diff = sample[feature_idx] - centroid[feature_idx];
495 distance += diff * diff;
496 }
497 distance = distance.sqrt();
498
499 if distance < min_distance {
500 min_distance = distance;
501 best_class_id = class_id;
502 }
503 }
504
505 if let Some(label_combination) = self.state.class_to_combination.get(&best_class_id) {
507 for label_idx in 0..self.state.n_labels {
508 predictions[[sample_idx, label_idx]] = label_combination[label_idx];
509 }
510 }
511 }
512
513 Ok(predictions)
514 }
515}
516
517impl LabelPowerset<LabelPowersetTrained> {
519 #[allow(non_snake_case)]
521 pub fn decision_function(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
522 let X = X.to_owned();
523 let (n_samples, n_features) = X.dim();
524
525 if n_features != self.state.n_features {
526 return Err(SklearsError::InvalidInput(
527 "Number of features doesn't match training data".to_string(),
528 ));
529 }
530
531 let n_classes = self.state.unique_classes.len();
532 let mut scores = Array2::<Float>::zeros((n_samples, n_classes));
533
534 for (sample_idx, sample) in X.axis_iter(Axis(0)).enumerate() {
536 for (class_idx, &class_id) in self.state.unique_classes.iter().enumerate() {
537 if let Some(centroid) = self.state.class_centroids.get(&class_id) {
538 let mut distance = 0.0;
539 for feature_idx in 0..n_features {
540 let diff = sample[feature_idx] - centroid[feature_idx];
541 distance += diff * diff;
542 }
543 distance = distance.sqrt();
544
545 scores[[sample_idx, class_idx]] = -distance;
547 }
548 }
549 }
550
551 Ok(scores)
552 }
553}
554
555#[derive(Debug, Clone)]
557pub struct LabelPowersetTrained {
558 pub class_to_combination: HashMap<usize, Vec<i32>>,
560 pub combination_to_class: HashMap<Vec<i32>, usize>,
562 pub class_centroids: HashMap<usize, Array1<f64>>,
564 pub unique_classes: Vec<usize>,
566 pub n_labels: usize,
568 pub n_features: usize,
570}
571
572#[derive(Debug, Clone)]
591pub struct PrunedLabelPowerset<S = Untrained> {
592 state: S,
593 min_frequency: usize,
594 strategy: PruningStrategy,
595}
596
597#[derive(Debug, Clone)]
599pub enum PruningStrategy {
600 SimilarityMapping,
602 DefaultMapping(Vec<i32>),
604}
605
606impl PrunedLabelPowerset<Untrained> {
607 pub fn new() -> Self {
609 Self {
610 state: Untrained,
611 min_frequency: 2,
612 strategy: PruningStrategy::DefaultMapping(vec![]),
613 }
614 }
615
616 pub fn min_frequency(mut self, min_frequency: usize) -> Self {
618 self.min_frequency = min_frequency;
619 self
620 }
621
622 pub fn strategy(mut self, strategy: PruningStrategy) -> Self {
624 self.strategy = strategy;
625 self
626 }
627
628 pub fn get_min_frequency(&self) -> usize {
630 self.min_frequency
631 }
632
633 pub fn get_strategy(&self) -> &PruningStrategy {
635 &self.strategy
636 }
637}
638
639impl Default for PrunedLabelPowerset<Untrained> {
640 fn default() -> Self {
641 Self::new()
642 }
643}
644
645impl Estimator for PrunedLabelPowerset<Untrained> {
646 type Config = ();
647 type Error = SklearsError;
648 type Float = Float;
649
650 fn config(&self) -> &Self::Config {
651 &()
652 }
653}
654
655impl Fit<ArrayView2<'_, Float>, Array2<i32>> for PrunedLabelPowerset<Untrained> {
656 type Fitted = PrunedLabelPowerset<PrunedLabelPowersetTrained>;
657
658 #[allow(non_snake_case)]
659 fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
660 let X = X.to_owned();
661 let (n_samples, n_features) = X.dim();
662
663 if n_samples != y.nrows() {
664 return Err(SklearsError::InvalidInput(
665 "X and y must have the same number of samples".to_string(),
666 ));
667 }
668
669 let n_labels = y.ncols();
670 if n_labels == 0 {
671 return Err(SklearsError::InvalidInput(
672 "y must have at least one label".to_string(),
673 ));
674 }
675
676 for sample_idx in 0..n_samples {
678 for label_idx in 0..n_labels {
679 let label_value = y[[sample_idx, label_idx]];
680 if label_value != 0 && label_value != 1 {
681 return Err(SklearsError::InvalidInput(format!(
682 "PrunedLabelPowerset expects binary labels, but found {} at position ({}, {})",
683 label_value, sample_idx, label_idx
684 )));
685 }
686 }
687 }
688
689 let mut combination_counts: HashMap<Vec<i32>, usize> = HashMap::new();
691 for sample_idx in 0..n_samples {
692 let combination: Vec<i32> = (0..n_labels)
693 .map(|label_idx| y[[sample_idx, label_idx]])
694 .collect();
695 *combination_counts.entry(combination).or_insert(0) += 1;
696 }
697
698 let frequent_combinations: Vec<Vec<i32>> = combination_counts
700 .iter()
701 .filter(|(_, &count)| count >= self.min_frequency)
702 .map(|(combination, _)| combination.clone())
703 .collect();
704
705 if frequent_combinations.is_empty() {
706 return Err(SklearsError::InvalidInput(
707 "No label combinations meet the minimum frequency threshold".to_string(),
708 ));
709 }
710
711 let default_combination = match &self.strategy {
713 PruningStrategy::DefaultMapping(ref default) => {
714 if default.is_empty() {
715 vec![0; n_labels] } else if default.len() != n_labels {
717 return Err(SklearsError::InvalidInput(
718 "Default combination length must match number of labels".to_string(),
719 ));
720 } else {
721 default.clone()
722 }
723 }
724 PruningStrategy::SimilarityMapping => vec![], };
726
727 let mut final_frequent_combinations = frequent_combinations.clone();
730 if let PruningStrategy::DefaultMapping(_) = &self.strategy {
731 if !final_frequent_combinations.contains(&default_combination) {
732 final_frequent_combinations.push(default_combination.clone());
733 }
734 }
735
736 let mut combination_mapping: HashMap<Vec<i32>, Vec<i32>> = HashMap::new();
738
739 for (combination, &count) in &combination_counts {
740 if count >= self.min_frequency {
741 combination_mapping.insert(combination.clone(), combination.clone());
743 } else {
744 let mapped_combination = match &self.strategy {
746 PruningStrategy::SimilarityMapping => {
747 let mut best_similarity = -1.0;
749 let mut best_combination = &final_frequent_combinations[0];
750
751 for freq_combo in &final_frequent_combinations {
752 let intersection: i32 = combination
754 .iter()
755 .zip(freq_combo.iter())
756 .map(|(&a, &b)| if a == 1 && b == 1 { 1 } else { 0 })
757 .sum();
758 let union: i32 = combination
759 .iter()
760 .zip(freq_combo.iter())
761 .map(|(&a, &b)| if a == 1 || b == 1 { 1 } else { 0 })
762 .sum();
763
764 let similarity = if union > 0 {
765 intersection as f64 / union as f64
766 } else {
767 1.0 };
769
770 if similarity > best_similarity {
771 best_similarity = similarity;
772 best_combination = freq_combo;
773 }
774 }
775 best_combination.clone()
776 }
777 PruningStrategy::DefaultMapping(_) => default_combination.clone(),
778 };
779 combination_mapping.insert(combination.clone(), mapped_combination);
780 }
781 }
782
783 let mut class_to_combination: HashMap<usize, Vec<i32>> = HashMap::new();
785 let mut combination_to_class: HashMap<Vec<i32>, usize> = HashMap::new();
786
787 for (next_class_id, combo) in final_frequent_combinations.iter().enumerate() {
788 class_to_combination.insert(next_class_id, combo.clone());
789 combination_to_class.insert(combo.clone(), next_class_id);
790 }
791
792 let mut transformed_labels = Vec::new();
794 for sample_idx in 0..n_samples {
795 let original_combination: Vec<i32> = (0..n_labels)
796 .map(|label_idx| y[[sample_idx, label_idx]])
797 .collect();
798
799 let mapped_combination = combination_mapping
800 .get(&original_combination)
801 .unwrap()
802 .clone();
803
804 let class_id = *combination_to_class.get(&mapped_combination).unwrap();
805 transformed_labels.push(class_id);
806 }
807
808 let mut class_centroids: HashMap<usize, Array1<f64>> = HashMap::new();
810
811 for &class_id in class_to_combination.keys() {
812 let mut centroid = Array1::<Float>::zeros(n_features);
813 let mut count = 0;
814
815 for (sample_idx, &sample_class) in transformed_labels.iter().enumerate() {
816 if sample_class == class_id {
817 for feature_idx in 0..n_features {
818 centroid[feature_idx] += X[[sample_idx, feature_idx]];
819 }
820 count += 1;
821 }
822 }
823
824 if count > 0 {
825 centroid /= count as f64;
826 }
827 class_centroids.insert(class_id, centroid);
828 }
829
830 let unique_classes: Vec<usize> = class_to_combination.keys().cloned().collect();
831
832 Ok(PrunedLabelPowerset {
833 state: PrunedLabelPowersetTrained {
834 class_to_combination,
835 combination_to_class,
836 combination_mapping,
837 class_centroids,
838 unique_classes,
839 frequent_combinations: final_frequent_combinations,
840 n_labels,
841 n_features,
842 min_frequency: self.min_frequency,
843 strategy: self.strategy.clone(),
844 },
845 min_frequency: self.min_frequency,
846 strategy: self.strategy.clone(),
847 })
848 }
849}
850
851impl PrunedLabelPowerset<PrunedLabelPowersetTrained> {
852 pub fn frequent_combinations(&self) -> &[Vec<i32>] {
854 &self.state.frequent_combinations
855 }
856
857 pub fn n_frequent_classes(&self) -> usize {
859 self.state.unique_classes.len()
860 }
861
862 pub fn combination_mapping(&self) -> &HashMap<Vec<i32>, Vec<i32>> {
864 &self.state.combination_mapping
865 }
866
867 pub fn min_frequency(&self) -> usize {
869 self.state.min_frequency
870 }
871
872 pub fn n_features(&self) -> usize {
874 self.state.n_features
875 }
876
877 pub fn n_labels(&self) -> usize {
879 self.state.n_labels
880 }
881
882 pub fn class_centroids(&self) -> &HashMap<usize, Array1<f64>> {
884 &self.state.class_centroids
885 }
886
887 pub fn class_to_combination(&self) -> &HashMap<usize, Vec<i32>> {
889 &self.state.class_to_combination
890 }
891}
892
893impl Predict<ArrayView2<'_, Float>, Array2<i32>>
894 for PrunedLabelPowerset<PrunedLabelPowersetTrained>
895{
896 #[allow(non_snake_case)]
897 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
898 let X = X.to_owned();
899 let (n_samples, n_features) = X.dim();
900
901 if n_features != self.state.n_features {
902 return Err(SklearsError::InvalidInput(
903 "Number of features doesn't match training data".to_string(),
904 ));
905 }
906
907 let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_labels));
908
909 for (sample_idx, sample) in X.axis_iter(Axis(0)).enumerate() {
911 let mut min_distance = f64::INFINITY;
912 let mut best_class_id = 0;
913
914 for (&class_id, centroid) in &self.state.class_centroids {
916 let mut distance = 0.0;
917 for feature_idx in 0..n_features {
918 let diff = sample[feature_idx] - centroid[feature_idx];
919 distance += diff * diff;
920 }
921 distance = distance.sqrt();
922
923 if distance < min_distance {
924 min_distance = distance;
925 best_class_id = class_id;
926 }
927 }
928
929 if let Some(label_combination) = self.state.class_to_combination.get(&best_class_id) {
931 for label_idx in 0..self.state.n_labels {
932 predictions[[sample_idx, label_idx]] = label_combination[label_idx];
933 }
934 }
935 }
936
937 Ok(predictions)
938 }
939}
940
941#[derive(Debug, Clone)]
943pub struct PrunedLabelPowersetTrained {
944 pub class_to_combination: HashMap<usize, Vec<i32>>,
946 pub combination_to_class: HashMap<Vec<i32>, usize>,
948 pub combination_mapping: HashMap<Vec<i32>, Vec<i32>>,
950 pub class_centroids: HashMap<usize, Array1<f64>>,
952 pub unique_classes: Vec<usize>,
954 pub frequent_combinations: Vec<Vec<i32>>,
956 pub n_labels: usize,
958 pub n_features: usize,
960 pub min_frequency: usize,
962 pub strategy: PruningStrategy,
964}
965
966#[derive(Debug, Clone)]
984pub struct OneVsRestClassifier<S = Untrained> {
985 state: S,
986 n_jobs: Option<i32>,
987}
988
989impl OneVsRestClassifier<Untrained> {
990 pub fn new() -> Self {
992 Self {
993 state: Untrained,
994 n_jobs: None,
995 }
996 }
997
998 pub fn n_jobs(mut self, n_jobs: Option<i32>) -> Self {
1000 self.n_jobs = n_jobs;
1001 self
1002 }
1003}
1004
1005impl Default for OneVsRestClassifier<Untrained> {
1006 fn default() -> Self {
1007 Self::new()
1008 }
1009}
1010
1011impl Estimator for OneVsRestClassifier<Untrained> {
1012 type Config = ();
1013 type Error = SklearsError;
1014 type Float = Float;
1015
1016 fn config(&self) -> &Self::Config {
1017 &()
1018 }
1019}
1020
1021impl Fit<ArrayView2<'_, Float>, Array2<i32>> for OneVsRestClassifier<Untrained> {
1022 type Fitted = OneVsRestClassifier<OneVsRestClassifierTrained>;
1023
1024 fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
1025 let br = BinaryRelevance::new().n_jobs(self.n_jobs);
1027 let fitted_br = br.fit(X, y)?;
1028
1029 Ok(OneVsRestClassifier {
1030 state: OneVsRestClassifierTrained {
1031 binary_relevance: fitted_br,
1032 },
1033 n_jobs: self.n_jobs,
1034 })
1035 }
1036}
1037
1038impl OneVsRestClassifier<OneVsRestClassifierTrained> {
1039 pub fn classes(&self) -> &[Vec<i32>] {
1041 self.state.binary_relevance.classes()
1042 }
1043
1044 pub fn n_labels(&self) -> usize {
1046 self.state.binary_relevance.n_labels()
1047 }
1048}
1049
1050impl Predict<ArrayView2<'_, Float>, Array2<i32>>
1051 for OneVsRestClassifier<OneVsRestClassifierTrained>
1052{
1053 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
1054 self.state.binary_relevance.predict(X)
1055 }
1056}
1057
1058impl OneVsRestClassifier<OneVsRestClassifierTrained> {
1059 pub fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
1061 self.state.binary_relevance.predict_proba(X)
1062 }
1063
1064 #[allow(non_snake_case)]
1066 pub fn decision_function(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
1067 let X = X.to_owned();
1068 let (n_samples, n_features) = X.dim();
1069
1070 if n_features != self.state.binary_relevance.n_features() {
1071 return Err(SklearsError::InvalidInput(
1072 "Number of features doesn't match training data".to_string(),
1073 ));
1074 }
1075
1076 let mut scores =
1077 Array2::<Float>::zeros((n_samples, self.state.binary_relevance.n_labels()));
1078
1079 for label_idx in 0..self.state.binary_relevance.n_labels() {
1081 if let Some((weights, bias)) = self
1082 .state
1083 .binary_relevance
1084 .state
1085 .binary_classifiers
1086 .get(&label_idx)
1087 {
1088 for (sample_idx, sample) in X.axis_iter(Axis(0)).enumerate() {
1089 let score: f64 = sample
1090 .iter()
1091 .zip(weights.iter())
1092 .map(|(&x, &w)| x * w)
1093 .sum::<f64>()
1094 + bias;
1095
1096 scores[[sample_idx, label_idx]] = score;
1097 }
1098 }
1099 }
1100
1101 Ok(scores)
1102 }
1103}
1104
1105#[derive(Debug, Clone)]
1107pub struct OneVsRestClassifierTrained {
1108 pub binary_relevance: BinaryRelevance<BinaryRelevanceTrained>,
1110}