1use crate::{Loss, TrainError, TrainResult};
7use scirs2_core::ndarray::{Array, ArrayView, Ix2};
8
9#[derive(Debug, Clone)]
16pub struct LabelSmoothingLoss {
17 pub epsilon: f64,
19 pub num_classes: usize,
21}
22
23impl LabelSmoothingLoss {
24 pub fn new(epsilon: f64, num_classes: usize) -> TrainResult<Self> {
30 if !(0.0..=1.0).contains(&epsilon) {
31 return Err(TrainError::ConfigError(
32 "Epsilon must be between 0 and 1".to_string(),
33 ));
34 }
35
36 if num_classes == 0 {
37 return Err(TrainError::ConfigError(
38 "Number of classes must be positive".to_string(),
39 ));
40 }
41
42 Ok(Self {
43 epsilon,
44 num_classes,
45 })
46 }
47
48 pub fn smooth_labels(&self, targets: &ArrayView<f64, Ix2>) -> Array<f64, Ix2> {
56 if targets.ncols() != self.num_classes {
57 return targets.to_owned();
59 }
60
61 let mut smoothed = Array::zeros(targets.raw_dim());
62
63 let true_confidence = 1.0 - self.epsilon;
64 let other_confidence = self.epsilon / (self.num_classes - 1) as f64;
65
66 for i in 0..targets.nrows() {
67 for j in 0..targets.ncols() {
68 if targets[[i, j]] > 0.5 {
69 smoothed[[i, j]] = true_confidence;
71 } else {
72 smoothed[[i, j]] = other_confidence;
74 }
75 }
76 }
77
78 smoothed
79 }
80}
81
82impl Loss for LabelSmoothingLoss {
83 fn compute(
84 &self,
85 predictions: &ArrayView<f64, Ix2>,
86 targets: &ArrayView<f64, Ix2>,
87 ) -> TrainResult<f64> {
88 if predictions.shape() != targets.shape() {
89 return Err(TrainError::LossError(format!(
90 "Shape mismatch: predictions {:?} vs targets {:?}",
91 predictions.shape(),
92 targets.shape()
93 )));
94 }
95
96 if predictions.ncols() != self.num_classes {
97 return Err(TrainError::LossError(format!(
98 "Number of classes mismatch: expected {}, got {}",
99 self.num_classes,
100 predictions.ncols()
101 )));
102 }
103
104 let smoothed_targets = self.smooth_labels(targets);
106
107 let mut total_loss = 0.0;
109 let n_samples = predictions.nrows();
110
111 for i in 0..n_samples {
112 let max_pred = predictions
114 .row(i)
115 .iter()
116 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
117
118 let exp_preds: Vec<f64> = predictions
119 .row(i)
120 .iter()
121 .map(|&x| (x - max_pred).exp())
122 .collect();
123
124 let sum_exp: f64 = exp_preds.iter().sum();
125
126 for j in 0..predictions.ncols() {
128 let prob = exp_preds[j] / sum_exp;
129 let target = smoothed_targets[[i, j]];
130
131 if target > 1e-8 {
132 total_loss -= target * (prob + 1e-8).ln();
133 }
134 }
135 }
136
137 Ok(total_loss / n_samples as f64)
138 }
139
140 fn gradient(
141 &self,
142 predictions: &ArrayView<f64, Ix2>,
143 targets: &ArrayView<f64, Ix2>,
144 ) -> TrainResult<Array<f64, Ix2>> {
145 if predictions.shape() != targets.shape() {
146 return Err(TrainError::LossError(format!(
147 "Shape mismatch: predictions {:?} vs targets {:?}",
148 predictions.shape(),
149 targets.shape()
150 )));
151 }
152
153 if predictions.ncols() != self.num_classes {
154 return Err(TrainError::LossError(format!(
155 "Number of classes mismatch: expected {}, got {}",
156 self.num_classes,
157 predictions.ncols()
158 )));
159 }
160
161 let smoothed_targets = self.smooth_labels(targets);
163
164 let n_samples = predictions.nrows();
165 let mut grad = Array::zeros(predictions.raw_dim());
166
167 for i in 0..n_samples {
168 let max_pred = predictions
170 .row(i)
171 .iter()
172 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
173
174 let exp_preds: Vec<f64> = predictions
175 .row(i)
176 .iter()
177 .map(|&x| (x - max_pred).exp())
178 .collect();
179
180 let sum_exp: f64 = exp_preds.iter().sum();
181
182 for j in 0..predictions.ncols() {
184 let prob = exp_preds[j] / sum_exp;
185 let target = smoothed_targets[[i, j]];
186 grad[[i, j]] = (prob - target) / n_samples as f64;
187 }
188 }
189
190 Ok(grad)
191 }
192
193 fn name(&self) -> &str {
194 "label_smoothing"
195 }
196}
197
198#[derive(Debug)]
202pub struct MixupLoss {
203 pub alpha: f64,
205 pub base_loss: Box<dyn Loss>,
207}
208
209impl MixupLoss {
210 pub fn new(alpha: f64, base_loss: Box<dyn Loss>) -> TrainResult<Self> {
216 if alpha <= 0.0 {
217 return Err(TrainError::ConfigError(
218 "Alpha must be positive".to_string(),
219 ));
220 }
221
222 Ok(Self { alpha, base_loss })
223 }
224
225 pub fn compute_mixup(
230 &self,
231 predictions: &ArrayView<f64, Ix2>,
232 mixed_targets: &ArrayView<f64, Ix2>,
233 ) -> TrainResult<f64> {
234 self.base_loss.compute(predictions, mixed_targets)
236 }
237
238 pub fn mix_data(
248 data1: &ArrayView<f64, Ix2>,
249 data2: &ArrayView<f64, Ix2>,
250 lambda: f64,
251 ) -> TrainResult<Array<f64, Ix2>> {
252 if data1.shape() != data2.shape() {
253 return Err(TrainError::LossError(
254 "Data shapes must match for mixing".to_string(),
255 ));
256 }
257
258 let mixed = data1 * lambda + data2 * (1.0 - lambda);
259 Ok(mixed.to_owned())
260 }
261
262 #[allow(dead_code)]
267 fn sample_lambda(&self) -> f64 {
268 0.5
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use scirs2_core::array;
278
279 #[test]
280 fn test_label_smoothing_creation() {
281 let loss = LabelSmoothingLoss::new(0.1, 10);
282 assert!(loss.is_ok());
283
284 let loss = loss.unwrap();
285 assert_eq!(loss.epsilon, 0.1);
286 assert_eq!(loss.num_classes, 10);
287 }
288
289 #[test]
290 fn test_label_smoothing_invalid_epsilon() {
291 assert!(LabelSmoothingLoss::new(-0.1, 10).is_err());
292 assert!(LabelSmoothingLoss::new(1.5, 10).is_err());
293 }
294
295 #[test]
296 fn test_label_smoothing_smooth_labels() {
297 let loss = LabelSmoothingLoss::new(0.1, 3).unwrap();
298
299 let targets = array![[0.0, 1.0, 0.0], [1.0, 0.0, 0.0]];
301 let smoothed = loss.smooth_labels(&targets.view());
302
303 assert!((smoothed[[0, 1]] - 0.9).abs() < 1e-6);
305
306 assert!((smoothed[[0, 0]] - 0.05).abs() < 1e-6);
308 assert!((smoothed[[0, 2]] - 0.05).abs() < 1e-6);
309 }
310
311 #[test]
312 fn test_label_smoothing_loss_compute() {
313 let loss = LabelSmoothingLoss::new(0.1, 3).unwrap();
314
315 let predictions = array![[1.0, 2.0, 0.5], [0.5, 1.0, 2.0]];
316 let targets = array![[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
317
318 let result = loss.compute(&predictions.view(), &targets.view());
319 assert!(result.is_ok());
320
321 let loss_value = result.unwrap();
322 assert!(loss_value > 0.0);
323 assert!(loss_value.is_finite());
324 }
325
326 #[test]
327 fn test_mixup_loss_creation() {
328 use crate::MseLoss;
329
330 let loss = MixupLoss::new(1.0, Box::new(MseLoss));
331 assert!(loss.is_ok());
332
333 assert!(MixupLoss::new(0.0, Box::new(MseLoss)).is_err());
334 assert!(MixupLoss::new(-1.0, Box::new(MseLoss)).is_err());
335 }
336
337 #[test]
338 fn test_mixup_mix_data() {
339 let data1 = array![[1.0, 2.0], [3.0, 4.0]];
340 let data2 = array![[5.0, 6.0], [7.0, 8.0]];
341
342 let mixed = MixupLoss::mix_data(&data1.view(), &data2.view(), 0.5).unwrap();
343
344 assert!((mixed[[0, 0]] - 3.0).abs() < 1e-6);
346 assert!((mixed[[0, 1]] - 4.0).abs() < 1e-6);
347 assert!((mixed[[1, 0]] - 5.0).abs() < 1e-6);
348 assert!((mixed[[1, 1]] - 6.0).abs() < 1e-6);
349 }
350
351 #[test]
352 fn test_mixup_mix_data_lambda_extremes() {
353 let data1 = array![[1.0, 2.0]];
354 let data2 = array![[5.0, 6.0]];
355
356 let mixed = MixupLoss::mix_data(&data1.view(), &data2.view(), 1.0).unwrap();
358 assert!((mixed[[0, 0]] - 1.0).abs() < 1e-6);
359 assert!((mixed[[0, 1]] - 2.0).abs() < 1e-6);
360
361 let mixed = MixupLoss::mix_data(&data1.view(), &data2.view(), 0.0).unwrap();
363 assert!((mixed[[0, 0]] - 5.0).abs() < 1e-6);
364 assert!((mixed[[0, 1]] - 6.0).abs() < 1e-6);
365 }
366
367 #[test]
368 fn test_label_smoothing_zero_epsilon() {
369 let loss = LabelSmoothingLoss::new(0.0, 3).unwrap();
371
372 let targets = array![[0.0, 1.0, 0.0]];
373 let smoothed = loss.smooth_labels(&targets.view());
374
375 assert!((smoothed[[0, 0]] - 0.0).abs() < 1e-6);
376 assert!((smoothed[[0, 1]] - 1.0).abs() < 1e-6);
377 assert!((smoothed[[0, 2]] - 0.0).abs() < 1e-6);
378 }
379}