Skip to main content

trustformers_models/
progressive_training.rs

1/*!
2# Progressive Training Module
3
4This module provides progressive training capabilities for transformer models, enabling
5models to grow in capacity during training for improved efficiency and performance.
6
7## Features
8
9- **Layer Progressive Training**: Gradually add layers during training
10- **Width Progressive Training**: Progressively increase hidden dimensions
11- **Head Progressive Training**: Add attention heads incrementally
12- **Multiple Growth Strategies**: Linear, exponential, adaptive growth schedules
13- **Smooth Transitions**: Gradual parameter initialization and adaptation
14- **Curriculum Integration**: Compatible with curriculum learning frameworks
15
16## Usage
17
18```rust
19use trustformers_models::progressive_training::{
20    ProgressiveTrainer, ProgressiveConfig, GrowthStrategy, GrowthDimension
21};
22
23let config = ProgressiveConfig {
24    growth_dimension: GrowthDimension::Layers,
25    growth_strategy: GrowthStrategy::Linear,
26    initial_size: 6,
27    final_size: 12,
28    growth_epochs: vec![10, 20, 30],
29    warmup_steps: 1000,
30};
31
32let mut trainer = ProgressiveTrainer::new(config)?;
33```
34*/
35
36use serde::{Deserialize, Serialize};
37use std::collections::HashMap;
38use trustformers_core::errors::TrustformersError;
39
40/// Configuration for progressive training
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ProgressiveConfig {
43    /// Which dimension to grow progressively
44    pub growth_dimension: GrowthDimension,
45    /// Growth strategy to use
46    pub growth_strategy: GrowthStrategy,
47    /// Initial model size (layers, hidden dim, heads, etc.)
48    pub initial_size: usize,
49    /// Final model size
50    pub final_size: usize,
51    /// Epochs at which to trigger growth
52    pub growth_epochs: Vec<usize>,
53    /// Steps to warm up after each growth
54    pub warmup_steps: usize,
55    /// Whether to initialize new parameters with zeros
56    pub zero_init_new_params: bool,
57    /// Learning rate scaling factor after growth
58    pub lr_scaling_factor: f64,
59    /// Whether to use gradual weight initialization
60    pub gradual_initialization: bool,
61    /// Smoothing factor for parameter transitions
62    pub transition_smoothing: f64,
63    /// Whether to freeze old parameters during warmup
64    pub freeze_old_params_during_warmup: bool,
65}
66
67impl Default for ProgressiveConfig {
68    fn default() -> Self {
69        Self {
70            growth_dimension: GrowthDimension::Layers,
71            growth_strategy: GrowthStrategy::Linear,
72            initial_size: 6,
73            final_size: 12,
74            growth_epochs: vec![10, 20, 30, 40],
75            warmup_steps: 1000,
76            zero_init_new_params: true,
77            lr_scaling_factor: 0.5,
78            gradual_initialization: true,
79            transition_smoothing: 0.1,
80            freeze_old_params_during_warmup: false,
81        }
82    }
83}
84
85/// Dimensions along which models can grow
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
87pub enum GrowthDimension {
88    /// Add transformer layers
89    Layers,
90    /// Increase hidden dimension
91    HiddenDim,
92    /// Add attention heads
93    AttentionHeads,
94    /// Increase intermediate (FFN) dimension
95    IntermediateDim,
96    /// Grow vocabulary size
97    VocabSize,
98    /// Multi-dimensional growth (combined)
99    MultiDimensional,
100}
101
102/// Growth strategies for progressive training
103#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
104pub enum GrowthStrategy {
105    /// Linear growth at fixed intervals
106    Linear,
107    /// Exponential growth (larger jumps later)
108    Exponential,
109    /// Logarithmic growth (larger jumps earlier)
110    Logarithmic,
111    /// Adaptive growth based on learning progress
112    Adaptive,
113    /// Custom growth schedule
114    Custom,
115    /// Staged growth (fixed size increases)
116    Staged,
117}
118
119/// Growth schedule definition
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct GrowthSchedule {
122    /// Epoch -> new size mapping
123    pub growth_points: HashMap<usize, usize>,
124    /// Whether the schedule is adaptive
125    pub adaptive: bool,
126    /// Minimum epochs between growth steps
127    pub min_growth_interval: usize,
128    /// Maximum growth per step
129    pub max_growth_per_step: usize,
130}
131
132/// Progressive trainer for growing models during training
133pub struct ProgressiveTrainer {
134    config: ProgressiveConfig,
135    current_size: usize,
136    current_epoch: usize,
137    current_step: usize,
138    growth_schedule: GrowthSchedule,
139    growth_history: Vec<GrowthEvent>,
140    warmup_remaining: usize,
141    frozen_parameters: HashSet<String>,
142    learning_progress: LearningProgress,
143}
144
145use std::collections::HashSet;
146
147impl ProgressiveTrainer {
148    /// Create a new progressive trainer
149    pub fn new(config: ProgressiveConfig) -> Result<Self, TrustformersError> {
150        let growth_schedule = Self::create_growth_schedule(&config)?;
151
152        Ok(Self {
153            current_size: config.initial_size,
154            current_epoch: 0,
155            current_step: 0,
156            growth_schedule,
157            growth_history: Vec::new(),
158            warmup_remaining: 0,
159            frozen_parameters: HashSet::new(),
160            learning_progress: LearningProgress::new(),
161            config,
162        })
163    }
164
165    /// Create growth schedule based on configuration
166    fn create_growth_schedule(
167        config: &ProgressiveConfig,
168    ) -> Result<GrowthSchedule, TrustformersError> {
169        let mut growth_points = HashMap::new();
170
171        match config.growth_strategy {
172            GrowthStrategy::Linear => {
173                let total_growth = config.final_size - config.initial_size;
174                let num_steps = config.growth_epochs.len();
175                let growth_per_step = total_growth / num_steps.max(1);
176
177                for (i, &epoch) in config.growth_epochs.iter().enumerate() {
178                    let new_size = config.initial_size + (i + 1) * growth_per_step;
179                    growth_points.insert(epoch, new_size.min(config.final_size));
180                }
181            },
182            GrowthStrategy::Exponential => {
183                for (i, &epoch) in config.growth_epochs.iter().enumerate() {
184                    let progress = (i + 1) as f64 / config.growth_epochs.len() as f64;
185                    let exp_progress = progress.powf(2.0);
186                    let new_size = config.initial_size
187                        + ((config.final_size - config.initial_size) as f64 * exp_progress)
188                            as usize;
189                    growth_points.insert(epoch, new_size.min(config.final_size));
190                }
191            },
192            GrowthStrategy::Logarithmic => {
193                for (i, &epoch) in config.growth_epochs.iter().enumerate() {
194                    let progress = (i + 1) as f64 / config.growth_epochs.len() as f64;
195                    let log_progress = (1.0 + progress).ln() / (2.0_f64).ln();
196                    let new_size = config.initial_size
197                        + ((config.final_size - config.initial_size) as f64 * log_progress)
198                            as usize;
199                    growth_points.insert(epoch, new_size.min(config.final_size));
200                }
201            },
202            GrowthStrategy::Adaptive => {
203                // Initial schedule, will be updated based on learning progress
204                for (i, &epoch) in config.growth_epochs.iter().enumerate() {
205                    let progress = (i + 1) as f64 / config.growth_epochs.len() as f64;
206                    let new_size = config.initial_size
207                        + ((config.final_size - config.initial_size) as f64 * progress) as usize;
208                    growth_points.insert(epoch, new_size.min(config.final_size));
209                }
210            },
211            GrowthStrategy::Staged => {
212                let stage_size =
213                    (config.final_size - config.initial_size) / config.growth_epochs.len().max(1);
214                for (i, &epoch) in config.growth_epochs.iter().enumerate() {
215                    let new_size = config.initial_size + (i + 1) * stage_size;
216                    growth_points.insert(epoch, new_size.min(config.final_size));
217                }
218            },
219            GrowthStrategy::Custom => {
220                // Custom schedule should be provided externally
221            },
222        }
223
224        Ok(GrowthSchedule {
225            growth_points,
226            adaptive: matches!(config.growth_strategy, GrowthStrategy::Adaptive),
227            min_growth_interval: 5,
228            max_growth_per_step: (config.final_size - config.initial_size) / 2,
229        })
230    }
231
232    /// Check if model should grow at current epoch
233    pub fn should_grow(&self, epoch: usize) -> bool {
234        if self.warmup_remaining > 0 {
235            return false;
236        }
237
238        if let Some(&target_size) = self.growth_schedule.growth_points.get(&epoch) {
239            return target_size > self.current_size;
240        }
241
242        // Adaptive growth based on learning plateau
243        if self.growth_schedule.adaptive {
244            return self.learning_progress.should_trigger_growth(epoch);
245        }
246
247        false
248    }
249
250    /// Grow the model to the target size
251    pub fn grow_model(
252        &mut self,
253        model: &mut dyn ProgressiveModel,
254        epoch: usize,
255    ) -> Result<GrowthResult, TrustformersError> {
256        let target_size = self
257            .growth_schedule
258            .growth_points
259            .get(&epoch)
260            .copied()
261            .unwrap_or_else(|| self.determine_adaptive_growth_size(epoch));
262
263        if target_size <= self.current_size {
264            return Ok(GrowthResult::NoGrowthNeeded);
265        }
266
267        let growth_amount = target_size - self.current_size;
268        let start_time = std::time::Instant::now();
269
270        // Perform the actual growth based on dimension
271        let growth_info = match self.config.growth_dimension {
272            GrowthDimension::Layers => self.grow_layers(model, growth_amount)?,
273            GrowthDimension::HiddenDim => self.grow_hidden_dimension(model, target_size)?,
274            GrowthDimension::AttentionHeads => self.grow_attention_heads(model, target_size)?,
275            GrowthDimension::IntermediateDim => {
276                self.grow_intermediate_dimension(model, target_size)?
277            },
278            GrowthDimension::VocabSize => self.grow_vocabulary(model, target_size)?,
279            GrowthDimension::MultiDimensional => self.grow_multi_dimensional(model, target_size)?,
280        };
281
282        // Record growth event
283        let growth_event = GrowthEvent {
284            epoch,
285            old_size: self.current_size,
286            new_size: target_size,
287            growth_dimension: self.config.growth_dimension,
288            growth_time: start_time.elapsed(),
289            growth_info: growth_info.clone(),
290        };
291
292        self.growth_history.push(growth_event);
293        self.current_size = target_size;
294        self.warmup_remaining = self.config.warmup_steps;
295
296        // Optionally freeze old parameters during warmup
297        if self.config.freeze_old_params_during_warmup {
298            self.freeze_old_parameters(model)?;
299        }
300
301        Ok(GrowthResult::Grown {
302            old_size: self.current_size,
303            new_size: target_size,
304            growth_info,
305        })
306    }
307
308    /// Grow model by adding layers
309    fn grow_layers(
310        &mut self,
311        model: &mut dyn ProgressiveModel,
312        num_layers: usize,
313    ) -> Result<GrowthInfo, TrustformersError> {
314        let mut added_parameters = 0;
315        let mut initialization_method = String::new();
316
317        for i in 0..num_layers {
318            let layer_params = model.add_layer(self.current_size + i)?;
319            added_parameters += layer_params;
320
321            if self.config.gradual_initialization {
322                // Initialize with small weights that gradually increase
323                let scale = self.config.transition_smoothing * (i + 1) as f64 / num_layers as f64;
324                model.scale_layer_parameters(self.current_size + i, scale)?;
325                initialization_method = format!("Gradual scaling (factor: {})", scale);
326            } else if self.config.zero_init_new_params {
327                model.zero_initialize_layer(self.current_size + i)?;
328                initialization_method = "Zero initialization".to_string();
329            }
330        }
331
332        Ok(GrowthInfo {
333            added_parameters,
334            initialization_method,
335            growth_type: "Layer addition".to_string(),
336        })
337    }
338
339    /// Grow model by increasing hidden dimension
340    fn grow_hidden_dimension(
341        &mut self,
342        model: &mut dyn ProgressiveModel,
343        target_dim: usize,
344    ) -> Result<GrowthInfo, TrustformersError> {
345        let old_dim = model.get_hidden_dimension()?;
346        let _growth = target_dim - old_dim;
347
348        let added_parameters = model.expand_hidden_dimension(target_dim)?;
349
350        // Initialize new dimensions
351        if self.config.gradual_initialization {
352            model.initialize_expanded_dimensions(
353                old_dim,
354                target_dim,
355                self.config.transition_smoothing,
356            )?;
357        }
358
359        Ok(GrowthInfo {
360            added_parameters,
361            initialization_method: "Hidden dimension expansion".to_string(),
362            growth_type: format!("Hidden dim: {} -> {}", old_dim, target_dim),
363        })
364    }
365
366    /// Grow model by adding attention heads
367    fn grow_attention_heads(
368        &mut self,
369        model: &mut dyn ProgressiveModel,
370        target_heads: usize,
371    ) -> Result<GrowthInfo, TrustformersError> {
372        let old_heads = model.get_num_attention_heads()?;
373        let added_parameters = model.expand_attention_heads(target_heads)?;
374
375        Ok(GrowthInfo {
376            added_parameters,
377            initialization_method: "Attention head expansion".to_string(),
378            growth_type: format!("Attention heads: {} -> {}", old_heads, target_heads),
379        })
380    }
381
382    /// Grow intermediate (FFN) dimension
383    fn grow_intermediate_dimension(
384        &mut self,
385        model: &mut dyn ProgressiveModel,
386        target_dim: usize,
387    ) -> Result<GrowthInfo, TrustformersError> {
388        let old_dim = model.get_intermediate_dimension()?;
389        let added_parameters = model.expand_intermediate_dimension(target_dim)?;
390
391        Ok(GrowthInfo {
392            added_parameters,
393            initialization_method: "Intermediate dimension expansion".to_string(),
394            growth_type: format!("Intermediate dim: {} -> {}", old_dim, target_dim),
395        })
396    }
397
398    /// Grow vocabulary size
399    fn grow_vocabulary(
400        &mut self,
401        model: &mut dyn ProgressiveModel,
402        target_vocab: usize,
403    ) -> Result<GrowthInfo, TrustformersError> {
404        let old_vocab = model.get_vocab_size()?;
405        let added_parameters = model.expand_vocabulary(target_vocab)?;
406
407        Ok(GrowthInfo {
408            added_parameters,
409            initialization_method: "Vocabulary expansion".to_string(),
410            growth_type: format!("Vocab size: {} -> {}", old_vocab, target_vocab),
411        })
412    }
413
414    /// Multi-dimensional growth
415    fn grow_multi_dimensional(
416        &mut self,
417        model: &mut dyn ProgressiveModel,
418        _target_size: usize,
419    ) -> Result<GrowthInfo, TrustformersError> {
420        // Implement coordinated growth across multiple dimensions
421        let mut total_added_parameters = 0;
422
423        // Grow layers first
424        if self.current_size < self.config.final_size / 2 {
425            let layer_growth = self.grow_layers(model, 1)?;
426            total_added_parameters += layer_growth.added_parameters;
427        }
428
429        // Then grow width
430        let current_hidden = model.get_hidden_dimension()?;
431        if current_hidden < 1024 {
432            // Example threshold
433            let width_growth = self.grow_hidden_dimension(model, current_hidden + 64)?;
434            total_added_parameters += width_growth.added_parameters;
435        }
436
437        Ok(GrowthInfo {
438            added_parameters: total_added_parameters,
439            initialization_method: "Multi-dimensional growth".to_string(),
440            growth_type: "Combined layer and width growth".to_string(),
441        })
442    }
443
444    /// Determine adaptive growth size based on learning progress
445    fn determine_adaptive_growth_size(&self, _epoch: usize) -> usize {
446        // Adaptive logic based on learning plateau detection
447        if self.learning_progress.is_plateau() {
448            (self.current_size as f64 * 1.2) as usize // 20% increase
449        } else {
450            self.current_size + 1 // Conservative growth
451        }
452    }
453
454    /// Freeze old parameters during warmup
455    fn freeze_old_parameters(
456        &mut self,
457        model: &mut dyn ProgressiveModel,
458    ) -> Result<(), TrustformersError> {
459        let old_param_names = model.get_parameter_names()?;
460        for name in old_param_names {
461            self.frozen_parameters.insert(name);
462        }
463        model.freeze_parameters(&self.frozen_parameters)?;
464        Ok(())
465    }
466
467    /// Unfreeze parameters after warmup
468    fn unfreeze_parameters(
469        &mut self,
470        model: &mut dyn ProgressiveModel,
471    ) -> Result<(), TrustformersError> {
472        model.unfreeze_parameters(&self.frozen_parameters)?;
473        self.frozen_parameters.clear();
474        Ok(())
475    }
476
477    /// Update training state
478    pub fn step(
479        &mut self,
480        model: &mut dyn ProgressiveModel,
481        loss: f64,
482    ) -> Result<(), TrustformersError> {
483        self.current_step += 1;
484
485        // Update learning progress
486        self.learning_progress.update(loss);
487
488        // Handle warmup
489        if self.warmup_remaining > 0 {
490            self.warmup_remaining -= 1;
491            if self.warmup_remaining == 0 && !self.frozen_parameters.is_empty() {
492                self.unfreeze_parameters(model)?;
493            }
494        }
495
496        Ok(())
497    }
498
499    /// Set current epoch
500    pub fn set_epoch(&mut self, epoch: usize) {
501        self.current_epoch = epoch;
502        self.learning_progress.new_epoch();
503    }
504
505    /// Get current model size
506    pub fn current_size(&self) -> usize {
507        self.current_size
508    }
509
510    /// Get growth history
511    pub fn growth_history(&self) -> &[GrowthEvent] {
512        &self.growth_history
513    }
514
515    /// Get warmup status
516    pub fn is_in_warmup(&self) -> bool {
517        self.warmup_remaining > 0
518    }
519
520    /// Get learning progress
521    pub fn learning_progress(&self) -> &LearningProgress {
522        &self.learning_progress
523    }
524
525    /// Update growth schedule (for adaptive training)
526    pub fn update_growth_schedule(&mut self, new_points: HashMap<usize, usize>) {
527        self.growth_schedule.growth_points.extend(new_points);
528    }
529}
530
531/// Information about a growth operation
532#[derive(Debug, Clone, Serialize, Deserialize)]
533pub struct GrowthInfo {
534    /// Number of parameters added
535    pub added_parameters: usize,
536    /// Method used to initialize new parameters
537    pub initialization_method: String,
538    /// Type of growth performed
539    pub growth_type: String,
540}
541
542/// Result of a growth operation
543#[derive(Debug)]
544pub enum GrowthResult {
545    /// Model was grown successfully
546    Grown {
547        old_size: usize,
548        new_size: usize,
549        growth_info: GrowthInfo,
550    },
551    /// No growth was needed
552    NoGrowthNeeded,
553}
554
555/// Record of a growth event
556#[derive(Debug, Clone, Serialize, Deserialize)]
557pub struct GrowthEvent {
558    /// Epoch when growth occurred
559    pub epoch: usize,
560    /// Size before growth
561    pub old_size: usize,
562    /// Size after growth
563    pub new_size: usize,
564    /// Dimension that was grown
565    pub growth_dimension: GrowthDimension,
566    /// Time taken for growth operation
567    pub growth_time: std::time::Duration,
568    /// Additional growth information
569    pub growth_info: GrowthInfo,
570}
571
572/// Tracks learning progress for adaptive growth decisions
573#[derive(Debug)]
574pub struct LearningProgress {
575    loss_history: Vec<f64>,
576    recent_losses: std::collections::VecDeque<f64>,
577    plateau_threshold: f64,
578    plateau_patience: usize,
579    #[allow(dead_code)]
580    improvement_threshold: f64,
581    current_epoch: usize,
582}
583
584impl Default for LearningProgress {
585    fn default() -> Self {
586        Self::new()
587    }
588}
589
590impl LearningProgress {
591    pub fn new() -> Self {
592        Self {
593            loss_history: Vec::new(),
594            recent_losses: std::collections::VecDeque::with_capacity(10),
595            plateau_threshold: 0.001,
596            plateau_patience: 5,
597            improvement_threshold: 0.01,
598            current_epoch: 0,
599        }
600    }
601
602    pub fn update(&mut self, loss: f64) {
603        self.loss_history.push(loss);
604        self.recent_losses.push_back(loss);
605        if self.recent_losses.len() > 10 {
606            self.recent_losses.pop_front();
607        }
608    }
609
610    pub fn is_plateau(&self) -> bool {
611        if self.recent_losses.len() < self.plateau_patience {
612            return false;
613        }
614
615        let recent_avg = self.recent_losses.iter().sum::<f64>() / self.recent_losses.len() as f64;
616        let older_losses = &self.loss_history[self.loss_history.len().saturating_sub(20)
617            ..self.loss_history.len().saturating_sub(10)];
618
619        if older_losses.is_empty() {
620            return false;
621        }
622
623        let older_avg = older_losses.iter().sum::<f64>() / older_losses.len() as f64;
624        let improvement = older_avg - recent_avg;
625
626        improvement < self.plateau_threshold
627    }
628
629    pub fn should_trigger_growth(&self, _epoch: usize) -> bool {
630        self.is_plateau() && self.loss_history.len() > 100
631    }
632
633    pub fn new_epoch(&mut self) {
634        self.current_epoch += 1;
635    }
636}
637
638/// Trait that models must implement to support progressive training
639pub trait ProgressiveModel {
640    /// Add a new layer to the model
641    fn add_layer(&mut self, layer_index: usize) -> Result<usize, TrustformersError>;
642
643    /// Expand hidden dimension to target size
644    fn expand_hidden_dimension(&mut self, target_dim: usize) -> Result<usize, TrustformersError>;
645
646    /// Expand number of attention heads
647    fn expand_attention_heads(&mut self, target_heads: usize) -> Result<usize, TrustformersError>;
648
649    /// Expand intermediate (FFN) dimension
650    fn expand_intermediate_dimension(
651        &mut self,
652        target_dim: usize,
653    ) -> Result<usize, TrustformersError>;
654
655    /// Expand vocabulary size
656    fn expand_vocabulary(&mut self, target_vocab: usize) -> Result<usize, TrustformersError>;
657
658    /// Get current hidden dimension
659    fn get_hidden_dimension(&self) -> Result<usize, TrustformersError>;
660
661    /// Get current number of attention heads
662    fn get_num_attention_heads(&self) -> Result<usize, TrustformersError>;
663
664    /// Get current intermediate dimension
665    fn get_intermediate_dimension(&self) -> Result<usize, TrustformersError>;
666
667    /// Get current vocabulary size
668    fn get_vocab_size(&self) -> Result<usize, TrustformersError>;
669
670    /// Initialize a layer with zeros
671    fn zero_initialize_layer(&mut self, layer_index: usize) -> Result<(), TrustformersError>;
672
673    /// Scale layer parameters by a factor
674    fn scale_layer_parameters(
675        &mut self,
676        layer_index: usize,
677        scale: f64,
678    ) -> Result<(), TrustformersError>;
679
680    /// Initialize expanded dimensions with gradual scaling
681    fn initialize_expanded_dimensions(
682        &mut self,
683        old_dim: usize,
684        new_dim: usize,
685        smoothing: f64,
686    ) -> Result<(), TrustformersError>;
687
688    /// Get names of all parameters
689    fn get_parameter_names(&self) -> Result<Vec<String>, TrustformersError>;
690
691    /// Freeze specified parameters
692    fn freeze_parameters(&mut self, param_names: &HashSet<String>)
693        -> Result<(), TrustformersError>;
694
695    /// Unfreeze specified parameters
696    fn unfreeze_parameters(
697        &mut self,
698        param_names: &HashSet<String>,
699    ) -> Result<(), TrustformersError>;
700}
701
702/// Progressive training utilities
703pub mod utils {
704
705    /// Create a linear growth schedule
706    pub fn create_linear_schedule(
707        initial_size: usize,
708        final_size: usize,
709        num_steps: usize,
710        start_epoch: usize,
711        epoch_interval: usize,
712    ) -> Vec<usize> {
713        let _growth_per_step = (final_size - initial_size) / num_steps.max(1);
714        (0..num_steps).map(|i| start_epoch + i * epoch_interval).collect()
715    }
716
717    /// Create an exponential growth schedule
718    pub fn create_exponential_schedule(
719        _initial_size: usize,
720        _final_size: usize,
721        num_steps: usize,
722        start_epoch: usize,
723        epoch_interval: usize,
724    ) -> Vec<usize> {
725        (0..num_steps)
726            .map(|i| start_epoch + (epoch_interval as f64 * (1.5_f64.powi(i as i32))) as usize)
727            .collect()
728    }
729
730    /// Estimate parameter count for a transformer model
731    pub fn estimate_parameter_count(
732        vocab_size: usize,
733        hidden_dim: usize,
734        num_layers: usize,
735        _num_heads: usize,
736        intermediate_dim: usize,
737    ) -> usize {
738        // Embedding layer
739        let embedding_params = vocab_size * hidden_dim;
740
741        // Per-layer parameters
742        let attention_params = 4 * hidden_dim * hidden_dim; // Q, K, V, O projections
743        let ffn_params = 2 * hidden_dim * intermediate_dim; // Up and down projections
744        let norm_params = 2 * hidden_dim; // Layer norm
745        let layer_params = attention_params + ffn_params + norm_params;
746
747        // Total parameters
748        embedding_params + num_layers * layer_params + hidden_dim // Final layer norm
749    }
750
751    /// Calculate optimal growth schedule based on target training time
752    pub fn calculate_optimal_schedule(
753        initial_size: usize,
754        final_size: usize,
755        total_epochs: usize,
756        _computational_budget: f64,
757    ) -> Vec<usize> {
758        // Simple heuristic: grow more aggressively early when computation is cheaper
759        let mut schedule = Vec::new();
760        let num_growth_steps = ((final_size - initial_size) as f64).sqrt() as usize;
761
762        for i in 0..num_growth_steps {
763            let progress = i as f64 / num_growth_steps as f64;
764            let epoch = (total_epochs as f64 * progress.sqrt()) as usize;
765            schedule.push(epoch);
766        }
767
768        schedule
769    }
770}
771
772#[cfg(test)]
773mod tests {
774    use super::*;
775
776    #[test]
777    fn test_progressive_config_default() {
778        let config = ProgressiveConfig::default();
779        assert_eq!(config.initial_size, 6);
780        assert_eq!(config.final_size, 12);
781        assert!(config.zero_init_new_params);
782    }
783
784    #[test]
785    fn test_growth_schedule_creation() {
786        let config = ProgressiveConfig {
787            growth_strategy: GrowthStrategy::Linear,
788            initial_size: 4,
789            final_size: 12,
790            growth_epochs: vec![10, 20, 30, 40],
791            ..Default::default()
792        };
793
794        let schedule =
795            ProgressiveTrainer::create_growth_schedule(&config).expect("operation failed");
796        assert!(!schedule.growth_points.is_empty());
797        assert_eq!(schedule.growth_points.len(), 4);
798    }
799
800    #[test]
801    fn test_progressive_trainer_creation() {
802        let config = ProgressiveConfig::default();
803        let trainer = ProgressiveTrainer::new(config);
804        assert!(trainer.is_ok());
805
806        let trainer = trainer.expect("operation failed");
807        assert_eq!(trainer.current_size(), 6);
808        assert!(!trainer.is_in_warmup());
809    }
810
811    #[test]
812    fn test_learning_progress() {
813        let mut progress = LearningProgress::new();
814
815        // Add some losses
816        for i in 0..20 {
817            progress.update(1.0 - i as f64 * 0.01); // Decreasing loss
818        }
819
820        assert!(!progress.is_plateau());
821
822        // Add plateau losses - need enough constant losses so both
823        // "older" (indices len-20 to len-10) and "recent" (last 10)
824        // are all constant, making improvement near 0
825        for _ in 0..25 {
826            progress.update(0.8); // Constant loss
827        }
828
829        assert!(progress.is_plateau());
830    }
831
832    #[test]
833    fn test_growth_dimensions() {
834        assert_eq!(GrowthDimension::Layers as u8, 0);
835        assert_ne!(GrowthDimension::Layers, GrowthDimension::HiddenDim);
836    }
837
838    #[test]
839    fn test_growth_strategies() {
840        assert_eq!(GrowthStrategy::Linear as u8, 0);
841        assert_ne!(GrowthStrategy::Linear, GrowthStrategy::Exponential);
842    }
843
844    #[test]
845    fn test_utils_parameter_estimation() {
846        let params = utils::estimate_parameter_count(30000, 768, 12, 12, 3072);
847        assert!(params > 100_000_000); // Should be > 100M for BERT-base size
848    }
849
850    #[test]
851    fn test_utils_linear_schedule() {
852        let schedule = utils::create_linear_schedule(6, 12, 3, 10, 5);
853        assert_eq!(schedule.len(), 3);
854        assert_eq!(schedule[0], 10);
855        assert_eq!(schedule[1], 15);
856        assert_eq!(schedule[2], 20);
857    }
858}