Skip to main content

tensorlogic_train/
label_smoothing.rs

1//! Label smoothing regularization for improved generalization.
2//!
3//! Label smoothing is a regularization technique that prevents the model from becoming
4//! overconfident by smoothing the target distribution.
5
6use crate::{Loss, TrainError, TrainResult};
7use scirs2_core::ndarray::{Array, ArrayView, Ix2};
8
9/// Label smoothing cross-entropy loss.
10///
11/// Based on "Rethinking the Inception Architecture for Computer Vision" (Szegedy et al., 2016).
12/// Replaces hard 0/1 labels with smoothed distribution:
13/// - True class: 1 - epsilon
14/// - Other classes: epsilon / (num_classes - 1)
15#[derive(Debug, Clone)]
16pub struct LabelSmoothingLoss {
17    /// Smoothing parameter (typically 0.1).
18    pub epsilon: f64,
19    /// Number of classes.
20    pub num_classes: usize,
21}
22
23impl LabelSmoothingLoss {
24    /// Create a new label smoothing loss.
25    ///
26    /// # Arguments
27    /// * `epsilon` - Smoothing parameter (0 = no smoothing, higher = more smoothing)
28    /// * `num_classes` - Number of classes
29    pub fn new(epsilon: f64, num_classes: usize) -> TrainResult<Self> {
30        if !(0.0..=1.0).contains(&epsilon) {
31            return Err(TrainError::ConfigError(
32                "Epsilon must be between 0 and 1".to_string(),
33            ));
34        }
35
36        if num_classes == 0 {
37            return Err(TrainError::ConfigError(
38                "Number of classes must be positive".to_string(),
39            ));
40        }
41
42        Ok(Self {
43            epsilon,
44            num_classes,
45        })
46    }
47
48    /// Apply label smoothing to targets.
49    ///
50    /// # Arguments
51    /// * `targets` - One-hot encoded targets
52    ///
53    /// # Returns
54    /// Smoothed targets
55    pub fn smooth_labels(&self, targets: &ArrayView<f64, Ix2>) -> Array<f64, Ix2> {
56        if targets.ncols() != self.num_classes {
57            // If mismatch, return original (will error in loss computation)
58            return targets.to_owned();
59        }
60
61        let mut smoothed = Array::zeros(targets.raw_dim());
62
63        let true_confidence = 1.0 - self.epsilon;
64        let other_confidence = self.epsilon / (self.num_classes - 1) as f64;
65
66        for i in 0..targets.nrows() {
67            for j in 0..targets.ncols() {
68                if targets[[i, j]] > 0.5 {
69                    // True class
70                    smoothed[[i, j]] = true_confidence;
71                } else {
72                    // Other classes
73                    smoothed[[i, j]] = other_confidence;
74                }
75            }
76        }
77
78        smoothed
79    }
80}
81
82impl Loss for LabelSmoothingLoss {
83    fn compute(
84        &self,
85        predictions: &ArrayView<f64, Ix2>,
86        targets: &ArrayView<f64, Ix2>,
87    ) -> TrainResult<f64> {
88        if predictions.shape() != targets.shape() {
89            return Err(TrainError::LossError(format!(
90                "Shape mismatch: predictions {:?} vs targets {:?}",
91                predictions.shape(),
92                targets.shape()
93            )));
94        }
95
96        if predictions.ncols() != self.num_classes {
97            return Err(TrainError::LossError(format!(
98                "Number of classes mismatch: expected {}, got {}",
99                self.num_classes,
100                predictions.ncols()
101            )));
102        }
103
104        // Apply label smoothing
105        let smoothed_targets = self.smooth_labels(targets);
106
107        // Compute cross-entropy with smoothed labels
108        let mut total_loss = 0.0;
109        let n_samples = predictions.nrows();
110
111        for i in 0..n_samples {
112            // Softmax normalization
113            let max_pred = predictions
114                .row(i)
115                .iter()
116                .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
117
118            let exp_preds: Vec<f64> = predictions
119                .row(i)
120                .iter()
121                .map(|&x| (x - max_pred).exp())
122                .collect();
123
124            let sum_exp: f64 = exp_preds.iter().sum();
125
126            // Cross-entropy
127            for j in 0..predictions.ncols() {
128                let prob = exp_preds[j] / sum_exp;
129                let target = smoothed_targets[[i, j]];
130
131                if target > 1e-8 {
132                    total_loss -= target * (prob + 1e-8).ln();
133                }
134            }
135        }
136
137        Ok(total_loss / n_samples as f64)
138    }
139
140    fn gradient(
141        &self,
142        predictions: &ArrayView<f64, Ix2>,
143        targets: &ArrayView<f64, Ix2>,
144    ) -> TrainResult<Array<f64, Ix2>> {
145        if predictions.shape() != targets.shape() {
146            return Err(TrainError::LossError(format!(
147                "Shape mismatch: predictions {:?} vs targets {:?}",
148                predictions.shape(),
149                targets.shape()
150            )));
151        }
152
153        if predictions.ncols() != self.num_classes {
154            return Err(TrainError::LossError(format!(
155                "Number of classes mismatch: expected {}, got {}",
156                self.num_classes,
157                predictions.ncols()
158            )));
159        }
160
161        // Apply label smoothing
162        let smoothed_targets = self.smooth_labels(targets);
163
164        let n_samples = predictions.nrows();
165        let mut grad = Array::zeros(predictions.raw_dim());
166
167        for i in 0..n_samples {
168            // Softmax normalization
169            let max_pred = predictions
170                .row(i)
171                .iter()
172                .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
173
174            let exp_preds: Vec<f64> = predictions
175                .row(i)
176                .iter()
177                .map(|&x| (x - max_pred).exp())
178                .collect();
179
180            let sum_exp: f64 = exp_preds.iter().sum();
181
182            // Gradient: softmax(predictions) - smoothed_targets
183            for j in 0..predictions.ncols() {
184                let prob = exp_preds[j] / sum_exp;
185                let target = smoothed_targets[[i, j]];
186                grad[[i, j]] = (prob - target) / n_samples as f64;
187            }
188        }
189
190        Ok(grad)
191    }
192
193    fn name(&self) -> &str {
194        "label_smoothing"
195    }
196}
197
198/// Mixup data augmentation that mixes training examples and their labels.
199///
200/// Based on "mixup: Beyond Empirical Risk Minimization" (Zhang et al., 2018).
201#[derive(Debug)]
202pub struct MixupLoss {
203    /// Alpha parameter for Beta distribution (typically 1.0).
204    pub alpha: f64,
205    /// Base loss function.
206    pub base_loss: Box<dyn Loss>,
207}
208
209impl MixupLoss {
210    /// Create a new Mixup loss.
211    ///
212    /// # Arguments
213    /// * `alpha` - Beta distribution parameter (higher = more mixing)
214    /// * `base_loss` - Underlying loss function
215    pub fn new(alpha: f64, base_loss: Box<dyn Loss>) -> TrainResult<Self> {
216        if alpha <= 0.0 {
217            return Err(TrainError::ConfigError(
218                "Alpha must be positive".to_string(),
219            ));
220        }
221
222        Ok(Self { alpha, base_loss })
223    }
224
225    /// Compute mixup loss.
226    ///
227    /// Note: In practice, mixup is applied during data loading, not in the loss function.
228    /// This implementation assumes pre-mixed inputs and targets.
229    pub fn compute_mixup(
230        &self,
231        predictions: &ArrayView<f64, Ix2>,
232        mixed_targets: &ArrayView<f64, Ix2>,
233    ) -> TrainResult<f64> {
234        // With pre-mixed targets, just use the base loss
235        self.base_loss.compute(predictions, mixed_targets)
236    }
237
238    /// Mix two batches of data with random lambda.
239    ///
240    /// # Arguments
241    /// * `data1` - First batch of data
242    /// * `data2` - Second batch of data
243    /// * `lambda` - Mixing coefficient (0 to 1)
244    ///
245    /// # Returns
246    /// Mixed data
247    pub fn mix_data(
248        data1: &ArrayView<f64, Ix2>,
249        data2: &ArrayView<f64, Ix2>,
250        lambda: f64,
251    ) -> TrainResult<Array<f64, Ix2>> {
252        if data1.shape() != data2.shape() {
253            return Err(TrainError::LossError(
254                "Data shapes must match for mixing".to_string(),
255            ));
256        }
257
258        let mixed = data1 * lambda + data2 * (1.0 - lambda);
259        Ok(mixed.to_owned())
260    }
261
262    /// Sample mixing coefficient from Beta distribution.
263    ///
264    /// In practice, use a proper random number generator.
265    /// This is a simplified implementation for demonstration.
266    #[allow(dead_code)]
267    fn sample_lambda(&self) -> f64 {
268        // Simplified: return midpoint
269        // In real implementation, sample from Beta(alpha, alpha)
270        0.5
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use scirs2_core::array;
278
279    #[test]
280    fn test_label_smoothing_creation() {
281        let loss = LabelSmoothingLoss::new(0.1, 10);
282        assert!(loss.is_ok());
283
284        let loss = loss.unwrap();
285        assert_eq!(loss.epsilon, 0.1);
286        assert_eq!(loss.num_classes, 10);
287    }
288
289    #[test]
290    fn test_label_smoothing_invalid_epsilon() {
291        assert!(LabelSmoothingLoss::new(-0.1, 10).is_err());
292        assert!(LabelSmoothingLoss::new(1.5, 10).is_err());
293    }
294
295    #[test]
296    fn test_label_smoothing_smooth_labels() {
297        let loss = LabelSmoothingLoss::new(0.1, 3).unwrap();
298
299        // One-hot encoded: class 1 is true
300        let targets = array![[0.0, 1.0, 0.0], [1.0, 0.0, 0.0]];
301        let smoothed = loss.smooth_labels(&targets.view());
302
303        // True class should have 1 - epsilon = 0.9
304        assert!((smoothed[[0, 1]] - 0.9).abs() < 1e-6);
305
306        // Other classes should have epsilon / (num_classes - 1) = 0.1 / 2 = 0.05
307        assert!((smoothed[[0, 0]] - 0.05).abs() < 1e-6);
308        assert!((smoothed[[0, 2]] - 0.05).abs() < 1e-6);
309    }
310
311    #[test]
312    fn test_label_smoothing_loss_compute() {
313        let loss = LabelSmoothingLoss::new(0.1, 3).unwrap();
314
315        let predictions = array![[1.0, 2.0, 0.5], [0.5, 1.0, 2.0]];
316        let targets = array![[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
317
318        let result = loss.compute(&predictions.view(), &targets.view());
319        assert!(result.is_ok());
320
321        let loss_value = result.unwrap();
322        assert!(loss_value > 0.0);
323        assert!(loss_value.is_finite());
324    }
325
326    #[test]
327    fn test_mixup_loss_creation() {
328        use crate::MseLoss;
329
330        let loss = MixupLoss::new(1.0, Box::new(MseLoss));
331        assert!(loss.is_ok());
332
333        assert!(MixupLoss::new(0.0, Box::new(MseLoss)).is_err());
334        assert!(MixupLoss::new(-1.0, Box::new(MseLoss)).is_err());
335    }
336
337    #[test]
338    fn test_mixup_mix_data() {
339        let data1 = array![[1.0, 2.0], [3.0, 4.0]];
340        let data2 = array![[5.0, 6.0], [7.0, 8.0]];
341
342        let mixed = MixupLoss::mix_data(&data1.view(), &data2.view(), 0.5).unwrap();
343
344        // With lambda=0.5, should be average
345        assert!((mixed[[0, 0]] - 3.0).abs() < 1e-6);
346        assert!((mixed[[0, 1]] - 4.0).abs() < 1e-6);
347        assert!((mixed[[1, 0]] - 5.0).abs() < 1e-6);
348        assert!((mixed[[1, 1]] - 6.0).abs() < 1e-6);
349    }
350
351    #[test]
352    fn test_mixup_mix_data_lambda_extremes() {
353        let data1 = array![[1.0, 2.0]];
354        let data2 = array![[5.0, 6.0]];
355
356        // Lambda = 1.0 should return data1
357        let mixed = MixupLoss::mix_data(&data1.view(), &data2.view(), 1.0).unwrap();
358        assert!((mixed[[0, 0]] - 1.0).abs() < 1e-6);
359        assert!((mixed[[0, 1]] - 2.0).abs() < 1e-6);
360
361        // Lambda = 0.0 should return data2
362        let mixed = MixupLoss::mix_data(&data1.view(), &data2.view(), 0.0).unwrap();
363        assert!((mixed[[0, 0]] - 5.0).abs() < 1e-6);
364        assert!((mixed[[0, 1]] - 6.0).abs() < 1e-6);
365    }
366
367    #[test]
368    fn test_label_smoothing_zero_epsilon() {
369        // With epsilon = 0, should behave like regular one-hot
370        let loss = LabelSmoothingLoss::new(0.0, 3).unwrap();
371
372        let targets = array![[0.0, 1.0, 0.0]];
373        let smoothed = loss.smooth_labels(&targets.view());
374
375        assert!((smoothed[[0, 0]] - 0.0).abs() < 1e-6);
376        assert!((smoothed[[0, 1]] - 1.0).abs() < 1e-6);
377        assert!((smoothed[[0, 2]] - 0.0).abs() < 1e-6);
378    }
379}