1use serde::{Deserialize, Serialize};
37use std::collections::HashMap;
38use trustformers_core::errors::TrustformersError;
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ProgressiveConfig {
43 pub growth_dimension: GrowthDimension,
45 pub growth_strategy: GrowthStrategy,
47 pub initial_size: usize,
49 pub final_size: usize,
51 pub growth_epochs: Vec<usize>,
53 pub warmup_steps: usize,
55 pub zero_init_new_params: bool,
57 pub lr_scaling_factor: f64,
59 pub gradual_initialization: bool,
61 pub transition_smoothing: f64,
63 pub freeze_old_params_during_warmup: bool,
65}
66
67impl Default for ProgressiveConfig {
68 fn default() -> Self {
69 Self {
70 growth_dimension: GrowthDimension::Layers,
71 growth_strategy: GrowthStrategy::Linear,
72 initial_size: 6,
73 final_size: 12,
74 growth_epochs: vec![10, 20, 30, 40],
75 warmup_steps: 1000,
76 zero_init_new_params: true,
77 lr_scaling_factor: 0.5,
78 gradual_initialization: true,
79 transition_smoothing: 0.1,
80 freeze_old_params_during_warmup: false,
81 }
82 }
83}
84
85#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
87pub enum GrowthDimension {
88 Layers,
90 HiddenDim,
92 AttentionHeads,
94 IntermediateDim,
96 VocabSize,
98 MultiDimensional,
100}
101
102#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
104pub enum GrowthStrategy {
105 Linear,
107 Exponential,
109 Logarithmic,
111 Adaptive,
113 Custom,
115 Staged,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct GrowthSchedule {
122 pub growth_points: HashMap<usize, usize>,
124 pub adaptive: bool,
126 pub min_growth_interval: usize,
128 pub max_growth_per_step: usize,
130}
131
132pub struct ProgressiveTrainer {
134 config: ProgressiveConfig,
135 current_size: usize,
136 current_epoch: usize,
137 current_step: usize,
138 growth_schedule: GrowthSchedule,
139 growth_history: Vec<GrowthEvent>,
140 warmup_remaining: usize,
141 frozen_parameters: HashSet<String>,
142 learning_progress: LearningProgress,
143}
144
145use std::collections::HashSet;
146
147impl ProgressiveTrainer {
148 pub fn new(config: ProgressiveConfig) -> Result<Self, TrustformersError> {
150 let growth_schedule = Self::create_growth_schedule(&config)?;
151
152 Ok(Self {
153 current_size: config.initial_size,
154 current_epoch: 0,
155 current_step: 0,
156 growth_schedule,
157 growth_history: Vec::new(),
158 warmup_remaining: 0,
159 frozen_parameters: HashSet::new(),
160 learning_progress: LearningProgress::new(),
161 config,
162 })
163 }
164
165 fn create_growth_schedule(
167 config: &ProgressiveConfig,
168 ) -> Result<GrowthSchedule, TrustformersError> {
169 let mut growth_points = HashMap::new();
170
171 match config.growth_strategy {
172 GrowthStrategy::Linear => {
173 let total_growth = config.final_size - config.initial_size;
174 let num_steps = config.growth_epochs.len();
175 let growth_per_step = total_growth / num_steps.max(1);
176
177 for (i, &epoch) in config.growth_epochs.iter().enumerate() {
178 let new_size = config.initial_size + (i + 1) * growth_per_step;
179 growth_points.insert(epoch, new_size.min(config.final_size));
180 }
181 },
182 GrowthStrategy::Exponential => {
183 for (i, &epoch) in config.growth_epochs.iter().enumerate() {
184 let progress = (i + 1) as f64 / config.growth_epochs.len() as f64;
185 let exp_progress = progress.powf(2.0);
186 let new_size = config.initial_size
187 + ((config.final_size - config.initial_size) as f64 * exp_progress)
188 as usize;
189 growth_points.insert(epoch, new_size.min(config.final_size));
190 }
191 },
192 GrowthStrategy::Logarithmic => {
193 for (i, &epoch) in config.growth_epochs.iter().enumerate() {
194 let progress = (i + 1) as f64 / config.growth_epochs.len() as f64;
195 let log_progress = (1.0 + progress).ln() / (2.0_f64).ln();
196 let new_size = config.initial_size
197 + ((config.final_size - config.initial_size) as f64 * log_progress)
198 as usize;
199 growth_points.insert(epoch, new_size.min(config.final_size));
200 }
201 },
202 GrowthStrategy::Adaptive => {
203 for (i, &epoch) in config.growth_epochs.iter().enumerate() {
205 let progress = (i + 1) as f64 / config.growth_epochs.len() as f64;
206 let new_size = config.initial_size
207 + ((config.final_size - config.initial_size) as f64 * progress) as usize;
208 growth_points.insert(epoch, new_size.min(config.final_size));
209 }
210 },
211 GrowthStrategy::Staged => {
212 let stage_size =
213 (config.final_size - config.initial_size) / config.growth_epochs.len().max(1);
214 for (i, &epoch) in config.growth_epochs.iter().enumerate() {
215 let new_size = config.initial_size + (i + 1) * stage_size;
216 growth_points.insert(epoch, new_size.min(config.final_size));
217 }
218 },
219 GrowthStrategy::Custom => {
220 },
222 }
223
224 Ok(GrowthSchedule {
225 growth_points,
226 adaptive: matches!(config.growth_strategy, GrowthStrategy::Adaptive),
227 min_growth_interval: 5,
228 max_growth_per_step: (config.final_size - config.initial_size) / 2,
229 })
230 }
231
232 pub fn should_grow(&self, epoch: usize) -> bool {
234 if self.warmup_remaining > 0 {
235 return false;
236 }
237
238 if let Some(&target_size) = self.growth_schedule.growth_points.get(&epoch) {
239 return target_size > self.current_size;
240 }
241
242 if self.growth_schedule.adaptive {
244 return self.learning_progress.should_trigger_growth(epoch);
245 }
246
247 false
248 }
249
250 pub fn grow_model(
252 &mut self,
253 model: &mut dyn ProgressiveModel,
254 epoch: usize,
255 ) -> Result<GrowthResult, TrustformersError> {
256 let target_size = self
257 .growth_schedule
258 .growth_points
259 .get(&epoch)
260 .copied()
261 .unwrap_or_else(|| self.determine_adaptive_growth_size(epoch));
262
263 if target_size <= self.current_size {
264 return Ok(GrowthResult::NoGrowthNeeded);
265 }
266
267 let growth_amount = target_size - self.current_size;
268 let start_time = std::time::Instant::now();
269
270 let growth_info = match self.config.growth_dimension {
272 GrowthDimension::Layers => self.grow_layers(model, growth_amount)?,
273 GrowthDimension::HiddenDim => self.grow_hidden_dimension(model, target_size)?,
274 GrowthDimension::AttentionHeads => self.grow_attention_heads(model, target_size)?,
275 GrowthDimension::IntermediateDim => {
276 self.grow_intermediate_dimension(model, target_size)?
277 },
278 GrowthDimension::VocabSize => self.grow_vocabulary(model, target_size)?,
279 GrowthDimension::MultiDimensional => self.grow_multi_dimensional(model, target_size)?,
280 };
281
282 let growth_event = GrowthEvent {
284 epoch,
285 old_size: self.current_size,
286 new_size: target_size,
287 growth_dimension: self.config.growth_dimension,
288 growth_time: start_time.elapsed(),
289 growth_info: growth_info.clone(),
290 };
291
292 self.growth_history.push(growth_event);
293 self.current_size = target_size;
294 self.warmup_remaining = self.config.warmup_steps;
295
296 if self.config.freeze_old_params_during_warmup {
298 self.freeze_old_parameters(model)?;
299 }
300
301 Ok(GrowthResult::Grown {
302 old_size: self.current_size,
303 new_size: target_size,
304 growth_info,
305 })
306 }
307
308 fn grow_layers(
310 &mut self,
311 model: &mut dyn ProgressiveModel,
312 num_layers: usize,
313 ) -> Result<GrowthInfo, TrustformersError> {
314 let mut added_parameters = 0;
315 let mut initialization_method = String::new();
316
317 for i in 0..num_layers {
318 let layer_params = model.add_layer(self.current_size + i)?;
319 added_parameters += layer_params;
320
321 if self.config.gradual_initialization {
322 let scale = self.config.transition_smoothing * (i + 1) as f64 / num_layers as f64;
324 model.scale_layer_parameters(self.current_size + i, scale)?;
325 initialization_method = format!("Gradual scaling (factor: {})", scale);
326 } else if self.config.zero_init_new_params {
327 model.zero_initialize_layer(self.current_size + i)?;
328 initialization_method = "Zero initialization".to_string();
329 }
330 }
331
332 Ok(GrowthInfo {
333 added_parameters,
334 initialization_method,
335 growth_type: "Layer addition".to_string(),
336 })
337 }
338
339 fn grow_hidden_dimension(
341 &mut self,
342 model: &mut dyn ProgressiveModel,
343 target_dim: usize,
344 ) -> Result<GrowthInfo, TrustformersError> {
345 let old_dim = model.get_hidden_dimension()?;
346 let _growth = target_dim - old_dim;
347
348 let added_parameters = model.expand_hidden_dimension(target_dim)?;
349
350 if self.config.gradual_initialization {
352 model.initialize_expanded_dimensions(
353 old_dim,
354 target_dim,
355 self.config.transition_smoothing,
356 )?;
357 }
358
359 Ok(GrowthInfo {
360 added_parameters,
361 initialization_method: "Hidden dimension expansion".to_string(),
362 growth_type: format!("Hidden dim: {} -> {}", old_dim, target_dim),
363 })
364 }
365
366 fn grow_attention_heads(
368 &mut self,
369 model: &mut dyn ProgressiveModel,
370 target_heads: usize,
371 ) -> Result<GrowthInfo, TrustformersError> {
372 let old_heads = model.get_num_attention_heads()?;
373 let added_parameters = model.expand_attention_heads(target_heads)?;
374
375 Ok(GrowthInfo {
376 added_parameters,
377 initialization_method: "Attention head expansion".to_string(),
378 growth_type: format!("Attention heads: {} -> {}", old_heads, target_heads),
379 })
380 }
381
382 fn grow_intermediate_dimension(
384 &mut self,
385 model: &mut dyn ProgressiveModel,
386 target_dim: usize,
387 ) -> Result<GrowthInfo, TrustformersError> {
388 let old_dim = model.get_intermediate_dimension()?;
389 let added_parameters = model.expand_intermediate_dimension(target_dim)?;
390
391 Ok(GrowthInfo {
392 added_parameters,
393 initialization_method: "Intermediate dimension expansion".to_string(),
394 growth_type: format!("Intermediate dim: {} -> {}", old_dim, target_dim),
395 })
396 }
397
398 fn grow_vocabulary(
400 &mut self,
401 model: &mut dyn ProgressiveModel,
402 target_vocab: usize,
403 ) -> Result<GrowthInfo, TrustformersError> {
404 let old_vocab = model.get_vocab_size()?;
405 let added_parameters = model.expand_vocabulary(target_vocab)?;
406
407 Ok(GrowthInfo {
408 added_parameters,
409 initialization_method: "Vocabulary expansion".to_string(),
410 growth_type: format!("Vocab size: {} -> {}", old_vocab, target_vocab),
411 })
412 }
413
414 fn grow_multi_dimensional(
416 &mut self,
417 model: &mut dyn ProgressiveModel,
418 _target_size: usize,
419 ) -> Result<GrowthInfo, TrustformersError> {
420 let mut total_added_parameters = 0;
422
423 if self.current_size < self.config.final_size / 2 {
425 let layer_growth = self.grow_layers(model, 1)?;
426 total_added_parameters += layer_growth.added_parameters;
427 }
428
429 let current_hidden = model.get_hidden_dimension()?;
431 if current_hidden < 1024 {
432 let width_growth = self.grow_hidden_dimension(model, current_hidden + 64)?;
434 total_added_parameters += width_growth.added_parameters;
435 }
436
437 Ok(GrowthInfo {
438 added_parameters: total_added_parameters,
439 initialization_method: "Multi-dimensional growth".to_string(),
440 growth_type: "Combined layer and width growth".to_string(),
441 })
442 }
443
444 fn determine_adaptive_growth_size(&self, _epoch: usize) -> usize {
446 if self.learning_progress.is_plateau() {
448 (self.current_size as f64 * 1.2) as usize } else {
450 self.current_size + 1 }
452 }
453
454 fn freeze_old_parameters(
456 &mut self,
457 model: &mut dyn ProgressiveModel,
458 ) -> Result<(), TrustformersError> {
459 let old_param_names = model.get_parameter_names()?;
460 for name in old_param_names {
461 self.frozen_parameters.insert(name);
462 }
463 model.freeze_parameters(&self.frozen_parameters)?;
464 Ok(())
465 }
466
467 fn unfreeze_parameters(
469 &mut self,
470 model: &mut dyn ProgressiveModel,
471 ) -> Result<(), TrustformersError> {
472 model.unfreeze_parameters(&self.frozen_parameters)?;
473 self.frozen_parameters.clear();
474 Ok(())
475 }
476
477 pub fn step(
479 &mut self,
480 model: &mut dyn ProgressiveModel,
481 loss: f64,
482 ) -> Result<(), TrustformersError> {
483 self.current_step += 1;
484
485 self.learning_progress.update(loss);
487
488 if self.warmup_remaining > 0 {
490 self.warmup_remaining -= 1;
491 if self.warmup_remaining == 0 && !self.frozen_parameters.is_empty() {
492 self.unfreeze_parameters(model)?;
493 }
494 }
495
496 Ok(())
497 }
498
499 pub fn set_epoch(&mut self, epoch: usize) {
501 self.current_epoch = epoch;
502 self.learning_progress.new_epoch();
503 }
504
505 pub fn current_size(&self) -> usize {
507 self.current_size
508 }
509
510 pub fn growth_history(&self) -> &[GrowthEvent] {
512 &self.growth_history
513 }
514
515 pub fn is_in_warmup(&self) -> bool {
517 self.warmup_remaining > 0
518 }
519
520 pub fn learning_progress(&self) -> &LearningProgress {
522 &self.learning_progress
523 }
524
525 pub fn update_growth_schedule(&mut self, new_points: HashMap<usize, usize>) {
527 self.growth_schedule.growth_points.extend(new_points);
528 }
529}
530
531#[derive(Debug, Clone, Serialize, Deserialize)]
533pub struct GrowthInfo {
534 pub added_parameters: usize,
536 pub initialization_method: String,
538 pub growth_type: String,
540}
541
542#[derive(Debug)]
544pub enum GrowthResult {
545 Grown {
547 old_size: usize,
548 new_size: usize,
549 growth_info: GrowthInfo,
550 },
551 NoGrowthNeeded,
553}
554
555#[derive(Debug, Clone, Serialize, Deserialize)]
557pub struct GrowthEvent {
558 pub epoch: usize,
560 pub old_size: usize,
562 pub new_size: usize,
564 pub growth_dimension: GrowthDimension,
566 pub growth_time: std::time::Duration,
568 pub growth_info: GrowthInfo,
570}
571
572#[derive(Debug)]
574pub struct LearningProgress {
575 loss_history: Vec<f64>,
576 recent_losses: std::collections::VecDeque<f64>,
577 plateau_threshold: f64,
578 plateau_patience: usize,
579 #[allow(dead_code)]
580 improvement_threshold: f64,
581 current_epoch: usize,
582}
583
584impl Default for LearningProgress {
585 fn default() -> Self {
586 Self::new()
587 }
588}
589
590impl LearningProgress {
591 pub fn new() -> Self {
592 Self {
593 loss_history: Vec::new(),
594 recent_losses: std::collections::VecDeque::with_capacity(10),
595 plateau_threshold: 0.001,
596 plateau_patience: 5,
597 improvement_threshold: 0.01,
598 current_epoch: 0,
599 }
600 }
601
602 pub fn update(&mut self, loss: f64) {
603 self.loss_history.push(loss);
604 self.recent_losses.push_back(loss);
605 if self.recent_losses.len() > 10 {
606 self.recent_losses.pop_front();
607 }
608 }
609
610 pub fn is_plateau(&self) -> bool {
611 if self.recent_losses.len() < self.plateau_patience {
612 return false;
613 }
614
615 let recent_avg = self.recent_losses.iter().sum::<f64>() / self.recent_losses.len() as f64;
616 let older_losses = &self.loss_history[self.loss_history.len().saturating_sub(20)
617 ..self.loss_history.len().saturating_sub(10)];
618
619 if older_losses.is_empty() {
620 return false;
621 }
622
623 let older_avg = older_losses.iter().sum::<f64>() / older_losses.len() as f64;
624 let improvement = older_avg - recent_avg;
625
626 improvement < self.plateau_threshold
627 }
628
629 pub fn should_trigger_growth(&self, _epoch: usize) -> bool {
630 self.is_plateau() && self.loss_history.len() > 100
631 }
632
633 pub fn new_epoch(&mut self) {
634 self.current_epoch += 1;
635 }
636}
637
638pub trait ProgressiveModel {
640 fn add_layer(&mut self, layer_index: usize) -> Result<usize, TrustformersError>;
642
643 fn expand_hidden_dimension(&mut self, target_dim: usize) -> Result<usize, TrustformersError>;
645
646 fn expand_attention_heads(&mut self, target_heads: usize) -> Result<usize, TrustformersError>;
648
649 fn expand_intermediate_dimension(
651 &mut self,
652 target_dim: usize,
653 ) -> Result<usize, TrustformersError>;
654
655 fn expand_vocabulary(&mut self, target_vocab: usize) -> Result<usize, TrustformersError>;
657
658 fn get_hidden_dimension(&self) -> Result<usize, TrustformersError>;
660
661 fn get_num_attention_heads(&self) -> Result<usize, TrustformersError>;
663
664 fn get_intermediate_dimension(&self) -> Result<usize, TrustformersError>;
666
667 fn get_vocab_size(&self) -> Result<usize, TrustformersError>;
669
670 fn zero_initialize_layer(&mut self, layer_index: usize) -> Result<(), TrustformersError>;
672
673 fn scale_layer_parameters(
675 &mut self,
676 layer_index: usize,
677 scale: f64,
678 ) -> Result<(), TrustformersError>;
679
680 fn initialize_expanded_dimensions(
682 &mut self,
683 old_dim: usize,
684 new_dim: usize,
685 smoothing: f64,
686 ) -> Result<(), TrustformersError>;
687
688 fn get_parameter_names(&self) -> Result<Vec<String>, TrustformersError>;
690
691 fn freeze_parameters(&mut self, param_names: &HashSet<String>)
693 -> Result<(), TrustformersError>;
694
695 fn unfreeze_parameters(
697 &mut self,
698 param_names: &HashSet<String>,
699 ) -> Result<(), TrustformersError>;
700}
701
702pub mod utils {
704
705 pub fn create_linear_schedule(
707 initial_size: usize,
708 final_size: usize,
709 num_steps: usize,
710 start_epoch: usize,
711 epoch_interval: usize,
712 ) -> Vec<usize> {
713 let _growth_per_step = (final_size - initial_size) / num_steps.max(1);
714 (0..num_steps).map(|i| start_epoch + i * epoch_interval).collect()
715 }
716
717 pub fn create_exponential_schedule(
719 _initial_size: usize,
720 _final_size: usize,
721 num_steps: usize,
722 start_epoch: usize,
723 epoch_interval: usize,
724 ) -> Vec<usize> {
725 (0..num_steps)
726 .map(|i| start_epoch + (epoch_interval as f64 * (1.5_f64.powi(i as i32))) as usize)
727 .collect()
728 }
729
730 pub fn estimate_parameter_count(
732 vocab_size: usize,
733 hidden_dim: usize,
734 num_layers: usize,
735 _num_heads: usize,
736 intermediate_dim: usize,
737 ) -> usize {
738 let embedding_params = vocab_size * hidden_dim;
740
741 let attention_params = 4 * hidden_dim * hidden_dim; let ffn_params = 2 * hidden_dim * intermediate_dim; let norm_params = 2 * hidden_dim; let layer_params = attention_params + ffn_params + norm_params;
746
747 embedding_params + num_layers * layer_params + hidden_dim }
750
751 pub fn calculate_optimal_schedule(
753 initial_size: usize,
754 final_size: usize,
755 total_epochs: usize,
756 _computational_budget: f64,
757 ) -> Vec<usize> {
758 let mut schedule = Vec::new();
760 let num_growth_steps = ((final_size - initial_size) as f64).sqrt() as usize;
761
762 for i in 0..num_growth_steps {
763 let progress = i as f64 / num_growth_steps as f64;
764 let epoch = (total_epochs as f64 * progress.sqrt()) as usize;
765 schedule.push(epoch);
766 }
767
768 schedule
769 }
770}
771
772#[cfg(test)]
773mod tests {
774 use super::*;
775
776 #[test]
777 fn test_progressive_config_default() {
778 let config = ProgressiveConfig::default();
779 assert_eq!(config.initial_size, 6);
780 assert_eq!(config.final_size, 12);
781 assert!(config.zero_init_new_params);
782 }
783
784 #[test]
785 fn test_growth_schedule_creation() {
786 let config = ProgressiveConfig {
787 growth_strategy: GrowthStrategy::Linear,
788 initial_size: 4,
789 final_size: 12,
790 growth_epochs: vec![10, 20, 30, 40],
791 ..Default::default()
792 };
793
794 let schedule =
795 ProgressiveTrainer::create_growth_schedule(&config).expect("operation failed");
796 assert!(!schedule.growth_points.is_empty());
797 assert_eq!(schedule.growth_points.len(), 4);
798 }
799
800 #[test]
801 fn test_progressive_trainer_creation() {
802 let config = ProgressiveConfig::default();
803 let trainer = ProgressiveTrainer::new(config);
804 assert!(trainer.is_ok());
805
806 let trainer = trainer.expect("operation failed");
807 assert_eq!(trainer.current_size(), 6);
808 assert!(!trainer.is_in_warmup());
809 }
810
811 #[test]
812 fn test_learning_progress() {
813 let mut progress = LearningProgress::new();
814
815 for i in 0..20 {
817 progress.update(1.0 - i as f64 * 0.01); }
819
820 assert!(!progress.is_plateau());
821
822 for _ in 0..25 {
826 progress.update(0.8); }
828
829 assert!(progress.is_plateau());
830 }
831
832 #[test]
833 fn test_growth_dimensions() {
834 assert_eq!(GrowthDimension::Layers as u8, 0);
835 assert_ne!(GrowthDimension::Layers, GrowthDimension::HiddenDim);
836 }
837
838 #[test]
839 fn test_growth_strategies() {
840 assert_eq!(GrowthStrategy::Linear as u8, 0);
841 assert_ne!(GrowthStrategy::Linear, GrowthStrategy::Exponential);
842 }
843
844 #[test]
845 fn test_utils_parameter_estimation() {
846 let params = utils::estimate_parameter_count(30000, 768, 12, 12, 3072);
847 assert!(params > 100_000_000); }
849
850 #[test]
851 fn test_utils_linear_schedule() {
852 let schedule = utils::create_linear_schedule(6, 12, 3, 10, 5);
853 assert_eq!(schedule.len(), 3);
854 assert_eq!(schedule[0], 10);
855 assert_eq!(schedule[1], 15);
856 assert_eq!(schedule[2], 20);
857 }
858}