1use crate::bagging::BaggingClassifier;
8use scirs2_core::ndarray::{Array1, Array2};
11use sklears_core::{
12 error::Result as SklResult,
13 prelude::{Predict, SklearsError},
14 traits::{Estimator, Fit, Trained, Untrained},
15};
16use std::collections::HashMap;
17
18fn gen_f64(rng: &mut impl scirs2_core::random::RngCore) -> f64 {
20 let mut bytes = [0u8; 8];
21 rng.fill_bytes(&mut bytes);
22 f64::from_le_bytes(bytes) / f64::from_le_bytes([255u8; 8])
23}
24
25fn gen_range_usize(
27 rng: &mut impl scirs2_core::random::RngCore,
28 range: std::ops::Range<usize>,
29) -> usize {
30 let mut bytes = [0u8; 8];
31 rng.fill_bytes(&mut bytes);
32 let val = u64::from_le_bytes(bytes);
33 range.start + (val as usize % (range.end - range.start))
34}
35
36#[derive(Debug, Clone)]
38pub struct MultiLabelEnsembleConfig {
39 pub n_estimators: usize,
41 pub transformation_strategy: LabelTransformationStrategy,
43 pub aggregation_method: MultiLabelAggregationMethod,
45 pub correlation_method: LabelCorrelationMethod,
47 pub threshold: f64,
49 pub random_state: Option<u64>,
51 pub prune_labelsets: bool,
53 pub max_labelsets: Option<usize>,
55 pub chain_order: Option<Vec<usize>>,
57 pub ensemble_chains: bool,
59 pub n_chains: usize,
61}
62
63impl Default for MultiLabelEnsembleConfig {
64 fn default() -> Self {
65 Self {
66 n_estimators: 10,
67 transformation_strategy: LabelTransformationStrategy::BinaryRelevance,
68 aggregation_method: MultiLabelAggregationMethod::Voting,
69 correlation_method: LabelCorrelationMethod::Independent,
70 threshold: 0.5,
71 random_state: None,
72 prune_labelsets: true,
73 max_labelsets: Some(100),
74 chain_order: None,
75 ensemble_chains: false,
76 n_chains: 3,
77 }
78 }
79}
80
81#[derive(Debug, Clone, PartialEq)]
83pub enum LabelTransformationStrategy {
84 BinaryRelevance,
86 LabelPowerset,
88 ClassifierChains,
90 EnsembleOfClassifierChains,
92 AdaptedAlgorithm,
94 RandomKLabelsets,
96}
97
98#[derive(Debug, Clone, PartialEq)]
100pub enum MultiLabelAggregationMethod {
101 Voting,
103 WeightedVoting,
105 MaxProbability,
107 MeanProbability,
109 MedianProbability,
111 ThresholdAggregation,
113 RankAggregation,
115}
116
117#[derive(Debug, Clone, PartialEq)]
119pub enum LabelCorrelationMethod {
120 Independent,
122 Pairwise,
124 HigherOrder,
126 ConditionalIndependence,
128 LearnedCorrelation,
130}
131
132pub struct MultiLabelEnsembleClassifier<State = Untrained> {
134 config: MultiLabelEnsembleConfig,
135 state: std::marker::PhantomData<State>,
136 base_classifiers: Option<Vec<BaggingClassifier<Trained>>>,
138 label_indices: Option<Vec<usize>>,
139 labelset_mapping: Option<HashMap<Vec<usize>, usize>>,
140 inverse_labelset_mapping: Option<HashMap<usize, Vec<usize>>>,
141 label_correlations: Option<Array2<f64>>,
142 chain_orders: Option<Vec<Vec<usize>>>,
143 threshold_per_label: Option<Vec<f64>>,
144 n_labels: Option<usize>,
145}
146
147#[derive(Debug, Clone)]
149pub struct MultiLabelTrainingResults {
150 pub n_labelsets: usize,
152 pub label_frequencies: HashMap<usize, usize>,
154 pub labelset_frequencies: HashMap<Vec<usize>, usize>,
156 pub label_correlations: Array2<f64>,
158 pub training_time_ms: u64,
160}
161
162#[derive(Debug, Clone)]
164pub struct MultiLabelPredictionResults {
165 pub predictions: Array2<usize>,
167 pub probabilities: Array2<f64>,
169 pub confidence_scores: Vec<f64>,
171 pub ranking_scores: Array2<f64>,
173}
174
175impl MultiLabelEnsembleClassifier<Untrained> {
176 pub fn new(config: MultiLabelEnsembleConfig) -> Self {
178 Self {
179 config,
180 state: std::marker::PhantomData,
181 base_classifiers: None,
182 label_indices: None,
183 labelset_mapping: None,
184 inverse_labelset_mapping: None,
185 label_correlations: None,
186 chain_orders: None,
187 threshold_per_label: None,
188 n_labels: None,
189 }
190 }
191
192 pub fn binary_relevance() -> Self {
194 let config = MultiLabelEnsembleConfig {
195 transformation_strategy: LabelTransformationStrategy::BinaryRelevance,
196 ..Default::default()
197 };
198 Self::new(config)
199 }
200
201 pub fn label_powerset() -> Self {
203 let config = MultiLabelEnsembleConfig {
204 transformation_strategy: LabelTransformationStrategy::LabelPowerset,
205 ..Default::default()
206 };
207 Self::new(config)
208 }
209
210 pub fn classifier_chains() -> Self {
212 let config = MultiLabelEnsembleConfig {
213 transformation_strategy: LabelTransformationStrategy::ClassifierChains,
214 ..Default::default()
215 };
216 Self::new(config)
217 }
218
219 pub fn ensemble_classifier_chains() -> Self {
221 let config = MultiLabelEnsembleConfig {
222 transformation_strategy: LabelTransformationStrategy::EnsembleOfClassifierChains,
223 ensemble_chains: true,
224 n_chains: 5,
225 ..Default::default()
226 };
227 Self::new(config)
228 }
229
230 pub fn n_estimators(mut self, n_estimators: usize) -> Self {
232 self.config.n_estimators = n_estimators;
233 self
234 }
235
236 pub fn aggregation_method(mut self, method: MultiLabelAggregationMethod) -> Self {
238 self.config.aggregation_method = method;
239 self
240 }
241
242 pub fn correlation_method(mut self, method: LabelCorrelationMethod) -> Self {
244 self.config.correlation_method = method;
245 self
246 }
247
248 pub fn threshold(mut self, threshold: f64) -> Self {
250 self.config.threshold = threshold;
251 self
252 }
253
254 pub fn random_state(mut self, seed: u64) -> Self {
256 self.config.random_state = Some(seed);
257 self
258 }
259
260 pub fn prune_labelsets(mut self, prune: bool) -> Self {
262 self.config.prune_labelsets = prune;
263 self
264 }
265
266 fn extract_labelsets(
268 &self,
269 y: &Array2<usize>,
270 ) -> SklResult<(HashMap<Vec<usize>, usize>, HashMap<usize, Vec<usize>>)> {
271 let mut labelset_mapping = HashMap::new();
272 let mut inverse_mapping = HashMap::new();
273 let mut labelset_id = 0;
274
275 for row in y.outer_iter() {
276 let labelset: Vec<usize> = row
277 .iter()
278 .enumerate()
279 .filter(|(_, &label)| label == 1)
280 .map(|(idx, _)| idx)
281 .collect();
282
283 if !labelset_mapping.contains_key(&labelset) {
284 labelset_mapping.insert(labelset.clone(), labelset_id);
285 inverse_mapping.insert(labelset_id, labelset);
286 labelset_id += 1;
287 }
288 }
289
290 if self.config.prune_labelsets {
292 if let Some(max_labelsets) = self.config.max_labelsets {
293 if labelset_mapping.len() > max_labelsets {
294 let mut labelset_counts: Vec<_> = labelset_mapping.iter().collect();
296 labelset_counts.sort_by_key(|(labelset, _)| labelset.len());
297 labelset_counts.truncate(max_labelsets);
298
299 labelset_mapping = labelset_counts
300 .into_iter()
301 .enumerate()
302 .map(|(new_id, (labelset, _))| (labelset.clone(), new_id))
303 .collect();
304
305 inverse_mapping = labelset_mapping
306 .iter()
307 .map(|(labelset, &id)| (id, labelset.clone()))
308 .collect();
309 }
310 }
311 }
312
313 Ok((labelset_mapping, inverse_mapping))
314 }
315
316 fn compute_label_correlations(&self, y: &Array2<usize>) -> SklResult<Array2<f64>> {
318 let n_labels = y.ncols();
319 let mut correlations = Array2::zeros((n_labels, n_labels));
320
321 for i in 0..n_labels {
322 for j in i..n_labels {
323 if i == j {
324 correlations[[i, j]] = 1.0;
325 } else {
326 let mut intersection = 0;
328 let mut union = 0;
329
330 for k in 0..y.nrows() {
331 let label_i = y[[k, i]];
332 let label_j = y[[k, j]];
333
334 if label_i == 1 && label_j == 1 {
335 intersection += 1;
336 }
337 if label_i == 1 || label_j == 1 {
338 union += 1;
339 }
340 }
341
342 let correlation = if union > 0 {
343 intersection as f64 / union as f64
344 } else {
345 0.0
346 };
347
348 correlations[[i, j]] = correlation;
349 correlations[[j, i]] = correlation;
350 }
351 }
352 }
353
354 Ok(correlations)
355 }
356
357 fn generate_chain_orders(&self, n_labels: usize, n_chains: usize) -> Vec<Vec<usize>> {
359 let mut chains = Vec::new();
360 let mut rng = if let Some(seed) = self.config.random_state {
361 scirs2_core::random::seeded_rng(seed)
362 } else {
363 scirs2_core::random::seeded_rng(42)
364 };
365
366 for _ in 0..n_chains {
367 let mut order: Vec<usize> = (0..n_labels).collect();
368
369 for i in (1..order.len()).rev() {
371 let j = gen_range_usize(&mut rng, 0..(i + 1));
372 order.swap(i, j);
373 }
374
375 chains.push(order);
376 }
377
378 chains
379 }
380}
381
382impl Estimator for MultiLabelEnsembleClassifier<Untrained> {
383 type Config = MultiLabelEnsembleConfig;
384 type Error = SklearsError;
385 type Float = f64;
386
387 fn config(&self) -> &Self::Config {
388 &self.config
389 }
390}
391
392impl Fit<Array2<f64>, Array2<usize>> for MultiLabelEnsembleClassifier<Untrained> {
393 type Fitted = MultiLabelEnsembleClassifier<Trained>;
394
395 fn fit(self, X: &Array2<f64>, y: &Array2<usize>) -> SklResult<Self::Fitted> {
396 if X.nrows() != y.nrows() {
397 return Err(SklearsError::ShapeMismatch {
398 expected: format!("{} samples", X.nrows()),
399 actual: format!("{} samples", y.nrows()),
400 });
401 }
402
403 let n_labels = y.ncols();
404 let mut base_classifiers = Vec::new();
405 let mut chain_orders = Vec::new();
406
407 let label_correlations = self.compute_label_correlations(y)?;
409
410 match self.config.transformation_strategy {
411 LabelTransformationStrategy::BinaryRelevance => {
412 for label_idx in 0..n_labels {
414 let y_binary: Vec<usize> = y.column(label_idx).to_vec();
415
416 let y_binary_array =
417 Array1::from_vec(y_binary.iter().map(|&x| x as i32).collect());
418 let classifier = BaggingClassifier::new()
419 .n_estimators(self.config.n_estimators)
420 .fit(X, &y_binary_array)?;
421
422 base_classifiers.push(classifier);
423 }
424 }
425
426 LabelTransformationStrategy::LabelPowerset => {
427 let (labelset_mapping, _) = self.extract_labelsets(y)?;
429
430 let mut y_labelsets = Vec::new();
432 for row in y.outer_iter() {
433 let labelset: Vec<usize> = row
434 .iter()
435 .enumerate()
436 .filter(|(_, &label)| label == 1)
437 .map(|(idx, _)| idx)
438 .collect();
439
440 if let Some(&labelset_id) = labelset_mapping.get(&labelset) {
441 y_labelsets.push(labelset_id);
442 } else {
443 y_labelsets.push(0);
445 }
446 }
447
448 let y_labelsets_array =
449 Array1::from_vec(y_labelsets.iter().map(|&x| x as i32).collect());
450 let classifier = BaggingClassifier::new()
451 .n_estimators(self.config.n_estimators)
452 .fit(X, &y_labelsets_array)?;
453
454 base_classifiers.push(classifier);
455 }
456
457 LabelTransformationStrategy::EnsembleOfClassifierChains => {
458 chain_orders = self.generate_chain_orders(n_labels, self.config.n_chains);
460
461 for chain_order in &chain_orders {
462 for &label_idx in chain_order {
464 let y_binary: Vec<usize> = y.column(label_idx).to_vec();
465
466 let y_binary_array =
467 Array1::from_vec(y_binary.iter().map(|&x| x as i32).collect());
468 let classifier = BaggingClassifier::new()
469 .n_estimators(5) .fit(X, &y_binary_array)?;
471
472 base_classifiers.push(classifier);
473 }
474 }
475 }
476
477 _ => {
478 for label_idx in 0..n_labels {
480 let y_binary: Vec<usize> = y.column(label_idx).to_vec();
481
482 let y_binary_array =
483 Array1::from_vec(y_binary.iter().map(|&x| x as i32).collect());
484 let classifier = BaggingClassifier::new()
485 .n_estimators(self.config.n_estimators)
486 .fit(X, &y_binary_array)?;
487
488 base_classifiers.push(classifier);
489 }
490 }
491 }
492
493 let label_indices: Vec<usize> = (0..n_labels).collect();
495
496 let threshold_per_label = vec![self.config.threshold; n_labels];
498
499 let (labelset_mapping, inverse_labelset_mapping) = if matches!(
501 self.config.transformation_strategy,
502 LabelTransformationStrategy::LabelPowerset
503 ) {
504 let (forward, inverse) = self.extract_labelsets(y)?;
505 (Some(forward), Some(inverse))
506 } else {
507 (None, None)
508 };
509
510 Ok(MultiLabelEnsembleClassifier {
511 config: self.config,
512 state: std::marker::PhantomData,
513 base_classifiers: Some(base_classifiers),
514 label_indices: Some(label_indices),
515 labelset_mapping,
516 inverse_labelset_mapping,
517 label_correlations: Some(label_correlations),
518 chain_orders: if chain_orders.is_empty() {
519 None
520 } else {
521 Some(chain_orders)
522 },
523 threshold_per_label: Some(threshold_per_label),
524 n_labels: Some(n_labels),
525 })
526 }
527}
528
529impl Predict<Array2<f64>, MultiLabelPredictionResults> for MultiLabelEnsembleClassifier<Trained> {
530 fn predict(&self, X: &Array2<f64>) -> SklResult<MultiLabelPredictionResults> {
531 let base_classifiers = self.base_classifiers.as_ref().expect("Model is trained");
532 let n_labels = self.n_labels.expect("Model is trained");
533 let threshold_per_label = self.threshold_per_label.as_ref().expect("Model is trained");
534
535 let n_samples = X.nrows();
536 let mut predictions = Array2::zeros((n_samples, n_labels));
537 let mut probabilities = Array2::zeros((n_samples, n_labels));
538 let mut ranking_scores = Array2::zeros((n_samples, n_labels));
539
540 match self.config.transformation_strategy {
541 LabelTransformationStrategy::BinaryRelevance => {
542 for (label_idx, classifier) in base_classifiers.iter().enumerate().take(n_labels) {
544 let label_predictions = classifier.predict(X)?;
545
546 for (sample_idx, &pred) in label_predictions.iter().enumerate() {
548 predictions[[sample_idx, label_idx]] = pred as usize;
549 probabilities[[sample_idx, label_idx]] = pred as f64;
550 ranking_scores[[sample_idx, label_idx]] = pred as f64;
551 }
552 }
553 }
554
555 LabelTransformationStrategy::LabelPowerset => {
556 if let (Some(labelset_mapping), Some(inverse_labelset_mapping)) =
557 (&self.labelset_mapping, &self.inverse_labelset_mapping)
558 {
559 let labelset_predictions = base_classifiers[0].predict(X)?;
560
561 for (sample_idx, &labelset_id) in labelset_predictions.iter().enumerate() {
562 if let Some(labelset) =
563 inverse_labelset_mapping.get(&(labelset_id as usize))
564 {
565 for &label_idx in labelset {
566 if label_idx < n_labels {
567 predictions[[sample_idx, label_idx]] = 1;
568 probabilities[[sample_idx, label_idx]] = 1.0;
569 ranking_scores[[sample_idx, label_idx]] = 1.0;
570 }
571 }
572 }
573 }
574 }
575 }
576
577 _ => {
578 for (label_idx, classifier) in base_classifiers.iter().enumerate().take(n_labels) {
580 let label_predictions = classifier.predict(X)?;
581
582 for (sample_idx, &pred) in label_predictions.iter().enumerate() {
583 predictions[[sample_idx, label_idx]] = pred as usize;
584 probabilities[[sample_idx, label_idx]] = pred as f64;
585 ranking_scores[[sample_idx, label_idx]] = pred as f64;
586 }
587 }
588 }
589 }
590
591 for i in 0..n_samples {
593 for j in 0..n_labels {
594 if probabilities[[i, j]] >= threshold_per_label[j] {
595 predictions[[i, j]] = 1;
596 } else {
597 predictions[[i, j]] = 0;
598 }
599 }
600 }
601
602 let confidence_scores: Vec<f64> = (0..n_samples)
604 .map(|i| {
605 let row_probs: Vec<f64> = (0..n_labels).map(|j| probabilities[[i, j]]).collect();
606 row_probs.iter().sum::<f64>() / n_labels as f64
607 })
608 .collect();
609
610 Ok(MultiLabelPredictionResults {
611 predictions,
612 probabilities,
613 confidence_scores,
614 ranking_scores,
615 })
616 }
617}
618
619impl MultiLabelEnsembleClassifier<Trained> {
620 pub fn n_labels(&self) -> usize {
622 self.n_labels.expect("Model is trained")
623 }
624
625 pub fn label_correlations(&self) -> &Array2<f64> {
627 self.label_correlations.as_ref().expect("Model is trained")
628 }
629
630 pub fn transformation_strategy(&self) -> &LabelTransformationStrategy {
632 &self.config.transformation_strategy
633 }
634
635 pub fn predict_binary(&self, X: &Array2<f64>) -> SklResult<Array2<usize>> {
637 let results = self.predict(X)?;
638 Ok(results.predictions)
639 }
640
641 pub fn predict_proba(&self, X: &Array2<f64>) -> SklResult<Array2<f64>> {
643 let results = self.predict(X)?;
644 Ok(results.probabilities)
645 }
646
647 pub fn predict_rankings(&self, X: &Array2<f64>) -> SklResult<Array2<f64>> {
649 let results = self.predict(X)?;
650 Ok(results.ranking_scores)
651 }
652}
653
654#[allow(non_snake_case)]
655#[cfg(test)]
656mod tests {
657 use super::*;
658 use scirs2_core::ndarray::array;
659
660 #[test]
661 #[allow(non_snake_case)]
662 fn test_multi_label_binary_relevance() {
663 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
664
665 let y = array![[1, 0, 1], [0, 1, 1], [1, 1, 0], [0, 0, 1]];
666
667 let classifier = MultiLabelEnsembleClassifier::binary_relevance()
668 .n_estimators(3)
669 .random_state(42);
670
671 let trained = classifier.fit(&X, &y).expect("Training should succeed");
672 let results = trained.predict(&X).expect("Prediction should succeed");
673
674 assert_eq!(results.predictions.nrows(), 4);
675 assert_eq!(results.predictions.ncols(), 3);
676 assert_eq!(results.probabilities.nrows(), 4);
677 assert_eq!(results.probabilities.ncols(), 3);
678 }
679
680 #[test]
681 #[allow(non_snake_case)]
682 fn test_multi_label_label_powerset() {
683 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
684
685 let y = array![[1, 0, 1], [0, 1, 1], [1, 1, 0], [0, 0, 1]];
686
687 let classifier = MultiLabelEnsembleClassifier::label_powerset()
688 .n_estimators(5)
689 .random_state(42);
690
691 let trained = classifier.fit(&X, &y).expect("Training should succeed");
692 let results = trained.predict(&X).expect("Prediction should succeed");
693
694 assert_eq!(results.predictions.nrows(), 4);
695 assert_eq!(results.predictions.ncols(), 3);
696 assert_eq!(trained.n_labels(), 3);
697 }
698
699 #[test]
700 fn test_label_correlation_computation() {
701 let y = array![[1, 1, 0], [1, 0, 1], [0, 1, 1], [1, 1, 1]];
702
703 let classifier = MultiLabelEnsembleClassifier::binary_relevance();
704 let correlations = classifier
705 .compute_label_correlations(&y)
706 .expect("Should compute correlations");
707
708 assert_eq!(correlations.nrows(), 3);
709 assert_eq!(correlations.ncols(), 3);
710
711 for i in 0..3 {
713 assert_eq!(correlations[[i, i]], 1.0);
714 }
715 }
716
717 #[test]
718 fn test_labelset_extraction() {
719 let y = array![
720 [1, 0, 1],
721 [0, 1, 1],
722 [1, 1, 0],
723 [1, 0, 1] ];
725
726 let classifier = MultiLabelEnsembleClassifier::label_powerset();
727 let (labelset_mapping, inverse_mapping) = classifier
728 .extract_labelsets(&y)
729 .expect("Should extract labelsets");
730
731 assert_eq!(labelset_mapping.len(), 3);
733 assert_eq!(inverse_mapping.len(), 3);
734
735 assert!(labelset_mapping.contains_key(&vec![0, 2])); assert!(labelset_mapping.contains_key(&vec![1, 2])); assert!(labelset_mapping.contains_key(&vec![0, 1])); }
740
741 #[test]
742 #[allow(non_snake_case)]
743 fn test_ensemble_classifier_chains() {
744 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
745
746 let y = array![[1, 0, 1], [0, 1, 1], [1, 1, 0], [0, 0, 1]];
747
748 let classifier =
749 MultiLabelEnsembleClassifier::ensemble_classifier_chains().random_state(42);
750
751 let trained = classifier.fit(&X, &y).expect("Training should succeed");
752 let results = trained.predict(&X).expect("Prediction should succeed");
753
754 assert_eq!(results.predictions.nrows(), 4);
755 assert_eq!(results.predictions.ncols(), 3);
756
757 assert!(trained.chain_orders.is_some());
759 }
760}