Skip to main content

tensorlogic_train/
distillation.rs

1//! Knowledge distillation utilities for model compression and transfer learning.
2//!
3//! This module provides utilities for knowledge distillation, where a smaller "student"
4//! model learns from a larger "teacher" model's outputs.
5
6use crate::{Loss, TrainError, TrainResult};
7use scirs2_core::ndarray::{Array, ArrayView, Ix2};
8
9/// Knowledge distillation loss that combines student predictions with teacher soft targets.
10///
11/// Based on "Distilling the Knowledge in a Neural Network" (Hinton et al., 2015).
12pub struct DistillationLoss {
13    /// Temperature for softening probabilities (higher = softer).
14    pub temperature: f64,
15    /// Weight for distillation loss (1 - alpha for hard target loss).
16    pub alpha: f64,
17    /// Base loss function for hard targets.
18    pub hard_loss: Box<dyn Loss>,
19}
20
21impl DistillationLoss {
22    /// Create a new distillation loss.
23    ///
24    /// # Arguments
25    /// * `temperature` - Temperature for softening (typically 2.0-5.0)
26    /// * `alpha` - Weight for soft targets (typically 0.5-0.9)
27    /// * `hard_loss` - Loss function for hard targets
28    pub fn new(temperature: f64, alpha: f64, hard_loss: Box<dyn Loss>) -> TrainResult<Self> {
29        if temperature <= 0.0 {
30            return Err(TrainError::ConfigError(
31                "Temperature must be positive".to_string(),
32            ));
33        }
34
35        if !(0.0..=1.0).contains(&alpha) {
36            return Err(TrainError::ConfigError(
37                "Alpha must be between 0 and 1".to_string(),
38            ));
39        }
40
41        Ok(Self {
42            temperature,
43            alpha,
44            hard_loss,
45        })
46    }
47
48    /// Compute distillation loss combining soft and hard targets.
49    ///
50    /// # Arguments
51    /// * `student_logits` - Raw student model outputs (before softmax)
52    /// * `teacher_logits` - Raw teacher model outputs (before softmax)
53    /// * `hard_targets` - True labels (one-hot encoded)
54    ///
55    /// # Returns
56    /// Combined distillation loss
57    pub fn compute_distillation(
58        &self,
59        student_logits: &ArrayView<f64, Ix2>,
60        teacher_logits: &ArrayView<f64, Ix2>,
61        hard_targets: &ArrayView<f64, Ix2>,
62    ) -> TrainResult<f64> {
63        if student_logits.shape() != teacher_logits.shape() {
64            return Err(TrainError::LossError(format!(
65                "Student and teacher logits must have same shape: {:?} vs {:?}",
66                student_logits.shape(),
67                teacher_logits.shape()
68            )));
69        }
70
71        // Soft targets loss (KL divergence of softened distributions)
72        let soft_loss =
73            self.compute_kl_divergence_with_temperature(student_logits, teacher_logits)?;
74
75        // Hard targets loss
76        let hard_loss = self.hard_loss.compute(student_logits, hard_targets)?;
77
78        // Combine with weighting
79        // Note: soft loss is scaled by T^2 as per original paper
80        let t_squared = self.temperature * self.temperature;
81        let combined_loss = self.alpha * soft_loss * t_squared + (1.0 - self.alpha) * hard_loss;
82
83        Ok(combined_loss)
84    }
85
86    /// Compute KL divergence between temperature-scaled distributions.
87    fn compute_kl_divergence_with_temperature(
88        &self,
89        student_logits: &ArrayView<f64, Ix2>,
90        teacher_logits: &ArrayView<f64, Ix2>,
91    ) -> TrainResult<f64> {
92        let t = self.temperature;
93
94        let mut total_loss = 0.0;
95        let n_samples = student_logits.nrows();
96
97        for i in 0..n_samples {
98            // Scale by temperature and apply softmax
99            let student_probs = self.softmax_with_temperature(&student_logits.row(i), t);
100            let teacher_probs = self.softmax_with_temperature(&teacher_logits.row(i), t);
101
102            // KL divergence: sum(teacher * log(teacher / student))
103            for j in 0..student_probs.len() {
104                if teacher_probs[j] > 1e-8 {
105                    let ratio = teacher_probs[j] / (student_probs[j] + 1e-8);
106                    total_loss += teacher_probs[j] * ratio.ln();
107                }
108            }
109        }
110
111        Ok(total_loss / n_samples as f64)
112    }
113
114    /// Apply softmax with temperature scaling.
115    fn softmax_with_temperature(
116        &self,
117        logits: &ArrayView<f64, scirs2_core::ndarray::Ix1>,
118        temperature: f64,
119    ) -> Vec<f64> {
120        let scaled: Vec<f64> = logits.iter().map(|&x| x / temperature).collect();
121
122        let max_val = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
123        let exp_vals: Vec<f64> = scaled.iter().map(|&x| (x - max_val).exp()).collect();
124        let sum: f64 = exp_vals.iter().sum();
125
126        exp_vals.iter().map(|&x| x / sum).collect()
127    }
128}
129
130/// Feature-based distillation that matches intermediate layer representations.
131pub struct FeatureDistillationLoss {
132    /// Weight for each feature layer.
133    pub layer_weights: Vec<f64>,
134    /// Distance metric (2.0 for L2, 1.0 for L1).
135    pub p_norm: f64,
136}
137
138impl FeatureDistillationLoss {
139    /// Create a new feature distillation loss.
140    ///
141    /// # Arguments
142    /// * `layer_weights` - Weights for each intermediate layer
143    /// * `p_norm` - Norm to use for distance (1.0 or 2.0)
144    pub fn new(layer_weights: Vec<f64>, p_norm: f64) -> TrainResult<Self> {
145        if layer_weights.is_empty() {
146            return Err(TrainError::ConfigError(
147                "Must specify at least one layer weight".to_string(),
148            ));
149        }
150
151        if p_norm != 1.0 && p_norm != 2.0 {
152            return Err(TrainError::ConfigError(
153                "p_norm must be 1.0 or 2.0".to_string(),
154            ));
155        }
156
157        Ok(Self {
158            layer_weights,
159            p_norm,
160        })
161    }
162
163    /// Compute feature matching loss for intermediate representations.
164    ///
165    /// # Arguments
166    /// * `student_features` - Student model's intermediate features
167    /// * `teacher_features` - Teacher model's intermediate features
168    ///
169    /// # Returns
170    /// Weighted sum of feature matching losses
171    pub fn compute_feature_loss(
172        &self,
173        student_features: &[ArrayView<f64, Ix2>],
174        teacher_features: &[ArrayView<f64, Ix2>],
175    ) -> TrainResult<f64> {
176        if student_features.len() != teacher_features.len() {
177            return Err(TrainError::LossError(
178                "Number of student and teacher feature layers must match".to_string(),
179            ));
180        }
181
182        if student_features.len() != self.layer_weights.len() {
183            return Err(TrainError::LossError(format!(
184                "Number of layers ({}) must match number of weights ({})",
185                student_features.len(),
186                self.layer_weights.len()
187            )));
188        }
189
190        let mut total_loss = 0.0;
191
192        for (i, (student_feat, teacher_feat)) in student_features
193            .iter()
194            .zip(teacher_features.iter())
195            .enumerate()
196        {
197            if student_feat.shape() != teacher_feat.shape() {
198                return Err(TrainError::LossError(format!(
199                    "Layer {} shape mismatch: {:?} vs {:?}",
200                    i,
201                    student_feat.shape(),
202                    teacher_feat.shape()
203                )));
204            }
205
206            // Compute distance
207            let mut layer_loss = 0.0;
208            for (&s, &t) in student_feat.iter().zip(teacher_feat.iter()) {
209                let diff = (s - t).abs();
210                layer_loss += if self.p_norm == 2.0 {
211                    diff * diff
212                } else {
213                    diff
214                };
215            }
216
217            // Normalize by number of elements
218            let n_elements = student_feat.len() as f64;
219            layer_loss /= n_elements;
220
221            // Apply layer weight
222            total_loss += self.layer_weights[i] * layer_loss;
223        }
224
225        Ok(total_loss)
226    }
227}
228
229/// Attention transfer for distillation based on attention maps.
230pub struct AttentionTransferLoss {
231    /// Beta parameter for attention map normalization.
232    pub beta: f64,
233}
234
235impl AttentionTransferLoss {
236    /// Create a new attention transfer loss.
237    ///
238    /// # Arguments
239    /// * `beta` - Power for attention map normalization (typically 2.0)
240    pub fn new(beta: f64) -> Self {
241        Self { beta }
242    }
243
244    /// Compute attention transfer loss.
245    ///
246    /// # Arguments
247    /// * `student_attention` - Student attention maps
248    /// * `teacher_attention` - Teacher attention maps
249    ///
250    /// # Returns
251    /// Attention transfer loss
252    pub fn compute_attention_loss(
253        &self,
254        student_attention: &ArrayView<f64, Ix2>,
255        teacher_attention: &ArrayView<f64, Ix2>,
256    ) -> TrainResult<f64> {
257        if student_attention.shape() != teacher_attention.shape() {
258            return Err(TrainError::LossError(format!(
259                "Attention maps must have same shape: {:?} vs {:?}",
260                student_attention.shape(),
261                teacher_attention.shape()
262            )));
263        }
264
265        // Normalize attention maps
266        let student_norm = self.normalize_attention(student_attention);
267        let teacher_norm = self.normalize_attention(teacher_attention);
268
269        // Compute L2 distance
270        let mut loss = 0.0;
271        for (s, t) in student_norm.iter().zip(teacher_norm.iter()) {
272            let diff = s - t;
273            loss += diff * diff;
274        }
275
276        let n_elements = student_norm.len() as f64;
277        Ok(loss / n_elements)
278    }
279
280    /// Normalize attention map using beta-power normalization.
281    fn normalize_attention(&self, attention: &ArrayView<f64, Ix2>) -> Array<f64, Ix2> {
282        let mut normalized = attention.mapv(|x| x.abs().powf(self.beta));
283
284        // Normalize each sample
285        for mut row in normalized.rows_mut() {
286            let sum: f64 = row.iter().sum();
287            if sum > 1e-8 {
288                row.mapv_inplace(|x| x / sum);
289            }
290        }
291
292        normalized
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299    use crate::CrossEntropyLoss;
300    use scirs2_core::array;
301
302    #[test]
303    fn test_distillation_loss_creation() {
304        let loss = DistillationLoss::new(3.0, 0.7, Box::new(CrossEntropyLoss::default()));
305        assert!(loss.is_ok());
306
307        let loss = loss.unwrap();
308        assert_eq!(loss.temperature, 3.0);
309        assert_eq!(loss.alpha, 0.7);
310    }
311
312    #[test]
313    fn test_distillation_invalid_temperature() {
314        let result = DistillationLoss::new(0.0, 0.5, Box::new(CrossEntropyLoss::default()));
315        assert!(result.is_err());
316
317        let result = DistillationLoss::new(-1.0, 0.5, Box::new(CrossEntropyLoss::default()));
318        assert!(result.is_err());
319    }
320
321    #[test]
322    fn test_distillation_invalid_alpha() {
323        let result = DistillationLoss::new(3.0, -0.1, Box::new(CrossEntropyLoss::default()));
324        assert!(result.is_err());
325
326        let result = DistillationLoss::new(3.0, 1.1, Box::new(CrossEntropyLoss::default()));
327        assert!(result.is_err());
328    }
329
330    #[test]
331    fn test_distillation_compute() {
332        let loss = DistillationLoss::new(2.0, 0.5, Box::new(CrossEntropyLoss::default())).unwrap();
333
334        let student_logits = array![[1.0, 2.0, 0.5], [0.5, 1.0, 2.0]];
335        let teacher_logits = array![[1.2, 1.8, 0.6], [0.6, 1.1, 1.9]];
336        let hard_targets = array![[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
337
338        let result = loss.compute_distillation(
339            &student_logits.view(),
340            &teacher_logits.view(),
341            &hard_targets.view(),
342        );
343
344        assert!(result.is_ok());
345        let loss_value = result.unwrap();
346        assert!(loss_value > 0.0);
347        assert!(loss_value.is_finite());
348    }
349
350    #[test]
351    fn test_feature_distillation_loss() {
352        let loss = FeatureDistillationLoss::new(vec![0.5, 0.3, 0.2], 2.0).unwrap();
353
354        let s1 = array![[1.0, 2.0], [3.0, 4.0]];
355        let s2 = array![[0.5, 1.5], [2.5, 3.5]];
356        let s3 = array![[0.1, 0.2], [0.3, 0.4]];
357        let student_features = vec![s1.view(), s2.view(), s3.view()];
358
359        let t1 = array![[1.1, 2.1], [3.1, 4.1]];
360        let t2 = array![[0.6, 1.6], [2.6, 3.6]];
361        let t3 = array![[0.2, 0.3], [0.4, 0.5]];
362        let teacher_features = vec![t1.view(), t2.view(), t3.view()];
363
364        let result = loss.compute_feature_loss(&student_features, &teacher_features);
365        assert!(result.is_ok());
366
367        let loss_value = result.unwrap();
368        assert!(loss_value > 0.0);
369        assert!(loss_value < 1.0); // Should be small for similar features
370    }
371
372    #[test]
373    fn test_attention_transfer_loss() {
374        let loss = AttentionTransferLoss::new(2.0);
375
376        let student_attention = array![[0.3, 0.5, 0.2], [0.4, 0.4, 0.2]];
377        let teacher_attention = array![[0.35, 0.45, 0.2], [0.35, 0.45, 0.2]];
378
379        let result =
380            loss.compute_attention_loss(&student_attention.view(), &teacher_attention.view());
381        assert!(result.is_ok());
382
383        let loss_value = result.unwrap();
384        assert!(loss_value >= 0.0);
385        assert!(loss_value.is_finite());
386    }
387
388    #[test]
389    fn test_feature_distillation_shape_mismatch() {
390        let loss = FeatureDistillationLoss::new(vec![1.0], 2.0).unwrap();
391
392        let s1 = array![[1.0, 2.0]];
393        let student_features = vec![s1.view()];
394
395        let t1 = array![[1.0, 2.0, 3.0]];
396        let teacher_features = vec![t1.view()];
397
398        let result = loss.compute_feature_loss(&student_features, &teacher_features);
399        assert!(result.is_err());
400    }
401}