Skip to main content

tensorlogic_train/
transfer.rs

1//! Transfer learning utilities for model fine-tuning.
2//!
3//! This module provides utilities for transfer learning:
4//! - Layer freezing and unfreezing
5//! - Progressive fine-tuning strategies
6//! - Feature extraction mode
7//! - Learning rate scheduling for transfer learning
8
9use crate::{TrainError, TrainResult};
10use std::collections::{HashMap, HashSet};
11
12/// Layer freezing configuration for transfer learning.
13#[derive(Debug, Clone)]
14pub struct LayerFreezingConfig {
15    /// Set of frozen layer names.
16    frozen_layers: HashSet<String>,
17    /// Whether to freeze all layers by default.
18    freeze_all: bool,
19}
20
21impl LayerFreezingConfig {
22    /// Create a new layer freezing configuration.
23    pub fn new() -> Self {
24        Self {
25            frozen_layers: HashSet::new(),
26            freeze_all: false,
27        }
28    }
29
30    /// Freeze specific layers.
31    ///
32    /// # Arguments
33    /// * `layer_names` - Names of layers to freeze
34    pub fn freeze_layers(&mut self, layer_names: &[&str]) {
35        for name in layer_names {
36            self.frozen_layers.insert(name.to_string());
37        }
38    }
39
40    /// Unfreeze specific layers.
41    ///
42    /// # Arguments
43    /// * `layer_names` - Names of layers to unfreeze
44    pub fn unfreeze_layers(&mut self, layer_names: &[&str]) {
45        for name in layer_names {
46            self.frozen_layers.remove(*name);
47        }
48    }
49
50    /// Freeze all layers.
51    pub fn freeze_all(&mut self) {
52        self.freeze_all = true;
53    }
54
55    /// Unfreeze all layers.
56    pub fn unfreeze_all(&mut self) {
57        self.freeze_all = false;
58        self.frozen_layers.clear();
59    }
60
61    /// Check if a layer is frozen.
62    ///
63    /// # Arguments
64    /// * `layer_name` - Name of the layer to check
65    pub fn is_frozen(&self, layer_name: &str) -> bool {
66        self.freeze_all || self.frozen_layers.contains(layer_name)
67    }
68
69    /// Get all frozen layer names.
70    pub fn frozen_layers(&self) -> Vec<String> {
71        self.frozen_layers.iter().cloned().collect()
72    }
73
74    /// Get the number of frozen layers.
75    pub fn num_frozen(&self) -> usize {
76        if self.freeze_all {
77            usize::MAX
78        } else {
79            self.frozen_layers.len()
80        }
81    }
82}
83
84impl Default for LayerFreezingConfig {
85    fn default() -> Self {
86        Self::new()
87    }
88}
89
90/// Progressive unfreezing strategy for transfer learning.
91///
92/// Gradually unfreezes layers from top to bottom during training.
93#[derive(Debug, Clone)]
94pub struct ProgressiveUnfreezing {
95    /// Layer names ordered from bottom (early) to top (late).
96    layer_order: Vec<String>,
97    /// Number of epochs to wait before unfreezing next layer.
98    unfreeze_interval: usize,
99    /// Current unfreezing stage.
100    current_stage: usize,
101}
102
103impl ProgressiveUnfreezing {
104    /// Create a new progressive unfreezing strategy.
105    ///
106    /// # Arguments
107    /// * `layer_order` - Layer names ordered from bottom to top
108    /// * `unfreeze_interval` - Epochs between unfreezing stages
109    pub fn new(layer_order: Vec<String>, unfreeze_interval: usize) -> TrainResult<Self> {
110        if layer_order.is_empty() {
111            return Err(TrainError::InvalidParameter(
112                "layer_order cannot be empty".to_string(),
113            ));
114        }
115        if unfreeze_interval == 0 {
116            return Err(TrainError::InvalidParameter(
117                "unfreeze_interval must be positive".to_string(),
118            ));
119        }
120        Ok(Self {
121            layer_order,
122            unfreeze_interval,
123            current_stage: 0,
124        })
125    }
126
127    /// Update the unfreezing stage based on current epoch.
128    ///
129    /// # Arguments
130    /// * `epoch` - Current training epoch
131    ///
132    /// # Returns
133    /// Whether the stage was updated
134    pub fn update_stage(&mut self, epoch: usize) -> bool {
135        let new_stage = epoch / self.unfreeze_interval;
136        if new_stage > self.current_stage {
137            self.current_stage = new_stage.min(self.layer_order.len());
138            true
139        } else {
140            false
141        }
142    }
143
144    /// Get layers that should be unfrozen at current stage.
145    ///
146    /// # Returns
147    /// Layer names that should be trainable
148    pub fn get_trainable_layers(&self) -> Vec<String> {
149        // Unfreeze from top to bottom: start with last layers
150        let num_trainable = self.current_stage.min(self.layer_order.len());
151        let start_idx = self.layer_order.len().saturating_sub(num_trainable);
152
153        self.layer_order[start_idx..].to_vec()
154    }
155
156    /// Get layers that should be frozen at current stage.
157    ///
158    /// # Returns
159    /// Layer names that should be frozen
160    pub fn get_frozen_layers(&self) -> Vec<String> {
161        let num_trainable = self.current_stage.min(self.layer_order.len());
162        let end_idx = self.layer_order.len().saturating_sub(num_trainable);
163
164        self.layer_order[..end_idx].to_vec()
165    }
166
167    /// Check if unfreezing is complete.
168    pub fn is_complete(&self) -> bool {
169        self.current_stage >= self.layer_order.len()
170    }
171
172    /// Get current stage number.
173    pub fn current_stage(&self) -> usize {
174        self.current_stage
175    }
176
177    /// Get total number of stages.
178    pub fn total_stages(&self) -> usize {
179        self.layer_order.len()
180    }
181}
182
183/// Discriminative fine-tuning: use different learning rates for different layers.
184///
185/// Typically, earlier layers use smaller learning rates than later layers.
186#[derive(Debug, Clone)]
187pub struct DiscriminativeFineTuning {
188    /// Base learning rate for the last layer.
189    pub base_lr: f64,
190    /// Learning rate decay factor (each earlier layer uses lr * decay_factor).
191    pub decay_factor: f64,
192    /// Layer-specific learning rates.
193    layer_lrs: HashMap<String, f64>,
194}
195
196impl DiscriminativeFineTuning {
197    /// Create a new discriminative fine-tuning configuration.
198    ///
199    /// # Arguments
200    /// * `base_lr` - Learning rate for the last layer
201    /// * `decay_factor` - Decay factor for earlier layers (e.g., 0.5 means half LR)
202    pub fn new(base_lr: f64, decay_factor: f64) -> TrainResult<Self> {
203        if base_lr <= 0.0 {
204            return Err(TrainError::InvalidParameter(
205                "base_lr must be positive".to_string(),
206            ));
207        }
208        if !(0.0..=1.0).contains(&decay_factor) {
209            return Err(TrainError::InvalidParameter(
210                "decay_factor must be in [0, 1]".to_string(),
211            ));
212        }
213        Ok(Self {
214            base_lr,
215            decay_factor,
216            layer_lrs: HashMap::new(),
217        })
218    }
219
220    /// Compute learning rates for all layers.
221    ///
222    /// # Arguments
223    /// * `layer_order` - Layer names ordered from bottom to top
224    pub fn compute_layer_lrs(&mut self, layer_order: &[String]) {
225        self.layer_lrs.clear();
226
227        let num_layers = layer_order.len();
228        for (i, layer_name) in layer_order.iter().enumerate() {
229            // Later layers get higher learning rates
230            let depth = num_layers - 1 - i;
231            let lr = self.base_lr * self.decay_factor.powi(depth as i32);
232            self.layer_lrs.insert(layer_name.clone(), lr);
233        }
234    }
235
236    /// Get the learning rate for a specific layer.
237    ///
238    /// # Arguments
239    /// * `layer_name` - Name of the layer
240    ///
241    /// # Returns
242    /// Learning rate for the layer, or base_lr if not found
243    pub fn get_layer_lr(&self, layer_name: &str) -> f64 {
244        self.layer_lrs
245            .get(layer_name)
246            .copied()
247            .unwrap_or(self.base_lr)
248    }
249
250    /// Get all layer learning rates.
251    pub fn layer_lrs(&self) -> &HashMap<String, f64> {
252        &self.layer_lrs
253    }
254}
255
256/// Feature extraction mode: freeze entire feature extractor.
257///
258/// Only trains the final classification/regression head.
259#[derive(Debug, Clone)]
260pub struct FeatureExtractorMode {
261    /// Name of the feature extractor (typically all layers except last).
262    pub feature_extractor_name: String,
263    /// Name of the head/classifier (typically the last layer).
264    pub head_name: String,
265}
266
267impl FeatureExtractorMode {
268    /// Create a new feature extractor mode.
269    ///
270    /// # Arguments
271    /// * `feature_extractor_name` - Name/prefix of feature extractor layers
272    /// * `head_name` - Name/prefix of head layers
273    pub fn new(feature_extractor_name: String, head_name: String) -> Self {
274        Self {
275            feature_extractor_name,
276            head_name,
277        }
278    }
279
280    /// Check if a layer is part of the feature extractor.
281    ///
282    /// # Arguments
283    /// * `layer_name` - Name of the layer
284    pub fn is_feature_extractor(&self, layer_name: &str) -> bool {
285        layer_name.starts_with(&self.feature_extractor_name)
286    }
287
288    /// Check if a layer is part of the head.
289    ///
290    /// # Arguments
291    /// * `layer_name` - Name of the layer
292    pub fn is_head(&self, layer_name: &str) -> bool {
293        layer_name.starts_with(&self.head_name)
294    }
295
296    /// Get freezing configuration for feature extraction.
297    ///
298    /// # Arguments
299    /// * `all_layers` - All layer names in the model
300    ///
301    /// # Returns
302    /// Layer freezing configuration
303    pub fn get_freezing_config(&self, all_layers: &[String]) -> LayerFreezingConfig {
304        let mut config = LayerFreezingConfig::new();
305
306        // Freeze all feature extractor layers
307        let feature_layers: Vec<&str> = all_layers
308            .iter()
309            .filter(|name| self.is_feature_extractor(name))
310            .map(|s| s.as_str())
311            .collect();
312
313        config.freeze_layers(&feature_layers);
314        config
315    }
316}
317
318/// Transfer learning strategy manager.
319#[derive(Debug)]
320pub struct TransferLearningManager {
321    /// Layer freezing configuration.
322    pub freezing_config: LayerFreezingConfig,
323    /// Optional progressive unfreezing strategy.
324    pub progressive_unfreezing: Option<ProgressiveUnfreezing>,
325    /// Optional discriminative fine-tuning.
326    pub discriminative_finetuning: Option<DiscriminativeFineTuning>,
327    /// Current epoch counter.
328    current_epoch: usize,
329}
330
331impl TransferLearningManager {
332    /// Create a new transfer learning manager.
333    pub fn new() -> Self {
334        Self {
335            freezing_config: LayerFreezingConfig::new(),
336            progressive_unfreezing: None,
337            discriminative_finetuning: None,
338            current_epoch: 0,
339        }
340    }
341
342    /// Set progressive unfreezing strategy.
343    ///
344    /// # Arguments
345    /// * `strategy` - Progressive unfreezing configuration
346    pub fn with_progressive_unfreezing(mut self, strategy: ProgressiveUnfreezing) -> Self {
347        // Initialize freezing config with all layers frozen (stage 0)
348        let frozen = strategy.get_frozen_layers();
349        let frozen_refs: Vec<&str> = frozen.iter().map(|s| s.as_str()).collect();
350        self.freezing_config.freeze_layers(&frozen_refs);
351
352        self.progressive_unfreezing = Some(strategy);
353        self
354    }
355
356    /// Set discriminative fine-tuning.
357    ///
358    /// # Arguments
359    /// * `config` - Discriminative fine-tuning configuration
360    pub fn with_discriminative_finetuning(mut self, config: DiscriminativeFineTuning) -> Self {
361        self.discriminative_finetuning = Some(config);
362        self
363    }
364
365    /// Set feature extraction mode.
366    ///
367    /// # Arguments
368    /// * `mode` - Feature extraction configuration
369    /// * `all_layers` - All layer names in the model
370    pub fn with_feature_extraction(
371        mut self,
372        mode: FeatureExtractorMode,
373        all_layers: &[String],
374    ) -> Self {
375        self.freezing_config = mode.get_freezing_config(all_layers);
376        self
377    }
378
379    /// Update for new epoch.
380    ///
381    /// # Arguments
382    /// * `epoch` - Current training epoch
383    pub fn on_epoch_begin(&mut self, epoch: usize) {
384        self.current_epoch = epoch;
385
386        // Update progressive unfreezing if enabled
387        if let Some(ref mut unfreezing) = self.progressive_unfreezing {
388            if unfreezing.update_stage(epoch) {
389                // Update freezing config based on new stage
390                let frozen = unfreezing.get_frozen_layers();
391                let trainable = unfreezing.get_trainable_layers();
392
393                // Clear and rebuild freezing config
394                self.freezing_config.unfreeze_all();
395                let frozen_refs: Vec<&str> = frozen.iter().map(|s| s.as_str()).collect();
396                self.freezing_config.freeze_layers(&frozen_refs);
397
398                log::info!(
399                    "Progressive unfreezing: Stage {}/{}, {} layers trainable",
400                    unfreezing.current_stage(),
401                    unfreezing.total_stages(),
402                    trainable.len()
403                );
404            }
405        }
406    }
407
408    /// Check if a layer should be updated during training.
409    ///
410    /// # Arguments
411    /// * `layer_name` - Name of the layer
412    pub fn should_update_layer(&self, layer_name: &str) -> bool {
413        !self.freezing_config.is_frozen(layer_name)
414    }
415
416    /// Get the learning rate for a specific layer.
417    ///
418    /// # Arguments
419    /// * `layer_name` - Name of the layer
420    /// * `base_lr` - Base learning rate
421    ///
422    /// # Returns
423    /// Layer-specific learning rate
424    pub fn get_layer_lr(&self, layer_name: &str, base_lr: f64) -> f64 {
425        if let Some(ref finetuning) = self.discriminative_finetuning {
426            finetuning.get_layer_lr(layer_name)
427        } else {
428            base_lr
429        }
430    }
431
432    /// Get current epoch.
433    pub fn current_epoch(&self) -> usize {
434        self.current_epoch
435    }
436}
437
438impl Default for TransferLearningManager {
439    fn default() -> Self {
440        Self::new()
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447
448    #[test]
449    fn test_layer_freezing_config() {
450        let mut config = LayerFreezingConfig::new();
451        assert!(!config.is_frozen("layer1"));
452
453        config.freeze_layers(&["layer1", "layer2"]);
454        assert!(config.is_frozen("layer1"));
455        assert!(config.is_frozen("layer2"));
456        assert!(!config.is_frozen("layer3"));
457
458        config.unfreeze_layers(&["layer1"]);
459        assert!(!config.is_frozen("layer1"));
460        assert!(config.is_frozen("layer2"));
461
462        assert_eq!(config.num_frozen(), 1);
463    }
464
465    #[test]
466    fn test_layer_freezing_all() {
467        let mut config = LayerFreezingConfig::new();
468        config.freeze_all();
469
470        assert!(config.is_frozen("any_layer"));
471        assert!(config.is_frozen("another_layer"));
472
473        config.unfreeze_all();
474        assert!(!config.is_frozen("any_layer"));
475    }
476
477    #[test]
478    fn test_progressive_unfreezing() {
479        let layers = vec![
480            "layer1".to_string(),
481            "layer2".to_string(),
482            "layer3".to_string(),
483        ];
484        let mut unfreezing = ProgressiveUnfreezing::new(layers, 5).unwrap();
485
486        // Stage 0: all frozen
487        assert_eq!(unfreezing.get_trainable_layers().len(), 0);
488        assert_eq!(unfreezing.get_frozen_layers().len(), 3);
489        assert!(!unfreezing.is_complete());
490
491        // Epoch 5: unfreeze last layer
492        unfreezing.update_stage(5);
493        assert_eq!(unfreezing.current_stage(), 1);
494        assert_eq!(unfreezing.get_trainable_layers().len(), 1);
495        assert_eq!(unfreezing.get_frozen_layers().len(), 2);
496
497        // Epoch 10: unfreeze two last layers
498        unfreezing.update_stage(10);
499        assert_eq!(unfreezing.current_stage(), 2);
500        assert_eq!(unfreezing.get_trainable_layers().len(), 2);
501
502        // Epoch 15: all unfrozen
503        unfreezing.update_stage(15);
504        assert_eq!(unfreezing.current_stage(), 3);
505        assert_eq!(unfreezing.get_trainable_layers().len(), 3);
506        assert!(unfreezing.is_complete());
507    }
508
509    #[test]
510    fn test_progressive_unfreezing_invalid() {
511        let result = ProgressiveUnfreezing::new(vec![], 5);
512        assert!(result.is_err());
513
514        let result = ProgressiveUnfreezing::new(vec!["layer1".to_string()], 0);
515        assert!(result.is_err());
516    }
517
518    #[test]
519    fn test_discriminative_finetuning() {
520        let mut finetuning = DiscriminativeFineTuning::new(1e-3, 0.5).unwrap();
521
522        let layers = vec![
523            "layer1".to_string(),
524            "layer2".to_string(),
525            "layer3".to_string(),
526        ];
527        finetuning.compute_layer_lrs(&layers);
528
529        // Last layer should have base_lr
530        assert!((finetuning.get_layer_lr("layer3") - 1e-3).abs() < 1e-10);
531
532        // Second layer should have base_lr * decay_factor
533        assert!((finetuning.get_layer_lr("layer2") - 5e-4).abs() < 1e-10);
534
535        // First layer should have base_lr * decay_factor^2
536        assert!((finetuning.get_layer_lr("layer1") - 2.5e-4).abs() < 1e-10);
537    }
538
539    #[test]
540    fn test_discriminative_finetuning_invalid() {
541        assert!(DiscriminativeFineTuning::new(0.0, 0.5).is_err());
542        assert!(DiscriminativeFineTuning::new(-1e-3, 0.5).is_err());
543        assert!(DiscriminativeFineTuning::new(1e-3, 1.5).is_err());
544        assert!(DiscriminativeFineTuning::new(1e-3, -0.1).is_err());
545    }
546
547    #[test]
548    fn test_feature_extractor_mode() {
549        let mode = FeatureExtractorMode::new("encoder".to_string(), "classifier".to_string());
550
551        assert!(mode.is_feature_extractor("encoder.layer1"));
552        assert!(mode.is_feature_extractor("encoder.layer2"));
553        assert!(!mode.is_feature_extractor("classifier.fc"));
554
555        assert!(mode.is_head("classifier.fc"));
556        assert!(mode.is_head("classifier.output"));
557        assert!(!mode.is_head("encoder.layer1"));
558
559        let all_layers = vec![
560            "encoder.layer1".to_string(),
561            "encoder.layer2".to_string(),
562            "classifier.fc".to_string(),
563        ];
564
565        let config = mode.get_freezing_config(&all_layers);
566        assert!(config.is_frozen("encoder.layer1"));
567        assert!(config.is_frozen("encoder.layer2"));
568        assert!(!config.is_frozen("classifier.fc"));
569    }
570
571    #[test]
572    fn test_transfer_learning_manager() {
573        let mut manager = TransferLearningManager::new();
574
575        // Initially, all layers are trainable
576        assert!(manager.should_update_layer("layer1"));
577
578        // Freeze some layers
579        manager.freezing_config.freeze_layers(&["layer1"]);
580        assert!(!manager.should_update_layer("layer1"));
581        assert!(manager.should_update_layer("layer2"));
582    }
583
584    #[test]
585    fn test_transfer_learning_with_progressive_unfreezing() {
586        let layers = vec![
587            "layer1".to_string(),
588            "layer2".to_string(),
589            "layer3".to_string(),
590        ];
591        let unfreezing = ProgressiveUnfreezing::new(layers.clone(), 5).unwrap();
592
593        let mut manager = TransferLearningManager::new().with_progressive_unfreezing(unfreezing);
594
595        // Epoch 0: all should be frozen
596        manager.on_epoch_begin(0);
597        assert!(!manager.should_update_layer("layer1"));
598        assert!(!manager.should_update_layer("layer2"));
599        assert!(!manager.should_update_layer("layer3"));
600
601        // Epoch 5: last layer unfrozen
602        manager.on_epoch_begin(5);
603        assert!(!manager.should_update_layer("layer1"));
604        assert!(!manager.should_update_layer("layer2"));
605        assert!(manager.should_update_layer("layer3"));
606    }
607
608    #[test]
609    fn test_transfer_learning_with_discriminative_finetuning() {
610        let layers = vec![
611            "layer1".to_string(),
612            "layer2".to_string(),
613            "layer3".to_string(),
614        ];
615        let mut finetuning = DiscriminativeFineTuning::new(1e-3, 0.5).unwrap();
616        finetuning.compute_layer_lrs(&layers);
617
618        let manager = TransferLearningManager::new().with_discriminative_finetuning(finetuning);
619
620        // Check layer-specific learning rates
621        assert!((manager.get_layer_lr("layer3", 1e-3) - 1e-3).abs() < 1e-10);
622        assert!((manager.get_layer_lr("layer2", 1e-3) - 5e-4).abs() < 1e-10);
623        assert!((manager.get_layer_lr("layer1", 1e-3) - 2.5e-4).abs() < 1e-10);
624    }
625
626    #[test]
627    fn test_transfer_learning_with_feature_extraction() {
628        let mode = FeatureExtractorMode::new("encoder".to_string(), "classifier".to_string());
629        let all_layers = vec![
630            "encoder.layer1".to_string(),
631            "encoder.layer2".to_string(),
632            "classifier.fc".to_string(),
633        ];
634
635        let manager = TransferLearningManager::new().with_feature_extraction(mode, &all_layers);
636
637        // Encoder should be frozen
638        assert!(!manager.should_update_layer("encoder.layer1"));
639        assert!(!manager.should_update_layer("encoder.layer2"));
640
641        // Classifier should be trainable
642        assert!(manager.should_update_layer("classifier.fc"));
643    }
644
645    #[test]
646    fn test_frozen_layers_getter() {
647        let mut config = LayerFreezingConfig::new();
648        config.freeze_layers(&["layer1", "layer2"]);
649
650        let frozen = config.frozen_layers();
651        assert_eq!(frozen.len(), 2);
652        assert!(frozen.contains(&"layer1".to_string()));
653        assert!(frozen.contains(&"layer2".to_string()));
654    }
655
656    #[test]
657    fn test_progressive_unfreezing_total_stages() {
658        let layers = vec!["layer1".to_string(), "layer2".to_string()];
659        let unfreezing = ProgressiveUnfreezing::new(layers, 5).unwrap();
660
661        assert_eq!(unfreezing.total_stages(), 2);
662    }
663}