Skip to main content

torsh_distributed/zero_3_cpu_offload/
config.rs

1//! ZeRO-3 Configuration and Core Types
2//!
3//! This module provides core configuration types and rank mapping functionality for
4//! ZeRO-3 CPU offloading optimizations. It defines the configuration parameters,
5//! compression methods, memory management strategies, and distributed rank mapping.
6
7use std::collections::HashMap;
8
9/// ZeRO-3 CPU offloading configuration
10#[derive(Debug, Clone)]
11pub struct Zero3CpuOffloadConfig {
12    /// Enable parameter offloading to CPU
13    pub offload_params: bool,
14    /// Enable gradient offloading to CPU
15    pub offload_grads: bool,
16    /// Enable optimizer state offloading to CPU
17    pub offload_optimizer_states: bool,
18    /// CPU memory buffer size in bytes
19    pub cpu_memory_budget: usize,
20    /// GPU memory budget for parameters in bytes
21    pub gpu_param_memory_budget: usize,
22    /// Maximum GPU memory in MB (for memory pressure calculation)
23    pub max_gpu_memory_mb: usize,
24    /// Maximum CPU memory in MB (for memory pressure calculation)
25    pub max_cpu_memory_mb: usize,
26    /// Prefetch buffer size (number of parameters to prefetch)
27    pub prefetch_buffer_size: usize,
28    /// Enable asynchronous parameter prefetching
29    pub async_prefetch: bool,
30    /// Enable parameter overlapping (prefetch while computing)
31    pub overlap_computation: bool,
32    /// Pin CPU memory for faster transfers
33    pub pin_cpu_memory: bool,
34    /// Compression for CPU-stored parameters
35    pub cpu_compression: CpuCompressionMethod,
36    /// Automatic memory management strategy
37    pub auto_memory_management: AutoMemoryStrategy,
38}
39
40impl Default for Zero3CpuOffloadConfig {
41    fn default() -> Self {
42        Self {
43            offload_params: true,
44            offload_grads: true,
45            offload_optimizer_states: true,
46            cpu_memory_budget: 32 * 1024 * 1024 * 1024, // 32GB
47            gpu_param_memory_budget: 2 * 1024 * 1024 * 1024, // 2GB
48            max_gpu_memory_mb: 8 * 1024,                // 8GB
49            max_cpu_memory_mb: 64 * 1024,               // 64GB
50            prefetch_buffer_size: 16,
51            async_prefetch: true,
52            overlap_computation: true,
53            pin_cpu_memory: true,
54            cpu_compression: CpuCompressionMethod::None,
55            auto_memory_management: AutoMemoryStrategy::Aggressive,
56        }
57    }
58}
59
60impl Zero3CpuOffloadConfig {
61    /// Create a new configuration with custom settings
62    pub fn new() -> Self {
63        Self::default()
64    }
65
66    /// Set parameter offloading option
67    pub fn with_offload_params(mut self, offload: bool) -> Self {
68        self.offload_params = offload;
69        self
70    }
71
72    /// Set gradient offloading option
73    pub fn with_offload_grads(mut self, offload: bool) -> Self {
74        self.offload_grads = offload;
75        self
76    }
77
78    /// Set optimizer state offloading option
79    pub fn with_offload_optimizer_states(mut self, offload: bool) -> Self {
80        self.offload_optimizer_states = offload;
81        self
82    }
83
84    /// Set CPU memory budget
85    pub fn with_cpu_memory_budget(mut self, budget: usize) -> Self {
86        self.cpu_memory_budget = budget;
87        self
88    }
89
90    /// Set GPU parameter memory budget
91    pub fn with_gpu_param_memory_budget(mut self, budget: usize) -> Self {
92        self.gpu_param_memory_budget = budget;
93        self
94    }
95
96    /// Set prefetch buffer size
97    pub fn with_prefetch_buffer_size(mut self, size: usize) -> Self {
98        self.prefetch_buffer_size = size;
99        self
100    }
101
102    /// Set compression method
103    pub fn with_compression(mut self, compression: CpuCompressionMethod) -> Self {
104        self.cpu_compression = compression;
105        self
106    }
107
108    /// Set memory management strategy
109    pub fn with_memory_strategy(mut self, strategy: AutoMemoryStrategy) -> Self {
110        self.auto_memory_management = strategy;
111        self
112    }
113
114    /// Enable asynchronous prefetching
115    pub fn with_async_prefetch(mut self, async_prefetch: bool) -> Self {
116        self.async_prefetch = async_prefetch;
117        self
118    }
119
120    /// Enable computation overlap
121    pub fn with_overlap_computation(mut self, overlap: bool) -> Self {
122        self.overlap_computation = overlap;
123        self
124    }
125
126    /// Enable CPU memory pinning
127    pub fn with_pin_cpu_memory(mut self, pin: bool) -> Self {
128        self.pin_cpu_memory = pin;
129        self
130    }
131
132    /// Validate the configuration settings
133    pub fn validate(&self) -> Result<(), String> {
134        if self.cpu_memory_budget == 0 {
135            return Err("CPU memory budget cannot be zero".to_string());
136        }
137
138        if self.gpu_param_memory_budget == 0 {
139            return Err("GPU parameter memory budget cannot be zero".to_string());
140        }
141
142        if self.prefetch_buffer_size == 0 {
143            return Err("Prefetch buffer size cannot be zero".to_string());
144        }
145
146        if self.max_gpu_memory_mb == 0 {
147            return Err("Maximum GPU memory cannot be zero".to_string());
148        }
149
150        if self.max_cpu_memory_mb == 0 {
151            return Err("Maximum CPU memory cannot be zero".to_string());
152        }
153
154        // Check that GPU memory budget doesn't exceed maximum GPU memory
155        let gpu_budget_mb = self.gpu_param_memory_budget / (1024 * 1024);
156        if gpu_budget_mb > self.max_gpu_memory_mb {
157            return Err(format!(
158                "GPU parameter memory budget ({} MB) exceeds maximum GPU memory ({} MB)",
159                gpu_budget_mb, self.max_gpu_memory_mb
160            ));
161        }
162
163        // Check that CPU memory budget doesn't exceed maximum CPU memory
164        let cpu_budget_mb = self.cpu_memory_budget / (1024 * 1024);
165        if cpu_budget_mb > self.max_cpu_memory_mb {
166            return Err(format!(
167                "CPU memory budget ({} MB) exceeds maximum CPU memory ({} MB)",
168                cpu_budget_mb, self.max_cpu_memory_mb
169            ));
170        }
171
172        Ok(())
173    }
174
175    /// Get the effective compression ratio for CPU storage
176    pub fn compression_ratio(&self) -> f32 {
177        match self.cpu_compression {
178            CpuCompressionMethod::None => 1.0,
179            CpuCompressionMethod::FP16 => 0.5,
180            CpuCompressionMethod::BF16 => 0.5,
181            CpuCompressionMethod::INT8 => 0.25,
182            CpuCompressionMethod::Quantization => 0.25,
183            CpuCompressionMethod::LosslessCompression => 0.7, // Typical lossless compression ratio
184        }
185    }
186
187    /// Get the effective CPU memory budget after compression
188    pub fn effective_cpu_memory_budget(&self) -> usize {
189        (self.cpu_memory_budget as f32 / self.compression_ratio()) as usize
190    }
191}
192
193/// Compression methods for CPU-stored data
194#[derive(Debug, Clone, Copy, PartialEq, Eq)]
195pub enum CpuCompressionMethod {
196    /// No compression
197    None,
198    /// 16-bit floating point compression
199    FP16,
200    /// BFloat16 compression
201    BF16,
202    /// 8-bit integer quantization
203    INT8,
204    /// Advanced quantization schemes
205    Quantization,
206    /// Lossless compression (e.g., LZ4, Snappy)
207    LosslessCompression,
208}
209
210impl CpuCompressionMethod {
211    /// Get the compression ratio (0.0 to 1.0, where 1.0 means no compression)
212    pub fn ratio(&self) -> f32 {
213        match self {
214            CpuCompressionMethod::None => 1.0,
215            CpuCompressionMethod::FP16 => 0.5,
216            CpuCompressionMethod::BF16 => 0.5,
217            CpuCompressionMethod::INT8 => 0.25,
218            CpuCompressionMethod::Quantization => 0.25,
219            CpuCompressionMethod::LosslessCompression => 0.7,
220        }
221    }
222
223    /// Check if this compression method is lossy
224    pub fn is_lossy(&self) -> bool {
225        matches!(
226            self,
227            CpuCompressionMethod::FP16
228                | CpuCompressionMethod::BF16
229                | CpuCompressionMethod::INT8
230                | CpuCompressionMethod::Quantization
231        )
232    }
233
234    /// Get a human-readable description
235    pub fn description(&self) -> &'static str {
236        match self {
237            CpuCompressionMethod::None => "No compression",
238            CpuCompressionMethod::FP16 => "16-bit floating point",
239            CpuCompressionMethod::BF16 => "BFloat16",
240            CpuCompressionMethod::INT8 => "8-bit integer quantization",
241            CpuCompressionMethod::Quantization => "Advanced quantization",
242            CpuCompressionMethod::LosslessCompression => "Lossless compression",
243        }
244    }
245}
246
247/// Automatic memory management strategies
248#[derive(Debug, Clone, Copy, PartialEq, Eq)]
249pub enum AutoMemoryStrategy {
250    /// Conservative memory management - minimal offloading
251    Conservative,
252    /// Balanced approach - moderate offloading based on memory pressure
253    Balanced,
254    /// Aggressive offloading - maximize CPU utilization
255    Aggressive,
256    /// Extreme offloading - offload everything possible
257    Extreme,
258}
259
260impl AutoMemoryStrategy {
261    /// Get the memory pressure threshold for triggering offloading
262    pub fn pressure_threshold(&self) -> f32 {
263        match self {
264            AutoMemoryStrategy::Conservative => 0.9, // 90% memory pressure
265            AutoMemoryStrategy::Balanced => 0.75,    // 75% memory pressure
266            AutoMemoryStrategy::Aggressive => 0.6,   // 60% memory pressure
267            AutoMemoryStrategy::Extreme => 0.4,      // 40% memory pressure
268        }
269    }
270
271    /// Get the offloading aggressiveness factor (0.0 to 1.0)
272    pub fn aggressiveness(&self) -> f32 {
273        match self {
274            AutoMemoryStrategy::Conservative => 0.2,
275            AutoMemoryStrategy::Balanced => 0.5,
276            AutoMemoryStrategy::Aggressive => 0.8,
277            AutoMemoryStrategy::Extreme => 1.0,
278        }
279    }
280
281    /// Get a human-readable description
282    pub fn description(&self) -> &'static str {
283        match self {
284            AutoMemoryStrategy::Conservative => "Conservative - minimal offloading",
285            AutoMemoryStrategy::Balanced => "Balanced - moderate offloading",
286            AutoMemoryStrategy::Aggressive => "Aggressive - maximize CPU utilization",
287            AutoMemoryStrategy::Extreme => "Extreme - offload everything possible",
288        }
289    }
290}
291
292/// ZeRO-3 rank mapping for parameter partitioning
293#[derive(Debug, Clone)]
294pub struct Zero3RankMapping {
295    rank: usize,
296    world_size: usize,
297}
298
299impl Zero3RankMapping {
300    /// Create a new rank mapping
301    pub fn new(rank: usize, world_size: usize) -> Self {
302        assert!(rank < world_size, "Rank must be less than world size");
303        Self { rank, world_size }
304    }
305
306    /// Get the current rank
307    pub fn rank(&self) -> usize {
308        self.rank
309    }
310
311    /// Get the world size
312    pub fn world_size(&self) -> usize {
313        self.world_size
314    }
315
316    /// Check if this rank owns a specific partition
317    pub fn owns_partition(&self, partition_idx: usize) -> bool {
318        partition_idx % self.world_size == self.rank
319    }
320
321    /// Get the owner rank for a parameter
322    pub fn get_parameter_owner(&self, param_idx: usize) -> usize {
323        param_idx % self.world_size
324    }
325
326    /// Get all partitions owned by this rank
327    pub fn owned_partitions(&self, total_partitions: usize) -> Vec<usize> {
328        (0..total_partitions)
329            .filter(|&i| self.owns_partition(i))
330            .collect()
331    }
332
333    /// Get the number of partitions owned by this rank
334    pub fn owned_partition_count(&self, total_partitions: usize) -> usize {
335        let base_count = total_partitions / self.world_size;
336        let remainder = total_partitions % self.world_size;
337
338        if self.rank < remainder {
339            base_count + 1
340        } else {
341            base_count
342        }
343    }
344
345    /// Map a global parameter index to a local partition index
346    pub fn global_to_local_partition(&self, global_idx: usize) -> Option<usize> {
347        if self.owns_partition(global_idx) {
348            Some(global_idx / self.world_size)
349        } else {
350            None
351        }
352    }
353
354    /// Map a local partition index to a global parameter index
355    pub fn local_to_global_partition(&self, local_idx: usize) -> usize {
356        local_idx * self.world_size + self.rank
357    }
358
359    /// Get ranks that need to participate in communication for a given parameter set
360    pub fn communication_group(&self, param_indices: &[usize]) -> Vec<usize> {
361        let mut ranks = std::collections::HashSet::new();
362        for &param_idx in param_indices {
363            ranks.insert(self.get_parameter_owner(param_idx));
364        }
365        let mut result: Vec<usize> = ranks.into_iter().collect();
366        result.sort();
367        result
368    }
369}
370
371/// Model parameters for ZeRO-3 initialization
372#[derive(Debug)]
373pub struct ModelParameters {
374    pub parameter_count: usize,
375    pub parameter_names: Vec<String>,
376    pub parameter_shapes: HashMap<String, Vec<usize>>,
377    pub total_memory_bytes: usize,
378}
379
380impl ModelParameters {
381    /// Create a new empty model parameters collection
382    pub fn new() -> Self {
383        Self {
384            parameter_count: 0,
385            parameter_names: Vec::new(),
386            parameter_shapes: HashMap::new(),
387            total_memory_bytes: 0,
388        }
389    }
390
391    /// Add a parameter to the collection
392    pub fn add_parameter(&mut self, name: String, shape: Vec<usize>) {
393        let param_size = shape.iter().product::<usize>();
394        self.parameter_count += param_size;
395        self.total_memory_bytes += param_size * std::mem::size_of::<f32>();
396        self.parameter_shapes.insert(name.clone(), shape);
397        self.parameter_names.push(name);
398    }
399
400    /// Check if a parameter exists
401    pub fn has_parameter(&self, name: &str) -> bool {
402        self.parameter_shapes.contains_key(name)
403    }
404
405    /// Add a parameter with custom element size
406    pub fn add_parameter_with_size(
407        &mut self,
408        name: String,
409        shape: Vec<usize>,
410        element_size: usize,
411    ) {
412        let param_size = shape.iter().product::<usize>();
413        self.parameter_count += param_size;
414        self.total_memory_bytes += param_size * element_size;
415        self.parameter_shapes.insert(name.clone(), shape);
416        self.parameter_names.push(name);
417    }
418
419    /// Get the shape of a parameter by name
420    pub fn get_parameter_shape(&self, name: &str) -> Option<&Vec<usize>> {
421        self.parameter_shapes.get(name)
422    }
423
424    /// Get the number of elements in a parameter
425    pub fn get_parameter_size(&self, name: &str) -> Option<usize> {
426        self.parameter_shapes
427            .get(name)
428            .map(|shape| shape.iter().product::<usize>())
429    }
430
431    /// Get total number of parameters
432    pub fn total_parameters(&self) -> usize {
433        self.parameter_names.len()
434    }
435
436    /// Get memory usage in MB
437    pub fn memory_usage_mb(&self) -> f64 {
438        self.total_memory_bytes as f64 / (1024.0 * 1024.0)
439    }
440
441    /// Get parameter statistics
442    pub fn get_statistics(&self) -> ModelParameterStats {
443        if self.parameter_names.is_empty() {
444            return ModelParameterStats::default();
445        }
446
447        let mut sizes: Vec<usize> = self
448            .parameter_shapes
449            .values()
450            .map(|shape| shape.iter().product::<usize>())
451            .collect();
452        sizes.sort();
453
454        let total_elements = sizes.iter().sum::<usize>();
455        let mean_size = total_elements as f64 / sizes.len() as f64;
456        let median_size = if sizes.len() % 2 == 0 {
457            (sizes[sizes.len() / 2 - 1] + sizes[sizes.len() / 2]) as f64 / 2.0
458        } else {
459            sizes[sizes.len() / 2] as f64
460        };
461
462        ModelParameterStats {
463            total_parameters: self.parameter_names.len(),
464            total_elements,
465            mean_parameter_size: mean_size,
466            median_parameter_size: median_size,
467            min_parameter_size: *sizes.first().unwrap_or(&0),
468            max_parameter_size: *sizes.last().unwrap_or(&0),
469            total_memory_bytes: self.total_memory_bytes,
470        }
471    }
472}
473
474impl Default for ModelParameters {
475    fn default() -> Self {
476        Self::new()
477    }
478}
479
480/// Statistics about model parameters
481#[derive(Debug, Clone)]
482pub struct ModelParameterStats {
483    pub total_parameters: usize,
484    pub total_elements: usize,
485    pub mean_parameter_size: f64,
486    pub median_parameter_size: f64,
487    pub min_parameter_size: usize,
488    pub max_parameter_size: usize,
489    pub total_memory_bytes: usize,
490}
491
492impl Default for ModelParameterStats {
493    fn default() -> Self {
494        Self {
495            total_parameters: 0,
496            total_elements: 0,
497            mean_parameter_size: 0.0,
498            median_parameter_size: 0.0,
499            min_parameter_size: 0,
500            max_parameter_size: 0,
501            total_memory_bytes: 0,
502        }
503    }
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509
510    #[test]
511    fn test_zero3_config_default() {
512        let config = Zero3CpuOffloadConfig::default();
513        assert!(config.offload_params);
514        assert!(config.offload_grads);
515        assert!(config.offload_optimizer_states);
516        assert!(config.async_prefetch);
517        assert_eq!(config.cpu_compression, CpuCompressionMethod::None);
518        assert_eq!(
519            config.auto_memory_management,
520            AutoMemoryStrategy::Aggressive
521        );
522    }
523
524    #[test]
525    fn test_zero3_config_builder() {
526        let config = Zero3CpuOffloadConfig::new()
527            .with_offload_params(false)
528            .with_compression(CpuCompressionMethod::FP16)
529            .with_memory_strategy(AutoMemoryStrategy::Conservative)
530            .with_prefetch_buffer_size(32);
531
532        assert!(!config.offload_params);
533        assert_eq!(config.cpu_compression, CpuCompressionMethod::FP16);
534        assert_eq!(
535            config.auto_memory_management,
536            AutoMemoryStrategy::Conservative
537        );
538        assert_eq!(config.prefetch_buffer_size, 32);
539    }
540
541    #[test]
542    fn test_zero3_config_validation() {
543        let config = Zero3CpuOffloadConfig::default();
544        assert!(config.validate().is_ok());
545
546        let mut invalid_config = config.clone();
547        invalid_config.cpu_memory_budget = 0;
548        assert!(invalid_config.validate().is_err());
549
550        let mut invalid_config = config.clone();
551        invalid_config.gpu_param_memory_budget = 0;
552        assert!(invalid_config.validate().is_err());
553    }
554
555    #[test]
556    fn test_compression_methods() {
557        assert_eq!(CpuCompressionMethod::None.ratio(), 1.0);
558        assert_eq!(CpuCompressionMethod::FP16.ratio(), 0.5);
559        assert_eq!(CpuCompressionMethod::INT8.ratio(), 0.25);
560
561        assert!(!CpuCompressionMethod::None.is_lossy());
562        assert!(CpuCompressionMethod::FP16.is_lossy());
563        assert!(!CpuCompressionMethod::LosslessCompression.is_lossy());
564    }
565
566    #[test]
567    fn test_memory_strategies() {
568        assert_eq!(AutoMemoryStrategy::Conservative.pressure_threshold(), 0.9);
569        assert_eq!(AutoMemoryStrategy::Aggressive.pressure_threshold(), 0.6);
570
571        assert_eq!(AutoMemoryStrategy::Conservative.aggressiveness(), 0.2);
572        assert_eq!(AutoMemoryStrategy::Extreme.aggressiveness(), 1.0);
573    }
574
575    #[test]
576    fn test_rank_mapping() {
577        let mapping = Zero3RankMapping::new(1, 4);
578
579        assert_eq!(mapping.rank(), 1);
580        assert_eq!(mapping.world_size(), 4);
581
582        assert!(mapping.owns_partition(1)); // 1 % 4 == 1
583        assert!(mapping.owns_partition(5)); // 5 % 4 == 1
584        assert!(!mapping.owns_partition(0)); // 0 % 4 != 1
585        assert!(!mapping.owns_partition(2)); // 2 % 4 != 1
586
587        assert_eq!(mapping.get_parameter_owner(5), 1);
588        assert_eq!(mapping.get_parameter_owner(8), 0);
589
590        let owned = mapping.owned_partitions(10);
591        assert_eq!(owned, vec![1, 5, 9]);
592
593        assert_eq!(mapping.owned_partition_count(10), 3); // 10 partitions, rank 1 gets 3
594        assert_eq!(mapping.owned_partition_count(8), 2); // 8 partitions, rank 1 gets 2
595    }
596
597    #[test]
598    fn test_model_parameters() {
599        let mut params = ModelParameters::new();
600
601        params.add_parameter("layer1.weight".to_string(), vec![100, 50]);
602        params.add_parameter("layer1.bias".to_string(), vec![50]);
603
604        assert_eq!(params.total_parameters(), 2);
605        assert_eq!(params.parameter_count, 5050); // 100*50 + 50
606        assert_eq!(params.get_parameter_size("layer1.weight"), Some(5000));
607        assert_eq!(params.get_parameter_size("layer1.bias"), Some(50));
608
609        let stats = params.get_statistics();
610        assert_eq!(stats.total_parameters, 2);
611        assert_eq!(stats.total_elements, 5050);
612        assert_eq!(stats.min_parameter_size, 50);
613        assert_eq!(stats.max_parameter_size, 5000);
614    }
615
616    #[test]
617    fn test_rank_mapping_communication_group() {
618        let mapping = Zero3RankMapping::new(1, 4);
619        let param_indices = vec![0, 1, 4, 5, 8, 9];
620        let comm_group = mapping.communication_group(&param_indices);
621
622        // Parameters owned by: 0->rank0, 1->rank1, 4->rank0, 5->rank1, 8->rank0, 9->rank1
623        // So communication group should be [0, 1]
624        assert_eq!(comm_group, vec![0, 1]);
625    }
626
627    #[test]
628    fn test_effective_cpu_memory_budget() {
629        let config = Zero3CpuOffloadConfig::new()
630            .with_cpu_memory_budget(1000)
631            .with_compression(CpuCompressionMethod::FP16);
632
633        // With FP16 compression (0.5 ratio), effective budget should be 2000
634        assert_eq!(config.effective_cpu_memory_budget(), 2000);
635
636        let config_no_compression = Zero3CpuOffloadConfig::new()
637            .with_cpu_memory_budget(1000)
638            .with_compression(CpuCompressionMethod::None);
639
640        assert_eq!(config_no_compression.effective_cpu_memory_budget(), 1000);
641    }
642}