Skip to main content

torsh_distributed/zero_3_cpu_offload/
prefetch.rs

1//! Prefetch Scheduling for ZeRO-3 CPU Offloading
2//!
3//! This module implements intelligent prefetch scheduling for ZeRO-3 (Zero Redundancy
4//! Optimizer Stage 3) with CPU offloading capabilities. It provides asynchronous
5//! parameter prefetching, intelligent scheduling, and batch prefetch operations
6//! to minimize memory transfer latency and maximize training throughput.
7
8// Framework infrastructure - components designed for future use
9#![allow(dead_code)]
10#![allow(clippy::await_holding_lock)]
11use crate::{ProcessGroup, TorshResult};
12use log::info;
13use std::collections::VecDeque;
14use std::sync::{Arc, Mutex};
15use tokio::sync::Semaphore;
16
17use super::config::Zero3CpuOffloadConfig;
18
19/// Prefetch scheduler for async parameter loading
20///
21/// Implements intelligent prefetch scheduling including:
22/// - Asynchronous parameter prefetching from CPU to GPU
23/// - Intelligent scheduling based on execution patterns
24/// - Batch prefetch operations with controlled concurrency
25/// - Adaptive prefetch distance based on system resources
26/// - Background prefetch execution with minimal overhead
27pub struct PrefetchScheduler {
28    /// Configuration for prefetch scheduling
29    config: Zero3CpuOffloadConfig,
30    /// Process group for distributed coordination
31    process_group: Arc<ProcessGroup>,
32    /// Queue of layers to prefetch
33    prefetch_queue: Mutex<VecDeque<PrefetchRequest>>,
34    /// Current prefetch operations tracking
35    active_prefetches: Arc<Mutex<Vec<PrefetchOperation>>>,
36    /// Prefetch performance metrics
37    metrics: Arc<Mutex<PrefetchMetrics>>,
38    /// Adaptive prefetch configuration
39    adaptive_config: Arc<Mutex<AdaptivePrefetchConfig>>,
40    /// Background task coordination
41    task_coordination: Arc<Mutex<TaskCoordination>>,
42}
43
44impl PrefetchScheduler {
45    /// Create a new prefetch scheduler
46    pub fn new(config: &Zero3CpuOffloadConfig, process_group: Arc<ProcessGroup>) -> Self {
47        Self {
48            config: config.clone(),
49            process_group,
50            prefetch_queue: Mutex::new(VecDeque::new()),
51            active_prefetches: Arc::new(Mutex::new(Vec::new())),
52            metrics: Arc::new(Mutex::new(PrefetchMetrics::new())),
53            adaptive_config: Arc::new(Mutex::new(AdaptivePrefetchConfig::new(config))),
54            task_coordination: Arc::new(Mutex::new(TaskCoordination::new())),
55        }
56    }
57
58    /// Schedule a single layer for prefetch
59    ///
60    /// Adds the layer to the prefetch queue and optionally starts immediate prefetch
61    /// if async prefetching is enabled and system resources are available.
62    pub async fn schedule_prefetch(&self, layer_name: &str) -> TorshResult<()> {
63        if !self.config.async_prefetch {
64            return Ok(());
65        }
66
67        let request = PrefetchRequest {
68            layer_name: layer_name.to_string(),
69            priority: PrefetchPriority::Normal,
70            requested_at: std::time::Instant::now(),
71            estimated_size_bytes: self.estimate_layer_size(layer_name),
72        };
73
74        // Add to queue
75        {
76            let mut queue = self
77                .prefetch_queue
78                .lock()
79                .expect("lock should not be poisoned");
80            queue.push_back(request.clone());
81
82            // Maintain queue size limit
83            let max_queue_size = self
84                .adaptive_config
85                .lock()
86                .expect("lock should not be poisoned")
87                .max_queue_size;
88            while queue.len() > max_queue_size {
89                if let Some(dropped) = queue.pop_front() {
90                    info!(
91                        "     Dropped prefetch request for {} (queue full)",
92                        dropped.layer_name
93                    );
94                }
95            }
96        }
97
98        info!(
99            "    Scheduled prefetch for layer: {} ({} bytes)",
100            layer_name, request.estimated_size_bytes
101        );
102
103        // Start async prefetch task
104        self.execute_async_prefetch(request).await?;
105
106        Ok(())
107    }
108
109    /// Execute asynchronous prefetching for a single layer
110    async fn execute_async_prefetch(&self, request: PrefetchRequest) -> TorshResult<()> {
111        let process_group = self.process_group.clone();
112        let metrics = self.metrics.clone();
113        let active_prefetches = self.active_prefetches.clone();
114
115        // Check if we can start a new prefetch operation
116        if !self.can_start_prefetch().await? {
117            info!(
118                "   = Delaying prefetch for {} (system busy)",
119                request.layer_name
120            );
121            return Ok(());
122        }
123
124        // Create operation tracking
125        let operation = PrefetchOperation {
126            layer_name: request.layer_name.clone(),
127            started_at: std::time::Instant::now(),
128            status: PrefetchStatus::InProgress,
129        };
130
131        // Add to active operations
132        {
133            let mut active = active_prefetches
134                .lock()
135                .expect("lock should not be poisoned");
136            active.push(operation);
137        }
138
139        // Spawn background task for prefetching
140        let layer_name = request.layer_name.clone();
141        tokio::spawn(async move {
142            let start_time = std::time::Instant::now();
143            let result = Self::prefetch_layer_data(&layer_name, process_group).await;
144
145            // Update metrics
146            {
147                let mut metrics_guard = metrics.lock().expect("lock should not be poisoned");
148                let duration = start_time.elapsed();
149
150                match result {
151                    Ok(()) => {
152                        metrics_guard.record_successful_prefetch(duration, 0); // Size would be real in production
153                        info!(
154                            "   = Async prefetch completed for layer: {} in {:?}",
155                            layer_name, duration
156                        );
157                    }
158                    Err(e) => {
159                        metrics_guard.record_failed_prefetch(duration, e.to_string());
160                        tracing::error!("Async prefetch failed for layer {}: {}", layer_name, e);
161                    }
162                }
163            }
164
165            // Remove from active operations
166            {
167                let mut active = active_prefetches
168                    .lock()
169                    .expect("lock should not be poisoned");
170                active.retain(|op| op.layer_name != layer_name);
171            }
172        });
173
174        Ok(())
175    }
176
177    /// Check if a new prefetch operation can be started
178    async fn can_start_prefetch(&self) -> TorshResult<bool> {
179        let adaptive_config = self
180            .adaptive_config
181            .lock()
182            .expect("lock should not be poisoned");
183        let active_count = self
184            .active_prefetches
185            .lock()
186            .expect("lock should not be poisoned")
187            .len();
188
189        // Check concurrent prefetch limit
190        if active_count >= adaptive_config.max_concurrent_prefetches {
191            return Ok(false);
192        }
193
194        // Check system resource availability
195        // In a real implementation, this would check:
196        // - Available memory bandwidth
197        // - GPU memory availability
198        // - Current CPU/GPU workload
199        // - Network bandwidth for distributed setups
200
201        Ok(true)
202    }
203
204    /// Actually prefetch layer data from CPU to GPU
205    async fn prefetch_layer_data(
206        layer_name: &str,
207        _process_group: Arc<ProcessGroup>,
208    ) -> TorshResult<()> {
209        // In a real implementation, this would:
210        // 1. Check if layer parameters are needed soon
211        // 2. Load parameters from CPU memory to staging buffer
212        // 3. Transfer data to GPU in background
213        // 4. Update GPU cache with prefetched data
214        // 5. Mark parameters as ready for immediate use
215        // 6. Handle prefetch cancellation if needed
216        // 7. Implement memory-efficient transfer strategies
217
218        // Simulate async data transfer with realistic timing
219        let estimated_transfer_time = Self::estimate_transfer_time(layer_name);
220        tokio::time::sleep(estimated_transfer_time).await;
221
222        Ok(())
223    }
224
225    /// Estimate transfer time for a layer based on size and bandwidth
226    fn estimate_transfer_time(layer_name: &str) -> tokio::time::Duration {
227        // Mock estimation based on layer name
228        let base_time_ms = if layer_name.contains("large") {
229            50 // Large layers take 50ms
230        } else if layer_name.contains("medium") {
231            25 // Medium layers take 25ms
232        } else {
233            10 // Small layers take 10ms
234        };
235
236        tokio::time::Duration::from_millis(base_time_ms)
237    }
238
239    /// Execute prefetch for multiple layers in parallel
240    pub async fn batch_prefetch(&self, layer_names: Vec<String>) -> TorshResult<()> {
241        if !self.config.async_prefetch || layer_names.is_empty() {
242            return Ok(());
243        }
244
245        info!(
246            "   = Starting batch prefetch for {} layers",
247            layer_names.len()
248        );
249
250        #[allow(clippy::await_holding_lock)]
251        let adaptive_config = self
252            .adaptive_config
253            .lock()
254            .expect("lock should not be poisoned");
255        let max_concurrent = adaptive_config.max_concurrent_prefetches;
256        drop(adaptive_config);
257
258        // Execute prefetches in parallel with controlled concurrency
259        let semaphore = Arc::new(Semaphore::new(max_concurrent));
260        let mut tasks = Vec::new();
261
262        for layer_name in layer_names {
263            let sem = semaphore.clone();
264            let process_group = self.process_group.clone();
265            let metrics = self.metrics.clone();
266
267            let task = tokio::spawn(async move {
268                let _permit = sem.acquire().await.expect("semaphore should not be closed");
269                let start_time = std::time::Instant::now();
270                let result = Self::prefetch_layer_data(&layer_name, process_group).await;
271
272                // Record metrics
273                {
274                    let mut metrics_guard = metrics.lock().expect("lock should not be poisoned");
275                    let duration = start_time.elapsed();
276                    match result {
277                        Ok(()) => metrics_guard.record_successful_prefetch(duration, 0),
278                        Err(ref e) => metrics_guard.record_failed_prefetch(duration, e.to_string()),
279                    }
280                }
281
282                result
283            });
284
285            tasks.push(task);
286        }
287
288        // Wait for all prefetch tasks to complete
289        let results: Vec<_> = futures::future::join_all(tasks).await;
290
291        let mut successful = 0;
292        let mut failed = 0;
293
294        for result in results {
295            match result {
296                Ok(Ok(())) => successful += 1,
297                Ok(Err(e)) => {
298                    failed += 1;
299                    tracing::error!("Prefetch task failed: {}", e);
300                }
301                Err(e) => {
302                    failed += 1;
303                    tracing::error!("Prefetch task panicked: {}", e);
304                }
305            }
306        }
307
308        info!(
309            "    Batch prefetch completed: {} successful, {} failed",
310            successful, failed
311        );
312
313        // Update batch metrics
314        {
315            let mut metrics = self.metrics.lock().expect("lock should not be poisoned");
316            metrics.record_batch_prefetch(successful, failed);
317        }
318
319        Ok(())
320    }
321
322    /// Intelligent prefetch scheduling based on execution patterns
323    pub async fn intelligent_prefetch(
324        &self,
325        current_layer: &str,
326        execution_graph: &[String],
327    ) -> TorshResult<()> {
328        if !self.config.async_prefetch {
329            return Ok(());
330        }
331
332        // Find current layer position in execution graph
333        let current_pos = execution_graph.iter().position(|l| l == current_layer);
334
335        if let Some(pos) = current_pos {
336            // Determine how many layers ahead to prefetch based on memory availability
337            let prefetch_distance = self.calculate_optimal_prefetch_distance().await?;
338
339            // Collect layers to prefetch with priority assignment
340            let mut layers_to_prefetch = Vec::new();
341            for i in 1..=prefetch_distance {
342                if pos + i < execution_graph.len() {
343                    layers_to_prefetch.push(execution_graph[pos + i].clone());
344                }
345            }
346
347            if !layers_to_prefetch.is_empty() {
348                info!(
349                    "   > Intelligent prefetch: {} layers ahead from {}",
350                    layers_to_prefetch.len(),
351                    current_layer
352                );
353
354                // Use prioritized batch prefetch
355                self.prioritized_batch_prefetch(layers_to_prefetch, pos)
356                    .await?;
357            }
358        }
359
360        Ok(())
361    }
362
363    /// Execute batch prefetch with priority-based scheduling
364    async fn prioritized_batch_prefetch(
365        &self,
366        layer_names: Vec<String>,
367        _current_pos: usize,
368    ) -> TorshResult<()> {
369        let mut prioritized_requests = Vec::new();
370
371        for (i, layer_name) in layer_names.iter().enumerate() {
372            let priority = match i {
373                0 => PrefetchPriority::High,       // Next layer is high priority
374                1..=2 => PrefetchPriority::Normal, // Next 2 layers are normal priority
375                _ => PrefetchPriority::Low,        // Further layers are low priority
376            };
377
378            let request = PrefetchRequest {
379                layer_name: layer_name.clone(),
380                priority,
381                requested_at: std::time::Instant::now(),
382                estimated_size_bytes: self.estimate_layer_size(layer_name),
383            };
384
385            prioritized_requests.push(request);
386        }
387
388        // Sort by priority (high first)
389        prioritized_requests.sort_by(|a, b| b.priority.cmp(&a.priority));
390
391        #[allow(clippy::await_holding_lock)]
392        // Execute prefetches with priority consideration
393        let adaptive_config = self
394            .adaptive_config
395            .lock()
396            .expect("lock should not be poisoned");
397        let max_concurrent = adaptive_config.max_concurrent_prefetches;
398        drop(adaptive_config);
399
400        let semaphore = Arc::new(Semaphore::new(max_concurrent));
401        let mut tasks = Vec::new();
402
403        for request in prioritized_requests {
404            let sem = semaphore.clone();
405            let process_group = self.process_group.clone();
406            let metrics = self.metrics.clone();
407
408            // Higher priority requests get processed first
409            let delay = match request.priority {
410                PrefetchPriority::High => tokio::time::Duration::from_millis(0),
411                PrefetchPriority::Normal => tokio::time::Duration::from_millis(10),
412                PrefetchPriority::Low => tokio::time::Duration::from_millis(25),
413            };
414
415            let task = tokio::spawn(async move {
416                tokio::time::sleep(delay).await; // Priority-based delay
417                let _permit = sem.acquire().await.expect("semaphore should not be closed");
418                let start_time = std::time::Instant::now();
419                let result = Self::prefetch_layer_data(&request.layer_name, process_group).await;
420
421                // Record metrics
422                {
423                    let mut metrics_guard = metrics.lock().expect("lock should not be poisoned");
424                    let duration = start_time.elapsed();
425                    match result {
426                        Ok(()) => metrics_guard
427                            .record_successful_prefetch(duration, request.estimated_size_bytes),
428                        Err(ref e) => metrics_guard.record_failed_prefetch(duration, e.to_string()),
429                    }
430                }
431
432                (request.layer_name, result)
433            });
434
435            tasks.push(task);
436        }
437
438        // Wait for all prioritized prefetch tasks
439        let results: Vec<_> = futures::future::join_all(tasks).await;
440
441        let mut successful = 0;
442        let mut failed = 0;
443
444        for result in results {
445            match result {
446                Ok((layer_name, Ok(()))) => {
447                    successful += 1;
448                    info!("    Prioritized prefetch completed: {}", layer_name);
449                }
450                Ok((layer_name, Err(e))) => {
451                    failed += 1;
452                    tracing::error!("Prioritized prefetch failed for {}: {}", layer_name, e);
453                }
454                Err(e) => {
455                    failed += 1;
456                    tracing::error!("Prioritized prefetch task panicked: {}", e);
457                }
458            }
459        }
460
461        info!(
462            "   < Prioritized batch prefetch completed: {} successful, {} failed",
463            successful, failed
464        );
465
466        Ok(())
467    }
468
469    /// Calculate optimal prefetch distance based on system resources
470    pub async fn calculate_optimal_prefetch_distance(&self) -> TorshResult<usize> {
471        // In a real implementation, this would consider:
472        // 1. Available GPU memory
473        // 2. Network bandwidth for distributed setups
474        // 3. CPU-GPU transfer bandwidth
475        // 4. Historical execution timing
476        // 5. Layer parameter sizes
477        // 6. Current memory pressure
478        // 7. Prefetch success/failure rates
479
480        let adaptive_config = self
481            .adaptive_config
482            .lock()
483            .expect("lock should not be poisoned");
484        let base_distance = adaptive_config.base_prefetch_distance;
485        let current_performance = self
486            .metrics
487            .lock()
488            .expect("lock should not be poisoned")
489            .get_success_rate();
490
491        // Adjust based on recent prefetch performance
492        let performance_multiplier = if current_performance > 0.9 {
493            1.5 // High success rate, increase distance
494        } else if current_performance > 0.7 {
495            1.0 // Normal success rate, keep distance
496        } else {
497            0.7 // Low success rate, reduce distance
498        };
499
500        let optimal_distance = (base_distance as f32 * performance_multiplier) as usize;
501        let optimal_distance = optimal_distance
502            .max(1)
503            .min(adaptive_config.max_prefetch_distance);
504
505        // Update adaptive configuration
506        drop(adaptive_config);
507        {
508            let mut adaptive_config = self
509                .adaptive_config
510                .lock()
511                .expect("lock should not be poisoned");
512            adaptive_config.current_prefetch_distance = optimal_distance;
513        }
514
515        Ok(optimal_distance)
516    }
517
518    /// Adaptive prefetch management based on system performance
519    pub async fn adapt_prefetch_strategy(&self) -> TorshResult<()> {
520        let metrics = self
521            .metrics
522            .lock()
523            .expect("lock should not be poisoned")
524            .clone();
525        let mut adaptive_config = self
526            .adaptive_config
527            .lock()
528            .expect("lock should not be poisoned");
529
530        info!("   <  Adapting prefetch strategy based on performance");
531
532        // Adapt concurrent prefetch limit based on success rate
533        if metrics.get_success_rate() > 0.95 && metrics.total_prefetches > 10 {
534            // High success rate, increase concurrency
535            adaptive_config.max_concurrent_prefetches =
536                (adaptive_config.max_concurrent_prefetches + 1).min(8);
537            info!(
538                "       Increased max concurrent prefetches to {}",
539                adaptive_config.max_concurrent_prefetches
540            );
541        } else if metrics.get_success_rate() < 0.8 && adaptive_config.max_concurrent_prefetches > 1
542        {
543            // Low success rate, decrease concurrency
544            adaptive_config.max_concurrent_prefetches =
545                (adaptive_config.max_concurrent_prefetches - 1).max(1);
546            info!(
547                "       Decreased max concurrent prefetches to {}",
548                adaptive_config.max_concurrent_prefetches
549            );
550        }
551
552        // Adapt queue size based on utilization
553        let queue_size = self
554            .prefetch_queue
555            .lock()
556            .expect("lock should not be poisoned")
557            .len();
558        if queue_size > adaptive_config.max_queue_size * 3 / 4 {
559            // Queue is mostly full, increase size
560            adaptive_config.max_queue_size = (adaptive_config.max_queue_size + 2).min(32);
561            info!(
562                "     = Increased max queue size to {}",
563                adaptive_config.max_queue_size
564            );
565        } else if queue_size < adaptive_config.max_queue_size / 4
566            && adaptive_config.max_queue_size > 4
567        {
568            // Queue is mostly empty, decrease size
569            adaptive_config.max_queue_size = (adaptive_config.max_queue_size - 1).max(4);
570            info!(
571                "     = Decreased max queue size to {}",
572                adaptive_config.max_queue_size
573            );
574        }
575
576        // Adapt prefetch distance based on timing
577        if metrics.average_prefetch_time > tokio::time::Duration::from_millis(100) {
578            // Prefetches are taking too long, reduce distance
579            adaptive_config.base_prefetch_distance =
580                (adaptive_config.base_prefetch_distance - 1).max(1);
581            info!(
582                "     =; Decreased base prefetch distance to {}",
583                adaptive_config.base_prefetch_distance
584            );
585        } else if metrics.average_prefetch_time < tokio::time::Duration::from_millis(20) {
586            // Prefetches are fast, increase distance
587            adaptive_config.base_prefetch_distance =
588                (adaptive_config.base_prefetch_distance + 1).min(16);
589            info!(
590                "     =: Increased base prefetch distance to {}",
591                adaptive_config.base_prefetch_distance
592            );
593        }
594
595        Ok(())
596    }
597
598    /// Cancel all pending prefetch operations
599    pub async fn cancel_all_prefetches(&self) -> TorshResult<()> {
600        info!("   = Cancelling all pending prefetch operations");
601
602        // Clear prefetch queue
603        {
604            let mut queue = self
605                .prefetch_queue
606                .lock()
607                .expect("lock should not be poisoned");
608            let cancelled_count = queue.len();
609            queue.clear();
610            if cancelled_count > 0 {
611                info!(
612                    "     = Cancelled {} queued prefetch requests",
613                    cancelled_count
614                );
615            }
616        }
617
618        // Note: In a real implementation, you would also:
619        // 1. Cancel active prefetch operations
620        // 2. Clean up partial transfers
621        // 3. Update metrics for cancelled operations
622        // 4. Free allocated resources
623
624        {
625            let mut metrics = self.metrics.lock().expect("lock should not be poisoned");
626            metrics.record_cancellation();
627        }
628
629        Ok(())
630    }
631
632    /// Get current prefetch queue status
633    pub fn get_queue_status(&self) -> PrefetchQueueStatus {
634        let queue = self
635            .prefetch_queue
636            .lock()
637            .expect("lock should not be poisoned");
638        let active = self
639            .active_prefetches
640            .lock()
641            .expect("lock should not be poisoned");
642        let adaptive_config = self
643            .adaptive_config
644            .lock()
645            .expect("lock should not be poisoned");
646
647        PrefetchQueueStatus {
648            queued_requests: queue.len(),
649            active_operations: active.len(),
650            max_queue_size: adaptive_config.max_queue_size,
651            max_concurrent: adaptive_config.max_concurrent_prefetches,
652            current_prefetch_distance: adaptive_config.current_prefetch_distance,
653        }
654    }
655
656    /// Get prefetch performance metrics
657    pub fn get_metrics(&self) -> PrefetchMetrics {
658        self.metrics
659            .lock()
660            .expect("lock should not be poisoned")
661            .clone()
662    }
663
664    /// Get adaptive configuration
665    pub fn get_adaptive_config(&self) -> AdaptivePrefetchConfig {
666        self.adaptive_config
667            .lock()
668            .expect("lock should not be poisoned")
669            .clone()
670    }
671
672    // Helper methods
673
674    fn estimate_layer_size(&self, layer_name: &str) -> usize {
675        // Mock size estimation based on layer name
676        if layer_name.contains("large") {
677            64 * 1024 * 1024 // 64MB
678        } else if layer_name.contains("medium") {
679            16 * 1024 * 1024 // 16MB
680        } else {
681            4 * 1024 * 1024 // 4MB
682        }
683    }
684}
685
686/// Prefetch request with priority and metadata
687#[derive(Debug, Clone)]
688pub struct PrefetchRequest {
689    /// Layer name to prefetch
690    pub layer_name: String,
691    /// Priority of this prefetch request
692    pub priority: PrefetchPriority,
693    /// When this request was made
694    pub requested_at: std::time::Instant,
695    /// Estimated size in bytes
696    pub estimated_size_bytes: usize,
697}
698
699/// Priority levels for prefetch operations
700#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
701pub enum PrefetchPriority {
702    Low = 0,
703    Normal = 1,
704    High = 2,
705}
706
707/// Status of a prefetch operation
708#[derive(Debug, Clone, Copy, PartialEq, Eq)]
709pub enum PrefetchStatus {
710    Queued,
711    InProgress,
712    Completed,
713    Failed,
714    Cancelled,
715}
716
717/// Active prefetch operation tracking
718#[derive(Debug, Clone)]
719pub struct PrefetchOperation {
720    /// Layer being prefetched
721    pub layer_name: String,
722    /// When prefetch started
723    pub started_at: std::time::Instant,
724    /// Current status
725    pub status: PrefetchStatus,
726}
727
728/// Prefetch queue status information
729#[derive(Debug, Clone)]
730pub struct PrefetchQueueStatus {
731    /// Number of requests in queue
732    pub queued_requests: usize,
733    /// Number of active operations
734    pub active_operations: usize,
735    /// Maximum queue size
736    pub max_queue_size: usize,
737    /// Maximum concurrent operations
738    pub max_concurrent: usize,
739    /// Current prefetch distance
740    pub current_prefetch_distance: usize,
741}
742
743/// Adaptive prefetch configuration
744#[derive(Debug, Clone)]
745pub struct AdaptivePrefetchConfig {
746    /// Base prefetch distance (number of layers ahead)
747    pub base_prefetch_distance: usize,
748    /// Current dynamic prefetch distance
749    pub current_prefetch_distance: usize,
750    /// Maximum prefetch distance allowed
751    pub max_prefetch_distance: usize,
752    /// Maximum concurrent prefetch operations
753    pub max_concurrent_prefetches: usize,
754    /// Maximum queue size
755    pub max_queue_size: usize,
756    /// Whether adaptive optimization is enabled
757    pub adaptive_optimization_enabled: bool,
758}
759
760impl AdaptivePrefetchConfig {
761    pub fn new(config: &Zero3CpuOffloadConfig) -> Self {
762        Self {
763            base_prefetch_distance: config.prefetch_buffer_size / 4,
764            current_prefetch_distance: config.prefetch_buffer_size / 4,
765            max_prefetch_distance: config.prefetch_buffer_size,
766            max_concurrent_prefetches: 4,
767            max_queue_size: 16,
768            adaptive_optimization_enabled: true,
769        }
770    }
771}
772
773/// Task coordination for background operations
774#[derive(Debug)]
775pub struct TaskCoordination {
776    /// Number of active background tasks
777    pub active_tasks: usize,
778    /// Maximum allowed background tasks
779    pub max_background_tasks: usize,
780    /// Whether task coordination is enabled
781    pub coordination_enabled: bool,
782}
783
784impl TaskCoordination {
785    pub fn new() -> Self {
786        Self {
787            active_tasks: 0,
788            max_background_tasks: 8,
789            coordination_enabled: true,
790        }
791    }
792}
793
794impl Default for TaskCoordination {
795    fn default() -> Self {
796        Self::new()
797    }
798}
799
800/// Prefetch performance metrics
801#[derive(Debug, Clone)]
802pub struct PrefetchMetrics {
803    /// Total number of prefetch operations attempted
804    pub total_prefetches: u64,
805    /// Number of successful prefetches
806    pub successful_prefetches: u64,
807    /// Number of failed prefetches
808    pub failed_prefetches: u64,
809    /// Number of cancelled prefetches
810    pub cancelled_prefetches: u64,
811    /// Total time spent prefetching
812    pub total_prefetch_time: tokio::time::Duration,
813    /// Average prefetch time
814    pub average_prefetch_time: tokio::time::Duration,
815    /// Total bytes prefetched
816    pub total_bytes_prefetched: usize,
817    /// Number of batch operations
818    pub batch_operations: u64,
819    /// Failed batch operations
820    pub failed_batch_operations: u64,
821    /// Recent failure reasons
822    pub recent_failures: Vec<String>,
823}
824
825impl PrefetchMetrics {
826    pub fn new() -> Self {
827        Self {
828            total_prefetches: 0,
829            successful_prefetches: 0,
830            failed_prefetches: 0,
831            cancelled_prefetches: 0,
832            total_prefetch_time: tokio::time::Duration::ZERO,
833            average_prefetch_time: tokio::time::Duration::ZERO,
834            total_bytes_prefetched: 0,
835            batch_operations: 0,
836            failed_batch_operations: 0,
837            recent_failures: Vec::new(),
838        }
839    }
840
841    /// Record a successful prefetch operation
842    pub fn record_successful_prefetch(&mut self, duration: tokio::time::Duration, bytes: usize) {
843        self.total_prefetches += 1;
844        self.successful_prefetches += 1;
845        self.total_prefetch_time += duration;
846        self.total_bytes_prefetched += bytes;
847        self.update_average_time();
848    }
849
850    /// Record a failed prefetch operation
851    pub fn record_failed_prefetch(&mut self, duration: tokio::time::Duration, error: String) {
852        self.total_prefetches += 1;
853        self.failed_prefetches += 1;
854        self.total_prefetch_time += duration;
855
856        // Keep recent failure reasons (max 10)
857        self.recent_failures.push(error);
858        if self.recent_failures.len() > 10 {
859            self.recent_failures.remove(0);
860        }
861
862        self.update_average_time();
863    }
864
865    /// Record a batch prefetch operation
866    pub fn record_batch_prefetch(&mut self, _successful: usize, failed: usize) {
867        self.batch_operations += 1;
868        if failed > 0 {
869            self.failed_batch_operations += 1;
870        }
871    }
872
873    /// Record prefetch cancellation
874    pub fn record_cancellation(&mut self) {
875        self.cancelled_prefetches += 1;
876    }
877
878    /// Get success rate as a percentage
879    pub fn get_success_rate(&self) -> f32 {
880        if self.total_prefetches > 0 {
881            self.successful_prefetches as f32 / self.total_prefetches as f32
882        } else {
883            1.0 // No operations yet, assume 100% success
884        }
885    }
886
887    /// Get failure rate as a percentage
888    pub fn get_failure_rate(&self) -> f32 {
889        if self.total_prefetches > 0 {
890            self.failed_prefetches as f32 / self.total_prefetches as f32
891        } else {
892            0.0
893        }
894    }
895
896    /// Get average throughput in bytes per second
897    pub fn get_throughput_bps(&self) -> f64 {
898        if !self.total_prefetch_time.is_zero() {
899            self.total_bytes_prefetched as f64 / self.total_prefetch_time.as_secs_f64()
900        } else {
901            0.0
902        }
903    }
904
905    /// Update average prefetch time
906    fn update_average_time(&mut self) {
907        if self.total_prefetches > 0 {
908            self.average_prefetch_time = self.total_prefetch_time / self.total_prefetches as u32;
909        }
910    }
911
912    /// Reset all metrics
913    pub fn reset(&mut self) {
914        *self = Self::new();
915    }
916}
917
918impl Default for PrefetchMetrics {
919    fn default() -> Self {
920        Self::new()
921    }
922}
923
924#[cfg(test)]
925mod tests {
926    use super::*;
927    use crate::{init_process_group, BackendType};
928
929    #[test]
930    fn test_prefetch_request_priority_ordering() {
931        let mut requests = [
932            PrefetchRequest {
933                layer_name: "low".to_string(),
934                priority: PrefetchPriority::Low,
935                requested_at: std::time::Instant::now(),
936                estimated_size_bytes: 1000,
937            },
938            PrefetchRequest {
939                layer_name: "high".to_string(),
940                priority: PrefetchPriority::High,
941                requested_at: std::time::Instant::now(),
942                estimated_size_bytes: 1000,
943            },
944            PrefetchRequest {
945                layer_name: "normal".to_string(),
946                priority: PrefetchPriority::Normal,
947                requested_at: std::time::Instant::now(),
948                estimated_size_bytes: 1000,
949            },
950        ];
951
952        requests.sort_by(|a, b| b.priority.cmp(&a.priority));
953        assert_eq!(requests[0].layer_name, "high");
954        assert_eq!(requests[1].layer_name, "normal");
955        assert_eq!(requests[2].layer_name, "low");
956    }
957
958    #[test]
959    fn test_adaptive_prefetch_config() {
960        let zero3_config = Zero3CpuOffloadConfig::default();
961        let config = AdaptivePrefetchConfig::new(&zero3_config);
962
963        assert_eq!(
964            config.base_prefetch_distance,
965            zero3_config.prefetch_buffer_size / 4
966        );
967        assert_eq!(
968            config.max_prefetch_distance,
969            zero3_config.prefetch_buffer_size
970        );
971        assert!(config.adaptive_optimization_enabled);
972    }
973
974    #[test]
975    fn test_prefetch_metrics() {
976        let mut metrics = PrefetchMetrics::new();
977
978        metrics.record_successful_prefetch(tokio::time::Duration::from_millis(100), 1000);
979        assert_eq!(metrics.total_prefetches, 1);
980        assert_eq!(metrics.successful_prefetches, 1);
981        assert_eq!(metrics.get_success_rate(), 1.0);
982
983        metrics.record_failed_prefetch(
984            tokio::time::Duration::from_millis(50),
985            "test error".to_string(),
986        );
987        assert_eq!(metrics.total_prefetches, 2);
988        assert_eq!(metrics.failed_prefetches, 1);
989        assert_eq!(metrics.get_success_rate(), 0.5);
990        assert_eq!(metrics.recent_failures.len(), 1);
991    }
992
993    #[tokio::test]
994    async fn test_prefetch_scheduler_creation() {
995        let config = Zero3CpuOffloadConfig::default();
996        let pg = init_process_group(BackendType::Gloo, 0, 1, "127.0.0.1", 29500)
997            .await
998            .unwrap();
999        let scheduler = PrefetchScheduler::new(&config, Arc::new(pg));
1000
1001        let status = scheduler.get_queue_status();
1002        assert_eq!(status.queued_requests, 0);
1003        assert_eq!(status.active_operations, 0);
1004
1005        let metrics = scheduler.get_metrics();
1006        assert_eq!(metrics.total_prefetches, 0);
1007    }
1008
1009    #[tokio::test]
1010    async fn test_prefetch_distance_calculation() {
1011        let config = Zero3CpuOffloadConfig::default();
1012        let pg = init_process_group(BackendType::Gloo, 0, 1, "127.0.0.1", 29500)
1013            .await
1014            .unwrap();
1015        let scheduler = PrefetchScheduler::new(&config, Arc::new(pg));
1016
1017        let distance = scheduler
1018            .calculate_optimal_prefetch_distance()
1019            .await
1020            .unwrap();
1021        assert!(distance >= 1);
1022        assert!(distance <= config.prefetch_buffer_size);
1023    }
1024
1025    #[tokio::test]
1026    async fn test_batch_prefetch() {
1027        let config = Zero3CpuOffloadConfig::default();
1028        let pg = init_process_group(BackendType::Gloo, 0, 1, "127.0.0.1", 29500)
1029            .await
1030            .unwrap();
1031        let scheduler = PrefetchScheduler::new(&config, Arc::new(pg));
1032
1033        let layers = vec!["layer1".to_string(), "layer2".to_string()];
1034        scheduler.batch_prefetch(layers).await.unwrap();
1035
1036        let metrics = scheduler.get_metrics();
1037        assert_eq!(metrics.batch_operations, 1);
1038    }
1039
1040    #[test]
1041    fn test_task_coordination() {
1042        let coordination = TaskCoordination::new();
1043        assert_eq!(coordination.active_tasks, 0);
1044        assert!(coordination.coordination_enabled);
1045        assert_eq!(coordination.max_background_tasks, 8);
1046    }
1047
1048    #[tokio::test]
1049    async fn test_cancel_prefetches() {
1050        let config = Zero3CpuOffloadConfig::default();
1051        let pg = init_process_group(BackendType::Gloo, 0, 1, "127.0.0.1", 29500)
1052            .await
1053            .unwrap();
1054        let scheduler = PrefetchScheduler::new(&config, Arc::new(pg));
1055
1056        // Add some requests to queue (we'll mock this)
1057        scheduler.cancel_all_prefetches().await.unwrap();
1058
1059        let status = scheduler.get_queue_status();
1060        assert_eq!(status.queued_requests, 0);
1061    }
1062}