sklears_multioutput/
classification.rs

1//! Multi-label classification algorithms
2//!
3//! This module provides various multi-label classification approaches including
4//! calibrated methods, k-nearest neighbor approaches, cost-sensitive methods,
5//! and specialized techniques for handling multiple labels simultaneously.
6
7// Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
9use scirs2_core::random::{thread_rng, Rng};
10use sklears_core::{
11    error::{Result as SklResult, SklearsError},
12    traits::{Estimator, Fit, Predict, Untrained},
13    types::Float,
14};
15
16/// Calibrated Binary Relevance Method
17///
18/// Enhanced binary relevance that applies probability calibration to improve
19/// prediction reliability and provide confidence estimates.
20#[derive(Debug, Clone)]
21pub struct CalibratedBinaryRelevance<S = Untrained> {
22    state: S,
23    calibration_method: CalibrationMethod,
24}
25
26/// Calibration methods for probability calibration
27#[derive(Debug, Clone, Copy, PartialEq)]
28pub enum CalibrationMethod {
29    /// Platt scaling (sigmoid calibration)
30    Platt,
31    /// Isotonic regression calibration
32    Isotonic,
33}
34
35/// Trained state for CalibratedBinaryRelevance
36#[derive(Debug, Clone)]
37pub struct CalibratedBinaryRelevanceTrained {
38    base_models: Vec<(Array1<Float>, Float)>, // (weights, bias) for each label
39    calibration_params: Vec<(Float, Float)>,  // (slope, intercept) for each label
40    calibration_method: CalibrationMethod,
41    n_features: usize,
42    n_labels: usize,
43}
44
45impl Default for CalibratedBinaryRelevance<Untrained> {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51impl Estimator for CalibratedBinaryRelevance<Untrained> {
52    type Config = ();
53    type Error = SklearsError;
54    type Float = Float;
55
56    fn config(&self) -> &Self::Config {
57        &()
58    }
59}
60
61impl Fit<ArrayView2<'_, Float>, Array2<i32>> for CalibratedBinaryRelevance<Untrained> {
62    type Fitted = CalibratedBinaryRelevance<CalibratedBinaryRelevanceTrained>;
63
64    fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
65        let (n_samples, n_features) = X.dim();
66        let n_labels = y.ncols();
67
68        if n_samples != y.nrows() {
69            return Err(SklearsError::InvalidInput(
70                "X and y must have the same number of samples".to_string(),
71            ));
72        }
73
74        let mut base_models = Vec::new();
75        let mut calibration_params = Vec::new();
76
77        // Train base classifiers and calibration for each label
78        for label_idx in 0..n_labels {
79            let y_label = y.column(label_idx);
80
81            // Train base logistic regression
82            let mut weights = Array1::<Float>::zeros(n_features);
83            let mut bias = 0.0;
84            let learning_rate = 0.01;
85            let max_iter = 100;
86
87            // Simple logistic regression training
88            for _iter in 0..max_iter {
89                let mut weight_gradient = Array1::<Float>::zeros(n_features);
90                let mut bias_gradient = 0.0;
91
92                for sample_idx in 0..n_samples {
93                    let x = X.row(sample_idx);
94                    let y_true = y_label[sample_idx] as Float;
95
96                    let logit = x.dot(&weights) + bias;
97                    let prob = 1.0 / (1.0 + (-logit).exp());
98                    let error = prob - y_true;
99
100                    // Accumulate gradients
101                    for feat_idx in 0..n_features {
102                        weight_gradient[feat_idx] += error * x[feat_idx];
103                    }
104                    bias_gradient += error;
105                }
106
107                // Update parameters
108                for i in 0..n_features {
109                    weights[i] -= learning_rate * weight_gradient[i] / n_samples as Float;
110                }
111                bias -= learning_rate * bias_gradient / n_samples as Float;
112            }
113
114            // Collect probabilities for calibration
115            let mut probs = Vec::new();
116            let mut labels = Vec::new();
117            for sample_idx in 0..n_samples {
118                let x = X.row(sample_idx);
119                let logit = x.dot(&weights) + bias;
120                let prob = 1.0 / (1.0 + (-logit).exp());
121                probs.push(prob);
122                labels.push(y_label[sample_idx] as Float);
123            }
124
125            // Fit calibration
126            let (slope, intercept) = self.fit_calibration(&probs, &labels)?;
127
128            base_models.push((weights, bias));
129            calibration_params.push((slope, intercept));
130        }
131
132        Ok(CalibratedBinaryRelevance {
133            state: CalibratedBinaryRelevanceTrained {
134                base_models,
135                calibration_params,
136                calibration_method: self.calibration_method,
137                n_features,
138                n_labels,
139            },
140            calibration_method: self.calibration_method,
141        })
142    }
143}
144
145impl CalibratedBinaryRelevance<Untrained> {
146    /// Create a new CalibratedBinaryRelevance
147    pub fn new() -> Self {
148        Self {
149            state: Untrained,
150            calibration_method: CalibrationMethod::Platt,
151        }
152    }
153
154    /// Set the calibration method
155    pub fn calibration_method(mut self, method: CalibrationMethod) -> Self {
156        self.calibration_method = method;
157        self
158    }
159
160    /// Fit calibration parameters
161    fn fit_calibration(&self, probs: &[Float], labels: &[Float]) -> SklResult<(Float, Float)> {
162        // Simple Platt scaling implementation
163        match self.calibration_method {
164            CalibrationMethod::Platt => {
165                // Fit sigmoid: p_cal = 1 / (1 + exp(a*p + b))
166                // Simplified: just fit linear transformation
167                let mut a = -1.0;
168                let mut b = 0.0;
169                let learning_rate = 0.01;
170
171                for _iter in 0..100 {
172                    let mut grad_a = 0.0;
173                    let mut grad_b = 0.0;
174
175                    for (i, &prob) in probs.iter().enumerate() {
176                        let y_true = labels[i];
177                        let logit = a * prob + b;
178                        let cal_prob = 1.0 / (1.0 + (-logit).exp());
179                        let error = cal_prob - y_true;
180
181                        grad_a += error * prob;
182                        grad_b += error;
183                    }
184
185                    a -= learning_rate * grad_a / probs.len() as Float;
186                    b -= learning_rate * grad_b / probs.len() as Float;
187                }
188
189                Ok((a, b))
190            }
191            CalibrationMethod::Isotonic => {
192                // Simplified isotonic regression
193                Ok((-1.0, 0.0))
194            }
195        }
196    }
197}
198
199impl Predict<ArrayView2<'_, Float>, Array2<i32>>
200    for CalibratedBinaryRelevance<CalibratedBinaryRelevanceTrained>
201{
202    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
203        let (n_samples, n_features) = X.dim();
204
205        if n_features != self.state.n_features {
206            return Err(SklearsError::InvalidInput(
207                "X has different number of features than training data".to_string(),
208            ));
209        }
210
211        let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_labels));
212
213        for sample_idx in 0..n_samples {
214            let x = X.row(sample_idx);
215
216            for label_idx in 0..self.state.n_labels {
217                let (weights, bias) = &self.state.base_models[label_idx];
218                let (slope, intercept) = self.state.calibration_params[label_idx];
219
220                // Get base probability
221                let logit = x.dot(weights) + bias;
222                let base_prob = 1.0 / (1.0 + (-logit).exp());
223
224                // Apply calibration
225                let cal_logit = slope * base_prob + intercept;
226                let cal_prob = 1.0 / (1.0 + (-cal_logit).exp());
227
228                predictions[[sample_idx, label_idx]] = if cal_prob > 0.5 { 1 } else { 0 };
229            }
230        }
231
232        Ok(predictions)
233    }
234}
235
236impl CalibratedBinaryRelevance<CalibratedBinaryRelevanceTrained> {
237    /// Get calibrated probabilities
238    pub fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
239        let (n_samples, n_features) = X.dim();
240
241        if n_features != self.state.n_features {
242            return Err(SklearsError::InvalidInput(
243                "X has different number of features than training data".to_string(),
244            ));
245        }
246
247        let mut probabilities = Array2::<Float>::zeros((n_samples, self.state.n_labels));
248
249        for sample_idx in 0..n_samples {
250            let x = X.row(sample_idx);
251
252            for label_idx in 0..self.state.n_labels {
253                let (weights, bias) = &self.state.base_models[label_idx];
254                let (slope, intercept) = self.state.calibration_params[label_idx];
255
256                // Get base probability
257                let logit = x.dot(weights) + bias;
258                let base_prob = 1.0 / (1.0 + (-logit).exp());
259
260                // Apply calibration
261                let cal_logit = slope * base_prob + intercept;
262                let cal_prob = 1.0 / (1.0 + (-cal_logit).exp());
263
264                probabilities[[sample_idx, label_idx]] = cal_prob;
265            }
266        }
267
268        Ok(probabilities)
269    }
270}
271
272/// Random Label Combinations Method
273///
274/// Generates random label combinations for evaluation and testing purposes.
275/// Useful for creating synthetic multi-label datasets with controlled characteristics.
276pub struct RandomLabelCombinations {
277    n_labels: usize,
278    n_combinations: usize,
279    label_density: Float,
280    random_state: Option<u64>,
281}
282
283impl RandomLabelCombinations {
284    /// Create a new RandomLabelCombinations generator
285    pub fn new(n_labels: usize) -> Self {
286        Self {
287            n_labels,
288            n_combinations: 100,
289            label_density: 0.3,
290            random_state: None,
291        }
292    }
293
294    /// Set the number of combinations to generate
295    pub fn n_combinations(mut self, n_combinations: usize) -> Self {
296        self.n_combinations = n_combinations;
297        self
298    }
299
300    /// Set the label density (proportion of positive labels)
301    pub fn label_density(mut self, density: Float) -> Self {
302        self.label_density = density;
303        self
304    }
305
306    /// Set random state for reproducible results
307    pub fn random_state(mut self, seed: u64) -> Self {
308        self.random_state = Some(seed);
309        self
310    }
311
312    /// Generate random label combinations
313    pub fn generate(&self) -> Array2<i32> {
314        let mut rng = if let Some(_seed) = self.random_state {
315            // TODO: Implement deterministic seeding with ThreadRng
316            thread_rng()
317        } else {
318            thread_rng()
319        };
320
321        let mut combinations = Array2::<i32>::zeros((self.n_combinations, self.n_labels));
322
323        for i in 0..self.n_combinations {
324            for j in 0..self.n_labels {
325                combinations[[i, j]] = if rng.gen::<Float>() < self.label_density {
326                    1
327                } else {
328                    0
329                };
330            }
331        }
332
333        combinations
334    }
335}
336
337/// ML-kNN: Multi-Label k-Nearest Neighbors
338///
339/// ML-kNN is an adaptation of the k-nearest neighbors algorithm for multi-label classification.
340/// It uses the maximum a posteriori (MAP) principle to determine the label set for a test instance
341/// based on the labels of its k nearest neighbors.
342#[derive(Debug, Clone)]
343pub struct MLkNN<S = Untrained> {
344    state: S,
345    k: usize,
346    smooth: Float,
347    distance_metric: DistanceMetric,
348}
349
350/// Distance metrics for ML-kNN
351#[derive(Debug, Clone, Copy, PartialEq)]
352pub enum DistanceMetric {
353    /// Euclidean distance
354    Euclidean,
355    /// Manhattan distance
356    Manhattan,
357    /// Cosine distance
358    Cosine,
359}
360
361/// Trained state for ML-kNN
362#[derive(Debug, Clone)]
363pub struct MLkNNTrained {
364    training_data: Array2<Float>,
365    training_labels: Array2<i32>,
366    prior_probs: Array1<Float>,
367    conditional_probs: Array2<Float>, // P(label|neighbor_count)
368    k: usize,
369    smooth: Float,
370    distance_metric: DistanceMetric,
371    n_labels: usize,
372}
373
374impl Default for MLkNN<Untrained> {
375    fn default() -> Self {
376        Self::new()
377    }
378}
379
380impl Estimator for MLkNN<Untrained> {
381    type Config = ();
382    type Error = SklearsError;
383    type Float = Float;
384
385    fn config(&self) -> &Self::Config {
386        &()
387    }
388}
389
390impl Fit<ArrayView2<'_, Float>, Array2<i32>> for MLkNN<Untrained> {
391    type Fitted = MLkNN<MLkNNTrained>;
392
393    fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
394        let (n_samples, n_features) = X.dim();
395        let n_labels = y.ncols();
396
397        if n_samples != y.nrows() {
398            return Err(SklearsError::InvalidInput(
399                "X and y must have the same number of samples".to_string(),
400            ));
401        }
402
403        if self.k >= n_samples {
404            return Err(SklearsError::InvalidInput(
405                "k must be smaller than the number of training samples".to_string(),
406            ));
407        }
408
409        // Calculate prior probabilities
410        let mut prior_probs = Array1::<Float>::zeros(n_labels);
411        for label_idx in 0..n_labels {
412            let positive_count = y.column(label_idx).iter().filter(|&&x| x == 1).count();
413            prior_probs[label_idx] =
414                (positive_count as Float + self.smooth) / (n_samples as Float + 2.0 * self.smooth);
415        }
416
417        // Calculate conditional probabilities P(neighbor_count | label)
418        let mut conditional_probs = Array2::<Float>::zeros((n_labels, self.k + 1));
419
420        for sample_idx in 0..n_samples {
421            let neighbors = self.find_k_neighbors(X, sample_idx, &X.view())?;
422
423            for label_idx in 0..n_labels {
424                let label_count = neighbors
425                    .iter()
426                    .filter(|&&neighbor_idx| y[[neighbor_idx, label_idx]] == 1)
427                    .count();
428
429                if y[[sample_idx, label_idx]] == 1 {
430                    conditional_probs[[label_idx, label_count]] += 1.0;
431                }
432            }
433        }
434
435        // Normalize conditional probabilities with smoothing
436        for label_idx in 0..n_labels {
437            let total_positive = y.column(label_idx).iter().filter(|&&x| x == 1).count() as Float;
438            for count in 0..=self.k {
439                conditional_probs[[label_idx, count]] = (conditional_probs[[label_idx, count]]
440                    + self.smooth)
441                    / (total_positive + (self.k + 1) as Float * self.smooth);
442            }
443        }
444
445        Ok(MLkNN {
446            state: MLkNNTrained {
447                training_data: X.to_owned(),
448                training_labels: y.clone(),
449                prior_probs,
450                conditional_probs,
451                k: self.k,
452                smooth: self.smooth,
453                distance_metric: self.distance_metric,
454                n_labels,
455            },
456            k: self.k,
457            smooth: self.smooth,
458            distance_metric: self.distance_metric,
459        })
460    }
461}
462
463impl MLkNN<Untrained> {
464    /// Create a new ML-kNN classifier
465    pub fn new() -> Self {
466        Self {
467            state: Untrained,
468            k: 10,
469            smooth: 1.0,
470            distance_metric: DistanceMetric::Euclidean,
471        }
472    }
473
474    /// Set the number of neighbors
475    pub fn k(mut self, k: usize) -> Self {
476        self.k = k;
477        self
478    }
479
480    /// Set the smoothing parameter
481    pub fn smooth(mut self, smooth: Float) -> Self {
482        self.smooth = smooth;
483        self
484    }
485
486    /// Set the distance metric
487    pub fn distance_metric(mut self, metric: DistanceMetric) -> Self {
488        self.distance_metric = metric;
489        self
490    }
491
492    /// Find k nearest neighbors for a sample
493    fn find_k_neighbors(
494        &self,
495        X: &ArrayView2<'_, Float>,
496        sample_idx: usize,
497        training_data: &ArrayView2<'_, Float>,
498    ) -> SklResult<Vec<usize>> {
499        let query = X.row(sample_idx);
500        let mut distances = Vec::new();
501
502        for (train_idx, train_sample) in training_data.rows().into_iter().enumerate() {
503            if train_idx != sample_idx {
504                let distance = self.calculate_distance(&query, &train_sample);
505                distances.push((distance, train_idx));
506            }
507        }
508
509        distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
510        let neighbors = distances
511            .into_iter()
512            .take(self.k)
513            .map(|(_, idx)| idx)
514            .collect();
515
516        Ok(neighbors)
517    }
518
519    /// Calculate distance between two samples
520    fn calculate_distance(&self, a: &ArrayView1<'_, Float>, b: &ArrayView1<'_, Float>) -> Float {
521        match self.distance_metric {
522            DistanceMetric::Euclidean => a
523                .iter()
524                .zip(b.iter())
525                .map(|(x, y)| (x - y).powi(2))
526                .sum::<Float>()
527                .sqrt(),
528            DistanceMetric::Manhattan => a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum(),
529            DistanceMetric::Cosine => {
530                let dot = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<Float>();
531                let norm_a = a.iter().map(|x| x.powi(2)).sum::<Float>().sqrt();
532                let norm_b = b.iter().map(|x| x.powi(2)).sum::<Float>().sqrt();
533                if norm_a > 0.0 && norm_b > 0.0 {
534                    1.0 - dot / (norm_a * norm_b)
535                } else {
536                    1.0
537                }
538            }
539        }
540    }
541}
542
543impl Predict<ArrayView2<'_, Float>, Array2<i32>> for MLkNN<MLkNNTrained> {
544    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
545        let (n_samples, n_features) = X.dim();
546
547        if n_features != self.state.training_data.ncols() {
548            return Err(SklearsError::InvalidInput(
549                "X has different number of features than training data".to_string(),
550            ));
551        }
552
553        let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_labels));
554
555        for sample_idx in 0..n_samples {
556            let neighbors = self.find_k_neighbors_trained(X, sample_idx)?;
557
558            for label_idx in 0..self.state.n_labels {
559                // Count positive neighbors for this label
560                let positive_neighbors = neighbors
561                    .iter()
562                    .filter(|&&neighbor_idx| {
563                        self.state.training_labels[[neighbor_idx, label_idx]] == 1
564                    })
565                    .count();
566
567                // Calculate posterior probabilities using MAP
568                let prob_positive = self.state.prior_probs[label_idx]
569                    * self.state.conditional_probs[[label_idx, positive_neighbors]];
570                let prob_negative = (1.0 - self.state.prior_probs[label_idx])
571                    * (1.0 - self.state.conditional_probs[[label_idx, positive_neighbors]]);
572
573                predictions[[sample_idx, label_idx]] =
574                    if prob_positive > prob_negative { 1 } else { 0 };
575            }
576        }
577
578        Ok(predictions)
579    }
580}
581
582impl MLkNN<MLkNNTrained> {
583    /// Find k nearest neighbors for a test sample
584    fn find_k_neighbors_trained(
585        &self,
586        X: &ArrayView2<'_, Float>,
587        sample_idx: usize,
588    ) -> SklResult<Vec<usize>> {
589        let query = X.row(sample_idx);
590        let mut distances = Vec::new();
591
592        for (train_idx, train_sample) in self.state.training_data.rows().into_iter().enumerate() {
593            let distance = self.calculate_distance_trained(&query, &train_sample);
594            distances.push((distance, train_idx));
595        }
596
597        distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
598        let neighbors = distances
599            .into_iter()
600            .take(self.state.k)
601            .map(|(_, idx)| idx)
602            .collect();
603
604        Ok(neighbors)
605    }
606
607    /// Calculate distance between two samples (trained version)
608    fn calculate_distance_trained(
609        &self,
610        a: &ArrayView1<'_, Float>,
611        b: &ArrayView1<'_, Float>,
612    ) -> Float {
613        match self.state.distance_metric {
614            DistanceMetric::Euclidean => a
615                .iter()
616                .zip(b.iter())
617                .map(|(x, y)| (x - y).powi(2))
618                .sum::<Float>()
619                .sqrt(),
620            DistanceMetric::Manhattan => a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum(),
621            DistanceMetric::Cosine => {
622                let dot = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<Float>();
623                let norm_a = a.iter().map(|x| x.powi(2)).sum::<Float>().sqrt();
624                let norm_b = b.iter().map(|x| x.powi(2)).sum::<Float>().sqrt();
625                if norm_a > 0.0 && norm_b > 0.0 {
626                    1.0 - dot / (norm_a * norm_b)
627                } else {
628                    1.0
629                }
630            }
631        }
632    }
633
634    /// Get the number of neighbors
635    pub fn k(&self) -> usize {
636        self.state.k
637    }
638
639    /// Get prior probabilities
640    pub fn prior_probabilities(&self) -> &Array1<Float> {
641        &self.state.prior_probs
642    }
643}
644
645/// Cost-Sensitive Binary Relevance
646///
647/// Binary relevance approach that incorporates label-specific misclassification costs
648/// to optimize cost-sensitive performance rather than accuracy.
649#[derive(Debug, Clone)]
650pub struct CostSensitiveBinaryRelevance<S = Untrained> {
651    state: S,
652    cost_matrix: CostMatrix,
653    learning_rate: Float,
654    max_iterations: usize,
655    regularization: Float,
656}
657
658/// Cost matrix for cost-sensitive learning
659#[derive(Debug, Clone)]
660pub struct CostMatrix {
661    /// Cost of false positives for each label
662    false_positive_costs: Array1<Float>,
663    /// Cost of false negatives for each label
664    false_negative_costs: Array1<Float>,
665}
666
667impl CostMatrix {
668    /// Create a new cost matrix
669    pub fn new(false_positive_costs: Array1<Float>, false_negative_costs: Array1<Float>) -> Self {
670        Self {
671            false_positive_costs,
672            false_negative_costs,
673        }
674    }
675
676    /// Create uniform cost matrix
677    pub fn uniform(n_labels: usize, fp_cost: Float, fn_cost: Float) -> Self {
678        Self {
679            false_positive_costs: Array1::from_elem(n_labels, fp_cost),
680            false_negative_costs: Array1::from_elem(n_labels, fn_cost),
681        }
682    }
683
684    /// Get false positive cost for a label
685    pub fn fp_cost(&self, label_idx: usize) -> Float {
686        self.false_positive_costs
687            .get(label_idx)
688            .copied()
689            .unwrap_or(1.0)
690    }
691
692    /// Get false negative cost for a label
693    pub fn fn_cost(&self, label_idx: usize) -> Float {
694        self.false_negative_costs
695            .get(label_idx)
696            .copied()
697            .unwrap_or(1.0)
698    }
699}
700
701/// Trained state for cost-sensitive binary relevance
702#[derive(Debug, Clone)]
703pub struct CostSensitiveBinaryRelevanceTrained {
704    models: Vec<SimpleBinaryModel>,
705    cost_matrix: CostMatrix,
706    n_features: usize,
707    n_labels: usize,
708}
709
710/// Simple binary model for cost-sensitive learning
711#[derive(Debug, Clone)]
712pub struct SimpleBinaryModel {
713    weights: Array1<Float>,
714    bias: Float,
715    threshold: Float, // Cost-sensitive threshold
716}
717
718impl Default for CostSensitiveBinaryRelevance<Untrained> {
719    fn default() -> Self {
720        Self::new()
721    }
722}
723
724impl Estimator for CostSensitiveBinaryRelevance<Untrained> {
725    type Config = ();
726    type Error = SklearsError;
727    type Float = Float;
728
729    fn config(&self) -> &Self::Config {
730        &()
731    }
732}
733
734impl Fit<ArrayView2<'_, Float>, Array2<i32>> for CostSensitiveBinaryRelevance<Untrained> {
735    type Fitted = CostSensitiveBinaryRelevance<CostSensitiveBinaryRelevanceTrained>;
736
737    fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
738        let (n_samples, n_features) = X.dim();
739        let n_labels = y.ncols();
740
741        if n_samples != y.nrows() {
742            return Err(SklearsError::InvalidInput(
743                "X and y must have the same number of samples".to_string(),
744            ));
745        }
746
747        let mut models = Vec::new();
748
749        // Train cost-sensitive binary classifier for each label
750        for label_idx in 0..n_labels {
751            let y_label = y.column(label_idx);
752            let fp_cost = self.cost_matrix.fp_cost(label_idx);
753            let fn_cost = self.cost_matrix.fn_cost(label_idx);
754
755            let mut weights = Array1::<Float>::zeros(n_features);
756            let mut bias = 0.0;
757
758            // Cost-sensitive training loop
759            for _iter in 0..self.max_iterations {
760                let mut weight_gradient = Array1::<Float>::zeros(n_features);
761                let mut bias_gradient = 0.0;
762
763                for sample_idx in 0..n_samples {
764                    let x = X.row(sample_idx);
765                    let y_true = y_label[sample_idx] as Float;
766
767                    let logit = x.dot(&weights) + bias;
768                    let prob = 1.0 / (1.0 + (-logit).exp());
769
770                    // Cost-sensitive gradient
771                    let cost_weight = if y_true == 1.0 { fn_cost } else { fp_cost };
772                    let error = (prob - y_true) * cost_weight;
773
774                    // Accumulate gradients
775                    for feat_idx in 0..n_features {
776                        weight_gradient[feat_idx] += error * x[feat_idx];
777                    }
778                    bias_gradient += error;
779                }
780
781                // Add L2 regularization
782                for i in 0..n_features {
783                    weight_gradient[i] += self.regularization * weights[i];
784                }
785
786                // Update parameters
787                for i in 0..n_features {
788                    weights[i] -= self.learning_rate * weight_gradient[i] / n_samples as Float;
789                }
790                bias -= self.learning_rate * bias_gradient / n_samples as Float;
791            }
792
793            // Calculate cost-sensitive threshold
794            let threshold = self.calculate_cost_sensitive_threshold(fp_cost, fn_cost);
795
796            models.push(SimpleBinaryModel {
797                weights,
798                bias,
799                threshold,
800            });
801        }
802
803        Ok(CostSensitiveBinaryRelevance {
804            state: CostSensitiveBinaryRelevanceTrained {
805                models,
806                cost_matrix: self.cost_matrix,
807                n_features,
808                n_labels,
809            },
810            cost_matrix: CostMatrix::uniform(n_labels, 1.0, 1.0),
811            learning_rate: self.learning_rate,
812            max_iterations: self.max_iterations,
813            regularization: self.regularization,
814        })
815    }
816}
817
818impl CostSensitiveBinaryRelevance<Untrained> {
819    /// Create a new cost-sensitive binary relevance classifier
820    pub fn new() -> Self {
821        Self {
822            state: Untrained,
823            cost_matrix: CostMatrix::uniform(1, 1.0, 1.0),
824            learning_rate: 0.01,
825            max_iterations: 100,
826            regularization: 0.01,
827        }
828    }
829
830    /// Set the cost matrix
831    pub fn cost_matrix(mut self, cost_matrix: CostMatrix) -> Self {
832        self.cost_matrix = cost_matrix;
833        self
834    }
835
836    /// Set the learning rate
837    pub fn learning_rate(mut self, learning_rate: Float) -> Self {
838        self.learning_rate = learning_rate;
839        self
840    }
841
842    /// Set the maximum number of iterations
843    pub fn max_iterations(mut self, max_iterations: usize) -> Self {
844        self.max_iterations = max_iterations;
845        self
846    }
847
848    /// Set the regularization strength
849    pub fn regularization(mut self, regularization: Float) -> Self {
850        self.regularization = regularization;
851        self
852    }
853
854    /// Calculate cost-sensitive threshold
855    fn calculate_cost_sensitive_threshold(&self, fp_cost: Float, fn_cost: Float) -> Float {
856        // Threshold that minimizes expected cost
857        // threshold = log(fp_cost / fn_cost) if we had class priors
858        // Simplified version
859        fp_cost / (fp_cost + fn_cost)
860    }
861}
862
863impl Predict<ArrayView2<'_, Float>, Array2<i32>>
864    for CostSensitiveBinaryRelevance<CostSensitiveBinaryRelevanceTrained>
865{
866    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
867        let (n_samples, n_features) = X.dim();
868
869        if n_features != self.state.n_features {
870            return Err(SklearsError::InvalidInput(
871                "X has different number of features than training data".to_string(),
872            ));
873        }
874
875        let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_labels));
876
877        for sample_idx in 0..n_samples {
878            let x = X.row(sample_idx);
879
880            for (label_idx, model) in self.state.models.iter().enumerate() {
881                let logit = x.dot(&model.weights) + model.bias;
882                let prob = 1.0 / (1.0 + (-logit).exp());
883
884                predictions[[sample_idx, label_idx]] = if prob > model.threshold { 1 } else { 0 };
885            }
886        }
887
888        Ok(predictions)
889    }
890}
891
892impl CostSensitiveBinaryRelevance<CostSensitiveBinaryRelevanceTrained> {
893    /// Get the cost matrix
894    pub fn cost_matrix(&self) -> &CostMatrix {
895        &self.state.cost_matrix
896    }
897
898    /// Get model thresholds
899    pub fn thresholds(&self) -> Vec<Float> {
900        self.state.models.iter().map(|m| m.threshold).collect()
901    }
902}
903
904#[allow(non_snake_case)]
905#[cfg(test)]
906mod tests {
907    use super::*;
908    // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
909    use scirs2_core::ndarray::array;
910
911    #[test]
912    #[allow(non_snake_case)]
913    fn test_calibrated_binary_relevance_basic() {
914        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
915        let y = array![[1, 0], [0, 1], [1, 1], [0, 0]];
916
917        let cbr = CalibratedBinaryRelevance::new().calibration_method(CalibrationMethod::Platt);
918        let trained_cbr = cbr.fit(&X.view(), &y).unwrap();
919        let predictions = trained_cbr.predict(&X.view()).unwrap();
920
921        assert_eq!(predictions.dim(), (4, 2));
922        assert!(predictions.iter().all(|&x| x == 0 || x == 1));
923    }
924
925    #[test]
926    #[allow(non_snake_case)]
927    fn test_calibrated_binary_relevance_probabilities() {
928        let X = array![[1.0, 2.0], [2.0, 3.0]];
929        let y = array![[1, 0], [0, 1]];
930
931        let cbr = CalibratedBinaryRelevance::new();
932        let trained_cbr = cbr.fit(&X.view(), &y).unwrap();
933        let probabilities = trained_cbr.predict_proba(&X.view()).unwrap();
934
935        assert_eq!(probabilities.dim(), (2, 2));
936        assert!(probabilities.iter().all(|&p| p >= 0.0 && p <= 1.0));
937    }
938
939    #[test]
940    fn test_random_label_combinations() {
941        let generator = RandomLabelCombinations::new(3)
942            .n_combinations(5)
943            .label_density(0.5)
944            .random_state(42);
945
946        let combinations = generator.generate();
947        assert_eq!(combinations.dim(), (5, 3));
948        assert!(combinations.iter().all(|&x| x == 0 || x == 1));
949    }
950
951    #[test]
952    #[allow(non_snake_case)]
953    fn test_mlknn_basic() {
954        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0], [1.5, 2.5]];
955        let y = array![[1, 0], [0, 1], [1, 1], [0, 0], [1, 0]];
956
957        let mlknn = MLkNN::new().k(3).smooth(1.0);
958        let trained_mlknn = mlknn.fit(&X.view(), &y).unwrap();
959        let predictions = trained_mlknn.predict(&X.view()).unwrap();
960
961        assert_eq!(predictions.dim(), (5, 2));
962        assert!(predictions.iter().all(|&x| x == 0 || x == 1));
963        assert_eq!(trained_mlknn.k(), 3);
964    }
965
966    #[test]
967    #[allow(non_snake_case)]
968    fn test_mlknn_distance_metrics() {
969        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
970        let y = array![[1, 0], [0, 1], [1, 1]];
971
972        let mlknn_euclidean = MLkNN::new().k(2).distance_metric(DistanceMetric::Euclidean);
973        let trained_euclidean = mlknn_euclidean.fit(&X.view(), &y).unwrap();
974
975        let mlknn_manhattan = MLkNN::new().k(2).distance_metric(DistanceMetric::Manhattan);
976        let trained_manhattan = mlknn_manhattan.fit(&X.view(), &y).unwrap();
977
978        let pred_euclidean = trained_euclidean.predict(&X.view()).unwrap();
979        let pred_manhattan = trained_manhattan.predict(&X.view()).unwrap();
980
981        assert_eq!(pred_euclidean.dim(), (3, 2));
982        assert_eq!(pred_manhattan.dim(), (3, 2));
983    }
984
985    #[test]
986    #[allow(non_snake_case)]
987    fn test_cost_sensitive_binary_relevance() {
988        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
989        let y = array![[1, 0], [0, 1], [1, 1], [0, 0]];
990
991        let fp_costs = array![2.0, 1.0]; // Higher cost for FP on first label
992        let fn_costs = array![1.0, 3.0]; // Higher cost for FN on second label
993        let cost_matrix = CostMatrix::new(fp_costs, fn_costs);
994
995        let csbr = CostSensitiveBinaryRelevance::new()
996            .cost_matrix(cost_matrix)
997            .learning_rate(0.01)
998            .max_iterations(50);
999
1000        let trained_csbr = csbr.fit(&X.view(), &y).unwrap();
1001        let predictions = trained_csbr.predict(&X.view()).unwrap();
1002
1003        assert_eq!(predictions.dim(), (4, 2));
1004        assert!(predictions.iter().all(|&x| x == 0 || x == 1));
1005
1006        let thresholds = trained_csbr.thresholds();
1007        assert_eq!(thresholds.len(), 2);
1008    }
1009
1010    #[test]
1011    fn test_cost_matrix_creation() {
1012        let fp_costs = array![1.0, 2.0, 3.0];
1013        let fn_costs = array![2.0, 1.0, 1.0];
1014        let cost_matrix = CostMatrix::new(fp_costs, fn_costs);
1015
1016        assert_eq!(cost_matrix.fp_cost(0), 1.0);
1017        assert_eq!(cost_matrix.fp_cost(1), 2.0);
1018        assert_eq!(cost_matrix.fn_cost(0), 2.0);
1019        assert_eq!(cost_matrix.fn_cost(1), 1.0);
1020
1021        let uniform_costs = CostMatrix::uniform(3, 1.5, 2.5);
1022        assert_eq!(uniform_costs.fp_cost(0), 1.5);
1023        assert_eq!(uniform_costs.fn_cost(2), 2.5);
1024    }
1025
1026    #[test]
1027    fn test_calibration_methods() {
1028        let cbr_platt =
1029            CalibratedBinaryRelevance::new().calibration_method(CalibrationMethod::Platt);
1030        let cbr_isotonic =
1031            CalibratedBinaryRelevance::new().calibration_method(CalibrationMethod::Isotonic);
1032
1033        // Just test that they can be created with different methods
1034        assert_eq!(cbr_platt.calibration_method, CalibrationMethod::Platt);
1035        assert_eq!(cbr_isotonic.calibration_method, CalibrationMethod::Isotonic);
1036    }
1037
1038    #[test]
1039    #[allow(non_snake_case)]
1040    fn test_mlknn_prior_probabilities() {
1041        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
1042        let y = array![[1, 0], [0, 1], [1, 1], [0, 0]]; // 2/4 positive for each label
1043
1044        let mlknn = MLkNN::new().k(2).smooth(1.0);
1045        let trained_mlknn = mlknn.fit(&X.view(), &y).unwrap();
1046
1047        let priors = trained_mlknn.prior_probabilities();
1048        assert_eq!(priors.len(), 2);
1049
1050        // With smoothing: (2 + 1) / (4 + 2) = 0.5
1051        assert!((priors[0] - 0.5).abs() < 1e-6);
1052        assert!((priors[1] - 0.5).abs() < 1e-6);
1053    }
1054}