1use scirs2_core::ndarray::{ArrayD, Axis, IxDyn}; use serde::{Deserialize, Serialize};
35use std::collections::HashMap;
36use trustformers_core::{
37 errors::{tensor_op_error, TrustformersError},
38 layers::Linear,
39 tensor::Tensor,
40 traits::{Layer, Model},
41 Result,
42};
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct DistillationConfig {
47 pub temperature: f32,
49 pub alpha: f32,
51 pub strategy: DistillationStrategy,
53 pub use_feature_matching: bool,
55 pub feature_matching_layers: HashMap<usize, usize>,
57 pub use_attention_transfer: bool,
59 pub attention_loss_weight: f32,
61 pub progressive: bool,
63 pub progressive_stages: usize,
65 pub min_temperature: f32,
67}
68
69impl Default for DistillationConfig {
70 fn default() -> Self {
71 Self {
72 temperature: 4.0,
73 alpha: 0.7,
74 strategy: DistillationStrategy::ResponseBased,
75 use_feature_matching: false,
76 feature_matching_layers: HashMap::new(),
77 use_attention_transfer: false,
78 attention_loss_weight: 0.1,
79 progressive: false,
80 progressive_stages: 5,
81 min_temperature: 1.0,
82 }
83 }
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub enum DistillationStrategy {
89 ResponseBased,
91 FeatureBased,
93 AttentionBased,
95 Combined {
97 response_weight: f32,
98 feature_weight: f32,
99 attention_weight: f32,
100 },
101 Progressive { stages: Vec<ProgressiveStage> },
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct ProgressiveStage {
108 pub duration: usize,
110 pub temperature: f32,
112 pub alpha: f32,
114 pub freeze_teacher: bool,
116}
117
118#[derive(Debug, Clone)]
120pub struct DistillationOutput {
121 pub total_loss: Tensor,
123 pub loss_components: HashMap<String, Tensor>,
125 pub teacher_predictions: Tensor,
127 pub student_predictions: Tensor,
129 pub feature_losses: Option<HashMap<String, Tensor>>,
131 pub attention_losses: Option<HashMap<String, Tensor>>,
133}
134
135pub struct KnowledgeDistillationTrainer<T, S> {
137 #[allow(dead_code)]
139 teacher: T,
140 #[allow(dead_code)]
142 student: S,
143 config: DistillationConfig,
145 feature_projections: HashMap<usize, Linear>,
147 current_stage: usize,
149 current_step: usize,
151}
152
153impl<T, S> KnowledgeDistillationTrainer<T, S>
154where
155 T: Model,
156 S: Model,
157{
158 pub fn new(teacher: T, student: S, config: DistillationConfig) -> Result<Self> {
160 let mut feature_projections = HashMap::new();
161
162 if config.use_feature_matching {
164 for (&_teacher_layer, &student_layer) in &config.feature_matching_layers {
165 let projection = Linear::new(768, 768, true); feature_projections.insert(student_layer, projection);
169 }
170 }
171
172 Ok(Self {
173 teacher,
174 student,
175 config,
176 feature_projections,
177 current_stage: 0,
178 current_step: 0,
179 })
180 }
181
182 pub fn compute_distillation_loss(
184 &self,
185 teacher_outputs: &TeacherOutputs,
186 student_outputs: &StudentOutputs,
187 hard_targets: Option<&Tensor>,
188 ) -> Result<DistillationOutput> {
189 let mut loss_components = HashMap::new();
190 let mut total_loss = Tensor::zeros(&[1])?;
191
192 match &self.config.strategy {
193 DistillationStrategy::ResponseBased => {
194 let response_loss = self.compute_response_distillation_loss(
195 &teacher_outputs.logits,
196 &student_outputs.logits,
197 )?;
198 loss_components.insert("response".to_string(), response_loss.clone());
199 total_loss = total_loss.add(&response_loss)?;
200 },
201 DistillationStrategy::FeatureBased => {
202 let feature_loss = self.compute_feature_distillation_loss(
203 &teacher_outputs.hidden_states,
204 &student_outputs.hidden_states,
205 )?;
206 loss_components.insert("feature".to_string(), feature_loss.clone());
207 total_loss = total_loss.add(&feature_loss)?;
208 },
209 DistillationStrategy::AttentionBased => {
210 let attention_loss = self.compute_attention_distillation_loss(
211 &teacher_outputs.attentions,
212 &student_outputs.attentions,
213 )?;
214 loss_components.insert("attention".to_string(), attention_loss.clone());
215 total_loss = total_loss.add(&attention_loss)?;
216 },
217 DistillationStrategy::Combined {
218 response_weight,
219 feature_weight,
220 attention_weight,
221 } => {
222 if *response_weight > 0.0 {
223 let response_loss = self.compute_response_distillation_loss(
224 &teacher_outputs.logits,
225 &student_outputs.logits,
226 )?;
227 let weighted_response_loss = response_loss.scalar_mul(*response_weight)?;
228 loss_components.insert("response".to_string(), weighted_response_loss.clone());
229 total_loss = total_loss.add(&weighted_response_loss)?;
230 }
231
232 if *feature_weight > 0.0 && !teacher_outputs.hidden_states.is_empty() {
233 let feature_loss = self.compute_feature_distillation_loss(
234 &teacher_outputs.hidden_states,
235 &student_outputs.hidden_states,
236 )?;
237 let weighted_feature_loss = feature_loss.scalar_mul(*feature_weight)?;
238 loss_components.insert("feature".to_string(), weighted_feature_loss.clone());
239 total_loss = total_loss.add(&weighted_feature_loss)?;
240 }
241
242 if *attention_weight > 0.0 && !teacher_outputs.attentions.is_empty() {
243 let attention_loss = self.compute_attention_distillation_loss(
244 &teacher_outputs.attentions,
245 &student_outputs.attentions,
246 )?;
247 let weighted_attention_loss = attention_loss.scalar_mul(*attention_weight)?;
248 loss_components
249 .insert("attention".to_string(), weighted_attention_loss.clone());
250 total_loss = total_loss.add(&weighted_attention_loss)?;
251 }
252 },
253 DistillationStrategy::Progressive { stages } => {
254 let current_stage = &stages[self.current_stage.min(stages.len() - 1)];
255 let response_loss = self.compute_response_distillation_loss_with_temperature(
256 &teacher_outputs.logits,
257 &student_outputs.logits,
258 current_stage.temperature,
259 )?;
260 loss_components.insert("progressive_response".to_string(), response_loss.clone());
261 total_loss = total_loss.add(&response_loss)?;
262 },
263 }
264
265 if let Some(targets) = hard_targets {
267 let hard_loss = self.compute_hard_target_loss(&student_outputs.logits, targets)?;
268 let weighted_hard_loss = hard_loss.scalar_mul(1.0 - self.config.alpha)?;
269 loss_components.insert("hard_target".to_string(), weighted_hard_loss.clone());
270 total_loss = total_loss.add(&weighted_hard_loss)?;
271 }
272
273 let feature_losses = if !teacher_outputs.hidden_states.is_empty()
275 && !student_outputs.hidden_states.is_empty()
276 {
277 Some(self.compute_layer_wise_feature_losses(
278 &teacher_outputs.hidden_states,
279 &student_outputs.hidden_states,
280 )?)
281 } else {
282 None
283 };
284
285 let attention_losses =
287 if !teacher_outputs.attentions.is_empty() && !student_outputs.attentions.is_empty() {
288 Some(self.compute_layer_wise_attention_losses(
289 &teacher_outputs.attentions,
290 &student_outputs.attentions,
291 )?)
292 } else {
293 None
294 };
295
296 Ok(DistillationOutput {
297 total_loss,
298 loss_components,
299 teacher_predictions: teacher_outputs.logits.clone(),
300 student_predictions: student_outputs.logits.clone(),
301 feature_losses,
302 attention_losses,
303 })
304 }
305
306 fn compute_response_distillation_loss(
308 &self,
309 teacher_logits: &Tensor,
310 student_logits: &Tensor,
311 ) -> Result<Tensor> {
312 self.compute_response_distillation_loss_with_temperature(
313 teacher_logits,
314 student_logits,
315 self.config.temperature,
316 )
317 }
318
319 fn compute_response_distillation_loss_with_temperature(
321 &self,
322 teacher_logits: &Tensor,
323 student_logits: &Tensor,
324 temperature: f32,
325 ) -> Result<Tensor> {
326 let teacher_scaled = teacher_logits.scalar_div(temperature)?;
328 let student_scaled = student_logits.scalar_div(temperature)?;
329
330 let teacher_soft = teacher_scaled.softmax(-1)?;
332 let student_soft = student_scaled.softmax(-1)?;
333 let student_log_soft = student_soft.log()?;
334
335 let teacher_log = teacher_soft.log()?;
337 let log_diff = teacher_log.sub(&student_log_soft)?;
338 let kl_div = teacher_soft.mul(&log_diff)?;
339 let loss = kl_div.sum(None, false)?.mean()?;
340
341 let temp_squared = temperature * temperature;
343 loss.scalar_mul(temp_squared)
344 }
345
346 fn compute_feature_distillation_loss(
348 &self,
349 teacher_features: &[Tensor],
350 student_features: &[Tensor],
351 ) -> Result<Tensor> {
352 let mut total_loss = Tensor::zeros(&[1])?;
353 let mut num_matched = 0;
354
355 for (&teacher_layer, &student_layer) in &self.config.feature_matching_layers {
356 if teacher_layer < teacher_features.len() && student_layer < student_features.len() {
357 let teacher_feat = &teacher_features[teacher_layer];
358 let student_feat = &student_features[student_layer];
359
360 let projected_student =
362 if let Some(projection) = self.feature_projections.get(&student_layer) {
363 projection.forward(student_feat.clone())?
364 } else {
365 student_feat.clone()
366 };
367
368 let diff = teacher_feat.sub(&projected_student)?;
370 let diff_squared = diff.mul(&diff)?;
371 let mse_loss = diff_squared.mean()?;
372 total_loss = total_loss.add(&mse_loss)?;
373 num_matched += 1;
374 }
375 }
376
377 if num_matched > 0 {
378 Ok(total_loss.scalar_div(num_matched as f32)?)
379 } else {
380 Ok(total_loss)
381 }
382 }
383
384 fn compute_attention_distillation_loss(
386 &self,
387 teacher_attentions: &[Tensor],
388 student_attentions: &[Tensor],
389 ) -> Result<Tensor> {
390 let mut total_loss = Tensor::zeros(&[1])?;
391 let num_layers = teacher_attentions.len().min(student_attentions.len());
392
393 for i in 0..num_layers {
394 let teacher_attn = &teacher_attentions[i];
395 let student_attn = &student_attentions[i];
396
397 let diff = teacher_attn.sub(student_attn)?;
399 let diff_squared = diff.mul(&diff)?;
400 let mse_loss = diff_squared.mean()?;
401 total_loss = total_loss.add(&mse_loss)?;
402 }
403
404 if num_layers > 0 {
405 Ok(total_loss.scalar_div(num_layers as f32)?)
406 } else {
407 Ok(total_loss)
408 }
409 }
410
411 fn compute_hard_target_loss(&self, logits: &Tensor, _targets: &Tensor) -> Result<Tensor> {
413 let probs = logits.softmax(-1)?;
414 let log_probs = probs.log()?;
415
416 let neg_log_probs = log_probs.scalar_mul(-1.0)?;
419 neg_log_probs.mean()
420 }
421
422 pub fn step(&mut self) {
424 self.current_step += 1;
425
426 if let DistillationStrategy::Progressive { stages } = &self.config.strategy {
427 if self.current_stage < stages.len() - 1 {
429 let current_stage_config = &stages[self.current_stage];
430 if self.current_step >= current_stage_config.duration {
431 self.current_stage += 1;
432 self.current_step = 0;
433 }
434 }
435 }
436 }
437
438 pub fn current_temperature(&self) -> f32 {
440 match &self.config.strategy {
441 DistillationStrategy::Progressive { stages } => {
442 if self.current_stage < stages.len() {
443 stages[self.current_stage].temperature
444 } else {
445 self.config.min_temperature
446 }
447 },
448 _ => self.config.temperature,
449 }
450 }
451
452 pub fn current_alpha(&self) -> f32 {
454 match &self.config.strategy {
455 DistillationStrategy::Progressive { stages } => {
456 if self.current_stage < stages.len() {
457 stages[self.current_stage].alpha
458 } else {
459 self.config.alpha
460 }
461 },
462 _ => self.config.alpha,
463 }
464 }
465
466 fn compute_layer_wise_feature_losses(
468 &self,
469 teacher_hidden_states: &[Tensor],
470 student_hidden_states: &[Tensor],
471 ) -> Result<HashMap<String, Tensor>> {
472 let mut feature_losses = HashMap::new();
473
474 let num_layers = teacher_hidden_states.len().min(student_hidden_states.len());
476
477 for layer_idx in 0..num_layers {
478 let teacher_hidden = &teacher_hidden_states[layer_idx];
479 let student_hidden = &student_hidden_states[layer_idx];
480
481 let aligned_student = if teacher_hidden.shape() != student_hidden.shape() {
483 match (teacher_hidden, student_hidden) {
485 (Tensor::F32(t_arr), Tensor::F32(s_arr)) => {
486 let teacher_shape = t_arr.shape();
487 let student_shape = s_arr.shape();
488
489 if teacher_shape.len() == student_shape.len()
490 && teacher_shape[..teacher_shape.len() - 1]
491 == student_shape[..student_shape.len() - 1]
492 {
493 let teacher_hidden_dim = teacher_shape[teacher_shape.len() - 1];
495 let student_hidden_dim = student_shape[student_shape.len() - 1];
496
497 if student_hidden_dim != teacher_hidden_dim {
498 let scale = teacher_hidden_dim as f32 / student_hidden_dim as f32;
500 let projected = s_arr.mapv(|x| x * scale);
501
502 let new_shape = teacher_shape.to_vec();
504 let projected_data = if teacher_hidden_dim > student_hidden_dim {
505 let mut padded_data = vec![0.0; new_shape.iter().product()];
507 let chunk_size = student_hidden_dim;
508 let total_chunks = s_arr.len() / chunk_size;
509
510 for chunk_idx in 0..total_chunks {
511 let src_start = chunk_idx * chunk_size;
512 let dst_start = chunk_idx * teacher_hidden_dim;
513 for i in 0..chunk_size {
514 padded_data[dst_start + i] = projected[src_start + i];
515 }
516 }
517 padded_data
518 } else {
519 let chunk_size = teacher_hidden_dim;
521 let total_chunks = projected.len() / student_hidden_dim;
522 let mut truncated_data = Vec::new();
523
524 for chunk_idx in 0..total_chunks {
525 let src_start = chunk_idx * student_hidden_dim;
526 for i in 0..chunk_size {
527 truncated_data.push(projected[src_start + i]);
528 }
529 }
530 truncated_data
531 };
532
533 let projected_array =
534 ArrayD::from_shape_vec(IxDyn(&new_shape), projected_data)
535 .map_err(|_| {
536 TrustformersError::shape_error(
537 "Failed to project student features".to_string(),
538 )
539 })?;
540
541 Tensor::F32(projected_array)
542 } else {
543 student_hidden.clone()
544 }
545 } else {
546 student_hidden.clone()
547 }
548 },
549 _ => student_hidden.clone(),
550 }
551 } else {
552 student_hidden.clone()
553 };
554
555 let diff = teacher_hidden.sub(&aligned_student)?;
557 let squared_diff = diff.mul(&diff)?;
558 let mse_loss = squared_diff.mean()?;
559
560 feature_losses.insert(format!("layer_{}", layer_idx), mse_loss);
561 }
562
563 Ok(feature_losses)
564 }
565
566 fn compute_layer_wise_attention_losses(
568 &self,
569 teacher_attentions: &[Tensor],
570 student_attentions: &[Tensor],
571 ) -> Result<HashMap<String, Tensor>> {
572 let mut attention_losses = HashMap::new();
573
574 let num_layers = teacher_attentions.len().min(student_attentions.len());
576
577 for layer_idx in 0..num_layers {
578 let teacher_attn = &teacher_attentions[layer_idx];
579 let student_attn = &student_attentions[layer_idx];
580
581 let aligned_student_attn = if teacher_attn.shape() != student_attn.shape() {
583 self.align_attention_tensors(teacher_attn, student_attn)?
584 } else {
585 student_attn.clone()
586 };
587
588 let diff = teacher_attn.sub(&aligned_student_attn)?;
590 let squared_diff = diff.mul(&diff)?;
591 let attn_loss = squared_diff.mean()?;
592
593 attention_losses.insert(format!("layer_{}", layer_idx), attn_loss);
594
595 let teacher_entropy = self.compute_attention_entropy(teacher_attn)?;
598 let student_entropy = self.compute_attention_entropy(&aligned_student_attn)?;
599 let entropy_diff = teacher_entropy.sub(&student_entropy)?;
600 let entropy_loss = entropy_diff.mul(&entropy_diff)?;
601 attention_losses.insert(format!("layer_{}_entropy", layer_idx), entropy_loss);
602
603 let pattern_correlation =
605 self.compute_attention_correlation(teacher_attn, &aligned_student_attn)?;
606 attention_losses.insert(
607 format!("layer_{}_correlation", layer_idx),
608 pattern_correlation,
609 );
610 }
611
612 Ok(attention_losses)
613 }
614
615 fn align_attention_tensors(&self, teacher: &Tensor, student: &Tensor) -> Result<Tensor> {
617 match (teacher, student) {
618 (Tensor::F32(t_arr), Tensor::F32(s_arr)) => {
619 let teacher_shape = t_arr.shape();
620 let student_shape = s_arr.shape();
621
622 if teacher_shape.len() == 4 && student_shape.len() == 4 {
624 let teacher_heads = teacher_shape[1];
625 let student_heads = student_shape[1];
626
627 if teacher_heads != student_heads {
628 if student_heads < teacher_heads {
630 let _repeat_factor = teacher_heads / student_heads;
632 let mut aligned_data = Vec::new();
633
634 let batch_size = student_shape[0];
635 let seq_len = student_shape[2];
636 let seq_len_2 = student_shape[3];
637
638 for b in 0..batch_size {
639 for h in 0..teacher_heads {
640 let source_head = h % student_heads;
641 for i in 0..seq_len {
642 for j in 0..seq_len_2 {
643 aligned_data.push(s_arr[[b, source_head, i, j]]);
644 }
645 }
646 }
647 }
648
649 let aligned_array = ArrayD::from_shape_vec(
650 IxDyn(&[batch_size, teacher_heads, seq_len, seq_len_2]),
651 aligned_data,
652 )
653 .map_err(|_| {
654 TrustformersError::shape_error(
655 "Failed to align attention heads".to_string(),
656 )
657 })?;
658
659 Ok(Tensor::F32(aligned_array))
660 } else {
661 let group_size = student_heads / teacher_heads;
663 let mut aligned_data = Vec::new();
664
665 let batch_size = student_shape[0];
666 let seq_len = student_shape[2];
667 let seq_len_2 = student_shape[3];
668
669 for b in 0..batch_size {
670 for h in 0..teacher_heads {
671 for i in 0..seq_len {
672 for j in 0..seq_len_2 {
673 let mut sum = 0.0;
674 for g in 0..group_size {
675 let student_head = h * group_size + g;
676 if student_head < student_heads {
677 sum += s_arr[[b, student_head, i, j]];
678 }
679 }
680 aligned_data.push(sum / group_size as f32);
681 }
682 }
683 }
684 }
685
686 let aligned_array = ArrayD::from_shape_vec(
687 IxDyn(&[batch_size, teacher_heads, seq_len, seq_len_2]),
688 aligned_data,
689 )
690 .map_err(|_| {
691 TrustformersError::shape_error(
692 "Failed to align attention heads".to_string(),
693 )
694 })?;
695
696 Ok(Tensor::F32(aligned_array))
697 }
698 } else {
699 Ok(student.clone())
700 }
701 } else {
702 Ok(student.clone())
703 }
704 },
705 _ => Ok(student.clone()),
706 }
707 }
708
709 fn compute_attention_entropy(&self, attention: &Tensor) -> Result<Tensor> {
711 match attention {
712 Tensor::F32(arr) => {
713 let epsilon = 1e-8_f32; let log_probs = arr.mapv(|x| (x + epsilon).ln());
716 let entropy_contributions = arr * &log_probs;
717 let entropy = entropy_contributions.sum_axis(Axis(3)); let mean_entropy = entropy.mean().expect("operation failed");
719
720 Ok(Tensor::F32(ArrayD::from_elem(IxDyn(&[1]), -mean_entropy)))
721 },
722 _ => Err(tensor_op_error(
723 "tensor_operation",
724 "Attention entropy computation only supports F32 tensors".to_string(),
725 )),
726 }
727 }
728
729 fn compute_attention_correlation(&self, teacher: &Tensor, student: &Tensor) -> Result<Tensor> {
731 match (teacher, student) {
732 (Tensor::F32(t_arr), Tensor::F32(s_arr)) => {
733 let teacher_flat: Vec<f32> = t_arr.iter().cloned().collect();
735 let student_flat: Vec<f32> = s_arr.iter().cloned().collect();
736
737 if teacher_flat.len() != student_flat.len() {
738 return Ok(Tensor::F32(ArrayD::from_elem(IxDyn(&[1]), 0.0)));
739 }
740
741 let n = teacher_flat.len() as f32;
742 let teacher_mean: f32 = teacher_flat.iter().sum::<f32>() / n;
743 let student_mean: f32 = student_flat.iter().sum::<f32>() / n;
744
745 let mut numerator = 0.0;
746 let mut teacher_var = 0.0;
747 let mut student_var = 0.0;
748
749 for i in 0..teacher_flat.len() {
750 let teacher_centered = teacher_flat[i] - teacher_mean;
751 let student_centered = student_flat[i] - student_mean;
752
753 numerator += teacher_centered * student_centered;
754 teacher_var += teacher_centered * teacher_centered;
755 student_var += student_centered * student_centered;
756 }
757
758 let correlation = if teacher_var > 0.0 && student_var > 0.0 {
759 numerator / (teacher_var.sqrt() * student_var.sqrt())
760 } else {
761 0.0
762 };
763
764 Ok(Tensor::F32(ArrayD::from_elem(IxDyn(&[1]), correlation)))
765 },
766 _ => Err(tensor_op_error(
767 "tensor_operation",
768 "Attention correlation computation only supports F32 tensors".to_string(),
769 )),
770 }
771 }
772}
773
774#[derive(Debug, Clone)]
776pub struct TeacherOutputs {
777 pub logits: Tensor,
779 pub hidden_states: Vec<Tensor>,
781 pub attentions: Vec<Tensor>,
783}
784
785#[derive(Debug, Clone)]
787pub struct StudentOutputs {
788 pub logits: Tensor,
790 pub hidden_states: Vec<Tensor>,
792 pub attentions: Vec<Tensor>,
794}
795
796pub mod utils {
798 use super::*;
799
800 pub fn response_distillation_config(temperature: f32, alpha: f32) -> DistillationConfig {
802 DistillationConfig {
803 temperature,
804 alpha,
805 strategy: DistillationStrategy::ResponseBased,
806 ..Default::default()
807 }
808 }
809
810 pub fn feature_distillation_config(
812 layer_mapping: HashMap<usize, usize>,
813 alpha: f32,
814 ) -> DistillationConfig {
815 DistillationConfig {
816 alpha,
817 strategy: DistillationStrategy::FeatureBased,
818 use_feature_matching: true,
819 feature_matching_layers: layer_mapping,
820 ..Default::default()
821 }
822 }
823
824 pub fn combined_distillation_config(
826 temperature: f32,
827 alpha: f32,
828 response_weight: f32,
829 feature_weight: f32,
830 attention_weight: f32,
831 ) -> DistillationConfig {
832 DistillationConfig {
833 temperature,
834 alpha,
835 strategy: DistillationStrategy::Combined {
836 response_weight,
837 feature_weight,
838 attention_weight,
839 },
840 use_feature_matching: feature_weight > 0.0,
841 use_attention_transfer: attention_weight > 0.0,
842 ..Default::default()
843 }
844 }
845
846 pub fn progressive_distillation_config(stages: Vec<ProgressiveStage>) -> DistillationConfig {
848 DistillationConfig {
849 strategy: DistillationStrategy::Progressive { stages },
850 progressive: true,
851 ..Default::default()
852 }
853 }
854
855 pub fn linear_decay_stages(
857 initial_temp: f32,
858 final_temp: f32,
859 initial_alpha: f32,
860 final_alpha: f32,
861 num_stages: usize,
862 steps_per_stage: usize,
863 ) -> Vec<ProgressiveStage> {
864 let mut stages = Vec::new();
865
866 for i in 0..num_stages {
867 let progress = i as f32 / (num_stages - 1) as f32;
868 let temp = initial_temp + progress * (final_temp - initial_temp);
869 let alpha = initial_alpha + progress * (final_alpha - initial_alpha);
870
871 stages.push(ProgressiveStage {
872 duration: steps_per_stage,
873 temperature: temp,
874 alpha,
875 freeze_teacher: false,
876 });
877 }
878
879 stages
880 }
881}
882
883#[cfg(test)]
884mod tests {
885 use super::*;
886
887 #[test]
888 fn test_distillation_config_default() {
889 let config = DistillationConfig::default();
890 assert_eq!(config.temperature, 4.0);
891 assert_eq!(config.alpha, 0.7);
892 assert!(!config.use_feature_matching);
893 assert!(!config.use_attention_transfer);
894 }
895
896 #[test]
897 fn test_response_distillation_config() {
898 let config = utils::response_distillation_config(3.0, 0.8);
899 assert_eq!(config.temperature, 3.0);
900 assert_eq!(config.alpha, 0.8);
901 assert!(matches!(
902 config.strategy,
903 DistillationStrategy::ResponseBased
904 ));
905 }
906
907 #[test]
908 fn test_feature_distillation_config() {
909 let mut layer_mapping = HashMap::new();
910 layer_mapping.insert(11, 5); let config = utils::feature_distillation_config(layer_mapping.clone(), 0.6);
913 assert_eq!(config.alpha, 0.6);
914 assert!(config.use_feature_matching);
915 assert_eq!(config.feature_matching_layers, layer_mapping);
916 }
917
918 #[test]
919 fn test_combined_distillation_config() {
920 let config = utils::combined_distillation_config(4.0, 0.7, 0.5, 0.3, 0.2);
921 assert_eq!(config.temperature, 4.0);
922 assert_eq!(config.alpha, 0.7);
923 assert!(config.use_feature_matching);
924 assert!(config.use_attention_transfer);
925
926 if let DistillationStrategy::Combined {
927 response_weight,
928 feature_weight,
929 attention_weight,
930 } = config.strategy
931 {
932 assert_eq!(response_weight, 0.5);
933 assert_eq!(feature_weight, 0.3);
934 assert_eq!(attention_weight, 0.2);
935 } else {
936 panic!("Expected Combined strategy");
937 }
938 }
939
940 #[test]
941 fn test_progressive_stages() {
942 let stages = utils::linear_decay_stages(5.0, 1.0, 0.8, 0.5, 4, 1000);
943 assert_eq!(stages.len(), 4);
944 assert_eq!(stages[0].temperature, 5.0);
945 assert_eq!(stages[3].temperature, 1.0);
946 assert_eq!(stages[0].alpha, 0.8);
947 assert!(stages[3].alpha - 0.5 < 1e-6); }
949
950 #[test]
951 fn test_progressive_distillation_config() {
952 let stages = vec![
953 ProgressiveStage {
954 duration: 1000,
955 temperature: 5.0,
956 alpha: 0.8,
957 freeze_teacher: false,
958 },
959 ProgressiveStage {
960 duration: 1000,
961 temperature: 3.0,
962 alpha: 0.6,
963 freeze_teacher: false,
964 },
965 ];
966
967 let config = utils::progressive_distillation_config(stages.clone());
968 assert!(config.progressive);
969
970 if let DistillationStrategy::Progressive {
971 stages: config_stages,
972 } = config.strategy
973 {
974 assert_eq!(config_stages.len(), 2);
975 assert_eq!(config_stages[0].temperature, 5.0);
976 assert_eq!(config_stages[1].temperature, 3.0);
977 } else {
978 panic!("Expected Progressive strategy");
979 }
980 }
981}