1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
5use scirs2_core::random::thread_rng;
6use sklears_core::{
7 error::{Result as SklResult, SklearsError},
8 traits::{Estimator, Fit, Predict, Untrained},
9 types::Float,
10};
11
12#[derive(Debug, Clone)]
17pub struct IndependentLabelPrediction<S = Untrained> {
18 state: S,
19 threshold_strategy: ThresholdStrategy,
20 optimize_thresholds: bool,
21 class_weight: Option<String>, random_state: Option<u64>,
23}
24
25#[derive(Debug, Clone)]
27pub enum ThresholdStrategy {
28 Fixed(Float), PerLabel(Vec<Float>), Optimal, FScore, }
37
38#[derive(Debug, Clone)]
40pub struct IndependentLabelPredictionTrained {
41 binary_classifiers: Vec<BinaryClassifierModel>,
42 thresholds: Vec<Float>,
43 n_labels: usize,
44}
45
46#[derive(Debug, Clone)]
48pub struct BinaryClassifierModel {
49 weights: Array1<Float>,
50 bias: Float,
51 feature_means: Array1<Float>,
52 feature_stds: Array1<Float>,
53}
54
55impl IndependentLabelPrediction<Untrained> {
56 pub fn new() -> Self {
58 Self {
59 state: Untrained,
60 threshold_strategy: ThresholdStrategy::Fixed(0.5),
61 optimize_thresholds: false,
62 class_weight: None,
63 random_state: None,
64 }
65 }
66
67 pub fn threshold_strategy(mut self, strategy: ThresholdStrategy) -> Self {
69 self.threshold_strategy = strategy;
70 self
71 }
72
73 pub fn optimize_thresholds(mut self, optimize: bool) -> Self {
75 self.optimize_thresholds = optimize;
76 self
77 }
78
79 pub fn class_weight(mut self, weight: Option<String>) -> Self {
81 self.class_weight = weight;
82 self
83 }
84
85 pub fn random_state(mut self, random_state: Option<u64>) -> Self {
87 self.random_state = random_state;
88 self
89 }
90}
91
92impl Default for IndependentLabelPrediction<Untrained> {
93 fn default() -> Self {
94 Self::new()
95 }
96}
97
98impl Estimator for IndependentLabelPrediction<Untrained> {
99 type Config = ();
100 type Error = SklearsError;
101 type Float = Float;
102
103 fn config(&self) -> &Self::Config {
104 &()
105 }
106}
107
108impl Fit<ArrayView2<'_, Float>, ArrayView2<'_, i32>> for IndependentLabelPrediction<Untrained> {
109 type Fitted = IndependentLabelPrediction<IndependentLabelPredictionTrained>;
110
111 fn fit(self, x: &ArrayView2<'_, Float>, y: &ArrayView2<'_, i32>) -> SklResult<Self::Fitted> {
112 let (n_samples, n_features) = x.dim();
113 let (y_samples, n_labels) = y.dim();
114
115 if n_samples != y_samples {
116 return Err(SklearsError::InvalidInput(
117 "Number of samples in X and y must match".to_string(),
118 ));
119 }
120
121 if n_samples < 2 {
122 return Err(SklearsError::InvalidInput(
123 "Need at least 2 samples for training".to_string(),
124 ));
125 }
126
127 let mut rng = thread_rng();
129
130 let mut binary_classifiers = Vec::new();
132 for label_idx in 0..n_labels {
133 let label_column = y.column(label_idx);
134 let classifier = self.train_binary_classifier(x, &label_column, &mut rng)?;
135 binary_classifiers.push(classifier);
136 }
137
138 let thresholds = match &self.threshold_strategy {
140 ThresholdStrategy::Fixed(threshold) => vec![*threshold; n_labels],
141 ThresholdStrategy::PerLabel(thresholds) => {
142 if thresholds.len() != n_labels {
143 return Err(SklearsError::InvalidInput(
144 "Number of thresholds must match number of labels".to_string(),
145 ));
146 }
147 thresholds.clone()
148 }
149 ThresholdStrategy::Optimal => {
150 self.optimize_thresholds_for_accuracy(x, y, &binary_classifiers)?
151 }
152 ThresholdStrategy::FScore => {
153 self.optimize_thresholds_for_fscore(x, y, &binary_classifiers)?
154 }
155 };
156
157 Ok(IndependentLabelPrediction {
158 state: IndependentLabelPredictionTrained {
159 binary_classifiers,
160 thresholds,
161 n_labels,
162 },
163 threshold_strategy: self.threshold_strategy,
164 optimize_thresholds: self.optimize_thresholds,
165 class_weight: self.class_weight,
166 random_state: self.random_state,
167 })
168 }
169}
170
171impl IndependentLabelPrediction<Untrained> {
172 fn train_binary_classifier(
173 &self,
174 x: &ArrayView2<'_, Float>,
175 y_label: &ArrayView1<'_, i32>,
176 rng: &mut scirs2_core::random::CoreRandom,
177 ) -> SklResult<BinaryClassifierModel> {
178 let (n_samples, n_features) = x.dim();
179
180 let feature_means = x.mean_axis(Axis(0)).unwrap();
182 let feature_stds = x.mapv(|val| val * val).mean_axis(Axis(0)).unwrap()
183 - &feature_means.mapv(|mean| mean * mean);
184 let feature_stds = feature_stds.mapv(|var| (var.max(1e-10)).sqrt());
185
186 let mut x_normalized = x.to_owned();
188 for (i, mut row) in x_normalized.rows_mut().into_iter().enumerate() {
189 row -= &feature_means;
190 row /= &feature_stds;
191 }
192
193 let class_weights = if self.class_weight.as_deref() == Some("balanced") {
195 let pos_count = y_label.iter().filter(|&&y| y == 1).count();
196 let neg_count = n_samples - pos_count;
197
198 if pos_count == 0 || neg_count == 0 {
199 (1.0, 1.0)
200 } else {
201 let pos_weight = n_samples as Float / (2.0 * pos_count as Float);
202 let neg_weight = n_samples as Float / (2.0 * neg_count as Float);
203 (neg_weight, pos_weight)
204 }
205 } else {
206 (1.0, 1.0)
207 };
208
209 let mut weights = Array1::<Float>::zeros(n_features);
211 let mut bias = 0.0;
212
213 let learning_rate = 0.01;
214 let max_iter = 1000;
215 let tolerance = 1e-6;
216
217 for iteration in 0..max_iter {
218 let mut weight_gradient = Array1::<Float>::zeros(n_features);
219 let mut bias_gradient = 0.0;
220 let mut total_loss = 0.0;
221
222 for sample_idx in 0..n_samples {
223 let x_sample = x_normalized.row(sample_idx);
224 let y_true = y_label[sample_idx] as Float;
225
226 let logits = x_sample.dot(&weights) + bias;
228 let prediction = 1.0 / (1.0 + (-logits).exp());
229
230 let sample_weight = if y_true > 0.5 {
232 class_weights.1
233 } else {
234 class_weights.0
235 };
236 let loss = -sample_weight
237 * (y_true * prediction.ln() + (1.0 - y_true) * (1.0 - prediction).ln());
238 total_loss += loss;
239
240 let error = sample_weight * (prediction - y_true);
242 weight_gradient += &(x_sample.to_owned() * error);
243 bias_gradient += error;
244 }
245
246 weights -= &(weight_gradient * (learning_rate / n_samples as Float));
248 bias -= bias_gradient * (learning_rate / n_samples as Float);
249
250 if iteration > 10 {
252 let avg_loss = total_loss / n_samples as Float;
253 if avg_loss < tolerance {
254 break;
255 }
256 }
257 }
258
259 Ok(BinaryClassifierModel {
260 weights,
261 bias,
262 feature_means,
263 feature_stds,
264 })
265 }
266
267 fn optimize_thresholds_for_accuracy(
268 &self,
269 x: &ArrayView2<'_, Float>,
270 y: &ArrayView2<'_, i32>,
271 classifiers: &[BinaryClassifierModel],
272 ) -> SklResult<Vec<Float>> {
273 let n_labels = y.ncols();
274 let mut thresholds = Vec::new();
275
276 for label_idx in 0..n_labels {
277 let y_true = y.column(label_idx);
278 let y_scores = self.predict_probabilities_single_label(x, &classifiers[label_idx])?;
279
280 let mut best_threshold = 0.5;
281 let mut best_accuracy = 0.0;
282
283 for threshold_int in 1..100 {
285 let threshold = threshold_int as Float / 100.0;
286
287 let mut correct = 0;
288 for sample_idx in 0..x.nrows() {
289 let predicted = if y_scores[sample_idx] >= threshold {
290 1
291 } else {
292 0
293 };
294 if predicted == y_true[sample_idx] {
295 correct += 1;
296 }
297 }
298
299 let accuracy = correct as Float / x.nrows() as Float;
300 if accuracy > best_accuracy {
301 best_accuracy = accuracy;
302 best_threshold = threshold;
303 }
304 }
305
306 thresholds.push(best_threshold);
307 }
308
309 Ok(thresholds)
310 }
311
312 fn optimize_thresholds_for_fscore(
313 &self,
314 x: &ArrayView2<'_, Float>,
315 y: &ArrayView2<'_, i32>,
316 classifiers: &[BinaryClassifierModel],
317 ) -> SklResult<Vec<Float>> {
318 let n_labels = y.ncols();
319 let mut thresholds = Vec::new();
320
321 for label_idx in 0..n_labels {
322 let y_true = y.column(label_idx);
323 let y_scores = self.predict_probabilities_single_label(x, &classifiers[label_idx])?;
324
325 let mut best_threshold = 0.5;
326 let mut best_fscore = 0.0;
327
328 for threshold_int in 1..100 {
330 let threshold = threshold_int as Float / 100.0;
331
332 let mut tp = 0;
333 let mut fp = 0;
334 let mut fn_count = 0;
335
336 for sample_idx in 0..x.nrows() {
337 let predicted = if y_scores[sample_idx] >= threshold {
338 1
339 } else {
340 0
341 };
342 let actual = y_true[sample_idx];
343
344 match (actual, predicted) {
345 (1, 1) => tp += 1,
346 (0, 1) => fp += 1,
347 (1, 0) => fn_count += 1,
348 _ => {}
349 }
350 }
351
352 let precision = if tp + fp > 0 {
353 tp as Float / (tp + fp) as Float
354 } else {
355 0.0
356 };
357 let recall = if tp + fn_count > 0 {
358 tp as Float / (tp + fn_count) as Float
359 } else {
360 0.0
361 };
362 let fscore = if precision + recall > 0.0 {
363 2.0 * precision * recall / (precision + recall)
364 } else {
365 0.0
366 };
367
368 if fscore > best_fscore {
369 best_fscore = fscore;
370 best_threshold = threshold;
371 }
372 }
373
374 thresholds.push(best_threshold);
375 }
376
377 Ok(thresholds)
378 }
379
380 fn predict_probabilities_single_label(
381 &self,
382 x: &ArrayView2<'_, Float>,
383 classifier: &BinaryClassifierModel,
384 ) -> SklResult<Array1<Float>> {
385 let n_samples = x.nrows();
386 let mut probabilities = Array1::<Float>::zeros(n_samples);
387
388 for sample_idx in 0..n_samples {
389 let x_sample = x.row(sample_idx);
390
391 let x_normalized =
393 (&x_sample.to_owned() - &classifier.feature_means) / &classifier.feature_stds;
394
395 let logits = x_normalized.dot(&classifier.weights) + classifier.bias;
397 let probability = 1.0 / (1.0 + (-logits).exp());
398
399 probabilities[sample_idx] = probability;
400 }
401
402 Ok(probabilities)
403 }
404}
405
406impl Predict<ArrayView2<'_, Float>, Array2<i32>>
407 for IndependentLabelPrediction<IndependentLabelPredictionTrained>
408{
409 fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
410 let (n_samples, n_features) = x.dim();
411 let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_labels));
412
413 for label_idx in 0..self.state.n_labels {
414 let classifier = &self.state.binary_classifiers[label_idx];
415 let threshold = self.state.thresholds[label_idx];
416
417 for sample_idx in 0..n_samples {
418 let x_sample = x.row(sample_idx);
419
420 let x_normalized =
422 (&x_sample.to_owned() - &classifier.feature_means) / &classifier.feature_stds;
423
424 let logits = x_normalized.dot(&classifier.weights) + classifier.bias;
426 let probability = 1.0 / (1.0 + (-logits).exp());
427
428 predictions[[sample_idx, label_idx]] = if probability >= threshold { 1 } else { 0 };
430 }
431 }
432
433 Ok(predictions)
434 }
435}
436
437impl IndependentLabelPrediction<IndependentLabelPredictionTrained> {
438 pub fn predict_proba(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
440 let (n_samples, _) = x.dim();
441 let mut probabilities = Array2::<Float>::zeros((n_samples, self.state.n_labels));
442
443 for label_idx in 0..self.state.n_labels {
444 let classifier = &self.state.binary_classifiers[label_idx];
445
446 for sample_idx in 0..n_samples {
447 let x_sample = x.row(sample_idx);
448
449 let x_normalized =
451 (&x_sample.to_owned() - &classifier.feature_means) / &classifier.feature_stds;
452
453 let logits = x_normalized.dot(&classifier.weights) + classifier.bias;
455 let probability = 1.0 / (1.0 + (-logits).exp());
456
457 probabilities[[sample_idx, label_idx]] = probability;
458 }
459 }
460
461 Ok(probabilities)
462 }
463
464 pub fn thresholds(&self) -> &[Float] {
466 &self.state.thresholds
467 }
468
469 pub fn feature_importances(&self) -> Vec<Array1<Float>> {
471 self.state
472 .binary_classifiers
473 .iter()
474 .map(|classifier| classifier.weights.mapv(|w| w.abs()))
475 .collect()
476 }
477}