Skip to main content

trustformers_optim/zero/
zero_optimizer.rs

1//! Main ZeRO Optimizer Implementation
2//!
3//! This module provides the main ZeRO optimizer wrapper that coordinates
4//! between different ZeRO stages and manages the optimization process.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8use trustformers_core::errors::{Result, TrustformersError};
9use trustformers_core::parallel::ModelParallelContext;
10use trustformers_core::tensor::Tensor;
11use trustformers_core::traits::Optimizer;
12
13use super::{
14    ZeROImplementationStage, ZeROMemoryStats, ZeROStage1, ZeROStage2, ZeROStage3, ZeROState,
15};
16
17/// Configuration for ZeRO optimizer
18#[derive(Debug, Clone)]
19pub struct ZeROConfig {
20    /// ZeRO stage to use
21    pub stage: ZeROStage,
22    /// Target bucket size for gradient communication (in MB)
23    pub bucket_size_mb: usize,
24    /// Whether to overlap communication with computation
25    pub overlap_comm: bool,
26    /// Reduce bucket size (number of elements to reduce at once)
27    pub reduce_bucket_size: usize,
28    /// Prefetch depth for parameter gathering
29    pub prefetch_depth: usize,
30    /// Maximum memory usage threshold before releasing parameters
31    pub max_memory_usage_mb: usize,
32    /// Enable gradient compression
33    pub gradient_compression: bool,
34    /// Pin memory for faster GPU transfers
35    pub pin_memory: bool,
36}
37
38impl Default for ZeROConfig {
39    fn default() -> Self {
40        Self {
41            stage: ZeROStage::Stage1,
42            bucket_size_mb: 25,
43            overlap_comm: true,
44            reduce_bucket_size: 500_000_000, // 500M elements
45            prefetch_depth: 2,
46            max_memory_usage_mb: 1024, // 1GB
47            gradient_compression: false,
48            pin_memory: true,
49        }
50    }
51}
52
53/// ZeRO optimization stages
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55pub enum ZeROStage {
56    /// Stage 1: Partition optimizer states only
57    Stage1,
58    /// Stage 2: Partition optimizer states + gradients
59    Stage2,
60    /// Stage 3: Partition optimizer states + gradients + parameters
61    Stage3,
62}
63
64impl From<ZeROStage> for ZeROImplementationStage {
65    fn from(stage: ZeROStage) -> Self {
66        match stage {
67            ZeROStage::Stage1 => ZeROImplementationStage::Stage1,
68            ZeROStage::Stage2 => ZeROImplementationStage::Stage2,
69            ZeROStage::Stage3 => ZeROImplementationStage::Stage3,
70        }
71    }
72}
73
74/// Main ZeRO optimizer that wraps an underlying optimizer
75pub struct ZeROOptimizer<T: Optimizer> {
76    /// Underlying base optimizer
77    base_optimizer: T,
78    /// ZeRO configuration
79    config: ZeROConfig,
80    /// Model parallel context for communication
81    mp_context: Arc<ModelParallelContext>,
82    /// ZeRO-specific state
83    zero_state: ZeROState,
84    /// Stage 1 implementation
85    stage1: Option<ZeROStage1<T>>,
86    /// Stage 2 implementation
87    stage2: Option<ZeROStage2<T>>,
88    /// Stage 3 implementation
89    stage3: Option<ZeROStage3<T>>,
90    /// Memory statistics
91    memory_stats: ZeROMemoryStats,
92    /// Parameter names for tracking
93    parameter_names: Vec<String>,
94}
95
96impl<T: Optimizer> ZeROOptimizer<T> {
97    /// Create a new ZeRO optimizer
98    pub fn new(
99        base_optimizer: T,
100        config: ZeROConfig,
101        mp_context: Arc<ModelParallelContext>,
102    ) -> Result<Self> {
103        let mut optimizer = Self {
104            base_optimizer,
105            config: config.clone(),
106            mp_context: mp_context.clone(),
107            zero_state: ZeROState::new(),
108            stage1: None,
109            stage2: None,
110            stage3: None,
111            memory_stats: ZeROMemoryStats::new(),
112            parameter_names: Vec::new(),
113        };
114
115        // Initialize the appropriate stage
116        optimizer.initialize_stage(config.stage)?;
117
118        Ok(optimizer)
119    }
120
121    /// Initialize the specified ZeRO stage
122    fn initialize_stage(&mut self, stage: ZeROStage) -> Result<()> {
123        match stage {
124            ZeROStage::Stage1 => {
125                self.stage1 = Some(ZeROStage1::new(
126                    self.mp_context.clone(),
127                    self.config.clone(),
128                )?);
129            },
130            ZeROStage::Stage2 => {
131                self.stage2 = Some(ZeROStage2::new(
132                    self.mp_context.clone(),
133                    self.config.clone(),
134                )?);
135            },
136            ZeROStage::Stage3 => {
137                self.stage3 = Some(ZeROStage3::new(
138                    self.mp_context.clone(),
139                    self.config.clone(),
140                )?);
141            },
142        }
143        Ok(())
144    }
145
146    /// Register parameters with ZeRO optimizer
147    pub fn register_parameters(&mut self, parameters: HashMap<String, Tensor>) -> Result<()> {
148        self.parameter_names = parameters.keys().cloned().collect();
149
150        match self.config.stage {
151            ZeROStage::Stage1 => {
152                if let Some(stage1) = &mut self.stage1 {
153                    stage1.register_parameters(parameters)?;
154                }
155            },
156            ZeROStage::Stage2 => {
157                if let Some(stage2) = &mut self.stage2 {
158                    stage2.register_parameters(parameters)?;
159                }
160            },
161            ZeROStage::Stage3 => {
162                if let Some(stage3) = &mut self.stage3 {
163                    stage3.register_parameters(parameters)?;
164                }
165            },
166        }
167
168        self.update_memory_stats();
169        Ok(())
170    }
171
172    /// Update gradients for ZeRO optimization
173    pub fn update_gradients(&mut self, gradients: HashMap<String, Tensor>) -> Result<()> {
174        match self.config.stage {
175            ZeROStage::Stage1 => {
176                // Stage 1 doesn't partition gradients, use regular optimizer
177                for (name, grad) in gradients {
178                    if let Some(stage1) = &mut self.stage1 {
179                        stage1.accumulate_gradient(&name, &grad)?;
180                    }
181                }
182            },
183            ZeROStage::Stage2 => {
184                if let Some(stage2) = &mut self.stage2 {
185                    stage2.update_gradients(gradients)?;
186                }
187            },
188            ZeROStage::Stage3 => {
189                if let Some(stage3) = &mut self.stage3 {
190                    stage3.update_gradients(gradients)?;
191                }
192            },
193        }
194        Ok(())
195    }
196
197    /// Gather parameters for forward pass (Stage 3 only)
198    pub fn gather_parameters(
199        &mut self,
200        parameter_names: &[String],
201    ) -> Result<HashMap<String, Tensor>> {
202        match self.config.stage {
203            ZeROStage::Stage3 => {
204                if let Some(stage3) = &mut self.stage3 {
205                    stage3.gather_parameters(parameter_names)
206                } else {
207                    Err(TrustformersError::runtime_error(
208                        "Stage 3 not initialized".into(),
209                    ))
210                }
211            },
212            _ => {
213                // For Stage 1 and 2, parameters are not partitioned
214                Err(TrustformersError::runtime_error(
215                    "Parameter gathering only available in Stage 3".into(),
216                ))
217            },
218        }
219    }
220
221    /// Release gathered parameters to save memory (Stage 3 only)
222    pub fn release_parameters(&mut self, parameter_names: &[String]) -> Result<()> {
223        match self.config.stage {
224            ZeROStage::Stage3 => {
225                if let Some(stage3) = &mut self.stage3 {
226                    stage3.release_parameters(parameter_names)
227                } else {
228                    Err(TrustformersError::runtime_error(
229                        "Stage 3 not initialized".into(),
230                    ))
231                }
232            },
233            _ => Ok(()), // No-op for other stages
234        }
235    }
236
237    /// Get memory statistics
238    pub fn get_memory_stats(&self) -> &ZeROMemoryStats {
239        &self.memory_stats
240    }
241
242    /// Update memory statistics
243    fn update_memory_stats(&mut self) {
244        let memory_usage = self.zero_state.memory_usage();
245
246        self.memory_stats.optimizer_memory_saved =
247            memory_usage.get("optimizer_states").copied().unwrap_or(0);
248        self.memory_stats.gradient_memory_saved =
249            memory_usage.get("gradient_partitions").copied().unwrap_or(0);
250        self.memory_stats.parameter_memory_saved =
251            memory_usage.get("parameter_partitions").copied().unwrap_or(0);
252        self.memory_stats.communication_overhead =
253            memory_usage.get("communication_buffers").copied().unwrap_or(0);
254
255        self.memory_stats.update_totals();
256    }
257
258    /// Check if memory usage exceeds threshold
259    pub fn check_memory_usage(&self) -> bool {
260        let total_memory_mb = self.memory_stats.total_memory_saved / (1024 * 1024);
261        total_memory_mb > self.config.max_memory_usage_mb
262    }
263
264    /// Get current ZeRO stage
265    pub fn get_stage(&self) -> ZeROStage {
266        self.config.stage
267    }
268
269    /// Get the underlying base optimizer
270    pub fn base_optimizer(&self) -> &T {
271        &self.base_optimizer
272    }
273
274    /// Get mutable reference to base optimizer
275    pub fn base_optimizer_mut(&mut self) -> &mut T {
276        &mut self.base_optimizer
277    }
278
279    /// Get model parallel context
280    pub fn mp_context(&self) -> &Arc<ModelParallelContext> {
281        &self.mp_context
282    }
283
284    /// Perform optimizer step with ZeRO optimizations
285    pub fn optimizer_step(&mut self) -> Result<()> {
286        match self.config.stage {
287            ZeROStage::Stage1 => {
288                if let Some(stage1) = &mut self.stage1 {
289                    stage1.optimizer_step(&mut self.base_optimizer)?;
290                }
291            },
292            ZeROStage::Stage2 => {
293                if let Some(stage2) = &mut self.stage2 {
294                    stage2.optimizer_step(&mut self.base_optimizer)?;
295                }
296            },
297            ZeROStage::Stage3 => {
298                if let Some(stage3) = &mut self.stage3 {
299                    stage3.optimizer_step(&mut self.base_optimizer)?;
300                }
301            },
302        }
303
304        self.zero_state.step();
305        self.update_memory_stats();
306        Ok(())
307    }
308}
309
310impl<T: Optimizer> Optimizer for ZeROOptimizer<T> {
311    fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
312        // ZeRO optimizer handles updates through its stage implementations
313        // This method is called for individual parameter updates
314        match self.config.stage {
315            ZeROStage::Stage1 => {
316                if let Some(stage1) = &mut self.stage1 {
317                    stage1.update_parameter(parameter, grad, &mut self.base_optimizer)
318                } else {
319                    self.base_optimizer.update(parameter, grad)
320                }
321            },
322            ZeROStage::Stage2 | ZeROStage::Stage3 => {
323                // For Stage 2 and 3, gradients are handled in batches
324                // Individual updates are not recommended
325                Err(TrustformersError::runtime_error(
326                    "Individual parameter updates not supported in ZeRO Stage 2/3. Use batch updates."
327                        .into()
328                ))
329            },
330        }
331    }
332
333    fn zero_grad(&mut self) {
334        self.zero_state.zero_grad();
335        self.base_optimizer.zero_grad();
336    }
337
338    fn step(&mut self) {
339        self.base_optimizer.step();
340        self.zero_state.step();
341    }
342
343    fn get_lr(&self) -> f32 {
344        self.base_optimizer.get_lr()
345    }
346
347    fn set_lr(&mut self, lr: f32) {
348        self.base_optimizer.set_lr(lr);
349    }
350
351    fn accumulate_grad(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
352        // Handle gradient accumulation through ZeRO stages
353        match self.config.stage {
354            ZeROStage::Stage1 => {
355                if let Some(stage1) = &mut self.stage1 {
356                    // Stage 1 can use regular gradient accumulation
357                    stage1.accumulate_gradient_for_parameter(parameter, grad)
358                } else {
359                    self.base_optimizer.accumulate_grad(parameter, grad)
360                }
361            },
362            ZeROStage::Stage2 | ZeROStage::Stage3 => {
363                // For Stage 2/3, accumulation is handled in the stage implementation
364                Err(TrustformersError::runtime_error(
365                    "Gradient accumulation in ZeRO Stage 2/3 should be handled through update_gradients"
366                        .into()
367                ))
368            },
369        }
370    }
371
372    fn apply_accumulated_grads(&mut self, accumulation_steps: usize) -> Result<()> {
373        match self.config.stage {
374            ZeROStage::Stage1 => {
375                if let Some(stage1) = &mut self.stage1 {
376                    stage1.apply_accumulated_gradients(&mut self.base_optimizer, accumulation_steps)
377                } else {
378                    self.base_optimizer.apply_accumulated_grads(accumulation_steps)
379                }
380            },
381            ZeROStage::Stage2 => {
382                if let Some(stage2) = &mut self.stage2 {
383                    stage2.apply_accumulated_gradients(&mut self.base_optimizer, accumulation_steps)
384                } else {
385                    Err(TrustformersError::runtime_error(
386                        "Stage 2 not initialized".into(),
387                    ))
388                }
389            },
390            ZeROStage::Stage3 => {
391                if let Some(stage3) = &mut self.stage3 {
392                    stage3.apply_accumulated_gradients(&mut self.base_optimizer, accumulation_steps)
393                } else {
394                    Err(TrustformersError::runtime_error(
395                        "Stage 3 not initialized".into(),
396                    ))
397                }
398            },
399        }
400    }
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406    use crate::adam::Adam;
407    use trustformers_core::parallel::{
408        CommunicationBackend, ModelParallelConfig, ModelParallelStrategy,
409    };
410
411    #[test]
412    fn test_zero_optimizer_creation() {
413        let config = ModelParallelConfig {
414            num_devices: 2,
415            device_ids: vec![0, 1],
416            strategy: ModelParallelStrategy::Pipeline,
417            comm_backend: CommunicationBackend::Custom,
418            ..Default::default()
419        };
420        let mp_context = Arc::new(ModelParallelContext::new(config).unwrap());
421
422        let adam = Adam::new(0.001, (0.9, 0.999), 1e-8, 0.01);
423        let zero_config = ZeROConfig::default();
424
425        let zero_optimizer = ZeROOptimizer::new(adam, zero_config, mp_context);
426        assert!(zero_optimizer.is_ok());
427
428        let optimizer = zero_optimizer.unwrap();
429        assert_eq!(optimizer.get_stage(), ZeROStage::Stage1);
430    }
431
432    #[test]
433    fn test_zero_stage_initialization() {
434        let config = ModelParallelConfig {
435            num_devices: 4,
436            device_ids: vec![0, 1, 2, 3],
437            strategy: ModelParallelStrategy::Pipeline,
438            comm_backend: CommunicationBackend::Custom,
439            ..Default::default()
440        };
441        let mp_context = Arc::new(ModelParallelContext::new(config).unwrap());
442
443        // Test Stage 2
444        let adam = Adam::new(0.001, (0.9, 0.999), 1e-8, 0.01);
445        let zero_config = ZeROConfig {
446            stage: ZeROStage::Stage2,
447            ..Default::default()
448        };
449
450        let zero_optimizer = ZeROOptimizer::new(adam, zero_config, mp_context.clone());
451        assert!(zero_optimizer.is_ok());
452
453        // Test Stage 3
454        let adam = Adam::new(0.001, (0.9, 0.999), 1e-8, 0.01);
455        let zero_config = ZeROConfig {
456            stage: ZeROStage::Stage3,
457            ..Default::default()
458        };
459
460        let zero_optimizer = ZeROOptimizer::new(adam, zero_config, mp_context);
461        assert!(zero_optimizer.is_ok());
462    }
463
464    #[test]
465    fn test_parameter_registration() {
466        let config = ModelParallelConfig {
467            num_devices: 2,
468            device_ids: vec![0, 1],
469            strategy: ModelParallelStrategy::Pipeline,
470            comm_backend: CommunicationBackend::Custom,
471            ..Default::default()
472        };
473        let mp_context = Arc::new(ModelParallelContext::new(config).unwrap());
474
475        let adam = Adam::new(0.001, (0.9, 0.999), 1e-8, 0.01);
476        let zero_config = ZeROConfig::default();
477        let mut zero_optimizer = ZeROOptimizer::new(adam, zero_config, mp_context).unwrap();
478
479        let mut parameters = HashMap::new();
480        parameters.insert("weight1".to_string(), Tensor::ones(&[4, 4]).unwrap());
481        parameters.insert("bias1".to_string(), Tensor::ones(&[4]).unwrap());
482
483        let result = zero_optimizer.register_parameters(parameters);
484        assert!(result.is_ok());
485        assert_eq!(zero_optimizer.parameter_names.len(), 2);
486    }
487
488    #[test]
489    fn test_memory_stats() {
490        let config = ModelParallelConfig {
491            num_devices: 2,
492            device_ids: vec![0, 1],
493            strategy: ModelParallelStrategy::Pipeline,
494            comm_backend: CommunicationBackend::Custom,
495            ..Default::default()
496        };
497        let mp_context = Arc::new(ModelParallelContext::new(config).unwrap());
498
499        let adam = Adam::new(0.001, (0.9, 0.999), 1e-8, 0.01);
500        let zero_config = ZeROConfig::default();
501        let zero_optimizer = ZeROOptimizer::new(adam, zero_config, mp_context).unwrap();
502
503        let stats = zero_optimizer.get_memory_stats();
504        assert_eq!(stats.optimizer_memory_saved, 0); // No parameters registered yet
505        assert_eq!(stats.gradient_memory_saved, 0);
506        assert_eq!(stats.parameter_memory_saved, 0);
507    }
508}