1use crate::{Loss, TrainError, TrainResult};
7use scirs2_core::ndarray::{Array, ArrayView, Ix2};
8
9pub struct DistillationLoss {
13 pub temperature: f64,
15 pub alpha: f64,
17 pub hard_loss: Box<dyn Loss>,
19}
20
21impl DistillationLoss {
22 pub fn new(temperature: f64, alpha: f64, hard_loss: Box<dyn Loss>) -> TrainResult<Self> {
29 if temperature <= 0.0 {
30 return Err(TrainError::ConfigError(
31 "Temperature must be positive".to_string(),
32 ));
33 }
34
35 if !(0.0..=1.0).contains(&alpha) {
36 return Err(TrainError::ConfigError(
37 "Alpha must be between 0 and 1".to_string(),
38 ));
39 }
40
41 Ok(Self {
42 temperature,
43 alpha,
44 hard_loss,
45 })
46 }
47
48 pub fn compute_distillation(
58 &self,
59 student_logits: &ArrayView<f64, Ix2>,
60 teacher_logits: &ArrayView<f64, Ix2>,
61 hard_targets: &ArrayView<f64, Ix2>,
62 ) -> TrainResult<f64> {
63 if student_logits.shape() != teacher_logits.shape() {
64 return Err(TrainError::LossError(format!(
65 "Student and teacher logits must have same shape: {:?} vs {:?}",
66 student_logits.shape(),
67 teacher_logits.shape()
68 )));
69 }
70
71 let soft_loss =
73 self.compute_kl_divergence_with_temperature(student_logits, teacher_logits)?;
74
75 let hard_loss = self.hard_loss.compute(student_logits, hard_targets)?;
77
78 let t_squared = self.temperature * self.temperature;
81 let combined_loss = self.alpha * soft_loss * t_squared + (1.0 - self.alpha) * hard_loss;
82
83 Ok(combined_loss)
84 }
85
86 fn compute_kl_divergence_with_temperature(
88 &self,
89 student_logits: &ArrayView<f64, Ix2>,
90 teacher_logits: &ArrayView<f64, Ix2>,
91 ) -> TrainResult<f64> {
92 let t = self.temperature;
93
94 let mut total_loss = 0.0;
95 let n_samples = student_logits.nrows();
96
97 for i in 0..n_samples {
98 let student_probs = self.softmax_with_temperature(&student_logits.row(i), t);
100 let teacher_probs = self.softmax_with_temperature(&teacher_logits.row(i), t);
101
102 for j in 0..student_probs.len() {
104 if teacher_probs[j] > 1e-8 {
105 let ratio = teacher_probs[j] / (student_probs[j] + 1e-8);
106 total_loss += teacher_probs[j] * ratio.ln();
107 }
108 }
109 }
110
111 Ok(total_loss / n_samples as f64)
112 }
113
114 fn softmax_with_temperature(
116 &self,
117 logits: &ArrayView<f64, scirs2_core::ndarray::Ix1>,
118 temperature: f64,
119 ) -> Vec<f64> {
120 let scaled: Vec<f64> = logits.iter().map(|&x| x / temperature).collect();
121
122 let max_val = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
123 let exp_vals: Vec<f64> = scaled.iter().map(|&x| (x - max_val).exp()).collect();
124 let sum: f64 = exp_vals.iter().sum();
125
126 exp_vals.iter().map(|&x| x / sum).collect()
127 }
128}
129
130pub struct FeatureDistillationLoss {
132 pub layer_weights: Vec<f64>,
134 pub p_norm: f64,
136}
137
138impl FeatureDistillationLoss {
139 pub fn new(layer_weights: Vec<f64>, p_norm: f64) -> TrainResult<Self> {
145 if layer_weights.is_empty() {
146 return Err(TrainError::ConfigError(
147 "Must specify at least one layer weight".to_string(),
148 ));
149 }
150
151 if p_norm != 1.0 && p_norm != 2.0 {
152 return Err(TrainError::ConfigError(
153 "p_norm must be 1.0 or 2.0".to_string(),
154 ));
155 }
156
157 Ok(Self {
158 layer_weights,
159 p_norm,
160 })
161 }
162
163 pub fn compute_feature_loss(
172 &self,
173 student_features: &[ArrayView<f64, Ix2>],
174 teacher_features: &[ArrayView<f64, Ix2>],
175 ) -> TrainResult<f64> {
176 if student_features.len() != teacher_features.len() {
177 return Err(TrainError::LossError(
178 "Number of student and teacher feature layers must match".to_string(),
179 ));
180 }
181
182 if student_features.len() != self.layer_weights.len() {
183 return Err(TrainError::LossError(format!(
184 "Number of layers ({}) must match number of weights ({})",
185 student_features.len(),
186 self.layer_weights.len()
187 )));
188 }
189
190 let mut total_loss = 0.0;
191
192 for (i, (student_feat, teacher_feat)) in student_features
193 .iter()
194 .zip(teacher_features.iter())
195 .enumerate()
196 {
197 if student_feat.shape() != teacher_feat.shape() {
198 return Err(TrainError::LossError(format!(
199 "Layer {} shape mismatch: {:?} vs {:?}",
200 i,
201 student_feat.shape(),
202 teacher_feat.shape()
203 )));
204 }
205
206 let mut layer_loss = 0.0;
208 for (&s, &t) in student_feat.iter().zip(teacher_feat.iter()) {
209 let diff = (s - t).abs();
210 layer_loss += if self.p_norm == 2.0 {
211 diff * diff
212 } else {
213 diff
214 };
215 }
216
217 let n_elements = student_feat.len() as f64;
219 layer_loss /= n_elements;
220
221 total_loss += self.layer_weights[i] * layer_loss;
223 }
224
225 Ok(total_loss)
226 }
227}
228
229pub struct AttentionTransferLoss {
231 pub beta: f64,
233}
234
235impl AttentionTransferLoss {
236 pub fn new(beta: f64) -> Self {
241 Self { beta }
242 }
243
244 pub fn compute_attention_loss(
253 &self,
254 student_attention: &ArrayView<f64, Ix2>,
255 teacher_attention: &ArrayView<f64, Ix2>,
256 ) -> TrainResult<f64> {
257 if student_attention.shape() != teacher_attention.shape() {
258 return Err(TrainError::LossError(format!(
259 "Attention maps must have same shape: {:?} vs {:?}",
260 student_attention.shape(),
261 teacher_attention.shape()
262 )));
263 }
264
265 let student_norm = self.normalize_attention(student_attention);
267 let teacher_norm = self.normalize_attention(teacher_attention);
268
269 let mut loss = 0.0;
271 for (s, t) in student_norm.iter().zip(teacher_norm.iter()) {
272 let diff = s - t;
273 loss += diff * diff;
274 }
275
276 let n_elements = student_norm.len() as f64;
277 Ok(loss / n_elements)
278 }
279
280 fn normalize_attention(&self, attention: &ArrayView<f64, Ix2>) -> Array<f64, Ix2> {
282 let mut normalized = attention.mapv(|x| x.abs().powf(self.beta));
283
284 for mut row in normalized.rows_mut() {
286 let sum: f64 = row.iter().sum();
287 if sum > 1e-8 {
288 row.mapv_inplace(|x| x / sum);
289 }
290 }
291
292 normalized
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use crate::CrossEntropyLoss;
300 use scirs2_core::array;
301
302 #[test]
303 fn test_distillation_loss_creation() {
304 let loss = DistillationLoss::new(3.0, 0.7, Box::new(CrossEntropyLoss::default()));
305 assert!(loss.is_ok());
306
307 let loss = loss.expect("unwrap");
308 assert_eq!(loss.temperature, 3.0);
309 assert_eq!(loss.alpha, 0.7);
310 }
311
312 #[test]
313 fn test_distillation_invalid_temperature() {
314 let result = DistillationLoss::new(0.0, 0.5, Box::new(CrossEntropyLoss::default()));
315 assert!(result.is_err());
316
317 let result = DistillationLoss::new(-1.0, 0.5, Box::new(CrossEntropyLoss::default()));
318 assert!(result.is_err());
319 }
320
321 #[test]
322 fn test_distillation_invalid_alpha() {
323 let result = DistillationLoss::new(3.0, -0.1, Box::new(CrossEntropyLoss::default()));
324 assert!(result.is_err());
325
326 let result = DistillationLoss::new(3.0, 1.1, Box::new(CrossEntropyLoss::default()));
327 assert!(result.is_err());
328 }
329
330 #[test]
331 fn test_distillation_compute() {
332 let loss =
333 DistillationLoss::new(2.0, 0.5, Box::new(CrossEntropyLoss::default())).expect("unwrap");
334
335 let student_logits = array![[1.0, 2.0, 0.5], [0.5, 1.0, 2.0]];
336 let teacher_logits = array![[1.2, 1.8, 0.6], [0.6, 1.1, 1.9]];
337 let hard_targets = array![[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
338
339 let result = loss.compute_distillation(
340 &student_logits.view(),
341 &teacher_logits.view(),
342 &hard_targets.view(),
343 );
344
345 assert!(result.is_ok());
346 let loss_value = result.expect("unwrap");
347 assert!(loss_value > 0.0);
348 assert!(loss_value.is_finite());
349 }
350
351 #[test]
352 fn test_feature_distillation_loss() {
353 let loss = FeatureDistillationLoss::new(vec![0.5, 0.3, 0.2], 2.0).expect("unwrap");
354
355 let s1 = array![[1.0, 2.0], [3.0, 4.0]];
356 let s2 = array![[0.5, 1.5], [2.5, 3.5]];
357 let s3 = array![[0.1, 0.2], [0.3, 0.4]];
358 let student_features = vec![s1.view(), s2.view(), s3.view()];
359
360 let t1 = array![[1.1, 2.1], [3.1, 4.1]];
361 let t2 = array![[0.6, 1.6], [2.6, 3.6]];
362 let t3 = array![[0.2, 0.3], [0.4, 0.5]];
363 let teacher_features = vec![t1.view(), t2.view(), t3.view()];
364
365 let result = loss.compute_feature_loss(&student_features, &teacher_features);
366 assert!(result.is_ok());
367
368 let loss_value = result.expect("unwrap");
369 assert!(loss_value > 0.0);
370 assert!(loss_value < 1.0); }
372
373 #[test]
374 fn test_attention_transfer_loss() {
375 let loss = AttentionTransferLoss::new(2.0);
376
377 let student_attention = array![[0.3, 0.5, 0.2], [0.4, 0.4, 0.2]];
378 let teacher_attention = array![[0.35, 0.45, 0.2], [0.35, 0.45, 0.2]];
379
380 let result =
381 loss.compute_attention_loss(&student_attention.view(), &teacher_attention.view());
382 assert!(result.is_ok());
383
384 let loss_value = result.expect("unwrap");
385 assert!(loss_value >= 0.0);
386 assert!(loss_value.is_finite());
387 }
388
389 #[test]
390 fn test_feature_distillation_shape_mismatch() {
391 let loss = FeatureDistillationLoss::new(vec![1.0], 2.0).expect("unwrap");
392
393 let s1 = array![[1.0, 2.0]];
394 let student_features = vec![s1.view()];
395
396 let t1 = array![[1.0, 2.0, 3.0]];
397 let teacher_features = vec![t1.view()];
398
399 let result = loss.compute_feature_loss(&student_features, &teacher_features);
400 assert!(result.is_err());
401 }
402}