1use crate::{Loss, TrainError, TrainResult};
7use scirs2_core::ndarray::{Array, ArrayView, Ix2};
8
9pub struct DistillationLoss {
13 pub temperature: f64,
15 pub alpha: f64,
17 pub hard_loss: Box<dyn Loss>,
19}
20
21impl DistillationLoss {
22 pub fn new(temperature: f64, alpha: f64, hard_loss: Box<dyn Loss>) -> TrainResult<Self> {
29 if temperature <= 0.0 {
30 return Err(TrainError::ConfigError(
31 "Temperature must be positive".to_string(),
32 ));
33 }
34
35 if !(0.0..=1.0).contains(&alpha) {
36 return Err(TrainError::ConfigError(
37 "Alpha must be between 0 and 1".to_string(),
38 ));
39 }
40
41 Ok(Self {
42 temperature,
43 alpha,
44 hard_loss,
45 })
46 }
47
48 pub fn compute_distillation(
58 &self,
59 student_logits: &ArrayView<f64, Ix2>,
60 teacher_logits: &ArrayView<f64, Ix2>,
61 hard_targets: &ArrayView<f64, Ix2>,
62 ) -> TrainResult<f64> {
63 if student_logits.shape() != teacher_logits.shape() {
64 return Err(TrainError::LossError(format!(
65 "Student and teacher logits must have same shape: {:?} vs {:?}",
66 student_logits.shape(),
67 teacher_logits.shape()
68 )));
69 }
70
71 let soft_loss =
73 self.compute_kl_divergence_with_temperature(student_logits, teacher_logits)?;
74
75 let hard_loss = self.hard_loss.compute(student_logits, hard_targets)?;
77
78 let t_squared = self.temperature * self.temperature;
81 let combined_loss = self.alpha * soft_loss * t_squared + (1.0 - self.alpha) * hard_loss;
82
83 Ok(combined_loss)
84 }
85
86 fn compute_kl_divergence_with_temperature(
88 &self,
89 student_logits: &ArrayView<f64, Ix2>,
90 teacher_logits: &ArrayView<f64, Ix2>,
91 ) -> TrainResult<f64> {
92 let t = self.temperature;
93
94 let mut total_loss = 0.0;
95 let n_samples = student_logits.nrows();
96
97 for i in 0..n_samples {
98 let student_probs = self.softmax_with_temperature(&student_logits.row(i), t);
100 let teacher_probs = self.softmax_with_temperature(&teacher_logits.row(i), t);
101
102 for j in 0..student_probs.len() {
104 if teacher_probs[j] > 1e-8 {
105 let ratio = teacher_probs[j] / (student_probs[j] + 1e-8);
106 total_loss += teacher_probs[j] * ratio.ln();
107 }
108 }
109 }
110
111 Ok(total_loss / n_samples as f64)
112 }
113
114 fn softmax_with_temperature(
116 &self,
117 logits: &ArrayView<f64, scirs2_core::ndarray::Ix1>,
118 temperature: f64,
119 ) -> Vec<f64> {
120 let scaled: Vec<f64> = logits.iter().map(|&x| x / temperature).collect();
121
122 let max_val = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
123 let exp_vals: Vec<f64> = scaled.iter().map(|&x| (x - max_val).exp()).collect();
124 let sum: f64 = exp_vals.iter().sum();
125
126 exp_vals.iter().map(|&x| x / sum).collect()
127 }
128}
129
130pub struct FeatureDistillationLoss {
132 pub layer_weights: Vec<f64>,
134 pub p_norm: f64,
136}
137
138impl FeatureDistillationLoss {
139 pub fn new(layer_weights: Vec<f64>, p_norm: f64) -> TrainResult<Self> {
145 if layer_weights.is_empty() {
146 return Err(TrainError::ConfigError(
147 "Must specify at least one layer weight".to_string(),
148 ));
149 }
150
151 if p_norm != 1.0 && p_norm != 2.0 {
152 return Err(TrainError::ConfigError(
153 "p_norm must be 1.0 or 2.0".to_string(),
154 ));
155 }
156
157 Ok(Self {
158 layer_weights,
159 p_norm,
160 })
161 }
162
163 pub fn compute_feature_loss(
172 &self,
173 student_features: &[ArrayView<f64, Ix2>],
174 teacher_features: &[ArrayView<f64, Ix2>],
175 ) -> TrainResult<f64> {
176 if student_features.len() != teacher_features.len() {
177 return Err(TrainError::LossError(
178 "Number of student and teacher feature layers must match".to_string(),
179 ));
180 }
181
182 if student_features.len() != self.layer_weights.len() {
183 return Err(TrainError::LossError(format!(
184 "Number of layers ({}) must match number of weights ({})",
185 student_features.len(),
186 self.layer_weights.len()
187 )));
188 }
189
190 let mut total_loss = 0.0;
191
192 for (i, (student_feat, teacher_feat)) in student_features
193 .iter()
194 .zip(teacher_features.iter())
195 .enumerate()
196 {
197 if student_feat.shape() != teacher_feat.shape() {
198 return Err(TrainError::LossError(format!(
199 "Layer {} shape mismatch: {:?} vs {:?}",
200 i,
201 student_feat.shape(),
202 teacher_feat.shape()
203 )));
204 }
205
206 let mut layer_loss = 0.0;
208 for (&s, &t) in student_feat.iter().zip(teacher_feat.iter()) {
209 let diff = (s - t).abs();
210 layer_loss += if self.p_norm == 2.0 {
211 diff * diff
212 } else {
213 diff
214 };
215 }
216
217 let n_elements = student_feat.len() as f64;
219 layer_loss /= n_elements;
220
221 total_loss += self.layer_weights[i] * layer_loss;
223 }
224
225 Ok(total_loss)
226 }
227}
228
229pub struct AttentionTransferLoss {
231 pub beta: f64,
233}
234
235impl AttentionTransferLoss {
236 pub fn new(beta: f64) -> Self {
241 Self { beta }
242 }
243
244 pub fn compute_attention_loss(
253 &self,
254 student_attention: &ArrayView<f64, Ix2>,
255 teacher_attention: &ArrayView<f64, Ix2>,
256 ) -> TrainResult<f64> {
257 if student_attention.shape() != teacher_attention.shape() {
258 return Err(TrainError::LossError(format!(
259 "Attention maps must have same shape: {:?} vs {:?}",
260 student_attention.shape(),
261 teacher_attention.shape()
262 )));
263 }
264
265 let student_norm = self.normalize_attention(student_attention);
267 let teacher_norm = self.normalize_attention(teacher_attention);
268
269 let mut loss = 0.0;
271 for (s, t) in student_norm.iter().zip(teacher_norm.iter()) {
272 let diff = s - t;
273 loss += diff * diff;
274 }
275
276 let n_elements = student_norm.len() as f64;
277 Ok(loss / n_elements)
278 }
279
280 fn normalize_attention(&self, attention: &ArrayView<f64, Ix2>) -> Array<f64, Ix2> {
282 let mut normalized = attention.mapv(|x| x.abs().powf(self.beta));
283
284 for mut row in normalized.rows_mut() {
286 let sum: f64 = row.iter().sum();
287 if sum > 1e-8 {
288 row.mapv_inplace(|x| x / sum);
289 }
290 }
291
292 normalized
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use crate::CrossEntropyLoss;
300 use scirs2_core::array;
301
302 #[test]
303 fn test_distillation_loss_creation() {
304 let loss = DistillationLoss::new(3.0, 0.7, Box::new(CrossEntropyLoss::default()));
305 assert!(loss.is_ok());
306
307 let loss = loss.unwrap();
308 assert_eq!(loss.temperature, 3.0);
309 assert_eq!(loss.alpha, 0.7);
310 }
311
312 #[test]
313 fn test_distillation_invalid_temperature() {
314 let result = DistillationLoss::new(0.0, 0.5, Box::new(CrossEntropyLoss::default()));
315 assert!(result.is_err());
316
317 let result = DistillationLoss::new(-1.0, 0.5, Box::new(CrossEntropyLoss::default()));
318 assert!(result.is_err());
319 }
320
321 #[test]
322 fn test_distillation_invalid_alpha() {
323 let result = DistillationLoss::new(3.0, -0.1, Box::new(CrossEntropyLoss::default()));
324 assert!(result.is_err());
325
326 let result = DistillationLoss::new(3.0, 1.1, Box::new(CrossEntropyLoss::default()));
327 assert!(result.is_err());
328 }
329
330 #[test]
331 fn test_distillation_compute() {
332 let loss = DistillationLoss::new(2.0, 0.5, Box::new(CrossEntropyLoss::default())).unwrap();
333
334 let student_logits = array![[1.0, 2.0, 0.5], [0.5, 1.0, 2.0]];
335 let teacher_logits = array![[1.2, 1.8, 0.6], [0.6, 1.1, 1.9]];
336 let hard_targets = array![[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
337
338 let result = loss.compute_distillation(
339 &student_logits.view(),
340 &teacher_logits.view(),
341 &hard_targets.view(),
342 );
343
344 assert!(result.is_ok());
345 let loss_value = result.unwrap();
346 assert!(loss_value > 0.0);
347 assert!(loss_value.is_finite());
348 }
349
350 #[test]
351 fn test_feature_distillation_loss() {
352 let loss = FeatureDistillationLoss::new(vec![0.5, 0.3, 0.2], 2.0).unwrap();
353
354 let s1 = array![[1.0, 2.0], [3.0, 4.0]];
355 let s2 = array![[0.5, 1.5], [2.5, 3.5]];
356 let s3 = array![[0.1, 0.2], [0.3, 0.4]];
357 let student_features = vec![s1.view(), s2.view(), s3.view()];
358
359 let t1 = array![[1.1, 2.1], [3.1, 4.1]];
360 let t2 = array![[0.6, 1.6], [2.6, 3.6]];
361 let t3 = array![[0.2, 0.3], [0.4, 0.5]];
362 let teacher_features = vec![t1.view(), t2.view(), t3.view()];
363
364 let result = loss.compute_feature_loss(&student_features, &teacher_features);
365 assert!(result.is_ok());
366
367 let loss_value = result.unwrap();
368 assert!(loss_value > 0.0);
369 assert!(loss_value < 1.0); }
371
372 #[test]
373 fn test_attention_transfer_loss() {
374 let loss = AttentionTransferLoss::new(2.0);
375
376 let student_attention = array![[0.3, 0.5, 0.2], [0.4, 0.4, 0.2]];
377 let teacher_attention = array![[0.35, 0.45, 0.2], [0.35, 0.45, 0.2]];
378
379 let result =
380 loss.compute_attention_loss(&student_attention.view(), &teacher_attention.view());
381 assert!(result.is_ok());
382
383 let loss_value = result.unwrap();
384 assert!(loss_value >= 0.0);
385 assert!(loss_value.is_finite());
386 }
387
388 #[test]
389 fn test_feature_distillation_shape_mismatch() {
390 let loss = FeatureDistillationLoss::new(vec![1.0], 2.0).unwrap();
391
392 let s1 = array![[1.0, 2.0]];
393 let student_features = vec![s1.view()];
394
395 let t1 = array![[1.0, 2.0, 3.0]];
396 let teacher_features = vec![t1.view()];
397
398 let result = loss.compute_feature_loss(&student_features, &teacher_features);
399 assert!(result.is_err());
400 }
401}