1use crate::LRScheduler;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::{Arc, Mutex};
10use trustformers_core::errors::{Result, TrustformersError};
11use trustformers_core::traits::Optimizer;
12use trustformers_core::Tensor;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct FusionConfig {
17 pub fuse_parameters: bool,
19 pub fuse_gradients: bool,
21 pub fuse_state: bool,
23 pub window_size: usize,
25 pub memory_threshold: usize,
27}
28
29impl Default for FusionConfig {
30 fn default() -> Self {
31 Self {
32 fuse_parameters: true,
33 fuse_gradients: true,
34 fuse_state: true,
35 window_size: 32,
36 memory_threshold: 1024 * 1024 * 100, }
38 }
39}
40
41pub struct FusedOptimizer {
43 optimizers: Vec<Box<dyn Optimizer>>,
44 config: FusionConfig,
45 fused_parameters: Arc<Mutex<HashMap<String, Tensor>>>,
46 fused_gradients: Arc<Mutex<HashMap<String, Tensor>>>,
47 fusion_groups: Vec<Vec<usize>>, }
49
50impl FusedOptimizer {
51 pub fn new(optimizers: Vec<Box<dyn Optimizer>>, config: FusionConfig) -> Result<Self> {
53 let fusion_groups = Self::compute_fusion_groups(&optimizers, &config);
54
55 Ok(Self {
56 optimizers,
57 config,
58 fused_parameters: Arc::new(Mutex::new(HashMap::new())),
59 fused_gradients: Arc::new(Mutex::new(HashMap::new())),
60 fusion_groups,
61 })
62 }
63
64 fn compute_fusion_groups(
66 optimizers: &[Box<dyn Optimizer>],
67 config: &FusionConfig,
68 ) -> Vec<Vec<usize>> {
69 let mut groups = Vec::new();
70 let mut used = vec![false; optimizers.len()];
71
72 for i in 0..optimizers.len() {
73 if used[i] {
74 continue;
75 }
76
77 let mut group = vec![i];
78 used[i] = true;
79
80 for j in (i + 1)..optimizers.len() {
82 if used[j] {
83 continue;
84 }
85
86 if Self::can_fuse(&optimizers[i], &optimizers[j], config) {
87 group.push(j);
88 used[j] = true;
89 }
90 }
91
92 groups.push(group);
93 }
94
95 groups
96 }
97
98 fn can_fuse(
100 _opt1: &Box<dyn Optimizer>,
101 _opt2: &Box<dyn Optimizer>,
102 _config: &FusionConfig,
103 ) -> bool {
104 true
107 }
108
109 fn fuse_parameters(&self, parameters: &mut HashMap<String, Tensor>) -> Result<()> {
111 if !self.config.fuse_parameters {
112 return Ok(());
113 }
114
115 let mut fused_params = self.fused_parameters.lock().unwrap();
116 fused_params.clear();
117
118 for group in &self.fusion_groups {
120 if group.len() > 1 {
121 let group_params: Vec<_> = parameters
123 .iter()
124 .filter(|(name, _)| {
125 group.iter().any(|&i| name.contains(&format!("opt_{}", i)))
127 })
128 .collect();
129
130 if !group_params.is_empty() {
131 let fused_name = format!("fused_group_{}", group[0]);
133 let fused_tensor = self.concatenate_tensors(
134 &group_params.iter().map(|(_, t)| *t).collect::<Vec<_>>(),
135 )?;
136 fused_params.insert(fused_name, fused_tensor);
137 }
138 }
139 }
140
141 Ok(())
142 }
143
144 fn concatenate_tensors(&self, tensors: &[&Tensor]) -> Result<Tensor> {
146 if tensors.is_empty() {
147 return Err(TrustformersError::invalid_argument(
148 "Empty tensor list".to_string(),
149 ));
150 }
151
152 let mut total_size = 0;
154 for tensor in tensors {
155 total_size += tensor.len();
156 }
157
158 Tensor::zeros(&[total_size])
160 }
161
162 pub fn fused_step(&mut self, parameters: &mut HashMap<String, Tensor>) -> Result<()> {
164 self.fuse_parameters(parameters)?;
166
167 let fusion_groups = self.fusion_groups.clone();
169 for group in &fusion_groups {
170 if group.len() > 1 {
171 self.apply_fused_group_optimization(group)?;
173 } else {
174 let optimizer_idx = group[0];
176 for (name, param) in parameters.iter_mut() {
178 if let Some(grad) = self.get_gradient_for_param(name) {
179 self.optimizers[optimizer_idx].update(param, &grad)?;
180 }
181 }
182 }
183 }
184
185 Ok(())
186 }
187
188 fn apply_fused_group_optimization(&mut self, group: &[usize]) -> Result<()> {
190 let primary_optimizer_idx = group[0];
192
193 let mut fused_params = self.fused_parameters.lock().unwrap();
194 let fused_gradients = self.fused_gradients.lock().unwrap();
195
196 let group_name = format!("fused_group_{}", primary_optimizer_idx);
197
198 if let (Some(param), Some(grad)) = (
199 fused_params.get_mut(&group_name),
200 fused_gradients.get(&group_name),
201 ) {
202 self.optimizers[primary_optimizer_idx].update(param, grad)?;
203 }
204
205 Ok(())
206 }
207
208 fn get_gradient_for_param(&self, param_name: &str) -> Option<Tensor> {
213 {
215 let fused_gradients = self.fused_gradients.lock().ok()?;
216 if let Some(gradient) = fused_gradients.get(param_name) {
217 return Some(gradient.clone());
218 }
219 }
220
221 for (idx, _optimizer) in self.optimizers.iter().enumerate() {
224 let full_param_name = format!("optimizer_{}_{}", idx, param_name);
225
226 let fused_gradients = self.fused_gradients.lock().ok()?;
228 if let Some(gradient) = fused_gradients.get(&full_param_name) {
229 return Some(gradient.clone());
230 }
231 drop(fused_gradients);
232
233 }
237
238 None
240 }
241
242 pub fn register_gradient(&self, param_name: &str, gradient: Tensor) -> Result<()> {
247 let mut fused_gradients = self.fused_gradients.lock().map_err(|_| {
248 TrustformersError::tensor_op_error(
249 "Failed to lock fused gradients",
250 "register_gradient",
251 )
252 })?;
253
254 fused_gradients.insert(param_name.to_string(), gradient);
255 Ok(())
256 }
257
258 pub fn clear_gradients(&self) -> Result<()> {
262 let mut fused_gradients = self.fused_gradients.lock().map_err(|_| {
263 TrustformersError::tensor_op_error("Failed to lock fused gradients", "clear_gradients")
264 })?;
265
266 fused_gradients.clear();
267 Ok(())
268 }
269
270 pub fn get_available_gradient_names(&self) -> Result<Vec<String>> {
274 let fused_gradients = self.fused_gradients.lock().map_err(|_| {
275 TrustformersError::tensor_op_error(
276 "Failed to lock fused gradients",
277 "get_available_gradient_names",
278 )
279 })?;
280
281 Ok(fused_gradients.keys().cloned().collect())
282 }
283
284 pub fn get_fusion_stats(&self) -> FusionStats {
286 let total_optimizers = self.optimizers.len();
287 let fused_groups = self.fusion_groups.iter().filter(|group| group.len() > 1).count();
288 let unfused_optimizers = self.fusion_groups.iter().filter(|group| group.len() == 1).count();
289
290 FusionStats {
291 total_optimizers,
292 fused_groups,
293 unfused_optimizers,
294 fusion_ratio: fused_groups as f64 / total_optimizers as f64,
295 memory_saved: self.estimate_memory_savings(),
296 }
297 }
298
299 fn estimate_memory_savings(&self) -> usize {
301 let fused_params = self.fused_parameters.lock().unwrap();
302 let total_fused_size: usize = fused_params.values()
303 .map(|t| t.len() * 4) .sum();
305
306 let estimated_original_size = total_fused_size * 2; estimated_original_size.saturating_sub(total_fused_size)
310 }
311}
312
313#[derive(Debug, Clone, Serialize, Deserialize)]
315pub struct FusionStats {
316 pub total_optimizers: usize,
317 pub fused_groups: usize,
318 pub unfused_optimizers: usize,
319 pub fusion_ratio: f64,
320 pub memory_saved: usize,
321}
322
323pub struct MultiOptimizerTrainer {
325 optimizers: HashMap<String, Box<dyn Optimizer>>,
326 parameter_assignments: HashMap<String, String>, schedulers: HashMap<String, Box<dyn LRScheduler>>,
328 weights: HashMap<String, f64>, }
330
331impl Default for MultiOptimizerTrainer {
332 fn default() -> Self {
333 Self::new()
334 }
335}
336
337impl MultiOptimizerTrainer {
338 pub fn new() -> Self {
340 Self {
341 optimizers: HashMap::new(),
342 parameter_assignments: HashMap::new(),
343 schedulers: HashMap::new(),
344 weights: HashMap::new(),
345 }
346 }
347
348 pub fn add_optimizer(
350 &mut self,
351 name: String,
352 optimizer: Box<dyn Optimizer>,
353 weight: f64,
354 ) -> Result<()> {
355 self.optimizers.insert(name.clone(), optimizer);
356 self.weights.insert(name, weight);
357 Ok(())
358 }
359
360 pub fn add_scheduler(
362 &mut self,
363 optimizer_name: String,
364 scheduler: Box<dyn LRScheduler>,
365 ) -> Result<()> {
366 if !self.optimizers.contains_key(&optimizer_name) {
367 return Err(TrustformersError::invalid_argument(format!(
368 "Optimizer {} not found",
369 optimizer_name
370 )));
371 }
372
373 self.schedulers.insert(optimizer_name, scheduler);
374 Ok(())
375 }
376
377 pub fn assign_parameters(&mut self, assignments: HashMap<String, String>) -> Result<()> {
379 for optimizer_name in assignments.values() {
381 if !self.optimizers.contains_key(optimizer_name) {
382 return Err(TrustformersError::invalid_argument(format!(
383 "Optimizer {} not found",
384 optimizer_name
385 )));
386 }
387 }
388
389 self.parameter_assignments = assignments;
390 Ok(())
391 }
392
393 pub fn step(
395 &mut self,
396 parameters: &HashMap<String, Tensor>,
397 gradients: &HashMap<String, Tensor>,
398 ) -> Result<()> {
399 let mut optimizer_params: HashMap<String, Vec<(String, Tensor, Tensor)>> = HashMap::new();
401
402 for (param_name, param) in parameters {
403 if let Some(grad) = gradients.get(param_name) {
404 let optimizer_name = self
405 .parameter_assignments
406 .get(param_name)
407 .cloned()
408 .unwrap_or_else(|| "default".to_string());
409
410 optimizer_params.entry(optimizer_name).or_default().push((
411 param_name.clone(),
412 param.clone(),
413 grad.clone(),
414 ));
415 }
416 }
417
418 for (optimizer_name, param_grad_pairs) in optimizer_params {
420 if let Some(optimizer) = self.optimizers.get_mut(&optimizer_name) {
421 let weight = self.weights.get(&optimizer_name).copied().unwrap_or(1.0);
422
423 for (_, param, grad) in param_grad_pairs {
424 let scaled_grad = grad.mul_scalar(weight as f32)?;
426 optimizer.update(&mut param.clone(), &scaled_grad)?;
427 }
428 }
429 }
430
431 Ok(())
432 }
433
434 pub fn step_schedulers(&mut self, epoch: usize) -> Result<()> {
436 for (optimizer_name, scheduler) in &mut self.schedulers {
437 let new_lr = scheduler.get_lr(epoch);
438
439 if let Some(optimizer) = self.optimizers.get_mut(optimizer_name) {
440 optimizer.set_lr(new_lr);
441 }
442 }
443
444 Ok(())
445 }
446
447 pub fn get_stats(&self) -> MultiOptimizerStats {
449 MultiOptimizerStats {
450 num_optimizers: self.optimizers.len(),
451 num_schedulers: self.schedulers.len(),
452 num_assigned_params: self.parameter_assignments.len(),
453 optimizer_weights: self.weights.clone(),
454 }
455 }
456}
457
458#[derive(Debug, Clone, Serialize, Deserialize)]
460pub struct MultiOptimizerStats {
461 pub num_optimizers: usize,
462 pub num_schedulers: usize,
463 pub num_assigned_params: usize,
464 pub optimizer_weights: HashMap<String, f64>,
465}
466
467#[derive(Debug, Clone, Serialize, Deserialize)]
469pub enum WarmupStrategy {
470 Linear { steps: usize },
472 Exponential { steps: usize, base: f64 },
474 Cosine { steps: usize },
476 Custom { steps: usize },
478}
479
480pub struct WarmupOptimizer {
482 inner: Box<dyn Optimizer>,
483 strategy: WarmupStrategy,
484 current_step: usize,
485 base_lr: f64,
486 target_lr: f64,
487}
488
489impl WarmupOptimizer {
490 pub fn new(
492 optimizer: Box<dyn Optimizer>,
493 strategy: WarmupStrategy,
494 base_lr: f64,
495 target_lr: f64,
496 ) -> Self {
497 Self {
498 inner: optimizer,
499 strategy,
500 current_step: 0,
501 base_lr,
502 target_lr,
503 }
504 }
505
506 fn get_warmup_lr(&self) -> f64 {
508 let warmup_steps = match &self.strategy {
509 WarmupStrategy::Linear { steps } => *steps,
510 WarmupStrategy::Exponential { steps, .. } => *steps,
511 WarmupStrategy::Cosine { steps } => *steps,
512 WarmupStrategy::Custom { steps } => *steps,
513 };
514
515 if self.current_step >= warmup_steps {
516 return self.target_lr;
517 }
518
519 let progress = self.current_step as f64 / warmup_steps as f64;
520
521 match &self.strategy {
522 WarmupStrategy::Linear { .. } => {
523 self.base_lr + (self.target_lr - self.base_lr) * progress
524 },
525 WarmupStrategy::Exponential { base, .. } => {
526 self.base_lr + (self.target_lr - self.base_lr) * base.powf(1.0 - progress)
527 },
528 WarmupStrategy::Cosine { .. } => {
529 let cosine_progress = 0.5 * (1.0 - (std::f64::consts::PI * progress).cos());
530 self.base_lr + (self.target_lr - self.base_lr) * cosine_progress
531 },
532 WarmupStrategy::Custom { .. } => {
533 self.base_lr + (self.target_lr - self.base_lr) * progress
535 },
536 }
537 }
538
539 pub fn is_warmup_complete(&self) -> bool {
541 let warmup_steps = match &self.strategy {
542 WarmupStrategy::Linear { steps } => *steps,
543 WarmupStrategy::Exponential { steps, .. } => *steps,
544 WarmupStrategy::Cosine { steps } => *steps,
545 WarmupStrategy::Custom { steps } => *steps,
546 };
547
548 self.current_step >= warmup_steps
549 }
550}
551
552impl Optimizer for WarmupOptimizer {
553 fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
554 let current_lr = self.get_warmup_lr();
556 self.inner.set_lr(current_lr as f32);
557
558 self.inner.update(parameter, grad)
560 }
561
562 fn zero_grad(&mut self) {
563 self.inner.zero_grad()
564 }
565
566 fn step(&mut self) {
567 self.inner.step();
568 self.current_step += 1;
569 }
570
571 fn get_lr(&self) -> f32 {
572 self.get_warmup_lr() as f32
573 }
574
575 fn set_lr(&mut self, lr: f32) {
576 self.target_lr = lr as f64;
577 self.inner.set_lr(lr);
578 }
579}
580
581#[derive(Debug, Clone, Serialize, Deserialize)]
583pub struct CheckpointConfig {
584 pub save_interval: usize,
586 pub compress: bool,
588 pub max_checkpoints: usize,
590 pub incremental: bool,
592}
593
594impl Default for CheckpointConfig {
595 fn default() -> Self {
596 Self {
597 save_interval: 1000,
598 compress: true,
599 max_checkpoints: 5,
600 incremental: false,
601 }
602 }
603}
604
605pub struct MemoryBandwidthOptimizer {
607 inner: Box<dyn Optimizer>,
608 memory_threshold: usize,
609 bandwidth_threshold: f64,
610 adaptive_batch_size: bool,
611 current_batch_size: usize,
612 base_batch_size: usize,
613}
614
615impl MemoryBandwidthOptimizer {
616 pub fn new(
618 optimizer: Box<dyn Optimizer>,
619 memory_threshold: usize,
620 bandwidth_threshold: f64,
621 base_batch_size: usize,
622 ) -> Self {
623 Self {
624 inner: optimizer,
625 memory_threshold,
626 bandwidth_threshold,
627 adaptive_batch_size: true,
628 current_batch_size: base_batch_size,
629 base_batch_size,
630 }
631 }
632
633 pub fn adjust_batch_size(&mut self, memory_usage: usize, bandwidth_usage: f64) -> usize {
635 if !self.adaptive_batch_size {
636 return self.current_batch_size;
637 }
638
639 let memory_pressure = memory_usage as f64 / self.memory_threshold as f64;
640 let bandwidth_pressure = bandwidth_usage / self.bandwidth_threshold;
641
642 let pressure = memory_pressure.max(bandwidth_pressure);
643
644 if pressure > 1.1 {
645 self.current_batch_size = (self.current_batch_size as f64 * 0.9) as usize;
647 self.current_batch_size = self.current_batch_size.max(1);
648 } else if pressure < 0.8 {
649 self.current_batch_size = (self.current_batch_size as f64 * 1.1) as usize;
651 self.current_batch_size = self.current_batch_size.min(self.base_batch_size * 4);
652 }
653
654 self.current_batch_size
655 }
656
657 pub fn get_utilization(&self) -> ResourceUtilization {
659 ResourceUtilization {
660 current_batch_size: self.current_batch_size,
661 base_batch_size: self.base_batch_size,
662 memory_threshold: self.memory_threshold,
663 bandwidth_threshold: self.bandwidth_threshold,
664 adaptive_enabled: self.adaptive_batch_size,
665 }
666 }
667}
668
669impl Optimizer for MemoryBandwidthOptimizer {
670 fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
671 self.inner.update(parameter, grad)
672 }
673
674 fn zero_grad(&mut self) {
675 self.inner.zero_grad()
676 }
677
678 fn step(&mut self) {
679 self.inner.step()
680 }
681
682 fn get_lr(&self) -> f32 {
683 self.inner.get_lr()
684 }
685
686 fn set_lr(&mut self, lr: f32) {
687 self.inner.set_lr(lr)
688 }
689}
690
691#[derive(Debug, Clone, Serialize, Deserialize)]
693pub struct ResourceUtilization {
694 pub current_batch_size: usize,
695 pub base_batch_size: usize,
696 pub memory_threshold: usize,
697 pub bandwidth_threshold: f64,
698 pub adaptive_enabled: bool,
699}
700
701#[cfg(test)]
702mod tests {
703 use super::*;
704 use crate::Adam;
705
706 #[test]
707 fn test_fusion_config_default() {
708 let config = FusionConfig::default();
709 assert!(config.fuse_parameters);
710 assert!(config.fuse_gradients);
711 assert!(config.fuse_state);
712 assert_eq!(config.window_size, 32);
713 }
714
715 #[test]
716 fn test_warmup_strategy_linear() {
717 let strategy = WarmupStrategy::Linear { steps: 100 };
718
719 let adam = Adam::new(0.001, (0.9, 0.999), 1e-8, 0.0);
720
721 let warmup_optimizer = WarmupOptimizer::new(Box::new(adam), strategy, 0.0, 0.001);
722
723 assert!(!warmup_optimizer.is_warmup_complete());
724 assert_eq!(warmup_optimizer.get_warmup_lr(), 0.0);
725 }
726
727 #[test]
728 fn test_multi_optimizer_trainer_creation() {
729 let mut trainer = MultiOptimizerTrainer::new();
730
731 let adam = Adam::new(0.001, (0.9, 0.999), 1e-8, 0.0);
732 trainer.add_optimizer("adam".to_string(), Box::new(adam), 1.0).unwrap();
733
734 let stats = trainer.get_stats();
735 assert_eq!(stats.num_optimizers, 1);
736 assert_eq!(stats.optimizer_weights.get("adam"), Some(&1.0));
737 }
738
739 #[test]
740 fn test_memory_bandwidth_optimizer() {
741 let adam = Adam::new(0.001, (0.9, 0.999), 1e-8, 0.0);
742 let mut mb_optimizer = MemoryBandwidthOptimizer::new(
743 Box::new(adam),
744 1024 * 1024 * 100, 100.0, 32,
747 );
748
749 let utilization = mb_optimizer.get_utilization();
750 assert_eq!(utilization.current_batch_size, 32);
751 assert_eq!(utilization.base_batch_size, 32);
752
753 let new_batch_size = mb_optimizer.adjust_batch_size(
755 1024 * 1024 * 120, 50.0,
757 );
758 assert!(new_batch_size < 32);
759 }
760
761 #[test]
762 fn test_checkpoint_config_default() {
763 let config = CheckpointConfig::default();
764 assert_eq!(config.save_interval, 1000);
765 assert!(config.compress);
766 assert_eq!(config.max_checkpoints, 5);
767 assert!(!config.incremental);
768 }
769}