sklears_linear/
loss_functions.rs

1//! Pluggable Loss Functions for Linear Models
2//!
3//! This module implements various loss functions that can be used with the modular framework.
4//! All loss functions implement the LossFunction trait for consistency and pluggability.
5
6use crate::modular_framework::LossFunction;
7use scirs2_core::ndarray::Array1;
8use sklears_core::{
9    error::{Result, SklearsError},
10    types::Float,
11};
12
13/// Mean Squared Error loss for regression
14#[derive(Debug, Clone)]
15pub struct SquaredLoss;
16
17impl LossFunction for SquaredLoss {
18    fn loss(&self, y_true: &Array1<Float>, y_pred: &Array1<Float>) -> Result<Float> {
19        if y_true.len() != y_pred.len() {
20            return Err(SklearsError::DimensionMismatch {
21                expected: y_true.len(),
22                actual: y_pred.len(),
23            });
24        }
25
26        let diff = y_pred - y_true;
27        let mse = diff.mapv(|x| x * x).sum() / (2.0 * y_true.len() as Float);
28        Ok(mse)
29    }
30
31    fn loss_derivative(
32        &self,
33        y_true: &Array1<Float>,
34        y_pred: &Array1<Float>,
35    ) -> Result<Array1<Float>> {
36        if y_true.len() != y_pred.len() {
37            return Err(SklearsError::DimensionMismatch {
38                expected: y_true.len(),
39                actual: y_pred.len(),
40            });
41        }
42
43        Ok((y_pred - y_true) / (y_true.len() as Float))
44    }
45
46    fn name(&self) -> &'static str {
47        "SquaredLoss"
48    }
49}
50
51/// Mean Absolute Error loss for robust regression
52#[derive(Debug, Clone)]
53pub struct AbsoluteLoss;
54
55impl LossFunction for AbsoluteLoss {
56    fn loss(&self, y_true: &Array1<Float>, y_pred: &Array1<Float>) -> Result<Float> {
57        if y_true.len() != y_pred.len() {
58            return Err(SklearsError::DimensionMismatch {
59                expected: y_true.len(),
60                actual: y_pred.len(),
61            });
62        }
63
64        let mae = (y_pred - y_true).mapv(|x| x.abs()).sum() / (y_true.len() as Float);
65        Ok(mae)
66    }
67
68    fn loss_derivative(
69        &self,
70        y_true: &Array1<Float>,
71        y_pred: &Array1<Float>,
72    ) -> Result<Array1<Float>> {
73        if y_true.len() != y_pred.len() {
74            return Err(SklearsError::DimensionMismatch {
75                expected: y_true.len(),
76                actual: y_pred.len(),
77            });
78        }
79
80        let derivative = (y_pred - y_true).mapv(|x| {
81            if x > 0.0 {
82                1.0
83            } else if x < 0.0 {
84                -1.0
85            } else {
86                0.0
87            }
88        });
89        Ok(derivative / (y_true.len() as Float))
90    }
91
92    fn name(&self) -> &'static str {
93        "AbsoluteLoss"
94    }
95}
96
97/// Huber loss for robust regression (combination of squared and absolute loss)
98#[derive(Debug, Clone)]
99pub struct HuberLoss {
100    /// Threshold parameter controlling the transition between squared and absolute loss
101    pub delta: Float,
102}
103
104impl HuberLoss {
105    /// Create a new Huber loss with the specified delta parameter
106    pub fn new(delta: Float) -> Self {
107        Self { delta }
108    }
109}
110
111impl Default for HuberLoss {
112    fn default() -> Self {
113        Self::new(1.0)
114    }
115}
116
117impl LossFunction for HuberLoss {
118    fn loss(&self, y_true: &Array1<Float>, y_pred: &Array1<Float>) -> Result<Float> {
119        if y_true.len() != y_pred.len() {
120            return Err(SklearsError::DimensionMismatch {
121                expected: y_true.len(),
122                actual: y_pred.len(),
123            });
124        }
125
126        let residuals = y_pred - y_true;
127        let loss_sum = residuals
128            .mapv(|r| {
129                let abs_r = r.abs();
130                if abs_r <= self.delta {
131                    0.5 * r * r
132                } else {
133                    self.delta * abs_r - 0.5 * self.delta * self.delta
134                }
135            })
136            .sum();
137
138        Ok(loss_sum / (y_true.len() as Float))
139    }
140
141    fn loss_derivative(
142        &self,
143        y_true: &Array1<Float>,
144        y_pred: &Array1<Float>,
145    ) -> Result<Array1<Float>> {
146        if y_true.len() != y_pred.len() {
147            return Err(SklearsError::DimensionMismatch {
148                expected: y_true.len(),
149                actual: y_pred.len(),
150            });
151        }
152
153        let residuals = y_pred - y_true;
154        let derivative = residuals.mapv(|r| {
155            if r.abs() <= self.delta {
156                r
157            } else if r > 0.0 {
158                self.delta
159            } else {
160                -self.delta
161            }
162        });
163
164        Ok(derivative / (y_true.len() as Float))
165    }
166
167    fn name(&self) -> &'static str {
168        "HuberLoss"
169    }
170}
171
172/// Logistic loss for binary classification
173#[derive(Debug, Clone)]
174pub struct LogisticLoss;
175
176impl LossFunction for LogisticLoss {
177    fn loss(&self, y_true: &Array1<Float>, y_pred: &Array1<Float>) -> Result<Float> {
178        if y_true.len() != y_pred.len() {
179            return Err(SklearsError::DimensionMismatch {
180                expected: y_true.len(),
181                actual: y_pred.len(),
182            });
183        }
184
185        // y_pred are logits, y_true should be in {-1, 1} or {0, 1}
186        let loss_sum = y_true
187            .iter()
188            .zip(y_pred.iter())
189            .map(|(&y, &pred)| {
190                // Convert y_true to {-1, 1} format if it's in {0, 1}
191                let y_adjusted = if y == 0.0 { -1.0 } else { y };
192
193                // Numerically stable computation: log(1 + exp(-y * pred))
194                let margin = y_adjusted * pred;
195                if margin > 0.0 {
196                    (1.0 + (-margin).exp()).ln()
197                } else {
198                    -margin + (1.0 + margin.exp()).ln()
199                }
200            })
201            .sum::<Float>();
202
203        Ok(loss_sum / (y_true.len() as Float))
204    }
205
206    fn loss_derivative(
207        &self,
208        y_true: &Array1<Float>,
209        y_pred: &Array1<Float>,
210    ) -> Result<Array1<Float>> {
211        if y_true.len() != y_pred.len() {
212            return Err(SklearsError::DimensionMismatch {
213                expected: y_true.len(),
214                actual: y_pred.len(),
215            });
216        }
217
218        let derivative = y_true
219            .iter()
220            .zip(y_pred.iter())
221            .map(|(&y, &pred)| {
222                // Convert y_true to {-1, 1} format if it's in {0, 1}
223                let y_adjusted = if y == 0.0 { -1.0 } else { y };
224
225                // Derivative: -y / (1 + exp(y * pred))
226                let margin = y_adjusted * pred;
227                -y_adjusted / (1.0 + margin.exp())
228            })
229            .collect::<Vec<Float>>();
230
231        let result = Array1::from_vec(derivative) / (y_true.len() as Float);
232        Ok(result)
233    }
234
235    fn is_classification(&self) -> bool {
236        true
237    }
238
239    fn name(&self) -> &'static str {
240        "LogisticLoss"
241    }
242}
243
244/// Hinge loss for Support Vector Machines
245#[derive(Debug, Clone)]
246pub struct HingeLoss;
247
248impl LossFunction for HingeLoss {
249    fn loss(&self, y_true: &Array1<Float>, y_pred: &Array1<Float>) -> Result<Float> {
250        if y_true.len() != y_pred.len() {
251            return Err(SklearsError::DimensionMismatch {
252                expected: y_true.len(),
253                actual: y_pred.len(),
254            });
255        }
256
257        // y_true should be in {-1, 1}, y_pred are decision function values
258        let loss_sum = y_true
259            .iter()
260            .zip(y_pred.iter())
261            .map(|(&y, &pred)| {
262                let y_adjusted = if y == 0.0 { -1.0 } else { y };
263                let margin = y_adjusted * pred;
264                (1.0 - margin).max(0.0)
265            })
266            .sum::<Float>();
267
268        Ok(loss_sum / (y_true.len() as Float))
269    }
270
271    fn loss_derivative(
272        &self,
273        y_true: &Array1<Float>,
274        y_pred: &Array1<Float>,
275    ) -> Result<Array1<Float>> {
276        if y_true.len() != y_pred.len() {
277            return Err(SklearsError::DimensionMismatch {
278                expected: y_true.len(),
279                actual: y_pred.len(),
280            });
281        }
282
283        let derivative = y_true
284            .iter()
285            .zip(y_pred.iter())
286            .map(|(&y, &pred)| {
287                let y_adjusted = if y == 0.0 { -1.0 } else { y };
288                let margin = y_adjusted * pred;
289                if margin < 1.0 {
290                    -y_adjusted
291                } else {
292                    0.0
293                }
294            })
295            .collect::<Vec<Float>>();
296
297        let result = Array1::from_vec(derivative) / (y_true.len() as Float);
298        Ok(result)
299    }
300
301    fn is_classification(&self) -> bool {
302        true
303    }
304
305    fn name(&self) -> &'static str {
306        "HingeLoss"
307    }
308}
309
310/// Squared Hinge loss (smooth variant of hinge loss)
311#[derive(Debug, Clone)]
312pub struct SquaredHingeLoss;
313
314impl LossFunction for SquaredHingeLoss {
315    fn loss(&self, y_true: &Array1<Float>, y_pred: &Array1<Float>) -> Result<Float> {
316        if y_true.len() != y_pred.len() {
317            return Err(SklearsError::DimensionMismatch {
318                expected: y_true.len(),
319                actual: y_pred.len(),
320            });
321        }
322
323        let loss_sum = y_true
324            .iter()
325            .zip(y_pred.iter())
326            .map(|(&y, &pred)| {
327                let y_adjusted = if y == 0.0 { -1.0 } else { y };
328                let margin = y_adjusted * pred;
329                let hinge = (1.0 - margin).max(0.0);
330                hinge * hinge
331            })
332            .sum::<Float>();
333
334        Ok(loss_sum / (y_true.len() as Float))
335    }
336
337    fn loss_derivative(
338        &self,
339        y_true: &Array1<Float>,
340        y_pred: &Array1<Float>,
341    ) -> Result<Array1<Float>> {
342        if y_true.len() != y_pred.len() {
343            return Err(SklearsError::DimensionMismatch {
344                expected: y_true.len(),
345                actual: y_pred.len(),
346            });
347        }
348
349        let derivative = y_true
350            .iter()
351            .zip(y_pred.iter())
352            .map(|(&y, &pred)| {
353                let y_adjusted = if y == 0.0 { -1.0 } else { y };
354                let margin = y_adjusted * pred;
355                if margin < 1.0 {
356                    -2.0 * y_adjusted * (1.0 - margin)
357                } else {
358                    0.0
359                }
360            })
361            .collect::<Vec<Float>>();
362
363        let result = Array1::from_vec(derivative) / (y_true.len() as Float);
364        Ok(result)
365    }
366
367    fn is_classification(&self) -> bool {
368        true
369    }
370
371    fn name(&self) -> &'static str {
372        "SquaredHingeLoss"
373    }
374}
375
376/// Quantile loss for quantile regression
377#[derive(Debug, Clone)]
378pub struct QuantileLoss {
379    /// Quantile parameter (between 0 and 1)
380    pub quantile: Float,
381}
382
383impl QuantileLoss {
384    /// Create a new quantile loss with the specified quantile
385    pub fn new(quantile: Float) -> Result<Self> {
386        if quantile <= 0.0 || quantile >= 1.0 {
387            return Err(SklearsError::InvalidParameter {
388                name: "quantile".to_string(),
389                reason: format!("Quantile must be between 0 and 1, got {}", quantile),
390            });
391        }
392        Ok(Self { quantile })
393    }
394
395    /// Create median regression (quantile = 0.5)
396    pub fn median() -> Self {
397        Self { quantile: 0.5 }
398    }
399}
400
401impl LossFunction for QuantileLoss {
402    fn loss(&self, y_true: &Array1<Float>, y_pred: &Array1<Float>) -> Result<Float> {
403        if y_true.len() != y_pred.len() {
404            return Err(SklearsError::DimensionMismatch {
405                expected: y_true.len(),
406                actual: y_pred.len(),
407            });
408        }
409
410        let residuals = y_true - y_pred;
411        let loss_sum = residuals
412            .mapv(|r| {
413                if r >= 0.0 {
414                    self.quantile * r
415                } else {
416                    (self.quantile - 1.0) * r
417                }
418            })
419            .sum();
420
421        Ok(loss_sum / (y_true.len() as Float))
422    }
423
424    fn loss_derivative(
425        &self,
426        y_true: &Array1<Float>,
427        y_pred: &Array1<Float>,
428    ) -> Result<Array1<Float>> {
429        if y_true.len() != y_pred.len() {
430            return Err(SklearsError::DimensionMismatch {
431                expected: y_true.len(),
432                actual: y_pred.len(),
433            });
434        }
435
436        let residuals = y_true - y_pred;
437        let derivative = residuals.mapv(|r| {
438            if r > 0.0 {
439                -self.quantile
440            } else if r < 0.0 {
441                -(self.quantile - 1.0)
442            } else {
443                0.0 // Subgradient at 0
444            }
445        });
446
447        Ok(derivative / (y_true.len() as Float))
448    }
449
450    fn name(&self) -> &'static str {
451        "QuantileLoss"
452    }
453}
454
455/// Epsilon-insensitive loss for Support Vector Regression
456#[derive(Debug, Clone)]
457pub struct EpsilonInsensitiveLoss {
458    /// Epsilon parameter (tolerance for errors)
459    pub epsilon: Float,
460}
461
462impl EpsilonInsensitiveLoss {
463    /// Create a new epsilon-insensitive loss
464    pub fn new(epsilon: Float) -> Result<Self> {
465        if epsilon < 0.0 {
466            return Err(SklearsError::InvalidParameter {
467                name: "epsilon".to_string(),
468                reason: format!("Epsilon must be non-negative, got {}", epsilon),
469            });
470        }
471        Ok(Self { epsilon })
472    }
473}
474
475impl LossFunction for EpsilonInsensitiveLoss {
476    fn loss(&self, y_true: &Array1<Float>, y_pred: &Array1<Float>) -> Result<Float> {
477        if y_true.len() != y_pred.len() {
478            return Err(SklearsError::DimensionMismatch {
479                expected: y_true.len(),
480                actual: y_pred.len(),
481            });
482        }
483
484        let loss_sum = (y_true - y_pred)
485            .mapv(|r| (r.abs() - self.epsilon).max(0.0))
486            .sum();
487
488        Ok(loss_sum / (y_true.len() as Float))
489    }
490
491    fn loss_derivative(
492        &self,
493        y_true: &Array1<Float>,
494        y_pred: &Array1<Float>,
495    ) -> Result<Array1<Float>> {
496        if y_true.len() != y_pred.len() {
497            return Err(SklearsError::DimensionMismatch {
498                expected: y_true.len(),
499                actual: y_pred.len(),
500            });
501        }
502
503        let residuals = y_true - y_pred;
504        let derivative = residuals.mapv(|r| {
505            if r > self.epsilon {
506                -1.0
507            } else if r < -self.epsilon {
508                1.0
509            } else {
510                0.0
511            }
512        });
513
514        Ok(derivative / (y_true.len() as Float))
515    }
516
517    fn name(&self) -> &'static str {
518        "EpsilonInsensitiveLoss"
519    }
520}
521
522/// Factory for creating common loss functions
523pub struct LossFactory;
524
525impl LossFactory {
526    /// Create a squared loss (MSE)
527    pub fn squared() -> Box<dyn LossFunction> {
528        Box::new(SquaredLoss)
529    }
530
531    /// Create an absolute loss (MAE)
532    pub fn absolute() -> Box<dyn LossFunction> {
533        Box::new(AbsoluteLoss)
534    }
535
536    /// Create a Huber loss with specified delta
537    pub fn huber(delta: Float) -> Box<dyn LossFunction> {
538        Box::new(HuberLoss::new(delta))
539    }
540
541    /// Create a logistic loss for binary classification
542    pub fn logistic() -> Box<dyn LossFunction> {
543        Box::new(LogisticLoss)
544    }
545
546    /// Create a hinge loss for SVM
547    pub fn hinge() -> Box<dyn LossFunction> {
548        Box::new(HingeLoss)
549    }
550
551    /// Create a squared hinge loss
552    pub fn squared_hinge() -> Box<dyn LossFunction> {
553        Box::new(SquaredHingeLoss)
554    }
555
556    /// Create a quantile loss
557    pub fn quantile(quantile: Float) -> Result<Box<dyn LossFunction>> {
558        Ok(Box::new(QuantileLoss::new(quantile)?))
559    }
560
561    /// Create an epsilon-insensitive loss for SVR
562    pub fn epsilon_insensitive(epsilon: Float) -> Result<Box<dyn LossFunction>> {
563        Ok(Box::new(EpsilonInsensitiveLoss::new(epsilon)?))
564    }
565}
566
567#[allow(non_snake_case)]
568#[cfg(test)]
569mod tests {
570    use super::*;
571    use scirs2_core::ndarray::Array;
572
573    #[test]
574    fn test_squared_loss() {
575        let loss = SquaredLoss;
576        let y_true = Array::from_vec(vec![1.0, 2.0, 3.0]);
577        let y_pred = Array::from_vec(vec![1.1, 1.9, 3.1]);
578
579        let loss_value = loss.loss(&y_true, &y_pred).unwrap();
580        let expected = ((0.1 * 0.1) + (0.1 * 0.1) + (0.1 * 0.1)) / (2.0 * 3.0);
581        assert!((loss_value - expected).abs() < 1e-10);
582
583        let derivative = loss.loss_derivative(&y_true, &y_pred).unwrap();
584        let expected_grad = Array::from_vec(vec![0.1, -0.1, 0.1]) / 3.0;
585        for (actual, expected) in derivative.iter().zip(expected_grad.iter()) {
586            assert!((actual - expected).abs() < 1e-10);
587        }
588    }
589
590    #[test]
591    fn test_absolute_loss() {
592        let loss = AbsoluteLoss;
593        let y_true = Array::from_vec(vec![1.0, 2.0, 3.0]);
594        let y_pred = Array::from_vec(vec![1.2, 1.8, 3.1]);
595
596        let loss_value = loss.loss(&y_true, &y_pred).unwrap();
597        let expected = (0.2 + 0.2 + 0.1) / 3.0;
598        assert!((loss_value - expected).abs() < 1e-10);
599    }
600
601    #[test]
602    fn test_huber_loss() {
603        let loss = HuberLoss::new(1.0);
604        let y_true = Array::from_vec(vec![0.0, 0.0]);
605        let y_pred = Array::from_vec(vec![0.5, 2.0]); // First within delta, second outside
606
607        let loss_value = loss.loss(&y_true, &y_pred).unwrap();
608        // First: 0.5 * 0.5^2 = 0.125
609        // Second: 1.0 * 2.0 - 0.5 * 1.0^2 = 1.5
610        let expected = (0.125 + 1.5) / 2.0;
611        assert!((loss_value - expected).abs() < 1e-10);
612    }
613
614    #[test]
615    fn test_logistic_loss() {
616        let loss = LogisticLoss;
617        let y_true = Array::from_vec(vec![1.0, -1.0]);
618        let y_pred = Array::from_vec(vec![2.0, -2.0]); // Strong correct predictions
619
620        let loss_value = loss.loss(&y_true, &y_pred).unwrap();
621        // Should be small for correct predictions
622        assert!(loss_value < 0.5);
623    }
624
625    #[test]
626    fn test_quantile_loss() {
627        let loss = QuantileLoss::new(0.7).unwrap();
628        let y_true = Array::from_vec(vec![1.0, 2.0]);
629        let y_pred = Array::from_vec(vec![0.5, 2.5]); // Under-predict, over-predict
630
631        let loss_value = loss.loss(&y_true, &y_pred).unwrap();
632        // Under-prediction: 0.7 * 0.5 = 0.35
633        // Over-prediction: (0.7 - 1.0) * (-0.5) = 0.15
634        let expected = (0.35 + 0.15) / 2.0;
635        assert!((loss_value - expected).abs() < 1e-10);
636    }
637
638    #[test]
639    fn test_loss_factory() {
640        let squared = LossFactory::squared();
641        assert_eq!(squared.name(), "SquaredLoss");
642
643        let huber = LossFactory::huber(1.5);
644        assert_eq!(huber.name(), "HuberLoss");
645
646        let quantile = LossFactory::quantile(0.8).unwrap();
647        assert_eq!(quantile.name(), "QuantileLoss");
648    }
649
650    #[test]
651    fn test_dimension_mismatch() {
652        let loss = SquaredLoss;
653        let y_true = Array::from_vec(vec![1.0, 2.0]);
654        let y_pred = Array::from_vec(vec![1.0, 2.0, 3.0]);
655
656        let result = loss.loss(&y_true, &y_pred);
657        assert!(result.is_err());
658    }
659}