Skip to main content

trustformers_optim/
advanced_features.rs

1//! Advanced Optimizer Features
2//!
3//! This module implements advanced optimization techniques including optimizer fusion,
4//! multi-optimizer training, warm-up strategies, and checkpointing optimizations.
5
6use crate::LRScheduler;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::{Arc, Mutex};
10use trustformers_core::errors::{Result, TrustformersError};
11use trustformers_core::traits::Optimizer;
12use trustformers_core::Tensor;
13
14/// Configuration for optimizer fusion
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct FusionConfig {
17    /// Whether to enable parameter fusion
18    pub fuse_parameters: bool,
19    /// Whether to enable gradient fusion
20    pub fuse_gradients: bool,
21    /// Whether to enable state fusion
22    pub fuse_state: bool,
23    /// Fusion window size
24    pub window_size: usize,
25    /// Memory threshold for fusion (in bytes)
26    pub memory_threshold: usize,
27}
28
29impl Default for FusionConfig {
30    fn default() -> Self {
31        Self {
32            fuse_parameters: true,
33            fuse_gradients: true,
34            fuse_state: true,
35            window_size: 32,
36            memory_threshold: 1024 * 1024 * 100, // 100MB
37        }
38    }
39}
40
41/// Fused optimizer that combines multiple optimizers for efficiency
42pub struct FusedOptimizer {
43    optimizers: Vec<Box<dyn Optimizer>>,
44    config: FusionConfig,
45    fused_parameters: Arc<Mutex<HashMap<String, Tensor>>>,
46    fused_gradients: Arc<Mutex<HashMap<String, Tensor>>>,
47    fusion_groups: Vec<Vec<usize>>, // Groups of optimizer indices that can be fused
48}
49
50impl FusedOptimizer {
51    /// Create a new fused optimizer
52    pub fn new(optimizers: Vec<Box<dyn Optimizer>>, config: FusionConfig) -> Result<Self> {
53        let fusion_groups = Self::compute_fusion_groups(&optimizers, &config);
54
55        Ok(Self {
56            optimizers,
57            config,
58            fused_parameters: Arc::new(Mutex::new(HashMap::new())),
59            fused_gradients: Arc::new(Mutex::new(HashMap::new())),
60            fusion_groups,
61        })
62    }
63
64    /// Compute fusion groups based on optimizer compatibility
65    fn compute_fusion_groups(
66        optimizers: &[Box<dyn Optimizer>],
67        config: &FusionConfig,
68    ) -> Vec<Vec<usize>> {
69        let mut groups = Vec::new();
70        let mut used = vec![false; optimizers.len()];
71
72        for i in 0..optimizers.len() {
73            if used[i] {
74                continue;
75            }
76
77            let mut group = vec![i];
78            used[i] = true;
79
80            // Find compatible optimizers to fuse with
81            for j in (i + 1)..optimizers.len() {
82                if used[j] {
83                    continue;
84                }
85
86                if Self::can_fuse(&optimizers[i], &optimizers[j], config) {
87                    group.push(j);
88                    used[j] = true;
89                }
90            }
91
92            groups.push(group);
93        }
94
95        groups
96    }
97
98    /// Check if two optimizers can be fused
99    fn can_fuse(
100        _opt1: &Box<dyn Optimizer>,
101        _opt2: &Box<dyn Optimizer>,
102        _config: &FusionConfig,
103    ) -> bool {
104        // Simple heuristic: for now, assume all optimizers can be fused
105        // In a real implementation, we would check optimizer types and configurations
106        true
107    }
108
109    /// Fuse parameters across optimizer groups
110    fn fuse_parameters(&self, parameters: &mut HashMap<String, Tensor>) -> Result<()> {
111        if !self.config.fuse_parameters {
112            return Ok(());
113        }
114
115        let mut fused_params = self.fused_parameters.lock().unwrap();
116        fused_params.clear();
117
118        // Group parameters by fusion groups
119        for group in &self.fusion_groups {
120            if group.len() > 1 {
121                // Create fused parameter tensor for this group
122                let group_params: Vec<_> = parameters
123                    .iter()
124                    .filter(|(name, _)| {
125                        // Check if parameter belongs to this group (simplified)
126                        group.iter().any(|&i| name.contains(&format!("opt_{}", i)))
127                    })
128                    .collect();
129
130                if !group_params.is_empty() {
131                    // Concatenate parameters
132                    let fused_name = format!("fused_group_{}", group[0]);
133                    let fused_tensor = self.concatenate_tensors(
134                        &group_params.iter().map(|(_, t)| *t).collect::<Vec<_>>(),
135                    )?;
136                    fused_params.insert(fused_name, fused_tensor);
137                }
138            }
139        }
140
141        Ok(())
142    }
143
144    /// Concatenate tensors for fusion
145    fn concatenate_tensors(&self, tensors: &[&Tensor]) -> Result<Tensor> {
146        if tensors.is_empty() {
147            return Err(TrustformersError::invalid_argument(
148                "Empty tensor list".to_string(),
149            ));
150        }
151
152        // Flatten all tensors and concatenate
153        let mut total_size = 0;
154        for tensor in tensors {
155            total_size += tensor.len();
156        }
157
158        // Create concatenated tensor (simplified implementation)
159        Tensor::zeros(&[total_size])
160    }
161
162    /// Perform fused optimization step
163    pub fn fused_step(&mut self, parameters: &mut HashMap<String, Tensor>) -> Result<()> {
164        // Fuse parameters
165        self.fuse_parameters(parameters)?;
166
167        // Apply optimization steps to fused groups
168        let fusion_groups = self.fusion_groups.clone();
169        for group in &fusion_groups {
170            if group.len() > 1 {
171                // Apply fused optimization
172                self.apply_fused_group_optimization(group)?;
173            } else {
174                // Apply single optimizer
175                let optimizer_idx = group[0];
176                // Apply optimization for single optimizer (simplified)
177                for (name, param) in parameters.iter_mut() {
178                    if let Some(grad) = self.get_gradient_for_param(name) {
179                        self.optimizers[optimizer_idx].update(param, &grad)?;
180                    }
181                }
182            }
183        }
184
185        Ok(())
186    }
187
188    /// Apply optimization to a fused group
189    fn apply_fused_group_optimization(&mut self, group: &[usize]) -> Result<()> {
190        // Use the first optimizer in the group as the representative
191        let primary_optimizer_idx = group[0];
192
193        let mut fused_params = self.fused_parameters.lock().unwrap();
194        let fused_gradients = self.fused_gradients.lock().unwrap();
195
196        let group_name = format!("fused_group_{}", primary_optimizer_idx);
197
198        if let (Some(param), Some(grad)) = (
199            fused_params.get_mut(&group_name),
200            fused_gradients.get(&group_name),
201        ) {
202            self.optimizers[primary_optimizer_idx].update(param, grad)?;
203        }
204
205        Ok(())
206    }
207
208    /// Get gradient for parameter
209    ///
210    /// Integrates with the automatic differentiation system to retrieve
211    /// accumulated gradients for the specified parameter.
212    fn get_gradient_for_param(&self, param_name: &str) -> Option<Tensor> {
213        // Check fused gradients first (higher priority)
214        {
215            let fused_gradients = self.fused_gradients.lock().ok()?;
216            if let Some(gradient) = fused_gradients.get(param_name) {
217                return Some(gradient.clone());
218            }
219        }
220
221        // Check individual optimizer gradients by parameter name
222        // Parameter names are typically formatted as "optimizer_{idx}_{param_name}"
223        for (idx, _optimizer) in self.optimizers.iter().enumerate() {
224            let full_param_name = format!("optimizer_{}_{}", idx, param_name);
225
226            // Try to get gradient from fused gradients with full name
227            let fused_gradients = self.fused_gradients.lock().ok()?;
228            if let Some(gradient) = fused_gradients.get(&full_param_name) {
229                return Some(gradient.clone());
230            }
231            drop(fused_gradients);
232
233            // For individual parameters, we would need access to the parameter
234            // registry maintained by the automatic differentiation system.
235            // This is typically maintained at the model level rather than optimizer level.
236        }
237
238        // Return None if gradient not found in any registry
239        None
240    }
241
242    /// Register gradient for parameter in the fused gradient registry
243    ///
244    /// This method allows external automatic differentiation systems to register
245    /// computed gradients with the fused optimizer for parameter updates.
246    pub fn register_gradient(&self, param_name: &str, gradient: Tensor) -> Result<()> {
247        let mut fused_gradients = self.fused_gradients.lock().map_err(|_| {
248            TrustformersError::tensor_op_error(
249                "Failed to lock fused gradients",
250                "register_gradient",
251            )
252        })?;
253
254        fused_gradients.insert(param_name.to_string(), gradient);
255        Ok(())
256    }
257
258    /// Clear all registered gradients
259    ///
260    /// This should be called after each optimization step to clear accumulated gradients.
261    pub fn clear_gradients(&self) -> Result<()> {
262        let mut fused_gradients = self.fused_gradients.lock().map_err(|_| {
263            TrustformersError::tensor_op_error("Failed to lock fused gradients", "clear_gradients")
264        })?;
265
266        fused_gradients.clear();
267        Ok(())
268    }
269
270    /// Get all available gradient parameter names
271    ///
272    /// Returns a list of parameter names for which gradients are currently available.
273    pub fn get_available_gradient_names(&self) -> Result<Vec<String>> {
274        let fused_gradients = self.fused_gradients.lock().map_err(|_| {
275            TrustformersError::tensor_op_error(
276                "Failed to lock fused gradients",
277                "get_available_gradient_names",
278            )
279        })?;
280
281        Ok(fused_gradients.keys().cloned().collect())
282    }
283
284    /// Get fusion statistics
285    pub fn get_fusion_stats(&self) -> FusionStats {
286        let total_optimizers = self.optimizers.len();
287        let fused_groups = self.fusion_groups.iter().filter(|group| group.len() > 1).count();
288        let unfused_optimizers = self.fusion_groups.iter().filter(|group| group.len() == 1).count();
289
290        FusionStats {
291            total_optimizers,
292            fused_groups,
293            unfused_optimizers,
294            fusion_ratio: fused_groups as f64 / total_optimizers as f64,
295            memory_saved: self.estimate_memory_savings(),
296        }
297    }
298
299    /// Estimate memory savings from fusion
300    fn estimate_memory_savings(&self) -> usize {
301        let fused_params = self.fused_parameters.lock().unwrap();
302        let total_fused_size: usize = fused_params.values()
303            .map(|t| t.len() * 4) // Assuming f32 tensors
304            .sum();
305
306        // Estimate original size
307        let estimated_original_size = total_fused_size * 2; // Conservative estimate
308
309        estimated_original_size.saturating_sub(total_fused_size)
310    }
311}
312
313/// Statistics for optimizer fusion
314#[derive(Debug, Clone, Serialize, Deserialize)]
315pub struct FusionStats {
316    pub total_optimizers: usize,
317    pub fused_groups: usize,
318    pub unfused_optimizers: usize,
319    pub fusion_ratio: f64,
320    pub memory_saved: usize,
321}
322
323/// Multi-optimizer training system
324pub struct MultiOptimizerTrainer {
325    optimizers: HashMap<String, Box<dyn Optimizer>>,
326    parameter_assignments: HashMap<String, String>, // param_name -> optimizer_name
327    schedulers: HashMap<String, Box<dyn LRScheduler>>,
328    weights: HashMap<String, f64>, // optimizer weights for ensemble
329}
330
331impl Default for MultiOptimizerTrainer {
332    fn default() -> Self {
333        Self::new()
334    }
335}
336
337impl MultiOptimizerTrainer {
338    /// Create a new multi-optimizer trainer
339    pub fn new() -> Self {
340        Self {
341            optimizers: HashMap::new(),
342            parameter_assignments: HashMap::new(),
343            schedulers: HashMap::new(),
344            weights: HashMap::new(),
345        }
346    }
347
348    /// Add an optimizer with a name
349    pub fn add_optimizer(
350        &mut self,
351        name: String,
352        optimizer: Box<dyn Optimizer>,
353        weight: f64,
354    ) -> Result<()> {
355        self.optimizers.insert(name.clone(), optimizer);
356        self.weights.insert(name, weight);
357        Ok(())
358    }
359
360    /// Add a scheduler for an optimizer
361    pub fn add_scheduler(
362        &mut self,
363        optimizer_name: String,
364        scheduler: Box<dyn LRScheduler>,
365    ) -> Result<()> {
366        if !self.optimizers.contains_key(&optimizer_name) {
367            return Err(TrustformersError::invalid_argument(format!(
368                "Optimizer {} not found",
369                optimizer_name
370            )));
371        }
372
373        self.schedulers.insert(optimizer_name, scheduler);
374        Ok(())
375    }
376
377    /// Assign parameters to optimizers
378    pub fn assign_parameters(&mut self, assignments: HashMap<String, String>) -> Result<()> {
379        // Validate that all assigned optimizers exist
380        for optimizer_name in assignments.values() {
381            if !self.optimizers.contains_key(optimizer_name) {
382                return Err(TrustformersError::invalid_argument(format!(
383                    "Optimizer {} not found",
384                    optimizer_name
385                )));
386            }
387        }
388
389        self.parameter_assignments = assignments;
390        Ok(())
391    }
392
393    /// Perform multi-optimizer training step
394    pub fn step(
395        &mut self,
396        parameters: &HashMap<String, Tensor>,
397        gradients: &HashMap<String, Tensor>,
398    ) -> Result<()> {
399        // Group parameters by optimizer
400        let mut optimizer_params: HashMap<String, Vec<(String, Tensor, Tensor)>> = HashMap::new();
401
402        for (param_name, param) in parameters {
403            if let Some(grad) = gradients.get(param_name) {
404                let optimizer_name = self
405                    .parameter_assignments
406                    .get(param_name)
407                    .cloned()
408                    .unwrap_or_else(|| "default".to_string());
409
410                optimizer_params.entry(optimizer_name).or_default().push((
411                    param_name.clone(),
412                    param.clone(),
413                    grad.clone(),
414                ));
415            }
416        }
417
418        // Apply optimizers
419        for (optimizer_name, param_grad_pairs) in optimizer_params {
420            if let Some(optimizer) = self.optimizers.get_mut(&optimizer_name) {
421                let weight = self.weights.get(&optimizer_name).copied().unwrap_or(1.0);
422
423                for (_, param, grad) in param_grad_pairs {
424                    // Scale gradient by optimizer weight
425                    let scaled_grad = grad.mul_scalar(weight as f32)?;
426                    optimizer.update(&mut param.clone(), &scaled_grad)?;
427                }
428            }
429        }
430
431        Ok(())
432    }
433
434    /// Update learning rates using schedulers
435    pub fn step_schedulers(&mut self, epoch: usize) -> Result<()> {
436        for (optimizer_name, scheduler) in &mut self.schedulers {
437            let new_lr = scheduler.get_lr(epoch);
438
439            if let Some(optimizer) = self.optimizers.get_mut(optimizer_name) {
440                optimizer.set_lr(new_lr);
441            }
442        }
443
444        Ok(())
445    }
446
447    /// Get training statistics
448    pub fn get_stats(&self) -> MultiOptimizerStats {
449        MultiOptimizerStats {
450            num_optimizers: self.optimizers.len(),
451            num_schedulers: self.schedulers.len(),
452            num_assigned_params: self.parameter_assignments.len(),
453            optimizer_weights: self.weights.clone(),
454        }
455    }
456}
457
458/// Statistics for multi-optimizer training
459#[derive(Debug, Clone, Serialize, Deserialize)]
460pub struct MultiOptimizerStats {
461    pub num_optimizers: usize,
462    pub num_schedulers: usize,
463    pub num_assigned_params: usize,
464    pub optimizer_weights: HashMap<String, f64>,
465}
466
467/// Optimizer warm-up strategies
468#[derive(Debug, Clone, Serialize, Deserialize)]
469pub enum WarmupStrategy {
470    /// Linear warmup
471    Linear { steps: usize },
472    /// Exponential warmup
473    Exponential { steps: usize, base: f64 },
474    /// Cosine warmup
475    Cosine { steps: usize },
476    /// Custom warmup function
477    Custom { steps: usize },
478}
479
480/// Warmup optimizer wrapper
481pub struct WarmupOptimizer {
482    inner: Box<dyn Optimizer>,
483    strategy: WarmupStrategy,
484    current_step: usize,
485    base_lr: f64,
486    target_lr: f64,
487}
488
489impl WarmupOptimizer {
490    /// Create a new warmup optimizer
491    pub fn new(
492        optimizer: Box<dyn Optimizer>,
493        strategy: WarmupStrategy,
494        base_lr: f64,
495        target_lr: f64,
496    ) -> Self {
497        Self {
498            inner: optimizer,
499            strategy,
500            current_step: 0,
501            base_lr,
502            target_lr,
503        }
504    }
505
506    /// Get current learning rate based on warmup strategy
507    fn get_warmup_lr(&self) -> f64 {
508        let warmup_steps = match &self.strategy {
509            WarmupStrategy::Linear { steps } => *steps,
510            WarmupStrategy::Exponential { steps, .. } => *steps,
511            WarmupStrategy::Cosine { steps } => *steps,
512            WarmupStrategy::Custom { steps } => *steps,
513        };
514
515        if self.current_step >= warmup_steps {
516            return self.target_lr;
517        }
518
519        let progress = self.current_step as f64 / warmup_steps as f64;
520
521        match &self.strategy {
522            WarmupStrategy::Linear { .. } => {
523                self.base_lr + (self.target_lr - self.base_lr) * progress
524            },
525            WarmupStrategy::Exponential { base, .. } => {
526                self.base_lr + (self.target_lr - self.base_lr) * base.powf(1.0 - progress)
527            },
528            WarmupStrategy::Cosine { .. } => {
529                let cosine_progress = 0.5 * (1.0 - (std::f64::consts::PI * progress).cos());
530                self.base_lr + (self.target_lr - self.base_lr) * cosine_progress
531            },
532            WarmupStrategy::Custom { .. } => {
533                // Custom implementation would go here
534                self.base_lr + (self.target_lr - self.base_lr) * progress
535            },
536        }
537    }
538
539    /// Check if warmup is complete
540    pub fn is_warmup_complete(&self) -> bool {
541        let warmup_steps = match &self.strategy {
542            WarmupStrategy::Linear { steps } => *steps,
543            WarmupStrategy::Exponential { steps, .. } => *steps,
544            WarmupStrategy::Cosine { steps } => *steps,
545            WarmupStrategy::Custom { steps } => *steps,
546        };
547
548        self.current_step >= warmup_steps
549    }
550}
551
552impl Optimizer for WarmupOptimizer {
553    fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
554        // Update learning rate based on warmup strategy
555        let current_lr = self.get_warmup_lr();
556        self.inner.set_lr(current_lr as f32);
557
558        // Perform the actual optimization step
559        self.inner.update(parameter, grad)
560    }
561
562    fn zero_grad(&mut self) {
563        self.inner.zero_grad()
564    }
565
566    fn step(&mut self) {
567        self.inner.step();
568        self.current_step += 1;
569    }
570
571    fn get_lr(&self) -> f32 {
572        self.get_warmup_lr() as f32
573    }
574
575    fn set_lr(&mut self, lr: f32) {
576        self.target_lr = lr as f64;
577        self.inner.set_lr(lr);
578    }
579}
580
581/// Checkpointing optimization configuration
582#[derive(Debug, Clone, Serialize, Deserialize)]
583pub struct CheckpointConfig {
584    /// Save interval (in steps)
585    pub save_interval: usize,
586    /// Whether to compress checkpoints
587    pub compress: bool,
588    /// Maximum number of checkpoints to keep
589    pub max_checkpoints: usize,
590    /// Whether to save only state diffs
591    pub incremental: bool,
592}
593
594impl Default for CheckpointConfig {
595    fn default() -> Self {
596        Self {
597            save_interval: 1000,
598            compress: true,
599            max_checkpoints: 5,
600            incremental: false,
601        }
602    }
603}
604
605/// Memory-bandwidth co-optimization
606pub struct MemoryBandwidthOptimizer {
607    inner: Box<dyn Optimizer>,
608    memory_threshold: usize,
609    bandwidth_threshold: f64,
610    adaptive_batch_size: bool,
611    current_batch_size: usize,
612    base_batch_size: usize,
613}
614
615impl MemoryBandwidthOptimizer {
616    /// Create a new memory-bandwidth co-optimizer
617    pub fn new(
618        optimizer: Box<dyn Optimizer>,
619        memory_threshold: usize,
620        bandwidth_threshold: f64,
621        base_batch_size: usize,
622    ) -> Self {
623        Self {
624            inner: optimizer,
625            memory_threshold,
626            bandwidth_threshold,
627            adaptive_batch_size: true,
628            current_batch_size: base_batch_size,
629            base_batch_size,
630        }
631    }
632
633    /// Adjust batch size based on memory and bandwidth usage
634    pub fn adjust_batch_size(&mut self, memory_usage: usize, bandwidth_usage: f64) -> usize {
635        if !self.adaptive_batch_size {
636            return self.current_batch_size;
637        }
638
639        let memory_pressure = memory_usage as f64 / self.memory_threshold as f64;
640        let bandwidth_pressure = bandwidth_usage / self.bandwidth_threshold;
641
642        let pressure = memory_pressure.max(bandwidth_pressure);
643
644        if pressure > 1.1 {
645            // High pressure - reduce batch size
646            self.current_batch_size = (self.current_batch_size as f64 * 0.9) as usize;
647            self.current_batch_size = self.current_batch_size.max(1);
648        } else if pressure < 0.8 {
649            // Low pressure - increase batch size
650            self.current_batch_size = (self.current_batch_size as f64 * 1.1) as usize;
651            self.current_batch_size = self.current_batch_size.min(self.base_batch_size * 4);
652        }
653
654        self.current_batch_size
655    }
656
657    /// Get current resource utilization
658    pub fn get_utilization(&self) -> ResourceUtilization {
659        ResourceUtilization {
660            current_batch_size: self.current_batch_size,
661            base_batch_size: self.base_batch_size,
662            memory_threshold: self.memory_threshold,
663            bandwidth_threshold: self.bandwidth_threshold,
664            adaptive_enabled: self.adaptive_batch_size,
665        }
666    }
667}
668
669impl Optimizer for MemoryBandwidthOptimizer {
670    fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
671        self.inner.update(parameter, grad)
672    }
673
674    fn zero_grad(&mut self) {
675        self.inner.zero_grad()
676    }
677
678    fn step(&mut self) {
679        self.inner.step()
680    }
681
682    fn get_lr(&self) -> f32 {
683        self.inner.get_lr()
684    }
685
686    fn set_lr(&mut self, lr: f32) {
687        self.inner.set_lr(lr)
688    }
689}
690
691/// Resource utilization statistics
692#[derive(Debug, Clone, Serialize, Deserialize)]
693pub struct ResourceUtilization {
694    pub current_batch_size: usize,
695    pub base_batch_size: usize,
696    pub memory_threshold: usize,
697    pub bandwidth_threshold: f64,
698    pub adaptive_enabled: bool,
699}
700
701#[cfg(test)]
702mod tests {
703    use super::*;
704    use crate::Adam;
705
706    #[test]
707    fn test_fusion_config_default() {
708        let config = FusionConfig::default();
709        assert!(config.fuse_parameters);
710        assert!(config.fuse_gradients);
711        assert!(config.fuse_state);
712        assert_eq!(config.window_size, 32);
713    }
714
715    #[test]
716    fn test_warmup_strategy_linear() {
717        let strategy = WarmupStrategy::Linear { steps: 100 };
718
719        let adam = Adam::new(0.001, (0.9, 0.999), 1e-8, 0.0);
720
721        let warmup_optimizer = WarmupOptimizer::new(Box::new(adam), strategy, 0.0, 0.001);
722
723        assert!(!warmup_optimizer.is_warmup_complete());
724        assert_eq!(warmup_optimizer.get_warmup_lr(), 0.0);
725    }
726
727    #[test]
728    fn test_multi_optimizer_trainer_creation() {
729        let mut trainer = MultiOptimizerTrainer::new();
730
731        let adam = Adam::new(0.001, (0.9, 0.999), 1e-8, 0.0);
732        trainer.add_optimizer("adam".to_string(), Box::new(adam), 1.0).unwrap();
733
734        let stats = trainer.get_stats();
735        assert_eq!(stats.num_optimizers, 1);
736        assert_eq!(stats.optimizer_weights.get("adam"), Some(&1.0));
737    }
738
739    #[test]
740    fn test_memory_bandwidth_optimizer() {
741        let adam = Adam::new(0.001, (0.9, 0.999), 1e-8, 0.0);
742        let mut mb_optimizer = MemoryBandwidthOptimizer::new(
743            Box::new(adam),
744            1024 * 1024 * 100, // 100MB
745            100.0,             // 100 MB/s
746            32,
747        );
748
749        let utilization = mb_optimizer.get_utilization();
750        assert_eq!(utilization.current_batch_size, 32);
751        assert_eq!(utilization.base_batch_size, 32);
752
753        // Test batch size adjustment under high memory pressure
754        let new_batch_size = mb_optimizer.adjust_batch_size(
755            1024 * 1024 * 120, // 120MB (above threshold)
756            50.0,
757        );
758        assert!(new_batch_size < 32);
759    }
760
761    #[test]
762    fn test_checkpoint_config_default() {
763        let config = CheckpointConfig::default();
764        assert_eq!(config.save_interval, 1000);
765        assert!(config.compress);
766        assert_eq!(config.max_checkpoints, 5);
767        assert!(!config.incremental);
768    }
769}