1#![allow(dead_code)]
78pub mod config;
80pub mod gradient_management;
81pub mod memory_management;
82pub mod optimizer_state;
83pub mod parameter_management;
84pub mod prefetch;
85pub mod stats;
86
87pub use config::{
89 AutoMemoryStrategy, CpuCompressionMethod, ModelParameterStats,
90 ModelParameters as ConfigModelParameters, Zero3CpuOffloadConfig,
91 Zero3RankMapping as ConfigZero3RankMapping,
92};
93pub use gradient_management::*;
94pub use memory_management::*;
95pub use optimizer_state::{
96 OptimizerState, OptimizerStateManager, OptimizerStateMemoryStats,
97 Zero3RankMapping as OptimizerZero3RankMapping,
98};
99pub use parameter_management::*;
100pub use prefetch::*;
101pub use stats::*;
102
103use crate::{ProcessGroup, TorshDistributedError, TorshResult};
105use half::{bf16, f16};
106use log::info;
107use std::collections::HashMap;
108use std::sync::{Arc, Mutex};
109use std::time::Instant;
110use torsh_core::device::DeviceType;
111use torsh_tensor::Tensor;
112
113pub struct Zero3CpuOffloadManager {
119 config: Zero3CpuOffloadConfig,
121 process_group: Arc<ProcessGroup>,
123 rank_mapping: ConfigZero3RankMapping,
125
126 param_partitioner: ParameterPartitioner,
129 cpu_param_store: CpuParameterStore,
131 gpu_param_cache: GpuParameterCache,
133
134 gradient_partitioner: GradientPartitioner,
136 cpu_gradient_store: CpuGradientStore,
138 gpu_gradient_buffer: GpuGradientBuffer,
140
141 optimizer_state_manager: OptimizerStateManager,
143
144 memory_manager: Zero3MemoryManager,
146 prefetch_scheduler: PrefetchScheduler,
148
149 performance_stats: Arc<Mutex<Zero3PerformanceStats>>,
151}
152
153impl Zero3CpuOffloadManager {
154 pub fn new(
170 config: Zero3CpuOffloadConfig,
171 process_group: Arc<ProcessGroup>,
172 model_parameters: &ConfigModelParameters,
173 ) -> TorshResult<Self> {
174 let world_size = process_group.world_size() as usize;
175 let rank = process_group.rank() as usize;
176
177 info!(
178 " Initializing ZeRO-3 CPU Offload Manager: rank {}/{}, {} parameters",
179 rank, world_size, model_parameters.parameter_count
180 );
181
182 let rank_mapping = ConfigZero3RankMapping::new(rank, world_size);
183
184 let param_partitioner =
186 ParameterPartitioner::new(&config, &rank_mapping, model_parameters)?;
187 let cpu_param_store = CpuParameterStore::new(&config)?;
188 let gpu_param_cache = GpuParameterCache::new(&config)?;
189
190 let gradient_partitioner = GradientPartitioner::new(&config, &rank_mapping)?;
192 let cpu_gradient_store = CpuGradientStore::new(&config)?;
193 let gpu_gradient_buffer = GpuGradientBuffer::new(&config)?;
194
195 let optimizer_rank_mapping = OptimizerZero3RankMapping::new(rank, world_size);
197 let optimizer_state_manager = OptimizerStateManager::new(&config, &optimizer_rank_mapping)?;
198
199 let memory_manager = Zero3MemoryManager::new(&config);
201 let prefetch_scheduler = PrefetchScheduler::new(&config, process_group.clone());
202
203 let performance_stats = Arc::new(Mutex::new(Zero3PerformanceStats::new()));
204
205 info!(" ZeRO-3 CPU Offload initialized successfully:");
206 info!(
207 " Parameters: {} total, partitioned across {} ranks",
208 model_parameters.parameter_count, world_size
209 );
210 info!(
211 " Memory: CPU budget {}GB, GPU budget {}GB",
212 config.cpu_memory_budget / (1024 * 1024 * 1024),
213 config.gpu_param_memory_budget / (1024 * 1024 * 1024)
214 );
215 info!(
216 " 🔧 Features: params={}, grads={}, optimizer={}, prefetch={}",
217 config.offload_params,
218 config.offload_grads,
219 config.offload_optimizer_states,
220 config.async_prefetch
221 );
222
223 Ok(Self {
224 config,
225 process_group,
226 rank_mapping,
227 param_partitioner,
228 cpu_param_store,
229 gpu_param_cache,
230 gradient_partitioner,
231 cpu_gradient_store,
232 gpu_gradient_buffer,
233 optimizer_state_manager,
234 memory_manager,
235 prefetch_scheduler,
236 performance_stats,
237 })
238 }
239
240 pub async fn forward_pass(
258 &mut self,
259 input: &Tensor<f32>,
260 layer_names: &[String],
261 ) -> TorshResult<Tensor<f32>> {
262 let start_time = Instant::now();
263 let mut current_input = input.clone();
264
265 info!(" ZeRO-3 Forward Pass: {} layers", layer_names.len());
266
267 for (layer_idx, layer_name) in layer_names.iter().enumerate() {
269 let layer_start = Instant::now();
270
271 if self.config.async_prefetch {
273 self.prefetch_scheduler
274 .intelligent_prefetch(layer_name, layer_names)
275 .await?;
276 }
277
278 let layer_params = self.ensure_parameters_on_gpu(layer_name).await?;
280
281 current_input = self
283 .execute_layer_computation(¤t_input, &layer_params, layer_name)
284 .await?;
285
286 if self.should_offload_layer_params(layer_name, layer_idx, layer_names.len()) {
288 self.offload_parameters_to_cpu(layer_name, &layer_params)
289 .await?;
290 }
291
292 let layer_duration = layer_start.elapsed();
294 {
295 let mut stats = self
296 .performance_stats
297 .lock()
298 .expect("lock should not be poisoned");
299 stats.record_layer_execution(layer_name.clone(), layer_duration);
300 }
301
302 if layer_idx % 4 == 0 {
304 self.memory_manager.check_and_optimize_memory().await?;
305 }
306 }
307
308 let total_duration = start_time.elapsed();
310 {
311 let mut stats = self
312 .performance_stats
313 .lock()
314 .expect("lock should not be poisoned");
315 stats.record_forward_pass(total_duration, input.numel());
316 }
317
318 info!(" Forward pass completed in {:?}", total_duration);
319 Ok(current_input)
320 }
321
322 pub async fn backward_pass(
339 &mut self,
340 grad_output: &Tensor<f32>,
341 layer_names: &[String],
342 ) -> TorshResult<()> {
343 let start_time = Instant::now();
344 let mut current_grad = grad_output.clone();
345
346 info!(" ZeRO-3 Backward Pass: {} layers", layer_names.len());
347
348 for (rev_idx, layer_name) in layer_names.iter().rev().enumerate() {
350 let layer_start = Instant::now();
351
352 let layer_params = self.ensure_parameters_on_gpu(layer_name).await?;
354
355 let (grad_input, param_grads) = self
357 .compute_layer_gradients(¤t_grad, &layer_params, layer_name)
358 .await?;
359
360 self.handle_parameter_gradients(layer_name, ¶m_grads)
362 .await?;
363
364 current_grad = grad_input;
366
367 if self.should_offload_layer_params(layer_name, rev_idx, layer_names.len()) {
369 self.offload_parameters_to_cpu(layer_name, &layer_params)
370 .await?;
371 }
372
373 let layer_duration = layer_start.elapsed();
374 {
375 let mut stats = self
376 .performance_stats
377 .lock()
378 .expect("lock should not be poisoned");
379 stats.record_layer_backward(layer_name.clone(), layer_duration);
380 }
381 }
382
383 self.all_reduce_partitioned_gradients().await?;
385
386 let total_duration = start_time.elapsed();
387 {
388 let mut stats = self
389 .performance_stats
390 .lock()
391 .expect("lock should not be poisoned");
392 stats.record_backward_pass(total_duration, grad_output.numel());
393 }
394
395 info!(" Backward pass completed in {:?}", total_duration);
396 Ok(())
397 }
398
399 pub async fn optimizer_step(&mut self, learning_rate: f32) -> TorshResult<()> {
416 let start_time = Instant::now();
417
418 info!(" ZeRO-3 Optimizer Step (lr: {})", learning_rate);
419
420 let owned_param_grads = self.gather_owned_parameter_gradients().await?;
422
423 info!(
424 " Processing {} owned parameter gradients",
425 owned_param_grads.len()
426 );
427
428 for (param_name, gradient) in owned_param_grads.iter() {
430 let optimizer_state = self.optimizer_state_manager.fetch_state(param_name).await?;
432
433 let param_update =
435 self.compute_parameter_update(&optimizer_state, gradient, learning_rate)?;
436
437 let mut parameter = self.fetch_parameter_for_update(param_name).await?;
439 parameter = parameter.sub(¶m_update)?;
440
441 self.store_updated_parameter(param_name, ¶meter).await?;
443 self.optimizer_state_manager
444 .store_state(param_name, &optimizer_state)
445 .await?;
446 }
447
448 self.broadcast_parameter_updates().await?;
450
451 let duration = start_time.elapsed();
452 {
453 let mut stats = self
454 .performance_stats
455 .lock()
456 .expect("lock should not be poisoned");
457 stats.record_optimizer_step(duration, owned_param_grads.len());
458 }
459
460 info!(" Optimizer step completed in {:?}", duration);
461 Ok(())
462 }
463
464 pub fn get_performance_stats(&self) -> Zero3PerformanceStats {
469 self.performance_stats
470 .lock()
471 .expect("lock should not be poisoned")
472 .clone()
473 }
474
475 pub fn get_memory_stats(&self) -> Zero3MemoryStats {
480 self.memory_manager.get_memory_stats()
481 }
482
483 pub async fn force_memory_optimization(&self) -> TorshResult<()> {
488 self.memory_manager.force_memory_optimization().await
489 }
490
491 pub fn get_prefetch_status(&self) -> PrefetchQueueStatus {
495 self.prefetch_scheduler.get_queue_status()
496 }
497
498 pub async fn adapt_performance(&self) -> TorshResult<()> {
503 self.prefetch_scheduler.adapt_prefetch_strategy().await
504 }
505
506 pub async fn reset_state(&self) -> TorshResult<()> {
510 self.optimizer_state_manager.clear_states().await?;
511 self.prefetch_scheduler.cancel_all_prefetches().await?;
512 info!("🧹 ZeRO-3 manager state reset completed");
513 Ok(())
514 }
515
516 async fn ensure_parameters_on_gpu(&mut self, layer_name: &str) -> TorshResult<LayerParameters> {
519 if let Some(cached_params) = self.gpu_param_cache.get(layer_name).await? {
521 return Ok(cached_params);
522 }
523
524 let cpu_params = self.cpu_param_store.fetch(layer_name).await?;
526
527 let gpu_params = self.transfer_params_to_gpu(&cpu_params).await?;
529
530 self.gpu_param_cache.store(layer_name, &gpu_params).await?;
532
533 Ok(gpu_params)
534 }
535
536 async fn transfer_params_to_gpu(
537 &self,
538 cpu_params: &CpuParameterData,
539 ) -> TorshResult<LayerParameters> {
540 let transfer_start = Instant::now();
541
542 let decompressed_data = match self.config.cpu_compression {
544 CpuCompressionMethod::None => cpu_params.data.clone(),
545 CpuCompressionMethod::FP16 => self.decompress_fp16(&cpu_params.data)?,
546 CpuCompressionMethod::BF16 => self.decompress_bf16(&cpu_params.data)?,
547 CpuCompressionMethod::INT8 => self.decompress_int8(&cpu_params.data)?,
548 _ => {
549 return Err(TorshDistributedError::feature_not_available(
550 "compression_method",
551 "Compression method not implemented",
552 ));
553 }
554 };
555
556 let weight = Tensor::from_data(
558 decompressed_data,
559 cpu_params.weight_shape.clone(),
560 DeviceType::Cuda(0),
561 )?;
562 let bias = if let Some(ref bias_data) = cpu_params.bias_data {
563 Some(Tensor::from_data(
564 bias_data.clone(),
565 cpu_params
566 .bias_shape
567 .as_ref()
568 .expect("bias_shape should be present when bias_data exists")
569 .clone(),
570 DeviceType::Cuda(0),
571 )?)
572 } else {
573 None
574 };
575
576 let transfer_duration = transfer_start.elapsed();
578 {
579 let mut stats = self
580 .performance_stats
581 .lock()
582 .expect("lock should not be poisoned");
583 stats.record_parameter_transfer(
584 transfer_duration,
585 cpu_params.size_bytes,
586 TransferDirection::CpuToGpu,
587 );
588 }
589
590 info!(
591 " Transferred parameters to GPU: {} ({} bytes in {:?})",
592 "layer", cpu_params.size_bytes, transfer_duration
593 );
594
595 Ok(LayerParameters { weight, bias })
596 }
597
598 async fn execute_layer_computation(
599 &self,
600 input: &Tensor<f32>,
601 params: &LayerParameters,
602 layer_name: &str,
603 ) -> TorshResult<Tensor<f32>> {
604 info!(" 🧮 Computing layer: {}", layer_name);
605
606 let output = input.matmul(¶ms.weight)?;
608
609 if let Some(ref bias) = params.bias {
610 let output = output.add(bias)?;
611 Ok(output.relu()?) } else {
613 Ok(output.relu()?)
614 }
615 }
616
617 fn should_offload_layer_params(
618 &self,
619 _layer_name: &str,
620 current_idx: usize,
621 total_layers: usize,
622 ) -> bool {
623 let remaining_layers = total_layers - current_idx;
625 remaining_layers > self.config.prefetch_buffer_size
626 }
627
628 async fn offload_parameters_to_cpu(
629 &mut self,
630 layer_name: &str,
631 params: &LayerParameters,
632 ) -> TorshResult<()> {
633 if !self.config.offload_params {
634 return Ok(());
635 }
636
637 let offload_start = Instant::now();
638
639 let compressed_data = self.compress_parameters(params).await?;
641
642 self.cpu_param_store
644 .store(layer_name, &compressed_data)
645 .await?;
646
647 self.gpu_param_cache.remove(layer_name).await?;
649
650 let offload_duration = offload_start.elapsed();
652 {
653 let mut stats = self
654 .performance_stats
655 .lock()
656 .expect("lock should not be poisoned");
657 stats.record_parameter_transfer(
658 offload_duration,
659 compressed_data.size_bytes,
660 TransferDirection::GpuToCpu,
661 );
662 }
663
664 info!(
665 " Offloaded parameters to CPU: {} ({} bytes in {:?})",
666 layer_name, compressed_data.size_bytes, offload_duration
667 );
668
669 Ok(())
670 }
671
672 async fn compress_parameters(&self, params: &LayerParameters) -> TorshResult<CpuParameterData> {
673 let weight_data = params.weight.to_vec()?;
674 let bias_data = if let Some(ref bias) = params.bias {
675 Some(bias.to_vec()?)
676 } else {
677 None
678 };
679
680 let (compressed_weight, weight_shape) = match self.config.cpu_compression {
681 CpuCompressionMethod::None => (weight_data, params.weight.shape().dims().to_vec()),
682 CpuCompressionMethod::FP16 => {
683 self.compress_to_fp16(&weight_data, params.weight.shape().dims())?
684 }
685 CpuCompressionMethod::BF16 => {
686 self.compress_to_bf16(&weight_data, params.weight.shape().dims())?
687 }
688 CpuCompressionMethod::INT8 => {
689 self.compress_to_int8(&weight_data, params.weight.shape().dims())?
690 }
691 _ => {
692 return Err(TorshDistributedError::feature_not_available(
693 "compression_method",
694 "Compression method not implemented",
695 ));
696 }
697 };
698
699 let size_bytes = compressed_weight.len() * std::mem::size_of::<f32>()
700 + bias_data
701 .as_ref()
702 .map(|b: &Vec<f32>| b.len() * std::mem::size_of::<f32>())
703 .unwrap_or(0);
704
705 Ok(CpuParameterData {
706 data: compressed_weight,
707 bias_data,
708 weight_shape,
709 bias_shape: params.bias.as_ref().map(|b| b.shape().dims().to_vec()),
710 size_bytes,
711 compression: self.config.cpu_compression,
712 })
713 }
714
715 async fn compute_layer_gradients(
716 &self,
717 grad_output: &Tensor<f32>,
718 params: &LayerParameters,
719 layer_name: &str,
720 ) -> TorshResult<(Tensor<f32>, ParameterGradients)> {
721 info!(" 🔢 Computing gradients for layer: {}", layer_name);
722
723 let grad_input = grad_output.matmul(¶ms.weight.transpose(-2, -1)?)?;
725 let grad_weight = grad_output.clone(); let grad_bias = if params.bias.is_some() {
727 Some(grad_output.sum_dim(&[0], false)?)
728 } else {
729 None
730 };
731
732 let param_grads = ParameterGradients {
733 weight_grad: grad_weight,
734 bias_grad: grad_bias,
735 };
736
737 Ok((grad_input, param_grads))
738 }
739
740 async fn handle_parameter_gradients(
741 &mut self,
742 layer_name: &str,
743 grads: &ParameterGradients,
744 ) -> TorshResult<()> {
745 let partitioned_grads = self
747 .gradient_partitioner
748 .partition_gradients(layer_name, grads)?;
749
750 for (partition_idx, grad_partition) in partitioned_grads.into_iter().enumerate() {
752 if self.rank_mapping.owns_partition(partition_idx) {
753 if self.config.offload_grads {
754 self.cpu_gradient_store
755 .store(layer_name, partition_idx, &grad_partition.weight_gradient)
756 .await?;
757 } else {
758 self.gpu_gradient_buffer
759 .store(layer_name, partition_idx, &grad_partition.weight_gradient)
760 .await?;
761 }
762 }
763 }
764
765 Ok(())
766 }
767
768 async fn all_reduce_partitioned_gradients(&mut self) -> TorshResult<()> {
769 let sync_start = Instant::now();
770 info!(" All-reducing partitioned gradients");
771
772 let local_gradients = self.cpu_gradient_store.get_all_gradients().await?;
773 let gradients_count = local_gradients.len();
774
775 for (layer_partition_key, gradient) in local_gradients {
777 let mut grad_tensor = gradient;
778 let world_size = self.process_group.world_size() as f32;
779
780 grad_tensor = grad_tensor.div_scalar(world_size)?;
782
783 self.cpu_gradient_store
784 .store_reduced_gradient(&layer_partition_key, &grad_tensor)
785 .await?;
786 }
787
788 let sync_duration = sync_start.elapsed();
789 {
790 let mut stats = self
791 .performance_stats
792 .lock()
793 .expect("lock should not be poisoned");
794 stats.record_gradient_sync(
795 sync_duration,
796 gradients_count,
797 self.process_group.world_size() as usize,
798 );
799 }
800
801 info!(
802 " Gradient synchronization completed in {:?}",
803 sync_duration
804 );
805 Ok(())
806 }
807
808 async fn gather_owned_parameter_gradients(
809 &mut self,
810 ) -> TorshResult<HashMap<String, Tensor<f32>>> {
811 self.cpu_gradient_store
812 .get_owned_gradients(self.rank_mapping.rank(), self.rank_mapping.world_size())
813 .await
814 }
815
816 fn compute_parameter_update(
817 &self,
818 _optimizer_state: &OptimizerState,
819 gradient: &Tensor<f32>,
820 learning_rate: f32,
821 ) -> TorshResult<Tensor<f32>> {
822 Ok(gradient.mul_scalar(learning_rate)?)
824 }
825
826 async fn fetch_parameter_for_update(&mut self, param_name: &str) -> TorshResult<Tensor<f32>> {
827 let cpu_param_data = self.cpu_param_store.fetch(param_name).await?;
828 let gpu_params = self.transfer_params_to_gpu(&cpu_param_data).await?;
829 Ok(gpu_params.weight)
830 }
831
832 async fn store_updated_parameter(
833 &mut self,
834 param_name: &str,
835 parameter: &Tensor<f32>,
836 ) -> TorshResult<()> {
837 let layer_params = LayerParameters {
838 weight: parameter.clone(),
839 bias: None, };
841
842 let compressed_data = self.compress_parameters(&layer_params).await?;
843 self.cpu_param_store
844 .store(param_name, &compressed_data)
845 .await?;
846
847 Ok(())
848 }
849
850 async fn broadcast_parameter_updates(&mut self) -> TorshResult<()> {
851 let broadcast_start = Instant::now();
852 info!(" Broadcasting parameter updates across process group");
853
854 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
856
857 let broadcast_duration = broadcast_start.elapsed();
858 {
859 let mut stats = self
860 .performance_stats
861 .lock()
862 .expect("lock should not be poisoned");
863 stats.record_communication(
864 CommunicationOperation::Broadcast,
865 broadcast_duration,
866 1024 * 1024, );
868 }
869
870 info!(
871 " Parameter broadcasting completed in {:?}",
872 broadcast_duration
873 );
874 Ok(())
875 }
876
877 fn compress_to_fp16(
879 &self,
880 data: &[f32],
881 shape: &[usize],
882 ) -> TorshResult<(Vec<f32>, Vec<usize>)> {
883 let compressed: Vec<f32> = data
884 .iter()
885 .map(|&val| f16::from_f32(val).to_f32())
886 .collect();
887 Ok((compressed, shape.to_vec()))
888 }
889
890 fn compress_to_bf16(
891 &self,
892 data: &[f32],
893 shape: &[usize],
894 ) -> TorshResult<(Vec<f32>, Vec<usize>)> {
895 let compressed: Vec<f32> = data
896 .iter()
897 .map(|&val| bf16::from_f32(val).to_f32())
898 .collect();
899 Ok((compressed, shape.to_vec()))
900 }
901
902 fn compress_to_int8(
903 &self,
904 data: &[f32],
905 shape: &[usize],
906 ) -> TorshResult<(Vec<f32>, Vec<usize>)> {
907 if data.is_empty() {
908 return Ok((Vec::new(), shape.to_vec()));
909 }
910
911 let max_abs = data
912 .iter()
913 .map(|&x| x.abs())
914 .fold(f32::NEG_INFINITY, f32::max);
915 if max_abs == 0.0 {
916 return Ok((vec![0.0; data.len()], shape.to_vec()));
917 }
918
919 let scale = 127.0 / max_abs;
920 let inv_scale = max_abs / 127.0;
921
922 let quantized: Vec<f32> = data
923 .iter()
924 .map(|&val| {
925 let quantized_val = (val * scale).round().clamp(-127.0, 127.0);
926 quantized_val * inv_scale
927 })
928 .collect();
929
930 Ok((quantized, shape.to_vec()))
931 }
932
933 fn decompress_fp16(&self, data: &[f32]) -> TorshResult<Vec<f32>> {
934 Ok(data.to_vec())
935 }
936
937 fn decompress_bf16(&self, data: &[f32]) -> TorshResult<Vec<f32>> {
938 Ok(data.to_vec())
939 }
940
941 fn decompress_int8(&self, data: &[f32]) -> TorshResult<Vec<f32>> {
942 Ok(data.to_vec())
943 }
944}
945
946#[derive(Debug)]
948pub struct ModelParameters {
949 pub parameter_count: usize,
951 pub parameter_names: Vec<String>,
953 pub parameter_shapes: HashMap<String, Vec<usize>>,
955 pub total_memory_bytes: usize,
957}
958
959impl ModelParameters {
960 pub fn new() -> Self {
962 Self {
963 parameter_count: 0,
964 parameter_names: Vec::new(),
965 parameter_shapes: HashMap::new(),
966 total_memory_bytes: 0,
967 }
968 }
969
970 pub fn add_parameter(&mut self, name: String, shape: Vec<usize>) {
972 let param_size = shape.iter().product::<usize>();
973 self.parameter_count += param_size;
974 self.total_memory_bytes += param_size * std::mem::size_of::<f32>();
975 self.parameter_shapes.insert(name.clone(), shape);
976 self.parameter_names.push(name);
977 }
978
979 pub fn get_parameter_shape(&self, name: &str) -> Option<&Vec<usize>> {
981 self.parameter_shapes.get(name)
982 }
983
984 pub fn has_parameter(&self, name: &str) -> bool {
986 self.parameter_shapes.contains_key(name)
987 }
988}
989
990impl Default for ModelParameters {
991 fn default() -> Self {
992 Self::new()
993 }
994}
995
996#[cfg(test)]
997mod tests {
998 use super::*;
999 use crate::{init_process_group, BackendType};
1000
1001 #[test]
1002 fn test_model_parameters() {
1003 let mut model_params = ConfigModelParameters::new();
1004 model_params.add_parameter("layer1.weight".to_string(), vec![512, 1024]);
1005 model_params.add_parameter("layer1.bias".to_string(), vec![1024]);
1006
1007 assert_eq!(model_params.parameter_names.len(), 2);
1008 assert_eq!(model_params.parameter_count, 512 * 1024 + 1024);
1009 assert!(model_params.has_parameter("layer1.weight"));
1010 assert!(!model_params.has_parameter("nonexistent"));
1011 }
1012
1013 #[tokio::test]
1014 async fn test_zero3_manager_creation() {
1015 let pg = init_process_group(BackendType::Gloo, 0, 4, "127.0.0.1", 29500)
1016 .await
1017 .unwrap();
1018 let config = Zero3CpuOffloadConfig::default();
1019
1020 let mut model_params = ConfigModelParameters::new();
1021 model_params.add_parameter("layer1.weight".to_string(), vec![512, 512]);
1022 model_params.add_parameter("layer2.weight".to_string(), vec![512, 512]);
1023
1024 let manager = Zero3CpuOffloadManager::new(config, Arc::new(pg), &model_params);
1025 assert!(manager.is_ok());
1026
1027 let manager = manager.unwrap();
1028 let stats = manager.get_performance_stats();
1029 assert_eq!(stats.forward_passes, 0);
1030
1031 let _memory_stats = manager.get_memory_stats();
1032 }
1034
1035 #[tokio::test]
1036 async fn test_manager_operations() {
1037 let pg = init_process_group(BackendType::Gloo, 0, 1, "127.0.0.1", 29500)
1038 .await
1039 .unwrap();
1040 let config = Zero3CpuOffloadConfig::default();
1041
1042 let mut model_params = ConfigModelParameters::new();
1043 model_params.add_parameter("test_layer".to_string(), vec![10, 10]);
1044
1045 let manager = Zero3CpuOffloadManager::new(config, Arc::new(pg), &model_params).unwrap();
1046
1047 manager.reset_state().await.unwrap();
1049
1050 manager.force_memory_optimization().await.unwrap();
1052
1053 let prefetch_status = manager.get_prefetch_status();
1055 assert_eq!(prefetch_status.queued_requests, 0);
1056 }
1057
1058 #[tokio::test]
1059 async fn test_compression_methods() {
1060 let config = Zero3CpuOffloadConfig::default();
1061 let pg = init_process_group(BackendType::Gloo, 0, 1, "127.0.0.1", 29500)
1062 .await
1063 .unwrap();
1064 let model_params = ConfigModelParameters::new();
1065 let manager = Zero3CpuOffloadManager::new(config, Arc::new(pg), &model_params).unwrap();
1066
1067 let test_data = vec![1.0, 2.0, -1.5, 0.5];
1068 let shape = vec![2, 2];
1069
1070 let (compressed, result_shape) = manager.compress_to_fp16(&test_data, &shape).unwrap();
1072 assert_eq!(result_shape, shape);
1073 assert_eq!(compressed.len(), test_data.len());
1074
1075 let (compressed, result_shape) = manager.compress_to_bf16(&test_data, &shape).unwrap();
1077 assert_eq!(result_shape, shape);
1078 assert_eq!(compressed.len(), test_data.len());
1079
1080 let (compressed, result_shape) = manager.compress_to_int8(&test_data, &shape).unwrap();
1082 assert_eq!(result_shape, shape);
1083 assert_eq!(compressed.len(), test_data.len());
1084 }
1085}