1use crate::{TrainError, TrainResult};
10use std::collections::{HashMap, HashSet};
11
12#[derive(Debug, Clone)]
14pub struct LayerFreezingConfig {
15 frozen_layers: HashSet<String>,
17 freeze_all: bool,
19}
20
21impl LayerFreezingConfig {
22 pub fn new() -> Self {
24 Self {
25 frozen_layers: HashSet::new(),
26 freeze_all: false,
27 }
28 }
29
30 pub fn freeze_layers(&mut self, layer_names: &[&str]) {
35 for name in layer_names {
36 self.frozen_layers.insert(name.to_string());
37 }
38 }
39
40 pub fn unfreeze_layers(&mut self, layer_names: &[&str]) {
45 for name in layer_names {
46 self.frozen_layers.remove(*name);
47 }
48 }
49
50 pub fn freeze_all(&mut self) {
52 self.freeze_all = true;
53 }
54
55 pub fn unfreeze_all(&mut self) {
57 self.freeze_all = false;
58 self.frozen_layers.clear();
59 }
60
61 pub fn is_frozen(&self, layer_name: &str) -> bool {
66 self.freeze_all || self.frozen_layers.contains(layer_name)
67 }
68
69 pub fn frozen_layers(&self) -> Vec<String> {
71 self.frozen_layers.iter().cloned().collect()
72 }
73
74 pub fn num_frozen(&self) -> usize {
76 if self.freeze_all {
77 usize::MAX
78 } else {
79 self.frozen_layers.len()
80 }
81 }
82}
83
84impl Default for LayerFreezingConfig {
85 fn default() -> Self {
86 Self::new()
87 }
88}
89
90#[derive(Debug, Clone)]
94pub struct ProgressiveUnfreezing {
95 layer_order: Vec<String>,
97 unfreeze_interval: usize,
99 current_stage: usize,
101}
102
103impl ProgressiveUnfreezing {
104 pub fn new(layer_order: Vec<String>, unfreeze_interval: usize) -> TrainResult<Self> {
110 if layer_order.is_empty() {
111 return Err(TrainError::InvalidParameter(
112 "layer_order cannot be empty".to_string(),
113 ));
114 }
115 if unfreeze_interval == 0 {
116 return Err(TrainError::InvalidParameter(
117 "unfreeze_interval must be positive".to_string(),
118 ));
119 }
120 Ok(Self {
121 layer_order,
122 unfreeze_interval,
123 current_stage: 0,
124 })
125 }
126
127 pub fn update_stage(&mut self, epoch: usize) -> bool {
135 let new_stage = epoch / self.unfreeze_interval;
136 if new_stage > self.current_stage {
137 self.current_stage = new_stage.min(self.layer_order.len());
138 true
139 } else {
140 false
141 }
142 }
143
144 pub fn get_trainable_layers(&self) -> Vec<String> {
149 let num_trainable = self.current_stage.min(self.layer_order.len());
151 let start_idx = self.layer_order.len().saturating_sub(num_trainable);
152
153 self.layer_order[start_idx..].to_vec()
154 }
155
156 pub fn get_frozen_layers(&self) -> Vec<String> {
161 let num_trainable = self.current_stage.min(self.layer_order.len());
162 let end_idx = self.layer_order.len().saturating_sub(num_trainable);
163
164 self.layer_order[..end_idx].to_vec()
165 }
166
167 pub fn is_complete(&self) -> bool {
169 self.current_stage >= self.layer_order.len()
170 }
171
172 pub fn current_stage(&self) -> usize {
174 self.current_stage
175 }
176
177 pub fn total_stages(&self) -> usize {
179 self.layer_order.len()
180 }
181}
182
183#[derive(Debug, Clone)]
187pub struct DiscriminativeFineTuning {
188 pub base_lr: f64,
190 pub decay_factor: f64,
192 layer_lrs: HashMap<String, f64>,
194}
195
196impl DiscriminativeFineTuning {
197 pub fn new(base_lr: f64, decay_factor: f64) -> TrainResult<Self> {
203 if base_lr <= 0.0 {
204 return Err(TrainError::InvalidParameter(
205 "base_lr must be positive".to_string(),
206 ));
207 }
208 if !(0.0..=1.0).contains(&decay_factor) {
209 return Err(TrainError::InvalidParameter(
210 "decay_factor must be in [0, 1]".to_string(),
211 ));
212 }
213 Ok(Self {
214 base_lr,
215 decay_factor,
216 layer_lrs: HashMap::new(),
217 })
218 }
219
220 pub fn compute_layer_lrs(&mut self, layer_order: &[String]) {
225 self.layer_lrs.clear();
226
227 let num_layers = layer_order.len();
228 for (i, layer_name) in layer_order.iter().enumerate() {
229 let depth = num_layers - 1 - i;
231 let lr = self.base_lr * self.decay_factor.powi(depth as i32);
232 self.layer_lrs.insert(layer_name.clone(), lr);
233 }
234 }
235
236 pub fn get_layer_lr(&self, layer_name: &str) -> f64 {
244 self.layer_lrs
245 .get(layer_name)
246 .copied()
247 .unwrap_or(self.base_lr)
248 }
249
250 pub fn layer_lrs(&self) -> &HashMap<String, f64> {
252 &self.layer_lrs
253 }
254}
255
256#[derive(Debug, Clone)]
260pub struct FeatureExtractorMode {
261 pub feature_extractor_name: String,
263 pub head_name: String,
265}
266
267impl FeatureExtractorMode {
268 pub fn new(feature_extractor_name: String, head_name: String) -> Self {
274 Self {
275 feature_extractor_name,
276 head_name,
277 }
278 }
279
280 pub fn is_feature_extractor(&self, layer_name: &str) -> bool {
285 layer_name.starts_with(&self.feature_extractor_name)
286 }
287
288 pub fn is_head(&self, layer_name: &str) -> bool {
293 layer_name.starts_with(&self.head_name)
294 }
295
296 pub fn get_freezing_config(&self, all_layers: &[String]) -> LayerFreezingConfig {
304 let mut config = LayerFreezingConfig::new();
305
306 let feature_layers: Vec<&str> = all_layers
308 .iter()
309 .filter(|name| self.is_feature_extractor(name))
310 .map(|s| s.as_str())
311 .collect();
312
313 config.freeze_layers(&feature_layers);
314 config
315 }
316}
317
318#[derive(Debug)]
320pub struct TransferLearningManager {
321 pub freezing_config: LayerFreezingConfig,
323 pub progressive_unfreezing: Option<ProgressiveUnfreezing>,
325 pub discriminative_finetuning: Option<DiscriminativeFineTuning>,
327 current_epoch: usize,
329}
330
331impl TransferLearningManager {
332 pub fn new() -> Self {
334 Self {
335 freezing_config: LayerFreezingConfig::new(),
336 progressive_unfreezing: None,
337 discriminative_finetuning: None,
338 current_epoch: 0,
339 }
340 }
341
342 pub fn with_progressive_unfreezing(mut self, strategy: ProgressiveUnfreezing) -> Self {
347 let frozen = strategy.get_frozen_layers();
349 let frozen_refs: Vec<&str> = frozen.iter().map(|s| s.as_str()).collect();
350 self.freezing_config.freeze_layers(&frozen_refs);
351
352 self.progressive_unfreezing = Some(strategy);
353 self
354 }
355
356 pub fn with_discriminative_finetuning(mut self, config: DiscriminativeFineTuning) -> Self {
361 self.discriminative_finetuning = Some(config);
362 self
363 }
364
365 pub fn with_feature_extraction(
371 mut self,
372 mode: FeatureExtractorMode,
373 all_layers: &[String],
374 ) -> Self {
375 self.freezing_config = mode.get_freezing_config(all_layers);
376 self
377 }
378
379 pub fn on_epoch_begin(&mut self, epoch: usize) {
384 self.current_epoch = epoch;
385
386 if let Some(ref mut unfreezing) = self.progressive_unfreezing {
388 if unfreezing.update_stage(epoch) {
389 let frozen = unfreezing.get_frozen_layers();
391 let trainable = unfreezing.get_trainable_layers();
392
393 self.freezing_config.unfreeze_all();
395 let frozen_refs: Vec<&str> = frozen.iter().map(|s| s.as_str()).collect();
396 self.freezing_config.freeze_layers(&frozen_refs);
397
398 log::info!(
399 "Progressive unfreezing: Stage {}/{}, {} layers trainable",
400 unfreezing.current_stage(),
401 unfreezing.total_stages(),
402 trainable.len()
403 );
404 }
405 }
406 }
407
408 pub fn should_update_layer(&self, layer_name: &str) -> bool {
413 !self.freezing_config.is_frozen(layer_name)
414 }
415
416 pub fn get_layer_lr(&self, layer_name: &str, base_lr: f64) -> f64 {
425 if let Some(ref finetuning) = self.discriminative_finetuning {
426 finetuning.get_layer_lr(layer_name)
427 } else {
428 base_lr
429 }
430 }
431
432 pub fn current_epoch(&self) -> usize {
434 self.current_epoch
435 }
436}
437
438impl Default for TransferLearningManager {
439 fn default() -> Self {
440 Self::new()
441 }
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447
448 #[test]
449 fn test_layer_freezing_config() {
450 let mut config = LayerFreezingConfig::new();
451 assert!(!config.is_frozen("layer1"));
452
453 config.freeze_layers(&["layer1", "layer2"]);
454 assert!(config.is_frozen("layer1"));
455 assert!(config.is_frozen("layer2"));
456 assert!(!config.is_frozen("layer3"));
457
458 config.unfreeze_layers(&["layer1"]);
459 assert!(!config.is_frozen("layer1"));
460 assert!(config.is_frozen("layer2"));
461
462 assert_eq!(config.num_frozen(), 1);
463 }
464
465 #[test]
466 fn test_layer_freezing_all() {
467 let mut config = LayerFreezingConfig::new();
468 config.freeze_all();
469
470 assert!(config.is_frozen("any_layer"));
471 assert!(config.is_frozen("another_layer"));
472
473 config.unfreeze_all();
474 assert!(!config.is_frozen("any_layer"));
475 }
476
477 #[test]
478 fn test_progressive_unfreezing() {
479 let layers = vec![
480 "layer1".to_string(),
481 "layer2".to_string(),
482 "layer3".to_string(),
483 ];
484 let mut unfreezing = ProgressiveUnfreezing::new(layers, 5).unwrap();
485
486 assert_eq!(unfreezing.get_trainable_layers().len(), 0);
488 assert_eq!(unfreezing.get_frozen_layers().len(), 3);
489 assert!(!unfreezing.is_complete());
490
491 unfreezing.update_stage(5);
493 assert_eq!(unfreezing.current_stage(), 1);
494 assert_eq!(unfreezing.get_trainable_layers().len(), 1);
495 assert_eq!(unfreezing.get_frozen_layers().len(), 2);
496
497 unfreezing.update_stage(10);
499 assert_eq!(unfreezing.current_stage(), 2);
500 assert_eq!(unfreezing.get_trainable_layers().len(), 2);
501
502 unfreezing.update_stage(15);
504 assert_eq!(unfreezing.current_stage(), 3);
505 assert_eq!(unfreezing.get_trainable_layers().len(), 3);
506 assert!(unfreezing.is_complete());
507 }
508
509 #[test]
510 fn test_progressive_unfreezing_invalid() {
511 let result = ProgressiveUnfreezing::new(vec![], 5);
512 assert!(result.is_err());
513
514 let result = ProgressiveUnfreezing::new(vec!["layer1".to_string()], 0);
515 assert!(result.is_err());
516 }
517
518 #[test]
519 fn test_discriminative_finetuning() {
520 let mut finetuning = DiscriminativeFineTuning::new(1e-3, 0.5).unwrap();
521
522 let layers = vec![
523 "layer1".to_string(),
524 "layer2".to_string(),
525 "layer3".to_string(),
526 ];
527 finetuning.compute_layer_lrs(&layers);
528
529 assert!((finetuning.get_layer_lr("layer3") - 1e-3).abs() < 1e-10);
531
532 assert!((finetuning.get_layer_lr("layer2") - 5e-4).abs() < 1e-10);
534
535 assert!((finetuning.get_layer_lr("layer1") - 2.5e-4).abs() < 1e-10);
537 }
538
539 #[test]
540 fn test_discriminative_finetuning_invalid() {
541 assert!(DiscriminativeFineTuning::new(0.0, 0.5).is_err());
542 assert!(DiscriminativeFineTuning::new(-1e-3, 0.5).is_err());
543 assert!(DiscriminativeFineTuning::new(1e-3, 1.5).is_err());
544 assert!(DiscriminativeFineTuning::new(1e-3, -0.1).is_err());
545 }
546
547 #[test]
548 fn test_feature_extractor_mode() {
549 let mode = FeatureExtractorMode::new("encoder".to_string(), "classifier".to_string());
550
551 assert!(mode.is_feature_extractor("encoder.layer1"));
552 assert!(mode.is_feature_extractor("encoder.layer2"));
553 assert!(!mode.is_feature_extractor("classifier.fc"));
554
555 assert!(mode.is_head("classifier.fc"));
556 assert!(mode.is_head("classifier.output"));
557 assert!(!mode.is_head("encoder.layer1"));
558
559 let all_layers = vec![
560 "encoder.layer1".to_string(),
561 "encoder.layer2".to_string(),
562 "classifier.fc".to_string(),
563 ];
564
565 let config = mode.get_freezing_config(&all_layers);
566 assert!(config.is_frozen("encoder.layer1"));
567 assert!(config.is_frozen("encoder.layer2"));
568 assert!(!config.is_frozen("classifier.fc"));
569 }
570
571 #[test]
572 fn test_transfer_learning_manager() {
573 let mut manager = TransferLearningManager::new();
574
575 assert!(manager.should_update_layer("layer1"));
577
578 manager.freezing_config.freeze_layers(&["layer1"]);
580 assert!(!manager.should_update_layer("layer1"));
581 assert!(manager.should_update_layer("layer2"));
582 }
583
584 #[test]
585 fn test_transfer_learning_with_progressive_unfreezing() {
586 let layers = vec![
587 "layer1".to_string(),
588 "layer2".to_string(),
589 "layer3".to_string(),
590 ];
591 let unfreezing = ProgressiveUnfreezing::new(layers.clone(), 5).unwrap();
592
593 let mut manager = TransferLearningManager::new().with_progressive_unfreezing(unfreezing);
594
595 manager.on_epoch_begin(0);
597 assert!(!manager.should_update_layer("layer1"));
598 assert!(!manager.should_update_layer("layer2"));
599 assert!(!manager.should_update_layer("layer3"));
600
601 manager.on_epoch_begin(5);
603 assert!(!manager.should_update_layer("layer1"));
604 assert!(!manager.should_update_layer("layer2"));
605 assert!(manager.should_update_layer("layer3"));
606 }
607
608 #[test]
609 fn test_transfer_learning_with_discriminative_finetuning() {
610 let layers = vec![
611 "layer1".to_string(),
612 "layer2".to_string(),
613 "layer3".to_string(),
614 ];
615 let mut finetuning = DiscriminativeFineTuning::new(1e-3, 0.5).unwrap();
616 finetuning.compute_layer_lrs(&layers);
617
618 let manager = TransferLearningManager::new().with_discriminative_finetuning(finetuning);
619
620 assert!((manager.get_layer_lr("layer3", 1e-3) - 1e-3).abs() < 1e-10);
622 assert!((manager.get_layer_lr("layer2", 1e-3) - 5e-4).abs() < 1e-10);
623 assert!((manager.get_layer_lr("layer1", 1e-3) - 2.5e-4).abs() < 1e-10);
624 }
625
626 #[test]
627 fn test_transfer_learning_with_feature_extraction() {
628 let mode = FeatureExtractorMode::new("encoder".to_string(), "classifier".to_string());
629 let all_layers = vec![
630 "encoder.layer1".to_string(),
631 "encoder.layer2".to_string(),
632 "classifier.fc".to_string(),
633 ];
634
635 let manager = TransferLearningManager::new().with_feature_extraction(mode, &all_layers);
636
637 assert!(!manager.should_update_layer("encoder.layer1"));
639 assert!(!manager.should_update_layer("encoder.layer2"));
640
641 assert!(manager.should_update_layer("classifier.fc"));
643 }
644
645 #[test]
646 fn test_frozen_layers_getter() {
647 let mut config = LayerFreezingConfig::new();
648 config.freeze_layers(&["layer1", "layer2"]);
649
650 let frozen = config.frozen_layers();
651 assert_eq!(frozen.len(), 2);
652 assert!(frozen.contains(&"layer1".to_string()));
653 assert!(frozen.contains(&"layer2".to_string()));
654 }
655
656 #[test]
657 fn test_progressive_unfreezing_total_stages() {
658 let layers = vec!["layer1".to_string(), "layer2".to_string()];
659 let unfreezing = ProgressiveUnfreezing::new(layers, 5).unwrap();
660
661 assert_eq!(unfreezing.total_stages(), 2);
662 }
663}