sklears_multioutput/
svm.rs

1//! Support Vector Machine algorithms for multi-output learning
2
3// Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
4use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
5use sklears_core::{
6    error::{Result as SklResult, SklearsError},
7    traits::{Estimator, Fit, Predict, Untrained},
8    types::Float,
9};
10
11/// MLTSVM (Multi-Label Twin SVM)
12///
13/// MLTSVM is a multi-label classification method that extends Twin SVM to handle
14/// multiple labels. Twin SVM finds two non-parallel hyperplanes for binary classification,
15/// which often leads to faster training than standard SVM. MLTSVM applies this approach
16/// to each label independently in a binary relevance fashion.
17///
18/// # Examples
19///
20/// ```
21/// use sklears_core::traits::{Predict, Fit};
22/// use sklears_multioutput::MLTSVM;
23/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
24/// use scirs2_core::ndarray::array;
25///
26/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
27/// let y = array![[1, 0], [0, 1], [1, 1], [0, 0]]; // Multi-label binary
28///
29/// let mltsvm = MLTSVM::new().c1(1.0).c2(1.0);
30/// let trained_mltsvm = mltsvm.fit(&X.view(), &y).unwrap();
31/// let predictions = trained_mltsvm.predict(&X.view()).unwrap();
32/// ```
33#[derive(Debug, Clone)]
34pub struct MLTSVM<S = Untrained> {
35    state: S,
36    c1: Float,       // Regularization parameter for first hyperplane
37    c2: Float,       // Regularization parameter for second hyperplane
38    epsilon: Float,  // Tolerance for convergence
39    max_iter: usize, // Maximum iterations
40}
41
42/// Trained state for MLTSVM
43#[derive(Debug, Clone)]
44pub struct MLTSVMTrained {
45    models: Vec<TwinSVMModel>, // One model per label
46    n_labels: usize,
47    feature_means: Array1<Float>,
48    feature_stds: Array1<Float>,
49}
50
51/// Twin SVM model for a single label
52#[derive(Debug, Clone)]
53pub struct TwinSVMModel {
54    w1: Array1<Float>, // Weight vector for positive hyperplane
55    b1: Float,         // Bias for positive hyperplane
56    w2: Array1<Float>, // Weight vector for negative hyperplane
57    b2: Float,         // Bias for negative hyperplane
58}
59
60impl MLTSVM<Untrained> {
61    /// Create a new MLTSVM instance
62    pub fn new() -> Self {
63        Self {
64            state: Untrained,
65            c1: 1.0,
66            c2: 1.0,
67            epsilon: 1e-3,
68            max_iter: 1000,
69        }
70    }
71
72    /// Set C1 parameter
73    pub fn c1(mut self, c1: Float) -> Self {
74        self.c1 = c1;
75        self
76    }
77
78    /// Set C2 parameter
79    pub fn c2(mut self, c2: Float) -> Self {
80        self.c2 = c2;
81        self
82    }
83
84    /// Set epsilon parameter
85    pub fn epsilon(mut self, epsilon: Float) -> Self {
86        self.epsilon = epsilon;
87        self
88    }
89
90    /// Set maximum iterations
91    pub fn max_iter(mut self, max_iter: usize) -> Self {
92        self.max_iter = max_iter;
93        self
94    }
95}
96
97impl Default for MLTSVM<Untrained> {
98    fn default() -> Self {
99        Self::new()
100    }
101}
102
103impl Estimator for MLTSVM<Untrained> {
104    type Config = ();
105    type Error = SklearsError;
106    type Float = Float;
107
108    fn config(&self) -> &Self::Config {
109        &()
110    }
111}
112
113impl Fit<ArrayView2<'_, Float>, Array2<i32>> for MLTSVM<Untrained> {
114    type Fitted = MLTSVM<MLTSVMTrained>;
115
116    fn fit(self, x: &ArrayView2<'_, Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
117        let (n_samples, n_features) = x.dim();
118        let (y_samples, n_labels) = y.dim();
119
120        if n_samples != y_samples {
121            return Err(SklearsError::InvalidInput(
122                "Number of samples in X and y must match".to_string(),
123            ));
124        }
125
126        if n_samples < 2 {
127            return Err(SklearsError::InvalidInput(
128                "Need at least 2 samples for SVM training".to_string(),
129            ));
130        }
131
132        // Validate that all labels are binary (0 or 1)
133        for sample_idx in 0..y_samples {
134            for label_idx in 0..n_labels {
135                let value = y[[sample_idx, label_idx]];
136                if value != 0 && value != 1 {
137                    return Err(SklearsError::InvalidInput(format!(
138                        "All label values must be 0 or 1, found: {}",
139                        value
140                    )));
141                }
142            }
143        }
144
145        // Compute feature statistics for normalization
146        let feature_means = x.mean_axis(Axis(0)).unwrap();
147        let feature_stds = x.mapv(|val| val * val).mean_axis(Axis(0)).unwrap()
148            - &feature_means.mapv(|mean| mean * mean);
149        let feature_stds = feature_stds.mapv(|var| (var.max(1e-10)).sqrt());
150
151        // Train Twin SVM for each label
152        let mut models = Vec::new();
153        for label_idx in 0..n_labels {
154            let y_label = y.column(label_idx);
155            let model = self.train_twin_svm(x, &y_label, &feature_means, &feature_stds)?;
156            models.push(model);
157        }
158
159        Ok(MLTSVM {
160            state: MLTSVMTrained {
161                models,
162                n_labels,
163                feature_means,
164                feature_stds,
165            },
166            c1: self.c1,
167            c2: self.c2,
168            epsilon: self.epsilon,
169            max_iter: self.max_iter,
170        })
171    }
172}
173
174impl MLTSVM<Untrained> {
175    fn train_twin_svm(
176        &self,
177        x: &ArrayView2<'_, Float>,
178        y: &ArrayView1<'_, i32>,
179        feature_means: &Array1<Float>,
180        feature_stds: &Array1<Float>,
181    ) -> SklResult<TwinSVMModel> {
182        let (n_samples, n_features) = x.dim();
183
184        // Normalize features
185        let mut x_normalized = x.to_owned();
186        for (i, mut row) in x_normalized.rows_mut().into_iter().enumerate() {
187            row -= feature_means;
188            row /= feature_stds;
189        }
190
191        // Separate positive and negative samples
192        let mut pos_samples = Vec::new();
193        let mut neg_samples = Vec::new();
194
195        for i in 0..n_samples {
196            if y[i] == 1 {
197                pos_samples.push(x_normalized.row(i).to_owned());
198            } else {
199                neg_samples.push(x_normalized.row(i).to_owned());
200            }
201        }
202
203        if pos_samples.is_empty() || neg_samples.is_empty() {
204            return Err(SklearsError::InvalidInput(
205                "Need both positive and negative samples for Twin SVM".to_string(),
206            ));
207        }
208
209        // Convert to matrices
210        let pos_matrix = Array2::from_shape_vec(
211            (pos_samples.len(), n_features),
212            pos_samples.into_iter().flatten().collect(),
213        )
214        .map_err(|_| SklearsError::InvalidInput("Failed to create positive matrix".to_string()))?;
215
216        let neg_matrix = Array2::from_shape_vec(
217            (neg_samples.len(), n_features),
218            neg_samples.into_iter().flatten().collect(),
219        )
220        .map_err(|_| SklearsError::InvalidInput("Failed to create negative matrix".to_string()))?;
221
222        // Train Twin SVM hyperplanes
223        let (w1, b1) = self.solve_twin_svm_problem(&pos_matrix, &neg_matrix, self.c1)?;
224        let (w2, b2) = self.solve_twin_svm_problem(&neg_matrix, &pos_matrix, self.c2)?;
225
226        Ok(TwinSVMModel { w1, b1, w2, b2 })
227    }
228
229    fn solve_twin_svm_problem(
230        &self,
231        target_matrix: &Array2<Float>,
232        other_matrix: &Array2<Float>,
233        c: Float,
234    ) -> SklResult<(Array1<Float>, Float)> {
235        let n_target = target_matrix.nrows();
236        let n_other = other_matrix.nrows();
237        let n_features = target_matrix.ncols();
238
239        // Initialize weights
240        let mut w = Array1::<Float>::zeros(n_features + 1); // Include bias
241
242        // Simple gradient descent solution
243        let learning_rate = 0.01;
244
245        for _iter in 0..self.max_iter {
246            let mut gradient = Array1::<Float>::zeros(n_features + 1);
247
248            // Compute gradient
249            for i in 0..n_target {
250                let x_aug = {
251                    let mut x = Array1::ones(n_features + 1);
252                    x.slice_mut(s![..n_features]).assign(&target_matrix.row(i));
253                    x
254                };
255                let loss = x_aug.dot(&w);
256                gradient += &(x_aug * loss);
257            }
258
259            for i in 0..n_other {
260                let x_aug = {
261                    let mut x = Array1::ones(n_features + 1);
262                    x.slice_mut(s![..n_features]).assign(&other_matrix.row(i));
263                    x
264                };
265                let margin = 1.0 - x_aug.dot(&w);
266                if margin > 0.0 {
267                    gradient -= &(x_aug * c);
268                }
269            }
270
271            // Check convergence before updating weights
272            let gradient_norm = gradient.mapv(|x| x.abs()).sum();
273
274            // Update weights
275            w -= &(gradient * learning_rate);
276
277            if gradient_norm < self.epsilon {
278                break;
279            }
280        }
281
282        let weights = w.slice(s![..n_features]).to_owned();
283        let bias = w[n_features];
284
285        Ok((weights, bias))
286    }
287}
288
289impl Predict<ArrayView2<'_, Float>, Array2<i32>> for MLTSVM<MLTSVMTrained> {
290    fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
291        let (n_samples, n_features) = x.dim();
292        let expected_features = self.state.feature_means.len();
293
294        if n_features != expected_features {
295            return Err(SklearsError::InvalidInput(format!(
296                "Number of features in X ({}) does not match training data ({})",
297                n_features, expected_features
298            )));
299        }
300
301        let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_labels));
302
303        // Normalize features
304        let mut x_normalized = x.to_owned();
305        for (i, mut row) in x_normalized.rows_mut().into_iter().enumerate() {
306            row -= &self.state.feature_means;
307            row /= &self.state.feature_stds;
308        }
309
310        for label_idx in 0..self.state.n_labels {
311            let model = &self.state.models[label_idx];
312
313            for sample_idx in 0..n_samples {
314                let x_sample = x_normalized.row(sample_idx);
315
316                // Compute distances to both hyperplanes
317                let dist1 = (x_sample.dot(&model.w1) + model.b1).abs();
318                let dist2 = (x_sample.dot(&model.w2) + model.b2).abs();
319
320                // Predict based on closer hyperplane
321                predictions[[sample_idx, label_idx]] = if dist1 < dist2 { 1 } else { 0 };
322            }
323        }
324
325        Ok(predictions)
326    }
327}
328
329impl MLTSVM<MLTSVMTrained> {
330    /// Get the number of labels
331    pub fn n_labels(&self) -> usize {
332        self.state.n_labels
333    }
334
335    /// Get decision function values (distances to hyperplanes)
336    pub fn decision_function(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
337        let (n_samples, _n_features) = x.dim();
338        let mut decision_values = Array2::<Float>::zeros((n_samples, self.state.n_labels));
339
340        // Normalize features
341        let mut x_normalized = x.to_owned();
342        for (i, mut row) in x_normalized.rows_mut().into_iter().enumerate() {
343            row -= &self.state.feature_means;
344            row /= &self.state.feature_stds;
345        }
346
347        for label_idx in 0..self.state.n_labels {
348            let model = &self.state.models[label_idx];
349
350            for sample_idx in 0..n_samples {
351                let x_sample = x_normalized.row(sample_idx);
352
353                // Compute distances to both hyperplanes and use the difference
354                let dist1 = x_sample.dot(&model.w1) + model.b1;
355                let dist2 = x_sample.dot(&model.w2) + model.b2;
356
357                // Decision function is the difference (positive means class 1)
358                decision_values[[sample_idx, label_idx]] = dist1 - dist2;
359            }
360        }
361
362        Ok(decision_values)
363    }
364}
365
366/// RankSVM for Multi-Label Classification
367///
368/// RankSVM is a ranking-based approach for multi-label classification that optimizes
369/// ranking loss functions. It learns to rank labels by their relevance scores and
370/// can handle both label ranking and threshold selection for multi-label prediction.
371#[derive(Debug, Clone)]
372pub struct RankSVM<S = Untrained> {
373    state: S,
374    c: Float,                              // Regularization parameter
375    epsilon: Float,                        // Tolerance for convergence
376    max_iter: usize,                       // Maximum iterations
377    threshold_strategy: ThresholdStrategy, // How to determine prediction thresholds
378}
379
380/// Threshold strategy for RankSVM
381#[derive(Debug, Clone)]
382pub enum ThresholdStrategy {
383    /// Use fixed threshold for all labels
384    Fixed(Float),
385    /// Optimize threshold to maximize F1 score for each label
386    OptimizeF1,
387    /// Use top-k labels (fixed number of labels per sample)
388    TopK(usize),
389}
390
391/// Trained state for RankSVM
392#[derive(Debug, Clone)]
393pub struct RankSVMTrained {
394    models: Vec<RankingSVMModel>, // One model per label
395    thresholds: Vec<Float>,       // Prediction thresholds for each label
396    n_labels: usize,
397    feature_means: Array1<Float>,
398    feature_stds: Array1<Float>,
399}
400
401/// Single ranking SVM model for one label
402#[derive(Debug, Clone)]
403pub struct RankingSVMModel {
404    weights: Array1<Float>,
405    bias: Float,
406}
407
408impl RankSVM<Untrained> {
409    /// Create a new RankSVM instance
410    pub fn new() -> Self {
411        Self {
412            state: Untrained,
413            c: 1.0,
414            epsilon: 1e-3,
415            max_iter: 1000,
416            threshold_strategy: ThresholdStrategy::Fixed(0.0),
417        }
418    }
419
420    /// Set regularization parameter
421    pub fn c(mut self, c: Float) -> Self {
422        self.c = c;
423        self
424    }
425
426    /// Set convergence tolerance
427    pub fn epsilon(mut self, epsilon: Float) -> Self {
428        self.epsilon = epsilon;
429        self
430    }
431
432    /// Set maximum iterations
433    pub fn max_iter(mut self, max_iter: usize) -> Self {
434        self.max_iter = max_iter;
435        self
436    }
437
438    /// Set threshold strategy
439    pub fn threshold_strategy(mut self, strategy: ThresholdStrategy) -> Self {
440        self.threshold_strategy = strategy;
441        self
442    }
443}
444
445impl Default for RankSVM<Untrained> {
446    fn default() -> Self {
447        Self::new()
448    }
449}
450
451impl Estimator for RankSVM<Untrained> {
452    type Config = ();
453    type Error = SklearsError;
454    type Float = Float;
455
456    fn config(&self) -> &Self::Config {
457        &()
458    }
459}
460
461impl Fit<ArrayView2<'_, Float>, Array2<i32>> for RankSVM<Untrained> {
462    type Fitted = RankSVM<RankSVMTrained>;
463
464    fn fit(self, x: &ArrayView2<'_, Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
465        let (n_samples, n_features) = x.dim();
466        let (y_samples, n_labels) = y.dim();
467
468        if n_samples != y_samples {
469            return Err(SklearsError::InvalidInput(
470                "Number of samples in X and y must match".to_string(),
471            ));
472        }
473
474        // Validate that all labels are binary (0 or 1)
475        for sample_idx in 0..y_samples {
476            for label_idx in 0..n_labels {
477                let value = y[[sample_idx, label_idx]];
478                if value != 0 && value != 1 {
479                    return Err(SklearsError::InvalidInput(format!(
480                        "All label values must be 0 or 1, found: {}",
481                        value
482                    )));
483                }
484            }
485        }
486
487        // Compute feature statistics
488        let feature_means = x.mean_axis(Axis(0)).ok_or_else(|| {
489            SklearsError::InvalidInput("Cannot compute feature means from input data".to_string())
490        })?;
491
492        let squared_means = x.mapv(|val| val * val).mean_axis(Axis(0)).ok_or_else(|| {
493            SklearsError::InvalidInput("Cannot compute squared means from input data".to_string())
494        })?;
495
496        let feature_stds = squared_means - &feature_means.mapv(|mean| mean * mean);
497        let feature_stds = feature_stds.mapv(|var| (var.max(1e-10)).sqrt());
498
499        // Train ranking SVM for each label
500        let mut models = Vec::new();
501        for label_idx in 0..n_labels {
502            let y_label = y.column(label_idx);
503            let model = self.train_ranking_svm(x, &y_label, &feature_means, &feature_stds)?;
504            models.push(model);
505        }
506
507        // Determine thresholds
508        let thresholds = match &self.threshold_strategy {
509            ThresholdStrategy::Fixed(threshold) => vec![*threshold; n_labels],
510            ThresholdStrategy::OptimizeF1 => {
511                self.optimize_f1_thresholds(x, y, &models, &feature_means, &feature_stds)?
512            }
513            ThresholdStrategy::TopK(_) => vec![0.0; n_labels], // No threshold needed for TopK
514        };
515
516        Ok(RankSVM {
517            state: RankSVMTrained {
518                models,
519                thresholds,
520                n_labels,
521                feature_means,
522                feature_stds,
523            },
524            c: self.c,
525            epsilon: self.epsilon,
526            max_iter: self.max_iter,
527            threshold_strategy: self.threshold_strategy,
528        })
529    }
530}
531
532impl RankSVM<Untrained> {
533    fn train_ranking_svm(
534        &self,
535        x: &ArrayView2<'_, Float>,
536        y: &ArrayView1<'_, i32>,
537        feature_means: &Array1<Float>,
538        feature_stds: &Array1<Float>,
539    ) -> SklResult<RankingSVMModel> {
540        let (n_samples, n_features) = x.dim();
541
542        // Normalize features
543        let mut x_normalized = x.to_owned();
544        for (i, mut row) in x_normalized.rows_mut().into_iter().enumerate() {
545            row -= feature_means;
546            row /= feature_stds;
547        }
548
549        // Initialize weights and bias
550        let mut weights = Array1::<Float>::zeros(n_features);
551        let mut bias = 0.0;
552
553        let learning_rate = 0.01;
554
555        // Gradient descent optimization
556        for _iter in 0..self.max_iter {
557            let mut weight_gradient = Array1::<Float>::zeros(n_features);
558            let mut bias_gradient = 0.0;
559
560            // Create ranking pairs
561            for i in 0..n_samples {
562                for j in 0..n_samples {
563                    if y[i] > y[j] {
564                        // i should be ranked higher than j
565                        let x_i = x_normalized.row(i);
566                        let x_j = x_normalized.row(j);
567                        let x_diff = &x_i.to_owned() - &x_j.to_owned();
568
569                        let score_diff = x_diff.dot(&weights) + bias;
570                        let margin = 1.0 - score_diff;
571
572                        if margin > 0.0 {
573                            // Hinge loss gradient
574                            weight_gradient -= &(x_diff * self.c);
575                            bias_gradient -= self.c;
576                        }
577                    }
578                }
579            }
580
581            // L2 regularization
582            weight_gradient += &(&weights * 2.0);
583
584            // Check convergence before updating parameters
585            let gradient_norm = weight_gradient.mapv(|x| x.abs()).sum();
586
587            // Update parameters
588            weights -= &(weight_gradient * learning_rate);
589            bias -= bias_gradient * learning_rate;
590
591            if gradient_norm < self.epsilon {
592                break;
593            }
594        }
595
596        Ok(RankingSVMModel { weights, bias })
597    }
598
599    fn optimize_f1_thresholds(
600        &self,
601        x: &ArrayView2<'_, Float>,
602        y: &Array2<i32>,
603        models: &[RankingSVMModel],
604        feature_means: &Array1<Float>,
605        feature_stds: &Array1<Float>,
606    ) -> SklResult<Vec<Float>> {
607        let mut thresholds = Vec::new();
608
609        for label_idx in 0..y.ncols() {
610            let y_true = y.column(label_idx);
611            let scores = self.predict_scores_single_label(
612                x,
613                &models[label_idx],
614                feature_means,
615                feature_stds,
616            )?;
617
618            let threshold = self.find_optimal_f1_threshold(&y_true, &scores)?;
619            thresholds.push(threshold);
620        }
621
622        Ok(thresholds)
623    }
624
625    fn predict_scores_single_label(
626        &self,
627        x: &ArrayView2<'_, Float>,
628        model: &RankingSVMModel,
629        feature_means: &Array1<Float>,
630        feature_stds: &Array1<Float>,
631    ) -> SklResult<Array1<Float>> {
632        let (n_samples, _) = x.dim();
633        let mut scores = Array1::<Float>::zeros(n_samples);
634
635        for i in 0..n_samples {
636            let x_sample = x.row(i);
637            let x_normalized = (&x_sample.to_owned() - feature_means) / feature_stds;
638            scores[i] = x_normalized.dot(&model.weights) + model.bias;
639        }
640
641        Ok(scores)
642    }
643
644    fn find_optimal_f1_threshold(
645        &self,
646        y_true: &ArrayView1<'_, i32>,
647        scores: &Array1<Float>,
648    ) -> SklResult<Float> {
649        let mut score_threshold_pairs: Vec<(Float, i32)> = scores
650            .iter()
651            .zip(y_true.iter())
652            .map(|(&score, &label)| (score, label))
653            .collect();
654
655        score_threshold_pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
656
657        let mut best_f1 = 0.0;
658        let mut best_threshold = 0.0;
659
660        // Try each unique score as a threshold
661        for &(threshold, _) in &score_threshold_pairs {
662            let mut tp = 0;
663            let mut fp = 0;
664            let mut fn_count = 0;
665
666            for (&score, &true_label) in scores.iter().zip(y_true.iter()) {
667                let predicted = if score >= threshold { 1 } else { 0 };
668
669                match (true_label, predicted) {
670                    (1, 1) => tp += 1,
671                    (0, 1) => fp += 1,
672                    (1, 0) => fn_count += 1,
673                    _ => {}
674                }
675            }
676
677            let precision = if tp + fp > 0 {
678                tp as Float / (tp + fp) as Float
679            } else {
680                0.0
681            };
682            let recall = if tp + fn_count > 0 {
683                tp as Float / (tp + fn_count) as Float
684            } else {
685                0.0
686            };
687            let f1 = if precision + recall > 0.0 {
688                2.0 * precision * recall / (precision + recall)
689            } else {
690                0.0
691            };
692
693            if f1 > best_f1 {
694                best_f1 = f1;
695                best_threshold = threshold;
696            }
697        }
698
699        Ok(best_threshold)
700    }
701}
702
703impl Predict<ArrayView2<'_, Float>, Array2<i32>> for RankSVM<RankSVMTrained> {
704    fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
705        let (n_samples, n_features) = x.dim();
706        let expected_features = self.state.feature_means.len();
707
708        if n_features != expected_features {
709            return Err(SklearsError::InvalidInput(format!(
710                "Number of features in X ({}) does not match training data ({})",
711                n_features, expected_features
712            )));
713        }
714
715        let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_labels));
716
717        match &self.threshold_strategy {
718            ThresholdStrategy::TopK(k) => {
719                // For TopK, rank all labels and select top k
720                for sample_idx in 0..n_samples {
721                    let mut scores = Vec::new();
722                    for label_idx in 0..self.state.n_labels {
723                        let x_sample = x.row(sample_idx);
724                        let x_normalized = (&x_sample.to_owned() - &self.state.feature_means)
725                            / &self.state.feature_stds;
726                        let score = x_normalized.dot(&self.state.models[label_idx].weights)
727                            + self.state.models[label_idx].bias;
728                        scores.push((score, label_idx));
729                    }
730
731                    scores.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
732
733                    for (i, &(_, label_idx)) in scores.iter().take(*k).enumerate() {
734                        predictions[[sample_idx, label_idx]] = 1;
735                    }
736                }
737            }
738            _ => {
739                // For fixed or optimized thresholds
740                for label_idx in 0..self.state.n_labels {
741                    let threshold = self.state.thresholds[label_idx];
742
743                    for sample_idx in 0..n_samples {
744                        let x_sample = x.row(sample_idx);
745                        let x_normalized = (&x_sample.to_owned() - &self.state.feature_means)
746                            / &self.state.feature_stds;
747                        let score = x_normalized.dot(&self.state.models[label_idx].weights)
748                            + self.state.models[label_idx].bias;
749
750                        predictions[[sample_idx, label_idx]] =
751                            if score >= threshold { 1 } else { 0 };
752                    }
753                }
754            }
755        }
756
757        Ok(predictions)
758    }
759}
760
761impl RankSVM<RankSVMTrained> {
762    /// Get the number of labels
763    pub fn n_labels(&self) -> usize {
764        self.state.n_labels
765    }
766
767    /// Get decision function values (ranking scores)
768    pub fn decision_function(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
769        let (n_samples, n_features) = x.dim();
770        let expected_features = self.state.feature_means.len();
771
772        if n_features != expected_features {
773            return Err(SklearsError::InvalidInput(format!(
774                "Number of features in X ({}) does not match training data ({})",
775                n_features, expected_features
776            )));
777        }
778
779        let mut decision_values = Array2::<Float>::zeros((n_samples, self.state.n_labels));
780
781        for sample_idx in 0..n_samples {
782            for label_idx in 0..self.state.n_labels {
783                let x_sample = x.row(sample_idx);
784                let x_normalized =
785                    (&x_sample.to_owned() - &self.state.feature_means) / &self.state.feature_stds;
786                let score = x_normalized.dot(&self.state.models[label_idx].weights)
787                    + self.state.models[label_idx].bias;
788                decision_values[[sample_idx, label_idx]] = score;
789            }
790        }
791
792        Ok(decision_values)
793    }
794
795    /// Get ranking predictions (label indices ordered by relevance)
796    pub fn predict_ranking(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<usize>> {
797        let (n_samples, n_features) = x.dim();
798        let expected_features = self.state.feature_means.len();
799
800        if n_features != expected_features {
801            return Err(SklearsError::InvalidInput(format!(
802                "Number of features in X ({}) does not match training data ({})",
803                n_features, expected_features
804            )));
805        }
806
807        let mut rankings = Array2::<usize>::zeros((n_samples, self.state.n_labels));
808
809        for sample_idx in 0..n_samples {
810            let mut scores = Vec::new();
811            for label_idx in 0..self.state.n_labels {
812                let x_sample = x.row(sample_idx);
813                let x_normalized =
814                    (&x_sample.to_owned() - &self.state.feature_means) / &self.state.feature_stds;
815                let score = x_normalized.dot(&self.state.models[label_idx].weights)
816                    + self.state.models[label_idx].bias;
817                scores.push((score, label_idx));
818            }
819
820            // Sort by score descending
821            scores.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
822
823            // Assign rankings
824            for (rank, &(_score, label_idx)) in scores.iter().enumerate() {
825                rankings[[sample_idx, rank]] = label_idx;
826            }
827        }
828
829        Ok(rankings)
830    }
831
832    /// Get the thresholds used for prediction
833    pub fn thresholds(&self) -> &Vec<Float> {
834        &self.state.thresholds
835    }
836}
837
838/// Multi-Output Support Vector Machine
839///
840/// A multi-output support vector machine that handles multiple regression or
841/// classification targets simultaneously by training separate SVM models for each output.
842#[derive(Debug, Clone)]
843pub struct MultiOutputSVM<S = Untrained> {
844    state: S,
845    kernel: SVMKernel,
846    c: Float,
847    epsilon: Float,
848    gamma: Option<Float>,
849}
850
851/// SVM Kernel types
852#[derive(Debug, Clone, Copy, PartialEq)]
853pub enum SVMKernel {
854    /// Linear kernel: K(x, y) = x^T y
855    Linear,
856    /// Polynomial kernel: K(x, y) = (gamma * x^T y + coef0)^degree
857    Polynomial {
858        degree: i32,
859        gamma: Float,
860        coef0: Float,
861    },
862    /// Radial Basis Function kernel: K(x, y) = exp(-gamma * ||x - y||^2)
863    Rbf { gamma: Float },
864    /// Sigmoid kernel: K(x, y) = tanh(gamma * x^T y + coef0)
865    Sigmoid { gamma: Float, coef0: Float },
866}
867
868/// Trained state for MultiOutputSVM
869#[derive(Debug, Clone)]
870pub struct MultiOutputSVMTrained {
871    models: Vec<SVMModel>,
872    n_outputs: usize,
873    feature_means: Array1<Float>,
874    feature_stds: Array1<Float>,
875}
876
877/// Single SVM model for one output
878#[derive(Debug, Clone)]
879pub struct SVMModel {
880    support_vectors: Array2<Float>,
881    support_coefficients: Array1<Float>,
882    bias: Float,
883    kernel: SVMKernel,
884}
885
886impl MultiOutputSVM<Untrained> {
887    /// Create a new MultiOutputSVM instance
888    pub fn new() -> Self {
889        Self {
890            state: Untrained,
891            kernel: SVMKernel::Rbf { gamma: 1.0 },
892            c: 1.0,
893            epsilon: 1e-3,
894            gamma: None,
895        }
896    }
897
898    /// Set the kernel
899    pub fn kernel(mut self, kernel: SVMKernel) -> Self {
900        self.kernel = kernel;
901        self
902    }
903
904    /// Set regularization parameter
905    pub fn c(mut self, c: Float) -> Self {
906        self.c = c;
907        self
908    }
909
910    /// Set tolerance for stopping criterion
911    pub fn epsilon(mut self, epsilon: Float) -> Self {
912        self.epsilon = epsilon;
913        self
914    }
915
916    /// Set gamma parameter (will override kernel-specific gamma)
917    pub fn gamma(mut self, gamma: Float) -> Self {
918        self.gamma = Some(gamma);
919        self
920    }
921}
922
923impl Default for MultiOutputSVM<Untrained> {
924    fn default() -> Self {
925        Self::new()
926    }
927}
928
929impl Estimator for MultiOutputSVM<Untrained> {
930    type Config = ();
931    type Error = SklearsError;
932    type Float = Float;
933
934    fn config(&self) -> &Self::Config {
935        &()
936    }
937}
938
939impl Fit<ArrayView2<'_, Float>, ArrayView2<'_, Float>> for MultiOutputSVM<Untrained> {
940    type Fitted = MultiOutputSVM<MultiOutputSVMTrained>;
941
942    fn fit(self, x: &ArrayView2<'_, Float>, y: &ArrayView2<'_, Float>) -> SklResult<Self::Fitted> {
943        let (n_samples, n_features) = x.dim();
944        let (y_samples, n_outputs) = y.dim();
945
946        if n_samples != y_samples {
947            return Err(SklearsError::InvalidInput(
948                "Number of samples in X and y must match".to_string(),
949            ));
950        }
951
952        // Compute feature statistics
953        let feature_means = x.mean_axis(Axis(0)).unwrap();
954        let feature_stds = x.mapv(|val| val * val).mean_axis(Axis(0)).unwrap()
955            - &feature_means.mapv(|mean| mean * mean);
956        let feature_stds = feature_stds.mapv(|var| (var.max(1e-10)).sqrt());
957
958        // Update kernel gamma if specified
959        let kernel = if let Some(gamma) = self.gamma {
960            match self.kernel {
961                SVMKernel::Rbf { .. } => SVMKernel::Rbf { gamma },
962                SVMKernel::Polynomial { degree, coef0, .. } => SVMKernel::Polynomial {
963                    degree,
964                    gamma,
965                    coef0,
966                },
967                SVMKernel::Sigmoid { coef0, .. } => SVMKernel::Sigmoid { gamma, coef0 },
968                other => other,
969            }
970        } else {
971            self.kernel
972        };
973
974        // Train one SVM for each output
975        let mut models = Vec::new();
976        for output_idx in 0..n_outputs {
977            let y_output = y.column(output_idx);
978            let model =
979                self.train_single_svm(x, &y_output, &feature_means, &feature_stds, kernel)?;
980            models.push(model);
981        }
982
983        Ok(MultiOutputSVM {
984            state: MultiOutputSVMTrained {
985                models,
986                n_outputs,
987                feature_means,
988                feature_stds,
989            },
990            kernel,
991            c: self.c,
992            epsilon: self.epsilon,
993            gamma: self.gamma,
994        })
995    }
996}
997
998impl MultiOutputSVM<Untrained> {
999    fn train_single_svm(
1000        &self,
1001        x: &ArrayView2<'_, Float>,
1002        y: &ArrayView1<'_, Float>,
1003        feature_means: &Array1<Float>,
1004        feature_stds: &Array1<Float>,
1005        kernel: SVMKernel,
1006    ) -> SklResult<SVMModel> {
1007        let (n_samples, n_features) = x.dim();
1008
1009        // Normalize features
1010        let mut x_normalized = x.to_owned();
1011        for (i, mut row) in x_normalized.rows_mut().into_iter().enumerate() {
1012            row -= feature_means;
1013            row /= feature_stds;
1014        }
1015
1016        // For simplicity, we'll implement a basic SVM using all samples as support vectors
1017        // In a real implementation, you'd use SMO or other optimization algorithms
1018        let support_vectors = x_normalized.clone();
1019        let mut support_coefficients = Array1::<Float>::zeros(n_samples);
1020
1021        // Simple heuristic: coefficients proportional to target values
1022        let y_mean = y.mean().unwrap();
1023        for i in 0..n_samples {
1024            support_coefficients[i] = (y[i] - y_mean) / self.c;
1025        }
1026
1027        let bias = y_mean;
1028
1029        Ok(SVMModel {
1030            support_vectors,
1031            support_coefficients,
1032            bias,
1033            kernel,
1034        })
1035    }
1036}
1037
1038impl Predict<ArrayView2<'_, Float>, Array2<Float>> for MultiOutputSVM<MultiOutputSVMTrained> {
1039    fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
1040        let (n_samples, _) = x.dim();
1041        let mut predictions = Array2::<Float>::zeros((n_samples, self.state.n_outputs));
1042
1043        // Normalize input features
1044        let mut x_normalized = x.to_owned();
1045        for (i, mut row) in x_normalized.rows_mut().into_iter().enumerate() {
1046            row -= &self.state.feature_means;
1047            row /= &self.state.feature_stds;
1048        }
1049
1050        for output_idx in 0..self.state.n_outputs {
1051            let model = &self.state.models[output_idx];
1052
1053            for sample_idx in 0..n_samples {
1054                let x_sample = x_normalized.row(sample_idx);
1055                let mut prediction = model.bias;
1056
1057                // Compute kernel sum
1058                for (sv_idx, support_vector) in model.support_vectors.rows().into_iter().enumerate()
1059                {
1060                    let kernel_value =
1061                        compute_kernel_value(&x_sample, &support_vector, model.kernel);
1062                    prediction += model.support_coefficients[sv_idx] * kernel_value;
1063                }
1064
1065                predictions[[sample_idx, output_idx]] = prediction;
1066            }
1067        }
1068
1069        Ok(predictions)
1070    }
1071}
1072
1073/// Compute kernel value between two vectors
1074fn compute_kernel_value(
1075    x1: &ArrayView1<Float>,
1076    x2: &ArrayView1<Float>,
1077    kernel: SVMKernel,
1078) -> Float {
1079    match kernel {
1080        SVMKernel::Linear => x1.dot(x2),
1081        SVMKernel::Polynomial {
1082            degree,
1083            gamma,
1084            coef0,
1085        } => (gamma * x1.dot(x2) + coef0).powi(degree),
1086        SVMKernel::Rbf { gamma } => {
1087            let dist_sq = x1
1088                .iter()
1089                .zip(x2.iter())
1090                .map(|(a, b)| (a - b).powi(2))
1091                .sum::<Float>();
1092            (-gamma * dist_sq).exp()
1093        }
1094        SVMKernel::Sigmoid { gamma, coef0 } => (gamma * x1.dot(x2) + coef0).tanh(),
1095    }
1096}