1#![allow(dead_code)]
10#![allow(clippy::await_holding_lock)]
11use crate::{ProcessGroup, TorshResult};
12use log::info;
13use std::collections::VecDeque;
14use std::sync::{Arc, Mutex};
15use tokio::sync::Semaphore;
16
17use super::config::Zero3CpuOffloadConfig;
18
19pub struct PrefetchScheduler {
28 config: Zero3CpuOffloadConfig,
30 process_group: Arc<ProcessGroup>,
32 prefetch_queue: Mutex<VecDeque<PrefetchRequest>>,
34 active_prefetches: Arc<Mutex<Vec<PrefetchOperation>>>,
36 metrics: Arc<Mutex<PrefetchMetrics>>,
38 adaptive_config: Arc<Mutex<AdaptivePrefetchConfig>>,
40 task_coordination: Arc<Mutex<TaskCoordination>>,
42}
43
44impl PrefetchScheduler {
45 pub fn new(config: &Zero3CpuOffloadConfig, process_group: Arc<ProcessGroup>) -> Self {
47 Self {
48 config: config.clone(),
49 process_group,
50 prefetch_queue: Mutex::new(VecDeque::new()),
51 active_prefetches: Arc::new(Mutex::new(Vec::new())),
52 metrics: Arc::new(Mutex::new(PrefetchMetrics::new())),
53 adaptive_config: Arc::new(Mutex::new(AdaptivePrefetchConfig::new(config))),
54 task_coordination: Arc::new(Mutex::new(TaskCoordination::new())),
55 }
56 }
57
58 pub async fn schedule_prefetch(&self, layer_name: &str) -> TorshResult<()> {
63 if !self.config.async_prefetch {
64 return Ok(());
65 }
66
67 let request = PrefetchRequest {
68 layer_name: layer_name.to_string(),
69 priority: PrefetchPriority::Normal,
70 requested_at: std::time::Instant::now(),
71 estimated_size_bytes: self.estimate_layer_size(layer_name),
72 };
73
74 {
76 let mut queue = self
77 .prefetch_queue
78 .lock()
79 .expect("lock should not be poisoned");
80 queue.push_back(request.clone());
81
82 let max_queue_size = self
84 .adaptive_config
85 .lock()
86 .expect("lock should not be poisoned")
87 .max_queue_size;
88 while queue.len() > max_queue_size {
89 if let Some(dropped) = queue.pop_front() {
90 info!(
91 " Dropped prefetch request for {} (queue full)",
92 dropped.layer_name
93 );
94 }
95 }
96 }
97
98 info!(
99 " Scheduled prefetch for layer: {} ({} bytes)",
100 layer_name, request.estimated_size_bytes
101 );
102
103 self.execute_async_prefetch(request).await?;
105
106 Ok(())
107 }
108
109 async fn execute_async_prefetch(&self, request: PrefetchRequest) -> TorshResult<()> {
111 let process_group = self.process_group.clone();
112 let metrics = self.metrics.clone();
113 let active_prefetches = self.active_prefetches.clone();
114
115 if !self.can_start_prefetch().await? {
117 info!(
118 " = Delaying prefetch for {} (system busy)",
119 request.layer_name
120 );
121 return Ok(());
122 }
123
124 let operation = PrefetchOperation {
126 layer_name: request.layer_name.clone(),
127 started_at: std::time::Instant::now(),
128 status: PrefetchStatus::InProgress,
129 };
130
131 {
133 let mut active = active_prefetches
134 .lock()
135 .expect("lock should not be poisoned");
136 active.push(operation);
137 }
138
139 let layer_name = request.layer_name.clone();
141 tokio::spawn(async move {
142 let start_time = std::time::Instant::now();
143 let result = Self::prefetch_layer_data(&layer_name, process_group).await;
144
145 {
147 let mut metrics_guard = metrics.lock().expect("lock should not be poisoned");
148 let duration = start_time.elapsed();
149
150 match result {
151 Ok(()) => {
152 metrics_guard.record_successful_prefetch(duration, 0); info!(
154 " = Async prefetch completed for layer: {} in {:?}",
155 layer_name, duration
156 );
157 }
158 Err(e) => {
159 metrics_guard.record_failed_prefetch(duration, e.to_string());
160 tracing::error!("Async prefetch failed for layer {}: {}", layer_name, e);
161 }
162 }
163 }
164
165 {
167 let mut active = active_prefetches
168 .lock()
169 .expect("lock should not be poisoned");
170 active.retain(|op| op.layer_name != layer_name);
171 }
172 });
173
174 Ok(())
175 }
176
177 async fn can_start_prefetch(&self) -> TorshResult<bool> {
179 let adaptive_config = self
180 .adaptive_config
181 .lock()
182 .expect("lock should not be poisoned");
183 let active_count = self
184 .active_prefetches
185 .lock()
186 .expect("lock should not be poisoned")
187 .len();
188
189 if active_count >= adaptive_config.max_concurrent_prefetches {
191 return Ok(false);
192 }
193
194 Ok(true)
202 }
203
204 async fn prefetch_layer_data(
206 layer_name: &str,
207 _process_group: Arc<ProcessGroup>,
208 ) -> TorshResult<()> {
209 let estimated_transfer_time = Self::estimate_transfer_time(layer_name);
220 tokio::time::sleep(estimated_transfer_time).await;
221
222 Ok(())
223 }
224
225 fn estimate_transfer_time(layer_name: &str) -> tokio::time::Duration {
227 let base_time_ms = if layer_name.contains("large") {
229 50 } else if layer_name.contains("medium") {
231 25 } else {
233 10 };
235
236 tokio::time::Duration::from_millis(base_time_ms)
237 }
238
239 pub async fn batch_prefetch(&self, layer_names: Vec<String>) -> TorshResult<()> {
241 if !self.config.async_prefetch || layer_names.is_empty() {
242 return Ok(());
243 }
244
245 info!(
246 " = Starting batch prefetch for {} layers",
247 layer_names.len()
248 );
249
250 #[allow(clippy::await_holding_lock)]
251 let adaptive_config = self
252 .adaptive_config
253 .lock()
254 .expect("lock should not be poisoned");
255 let max_concurrent = adaptive_config.max_concurrent_prefetches;
256 drop(adaptive_config);
257
258 let semaphore = Arc::new(Semaphore::new(max_concurrent));
260 let mut tasks = Vec::new();
261
262 for layer_name in layer_names {
263 let sem = semaphore.clone();
264 let process_group = self.process_group.clone();
265 let metrics = self.metrics.clone();
266
267 let task = tokio::spawn(async move {
268 let _permit = sem.acquire().await.expect("semaphore should not be closed");
269 let start_time = std::time::Instant::now();
270 let result = Self::prefetch_layer_data(&layer_name, process_group).await;
271
272 {
274 let mut metrics_guard = metrics.lock().expect("lock should not be poisoned");
275 let duration = start_time.elapsed();
276 match result {
277 Ok(()) => metrics_guard.record_successful_prefetch(duration, 0),
278 Err(ref e) => metrics_guard.record_failed_prefetch(duration, e.to_string()),
279 }
280 }
281
282 result
283 });
284
285 tasks.push(task);
286 }
287
288 let results: Vec<_> = futures::future::join_all(tasks).await;
290
291 let mut successful = 0;
292 let mut failed = 0;
293
294 for result in results {
295 match result {
296 Ok(Ok(())) => successful += 1,
297 Ok(Err(e)) => {
298 failed += 1;
299 tracing::error!("Prefetch task failed: {}", e);
300 }
301 Err(e) => {
302 failed += 1;
303 tracing::error!("Prefetch task panicked: {}", e);
304 }
305 }
306 }
307
308 info!(
309 " Batch prefetch completed: {} successful, {} failed",
310 successful, failed
311 );
312
313 {
315 let mut metrics = self.metrics.lock().expect("lock should not be poisoned");
316 metrics.record_batch_prefetch(successful, failed);
317 }
318
319 Ok(())
320 }
321
322 pub async fn intelligent_prefetch(
324 &self,
325 current_layer: &str,
326 execution_graph: &[String],
327 ) -> TorshResult<()> {
328 if !self.config.async_prefetch {
329 return Ok(());
330 }
331
332 let current_pos = execution_graph.iter().position(|l| l == current_layer);
334
335 if let Some(pos) = current_pos {
336 let prefetch_distance = self.calculate_optimal_prefetch_distance().await?;
338
339 let mut layers_to_prefetch = Vec::new();
341 for i in 1..=prefetch_distance {
342 if pos + i < execution_graph.len() {
343 layers_to_prefetch.push(execution_graph[pos + i].clone());
344 }
345 }
346
347 if !layers_to_prefetch.is_empty() {
348 info!(
349 " > Intelligent prefetch: {} layers ahead from {}",
350 layers_to_prefetch.len(),
351 current_layer
352 );
353
354 self.prioritized_batch_prefetch(layers_to_prefetch, pos)
356 .await?;
357 }
358 }
359
360 Ok(())
361 }
362
363 async fn prioritized_batch_prefetch(
365 &self,
366 layer_names: Vec<String>,
367 _current_pos: usize,
368 ) -> TorshResult<()> {
369 let mut prioritized_requests = Vec::new();
370
371 for (i, layer_name) in layer_names.iter().enumerate() {
372 let priority = match i {
373 0 => PrefetchPriority::High, 1..=2 => PrefetchPriority::Normal, _ => PrefetchPriority::Low, };
377
378 let request = PrefetchRequest {
379 layer_name: layer_name.clone(),
380 priority,
381 requested_at: std::time::Instant::now(),
382 estimated_size_bytes: self.estimate_layer_size(layer_name),
383 };
384
385 prioritized_requests.push(request);
386 }
387
388 prioritized_requests.sort_by(|a, b| b.priority.cmp(&a.priority));
390
391 #[allow(clippy::await_holding_lock)]
392 let adaptive_config = self
394 .adaptive_config
395 .lock()
396 .expect("lock should not be poisoned");
397 let max_concurrent = adaptive_config.max_concurrent_prefetches;
398 drop(adaptive_config);
399
400 let semaphore = Arc::new(Semaphore::new(max_concurrent));
401 let mut tasks = Vec::new();
402
403 for request in prioritized_requests {
404 let sem = semaphore.clone();
405 let process_group = self.process_group.clone();
406 let metrics = self.metrics.clone();
407
408 let delay = match request.priority {
410 PrefetchPriority::High => tokio::time::Duration::from_millis(0),
411 PrefetchPriority::Normal => tokio::time::Duration::from_millis(10),
412 PrefetchPriority::Low => tokio::time::Duration::from_millis(25),
413 };
414
415 let task = tokio::spawn(async move {
416 tokio::time::sleep(delay).await; let _permit = sem.acquire().await.expect("semaphore should not be closed");
418 let start_time = std::time::Instant::now();
419 let result = Self::prefetch_layer_data(&request.layer_name, process_group).await;
420
421 {
423 let mut metrics_guard = metrics.lock().expect("lock should not be poisoned");
424 let duration = start_time.elapsed();
425 match result {
426 Ok(()) => metrics_guard
427 .record_successful_prefetch(duration, request.estimated_size_bytes),
428 Err(ref e) => metrics_guard.record_failed_prefetch(duration, e.to_string()),
429 }
430 }
431
432 (request.layer_name, result)
433 });
434
435 tasks.push(task);
436 }
437
438 let results: Vec<_> = futures::future::join_all(tasks).await;
440
441 let mut successful = 0;
442 let mut failed = 0;
443
444 for result in results {
445 match result {
446 Ok((layer_name, Ok(()))) => {
447 successful += 1;
448 info!(" Prioritized prefetch completed: {}", layer_name);
449 }
450 Ok((layer_name, Err(e))) => {
451 failed += 1;
452 tracing::error!("Prioritized prefetch failed for {}: {}", layer_name, e);
453 }
454 Err(e) => {
455 failed += 1;
456 tracing::error!("Prioritized prefetch task panicked: {}", e);
457 }
458 }
459 }
460
461 info!(
462 " < Prioritized batch prefetch completed: {} successful, {} failed",
463 successful, failed
464 );
465
466 Ok(())
467 }
468
469 pub async fn calculate_optimal_prefetch_distance(&self) -> TorshResult<usize> {
471 let adaptive_config = self
481 .adaptive_config
482 .lock()
483 .expect("lock should not be poisoned");
484 let base_distance = adaptive_config.base_prefetch_distance;
485 let current_performance = self
486 .metrics
487 .lock()
488 .expect("lock should not be poisoned")
489 .get_success_rate();
490
491 let performance_multiplier = if current_performance > 0.9 {
493 1.5 } else if current_performance > 0.7 {
495 1.0 } else {
497 0.7 };
499
500 let optimal_distance = (base_distance as f32 * performance_multiplier) as usize;
501 let optimal_distance = optimal_distance
502 .max(1)
503 .min(adaptive_config.max_prefetch_distance);
504
505 drop(adaptive_config);
507 {
508 let mut adaptive_config = self
509 .adaptive_config
510 .lock()
511 .expect("lock should not be poisoned");
512 adaptive_config.current_prefetch_distance = optimal_distance;
513 }
514
515 Ok(optimal_distance)
516 }
517
518 pub async fn adapt_prefetch_strategy(&self) -> TorshResult<()> {
520 let metrics = self
521 .metrics
522 .lock()
523 .expect("lock should not be poisoned")
524 .clone();
525 let mut adaptive_config = self
526 .adaptive_config
527 .lock()
528 .expect("lock should not be poisoned");
529
530 info!(" < Adapting prefetch strategy based on performance");
531
532 if metrics.get_success_rate() > 0.95 && metrics.total_prefetches > 10 {
534 adaptive_config.max_concurrent_prefetches =
536 (adaptive_config.max_concurrent_prefetches + 1).min(8);
537 info!(
538 " Increased max concurrent prefetches to {}",
539 adaptive_config.max_concurrent_prefetches
540 );
541 } else if metrics.get_success_rate() < 0.8 && adaptive_config.max_concurrent_prefetches > 1
542 {
543 adaptive_config.max_concurrent_prefetches =
545 (adaptive_config.max_concurrent_prefetches - 1).max(1);
546 info!(
547 " Decreased max concurrent prefetches to {}",
548 adaptive_config.max_concurrent_prefetches
549 );
550 }
551
552 let queue_size = self
554 .prefetch_queue
555 .lock()
556 .expect("lock should not be poisoned")
557 .len();
558 if queue_size > adaptive_config.max_queue_size * 3 / 4 {
559 adaptive_config.max_queue_size = (adaptive_config.max_queue_size + 2).min(32);
561 info!(
562 " = Increased max queue size to {}",
563 adaptive_config.max_queue_size
564 );
565 } else if queue_size < adaptive_config.max_queue_size / 4
566 && adaptive_config.max_queue_size > 4
567 {
568 adaptive_config.max_queue_size = (adaptive_config.max_queue_size - 1).max(4);
570 info!(
571 " = Decreased max queue size to {}",
572 adaptive_config.max_queue_size
573 );
574 }
575
576 if metrics.average_prefetch_time > tokio::time::Duration::from_millis(100) {
578 adaptive_config.base_prefetch_distance =
580 (adaptive_config.base_prefetch_distance - 1).max(1);
581 info!(
582 " =; Decreased base prefetch distance to {}",
583 adaptive_config.base_prefetch_distance
584 );
585 } else if metrics.average_prefetch_time < tokio::time::Duration::from_millis(20) {
586 adaptive_config.base_prefetch_distance =
588 (adaptive_config.base_prefetch_distance + 1).min(16);
589 info!(
590 " =: Increased base prefetch distance to {}",
591 adaptive_config.base_prefetch_distance
592 );
593 }
594
595 Ok(())
596 }
597
598 pub async fn cancel_all_prefetches(&self) -> TorshResult<()> {
600 info!(" = Cancelling all pending prefetch operations");
601
602 {
604 let mut queue = self
605 .prefetch_queue
606 .lock()
607 .expect("lock should not be poisoned");
608 let cancelled_count = queue.len();
609 queue.clear();
610 if cancelled_count > 0 {
611 info!(
612 " = Cancelled {} queued prefetch requests",
613 cancelled_count
614 );
615 }
616 }
617
618 {
625 let mut metrics = self.metrics.lock().expect("lock should not be poisoned");
626 metrics.record_cancellation();
627 }
628
629 Ok(())
630 }
631
632 pub fn get_queue_status(&self) -> PrefetchQueueStatus {
634 let queue = self
635 .prefetch_queue
636 .lock()
637 .expect("lock should not be poisoned");
638 let active = self
639 .active_prefetches
640 .lock()
641 .expect("lock should not be poisoned");
642 let adaptive_config = self
643 .adaptive_config
644 .lock()
645 .expect("lock should not be poisoned");
646
647 PrefetchQueueStatus {
648 queued_requests: queue.len(),
649 active_operations: active.len(),
650 max_queue_size: adaptive_config.max_queue_size,
651 max_concurrent: adaptive_config.max_concurrent_prefetches,
652 current_prefetch_distance: adaptive_config.current_prefetch_distance,
653 }
654 }
655
656 pub fn get_metrics(&self) -> PrefetchMetrics {
658 self.metrics
659 .lock()
660 .expect("lock should not be poisoned")
661 .clone()
662 }
663
664 pub fn get_adaptive_config(&self) -> AdaptivePrefetchConfig {
666 self.adaptive_config
667 .lock()
668 .expect("lock should not be poisoned")
669 .clone()
670 }
671
672 fn estimate_layer_size(&self, layer_name: &str) -> usize {
675 if layer_name.contains("large") {
677 64 * 1024 * 1024 } else if layer_name.contains("medium") {
679 16 * 1024 * 1024 } else {
681 4 * 1024 * 1024 }
683 }
684}
685
686#[derive(Debug, Clone)]
688pub struct PrefetchRequest {
689 pub layer_name: String,
691 pub priority: PrefetchPriority,
693 pub requested_at: std::time::Instant,
695 pub estimated_size_bytes: usize,
697}
698
699#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
701pub enum PrefetchPriority {
702 Low = 0,
703 Normal = 1,
704 High = 2,
705}
706
707#[derive(Debug, Clone, Copy, PartialEq, Eq)]
709pub enum PrefetchStatus {
710 Queued,
711 InProgress,
712 Completed,
713 Failed,
714 Cancelled,
715}
716
717#[derive(Debug, Clone)]
719pub struct PrefetchOperation {
720 pub layer_name: String,
722 pub started_at: std::time::Instant,
724 pub status: PrefetchStatus,
726}
727
728#[derive(Debug, Clone)]
730pub struct PrefetchQueueStatus {
731 pub queued_requests: usize,
733 pub active_operations: usize,
735 pub max_queue_size: usize,
737 pub max_concurrent: usize,
739 pub current_prefetch_distance: usize,
741}
742
743#[derive(Debug, Clone)]
745pub struct AdaptivePrefetchConfig {
746 pub base_prefetch_distance: usize,
748 pub current_prefetch_distance: usize,
750 pub max_prefetch_distance: usize,
752 pub max_concurrent_prefetches: usize,
754 pub max_queue_size: usize,
756 pub adaptive_optimization_enabled: bool,
758}
759
760impl AdaptivePrefetchConfig {
761 pub fn new(config: &Zero3CpuOffloadConfig) -> Self {
762 Self {
763 base_prefetch_distance: config.prefetch_buffer_size / 4,
764 current_prefetch_distance: config.prefetch_buffer_size / 4,
765 max_prefetch_distance: config.prefetch_buffer_size,
766 max_concurrent_prefetches: 4,
767 max_queue_size: 16,
768 adaptive_optimization_enabled: true,
769 }
770 }
771}
772
773#[derive(Debug)]
775pub struct TaskCoordination {
776 pub active_tasks: usize,
778 pub max_background_tasks: usize,
780 pub coordination_enabled: bool,
782}
783
784impl TaskCoordination {
785 pub fn new() -> Self {
786 Self {
787 active_tasks: 0,
788 max_background_tasks: 8,
789 coordination_enabled: true,
790 }
791 }
792}
793
794impl Default for TaskCoordination {
795 fn default() -> Self {
796 Self::new()
797 }
798}
799
800#[derive(Debug, Clone)]
802pub struct PrefetchMetrics {
803 pub total_prefetches: u64,
805 pub successful_prefetches: u64,
807 pub failed_prefetches: u64,
809 pub cancelled_prefetches: u64,
811 pub total_prefetch_time: tokio::time::Duration,
813 pub average_prefetch_time: tokio::time::Duration,
815 pub total_bytes_prefetched: usize,
817 pub batch_operations: u64,
819 pub failed_batch_operations: u64,
821 pub recent_failures: Vec<String>,
823}
824
825impl PrefetchMetrics {
826 pub fn new() -> Self {
827 Self {
828 total_prefetches: 0,
829 successful_prefetches: 0,
830 failed_prefetches: 0,
831 cancelled_prefetches: 0,
832 total_prefetch_time: tokio::time::Duration::ZERO,
833 average_prefetch_time: tokio::time::Duration::ZERO,
834 total_bytes_prefetched: 0,
835 batch_operations: 0,
836 failed_batch_operations: 0,
837 recent_failures: Vec::new(),
838 }
839 }
840
841 pub fn record_successful_prefetch(&mut self, duration: tokio::time::Duration, bytes: usize) {
843 self.total_prefetches += 1;
844 self.successful_prefetches += 1;
845 self.total_prefetch_time += duration;
846 self.total_bytes_prefetched += bytes;
847 self.update_average_time();
848 }
849
850 pub fn record_failed_prefetch(&mut self, duration: tokio::time::Duration, error: String) {
852 self.total_prefetches += 1;
853 self.failed_prefetches += 1;
854 self.total_prefetch_time += duration;
855
856 self.recent_failures.push(error);
858 if self.recent_failures.len() > 10 {
859 self.recent_failures.remove(0);
860 }
861
862 self.update_average_time();
863 }
864
865 pub fn record_batch_prefetch(&mut self, _successful: usize, failed: usize) {
867 self.batch_operations += 1;
868 if failed > 0 {
869 self.failed_batch_operations += 1;
870 }
871 }
872
873 pub fn record_cancellation(&mut self) {
875 self.cancelled_prefetches += 1;
876 }
877
878 pub fn get_success_rate(&self) -> f32 {
880 if self.total_prefetches > 0 {
881 self.successful_prefetches as f32 / self.total_prefetches as f32
882 } else {
883 1.0 }
885 }
886
887 pub fn get_failure_rate(&self) -> f32 {
889 if self.total_prefetches > 0 {
890 self.failed_prefetches as f32 / self.total_prefetches as f32
891 } else {
892 0.0
893 }
894 }
895
896 pub fn get_throughput_bps(&self) -> f64 {
898 if !self.total_prefetch_time.is_zero() {
899 self.total_bytes_prefetched as f64 / self.total_prefetch_time.as_secs_f64()
900 } else {
901 0.0
902 }
903 }
904
905 fn update_average_time(&mut self) {
907 if self.total_prefetches > 0 {
908 self.average_prefetch_time = self.total_prefetch_time / self.total_prefetches as u32;
909 }
910 }
911
912 pub fn reset(&mut self) {
914 *self = Self::new();
915 }
916}
917
918impl Default for PrefetchMetrics {
919 fn default() -> Self {
920 Self::new()
921 }
922}
923
924#[cfg(test)]
925mod tests {
926 use super::*;
927 use crate::{init_process_group, BackendType};
928
929 #[test]
930 fn test_prefetch_request_priority_ordering() {
931 let mut requests = [
932 PrefetchRequest {
933 layer_name: "low".to_string(),
934 priority: PrefetchPriority::Low,
935 requested_at: std::time::Instant::now(),
936 estimated_size_bytes: 1000,
937 },
938 PrefetchRequest {
939 layer_name: "high".to_string(),
940 priority: PrefetchPriority::High,
941 requested_at: std::time::Instant::now(),
942 estimated_size_bytes: 1000,
943 },
944 PrefetchRequest {
945 layer_name: "normal".to_string(),
946 priority: PrefetchPriority::Normal,
947 requested_at: std::time::Instant::now(),
948 estimated_size_bytes: 1000,
949 },
950 ];
951
952 requests.sort_by(|a, b| b.priority.cmp(&a.priority));
953 assert_eq!(requests[0].layer_name, "high");
954 assert_eq!(requests[1].layer_name, "normal");
955 assert_eq!(requests[2].layer_name, "low");
956 }
957
958 #[test]
959 fn test_adaptive_prefetch_config() {
960 let zero3_config = Zero3CpuOffloadConfig::default();
961 let config = AdaptivePrefetchConfig::new(&zero3_config);
962
963 assert_eq!(
964 config.base_prefetch_distance,
965 zero3_config.prefetch_buffer_size / 4
966 );
967 assert_eq!(
968 config.max_prefetch_distance,
969 zero3_config.prefetch_buffer_size
970 );
971 assert!(config.adaptive_optimization_enabled);
972 }
973
974 #[test]
975 fn test_prefetch_metrics() {
976 let mut metrics = PrefetchMetrics::new();
977
978 metrics.record_successful_prefetch(tokio::time::Duration::from_millis(100), 1000);
979 assert_eq!(metrics.total_prefetches, 1);
980 assert_eq!(metrics.successful_prefetches, 1);
981 assert_eq!(metrics.get_success_rate(), 1.0);
982
983 metrics.record_failed_prefetch(
984 tokio::time::Duration::from_millis(50),
985 "test error".to_string(),
986 );
987 assert_eq!(metrics.total_prefetches, 2);
988 assert_eq!(metrics.failed_prefetches, 1);
989 assert_eq!(metrics.get_success_rate(), 0.5);
990 assert_eq!(metrics.recent_failures.len(), 1);
991 }
992
993 #[tokio::test]
994 async fn test_prefetch_scheduler_creation() {
995 let config = Zero3CpuOffloadConfig::default();
996 let pg = init_process_group(BackendType::Gloo, 0, 1, "127.0.0.1", 29500)
997 .await
998 .unwrap();
999 let scheduler = PrefetchScheduler::new(&config, Arc::new(pg));
1000
1001 let status = scheduler.get_queue_status();
1002 assert_eq!(status.queued_requests, 0);
1003 assert_eq!(status.active_operations, 0);
1004
1005 let metrics = scheduler.get_metrics();
1006 assert_eq!(metrics.total_prefetches, 0);
1007 }
1008
1009 #[tokio::test]
1010 async fn test_prefetch_distance_calculation() {
1011 let config = Zero3CpuOffloadConfig::default();
1012 let pg = init_process_group(BackendType::Gloo, 0, 1, "127.0.0.1", 29500)
1013 .await
1014 .unwrap();
1015 let scheduler = PrefetchScheduler::new(&config, Arc::new(pg));
1016
1017 let distance = scheduler
1018 .calculate_optimal_prefetch_distance()
1019 .await
1020 .unwrap();
1021 assert!(distance >= 1);
1022 assert!(distance <= config.prefetch_buffer_size);
1023 }
1024
1025 #[tokio::test]
1026 async fn test_batch_prefetch() {
1027 let config = Zero3CpuOffloadConfig::default();
1028 let pg = init_process_group(BackendType::Gloo, 0, 1, "127.0.0.1", 29500)
1029 .await
1030 .unwrap();
1031 let scheduler = PrefetchScheduler::new(&config, Arc::new(pg));
1032
1033 let layers = vec!["layer1".to_string(), "layer2".to_string()];
1034 scheduler.batch_prefetch(layers).await.unwrap();
1035
1036 let metrics = scheduler.get_metrics();
1037 assert_eq!(metrics.batch_operations, 1);
1038 }
1039
1040 #[test]
1041 fn test_task_coordination() {
1042 let coordination = TaskCoordination::new();
1043 assert_eq!(coordination.active_tasks, 0);
1044 assert!(coordination.coordination_enabled);
1045 assert_eq!(coordination.max_background_tasks, 8);
1046 }
1047
1048 #[tokio::test]
1049 async fn test_cancel_prefetches() {
1050 let config = Zero3CpuOffloadConfig::default();
1051 let pg = init_process_group(BackendType::Gloo, 0, 1, "127.0.0.1", 29500)
1052 .await
1053 .unwrap();
1054 let scheduler = PrefetchScheduler::new(&config, Arc::new(pg));
1055
1056 scheduler.cancel_all_prefetches().await.unwrap();
1058
1059 let status = scheduler.get_queue_status();
1060 assert_eq!(status.queued_requests, 0);
1061 }
1062}