1use std::collections::HashMap;
8
9#[derive(Debug, Clone)]
11pub struct Zero3CpuOffloadConfig {
12 pub offload_params: bool,
14 pub offload_grads: bool,
16 pub offload_optimizer_states: bool,
18 pub cpu_memory_budget: usize,
20 pub gpu_param_memory_budget: usize,
22 pub max_gpu_memory_mb: usize,
24 pub max_cpu_memory_mb: usize,
26 pub prefetch_buffer_size: usize,
28 pub async_prefetch: bool,
30 pub overlap_computation: bool,
32 pub pin_cpu_memory: bool,
34 pub cpu_compression: CpuCompressionMethod,
36 pub auto_memory_management: AutoMemoryStrategy,
38}
39
40impl Default for Zero3CpuOffloadConfig {
41 fn default() -> Self {
42 Self {
43 offload_params: true,
44 offload_grads: true,
45 offload_optimizer_states: true,
46 cpu_memory_budget: 32 * 1024 * 1024 * 1024, gpu_param_memory_budget: 2 * 1024 * 1024 * 1024, max_gpu_memory_mb: 8 * 1024, max_cpu_memory_mb: 64 * 1024, prefetch_buffer_size: 16,
51 async_prefetch: true,
52 overlap_computation: true,
53 pin_cpu_memory: true,
54 cpu_compression: CpuCompressionMethod::None,
55 auto_memory_management: AutoMemoryStrategy::Aggressive,
56 }
57 }
58}
59
60impl Zero3CpuOffloadConfig {
61 pub fn new() -> Self {
63 Self::default()
64 }
65
66 pub fn with_offload_params(mut self, offload: bool) -> Self {
68 self.offload_params = offload;
69 self
70 }
71
72 pub fn with_offload_grads(mut self, offload: bool) -> Self {
74 self.offload_grads = offload;
75 self
76 }
77
78 pub fn with_offload_optimizer_states(mut self, offload: bool) -> Self {
80 self.offload_optimizer_states = offload;
81 self
82 }
83
84 pub fn with_cpu_memory_budget(mut self, budget: usize) -> Self {
86 self.cpu_memory_budget = budget;
87 self
88 }
89
90 pub fn with_gpu_param_memory_budget(mut self, budget: usize) -> Self {
92 self.gpu_param_memory_budget = budget;
93 self
94 }
95
96 pub fn with_prefetch_buffer_size(mut self, size: usize) -> Self {
98 self.prefetch_buffer_size = size;
99 self
100 }
101
102 pub fn with_compression(mut self, compression: CpuCompressionMethod) -> Self {
104 self.cpu_compression = compression;
105 self
106 }
107
108 pub fn with_memory_strategy(mut self, strategy: AutoMemoryStrategy) -> Self {
110 self.auto_memory_management = strategy;
111 self
112 }
113
114 pub fn with_async_prefetch(mut self, async_prefetch: bool) -> Self {
116 self.async_prefetch = async_prefetch;
117 self
118 }
119
120 pub fn with_overlap_computation(mut self, overlap: bool) -> Self {
122 self.overlap_computation = overlap;
123 self
124 }
125
126 pub fn with_pin_cpu_memory(mut self, pin: bool) -> Self {
128 self.pin_cpu_memory = pin;
129 self
130 }
131
132 pub fn validate(&self) -> Result<(), String> {
134 if self.cpu_memory_budget == 0 {
135 return Err("CPU memory budget cannot be zero".to_string());
136 }
137
138 if self.gpu_param_memory_budget == 0 {
139 return Err("GPU parameter memory budget cannot be zero".to_string());
140 }
141
142 if self.prefetch_buffer_size == 0 {
143 return Err("Prefetch buffer size cannot be zero".to_string());
144 }
145
146 if self.max_gpu_memory_mb == 0 {
147 return Err("Maximum GPU memory cannot be zero".to_string());
148 }
149
150 if self.max_cpu_memory_mb == 0 {
151 return Err("Maximum CPU memory cannot be zero".to_string());
152 }
153
154 let gpu_budget_mb = self.gpu_param_memory_budget / (1024 * 1024);
156 if gpu_budget_mb > self.max_gpu_memory_mb {
157 return Err(format!(
158 "GPU parameter memory budget ({} MB) exceeds maximum GPU memory ({} MB)",
159 gpu_budget_mb, self.max_gpu_memory_mb
160 ));
161 }
162
163 let cpu_budget_mb = self.cpu_memory_budget / (1024 * 1024);
165 if cpu_budget_mb > self.max_cpu_memory_mb {
166 return Err(format!(
167 "CPU memory budget ({} MB) exceeds maximum CPU memory ({} MB)",
168 cpu_budget_mb, self.max_cpu_memory_mb
169 ));
170 }
171
172 Ok(())
173 }
174
175 pub fn compression_ratio(&self) -> f32 {
177 match self.cpu_compression {
178 CpuCompressionMethod::None => 1.0,
179 CpuCompressionMethod::FP16 => 0.5,
180 CpuCompressionMethod::BF16 => 0.5,
181 CpuCompressionMethod::INT8 => 0.25,
182 CpuCompressionMethod::Quantization => 0.25,
183 CpuCompressionMethod::LosslessCompression => 0.7, }
185 }
186
187 pub fn effective_cpu_memory_budget(&self) -> usize {
189 (self.cpu_memory_budget as f32 / self.compression_ratio()) as usize
190 }
191}
192
193#[derive(Debug, Clone, Copy, PartialEq, Eq)]
195pub enum CpuCompressionMethod {
196 None,
198 FP16,
200 BF16,
202 INT8,
204 Quantization,
206 LosslessCompression,
208}
209
210impl CpuCompressionMethod {
211 pub fn ratio(&self) -> f32 {
213 match self {
214 CpuCompressionMethod::None => 1.0,
215 CpuCompressionMethod::FP16 => 0.5,
216 CpuCompressionMethod::BF16 => 0.5,
217 CpuCompressionMethod::INT8 => 0.25,
218 CpuCompressionMethod::Quantization => 0.25,
219 CpuCompressionMethod::LosslessCompression => 0.7,
220 }
221 }
222
223 pub fn is_lossy(&self) -> bool {
225 matches!(
226 self,
227 CpuCompressionMethod::FP16
228 | CpuCompressionMethod::BF16
229 | CpuCompressionMethod::INT8
230 | CpuCompressionMethod::Quantization
231 )
232 }
233
234 pub fn description(&self) -> &'static str {
236 match self {
237 CpuCompressionMethod::None => "No compression",
238 CpuCompressionMethod::FP16 => "16-bit floating point",
239 CpuCompressionMethod::BF16 => "BFloat16",
240 CpuCompressionMethod::INT8 => "8-bit integer quantization",
241 CpuCompressionMethod::Quantization => "Advanced quantization",
242 CpuCompressionMethod::LosslessCompression => "Lossless compression",
243 }
244 }
245}
246
247#[derive(Debug, Clone, Copy, PartialEq, Eq)]
249pub enum AutoMemoryStrategy {
250 Conservative,
252 Balanced,
254 Aggressive,
256 Extreme,
258}
259
260impl AutoMemoryStrategy {
261 pub fn pressure_threshold(&self) -> f32 {
263 match self {
264 AutoMemoryStrategy::Conservative => 0.9, AutoMemoryStrategy::Balanced => 0.75, AutoMemoryStrategy::Aggressive => 0.6, AutoMemoryStrategy::Extreme => 0.4, }
269 }
270
271 pub fn aggressiveness(&self) -> f32 {
273 match self {
274 AutoMemoryStrategy::Conservative => 0.2,
275 AutoMemoryStrategy::Balanced => 0.5,
276 AutoMemoryStrategy::Aggressive => 0.8,
277 AutoMemoryStrategy::Extreme => 1.0,
278 }
279 }
280
281 pub fn description(&self) -> &'static str {
283 match self {
284 AutoMemoryStrategy::Conservative => "Conservative - minimal offloading",
285 AutoMemoryStrategy::Balanced => "Balanced - moderate offloading",
286 AutoMemoryStrategy::Aggressive => "Aggressive - maximize CPU utilization",
287 AutoMemoryStrategy::Extreme => "Extreme - offload everything possible",
288 }
289 }
290}
291
292#[derive(Debug, Clone)]
294pub struct Zero3RankMapping {
295 rank: usize,
296 world_size: usize,
297}
298
299impl Zero3RankMapping {
300 pub fn new(rank: usize, world_size: usize) -> Self {
302 assert!(rank < world_size, "Rank must be less than world size");
303 Self { rank, world_size }
304 }
305
306 pub fn rank(&self) -> usize {
308 self.rank
309 }
310
311 pub fn world_size(&self) -> usize {
313 self.world_size
314 }
315
316 pub fn owns_partition(&self, partition_idx: usize) -> bool {
318 partition_idx % self.world_size == self.rank
319 }
320
321 pub fn get_parameter_owner(&self, param_idx: usize) -> usize {
323 param_idx % self.world_size
324 }
325
326 pub fn owned_partitions(&self, total_partitions: usize) -> Vec<usize> {
328 (0..total_partitions)
329 .filter(|&i| self.owns_partition(i))
330 .collect()
331 }
332
333 pub fn owned_partition_count(&self, total_partitions: usize) -> usize {
335 let base_count = total_partitions / self.world_size;
336 let remainder = total_partitions % self.world_size;
337
338 if self.rank < remainder {
339 base_count + 1
340 } else {
341 base_count
342 }
343 }
344
345 pub fn global_to_local_partition(&self, global_idx: usize) -> Option<usize> {
347 if self.owns_partition(global_idx) {
348 Some(global_idx / self.world_size)
349 } else {
350 None
351 }
352 }
353
354 pub fn local_to_global_partition(&self, local_idx: usize) -> usize {
356 local_idx * self.world_size + self.rank
357 }
358
359 pub fn communication_group(&self, param_indices: &[usize]) -> Vec<usize> {
361 let mut ranks = std::collections::HashSet::new();
362 for ¶m_idx in param_indices {
363 ranks.insert(self.get_parameter_owner(param_idx));
364 }
365 let mut result: Vec<usize> = ranks.into_iter().collect();
366 result.sort();
367 result
368 }
369}
370
371#[derive(Debug)]
373pub struct ModelParameters {
374 pub parameter_count: usize,
375 pub parameter_names: Vec<String>,
376 pub parameter_shapes: HashMap<String, Vec<usize>>,
377 pub total_memory_bytes: usize,
378}
379
380impl ModelParameters {
381 pub fn new() -> Self {
383 Self {
384 parameter_count: 0,
385 parameter_names: Vec::new(),
386 parameter_shapes: HashMap::new(),
387 total_memory_bytes: 0,
388 }
389 }
390
391 pub fn add_parameter(&mut self, name: String, shape: Vec<usize>) {
393 let param_size = shape.iter().product::<usize>();
394 self.parameter_count += param_size;
395 self.total_memory_bytes += param_size * std::mem::size_of::<f32>();
396 self.parameter_shapes.insert(name.clone(), shape);
397 self.parameter_names.push(name);
398 }
399
400 pub fn has_parameter(&self, name: &str) -> bool {
402 self.parameter_shapes.contains_key(name)
403 }
404
405 pub fn add_parameter_with_size(
407 &mut self,
408 name: String,
409 shape: Vec<usize>,
410 element_size: usize,
411 ) {
412 let param_size = shape.iter().product::<usize>();
413 self.parameter_count += param_size;
414 self.total_memory_bytes += param_size * element_size;
415 self.parameter_shapes.insert(name.clone(), shape);
416 self.parameter_names.push(name);
417 }
418
419 pub fn get_parameter_shape(&self, name: &str) -> Option<&Vec<usize>> {
421 self.parameter_shapes.get(name)
422 }
423
424 pub fn get_parameter_size(&self, name: &str) -> Option<usize> {
426 self.parameter_shapes
427 .get(name)
428 .map(|shape| shape.iter().product::<usize>())
429 }
430
431 pub fn total_parameters(&self) -> usize {
433 self.parameter_names.len()
434 }
435
436 pub fn memory_usage_mb(&self) -> f64 {
438 self.total_memory_bytes as f64 / (1024.0 * 1024.0)
439 }
440
441 pub fn get_statistics(&self) -> ModelParameterStats {
443 if self.parameter_names.is_empty() {
444 return ModelParameterStats::default();
445 }
446
447 let mut sizes: Vec<usize> = self
448 .parameter_shapes
449 .values()
450 .map(|shape| shape.iter().product::<usize>())
451 .collect();
452 sizes.sort();
453
454 let total_elements = sizes.iter().sum::<usize>();
455 let mean_size = total_elements as f64 / sizes.len() as f64;
456 let median_size = if sizes.len() % 2 == 0 {
457 (sizes[sizes.len() / 2 - 1] + sizes[sizes.len() / 2]) as f64 / 2.0
458 } else {
459 sizes[sizes.len() / 2] as f64
460 };
461
462 ModelParameterStats {
463 total_parameters: self.parameter_names.len(),
464 total_elements,
465 mean_parameter_size: mean_size,
466 median_parameter_size: median_size,
467 min_parameter_size: *sizes.first().unwrap_or(&0),
468 max_parameter_size: *sizes.last().unwrap_or(&0),
469 total_memory_bytes: self.total_memory_bytes,
470 }
471 }
472}
473
474impl Default for ModelParameters {
475 fn default() -> Self {
476 Self::new()
477 }
478}
479
480#[derive(Debug, Clone)]
482pub struct ModelParameterStats {
483 pub total_parameters: usize,
484 pub total_elements: usize,
485 pub mean_parameter_size: f64,
486 pub median_parameter_size: f64,
487 pub min_parameter_size: usize,
488 pub max_parameter_size: usize,
489 pub total_memory_bytes: usize,
490}
491
492impl Default for ModelParameterStats {
493 fn default() -> Self {
494 Self {
495 total_parameters: 0,
496 total_elements: 0,
497 mean_parameter_size: 0.0,
498 median_parameter_size: 0.0,
499 min_parameter_size: 0,
500 max_parameter_size: 0,
501 total_memory_bytes: 0,
502 }
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509
510 #[test]
511 fn test_zero3_config_default() {
512 let config = Zero3CpuOffloadConfig::default();
513 assert!(config.offload_params);
514 assert!(config.offload_grads);
515 assert!(config.offload_optimizer_states);
516 assert!(config.async_prefetch);
517 assert_eq!(config.cpu_compression, CpuCompressionMethod::None);
518 assert_eq!(
519 config.auto_memory_management,
520 AutoMemoryStrategy::Aggressive
521 );
522 }
523
524 #[test]
525 fn test_zero3_config_builder() {
526 let config = Zero3CpuOffloadConfig::new()
527 .with_offload_params(false)
528 .with_compression(CpuCompressionMethod::FP16)
529 .with_memory_strategy(AutoMemoryStrategy::Conservative)
530 .with_prefetch_buffer_size(32);
531
532 assert!(!config.offload_params);
533 assert_eq!(config.cpu_compression, CpuCompressionMethod::FP16);
534 assert_eq!(
535 config.auto_memory_management,
536 AutoMemoryStrategy::Conservative
537 );
538 assert_eq!(config.prefetch_buffer_size, 32);
539 }
540
541 #[test]
542 fn test_zero3_config_validation() {
543 let config = Zero3CpuOffloadConfig::default();
544 assert!(config.validate().is_ok());
545
546 let mut invalid_config = config.clone();
547 invalid_config.cpu_memory_budget = 0;
548 assert!(invalid_config.validate().is_err());
549
550 let mut invalid_config = config.clone();
551 invalid_config.gpu_param_memory_budget = 0;
552 assert!(invalid_config.validate().is_err());
553 }
554
555 #[test]
556 fn test_compression_methods() {
557 assert_eq!(CpuCompressionMethod::None.ratio(), 1.0);
558 assert_eq!(CpuCompressionMethod::FP16.ratio(), 0.5);
559 assert_eq!(CpuCompressionMethod::INT8.ratio(), 0.25);
560
561 assert!(!CpuCompressionMethod::None.is_lossy());
562 assert!(CpuCompressionMethod::FP16.is_lossy());
563 assert!(!CpuCompressionMethod::LosslessCompression.is_lossy());
564 }
565
566 #[test]
567 fn test_memory_strategies() {
568 assert_eq!(AutoMemoryStrategy::Conservative.pressure_threshold(), 0.9);
569 assert_eq!(AutoMemoryStrategy::Aggressive.pressure_threshold(), 0.6);
570
571 assert_eq!(AutoMemoryStrategy::Conservative.aggressiveness(), 0.2);
572 assert_eq!(AutoMemoryStrategy::Extreme.aggressiveness(), 1.0);
573 }
574
575 #[test]
576 fn test_rank_mapping() {
577 let mapping = Zero3RankMapping::new(1, 4);
578
579 assert_eq!(mapping.rank(), 1);
580 assert_eq!(mapping.world_size(), 4);
581
582 assert!(mapping.owns_partition(1)); assert!(mapping.owns_partition(5)); assert!(!mapping.owns_partition(0)); assert!(!mapping.owns_partition(2)); assert_eq!(mapping.get_parameter_owner(5), 1);
588 assert_eq!(mapping.get_parameter_owner(8), 0);
589
590 let owned = mapping.owned_partitions(10);
591 assert_eq!(owned, vec![1, 5, 9]);
592
593 assert_eq!(mapping.owned_partition_count(10), 3); assert_eq!(mapping.owned_partition_count(8), 2); }
596
597 #[test]
598 fn test_model_parameters() {
599 let mut params = ModelParameters::new();
600
601 params.add_parameter("layer1.weight".to_string(), vec![100, 50]);
602 params.add_parameter("layer1.bias".to_string(), vec![50]);
603
604 assert_eq!(params.total_parameters(), 2);
605 assert_eq!(params.parameter_count, 5050); assert_eq!(params.get_parameter_size("layer1.weight"), Some(5000));
607 assert_eq!(params.get_parameter_size("layer1.bias"), Some(50));
608
609 let stats = params.get_statistics();
610 assert_eq!(stats.total_parameters, 2);
611 assert_eq!(stats.total_elements, 5050);
612 assert_eq!(stats.min_parameter_size, 50);
613 assert_eq!(stats.max_parameter_size, 5000);
614 }
615
616 #[test]
617 fn test_rank_mapping_communication_group() {
618 let mapping = Zero3RankMapping::new(1, 4);
619 let param_indices = vec![0, 1, 4, 5, 8, 9];
620 let comm_group = mapping.communication_group(¶m_indices);
621
622 assert_eq!(comm_group, vec![0, 1]);
625 }
626
627 #[test]
628 fn test_effective_cpu_memory_budget() {
629 let config = Zero3CpuOffloadConfig::new()
630 .with_cpu_memory_budget(1000)
631 .with_compression(CpuCompressionMethod::FP16);
632
633 assert_eq!(config.effective_cpu_memory_budget(), 2000);
635
636 let config_no_compression = Zero3CpuOffloadConfig::new()
637 .with_cpu_memory_budget(1000)
638 .with_compression(CpuCompressionMethod::None);
639
640 assert_eq!(config_no_compression.effective_cpu_memory_budget(), 1000);
641 }
642}