1use std::collections::HashMap;
4use trustformers_core::errors::Result;
5use trustformers_core::parallel::ModelParallelContext;
6use trustformers_core::tensor::Tensor;
7
8#[derive(Debug, Clone)]
10pub struct ZeROState {
11 pub step: usize,
13 pub optimizer_states: HashMap<String, HashMap<String, Tensor>>,
15 pub gradient_partitions: HashMap<String, GradientBuffer>,
17 pub parameter_partitions: HashMap<String, ParameterPartition>,
19 pub communication_buffers: HashMap<String, Tensor>,
21}
22
23impl Default for ZeROState {
24 fn default() -> Self {
25 Self::new()
26 }
27}
28
29impl ZeROState {
30 pub fn new() -> Self {
31 Self {
32 step: 0,
33 optimizer_states: HashMap::new(),
34 gradient_partitions: HashMap::new(),
35 parameter_partitions: HashMap::new(),
36 communication_buffers: HashMap::new(),
37 }
38 }
39
40 pub fn zero_grad(&mut self) {
42 for buffer in self.gradient_partitions.values_mut() {
43 buffer.zero();
44 }
45 }
46
47 pub fn step(&mut self) {
49 self.step += 1;
50 }
51
52 pub fn memory_usage(&self) -> HashMap<String, usize> {
54 let mut stats = HashMap::new();
55
56 let mut optimizer_memory = 0;
58 for states in self.optimizer_states.values() {
59 for tensor in states.values() {
60 optimizer_memory += tensor.memory_usage();
61 }
62 }
63 stats.insert("optimizer_states".to_string(), optimizer_memory);
64
65 let mut gradient_memory = 0;
67 for buffer in self.gradient_partitions.values() {
68 gradient_memory += buffer.memory_usage();
69 }
70 stats.insert("gradient_partitions".to_string(), gradient_memory);
71
72 let mut parameter_memory = 0;
74 for partition in self.parameter_partitions.values() {
75 parameter_memory += partition.memory_usage();
76 }
77 stats.insert("parameter_partitions".to_string(), parameter_memory);
78
79 let mut comm_memory = 0;
81 for tensor in self.communication_buffers.values() {
82 comm_memory += tensor.memory_usage();
83 }
84 stats.insert("communication_buffers".to_string(), comm_memory);
85
86 stats
87 }
88}
89
90#[derive(Debug, Clone)]
92pub struct ParameterGroup {
93 pub name: String,
95 pub parameter_names: Vec<String>,
97 pub local_parameters: HashMap<String, Tensor>,
99 pub partition_info: PartitionInfo,
101}
102
103impl ParameterGroup {
104 pub fn new(name: String, parameter_names: Vec<String>) -> Self {
105 Self {
106 name,
107 parameter_names,
108 local_parameters: HashMap::new(),
109 partition_info: PartitionInfo::default(),
110 }
111 }
112
113 pub fn add_parameter(&mut self, name: String, tensor: Tensor) {
115 self.local_parameters.insert(name.clone(), tensor);
116 if !self.parameter_names.contains(&name) {
117 self.parameter_names.push(name);
118 }
119 }
120
121 pub fn memory_usage(&self) -> usize {
123 self.local_parameters.values().map(|t| t.memory_usage()).sum()
124 }
125}
126
127#[derive(Debug, Clone)]
129pub struct PartitionInfo {
130 pub rank: usize,
132 pub world_size: usize,
134 pub start_idx: usize,
136 pub end_idx: usize,
138 pub global_shape: Vec<usize>,
140 pub local_shape: Vec<usize>,
142}
143
144impl Default for PartitionInfo {
145 fn default() -> Self {
146 Self {
147 rank: 0,
148 world_size: 1,
149 start_idx: 0,
150 end_idx: 0,
151 global_shape: vec![],
152 local_shape: vec![],
153 }
154 }
155}
156
157#[derive(Debug, Clone)]
159pub struct ParameterPartition {
160 pub name: String,
162 pub local_shard: Tensor,
164 pub partition_info: PartitionInfo,
166 pub is_gathered: bool,
168 pub full_parameter: Option<Tensor>,
170}
171
172impl ParameterPartition {
173 pub fn new(name: String, local_shard: Tensor, partition_info: PartitionInfo) -> Self {
174 Self {
175 name,
176 local_shard,
177 partition_info,
178 is_gathered: false,
179 full_parameter: None,
180 }
181 }
182
183 pub fn memory_usage(&self) -> usize {
185 let mut usage = self.local_shard.memory_usage();
186 if let Some(full_param) = &self.full_parameter {
187 usage += full_param.memory_usage();
188 }
189 usage
190 }
191
192 pub fn gather(&mut self, mp_context: &ModelParallelContext) -> Result<()> {
194 if self.is_gathered {
195 return Ok(());
196 }
197
198 let full_param =
200 mp_context.all_gather(&trustformers_core::parallel::DistributedTensor::new(
201 self.local_shard.clone(),
202 self.partition_info.global_shape.clone(),
203 trustformers_core::parallel::TensorPartition {
204 split_dim: 0, start_idx: self.partition_info.start_idx,
206 end_idx: self.partition_info.end_idx,
207 num_partitions: self.partition_info.world_size,
208 partition_rank: self.partition_info.rank,
209 },
210 self.partition_info.rank,
211 ))?;
212
213 self.full_parameter = Some(full_param);
214 self.is_gathered = true;
215 Ok(())
216 }
217
218 pub fn release(&mut self) {
220 self.full_parameter = None;
221 self.is_gathered = false;
222 }
223}
224
225#[derive(Debug, Clone)]
227pub struct GradientBuffer {
228 pub name: String,
230 pub local_gradient: Tensor,
232 pub accumulated_gradient: Option<Tensor>,
234 pub accumulation_steps: usize,
236 pub partition_info: PartitionInfo,
238}
239
240impl GradientBuffer {
241 pub fn new(name: String, local_gradient: Tensor, partition_info: PartitionInfo) -> Self {
242 Self {
243 name,
244 local_gradient,
245 accumulated_gradient: None,
246 accumulation_steps: 0,
247 partition_info,
248 }
249 }
250
251 pub fn zero(&mut self) {
253 self.local_gradient = Tensor::zeros(&self.local_gradient.shape()).unwrap();
254 self.accumulated_gradient = None;
255 self.accumulation_steps = 0;
256 }
257
258 pub fn accumulate(&mut self, gradient: &Tensor) -> Result<()> {
260 if let Some(acc_grad) = &mut self.accumulated_gradient {
261 *acc_grad = acc_grad.add(gradient)?;
262 } else {
263 self.accumulated_gradient = Some(gradient.clone());
264 }
265 self.accumulation_steps += 1;
266 Ok(())
267 }
268
269 pub fn get_accumulated(&self) -> Option<Tensor> {
271 if let Some(acc_grad) = &self.accumulated_gradient {
272 if self.accumulation_steps > 1 {
273 acc_grad.scalar_div(self.accumulation_steps as f32).ok()
274 } else {
275 Some(acc_grad.clone())
276 }
277 } else {
278 None
279 }
280 }
281
282 pub fn memory_usage(&self) -> usize {
284 let mut usage = self.local_gradient.memory_usage();
285 if let Some(acc_grad) = &self.accumulated_gradient {
286 usage += acc_grad.memory_usage();
287 }
288 usage
289 }
290}
291
292pub fn partition_parameters(
294 parameters: &HashMap<String, Tensor>,
295 world_size: usize,
296 rank: usize,
297) -> Result<HashMap<String, ParameterPartition>> {
298 let mut partitions = HashMap::new();
299
300 for (name, param) in parameters {
301 let shape = param.shape();
302 let total_elements = shape.iter().product::<usize>();
303
304 let elements_per_rank = total_elements.div_ceil(world_size);
306 let start_idx = rank * elements_per_rank;
307 let end_idx = ((rank + 1) * elements_per_rank).min(total_elements);
308
309 let local_shard = if world_size == 1 || total_elements <= elements_per_rank {
313 param.clone()
315 } else {
316 let scale_factor = 1.0 / (world_size as f32);
319
320 param.mul_scalar(scale_factor)?
321 };
322
323 let partition_info = PartitionInfo {
324 rank,
325 world_size,
326 start_idx,
327 end_idx,
328 global_shape: shape.to_vec(),
329 local_shape: local_shard.shape().to_vec(),
330 };
331
332 let partition = ParameterPartition::new(name.clone(), local_shard, partition_info);
333 partitions.insert(name.clone(), partition);
334 }
335
336 Ok(partitions)
337}
338
339pub fn gather_parameters(
341 partitions: &mut HashMap<String, ParameterPartition>,
342 mp_context: &ModelParallelContext,
343) -> Result<HashMap<String, Tensor>> {
344 let mut gathered = HashMap::new();
345
346 for (name, partition) in partitions.iter_mut() {
347 partition.gather(mp_context)?;
348 if let Some(full_param) = &partition.full_parameter {
349 gathered.insert(name.clone(), full_param.clone());
350 }
351 }
352
353 Ok(gathered)
354}
355
356pub fn partition_gradients(
358 gradients: &HashMap<String, Tensor>,
359 world_size: usize,
360 rank: usize,
361) -> Result<HashMap<String, GradientBuffer>> {
362 let mut buffers = HashMap::new();
363
364 for (name, grad) in gradients {
365 let shape = grad.shape();
366 let total_elements = shape.iter().product::<usize>();
367
368 let elements_per_rank = total_elements.div_ceil(world_size);
370 let start_idx = rank * elements_per_rank;
371 let end_idx = ((rank + 1) * elements_per_rank).min(total_elements);
372
373 let local_gradient = if world_size == 1 || total_elements <= elements_per_rank {
376 grad.clone()
378 } else {
379 let scale_factor = 1.0 / (world_size as f32);
381
382 grad.mul_scalar(scale_factor)?
383 };
384
385 let partition_info = PartitionInfo {
386 rank,
387 world_size,
388 start_idx,
389 end_idx,
390 global_shape: shape.to_vec(),
391 local_shape: local_gradient.shape().to_vec(),
392 };
393
394 let buffer = GradientBuffer::new(name.clone(), local_gradient, partition_info);
395 buffers.insert(name.clone(), buffer);
396 }
397
398 Ok(buffers)
399}
400
401pub fn all_gather_gradients(
403 buffers: &HashMap<String, GradientBuffer>,
404 mp_context: &ModelParallelContext,
405) -> Result<HashMap<String, Tensor>> {
406 let mut gathered = HashMap::new();
407
408 for (name, buffer) in buffers {
409 let distributed_tensor = trustformers_core::parallel::DistributedTensor::new(
410 buffer.local_gradient.clone(),
411 buffer.partition_info.global_shape.clone(),
412 trustformers_core::parallel::TensorPartition {
413 split_dim: 0,
414 start_idx: buffer.partition_info.start_idx,
415 end_idx: buffer.partition_info.end_idx,
416 num_partitions: buffer.partition_info.world_size,
417 partition_rank: buffer.partition_info.rank,
418 },
419 buffer.partition_info.rank,
420 );
421
422 let full_gradient = mp_context.all_gather(&distributed_tensor)?;
423 gathered.insert(name.clone(), full_gradient);
424 }
425
426 Ok(gathered)
427}
428
429pub fn reduce_scatter_gradients(
431 gradients: &HashMap<String, Tensor>,
432 mp_context: &ModelParallelContext,
433) -> Result<HashMap<String, Tensor>> {
434 let mut scattered = HashMap::new();
435
436 for (name, grad) in gradients {
437 let scattered_grad = mp_context.reduce_scatter(grad, 0)?;
438 scattered.insert(name.clone(), scattered_grad);
439 }
440
441 Ok(scattered)
442}
443
444pub fn calculate_bucket_size(
446 parameter_sizes: &[usize],
447 target_bucket_size: usize,
448) -> Vec<Vec<usize>> {
449 let mut buckets = Vec::new();
450 let mut current_bucket = Vec::new();
451 let mut current_size = 0;
452
453 for (i, &size) in parameter_sizes.iter().enumerate() {
454 if current_size + size > target_bucket_size && !current_bucket.is_empty() {
455 buckets.push(current_bucket);
456 current_bucket = Vec::new();
457 current_size = 0;
458 }
459
460 current_bucket.push(i);
461 current_size += size;
462 }
463
464 if !current_bucket.is_empty() {
465 buckets.push(current_bucket);
466 }
467
468 buckets
469}
470
471#[cfg(test)]
472mod tests {
473 use super::*;
474
475 #[test]
476 fn test_zero_state_creation() {
477 let state = ZeROState::new();
478 assert_eq!(state.step, 0);
479 assert!(state.optimizer_states.is_empty());
480 assert!(state.gradient_partitions.is_empty());
481 assert!(state.parameter_partitions.is_empty());
482 }
483
484 #[test]
485 fn test_parameter_group() {
486 let mut group = ParameterGroup::new("test_group".to_string(), vec!["param1".to_string()]);
487 let tensor = Tensor::ones(&[2, 2]).unwrap();
488 group.add_parameter("param1".to_string(), tensor);
489
490 assert_eq!(group.parameter_names.len(), 1);
491 assert_eq!(group.local_parameters.len(), 1);
492 assert!(group.memory_usage() > 0);
493 }
494
495 #[test]
496 fn test_gradient_buffer() {
497 let tensor = Tensor::ones(&[2, 2]).unwrap();
498 let partition_info = PartitionInfo::default();
499 let mut buffer = GradientBuffer::new("test_grad".to_string(), tensor, partition_info);
500
501 let grad = Tensor::ones(&[2, 2]).unwrap();
502 buffer.accumulate(&grad).unwrap();
503
504 assert_eq!(buffer.accumulation_steps, 1);
505 assert!(buffer.get_accumulated().is_some());
506 }
507
508 #[test]
509 fn test_partition_parameters() {
510 let mut params = HashMap::new();
511 params.insert("param1".to_string(), Tensor::ones(&[4, 4]).unwrap());
512 params.insert("param2".to_string(), Tensor::ones(&[2, 2]).unwrap());
513
514 let partitions = partition_parameters(¶ms, 2, 0).unwrap();
515 assert_eq!(partitions.len(), 2);
516
517 for partition in partitions.values() {
518 assert_eq!(partition.partition_info.world_size, 2);
519 assert_eq!(partition.partition_info.rank, 0);
520 }
521 }
522
523 #[test]
524 fn test_calculate_bucket_size() {
525 let sizes = vec![100, 200, 150, 300, 50];
526 let buckets = calculate_bucket_size(&sizes, 400);
527
528 assert!(!buckets.is_empty());
529
530 for bucket in &buckets {
532 let bucket_size: usize = bucket.iter().map(|&i| sizes[i]).sum();
533 assert!(bucket_size <= 400 || bucket.len() == 1); }
535 }
536}