1use std::collections::HashMap;
7use std::sync::Arc;
8use trustformers_core::errors::{Result, TrustformersError};
9use trustformers_core::parallel::ModelParallelContext;
10use trustformers_core::tensor::Tensor;
11use trustformers_core::traits::Optimizer;
12
13use super::{
14 ZeROImplementationStage, ZeROMemoryStats, ZeROStage1, ZeROStage2, ZeROStage3, ZeROState,
15};
16
17#[derive(Debug, Clone)]
19pub struct ZeROConfig {
20 pub stage: ZeROStage,
22 pub bucket_size_mb: usize,
24 pub overlap_comm: bool,
26 pub reduce_bucket_size: usize,
28 pub prefetch_depth: usize,
30 pub max_memory_usage_mb: usize,
32 pub gradient_compression: bool,
34 pub pin_memory: bool,
36}
37
38impl Default for ZeROConfig {
39 fn default() -> Self {
40 Self {
41 stage: ZeROStage::Stage1,
42 bucket_size_mb: 25,
43 overlap_comm: true,
44 reduce_bucket_size: 500_000_000, prefetch_depth: 2,
46 max_memory_usage_mb: 1024, gradient_compression: false,
48 pin_memory: true,
49 }
50 }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55pub enum ZeROStage {
56 Stage1,
58 Stage2,
60 Stage3,
62}
63
64impl From<ZeROStage> for ZeROImplementationStage {
65 fn from(stage: ZeROStage) -> Self {
66 match stage {
67 ZeROStage::Stage1 => ZeROImplementationStage::Stage1,
68 ZeROStage::Stage2 => ZeROImplementationStage::Stage2,
69 ZeROStage::Stage3 => ZeROImplementationStage::Stage3,
70 }
71 }
72}
73
74pub struct ZeROOptimizer<T: Optimizer> {
76 base_optimizer: T,
78 config: ZeROConfig,
80 mp_context: Arc<ModelParallelContext>,
82 zero_state: ZeROState,
84 stage1: Option<ZeROStage1<T>>,
86 stage2: Option<ZeROStage2<T>>,
88 stage3: Option<ZeROStage3<T>>,
90 memory_stats: ZeROMemoryStats,
92 parameter_names: Vec<String>,
94}
95
96impl<T: Optimizer> ZeROOptimizer<T> {
97 pub fn new(
99 base_optimizer: T,
100 config: ZeROConfig,
101 mp_context: Arc<ModelParallelContext>,
102 ) -> Result<Self> {
103 let mut optimizer = Self {
104 base_optimizer,
105 config: config.clone(),
106 mp_context: mp_context.clone(),
107 zero_state: ZeROState::new(),
108 stage1: None,
109 stage2: None,
110 stage3: None,
111 memory_stats: ZeROMemoryStats::new(),
112 parameter_names: Vec::new(),
113 };
114
115 optimizer.initialize_stage(config.stage)?;
117
118 Ok(optimizer)
119 }
120
121 fn initialize_stage(&mut self, stage: ZeROStage) -> Result<()> {
123 match stage {
124 ZeROStage::Stage1 => {
125 self.stage1 = Some(ZeROStage1::new(
126 self.mp_context.clone(),
127 self.config.clone(),
128 )?);
129 },
130 ZeROStage::Stage2 => {
131 self.stage2 = Some(ZeROStage2::new(
132 self.mp_context.clone(),
133 self.config.clone(),
134 )?);
135 },
136 ZeROStage::Stage3 => {
137 self.stage3 = Some(ZeROStage3::new(
138 self.mp_context.clone(),
139 self.config.clone(),
140 )?);
141 },
142 }
143 Ok(())
144 }
145
146 pub fn register_parameters(&mut self, parameters: HashMap<String, Tensor>) -> Result<()> {
148 self.parameter_names = parameters.keys().cloned().collect();
149
150 match self.config.stage {
151 ZeROStage::Stage1 => {
152 if let Some(stage1) = &mut self.stage1 {
153 stage1.register_parameters(parameters)?;
154 }
155 },
156 ZeROStage::Stage2 => {
157 if let Some(stage2) = &mut self.stage2 {
158 stage2.register_parameters(parameters)?;
159 }
160 },
161 ZeROStage::Stage3 => {
162 if let Some(stage3) = &mut self.stage3 {
163 stage3.register_parameters(parameters)?;
164 }
165 },
166 }
167
168 self.update_memory_stats();
169 Ok(())
170 }
171
172 pub fn update_gradients(&mut self, gradients: HashMap<String, Tensor>) -> Result<()> {
174 match self.config.stage {
175 ZeROStage::Stage1 => {
176 for (name, grad) in gradients {
178 if let Some(stage1) = &mut self.stage1 {
179 stage1.accumulate_gradient(&name, &grad)?;
180 }
181 }
182 },
183 ZeROStage::Stage2 => {
184 if let Some(stage2) = &mut self.stage2 {
185 stage2.update_gradients(gradients)?;
186 }
187 },
188 ZeROStage::Stage3 => {
189 if let Some(stage3) = &mut self.stage3 {
190 stage3.update_gradients(gradients)?;
191 }
192 },
193 }
194 Ok(())
195 }
196
197 pub fn gather_parameters(
199 &mut self,
200 parameter_names: &[String],
201 ) -> Result<HashMap<String, Tensor>> {
202 match self.config.stage {
203 ZeROStage::Stage3 => {
204 if let Some(stage3) = &mut self.stage3 {
205 stage3.gather_parameters(parameter_names)
206 } else {
207 Err(TrustformersError::runtime_error(
208 "Stage 3 not initialized".into(),
209 ))
210 }
211 },
212 _ => {
213 Err(TrustformersError::runtime_error(
215 "Parameter gathering only available in Stage 3".into(),
216 ))
217 },
218 }
219 }
220
221 pub fn release_parameters(&mut self, parameter_names: &[String]) -> Result<()> {
223 match self.config.stage {
224 ZeROStage::Stage3 => {
225 if let Some(stage3) = &mut self.stage3 {
226 stage3.release_parameters(parameter_names)
227 } else {
228 Err(TrustformersError::runtime_error(
229 "Stage 3 not initialized".into(),
230 ))
231 }
232 },
233 _ => Ok(()), }
235 }
236
237 pub fn get_memory_stats(&self) -> &ZeROMemoryStats {
239 &self.memory_stats
240 }
241
242 fn update_memory_stats(&mut self) {
244 let memory_usage = self.zero_state.memory_usage();
245
246 self.memory_stats.optimizer_memory_saved =
247 memory_usage.get("optimizer_states").copied().unwrap_or(0);
248 self.memory_stats.gradient_memory_saved =
249 memory_usage.get("gradient_partitions").copied().unwrap_or(0);
250 self.memory_stats.parameter_memory_saved =
251 memory_usage.get("parameter_partitions").copied().unwrap_or(0);
252 self.memory_stats.communication_overhead =
253 memory_usage.get("communication_buffers").copied().unwrap_or(0);
254
255 self.memory_stats.update_totals();
256 }
257
258 pub fn check_memory_usage(&self) -> bool {
260 let total_memory_mb = self.memory_stats.total_memory_saved / (1024 * 1024);
261 total_memory_mb > self.config.max_memory_usage_mb
262 }
263
264 pub fn get_stage(&self) -> ZeROStage {
266 self.config.stage
267 }
268
269 pub fn base_optimizer(&self) -> &T {
271 &self.base_optimizer
272 }
273
274 pub fn base_optimizer_mut(&mut self) -> &mut T {
276 &mut self.base_optimizer
277 }
278
279 pub fn mp_context(&self) -> &Arc<ModelParallelContext> {
281 &self.mp_context
282 }
283
284 pub fn optimizer_step(&mut self) -> Result<()> {
286 match self.config.stage {
287 ZeROStage::Stage1 => {
288 if let Some(stage1) = &mut self.stage1 {
289 stage1.optimizer_step(&mut self.base_optimizer)?;
290 }
291 },
292 ZeROStage::Stage2 => {
293 if let Some(stage2) = &mut self.stage2 {
294 stage2.optimizer_step(&mut self.base_optimizer)?;
295 }
296 },
297 ZeROStage::Stage3 => {
298 if let Some(stage3) = &mut self.stage3 {
299 stage3.optimizer_step(&mut self.base_optimizer)?;
300 }
301 },
302 }
303
304 self.zero_state.step();
305 self.update_memory_stats();
306 Ok(())
307 }
308}
309
310impl<T: Optimizer> Optimizer for ZeROOptimizer<T> {
311 fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
312 match self.config.stage {
315 ZeROStage::Stage1 => {
316 if let Some(stage1) = &mut self.stage1 {
317 stage1.update_parameter(parameter, grad, &mut self.base_optimizer)
318 } else {
319 self.base_optimizer.update(parameter, grad)
320 }
321 },
322 ZeROStage::Stage2 | ZeROStage::Stage3 => {
323 Err(TrustformersError::runtime_error(
326 "Individual parameter updates not supported in ZeRO Stage 2/3. Use batch updates."
327 .into()
328 ))
329 },
330 }
331 }
332
333 fn zero_grad(&mut self) {
334 self.zero_state.zero_grad();
335 self.base_optimizer.zero_grad();
336 }
337
338 fn step(&mut self) {
339 self.base_optimizer.step();
340 self.zero_state.step();
341 }
342
343 fn get_lr(&self) -> f32 {
344 self.base_optimizer.get_lr()
345 }
346
347 fn set_lr(&mut self, lr: f32) {
348 self.base_optimizer.set_lr(lr);
349 }
350
351 fn accumulate_grad(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
352 match self.config.stage {
354 ZeROStage::Stage1 => {
355 if let Some(stage1) = &mut self.stage1 {
356 stage1.accumulate_gradient_for_parameter(parameter, grad)
358 } else {
359 self.base_optimizer.accumulate_grad(parameter, grad)
360 }
361 },
362 ZeROStage::Stage2 | ZeROStage::Stage3 => {
363 Err(TrustformersError::runtime_error(
365 "Gradient accumulation in ZeRO Stage 2/3 should be handled through update_gradients"
366 .into()
367 ))
368 },
369 }
370 }
371
372 fn apply_accumulated_grads(&mut self, accumulation_steps: usize) -> Result<()> {
373 match self.config.stage {
374 ZeROStage::Stage1 => {
375 if let Some(stage1) = &mut self.stage1 {
376 stage1.apply_accumulated_gradients(&mut self.base_optimizer, accumulation_steps)
377 } else {
378 self.base_optimizer.apply_accumulated_grads(accumulation_steps)
379 }
380 },
381 ZeROStage::Stage2 => {
382 if let Some(stage2) = &mut self.stage2 {
383 stage2.apply_accumulated_gradients(&mut self.base_optimizer, accumulation_steps)
384 } else {
385 Err(TrustformersError::runtime_error(
386 "Stage 2 not initialized".into(),
387 ))
388 }
389 },
390 ZeROStage::Stage3 => {
391 if let Some(stage3) = &mut self.stage3 {
392 stage3.apply_accumulated_gradients(&mut self.base_optimizer, accumulation_steps)
393 } else {
394 Err(TrustformersError::runtime_error(
395 "Stage 3 not initialized".into(),
396 ))
397 }
398 },
399 }
400 }
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406 use crate::adam::Adam;
407 use trustformers_core::parallel::{
408 CommunicationBackend, ModelParallelConfig, ModelParallelStrategy,
409 };
410
411 #[test]
412 fn test_zero_optimizer_creation() {
413 let config = ModelParallelConfig {
414 num_devices: 2,
415 device_ids: vec![0, 1],
416 strategy: ModelParallelStrategy::Pipeline,
417 comm_backend: CommunicationBackend::Custom,
418 ..Default::default()
419 };
420 let mp_context = Arc::new(ModelParallelContext::new(config).unwrap());
421
422 let adam = Adam::new(0.001, (0.9, 0.999), 1e-8, 0.01);
423 let zero_config = ZeROConfig::default();
424
425 let zero_optimizer = ZeROOptimizer::new(adam, zero_config, mp_context);
426 assert!(zero_optimizer.is_ok());
427
428 let optimizer = zero_optimizer.unwrap();
429 assert_eq!(optimizer.get_stage(), ZeROStage::Stage1);
430 }
431
432 #[test]
433 fn test_zero_stage_initialization() {
434 let config = ModelParallelConfig {
435 num_devices: 4,
436 device_ids: vec![0, 1, 2, 3],
437 strategy: ModelParallelStrategy::Pipeline,
438 comm_backend: CommunicationBackend::Custom,
439 ..Default::default()
440 };
441 let mp_context = Arc::new(ModelParallelContext::new(config).unwrap());
442
443 let adam = Adam::new(0.001, (0.9, 0.999), 1e-8, 0.01);
445 let zero_config = ZeROConfig {
446 stage: ZeROStage::Stage2,
447 ..Default::default()
448 };
449
450 let zero_optimizer = ZeROOptimizer::new(adam, zero_config, mp_context.clone());
451 assert!(zero_optimizer.is_ok());
452
453 let adam = Adam::new(0.001, (0.9, 0.999), 1e-8, 0.01);
455 let zero_config = ZeROConfig {
456 stage: ZeROStage::Stage3,
457 ..Default::default()
458 };
459
460 let zero_optimizer = ZeROOptimizer::new(adam, zero_config, mp_context);
461 assert!(zero_optimizer.is_ok());
462 }
463
464 #[test]
465 fn test_parameter_registration() {
466 let config = ModelParallelConfig {
467 num_devices: 2,
468 device_ids: vec![0, 1],
469 strategy: ModelParallelStrategy::Pipeline,
470 comm_backend: CommunicationBackend::Custom,
471 ..Default::default()
472 };
473 let mp_context = Arc::new(ModelParallelContext::new(config).unwrap());
474
475 let adam = Adam::new(0.001, (0.9, 0.999), 1e-8, 0.01);
476 let zero_config = ZeROConfig::default();
477 let mut zero_optimizer = ZeROOptimizer::new(adam, zero_config, mp_context).unwrap();
478
479 let mut parameters = HashMap::new();
480 parameters.insert("weight1".to_string(), Tensor::ones(&[4, 4]).unwrap());
481 parameters.insert("bias1".to_string(), Tensor::ones(&[4]).unwrap());
482
483 let result = zero_optimizer.register_parameters(parameters);
484 assert!(result.is_ok());
485 assert_eq!(zero_optimizer.parameter_names.len(), 2);
486 }
487
488 #[test]
489 fn test_memory_stats() {
490 let config = ModelParallelConfig {
491 num_devices: 2,
492 device_ids: vec![0, 1],
493 strategy: ModelParallelStrategy::Pipeline,
494 comm_backend: CommunicationBackend::Custom,
495 ..Default::default()
496 };
497 let mp_context = Arc::new(ModelParallelContext::new(config).unwrap());
498
499 let adam = Adam::new(0.001, (0.9, 0.999), 1e-8, 0.01);
500 let zero_config = ZeROConfig::default();
501 let zero_optimizer = ZeROOptimizer::new(adam, zero_config, mp_context).unwrap();
502
503 let stats = zero_optimizer.get_memory_stats();
504 assert_eq!(stats.optimizer_memory_saved, 0); assert_eq!(stats.gradient_memory_saved, 0);
506 assert_eq!(stats.parameter_memory_saved, 0);
507 }
508}