Skip to main content

trustformers_models/
knowledge_distillation.rs

1//! # Knowledge Distillation Framework
2//!
3//! This module provides a comprehensive framework for knowledge distillation,
4//! enabling efficient transfer of knowledge from large teacher models to smaller student models.
5//!
6//! ## Features
7//!
8//! - **Multiple Distillation Strategies**: Response-based, feature-based, and attention-based distillation
9//! - **Temperature Control**: Configurable temperature scaling for soft targets
10//! - **Loss Combinations**: Flexible combination of distillation and task-specific losses
11//! - **Multi-layer Feature Matching**: Deep feature alignment between teacher and student
12//! - **Attention Transfer**: Transfer attention patterns from teacher to student
13//! - **Progressive Knowledge Transfer**: Gradual knowledge transfer strategies
14//!
15//! ## Usage
16//!
17//! ```rust,no_run
18//! use trustformers_models::knowledge_distillation::{
19//!     KnowledgeDistillationTrainer, DistillationConfig, DistillationStrategy
20//! };
21//!
22//! let config = DistillationConfig {
23//!     temperature: 4.0,
24//!     alpha: 0.7,
25//!     strategy: DistillationStrategy::ResponseBased,
26//!     ..Default::default()
27//! };
28//!
29//! let trainer = KnowledgeDistillationTrainer::new(teacher_model, student_model, config)?;
30//! trainer.train(dataloader)?;
31//! ```
32
33use scirs2_core::ndarray::{ArrayD, Axis, IxDyn}; // SciRS2 Integration Policy
34use serde::{Deserialize, Serialize};
35use std::collections::HashMap;
36use trustformers_core::{
37    errors::{tensor_op_error, TrustformersError},
38    layers::Linear,
39    tensor::Tensor,
40    traits::{Layer, Model},
41    Result,
42};
43
44/// Configuration for knowledge distillation
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct DistillationConfig {
47    /// Temperature for softmax in distillation loss (higher = softer)
48    pub temperature: f32,
49    /// Weight for distillation loss vs. hard target loss (0.0-1.0)
50    pub alpha: f32,
51    /// Distillation strategy to use
52    pub strategy: DistillationStrategy,
53    /// Whether to use feature matching
54    pub use_feature_matching: bool,
55    /// Layers to match features (teacher_layer -> student_layer)
56    pub feature_matching_layers: HashMap<usize, usize>,
57    /// Whether to use attention transfer
58    pub use_attention_transfer: bool,
59    /// Weight for attention transfer loss
60    pub attention_loss_weight: f32,
61    /// Whether to use progressive distillation
62    pub progressive: bool,
63    /// Number of progressive stages
64    pub progressive_stages: usize,
65    /// Minimum temperature for progressive cooling
66    pub min_temperature: f32,
67}
68
69impl Default for DistillationConfig {
70    fn default() -> Self {
71        Self {
72            temperature: 4.0,
73            alpha: 0.7,
74            strategy: DistillationStrategy::ResponseBased,
75            use_feature_matching: false,
76            feature_matching_layers: HashMap::new(),
77            use_attention_transfer: false,
78            attention_loss_weight: 0.1,
79            progressive: false,
80            progressive_stages: 5,
81            min_temperature: 1.0,
82        }
83    }
84}
85
86/// Different strategies for knowledge distillation
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub enum DistillationStrategy {
89    /// Standard response-based distillation using soft targets
90    ResponseBased,
91    /// Feature-based distillation matching intermediate representations
92    FeatureBased,
93    /// Attention-based distillation transferring attention patterns
94    AttentionBased,
95    /// Combined approach using multiple strategies
96    Combined {
97        response_weight: f32,
98        feature_weight: f32,
99        attention_weight: f32,
100    },
101    /// Progressive distillation with curriculum learning
102    Progressive { stages: Vec<ProgressiveStage> },
103}
104
105/// Configuration for a progressive distillation stage
106#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct ProgressiveStage {
108    /// Duration of this stage (in training steps)
109    pub duration: usize,
110    /// Temperature for this stage
111    pub temperature: f32,
112    /// Alpha weight for this stage
113    pub alpha: f32,
114    /// Whether to freeze teacher layers
115    pub freeze_teacher: bool,
116}
117
118/// Output from distillation loss computation
119#[derive(Debug, Clone)]
120pub struct DistillationOutput {
121    /// Total distillation loss
122    pub total_loss: Tensor,
123    /// Individual loss components
124    pub loss_components: HashMap<String, Tensor>,
125    /// Soft predictions from teacher
126    pub teacher_predictions: Tensor,
127    /// Predictions from student
128    pub student_predictions: Tensor,
129    /// Feature matching losses (if used)
130    pub feature_losses: Option<HashMap<String, Tensor>>,
131    /// Attention transfer losses (if used)
132    pub attention_losses: Option<HashMap<String, Tensor>>,
133}
134
135/// Knowledge distillation trainer
136pub struct KnowledgeDistillationTrainer<T, S> {
137    /// Teacher model (typically larger, pre-trained)
138    #[allow(dead_code)]
139    teacher: T,
140    /// Student model (typically smaller, being trained)
141    #[allow(dead_code)]
142    student: S,
143    /// Distillation configuration
144    config: DistillationConfig,
145    /// Feature matching projections
146    feature_projections: HashMap<usize, Linear>,
147    /// Current training stage (for progressive distillation)
148    current_stage: usize,
149    /// Current training step
150    current_step: usize,
151}
152
153impl<T, S> KnowledgeDistillationTrainer<T, S>
154where
155    T: Model,
156    S: Model,
157{
158    /// Create a new knowledge distillation trainer
159    pub fn new(teacher: T, student: S, config: DistillationConfig) -> Result<Self> {
160        let mut feature_projections = HashMap::new();
161
162        // Initialize feature projection layers if needed
163        if config.use_feature_matching {
164            for (&_teacher_layer, &student_layer) in &config.feature_matching_layers {
165                // Note: In practice, you'd need to get the actual hidden sizes
166                // This is a simplified example
167                let projection = Linear::new(768, 768, true); // Assuming 768 hidden size
168                feature_projections.insert(student_layer, projection);
169            }
170        }
171
172        Ok(Self {
173            teacher,
174            student,
175            config,
176            feature_projections,
177            current_stage: 0,
178            current_step: 0,
179        })
180    }
181
182    /// Compute distillation loss
183    pub fn compute_distillation_loss(
184        &self,
185        teacher_outputs: &TeacherOutputs,
186        student_outputs: &StudentOutputs,
187        hard_targets: Option<&Tensor>,
188    ) -> Result<DistillationOutput> {
189        let mut loss_components = HashMap::new();
190        let mut total_loss = Tensor::zeros(&[1])?;
191
192        match &self.config.strategy {
193            DistillationStrategy::ResponseBased => {
194                let response_loss = self.compute_response_distillation_loss(
195                    &teacher_outputs.logits,
196                    &student_outputs.logits,
197                )?;
198                loss_components.insert("response".to_string(), response_loss.clone());
199                total_loss = total_loss.add(&response_loss)?;
200            },
201            DistillationStrategy::FeatureBased => {
202                let feature_loss = self.compute_feature_distillation_loss(
203                    &teacher_outputs.hidden_states,
204                    &student_outputs.hidden_states,
205                )?;
206                loss_components.insert("feature".to_string(), feature_loss.clone());
207                total_loss = total_loss.add(&feature_loss)?;
208            },
209            DistillationStrategy::AttentionBased => {
210                let attention_loss = self.compute_attention_distillation_loss(
211                    &teacher_outputs.attentions,
212                    &student_outputs.attentions,
213                )?;
214                loss_components.insert("attention".to_string(), attention_loss.clone());
215                total_loss = total_loss.add(&attention_loss)?;
216            },
217            DistillationStrategy::Combined {
218                response_weight,
219                feature_weight,
220                attention_weight,
221            } => {
222                if *response_weight > 0.0 {
223                    let response_loss = self.compute_response_distillation_loss(
224                        &teacher_outputs.logits,
225                        &student_outputs.logits,
226                    )?;
227                    let weighted_response_loss = response_loss.scalar_mul(*response_weight)?;
228                    loss_components.insert("response".to_string(), weighted_response_loss.clone());
229                    total_loss = total_loss.add(&weighted_response_loss)?;
230                }
231
232                if *feature_weight > 0.0 && !teacher_outputs.hidden_states.is_empty() {
233                    let feature_loss = self.compute_feature_distillation_loss(
234                        &teacher_outputs.hidden_states,
235                        &student_outputs.hidden_states,
236                    )?;
237                    let weighted_feature_loss = feature_loss.scalar_mul(*feature_weight)?;
238                    loss_components.insert("feature".to_string(), weighted_feature_loss.clone());
239                    total_loss = total_loss.add(&weighted_feature_loss)?;
240                }
241
242                if *attention_weight > 0.0 && !teacher_outputs.attentions.is_empty() {
243                    let attention_loss = self.compute_attention_distillation_loss(
244                        &teacher_outputs.attentions,
245                        &student_outputs.attentions,
246                    )?;
247                    let weighted_attention_loss = attention_loss.scalar_mul(*attention_weight)?;
248                    loss_components
249                        .insert("attention".to_string(), weighted_attention_loss.clone());
250                    total_loss = total_loss.add(&weighted_attention_loss)?;
251                }
252            },
253            DistillationStrategy::Progressive { stages } => {
254                let current_stage = &stages[self.current_stage.min(stages.len() - 1)];
255                let response_loss = self.compute_response_distillation_loss_with_temperature(
256                    &teacher_outputs.logits,
257                    &student_outputs.logits,
258                    current_stage.temperature,
259                )?;
260                loss_components.insert("progressive_response".to_string(), response_loss.clone());
261                total_loss = total_loss.add(&response_loss)?;
262            },
263        }
264
265        // Add hard target loss if provided
266        if let Some(targets) = hard_targets {
267            let hard_loss = self.compute_hard_target_loss(&student_outputs.logits, targets)?;
268            let weighted_hard_loss = hard_loss.scalar_mul(1.0 - self.config.alpha)?;
269            loss_components.insert("hard_target".to_string(), weighted_hard_loss.clone());
270            total_loss = total_loss.add(&weighted_hard_loss)?;
271        }
272
273        // Collect feature losses for tracking
274        let feature_losses = if !teacher_outputs.hidden_states.is_empty()
275            && !student_outputs.hidden_states.is_empty()
276        {
277            Some(self.compute_layer_wise_feature_losses(
278                &teacher_outputs.hidden_states,
279                &student_outputs.hidden_states,
280            )?)
281        } else {
282            None
283        };
284
285        // Collect attention losses for tracking
286        let attention_losses =
287            if !teacher_outputs.attentions.is_empty() && !student_outputs.attentions.is_empty() {
288                Some(self.compute_layer_wise_attention_losses(
289                    &teacher_outputs.attentions,
290                    &student_outputs.attentions,
291                )?)
292            } else {
293                None
294            };
295
296        Ok(DistillationOutput {
297            total_loss,
298            loss_components,
299            teacher_predictions: teacher_outputs.logits.clone(),
300            student_predictions: student_outputs.logits.clone(),
301            feature_losses,
302            attention_losses,
303        })
304    }
305
306    /// Compute response-based distillation loss (KL divergence of soft targets)
307    fn compute_response_distillation_loss(
308        &self,
309        teacher_logits: &Tensor,
310        student_logits: &Tensor,
311    ) -> Result<Tensor> {
312        self.compute_response_distillation_loss_with_temperature(
313            teacher_logits,
314            student_logits,
315            self.config.temperature,
316        )
317    }
318
319    /// Compute response-based distillation loss with custom temperature
320    fn compute_response_distillation_loss_with_temperature(
321        &self,
322        teacher_logits: &Tensor,
323        student_logits: &Tensor,
324        temperature: f32,
325    ) -> Result<Tensor> {
326        // Apply temperature scaling
327        let teacher_scaled = teacher_logits.scalar_div(temperature)?;
328        let student_scaled = student_logits.scalar_div(temperature)?;
329
330        // Compute soft targets (softmax with temperature)
331        let teacher_soft = teacher_scaled.softmax(-1)?;
332        let student_soft = student_scaled.softmax(-1)?;
333        let student_log_soft = student_soft.log()?;
334
335        // KL divergence loss
336        let teacher_log = teacher_soft.log()?;
337        let log_diff = teacher_log.sub(&student_log_soft)?;
338        let kl_div = teacher_soft.mul(&log_diff)?;
339        let loss = kl_div.sum(None, false)?.mean()?;
340
341        // Scale by temperature squared (standard in knowledge distillation)
342        let temp_squared = temperature * temperature;
343        loss.scalar_mul(temp_squared)
344    }
345
346    /// Compute feature-based distillation loss
347    fn compute_feature_distillation_loss(
348        &self,
349        teacher_features: &[Tensor],
350        student_features: &[Tensor],
351    ) -> Result<Tensor> {
352        let mut total_loss = Tensor::zeros(&[1])?;
353        let mut num_matched = 0;
354
355        for (&teacher_layer, &student_layer) in &self.config.feature_matching_layers {
356            if teacher_layer < teacher_features.len() && student_layer < student_features.len() {
357                let teacher_feat = &teacher_features[teacher_layer];
358                let student_feat = &student_features[student_layer];
359
360                // Project student features to match teacher dimensionality if needed
361                let projected_student =
362                    if let Some(projection) = self.feature_projections.get(&student_layer) {
363                        projection.forward(student_feat.clone())?
364                    } else {
365                        student_feat.clone()
366                    };
367
368                // MSE loss between features
369                let diff = teacher_feat.sub(&projected_student)?;
370                let diff_squared = diff.mul(&diff)?;
371                let mse_loss = diff_squared.mean()?;
372                total_loss = total_loss.add(&mse_loss)?;
373                num_matched += 1;
374            }
375        }
376
377        if num_matched > 0 {
378            Ok(total_loss.scalar_div(num_matched as f32)?)
379        } else {
380            Ok(total_loss)
381        }
382    }
383
384    /// Compute attention-based distillation loss
385    fn compute_attention_distillation_loss(
386        &self,
387        teacher_attentions: &[Tensor],
388        student_attentions: &[Tensor],
389    ) -> Result<Tensor> {
390        let mut total_loss = Tensor::zeros(&[1])?;
391        let num_layers = teacher_attentions.len().min(student_attentions.len());
392
393        for i in 0..num_layers {
394            let teacher_attn = &teacher_attentions[i];
395            let student_attn = &student_attentions[i];
396
397            // MSE loss between attention matrices
398            let diff = teacher_attn.sub(student_attn)?;
399            let diff_squared = diff.mul(&diff)?;
400            let mse_loss = diff_squared.mean()?;
401            total_loss = total_loss.add(&mse_loss)?;
402        }
403
404        if num_layers > 0 {
405            Ok(total_loss.scalar_div(num_layers as f32)?)
406        } else {
407            Ok(total_loss)
408        }
409    }
410
411    /// Compute hard target loss (standard cross-entropy)
412    fn compute_hard_target_loss(&self, logits: &Tensor, _targets: &Tensor) -> Result<Tensor> {
413        let probs = logits.softmax(-1)?;
414        let log_probs = probs.log()?;
415
416        // Simplified cross-entropy implementation - in practice this would need proper indexing
417        // For now, compute mean of log probs as a placeholder
418        let neg_log_probs = log_probs.scalar_mul(-1.0)?;
419        neg_log_probs.mean()
420    }
421
422    /// Update training step and potentially stage for progressive distillation
423    pub fn step(&mut self) {
424        self.current_step += 1;
425
426        if let DistillationStrategy::Progressive { stages } = &self.config.strategy {
427            // Check if we should advance to the next stage
428            if self.current_stage < stages.len() - 1 {
429                let current_stage_config = &stages[self.current_stage];
430                if self.current_step >= current_stage_config.duration {
431                    self.current_stage += 1;
432                    self.current_step = 0;
433                }
434            }
435        }
436    }
437
438    /// Get current temperature (useful for progressive distillation)
439    pub fn current_temperature(&self) -> f32 {
440        match &self.config.strategy {
441            DistillationStrategy::Progressive { stages } => {
442                if self.current_stage < stages.len() {
443                    stages[self.current_stage].temperature
444                } else {
445                    self.config.min_temperature
446                }
447            },
448            _ => self.config.temperature,
449        }
450    }
451
452    /// Get current alpha (useful for progressive distillation)
453    pub fn current_alpha(&self) -> f32 {
454        match &self.config.strategy {
455            DistillationStrategy::Progressive { stages } => {
456                if self.current_stage < stages.len() {
457                    stages[self.current_stage].alpha
458                } else {
459                    self.config.alpha
460                }
461            },
462            _ => self.config.alpha,
463        }
464    }
465
466    /// Compute layer-wise feature losses for detailed tracking
467    fn compute_layer_wise_feature_losses(
468        &self,
469        teacher_hidden_states: &[Tensor],
470        student_hidden_states: &[Tensor],
471    ) -> Result<HashMap<String, Tensor>> {
472        let mut feature_losses = HashMap::new();
473
474        // Ensure we have matching layers (or use the minimum)
475        let num_layers = teacher_hidden_states.len().min(student_hidden_states.len());
476
477        for layer_idx in 0..num_layers {
478            let teacher_hidden = &teacher_hidden_states[layer_idx];
479            let student_hidden = &student_hidden_states[layer_idx];
480
481            // Apply projection if dimensions don't match
482            let aligned_student = if teacher_hidden.shape() != student_hidden.shape() {
483                // Simple projection to match teacher dimensions
484                match (teacher_hidden, student_hidden) {
485                    (Tensor::F32(t_arr), Tensor::F32(s_arr)) => {
486                        let teacher_shape = t_arr.shape();
487                        let student_shape = s_arr.shape();
488
489                        if teacher_shape.len() == student_shape.len()
490                            && teacher_shape[..teacher_shape.len() - 1]
491                                == student_shape[..student_shape.len() - 1]
492                        {
493                            // Only hidden dimension differs, project student to teacher size
494                            let teacher_hidden_dim = teacher_shape[teacher_shape.len() - 1];
495                            let student_hidden_dim = student_shape[student_shape.len() - 1];
496
497                            if student_hidden_dim != teacher_hidden_dim {
498                                // Simple linear projection (in practice, this would be a learned projection)
499                                let scale = teacher_hidden_dim as f32 / student_hidden_dim as f32;
500                                let projected = s_arr.mapv(|x| x * scale);
501
502                                // Reshape to match teacher dimensions
503                                let new_shape = teacher_shape.to_vec();
504                                let projected_data = if teacher_hidden_dim > student_hidden_dim {
505                                    // Pad with zeros
506                                    let mut padded_data = vec![0.0; new_shape.iter().product()];
507                                    let chunk_size = student_hidden_dim;
508                                    let total_chunks = s_arr.len() / chunk_size;
509
510                                    for chunk_idx in 0..total_chunks {
511                                        let src_start = chunk_idx * chunk_size;
512                                        let dst_start = chunk_idx * teacher_hidden_dim;
513                                        for i in 0..chunk_size {
514                                            padded_data[dst_start + i] = projected[src_start + i];
515                                        }
516                                    }
517                                    padded_data
518                                } else {
519                                    // Truncate
520                                    let chunk_size = teacher_hidden_dim;
521                                    let total_chunks = projected.len() / student_hidden_dim;
522                                    let mut truncated_data = Vec::new();
523
524                                    for chunk_idx in 0..total_chunks {
525                                        let src_start = chunk_idx * student_hidden_dim;
526                                        for i in 0..chunk_size {
527                                            truncated_data.push(projected[src_start + i]);
528                                        }
529                                    }
530                                    truncated_data
531                                };
532
533                                let projected_array =
534                                    ArrayD::from_shape_vec(IxDyn(&new_shape), projected_data)
535                                        .map_err(|_| {
536                                            TrustformersError::shape_error(
537                                                "Failed to project student features".to_string(),
538                                            )
539                                        })?;
540
541                                Tensor::F32(projected_array)
542                            } else {
543                                student_hidden.clone()
544                            }
545                        } else {
546                            student_hidden.clone()
547                        }
548                    },
549                    _ => student_hidden.clone(),
550                }
551            } else {
552                student_hidden.clone()
553            };
554
555            // Compute MSE loss between teacher and (aligned) student features
556            let diff = teacher_hidden.sub(&aligned_student)?;
557            let squared_diff = diff.mul(&diff)?;
558            let mse_loss = squared_diff.mean()?;
559
560            feature_losses.insert(format!("layer_{}", layer_idx), mse_loss);
561        }
562
563        Ok(feature_losses)
564    }
565
566    /// Compute layer-wise attention losses for detailed tracking
567    fn compute_layer_wise_attention_losses(
568        &self,
569        teacher_attentions: &[Tensor],
570        student_attentions: &[Tensor],
571    ) -> Result<HashMap<String, Tensor>> {
572        let mut attention_losses = HashMap::new();
573
574        // Ensure we have matching layers
575        let num_layers = teacher_attentions.len().min(student_attentions.len());
576
577        for layer_idx in 0..num_layers {
578            let teacher_attn = &teacher_attentions[layer_idx];
579            let student_attn = &student_attentions[layer_idx];
580
581            // Handle different attention head counts
582            let aligned_student_attn = if teacher_attn.shape() != student_attn.shape() {
583                self.align_attention_tensors(teacher_attn, student_attn)?
584            } else {
585                student_attn.clone()
586            };
587
588            // Compute attention transfer loss (MSE between attention distributions)
589            let diff = teacher_attn.sub(&aligned_student_attn)?;
590            let squared_diff = diff.mul(&diff)?;
591            let attn_loss = squared_diff.mean()?;
592
593            attention_losses.insert(format!("layer_{}", layer_idx), attn_loss);
594
595            // Additional attention-specific metrics
596            // 1. Attention entropy similarity
597            let teacher_entropy = self.compute_attention_entropy(teacher_attn)?;
598            let student_entropy = self.compute_attention_entropy(&aligned_student_attn)?;
599            let entropy_diff = teacher_entropy.sub(&student_entropy)?;
600            let entropy_loss = entropy_diff.mul(&entropy_diff)?;
601            attention_losses.insert(format!("layer_{}_entropy", layer_idx), entropy_loss);
602
603            // 2. Attention pattern correlation
604            let pattern_correlation =
605                self.compute_attention_correlation(teacher_attn, &aligned_student_attn)?;
606            attention_losses.insert(
607                format!("layer_{}_correlation", layer_idx),
608                pattern_correlation,
609            );
610        }
611
612        Ok(attention_losses)
613    }
614
615    /// Align attention tensors when they have different shapes (e.g., different head counts)
616    fn align_attention_tensors(&self, teacher: &Tensor, student: &Tensor) -> Result<Tensor> {
617        match (teacher, student) {
618            (Tensor::F32(t_arr), Tensor::F32(s_arr)) => {
619                let teacher_shape = t_arr.shape();
620                let student_shape = s_arr.shape();
621
622                // Assume attention shape is [batch, heads, seq_len, seq_len]
623                if teacher_shape.len() == 4 && student_shape.len() == 4 {
624                    let teacher_heads = teacher_shape[1];
625                    let student_heads = student_shape[1];
626
627                    if teacher_heads != student_heads {
628                        // Simple head alignment: average/repeat heads to match teacher count
629                        if student_heads < teacher_heads {
630                            // Repeat student heads
631                            let _repeat_factor = teacher_heads / student_heads;
632                            let mut aligned_data = Vec::new();
633
634                            let batch_size = student_shape[0];
635                            let seq_len = student_shape[2];
636                            let seq_len_2 = student_shape[3];
637
638                            for b in 0..batch_size {
639                                for h in 0..teacher_heads {
640                                    let source_head = h % student_heads;
641                                    for i in 0..seq_len {
642                                        for j in 0..seq_len_2 {
643                                            aligned_data.push(s_arr[[b, source_head, i, j]]);
644                                        }
645                                    }
646                                }
647                            }
648
649                            let aligned_array = ArrayD::from_shape_vec(
650                                IxDyn(&[batch_size, teacher_heads, seq_len, seq_len_2]),
651                                aligned_data,
652                            )
653                            .map_err(|_| {
654                                TrustformersError::shape_error(
655                                    "Failed to align attention heads".to_string(),
656                                )
657                            })?;
658
659                            Ok(Tensor::F32(aligned_array))
660                        } else {
661                            // Average student heads to match teacher count
662                            let group_size = student_heads / teacher_heads;
663                            let mut aligned_data = Vec::new();
664
665                            let batch_size = student_shape[0];
666                            let seq_len = student_shape[2];
667                            let seq_len_2 = student_shape[3];
668
669                            for b in 0..batch_size {
670                                for h in 0..teacher_heads {
671                                    for i in 0..seq_len {
672                                        for j in 0..seq_len_2 {
673                                            let mut sum = 0.0;
674                                            for g in 0..group_size {
675                                                let student_head = h * group_size + g;
676                                                if student_head < student_heads {
677                                                    sum += s_arr[[b, student_head, i, j]];
678                                                }
679                                            }
680                                            aligned_data.push(sum / group_size as f32);
681                                        }
682                                    }
683                                }
684                            }
685
686                            let aligned_array = ArrayD::from_shape_vec(
687                                IxDyn(&[batch_size, teacher_heads, seq_len, seq_len_2]),
688                                aligned_data,
689                            )
690                            .map_err(|_| {
691                                TrustformersError::shape_error(
692                                    "Failed to align attention heads".to_string(),
693                                )
694                            })?;
695
696                            Ok(Tensor::F32(aligned_array))
697                        }
698                    } else {
699                        Ok(student.clone())
700                    }
701                } else {
702                    Ok(student.clone())
703                }
704            },
705            _ => Ok(student.clone()),
706        }
707    }
708
709    /// Compute attention entropy for measuring attention distribution sharpness
710    fn compute_attention_entropy(&self, attention: &Tensor) -> Result<Tensor> {
711        match attention {
712            Tensor::F32(arr) => {
713                // Compute entropy: -sum(p * log(p)) for each attention head
714                let epsilon = 1e-8_f32; // Small constant to avoid log(0)
715                let log_probs = arr.mapv(|x| (x + epsilon).ln());
716                let entropy_contributions = arr * &log_probs;
717                let entropy = entropy_contributions.sum_axis(Axis(3)); // Sum over last dimension
718                let mean_entropy = entropy.mean().expect("operation failed");
719
720                Ok(Tensor::F32(ArrayD::from_elem(IxDyn(&[1]), -mean_entropy)))
721            },
722            _ => Err(tensor_op_error(
723                "tensor_operation",
724                "Attention entropy computation only supports F32 tensors".to_string(),
725            )),
726        }
727    }
728
729    /// Compute correlation between teacher and student attention patterns
730    fn compute_attention_correlation(&self, teacher: &Tensor, student: &Tensor) -> Result<Tensor> {
731        match (teacher, student) {
732            (Tensor::F32(t_arr), Tensor::F32(s_arr)) => {
733                // Flatten attention matrices and compute Pearson correlation
734                let teacher_flat: Vec<f32> = t_arr.iter().cloned().collect();
735                let student_flat: Vec<f32> = s_arr.iter().cloned().collect();
736
737                if teacher_flat.len() != student_flat.len() {
738                    return Ok(Tensor::F32(ArrayD::from_elem(IxDyn(&[1]), 0.0)));
739                }
740
741                let n = teacher_flat.len() as f32;
742                let teacher_mean: f32 = teacher_flat.iter().sum::<f32>() / n;
743                let student_mean: f32 = student_flat.iter().sum::<f32>() / n;
744
745                let mut numerator = 0.0;
746                let mut teacher_var = 0.0;
747                let mut student_var = 0.0;
748
749                for i in 0..teacher_flat.len() {
750                    let teacher_centered = teacher_flat[i] - teacher_mean;
751                    let student_centered = student_flat[i] - student_mean;
752
753                    numerator += teacher_centered * student_centered;
754                    teacher_var += teacher_centered * teacher_centered;
755                    student_var += student_centered * student_centered;
756                }
757
758                let correlation = if teacher_var > 0.0 && student_var > 0.0 {
759                    numerator / (teacher_var.sqrt() * student_var.sqrt())
760                } else {
761                    0.0
762                };
763
764                Ok(Tensor::F32(ArrayD::from_elem(IxDyn(&[1]), correlation)))
765            },
766            _ => Err(tensor_op_error(
767                "tensor_operation",
768                "Attention correlation computation only supports F32 tensors".to_string(),
769            )),
770        }
771    }
772}
773
774/// Outputs from teacher model for distillation
775#[derive(Debug, Clone)]
776pub struct TeacherOutputs {
777    /// Final layer logits
778    pub logits: Tensor,
779    /// Hidden states from all layers
780    pub hidden_states: Vec<Tensor>,
781    /// Attention weights from all layers
782    pub attentions: Vec<Tensor>,
783}
784
785/// Outputs from student model for distillation
786#[derive(Debug, Clone)]
787pub struct StudentOutputs {
788    /// Final layer logits
789    pub logits: Tensor,
790    /// Hidden states from all layers
791    pub hidden_states: Vec<Tensor>,
792    /// Attention weights from all layers
793    pub attentions: Vec<Tensor>,
794}
795
796/// Utilities for knowledge distillation
797pub mod utils {
798    use super::*;
799
800    /// Create a basic response-based distillation config
801    pub fn response_distillation_config(temperature: f32, alpha: f32) -> DistillationConfig {
802        DistillationConfig {
803            temperature,
804            alpha,
805            strategy: DistillationStrategy::ResponseBased,
806            ..Default::default()
807        }
808    }
809
810    /// Create a feature-based distillation config
811    pub fn feature_distillation_config(
812        layer_mapping: HashMap<usize, usize>,
813        alpha: f32,
814    ) -> DistillationConfig {
815        DistillationConfig {
816            alpha,
817            strategy: DistillationStrategy::FeatureBased,
818            use_feature_matching: true,
819            feature_matching_layers: layer_mapping,
820            ..Default::default()
821        }
822    }
823
824    /// Create a combined distillation config
825    pub fn combined_distillation_config(
826        temperature: f32,
827        alpha: f32,
828        response_weight: f32,
829        feature_weight: f32,
830        attention_weight: f32,
831    ) -> DistillationConfig {
832        DistillationConfig {
833            temperature,
834            alpha,
835            strategy: DistillationStrategy::Combined {
836                response_weight,
837                feature_weight,
838                attention_weight,
839            },
840            use_feature_matching: feature_weight > 0.0,
841            use_attention_transfer: attention_weight > 0.0,
842            ..Default::default()
843        }
844    }
845
846    /// Create a progressive distillation config
847    pub fn progressive_distillation_config(stages: Vec<ProgressiveStage>) -> DistillationConfig {
848        DistillationConfig {
849            strategy: DistillationStrategy::Progressive { stages },
850            progressive: true,
851            ..Default::default()
852        }
853    }
854
855    /// Helper to create a linear decay schedule for progressive distillation
856    pub fn linear_decay_stages(
857        initial_temp: f32,
858        final_temp: f32,
859        initial_alpha: f32,
860        final_alpha: f32,
861        num_stages: usize,
862        steps_per_stage: usize,
863    ) -> Vec<ProgressiveStage> {
864        let mut stages = Vec::new();
865
866        for i in 0..num_stages {
867            let progress = i as f32 / (num_stages - 1) as f32;
868            let temp = initial_temp + progress * (final_temp - initial_temp);
869            let alpha = initial_alpha + progress * (final_alpha - initial_alpha);
870
871            stages.push(ProgressiveStage {
872                duration: steps_per_stage,
873                temperature: temp,
874                alpha,
875                freeze_teacher: false,
876            });
877        }
878
879        stages
880    }
881}
882
883#[cfg(test)]
884mod tests {
885    use super::*;
886
887    #[test]
888    fn test_distillation_config_default() {
889        let config = DistillationConfig::default();
890        assert_eq!(config.temperature, 4.0);
891        assert_eq!(config.alpha, 0.7);
892        assert!(!config.use_feature_matching);
893        assert!(!config.use_attention_transfer);
894    }
895
896    #[test]
897    fn test_response_distillation_config() {
898        let config = utils::response_distillation_config(3.0, 0.8);
899        assert_eq!(config.temperature, 3.0);
900        assert_eq!(config.alpha, 0.8);
901        assert!(matches!(
902            config.strategy,
903            DistillationStrategy::ResponseBased
904        ));
905    }
906
907    #[test]
908    fn test_feature_distillation_config() {
909        let mut layer_mapping = HashMap::new();
910        layer_mapping.insert(11, 5); // Map teacher layer 11 to student layer 5
911
912        let config = utils::feature_distillation_config(layer_mapping.clone(), 0.6);
913        assert_eq!(config.alpha, 0.6);
914        assert!(config.use_feature_matching);
915        assert_eq!(config.feature_matching_layers, layer_mapping);
916    }
917
918    #[test]
919    fn test_combined_distillation_config() {
920        let config = utils::combined_distillation_config(4.0, 0.7, 0.5, 0.3, 0.2);
921        assert_eq!(config.temperature, 4.0);
922        assert_eq!(config.alpha, 0.7);
923        assert!(config.use_feature_matching);
924        assert!(config.use_attention_transfer);
925
926        if let DistillationStrategy::Combined {
927            response_weight,
928            feature_weight,
929            attention_weight,
930        } = config.strategy
931        {
932            assert_eq!(response_weight, 0.5);
933            assert_eq!(feature_weight, 0.3);
934            assert_eq!(attention_weight, 0.2);
935        } else {
936            panic!("Expected Combined strategy");
937        }
938    }
939
940    #[test]
941    fn test_progressive_stages() {
942        let stages = utils::linear_decay_stages(5.0, 1.0, 0.8, 0.5, 4, 1000);
943        assert_eq!(stages.len(), 4);
944        assert_eq!(stages[0].temperature, 5.0);
945        assert_eq!(stages[3].temperature, 1.0);
946        assert_eq!(stages[0].alpha, 0.8);
947        assert!(stages[3].alpha - 0.5 < 1e-6); // Float comparison
948    }
949
950    #[test]
951    fn test_progressive_distillation_config() {
952        let stages = vec![
953            ProgressiveStage {
954                duration: 1000,
955                temperature: 5.0,
956                alpha: 0.8,
957                freeze_teacher: false,
958            },
959            ProgressiveStage {
960                duration: 1000,
961                temperature: 3.0,
962                alpha: 0.6,
963                freeze_teacher: false,
964            },
965        ];
966
967        let config = utils::progressive_distillation_config(stages.clone());
968        assert!(config.progressive);
969
970        if let DistillationStrategy::Progressive {
971            stages: config_stages,
972        } = config.strategy
973        {
974            assert_eq!(config_stages.len(), 2);
975            assert_eq!(config_stages[0].temperature, 5.0);
976            assert_eq!(config_stages[1].temperature, 3.0);
977        } else {
978            panic!("Expected Progressive strategy");
979        }
980    }
981}