sklears_multioutput/
ranking.rs

1//! Label ranking and threshold optimization algorithms
2
3// Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
4use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
5use scirs2_core::random::thread_rng;
6use sklears_core::{
7    error::{Result as SklResult, SklearsError},
8    traits::{Estimator, Fit, Predict, Untrained},
9    types::Float,
10};
11
12/// Independent Label Prediction with threshold optimization
13///
14/// This approach treats multi-label classification as independent binary classification
15/// problems, with sophisticated threshold optimization strategies.
16#[derive(Debug, Clone)]
17pub struct IndependentLabelPrediction<S = Untrained> {
18    state: S,
19    threshold_strategy: ThresholdStrategy,
20    optimize_thresholds: bool,
21    class_weight: Option<String>, // "balanced" or None
22    random_state: Option<u64>,
23}
24
25/// Threshold strategy for label prediction
26#[derive(Debug, Clone)]
27pub enum ThresholdStrategy {
28    /// Fixed
29    Fixed(Float), // Use fixed threshold for all labels
30    /// PerLabel
31    PerLabel(Vec<Float>), // Use different threshold for each label
32    /// Optimal
33    Optimal, // Learn optimal thresholds from validation data
34    /// FScore
35    FScore, // Optimize F-score threshold for each label
36}
37
38/// Trained state for Independent Label Prediction
39#[derive(Debug, Clone)]
40pub struct IndependentLabelPredictionTrained {
41    binary_classifiers: Vec<BinaryClassifierModel>,
42    thresholds: Vec<Float>,
43    n_labels: usize,
44}
45
46/// Simple binary classifier model
47#[derive(Debug, Clone)]
48pub struct BinaryClassifierModel {
49    weights: Array1<Float>,
50    bias: Float,
51    feature_means: Array1<Float>,
52    feature_stds: Array1<Float>,
53}
54
55impl IndependentLabelPrediction<Untrained> {
56    /// Create a new IndependentLabelPrediction instance
57    pub fn new() -> Self {
58        Self {
59            state: Untrained,
60            threshold_strategy: ThresholdStrategy::Fixed(0.5),
61            optimize_thresholds: false,
62            class_weight: None,
63            random_state: None,
64        }
65    }
66
67    /// Set the threshold strategy
68    pub fn threshold_strategy(mut self, strategy: ThresholdStrategy) -> Self {
69        self.threshold_strategy = strategy;
70        self
71    }
72
73    /// Set whether to optimize thresholds
74    pub fn optimize_thresholds(mut self, optimize: bool) -> Self {
75        self.optimize_thresholds = optimize;
76        self
77    }
78
79    /// Set class weight strategy
80    pub fn class_weight(mut self, weight: Option<String>) -> Self {
81        self.class_weight = weight;
82        self
83    }
84
85    /// Set random state
86    pub fn random_state(mut self, random_state: Option<u64>) -> Self {
87        self.random_state = random_state;
88        self
89    }
90}
91
92impl Default for IndependentLabelPrediction<Untrained> {
93    fn default() -> Self {
94        Self::new()
95    }
96}
97
98impl Estimator for IndependentLabelPrediction<Untrained> {
99    type Config = ();
100    type Error = SklearsError;
101    type Float = Float;
102
103    fn config(&self) -> &Self::Config {
104        &()
105    }
106}
107
108impl Fit<ArrayView2<'_, Float>, ArrayView2<'_, i32>> for IndependentLabelPrediction<Untrained> {
109    type Fitted = IndependentLabelPrediction<IndependentLabelPredictionTrained>;
110
111    fn fit(self, x: &ArrayView2<'_, Float>, y: &ArrayView2<'_, i32>) -> SklResult<Self::Fitted> {
112        let (n_samples, n_features) = x.dim();
113        let (y_samples, n_labels) = y.dim();
114
115        if n_samples != y_samples {
116            return Err(SklearsError::InvalidInput(
117                "Number of samples in X and y must match".to_string(),
118            ));
119        }
120
121        if n_samples < 2 {
122            return Err(SklearsError::InvalidInput(
123                "Need at least 2 samples for training".to_string(),
124            ));
125        }
126
127        // Initialize random number generator
128        let mut rng = thread_rng();
129
130        // Train binary classifiers for each label
131        let mut binary_classifiers = Vec::new();
132        for label_idx in 0..n_labels {
133            let label_column = y.column(label_idx);
134            let classifier = self.train_binary_classifier(x, &label_column, &mut rng)?;
135            binary_classifiers.push(classifier);
136        }
137
138        // Determine thresholds
139        let thresholds = match &self.threshold_strategy {
140            ThresholdStrategy::Fixed(threshold) => vec![*threshold; n_labels],
141            ThresholdStrategy::PerLabel(thresholds) => {
142                if thresholds.len() != n_labels {
143                    return Err(SklearsError::InvalidInput(
144                        "Number of thresholds must match number of labels".to_string(),
145                    ));
146                }
147                thresholds.clone()
148            }
149            ThresholdStrategy::Optimal => {
150                self.optimize_thresholds_for_accuracy(x, y, &binary_classifiers)?
151            }
152            ThresholdStrategy::FScore => {
153                self.optimize_thresholds_for_fscore(x, y, &binary_classifiers)?
154            }
155        };
156
157        Ok(IndependentLabelPrediction {
158            state: IndependentLabelPredictionTrained {
159                binary_classifiers,
160                thresholds,
161                n_labels,
162            },
163            threshold_strategy: self.threshold_strategy,
164            optimize_thresholds: self.optimize_thresholds,
165            class_weight: self.class_weight,
166            random_state: self.random_state,
167        })
168    }
169}
170
171impl IndependentLabelPrediction<Untrained> {
172    fn train_binary_classifier(
173        &self,
174        x: &ArrayView2<'_, Float>,
175        y_label: &ArrayView1<'_, i32>,
176        rng: &mut scirs2_core::random::CoreRandom,
177    ) -> SklResult<BinaryClassifierModel> {
178        let (n_samples, n_features) = x.dim();
179
180        // Compute feature statistics for normalization
181        let feature_means = x.mean_axis(Axis(0)).unwrap();
182        let feature_stds = x.mapv(|val| val * val).mean_axis(Axis(0)).unwrap()
183            - &feature_means.mapv(|mean| mean * mean);
184        let feature_stds = feature_stds.mapv(|var| (var.max(1e-10)).sqrt());
185
186        // Normalize features
187        let mut x_normalized = x.to_owned();
188        for (i, mut row) in x_normalized.rows_mut().into_iter().enumerate() {
189            row -= &feature_means;
190            row /= &feature_stds;
191        }
192
193        // Compute class weights if requested
194        let class_weights = if self.class_weight.as_deref() == Some("balanced") {
195            let pos_count = y_label.iter().filter(|&&y| y == 1).count();
196            let neg_count = n_samples - pos_count;
197
198            if pos_count == 0 || neg_count == 0 {
199                (1.0, 1.0)
200            } else {
201                let pos_weight = n_samples as Float / (2.0 * pos_count as Float);
202                let neg_weight = n_samples as Float / (2.0 * neg_count as Float);
203                (neg_weight, pos_weight)
204            }
205        } else {
206            (1.0, 1.0)
207        };
208
209        // Simple logistic regression using gradient descent
210        let mut weights = Array1::<Float>::zeros(n_features);
211        let mut bias = 0.0;
212
213        let learning_rate = 0.01;
214        let max_iter = 1000;
215        let tolerance = 1e-6;
216
217        for iteration in 0..max_iter {
218            let mut weight_gradient = Array1::<Float>::zeros(n_features);
219            let mut bias_gradient = 0.0;
220            let mut total_loss = 0.0;
221
222            for sample_idx in 0..n_samples {
223                let x_sample = x_normalized.row(sample_idx);
224                let y_true = y_label[sample_idx] as Float;
225
226                // Forward pass
227                let logits = x_sample.dot(&weights) + bias;
228                let prediction = 1.0 / (1.0 + (-logits).exp());
229
230                // Compute loss with class weights
231                let sample_weight = if y_true > 0.5 {
232                    class_weights.1
233                } else {
234                    class_weights.0
235                };
236                let loss = -sample_weight
237                    * (y_true * prediction.ln() + (1.0 - y_true) * (1.0 - prediction).ln());
238                total_loss += loss;
239
240                // Backward pass
241                let error = sample_weight * (prediction - y_true);
242                weight_gradient += &(x_sample.to_owned() * error);
243                bias_gradient += error;
244            }
245
246            // Update parameters
247            weights -= &(weight_gradient * (learning_rate / n_samples as Float));
248            bias -= bias_gradient * (learning_rate / n_samples as Float);
249
250            // Check convergence
251            if iteration > 10 {
252                let avg_loss = total_loss / n_samples as Float;
253                if avg_loss < tolerance {
254                    break;
255                }
256            }
257        }
258
259        Ok(BinaryClassifierModel {
260            weights,
261            bias,
262            feature_means,
263            feature_stds,
264        })
265    }
266
267    fn optimize_thresholds_for_accuracy(
268        &self,
269        x: &ArrayView2<'_, Float>,
270        y: &ArrayView2<'_, i32>,
271        classifiers: &[BinaryClassifierModel],
272    ) -> SklResult<Vec<Float>> {
273        let n_labels = y.ncols();
274        let mut thresholds = Vec::new();
275
276        for label_idx in 0..n_labels {
277            let y_true = y.column(label_idx);
278            let y_scores = self.predict_probabilities_single_label(x, &classifiers[label_idx])?;
279
280            let mut best_threshold = 0.5;
281            let mut best_accuracy = 0.0;
282
283            // Grid search for best threshold
284            for threshold_int in 1..100 {
285                let threshold = threshold_int as Float / 100.0;
286
287                let mut correct = 0;
288                for sample_idx in 0..x.nrows() {
289                    let predicted = if y_scores[sample_idx] >= threshold {
290                        1
291                    } else {
292                        0
293                    };
294                    if predicted == y_true[sample_idx] {
295                        correct += 1;
296                    }
297                }
298
299                let accuracy = correct as Float / x.nrows() as Float;
300                if accuracy > best_accuracy {
301                    best_accuracy = accuracy;
302                    best_threshold = threshold;
303                }
304            }
305
306            thresholds.push(best_threshold);
307        }
308
309        Ok(thresholds)
310    }
311
312    fn optimize_thresholds_for_fscore(
313        &self,
314        x: &ArrayView2<'_, Float>,
315        y: &ArrayView2<'_, i32>,
316        classifiers: &[BinaryClassifierModel],
317    ) -> SklResult<Vec<Float>> {
318        let n_labels = y.ncols();
319        let mut thresholds = Vec::new();
320
321        for label_idx in 0..n_labels {
322            let y_true = y.column(label_idx);
323            let y_scores = self.predict_probabilities_single_label(x, &classifiers[label_idx])?;
324
325            let mut best_threshold = 0.5;
326            let mut best_fscore = 0.0;
327
328            // Grid search for best F-score threshold
329            for threshold_int in 1..100 {
330                let threshold = threshold_int as Float / 100.0;
331
332                let mut tp = 0;
333                let mut fp = 0;
334                let mut fn_count = 0;
335
336                for sample_idx in 0..x.nrows() {
337                    let predicted = if y_scores[sample_idx] >= threshold {
338                        1
339                    } else {
340                        0
341                    };
342                    let actual = y_true[sample_idx];
343
344                    match (actual, predicted) {
345                        (1, 1) => tp += 1,
346                        (0, 1) => fp += 1,
347                        (1, 0) => fn_count += 1,
348                        _ => {}
349                    }
350                }
351
352                let precision = if tp + fp > 0 {
353                    tp as Float / (tp + fp) as Float
354                } else {
355                    0.0
356                };
357                let recall = if tp + fn_count > 0 {
358                    tp as Float / (tp + fn_count) as Float
359                } else {
360                    0.0
361                };
362                let fscore = if precision + recall > 0.0 {
363                    2.0 * precision * recall / (precision + recall)
364                } else {
365                    0.0
366                };
367
368                if fscore > best_fscore {
369                    best_fscore = fscore;
370                    best_threshold = threshold;
371                }
372            }
373
374            thresholds.push(best_threshold);
375        }
376
377        Ok(thresholds)
378    }
379
380    fn predict_probabilities_single_label(
381        &self,
382        x: &ArrayView2<'_, Float>,
383        classifier: &BinaryClassifierModel,
384    ) -> SklResult<Array1<Float>> {
385        let n_samples = x.nrows();
386        let mut probabilities = Array1::<Float>::zeros(n_samples);
387
388        for sample_idx in 0..n_samples {
389            let x_sample = x.row(sample_idx);
390
391            // Normalize features
392            let x_normalized =
393                (&x_sample.to_owned() - &classifier.feature_means) / &classifier.feature_stds;
394
395            // Compute logits and probability
396            let logits = x_normalized.dot(&classifier.weights) + classifier.bias;
397            let probability = 1.0 / (1.0 + (-logits).exp());
398
399            probabilities[sample_idx] = probability;
400        }
401
402        Ok(probabilities)
403    }
404}
405
406impl Predict<ArrayView2<'_, Float>, Array2<i32>>
407    for IndependentLabelPrediction<IndependentLabelPredictionTrained>
408{
409    fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
410        let (n_samples, n_features) = x.dim();
411        let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_labels));
412
413        for label_idx in 0..self.state.n_labels {
414            let classifier = &self.state.binary_classifiers[label_idx];
415            let threshold = self.state.thresholds[label_idx];
416
417            for sample_idx in 0..n_samples {
418                let x_sample = x.row(sample_idx);
419
420                // Normalize features
421                let x_normalized =
422                    (&x_sample.to_owned() - &classifier.feature_means) / &classifier.feature_stds;
423
424                // Compute probability
425                let logits = x_normalized.dot(&classifier.weights) + classifier.bias;
426                let probability = 1.0 / (1.0 + (-logits).exp());
427
428                // Apply threshold
429                predictions[[sample_idx, label_idx]] = if probability >= threshold { 1 } else { 0 };
430            }
431        }
432
433        Ok(predictions)
434    }
435}
436
437impl IndependentLabelPrediction<IndependentLabelPredictionTrained> {
438    /// Predict probabilities for each label
439    pub fn predict_proba(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
440        let (n_samples, _) = x.dim();
441        let mut probabilities = Array2::<Float>::zeros((n_samples, self.state.n_labels));
442
443        for label_idx in 0..self.state.n_labels {
444            let classifier = &self.state.binary_classifiers[label_idx];
445
446            for sample_idx in 0..n_samples {
447                let x_sample = x.row(sample_idx);
448
449                // Normalize features
450                let x_normalized =
451                    (&x_sample.to_owned() - &classifier.feature_means) / &classifier.feature_stds;
452
453                // Compute probability
454                let logits = x_normalized.dot(&classifier.weights) + classifier.bias;
455                let probability = 1.0 / (1.0 + (-logits).exp());
456
457                probabilities[[sample_idx, label_idx]] = probability;
458            }
459        }
460
461        Ok(probabilities)
462    }
463
464    /// Get the learned thresholds for each label
465    pub fn thresholds(&self) -> &[Float] {
466        &self.state.thresholds
467    }
468
469    /// Get the feature importance scores for each label
470    pub fn feature_importances(&self) -> Vec<Array1<Float>> {
471        self.state
472            .binary_classifiers
473            .iter()
474            .map(|classifier| classifier.weights.mapv(|w| w.abs()))
475            .collect()
476    }
477}