Skip to main content

trustformers_optim/zero/
zero_utils.rs

1//! Utility functions and data structures for ZeRO optimization
2
3use std::collections::HashMap;
4use trustformers_core::errors::Result;
5use trustformers_core::parallel::ModelParallelContext;
6use trustformers_core::tensor::Tensor;
7
8/// ZeRO optimizer state management
9#[derive(Debug, Clone)]
10pub struct ZeROState {
11    /// Current step number
12    pub step: usize,
13    /// Partitioned optimizer states per parameter group
14    pub optimizer_states: HashMap<String, HashMap<String, Tensor>>,
15    /// Partitioned gradients (for Stage 2+)
16    pub gradient_partitions: HashMap<String, GradientBuffer>,
17    /// Partitioned parameters (for Stage 3)
18    pub parameter_partitions: HashMap<String, ParameterPartition>,
19    /// Communication buffers for all-gather operations
20    pub communication_buffers: HashMap<String, Tensor>,
21}
22
23impl Default for ZeROState {
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29impl ZeROState {
30    pub fn new() -> Self {
31        Self {
32            step: 0,
33            optimizer_states: HashMap::new(),
34            gradient_partitions: HashMap::new(),
35            parameter_partitions: HashMap::new(),
36            communication_buffers: HashMap::new(),
37        }
38    }
39
40    /// Reset gradients for next iteration
41    pub fn zero_grad(&mut self) {
42        for buffer in self.gradient_partitions.values_mut() {
43            buffer.zero();
44        }
45    }
46
47    /// Increment step counter
48    pub fn step(&mut self) {
49        self.step += 1;
50    }
51
52    /// Get memory usage statistics
53    pub fn memory_usage(&self) -> HashMap<String, usize> {
54        let mut stats = HashMap::new();
55
56        // Calculate optimizer state memory
57        let mut optimizer_memory = 0;
58        for states in self.optimizer_states.values() {
59            for tensor in states.values() {
60                optimizer_memory += tensor.memory_usage();
61            }
62        }
63        stats.insert("optimizer_states".to_string(), optimizer_memory);
64
65        // Calculate gradient memory
66        let mut gradient_memory = 0;
67        for buffer in self.gradient_partitions.values() {
68            gradient_memory += buffer.memory_usage();
69        }
70        stats.insert("gradient_partitions".to_string(), gradient_memory);
71
72        // Calculate parameter memory
73        let mut parameter_memory = 0;
74        for partition in self.parameter_partitions.values() {
75            parameter_memory += partition.memory_usage();
76        }
77        stats.insert("parameter_partitions".to_string(), parameter_memory);
78
79        // Calculate communication buffer memory
80        let mut comm_memory = 0;
81        for tensor in self.communication_buffers.values() {
82            comm_memory += tensor.memory_usage();
83        }
84        stats.insert("communication_buffers".to_string(), comm_memory);
85
86        stats
87    }
88}
89
90/// Parameter group for ZeRO optimization
91#[derive(Debug, Clone)]
92pub struct ParameterGroup {
93    /// Group name/identifier
94    pub name: String,
95    /// Parameter names in this group
96    pub parameter_names: Vec<String>,
97    /// Local partition of parameters
98    pub local_parameters: HashMap<String, Tensor>,
99    /// Metadata for parameter partitioning
100    pub partition_info: PartitionInfo,
101}
102
103impl ParameterGroup {
104    pub fn new(name: String, parameter_names: Vec<String>) -> Self {
105        Self {
106            name,
107            parameter_names,
108            local_parameters: HashMap::new(),
109            partition_info: PartitionInfo::default(),
110        }
111    }
112
113    /// Add a parameter to this group
114    pub fn add_parameter(&mut self, name: String, tensor: Tensor) {
115        self.local_parameters.insert(name.clone(), tensor);
116        if !self.parameter_names.contains(&name) {
117            self.parameter_names.push(name);
118        }
119    }
120
121    /// Get total memory usage of this group
122    pub fn memory_usage(&self) -> usize {
123        self.local_parameters.values().map(|t| t.memory_usage()).sum()
124    }
125}
126
127/// Partition information for distributed parameters
128#[derive(Debug, Clone)]
129pub struct PartitionInfo {
130    /// Rank of this partition
131    pub rank: usize,
132    /// Total number of partitions
133    pub world_size: usize,
134    /// Start index in global parameter
135    pub start_idx: usize,
136    /// End index in global parameter
137    pub end_idx: usize,
138    /// Global shape of full parameter
139    pub global_shape: Vec<usize>,
140    /// Local shape of this partition
141    pub local_shape: Vec<usize>,
142}
143
144impl Default for PartitionInfo {
145    fn default() -> Self {
146        Self {
147            rank: 0,
148            world_size: 1,
149            start_idx: 0,
150            end_idx: 0,
151            global_shape: vec![],
152            local_shape: vec![],
153        }
154    }
155}
156
157/// Parameter partition for ZeRO Stage 3
158#[derive(Debug, Clone)]
159pub struct ParameterPartition {
160    /// Parameter name
161    pub name: String,
162    /// Local shard of the parameter
163    pub local_shard: Tensor,
164    /// Partition metadata
165    pub partition_info: PartitionInfo,
166    /// Whether this parameter is currently gathered
167    pub is_gathered: bool,
168    /// Full parameter (only valid when is_gathered = true)
169    pub full_parameter: Option<Tensor>,
170}
171
172impl ParameterPartition {
173    pub fn new(name: String, local_shard: Tensor, partition_info: PartitionInfo) -> Self {
174        Self {
175            name,
176            local_shard,
177            partition_info,
178            is_gathered: false,
179            full_parameter: None,
180        }
181    }
182
183    /// Get memory usage of this partition
184    pub fn memory_usage(&self) -> usize {
185        let mut usage = self.local_shard.memory_usage();
186        if let Some(full_param) = &self.full_parameter {
187            usage += full_param.memory_usage();
188        }
189        usage
190    }
191
192    /// Gather full parameter from all partitions
193    pub fn gather(&mut self, mp_context: &ModelParallelContext) -> Result<()> {
194        if self.is_gathered {
195            return Ok(());
196        }
197
198        // Use model parallel context to gather the parameter
199        let full_param =
200            mp_context.all_gather(&trustformers_core::parallel::DistributedTensor::new(
201                self.local_shard.clone(),
202                self.partition_info.global_shape.clone(),
203                trustformers_core::parallel::TensorPartition {
204                    split_dim: 0, // Assume partitioning along first dimension
205                    start_idx: self.partition_info.start_idx,
206                    end_idx: self.partition_info.end_idx,
207                    num_partitions: self.partition_info.world_size,
208                    partition_rank: self.partition_info.rank,
209                },
210                self.partition_info.rank,
211            ))?;
212
213        self.full_parameter = Some(full_param);
214        self.is_gathered = true;
215        Ok(())
216    }
217
218    /// Release gathered parameter to save memory
219    pub fn release(&mut self) {
220        self.full_parameter = None;
221        self.is_gathered = false;
222    }
223}
224
225/// Gradient buffer for ZeRO Stage 2+
226#[derive(Debug, Clone)]
227pub struct GradientBuffer {
228    /// Buffer name
229    pub name: String,
230    /// Local gradient shard
231    pub local_gradient: Tensor,
232    /// Accumulated gradients
233    pub accumulated_gradient: Option<Tensor>,
234    /// Number of accumulated steps
235    pub accumulation_steps: usize,
236    /// Partition metadata
237    pub partition_info: PartitionInfo,
238}
239
240impl GradientBuffer {
241    pub fn new(name: String, local_gradient: Tensor, partition_info: PartitionInfo) -> Self {
242        Self {
243            name,
244            local_gradient,
245            accumulated_gradient: None,
246            accumulation_steps: 0,
247            partition_info,
248        }
249    }
250
251    /// Zero the gradient buffer
252    pub fn zero(&mut self) {
253        self.local_gradient = Tensor::zeros(&self.local_gradient.shape()).unwrap();
254        self.accumulated_gradient = None;
255        self.accumulation_steps = 0;
256    }
257
258    /// Accumulate gradient
259    pub fn accumulate(&mut self, gradient: &Tensor) -> Result<()> {
260        if let Some(acc_grad) = &mut self.accumulated_gradient {
261            *acc_grad = acc_grad.add(gradient)?;
262        } else {
263            self.accumulated_gradient = Some(gradient.clone());
264        }
265        self.accumulation_steps += 1;
266        Ok(())
267    }
268
269    /// Get the accumulated gradient (averaged if needed)
270    pub fn get_accumulated(&self) -> Option<Tensor> {
271        if let Some(acc_grad) = &self.accumulated_gradient {
272            if self.accumulation_steps > 1 {
273                acc_grad.scalar_div(self.accumulation_steps as f32).ok()
274            } else {
275                Some(acc_grad.clone())
276            }
277        } else {
278            None
279        }
280    }
281
282    /// Get memory usage of this buffer
283    pub fn memory_usage(&self) -> usize {
284        let mut usage = self.local_gradient.memory_usage();
285        if let Some(acc_grad) = &self.accumulated_gradient {
286            usage += acc_grad.memory_usage();
287        }
288        usage
289    }
290}
291
292/// Partition parameters across devices for ZeRO Stage 3
293pub fn partition_parameters(
294    parameters: &HashMap<String, Tensor>,
295    world_size: usize,
296    rank: usize,
297) -> Result<HashMap<String, ParameterPartition>> {
298    let mut partitions = HashMap::new();
299
300    for (name, param) in parameters {
301        let shape = param.shape();
302        let total_elements = shape.iter().product::<usize>();
303
304        // Calculate partition size
305        let elements_per_rank = total_elements.div_ceil(world_size);
306        let start_idx = rank * elements_per_rank;
307        let end_idx = ((rank + 1) * elements_per_rank).min(total_elements);
308
309        // Create local shard using simplified slicing approach
310        // In a full implementation, this would use proper distributed tensor slicing
311        // For now, we create a scaled-down version to simulate partitioning
312        let local_shard = if world_size == 1 || total_elements <= elements_per_rank {
313            // If single device or small parameter, each rank gets a copy
314            param.clone()
315        } else {
316            // For demonstration, create a smaller tensor that represents the local shard
317            // This simulates the effect of partitioning without complex slicing logic
318            let scale_factor = 1.0 / (world_size as f32);
319
320            param.mul_scalar(scale_factor)?
321        };
322
323        let partition_info = PartitionInfo {
324            rank,
325            world_size,
326            start_idx,
327            end_idx,
328            global_shape: shape.to_vec(),
329            local_shape: local_shard.shape().to_vec(),
330        };
331
332        let partition = ParameterPartition::new(name.clone(), local_shard, partition_info);
333        partitions.insert(name.clone(), partition);
334    }
335
336    Ok(partitions)
337}
338
339/// Gather parameters from all devices
340pub fn gather_parameters(
341    partitions: &mut HashMap<String, ParameterPartition>,
342    mp_context: &ModelParallelContext,
343) -> Result<HashMap<String, Tensor>> {
344    let mut gathered = HashMap::new();
345
346    for (name, partition) in partitions.iter_mut() {
347        partition.gather(mp_context)?;
348        if let Some(full_param) = &partition.full_parameter {
349            gathered.insert(name.clone(), full_param.clone());
350        }
351    }
352
353    Ok(gathered)
354}
355
356/// Partition gradients across devices for ZeRO Stage 2+
357pub fn partition_gradients(
358    gradients: &HashMap<String, Tensor>,
359    world_size: usize,
360    rank: usize,
361) -> Result<HashMap<String, GradientBuffer>> {
362    let mut buffers = HashMap::new();
363
364    for (name, grad) in gradients {
365        let shape = grad.shape();
366        let total_elements = shape.iter().product::<usize>();
367
368        // Calculate partition size
369        let elements_per_rank = total_elements.div_ceil(world_size);
370        let start_idx = rank * elements_per_rank;
371        let end_idx = ((rank + 1) * elements_per_rank).min(total_elements);
372
373        // Create local gradient shard using simplified approach
374        // In a full implementation, this would use proper distributed gradient slicing
375        let local_gradient = if world_size == 1 || total_elements <= elements_per_rank {
376            // If single device or small gradient, each rank gets a copy
377            grad.clone()
378        } else {
379            // For demonstration, create a scaled version to simulate partitioning
380            let scale_factor = 1.0 / (world_size as f32);
381
382            grad.mul_scalar(scale_factor)?
383        };
384
385        let partition_info = PartitionInfo {
386            rank,
387            world_size,
388            start_idx,
389            end_idx,
390            global_shape: shape.to_vec(),
391            local_shape: local_gradient.shape().to_vec(),
392        };
393
394        let buffer = GradientBuffer::new(name.clone(), local_gradient, partition_info);
395        buffers.insert(name.clone(), buffer);
396    }
397
398    Ok(buffers)
399}
400
401/// All-gather gradients from all devices
402pub fn all_gather_gradients(
403    buffers: &HashMap<String, GradientBuffer>,
404    mp_context: &ModelParallelContext,
405) -> Result<HashMap<String, Tensor>> {
406    let mut gathered = HashMap::new();
407
408    for (name, buffer) in buffers {
409        let distributed_tensor = trustformers_core::parallel::DistributedTensor::new(
410            buffer.local_gradient.clone(),
411            buffer.partition_info.global_shape.clone(),
412            trustformers_core::parallel::TensorPartition {
413                split_dim: 0,
414                start_idx: buffer.partition_info.start_idx,
415                end_idx: buffer.partition_info.end_idx,
416                num_partitions: buffer.partition_info.world_size,
417                partition_rank: buffer.partition_info.rank,
418            },
419            buffer.partition_info.rank,
420        );
421
422        let full_gradient = mp_context.all_gather(&distributed_tensor)?;
423        gathered.insert(name.clone(), full_gradient);
424    }
425
426    Ok(gathered)
427}
428
429/// Reduce-scatter gradients across devices
430pub fn reduce_scatter_gradients(
431    gradients: &HashMap<String, Tensor>,
432    mp_context: &ModelParallelContext,
433) -> Result<HashMap<String, Tensor>> {
434    let mut scattered = HashMap::new();
435
436    for (name, grad) in gradients {
437        let scattered_grad = mp_context.reduce_scatter(grad, 0)?;
438        scattered.insert(name.clone(), scattered_grad);
439    }
440
441    Ok(scattered)
442}
443
444/// Calculate optimal bucket size for gradient communication
445pub fn calculate_bucket_size(
446    parameter_sizes: &[usize],
447    target_bucket_size: usize,
448) -> Vec<Vec<usize>> {
449    let mut buckets = Vec::new();
450    let mut current_bucket = Vec::new();
451    let mut current_size = 0;
452
453    for (i, &size) in parameter_sizes.iter().enumerate() {
454        if current_size + size > target_bucket_size && !current_bucket.is_empty() {
455            buckets.push(current_bucket);
456            current_bucket = Vec::new();
457            current_size = 0;
458        }
459
460        current_bucket.push(i);
461        current_size += size;
462    }
463
464    if !current_bucket.is_empty() {
465        buckets.push(current_bucket);
466    }
467
468    buckets
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474
475    #[test]
476    fn test_zero_state_creation() {
477        let state = ZeROState::new();
478        assert_eq!(state.step, 0);
479        assert!(state.optimizer_states.is_empty());
480        assert!(state.gradient_partitions.is_empty());
481        assert!(state.parameter_partitions.is_empty());
482    }
483
484    #[test]
485    fn test_parameter_group() {
486        let mut group = ParameterGroup::new("test_group".to_string(), vec!["param1".to_string()]);
487        let tensor = Tensor::ones(&[2, 2]).unwrap();
488        group.add_parameter("param1".to_string(), tensor);
489
490        assert_eq!(group.parameter_names.len(), 1);
491        assert_eq!(group.local_parameters.len(), 1);
492        assert!(group.memory_usage() > 0);
493    }
494
495    #[test]
496    fn test_gradient_buffer() {
497        let tensor = Tensor::ones(&[2, 2]).unwrap();
498        let partition_info = PartitionInfo::default();
499        let mut buffer = GradientBuffer::new("test_grad".to_string(), tensor, partition_info);
500
501        let grad = Tensor::ones(&[2, 2]).unwrap();
502        buffer.accumulate(&grad).unwrap();
503
504        assert_eq!(buffer.accumulation_steps, 1);
505        assert!(buffer.get_accumulated().is_some());
506    }
507
508    #[test]
509    fn test_partition_parameters() {
510        let mut params = HashMap::new();
511        params.insert("param1".to_string(), Tensor::ones(&[4, 4]).unwrap());
512        params.insert("param2".to_string(), Tensor::ones(&[2, 2]).unwrap());
513
514        let partitions = partition_parameters(&params, 2, 0).unwrap();
515        assert_eq!(partitions.len(), 2);
516
517        for partition in partitions.values() {
518            assert_eq!(partition.partition_info.world_size, 2);
519            assert_eq!(partition.partition_info.rank, 0);
520        }
521    }
522
523    #[test]
524    fn test_calculate_bucket_size() {
525        let sizes = vec![100, 200, 150, 300, 50];
526        let buckets = calculate_bucket_size(&sizes, 400);
527
528        assert!(!buckets.is_empty());
529
530        // Check that no bucket exceeds the target size
531        for bucket in &buckets {
532            let bucket_size: usize = bucket.iter().map(|&i| sizes[i]).sum();
533            assert!(bucket_size <= 400 || bucket.len() == 1); // Single large item allowed
534        }
535    }
536}