sklears_ensemble/
multi_label.rs

1//! Multi-Label Ensemble Methods
2//!
3//! This module provides ensemble methods specifically designed for multi-label classification
4//! where each instance can belong to multiple classes simultaneously. It includes various
5//! label transformation strategies, ensemble voting mechanisms, and performance optimizations.
6
7use crate::bagging::BaggingClassifier;
8// ❌ REMOVED: rand_chacha::rand_core - use scirs2_core::random instead
9// ❌ REMOVED: rand_chacha::scirs2_core::random::rngs::StdRng - use scirs2_core::random instead
10use scirs2_core::ndarray::{Array1, Array2};
11use sklears_core::{
12    error::Result as SklResult,
13    prelude::{Predict, SklearsError},
14    traits::{Estimator, Fit, Trained, Untrained},
15};
16use std::collections::HashMap;
17
18/// Helper function to generate random f64 from scirs2_core::random::RngCore
19fn gen_f64(rng: &mut impl scirs2_core::random::RngCore) -> f64 {
20    let mut bytes = [0u8; 8];
21    rng.fill_bytes(&mut bytes);
22    f64::from_le_bytes(bytes) / f64::from_le_bytes([255u8; 8])
23}
24
25/// Helper function to generate random value in range from scirs2_core::random::RngCore
26fn gen_range_usize(
27    rng: &mut impl scirs2_core::random::RngCore,
28    range: std::ops::Range<usize>,
29) -> usize {
30    let mut bytes = [0u8; 8];
31    rng.fill_bytes(&mut bytes);
32    let val = u64::from_le_bytes(bytes);
33    range.start + (val as usize % (range.end - range.start))
34}
35
36/// Configuration for multi-label ensemble methods
37#[derive(Debug, Clone)]
38pub struct MultiLabelEnsembleConfig {
39    /// Number of base estimators
40    pub n_estimators: usize,
41    /// Multi-label transformation strategy
42    pub transformation_strategy: LabelTransformationStrategy,
43    /// Ensemble aggregation method for multi-label predictions
44    pub aggregation_method: MultiLabelAggregationMethod,
45    /// Label correlation handling approach
46    pub correlation_method: LabelCorrelationMethod,
47    /// Threshold for binary relevance predictions
48    pub threshold: f64,
49    /// Random state for reproducibility
50    pub random_state: Option<u64>,
51    /// Whether to use label powerset pruning
52    pub prune_labelsets: bool,
53    /// Maximum number of labelsets to consider
54    pub max_labelsets: Option<usize>,
55    /// Label dependency order for classifier chains
56    pub chain_order: Option<Vec<usize>>,
57    /// Whether to use ensemble of chains
58    pub ensemble_chains: bool,
59    /// Number of chains in ensemble chains
60    pub n_chains: usize,
61}
62
63impl Default for MultiLabelEnsembleConfig {
64    fn default() -> Self {
65        Self {
66            n_estimators: 10,
67            transformation_strategy: LabelTransformationStrategy::BinaryRelevance,
68            aggregation_method: MultiLabelAggregationMethod::Voting,
69            correlation_method: LabelCorrelationMethod::Independent,
70            threshold: 0.5,
71            random_state: None,
72            prune_labelsets: true,
73            max_labelsets: Some(100),
74            chain_order: None,
75            ensemble_chains: false,
76            n_chains: 3,
77        }
78    }
79}
80
81/// Label transformation strategies for multi-label classification
82#[derive(Debug, Clone, PartialEq)]
83pub enum LabelTransformationStrategy {
84    /// Binary Relevance - train one classifier per label
85    BinaryRelevance,
86    /// Label Powerset - treat each unique label combination as a class
87    LabelPowerset,
88    /// Classifier Chains - chain classifiers to model label dependencies
89    ClassifierChains,
90    /// Ensemble of Classifier Chains
91    EnsembleOfClassifierChains,
92    /// Adapted Algorithm - adapt base algorithm for multi-label
93    AdaptedAlgorithm,
94    /// Random k-labelsets - random subset of label combinations
95    RandomKLabelsets,
96}
97
98/// Multi-label aggregation methods
99#[derive(Debug, Clone, PartialEq)]
100pub enum MultiLabelAggregationMethod {
101    /// Simple voting across base estimators
102    Voting,
103    /// Weighted voting based on estimator performance
104    WeightedVoting,
105    /// Maximum probability across estimators
106    MaxProbability,
107    /// Mean probability across estimators
108    MeanProbability,
109    /// Median probability across estimators
110    MedianProbability,
111    /// Threshold-based aggregation
112    ThresholdAggregation,
113    /// Rank-based aggregation
114    RankAggregation,
115}
116
117/// Label correlation handling methods
118#[derive(Debug, Clone, PartialEq)]
119pub enum LabelCorrelationMethod {
120    /// Treat labels as independent
121    Independent,
122    /// Model pairwise label correlations
123    Pairwise,
124    /// Model higher-order label correlations
125    HigherOrder,
126    /// Use conditional independence assumptions
127    ConditionalIndependence,
128    /// Learn label correlation structure
129    LearnedCorrelation,
130}
131
132/// Multi-label ensemble classifier
133pub struct MultiLabelEnsembleClassifier<State = Untrained> {
134    config: MultiLabelEnsembleConfig,
135    state: std::marker::PhantomData<State>,
136    // Fitted attributes - only populated after training
137    base_classifiers: Option<Vec<BaggingClassifier<Trained>>>,
138    label_indices: Option<Vec<usize>>,
139    labelset_mapping: Option<HashMap<Vec<usize>, usize>>,
140    inverse_labelset_mapping: Option<HashMap<usize, Vec<usize>>>,
141    label_correlations: Option<Array2<f64>>,
142    chain_orders: Option<Vec<Vec<usize>>>,
143    threshold_per_label: Option<Vec<f64>>,
144    n_labels: Option<usize>,
145}
146
147/// Results from multi-label ensemble training
148#[derive(Debug, Clone)]
149pub struct MultiLabelTrainingResults {
150    /// Number of unique labelsets found
151    pub n_labelsets: usize,
152    /// Label frequency distribution
153    pub label_frequencies: HashMap<usize, usize>,
154    /// Labelset frequency distribution
155    pub labelset_frequencies: HashMap<Vec<usize>, usize>,
156    /// Label correlation matrix
157    pub label_correlations: Array2<f64>,
158    /// Training time metrics
159    pub training_time_ms: u64,
160}
161
162/// Multi-label prediction results
163#[derive(Debug, Clone)]
164pub struct MultiLabelPredictionResults {
165    /// Binary predictions for each label
166    pub predictions: Array2<usize>,
167    /// Probability scores for each label
168    pub probabilities: Array2<f64>,
169    /// Confidence scores for predictions
170    pub confidence_scores: Vec<f64>,
171    /// Label ranking scores
172    pub ranking_scores: Array2<f64>,
173}
174
175impl MultiLabelEnsembleClassifier<Untrained> {
176    /// Create a new multi-label ensemble classifier
177    pub fn new(config: MultiLabelEnsembleConfig) -> Self {
178        Self {
179            config,
180            state: std::marker::PhantomData,
181            base_classifiers: None,
182            label_indices: None,
183            labelset_mapping: None,
184            inverse_labelset_mapping: None,
185            label_correlations: None,
186            chain_orders: None,
187            threshold_per_label: None,
188            n_labels: None,
189        }
190    }
191
192    /// Create a new multi-label ensemble classifier with binary relevance
193    pub fn binary_relevance() -> Self {
194        let config = MultiLabelEnsembleConfig {
195            transformation_strategy: LabelTransformationStrategy::BinaryRelevance,
196            ..Default::default()
197        };
198        Self::new(config)
199    }
200
201    /// Create a new multi-label ensemble classifier with label powerset
202    pub fn label_powerset() -> Self {
203        let config = MultiLabelEnsembleConfig {
204            transformation_strategy: LabelTransformationStrategy::LabelPowerset,
205            ..Default::default()
206        };
207        Self::new(config)
208    }
209
210    /// Create a new multi-label ensemble classifier with classifier chains
211    pub fn classifier_chains() -> Self {
212        let config = MultiLabelEnsembleConfig {
213            transformation_strategy: LabelTransformationStrategy::ClassifierChains,
214            ..Default::default()
215        };
216        Self::new(config)
217    }
218
219    /// Create a new multi-label ensemble classifier with ensemble of classifier chains
220    pub fn ensemble_classifier_chains() -> Self {
221        let config = MultiLabelEnsembleConfig {
222            transformation_strategy: LabelTransformationStrategy::EnsembleOfClassifierChains,
223            ensemble_chains: true,
224            n_chains: 5,
225            ..Default::default()
226        };
227        Self::new(config)
228    }
229
230    /// Builder method to configure the number of estimators
231    pub fn n_estimators(mut self, n_estimators: usize) -> Self {
232        self.config.n_estimators = n_estimators;
233        self
234    }
235
236    /// Builder method to configure the aggregation method
237    pub fn aggregation_method(mut self, method: MultiLabelAggregationMethod) -> Self {
238        self.config.aggregation_method = method;
239        self
240    }
241
242    /// Builder method to configure the correlation method
243    pub fn correlation_method(mut self, method: LabelCorrelationMethod) -> Self {
244        self.config.correlation_method = method;
245        self
246    }
247
248    /// Builder method to configure the threshold
249    pub fn threshold(mut self, threshold: f64) -> Self {
250        self.config.threshold = threshold;
251        self
252    }
253
254    /// Builder method to configure random state
255    pub fn random_state(mut self, seed: u64) -> Self {
256        self.config.random_state = Some(seed);
257        self
258    }
259
260    /// Builder method to configure label powerset pruning
261    pub fn prune_labelsets(mut self, prune: bool) -> Self {
262        self.config.prune_labelsets = prune;
263        self
264    }
265
266    /// Extract unique labelsets from multi-label target matrix
267    fn extract_labelsets(
268        &self,
269        y: &Array2<usize>,
270    ) -> SklResult<(HashMap<Vec<usize>, usize>, HashMap<usize, Vec<usize>>)> {
271        let mut labelset_mapping = HashMap::new();
272        let mut inverse_mapping = HashMap::new();
273        let mut labelset_id = 0;
274
275        for row in y.outer_iter() {
276            let labelset: Vec<usize> = row
277                .iter()
278                .enumerate()
279                .filter(|(_, &label)| label == 1)
280                .map(|(idx, _)| idx)
281                .collect();
282
283            if !labelset_mapping.contains_key(&labelset) {
284                labelset_mapping.insert(labelset.clone(), labelset_id);
285                inverse_mapping.insert(labelset_id, labelset);
286                labelset_id += 1;
287            }
288        }
289
290        // Prune rare labelsets if configured
291        if self.config.prune_labelsets {
292            if let Some(max_labelsets) = self.config.max_labelsets {
293                if labelset_mapping.len() > max_labelsets {
294                    // Keep only the most frequent labelsets
295                    let mut labelset_counts: Vec<_> = labelset_mapping.iter().collect();
296                    labelset_counts.sort_by_key(|(labelset, _)| labelset.len());
297                    labelset_counts.truncate(max_labelsets);
298
299                    labelset_mapping = labelset_counts
300                        .into_iter()
301                        .enumerate()
302                        .map(|(new_id, (labelset, _))| (labelset.clone(), new_id))
303                        .collect();
304
305                    inverse_mapping = labelset_mapping
306                        .iter()
307                        .map(|(labelset, &id)| (id, labelset.clone()))
308                        .collect();
309                }
310            }
311        }
312
313        Ok((labelset_mapping, inverse_mapping))
314    }
315
316    /// Compute label correlations
317    fn compute_label_correlations(&self, y: &Array2<usize>) -> SklResult<Array2<f64>> {
318        let n_labels = y.ncols();
319        let mut correlations = Array2::zeros((n_labels, n_labels));
320
321        for i in 0..n_labels {
322            for j in i..n_labels {
323                if i == j {
324                    correlations[[i, j]] = 1.0;
325                } else {
326                    // Compute Jaccard similarity
327                    let mut intersection = 0;
328                    let mut union = 0;
329
330                    for k in 0..y.nrows() {
331                        let label_i = y[[k, i]];
332                        let label_j = y[[k, j]];
333
334                        if label_i == 1 && label_j == 1 {
335                            intersection += 1;
336                        }
337                        if label_i == 1 || label_j == 1 {
338                            union += 1;
339                        }
340                    }
341
342                    let correlation = if union > 0 {
343                        intersection as f64 / union as f64
344                    } else {
345                        0.0
346                    };
347
348                    correlations[[i, j]] = correlation;
349                    correlations[[j, i]] = correlation;
350                }
351            }
352        }
353
354        Ok(correlations)
355    }
356
357    /// Generate chain orders for classifier chains
358    fn generate_chain_orders(&self, n_labels: usize, n_chains: usize) -> Vec<Vec<usize>> {
359        let mut chains = Vec::new();
360        let mut rng = if let Some(seed) = self.config.random_state {
361            scirs2_core::random::seeded_rng(seed)
362        } else {
363            scirs2_core::random::seeded_rng(42)
364        };
365
366        for _ in 0..n_chains {
367            let mut order: Vec<usize> = (0..n_labels).collect();
368
369            // Shuffle the order
370            for i in (1..order.len()).rev() {
371                let j = gen_range_usize(&mut rng, 0..(i + 1));
372                order.swap(i, j);
373            }
374
375            chains.push(order);
376        }
377
378        chains
379    }
380}
381
382impl Estimator for MultiLabelEnsembleClassifier<Untrained> {
383    type Config = MultiLabelEnsembleConfig;
384    type Error = SklearsError;
385    type Float = f64;
386
387    fn config(&self) -> &Self::Config {
388        &self.config
389    }
390}
391
392impl Fit<Array2<f64>, Array2<usize>> for MultiLabelEnsembleClassifier<Untrained> {
393    type Fitted = MultiLabelEnsembleClassifier<Trained>;
394
395    fn fit(self, X: &Array2<f64>, y: &Array2<usize>) -> SklResult<Self::Fitted> {
396        if X.nrows() != y.nrows() {
397            return Err(SklearsError::ShapeMismatch {
398                expected: format!("{} samples", X.nrows()),
399                actual: format!("{} samples", y.nrows()),
400            });
401        }
402
403        let n_labels = y.ncols();
404        let mut base_classifiers = Vec::new();
405        let mut chain_orders = Vec::new();
406
407        // Compute label correlations
408        let label_correlations = self.compute_label_correlations(y)?;
409
410        match self.config.transformation_strategy {
411            LabelTransformationStrategy::BinaryRelevance => {
412                // Train one classifier per label
413                for label_idx in 0..n_labels {
414                    let y_binary: Vec<usize> = y.column(label_idx).to_vec();
415
416                    let y_binary_array =
417                        Array1::from_vec(y_binary.iter().map(|&x| x as i32).collect());
418                    let classifier = BaggingClassifier::new()
419                        .n_estimators(self.config.n_estimators)
420                        .fit(X, &y_binary_array)?;
421
422                    base_classifiers.push(classifier);
423                }
424            }
425
426            LabelTransformationStrategy::LabelPowerset => {
427                // Extract labelsets and train single multi-class classifier
428                let (labelset_mapping, _) = self.extract_labelsets(y)?;
429
430                // Convert multi-label matrix to single-label vector
431                let mut y_labelsets = Vec::new();
432                for row in y.outer_iter() {
433                    let labelset: Vec<usize> = row
434                        .iter()
435                        .enumerate()
436                        .filter(|(_, &label)| label == 1)
437                        .map(|(idx, _)| idx)
438                        .collect();
439
440                    if let Some(&labelset_id) = labelset_mapping.get(&labelset) {
441                        y_labelsets.push(labelset_id);
442                    } else {
443                        // Handle unseen labelsets (assign to empty set)
444                        y_labelsets.push(0);
445                    }
446                }
447
448                let y_labelsets_array =
449                    Array1::from_vec(y_labelsets.iter().map(|&x| x as i32).collect());
450                let classifier = BaggingClassifier::new()
451                    .n_estimators(self.config.n_estimators)
452                    .fit(X, &y_labelsets_array)?;
453
454                base_classifiers.push(classifier);
455            }
456
457            LabelTransformationStrategy::EnsembleOfClassifierChains => {
458                // Generate multiple chain orders
459                chain_orders = self.generate_chain_orders(n_labels, self.config.n_chains);
460
461                for chain_order in &chain_orders {
462                    // Train chain of classifiers for this order
463                    for &label_idx in chain_order {
464                        let y_binary: Vec<usize> = y.column(label_idx).to_vec();
465
466                        let y_binary_array =
467                            Array1::from_vec(y_binary.iter().map(|&x| x as i32).collect());
468                        let classifier = BaggingClassifier::new()
469                            .n_estimators(5) // Fewer estimators per chain classifier
470                            .fit(X, &y_binary_array)?;
471
472                        base_classifiers.push(classifier);
473                    }
474                }
475            }
476
477            _ => {
478                // Default to binary relevance for other strategies
479                for label_idx in 0..n_labels {
480                    let y_binary: Vec<usize> = y.column(label_idx).to_vec();
481
482                    let y_binary_array =
483                        Array1::from_vec(y_binary.iter().map(|&x| x as i32).collect());
484                    let classifier = BaggingClassifier::new()
485                        .n_estimators(self.config.n_estimators)
486                        .fit(X, &y_binary_array)?;
487
488                    base_classifiers.push(classifier);
489                }
490            }
491        }
492
493        // Create label indices
494        let label_indices: Vec<usize> = (0..n_labels).collect();
495
496        // Compute per-label thresholds (simplified for now)
497        let threshold_per_label = vec![self.config.threshold; n_labels];
498
499        // Extract labelsets if using label powerset
500        let (labelset_mapping, inverse_labelset_mapping) = if matches!(
501            self.config.transformation_strategy,
502            LabelTransformationStrategy::LabelPowerset
503        ) {
504            let (forward, inverse) = self.extract_labelsets(y)?;
505            (Some(forward), Some(inverse))
506        } else {
507            (None, None)
508        };
509
510        Ok(MultiLabelEnsembleClassifier {
511            config: self.config,
512            state: std::marker::PhantomData,
513            base_classifiers: Some(base_classifiers),
514            label_indices: Some(label_indices),
515            labelset_mapping,
516            inverse_labelset_mapping,
517            label_correlations: Some(label_correlations),
518            chain_orders: if chain_orders.is_empty() {
519                None
520            } else {
521                Some(chain_orders)
522            },
523            threshold_per_label: Some(threshold_per_label),
524            n_labels: Some(n_labels),
525        })
526    }
527}
528
529impl Predict<Array2<f64>, MultiLabelPredictionResults> for MultiLabelEnsembleClassifier<Trained> {
530    fn predict(&self, X: &Array2<f64>) -> SklResult<MultiLabelPredictionResults> {
531        let base_classifiers = self.base_classifiers.as_ref().expect("Model is trained");
532        let n_labels = self.n_labels.expect("Model is trained");
533        let threshold_per_label = self.threshold_per_label.as_ref().expect("Model is trained");
534
535        let n_samples = X.nrows();
536        let mut predictions = Array2::zeros((n_samples, n_labels));
537        let mut probabilities = Array2::zeros((n_samples, n_labels));
538        let mut ranking_scores = Array2::zeros((n_samples, n_labels));
539
540        match self.config.transformation_strategy {
541            LabelTransformationStrategy::BinaryRelevance => {
542                // Get predictions from each binary classifier
543                for (label_idx, classifier) in base_classifiers.iter().enumerate().take(n_labels) {
544                    let label_predictions = classifier.predict(X)?;
545
546                    // Convert predictions to the right format for each sample
547                    for (sample_idx, &pred) in label_predictions.iter().enumerate() {
548                        predictions[[sample_idx, label_idx]] = pred as usize;
549                        probabilities[[sample_idx, label_idx]] = pred as f64;
550                        ranking_scores[[sample_idx, label_idx]] = pred as f64;
551                    }
552                }
553            }
554
555            LabelTransformationStrategy::LabelPowerset => {
556                if let (Some(labelset_mapping), Some(inverse_labelset_mapping)) =
557                    (&self.labelset_mapping, &self.inverse_labelset_mapping)
558                {
559                    let labelset_predictions = base_classifiers[0].predict(X)?;
560
561                    for (sample_idx, &labelset_id) in labelset_predictions.iter().enumerate() {
562                        if let Some(labelset) =
563                            inverse_labelset_mapping.get(&(labelset_id as usize))
564                        {
565                            for &label_idx in labelset {
566                                if label_idx < n_labels {
567                                    predictions[[sample_idx, label_idx]] = 1;
568                                    probabilities[[sample_idx, label_idx]] = 1.0;
569                                    ranking_scores[[sample_idx, label_idx]] = 1.0;
570                                }
571                            }
572                        }
573                    }
574                }
575            }
576
577            _ => {
578                // Default binary relevance behavior
579                for (label_idx, classifier) in base_classifiers.iter().enumerate().take(n_labels) {
580                    let label_predictions = classifier.predict(X)?;
581
582                    for (sample_idx, &pred) in label_predictions.iter().enumerate() {
583                        predictions[[sample_idx, label_idx]] = pred as usize;
584                        probabilities[[sample_idx, label_idx]] = pred as f64;
585                        ranking_scores[[sample_idx, label_idx]] = pred as f64;
586                    }
587                }
588            }
589        }
590
591        // Apply thresholds
592        for i in 0..n_samples {
593            for j in 0..n_labels {
594                if probabilities[[i, j]] >= threshold_per_label[j] {
595                    predictions[[i, j]] = 1;
596                } else {
597                    predictions[[i, j]] = 0;
598                }
599            }
600        }
601
602        // Compute confidence scores (simplified)
603        let confidence_scores: Vec<f64> = (0..n_samples)
604            .map(|i| {
605                let row_probs: Vec<f64> = (0..n_labels).map(|j| probabilities[[i, j]]).collect();
606                row_probs.iter().sum::<f64>() / n_labels as f64
607            })
608            .collect();
609
610        Ok(MultiLabelPredictionResults {
611            predictions,
612            probabilities,
613            confidence_scores,
614            ranking_scores,
615        })
616    }
617}
618
619impl MultiLabelEnsembleClassifier<Trained> {
620    /// Get the number of labels
621    pub fn n_labels(&self) -> usize {
622        self.n_labels.expect("Model is trained")
623    }
624
625    /// Get label correlations
626    pub fn label_correlations(&self) -> &Array2<f64> {
627        self.label_correlations.as_ref().expect("Model is trained")
628    }
629
630    /// Get the transformation strategy used
631    pub fn transformation_strategy(&self) -> &LabelTransformationStrategy {
632        &self.config.transformation_strategy
633    }
634
635    /// Predict binary labels only
636    pub fn predict_binary(&self, X: &Array2<f64>) -> SklResult<Array2<usize>> {
637        let results = self.predict(X)?;
638        Ok(results.predictions)
639    }
640
641    /// Predict probabilities only
642    pub fn predict_proba(&self, X: &Array2<f64>) -> SklResult<Array2<f64>> {
643        let results = self.predict(X)?;
644        Ok(results.probabilities)
645    }
646
647    /// Get label rankings
648    pub fn predict_rankings(&self, X: &Array2<f64>) -> SklResult<Array2<f64>> {
649        let results = self.predict(X)?;
650        Ok(results.ranking_scores)
651    }
652}
653
654#[allow(non_snake_case)]
655#[cfg(test)]
656mod tests {
657    use super::*;
658    use scirs2_core::ndarray::array;
659
660    #[test]
661    #[allow(non_snake_case)]
662    fn test_multi_label_binary_relevance() {
663        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
664
665        let y = array![[1, 0, 1], [0, 1, 1], [1, 1, 0], [0, 0, 1]];
666
667        let classifier = MultiLabelEnsembleClassifier::binary_relevance()
668            .n_estimators(3)
669            .random_state(42);
670
671        let trained = classifier.fit(&X, &y).expect("Training should succeed");
672        let results = trained.predict(&X).expect("Prediction should succeed");
673
674        assert_eq!(results.predictions.nrows(), 4);
675        assert_eq!(results.predictions.ncols(), 3);
676        assert_eq!(results.probabilities.nrows(), 4);
677        assert_eq!(results.probabilities.ncols(), 3);
678    }
679
680    #[test]
681    #[allow(non_snake_case)]
682    fn test_multi_label_label_powerset() {
683        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
684
685        let y = array![[1, 0, 1], [0, 1, 1], [1, 1, 0], [0, 0, 1]];
686
687        let classifier = MultiLabelEnsembleClassifier::label_powerset()
688            .n_estimators(5)
689            .random_state(42);
690
691        let trained = classifier.fit(&X, &y).expect("Training should succeed");
692        let results = trained.predict(&X).expect("Prediction should succeed");
693
694        assert_eq!(results.predictions.nrows(), 4);
695        assert_eq!(results.predictions.ncols(), 3);
696        assert_eq!(trained.n_labels(), 3);
697    }
698
699    #[test]
700    fn test_label_correlation_computation() {
701        let y = array![[1, 1, 0], [1, 0, 1], [0, 1, 1], [1, 1, 1]];
702
703        let classifier = MultiLabelEnsembleClassifier::binary_relevance();
704        let correlations = classifier
705            .compute_label_correlations(&y)
706            .expect("Should compute correlations");
707
708        assert_eq!(correlations.nrows(), 3);
709        assert_eq!(correlations.ncols(), 3);
710
711        // Diagonal should be 1.0
712        for i in 0..3 {
713            assert_eq!(correlations[[i, i]], 1.0);
714        }
715    }
716
717    #[test]
718    fn test_labelset_extraction() {
719        let y = array![
720            [1, 0, 1],
721            [0, 1, 1],
722            [1, 1, 0],
723            [1, 0, 1] // Duplicate labelset
724        ];
725
726        let classifier = MultiLabelEnsembleClassifier::label_powerset();
727        let (labelset_mapping, inverse_mapping) = classifier
728            .extract_labelsets(&y)
729            .expect("Should extract labelsets");
730
731        // Should have 3 unique labelsets
732        assert_eq!(labelset_mapping.len(), 3);
733        assert_eq!(inverse_mapping.len(), 3);
734
735        // Check specific labelsets
736        assert!(labelset_mapping.contains_key(&vec![0, 2])); // [1, 0, 1]
737        assert!(labelset_mapping.contains_key(&vec![1, 2])); // [0, 1, 1]
738        assert!(labelset_mapping.contains_key(&vec![0, 1])); // [1, 1, 0]
739    }
740
741    #[test]
742    #[allow(non_snake_case)]
743    fn test_ensemble_classifier_chains() {
744        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
745
746        let y = array![[1, 0, 1], [0, 1, 1], [1, 1, 0], [0, 0, 1]];
747
748        let classifier =
749            MultiLabelEnsembleClassifier::ensemble_classifier_chains().random_state(42);
750
751        let trained = classifier.fit(&X, &y).expect("Training should succeed");
752        let results = trained.predict(&X).expect("Prediction should succeed");
753
754        assert_eq!(results.predictions.nrows(), 4);
755        assert_eq!(results.predictions.ncols(), 3);
756
757        // Should have generated chain orders
758        assert!(trained.chain_orders.is_some());
759    }
760}