Skip to main content

trustformers_models/
performance_optimization.rs

1//! Performance Optimization Utilities
2//!
3//! This module provides performance optimization utilities for model inference,
4//! including batch processing, memory optimization, and caching strategies.
5
6use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, VecDeque};
8use trustformers_core::errors::{Result, TrustformersError};
9use trustformers_core::Tensor;
10
11/// Configuration for performance optimization
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct PerformanceConfig {
14    /// Maximum batch size for inference
15    pub max_batch_size: usize,
16    /// Whether to enable dynamic batching
17    pub enable_dynamic_batching: bool,
18    /// Cache size for frequently used tensors
19    pub cache_size: usize,
20    /// Whether to enable memory optimization
21    pub enable_memory_optimization: bool,
22    /// Number of threads for parallel processing
23    pub num_threads: Option<usize>,
24}
25
26impl Default for PerformanceConfig {
27    fn default() -> Self {
28        Self {
29            max_batch_size: 32,
30            enable_dynamic_batching: true,
31            cache_size: 1000,
32            enable_memory_optimization: true,
33            num_threads: None, // Use system default
34        }
35    }
36}
37
38/// LRU Cache implementation for tensors
39#[derive(Debug)]
40pub struct LruCache {
41    capacity: usize,
42    cache: HashMap<String, (Tensor, usize)>, // (tensor, access_order)
43    access_order: usize,
44    access_history: VecDeque<String>,
45    hits: usize,
46    misses: usize,
47}
48
49impl LruCache {
50    pub fn new(capacity: usize) -> Self {
51        Self {
52            capacity,
53            cache: HashMap::new(),
54            access_order: 0,
55            access_history: VecDeque::new(),
56            hits: 0,
57            misses: 0,
58        }
59    }
60
61    pub fn get(&mut self, key: &str) -> Option<&Tensor> {
62        if let Some((tensor, _)) = self.cache.get(key).cloned() {
63            self.access_order += 1;
64            self.cache.insert(key.to_string(), (tensor, self.access_order));
65            self.hits += 1;
66            self.cache.get(key).map(|(tensor, _)| tensor)
67        } else {
68            self.misses += 1;
69            None
70        }
71    }
72
73    pub fn put(&mut self, key: String, tensor: Tensor) {
74        if self.cache.len() >= self.capacity && !self.cache.contains_key(&key) {
75            self.evict_lru();
76        }
77
78        self.access_order += 1;
79        self.cache.insert(key.clone(), (tensor, self.access_order));
80        self.access_history.push_back(key);
81
82        // Keep access history manageable
83        if self.access_history.len() > self.capacity * 2 {
84            self.access_history.pop_front();
85        }
86    }
87
88    fn evict_lru(&mut self) {
89        if let Some(lru_key) = self.find_lru_key() {
90            self.cache.remove(&lru_key);
91        }
92    }
93
94    fn find_lru_key(&self) -> Option<String> {
95        self.cache
96            .iter()
97            .min_by_key(|(_, (_, access_order))| *access_order)
98            .map(|(key, _)| key.clone())
99    }
100
101    pub fn clear(&mut self) {
102        self.cache.clear();
103        self.access_history.clear();
104        self.access_order = 0;
105        self.hits = 0;
106        self.misses = 0;
107    }
108
109    pub fn len(&self) -> usize {
110        self.cache.len()
111    }
112
113    pub fn hit_rate(&self) -> f64 {
114        let total = self.hits + self.misses;
115        if total > 0 {
116            self.hits as f64 / total as f64
117        } else {
118            0.0
119        }
120    }
121
122    pub fn statistics(&self) -> CacheStatistics {
123        CacheStatistics {
124            current_size: self.cache.len(),
125            max_size: self.capacity,
126            hit_rate: self.hit_rate(),
127        }
128    }
129}
130
131/// Batch processor for efficient inference
132#[derive(Debug)]
133pub struct BatchProcessor {
134    config: PerformanceConfig,
135    cache: LruCache,
136    batch_buffer: Vec<Tensor>,
137}
138
139impl BatchProcessor {
140    /// Create a new batch processor
141    pub fn new(config: PerformanceConfig) -> Self {
142        Self {
143            cache: LruCache::new(config.cache_size),
144            config,
145            batch_buffer: Vec::new(),
146        }
147    }
148
149    /// Add a tensor to the current batch
150    pub fn add_to_batch(&mut self, tensor: Tensor) -> Result<Option<Vec<Tensor>>> {
151        self.batch_buffer.push(tensor);
152
153        if self.batch_buffer.len() >= self.config.max_batch_size {
154            Ok(Some(self.flush_batch()?))
155        } else {
156            Ok(None)
157        }
158    }
159
160    /// Flush the current batch and return it
161    pub fn flush_batch(&mut self) -> Result<Vec<Tensor>> {
162        let batch = std::mem::take(&mut self.batch_buffer);
163        Ok(batch)
164    }
165
166    /// Cache a tensor with a given key
167    pub fn cache_tensor(&mut self, key: String, tensor: Tensor) -> Result<()> {
168        self.cache.put(key, tensor);
169        Ok(())
170    }
171
172    /// Cache statistics
173    pub fn cache_stats(&self) -> CacheStatistics {
174        self.cache.statistics()
175    }
176
177    /// Retrieve a cached tensor
178    pub fn get_cached_tensor(&mut self, key: &str) -> Option<&Tensor> {
179        self.cache.get(key)
180    }
181
182    /// Clear the cache
183    pub fn clear_cache(&mut self) {
184        self.cache.clear();
185    }
186
187    /// Get current batch size
188    pub fn current_batch_size(&self) -> usize {
189        self.batch_buffer.len()
190    }
191}
192
193/// Memory optimization utilities
194pub struct MemoryOptimizer;
195
196impl MemoryOptimizer {
197    /// Optimize tensor memory layout for better cache performance
198    pub fn optimize_memory_layout(tensors: &mut [Tensor]) -> Result<()> {
199        // Sort tensors by size (larger tensors first) for better memory allocation patterns
200        tensors.sort_by(|a, b| {
201            let size_a = a.shape().iter().product::<usize>();
202            let size_b = b.shape().iter().product::<usize>();
203            size_b.cmp(&size_a) // Descending order
204        });
205
206        // Apply memory layout optimizations per tensor
207        for tensor in tensors.iter_mut() {
208            Self::optimize_single_tensor_layout(tensor)?;
209        }
210
211        Ok(())
212    }
213
214    /// Optimize memory layout for a single tensor
215    fn optimize_single_tensor_layout(tensor: &mut Tensor) -> Result<()> {
216        match tensor {
217            Tensor::F32(ref mut data)
218                // For multidimensional tensors, consider reshaping for better cache locality
219                // This is a simplified optimization - in practice, you'd analyze access patterns
220                if data.ndim() > 2
221                    // Ensure the tensor is in contiguous memory layout
222                    && !data.is_standard_layout() => {
223                        let owned = data.to_owned();
224                        *data = owned;
225                    },
226            Tensor::I64(ref mut data)
227                // Similar optimization for integer tensors
228                if data.ndim() > 2 && !data.is_standard_layout() => {
229                    let owned = data.to_owned();
230                    *data = owned;
231                },
232            _ => {
233                // For other tensor types, ensure standard layout if possible
234            },
235        }
236        Ok(())
237    }
238
239    /// Analyze memory access patterns and suggest optimizations
240    pub fn analyze_memory_patterns(tensors: &[Tensor]) -> Vec<String> {
241        let mut recommendations = Vec::new();
242
243        // Check for fragmentation patterns
244        let total_elements: usize =
245            tensors.iter().map(|t| t.shape().iter().product::<usize>()).sum();
246
247        if total_elements > 1_000_000 {
248            recommendations
249                .push("Consider using memory pooling for large tensor operations".to_string());
250        }
251
252        // Check for small tensor overhead
253        let small_tensors =
254            tensors.iter().filter(|t| t.shape().iter().product::<usize>() < 1000).count();
255
256        if small_tensors > 10 {
257            recommendations
258                .push("Consider tensor batching to reduce small tensor overhead".to_string());
259        }
260
261        // Check tensor alignment and suggest SIMD optimization
262        for (i, tensor) in tensors.iter().enumerate() {
263            let shape = tensor.shape();
264            if shape.len() >= 2 {
265                let last_dim = shape[shape.len() - 1];
266                if last_dim % 4 != 0 {
267                    recommendations.push(format!(
268                        "Tensor {} last dimension ({}) not aligned for SIMD operations",
269                        i, last_dim
270                    ));
271                }
272            }
273        }
274
275        recommendations
276    }
277
278    /// Estimate memory usage for a batch of tensors
279    pub fn estimate_memory_usage(tensors: &[Tensor]) -> Result<usize> {
280        let mut total_bytes = 0;
281
282        for tensor in tensors {
283            let shape = tensor.shape();
284            let elements = shape.iter().product::<usize>();
285            // Assuming f32 elements (4 bytes each)
286            total_bytes += elements * 4;
287        }
288
289        Ok(total_bytes)
290    }
291
292    /// Check if a batch fits within memory constraints
293    pub fn check_memory_constraints(tensors: &[Tensor], max_memory_mb: usize) -> Result<bool> {
294        let estimated_bytes = Self::estimate_memory_usage(tensors)?;
295        let max_bytes = max_memory_mb * 1024 * 1024;
296        Ok(estimated_bytes <= max_bytes)
297    }
298}
299
300/// Dynamic batching strategy
301#[derive(Debug, Clone, Serialize, Deserialize)]
302pub enum BatchingStrategy {
303    /// Fixed batch size
304    Fixed(usize),
305    /// Dynamic batching based on sequence length
306    DynamicByLength {
307        max_length: usize,
308        max_batch_size: usize,
309    },
310    /// Dynamic batching based on memory constraints
311    DynamicByMemory { max_memory_mb: usize },
312    /// Adaptive batching that adjusts based on performance metrics
313    Adaptive {
314        initial_batch_size: usize,
315        max_batch_size: usize,
316        target_latency_ms: f64,
317        adjustment_factor: f64,
318    },
319    /// Priority-based batching with different priorities
320    PriorityBased {
321        high_priority_batch_size: usize,
322        normal_priority_batch_size: usize,
323        low_priority_batch_size: usize,
324    },
325}
326
327/// Dynamic batch manager
328#[derive(Debug)]
329pub struct DynamicBatchManager {
330    strategy: BatchingStrategy,
331    pending_tensors: Vec<(Tensor, usize)>, // (tensor, priority)
332    current_batch_size: usize,
333    recent_latencies: VecDeque<f64>,
334    total_batches_processed: usize,
335}
336
337impl DynamicBatchManager {
338    /// Create a new dynamic batch manager
339    pub fn new(strategy: BatchingStrategy) -> Self {
340        let initial_batch_size = match &strategy {
341            BatchingStrategy::Fixed(size) => *size,
342            BatchingStrategy::DynamicByLength { max_batch_size, .. } => *max_batch_size / 2,
343            BatchingStrategy::DynamicByMemory { .. } => 16,
344            BatchingStrategy::Adaptive {
345                initial_batch_size, ..
346            } => *initial_batch_size,
347            BatchingStrategy::PriorityBased {
348                normal_priority_batch_size,
349                ..
350            } => *normal_priority_batch_size,
351        };
352
353        Self {
354            strategy,
355            pending_tensors: Vec::new(),
356            current_batch_size: initial_batch_size,
357            recent_latencies: VecDeque::new(),
358            total_batches_processed: 0,
359        }
360    }
361
362    /// Record latency for adaptive batching
363    pub fn record_latency(&mut self, latency_ms: f64) {
364        self.recent_latencies.push_back(latency_ms);
365
366        // Keep only recent latencies (last 20 batches)
367        if self.recent_latencies.len() > 20 {
368            self.recent_latencies.pop_front();
369        }
370
371        self.total_batches_processed += 1;
372
373        // Adjust batch size for adaptive strategy
374        if let BatchingStrategy::Adaptive {
375            target_latency_ms,
376            max_batch_size,
377            adjustment_factor,
378            ..
379        } = &self.strategy
380        {
381            if self.recent_latencies.len() >= 5 {
382                let avg_latency: f64 =
383                    self.recent_latencies.iter().sum::<f64>() / self.recent_latencies.len() as f64;
384
385                if avg_latency > *target_latency_ms {
386                    // Latency too high, reduce batch size
387                    self.current_batch_size = std::cmp::max(
388                        1,
389                        (self.current_batch_size as f64 * (1.0 - adjustment_factor)) as usize,
390                    );
391                } else if avg_latency < *target_latency_ms * 0.8 {
392                    // Latency acceptable, can increase batch size
393                    self.current_batch_size = std::cmp::min(
394                        *max_batch_size,
395                        (self.current_batch_size as f64 * (1.0 + adjustment_factor)) as usize,
396                    );
397                }
398            }
399        }
400    }
401
402    /// Add a tensor to the pending queue with priority
403    pub fn add_tensor(&mut self, tensor: Tensor, priority: usize) -> Result<()> {
404        self.pending_tensors.push((tensor, priority));
405        // Sort by priority (higher priority first)
406        self.pending_tensors.sort_by_key(|item| std::cmp::Reverse(item.1));
407        Ok(())
408    }
409
410    /// Get the next optimal batch based on the strategy
411    pub fn get_next_batch(&mut self) -> Result<Option<Vec<Tensor>>> {
412        if self.pending_tensors.is_empty() {
413            return Ok(None);
414        }
415
416        match &self.strategy {
417            BatchingStrategy::Fixed(batch_size) => {
418                if self.pending_tensors.len() >= *batch_size {
419                    let batch: Vec<Tensor> = self
420                        .pending_tensors
421                        .drain(0..*batch_size)
422                        .map(|(tensor, _)| tensor)
423                        .collect();
424                    Ok(Some(batch))
425                } else {
426                    Ok(None)
427                }
428            },
429            BatchingStrategy::DynamicByLength {
430                max_length: _,
431                max_batch_size,
432            } => {
433                let batch_size = std::cmp::min(self.pending_tensors.len(), *max_batch_size);
434                if batch_size > 0 {
435                    let batch: Vec<Tensor> = self
436                        .pending_tensors
437                        .drain(0..batch_size)
438                        .map(|(tensor, _)| tensor)
439                        .collect();
440                    Ok(Some(batch))
441                } else {
442                    Ok(None)
443                }
444            },
445            BatchingStrategy::DynamicByMemory { max_memory_mb } => {
446                let mut batch = Vec::new();
447                let mut current_memory = 0;
448
449                while !self.pending_tensors.is_empty() {
450                    let tensor_memory = self.estimate_tensor_memory(&self.pending_tensors[0].0)?;
451                    if current_memory + tensor_memory <= *max_memory_mb * 1024 * 1024 {
452                        let (tensor, _) = self.pending_tensors.remove(0);
453                        batch.push(tensor);
454                        current_memory += tensor_memory;
455                    } else {
456                        break;
457                    }
458                }
459
460                if batch.is_empty() {
461                    Ok(None)
462                } else {
463                    Ok(Some(batch))
464                }
465            },
466            BatchingStrategy::Adaptive { .. } => {
467                if self.pending_tensors.len() >= self.current_batch_size {
468                    let batch: Vec<Tensor> = self
469                        .pending_tensors
470                        .drain(0..self.current_batch_size)
471                        .map(|(tensor, _)| tensor)
472                        .collect();
473                    Ok(Some(batch))
474                } else {
475                    Ok(None)
476                }
477            },
478            BatchingStrategy::PriorityBased {
479                high_priority_batch_size,
480                normal_priority_batch_size,
481                low_priority_batch_size,
482            } => {
483                // Group by priority
484                let high_priority: Vec<_> = self
485                    .pending_tensors
486                    .iter()
487                    .filter(|(_, priority)| *priority >= 80)
488                    .cloned()
489                    .collect();
490                let normal_priority: Vec<_> = self
491                    .pending_tensors
492                    .iter()
493                    .filter(|(_, priority)| *priority >= 40 && *priority < 80)
494                    .cloned()
495                    .collect();
496                let low_priority: Vec<_> = self
497                    .pending_tensors
498                    .iter()
499                    .filter(|(_, priority)| *priority < 40)
500                    .cloned()
501                    .collect();
502
503                if high_priority.len() >= *high_priority_batch_size {
504                    let batch: Vec<Tensor> = high_priority
505                        .into_iter()
506                        .take(*high_priority_batch_size)
507                        .map(|(tensor, _)| tensor)
508                        .collect();
509                    // Remove processed tensors
510                    self.pending_tensors.retain(|(_, priority)| *priority < 80);
511                    Ok(Some(batch))
512                } else if normal_priority.len() >= *normal_priority_batch_size {
513                    let batch: Vec<Tensor> = normal_priority
514                        .into_iter()
515                        .take(*normal_priority_batch_size)
516                        .map(|(tensor, _)| tensor)
517                        .collect();
518                    // Remove processed tensors
519                    self.pending_tensors.retain(|(_, priority)| *priority < 40 || *priority >= 80);
520                    Ok(Some(batch))
521                } else if low_priority.len() >= *low_priority_batch_size {
522                    let batch: Vec<Tensor> = low_priority
523                        .into_iter()
524                        .take(*low_priority_batch_size)
525                        .map(|(tensor, _)| tensor)
526                        .collect();
527                    // Remove processed tensors
528                    self.pending_tensors.retain(|(_, priority)| *priority >= 40);
529                    Ok(Some(batch))
530                } else {
531                    Ok(None)
532                }
533            },
534        }
535    }
536
537    /// Estimate memory usage for a single tensor
538    fn estimate_tensor_memory(&self, tensor: &Tensor) -> Result<usize> {
539        let shape = tensor.shape();
540        let elements = shape.iter().product::<usize>();
541        // Assuming f32 elements (4 bytes each)
542        Ok(elements * 4)
543    }
544
545    /// Get number of pending tensors
546    pub fn pending_count(&self) -> usize {
547        self.pending_tensors.len()
548    }
549
550    /// Get current batch size for adaptive strategies
551    pub fn current_batch_size(&self) -> usize {
552        self.current_batch_size
553    }
554
555    /// Get average latency for performance analysis
556    pub fn average_latency(&self) -> f64 {
557        if self.recent_latencies.is_empty() {
558            0.0
559        } else {
560            self.recent_latencies.iter().sum::<f64>() / self.recent_latencies.len() as f64
561        }
562    }
563
564    /// Get batch processing statistics
565    pub fn get_batch_statistics(&self) -> BatchStatistics {
566        BatchStatistics {
567            total_batches_processed: self.total_batches_processed,
568            current_batch_size: self.current_batch_size,
569            pending_tensors: self.pending_tensors.len(),
570            average_latency_ms: self.average_latency(),
571            strategy_type: match &self.strategy {
572                BatchingStrategy::Fixed(_) => "Fixed".to_string(),
573                BatchingStrategy::DynamicByLength { .. } => "DynamicByLength".to_string(),
574                BatchingStrategy::DynamicByMemory { .. } => "DynamicByMemory".to_string(),
575                BatchingStrategy::Adaptive { .. } => "Adaptive".to_string(),
576                BatchingStrategy::PriorityBased { .. } => "PriorityBased".to_string(),
577            },
578        }
579    }
580}
581
582/// Performance monitoring utilities
583#[derive(Debug, Default)]
584pub struct PerformanceMonitor {
585    total_inference_time: f64,
586    total_inferences: usize,
587    batch_sizes: Vec<usize>,
588    memory_usage: Vec<usize>,
589}
590
591impl PerformanceMonitor {
592    /// Record an inference time
593    pub fn record_inference(&mut self, time_ms: f64, batch_size: usize, memory_usage: usize) {
594        self.total_inference_time += time_ms;
595        self.total_inferences += 1;
596        self.batch_sizes.push(batch_size);
597        self.memory_usage.push(memory_usage);
598    }
599
600    /// Get average inference time
601    pub fn average_inference_time(&self) -> f64 {
602        if self.total_inferences > 0 {
603            self.total_inference_time / self.total_inferences as f64
604        } else {
605            0.0
606        }
607    }
608
609    /// Get average batch size
610    pub fn average_batch_size(&self) -> f64 {
611        if self.batch_sizes.is_empty() {
612            0.0
613        } else {
614            self.batch_sizes.iter().sum::<usize>() as f64 / self.batch_sizes.len() as f64
615        }
616    }
617
618    /// Get peak memory usage
619    pub fn peak_memory_usage(&self) -> usize {
620        self.memory_usage.iter().max().copied().unwrap_or(0)
621    }
622
623    /// Get performance statistics
624    pub fn get_statistics(&self) -> PerformanceStatistics {
625        PerformanceStatistics {
626            total_inferences: self.total_inferences,
627            average_inference_time_ms: self.average_inference_time(),
628            average_batch_size: self.average_batch_size(),
629            peak_memory_usage_bytes: self.peak_memory_usage(),
630            throughput_inferences_per_second: if self.total_inference_time > 0.0 {
631                (self.total_inferences as f64) / (self.total_inference_time / 1000.0)
632            } else {
633                0.0
634            },
635        }
636    }
637}
638
639/// Cache statistics
640#[derive(Debug, Clone, Serialize, Deserialize)]
641pub struct CacheStatistics {
642    pub current_size: usize,
643    pub max_size: usize,
644    pub hit_rate: f64,
645}
646
647/// Performance statistics
648#[derive(Debug, Clone, Serialize, Deserialize)]
649pub struct PerformanceStatistics {
650    pub total_inferences: usize,
651    pub average_inference_time_ms: f64,
652    pub average_batch_size: f64,
653    pub peak_memory_usage_bytes: usize,
654    pub throughput_inferences_per_second: f64,
655}
656
657/// Advanced performance optimizer with workload analysis
658#[derive(Debug)]
659pub struct AdvancedPerformanceOptimizer {
660    #[allow(dead_code)]
661    config: PerformanceConfig,
662    workload_history: Vec<WorkloadMetrics>,
663    optimization_recommendations: Vec<String>,
664}
665
666/// Workload metrics for optimization analysis
667#[derive(Debug, Clone)]
668pub struct WorkloadMetrics {
669    pub batch_size: usize,
670    pub sequence_length: usize,
671    pub memory_usage: usize,
672    pub inference_time_ms: f64,
673    pub timestamp: std::time::Instant,
674}
675
676impl AdvancedPerformanceOptimizer {
677    /// Create a new advanced optimizer
678    pub fn new(config: PerformanceConfig) -> Self {
679        Self {
680            config,
681            workload_history: Vec::new(),
682            optimization_recommendations: Vec::new(),
683        }
684    }
685
686    /// Record workload metrics
687    pub fn record_workload(&mut self, metrics: WorkloadMetrics) {
688        self.workload_history.push(metrics);
689
690        // Keep only recent history (last 1000 entries)
691        if self.workload_history.len() > 1000 {
692            self.workload_history.remove(0);
693        }
694
695        // Generate recommendations based on patterns
696        self.generate_recommendations();
697    }
698
699    /// Generate optimization recommendations
700    fn generate_recommendations(&mut self) {
701        self.optimization_recommendations.clear();
702
703        if self.workload_history.len() < 10 {
704            return;
705        }
706
707        // Analyze recent performance patterns
708        let recent_metrics: Vec<_> = self.workload_history.iter().rev().take(50).collect();
709
710        // Check for small batch sizes
711        let avg_batch_size: f64 = recent_metrics.iter().map(|m| m.batch_size as f64).sum::<f64>()
712            / recent_metrics.len() as f64;
713
714        if avg_batch_size < 8.0 {
715            self.optimization_recommendations
716                .push("Consider increasing batch size for better throughput".to_string());
717        }
718
719        // Check for high memory usage variation
720        let memory_usages: Vec<usize> = recent_metrics.iter().map(|m| m.memory_usage).collect();
721        let max_memory = memory_usages.iter().max().unwrap_or(&0);
722        let min_memory = memory_usages.iter().min().unwrap_or(&0);
723
724        if *max_memory > min_memory * 2 {
725            self.optimization_recommendations.push(
726                "High memory usage variation detected - consider dynamic batching".to_string(),
727            );
728        }
729
730        // Check for performance degradation
731        if recent_metrics.len() >= 20 {
732            let first_half_avg: f64 =
733                recent_metrics[10..].iter().map(|m| m.inference_time_ms).sum::<f64>() / 10.0;
734            let second_half_avg: f64 =
735                recent_metrics[..10].iter().map(|m| m.inference_time_ms).sum::<f64>() / 10.0;
736
737            if second_half_avg > first_half_avg * 1.2 {
738                self.optimization_recommendations.push(
739                    "Performance degradation detected - consider cache clearing or model reloading"
740                        .to_string(),
741                );
742            }
743        }
744    }
745
746    /// Get current optimization recommendations
747    pub fn get_recommendations(&self) -> &[String] {
748        &self.optimization_recommendations
749    }
750
751    /// Get workload analysis summary
752    pub fn get_workload_analysis(&self) -> WorkloadAnalysis {
753        if self.workload_history.is_empty() {
754            return WorkloadAnalysis::default();
755        }
756
757        let total_metrics = self.workload_history.len();
758        let avg_batch_size = self.workload_history.iter().map(|m| m.batch_size as f64).sum::<f64>()
759            / total_metrics as f64;
760
761        let avg_inference_time =
762            self.workload_history.iter().map(|m| m.inference_time_ms).sum::<f64>()
763                / total_metrics as f64;
764
765        let peak_memory = self.workload_history.iter().map(|m| m.memory_usage).max().unwrap_or(0);
766
767        WorkloadAnalysis {
768            total_samples: total_metrics,
769            average_batch_size: avg_batch_size,
770            average_inference_time_ms: avg_inference_time,
771            peak_memory_usage_bytes: peak_memory,
772            recommendations_count: self.optimization_recommendations.len(),
773        }
774    }
775}
776
777/// Workload analysis summary
778#[derive(Debug, Default, Clone, Serialize, Deserialize)]
779pub struct WorkloadAnalysis {
780    pub total_samples: usize,
781    pub average_batch_size: f64,
782    pub average_inference_time_ms: f64,
783    pub peak_memory_usage_bytes: usize,
784    pub recommendations_count: usize,
785}
786
787/// Batch processing statistics
788#[derive(Debug, Clone, Serialize, Deserialize)]
789pub struct BatchStatistics {
790    pub total_batches_processed: usize,
791    pub current_batch_size: usize,
792    pub pending_tensors: usize,
793    pub average_latency_ms: f64,
794    pub strategy_type: String,
795}
796
797#[cfg(test)]
798mod tests {
799    use super::*;
800
801    #[test]
802    fn test_performance_config_default() {
803        let config = PerformanceConfig::default();
804        assert_eq!(config.max_batch_size, 32);
805        assert!(config.enable_dynamic_batching);
806        assert_eq!(config.cache_size, 1000);
807        assert!(config.enable_memory_optimization);
808    }
809
810    #[test]
811    fn test_batch_processor_creation() {
812        let config = PerformanceConfig::default();
813        let processor = BatchProcessor::new(config);
814        assert_eq!(processor.current_batch_size(), 0);
815    }
816
817    #[test]
818    fn test_memory_optimizer_estimate() {
819        // Create a simple test tensor
820        let tensor = Tensor::zeros(&[2, 3]).expect("operation failed");
821        let tensors = vec![tensor];
822
823        let estimated = MemoryOptimizer::estimate_memory_usage(&tensors).expect("operation failed");
824        // 2 * 3 elements * 4 bytes per f32 element = 24 bytes
825        assert_eq!(estimated, 24);
826    }
827
828    #[test]
829    fn test_dynamic_batch_manager() {
830        let strategy = BatchingStrategy::Fixed(2);
831        let mut manager = DynamicBatchManager::new(strategy);
832
833        let tensor1 = Tensor::zeros(&[1, 2]).expect("operation failed");
834        let tensor2 = Tensor::zeros(&[1, 2]).expect("operation failed");
835
836        manager.add_tensor(tensor1, 1).expect("operation failed");
837        manager.add_tensor(tensor2, 2).expect("operation failed");
838
839        let batch = manager.get_next_batch().expect("operation failed");
840        assert!(batch.is_some());
841        assert_eq!(batch.expect("operation failed").len(), 2);
842    }
843
844    #[test]
845    fn test_performance_monitor() {
846        let mut monitor = PerformanceMonitor::default();
847
848        monitor.record_inference(100.0, 4, 1024);
849        monitor.record_inference(200.0, 8, 2048);
850
851        let stats = monitor.get_statistics();
852        assert_eq!(stats.total_inferences, 2);
853        assert_eq!(stats.average_inference_time_ms, 150.0);
854        assert_eq!(stats.average_batch_size, 6.0);
855        assert_eq!(stats.peak_memory_usage_bytes, 2048);
856    }
857
858    #[test]
859    fn test_cache_statistics() {
860        let config = PerformanceConfig::default();
861        let processor = BatchProcessor::new(config);
862        let stats = processor.cache_stats();
863
864        assert_eq!(stats.current_size, 0);
865        assert_eq!(stats.max_size, 1000);
866        assert_eq!(stats.hit_rate, 0.0);
867    }
868
869    #[test]
870    fn test_advanced_performance_optimizer() {
871        let config = PerformanceConfig::default();
872        let mut optimizer = AdvancedPerformanceOptimizer::new(config);
873
874        // Record some sample workloads
875        for i in 1..=20 {
876            let metrics = WorkloadMetrics {
877                batch_size: if i < 10 { 2 } else { 16 }, // Small then large batches
878                sequence_length: 512,
879                memory_usage: 1024 * i,
880                inference_time_ms: 100.0 + (i as f64 * 5.0),
881                timestamp: std::time::Instant::now(),
882            };
883            optimizer.record_workload(metrics);
884        }
885
886        let analysis = optimizer.get_workload_analysis();
887        assert_eq!(analysis.total_samples, 20);
888        assert!(analysis.average_batch_size > 2.0); // Should be higher due to mix
889
890        let recommendations = optimizer.get_recommendations();
891        assert!(!recommendations.is_empty()); // Should have some recommendations
892    }
893
894    #[test]
895    fn test_lru_cache() {
896        let mut cache = LruCache::new(2);
897
898        let tensor1 = Tensor::zeros(&[1, 2]).expect("operation failed");
899        let tensor2 = Tensor::zeros(&[1, 3]).expect("operation failed");
900        let tensor3 = Tensor::zeros(&[1, 4]).expect("operation failed");
901
902        // Add tensors
903        cache.put("key1".to_string(), tensor1);
904        cache.put("key2".to_string(), tensor2);
905
906        // Access key1 to make it recently used
907        let _ = cache.get("key1");
908
909        // Add key3 - should evict key2 (least recently used)
910        cache.put("key3".to_string(), tensor3);
911
912        // key1 and key3 should be present, key2 should be evicted
913        assert!(cache.get("key1").is_some());
914        assert!(cache.get("key3").is_some());
915        assert!(cache.get("key2").is_none());
916
917        // Check statistics
918        let stats = cache.statistics();
919        assert_eq!(stats.current_size, 2);
920        assert_eq!(stats.max_size, 2);
921        assert!(stats.hit_rate > 0.0);
922    }
923
924    #[test]
925    fn test_adaptive_batching() {
926        let strategy = BatchingStrategy::Adaptive {
927            initial_batch_size: 4,
928            max_batch_size: 16,
929            target_latency_ms: 100.0,
930            adjustment_factor: 0.2,
931        };
932        let mut manager = DynamicBatchManager::new(strategy);
933
934        // Record high latency - should reduce batch size
935        for _ in 0..10 {
936            manager.record_latency(150.0); // Higher than target
937        }
938
939        assert!(manager.current_batch_size() < 4); // Should have reduced
940
941        // Record low latency - should increase batch size
942        for _ in 0..10 {
943            manager.record_latency(50.0); // Lower than target
944        }
945
946        // Note: size might not increase immediately due to adaptation logic
947        let stats = manager.get_batch_statistics();
948        assert_eq!(stats.strategy_type, "Adaptive");
949        assert!(stats.average_latency_ms > 0.0);
950    }
951
952    #[test]
953    fn test_priority_batching() {
954        let strategy = BatchingStrategy::PriorityBased {
955            high_priority_batch_size: 2,
956            normal_priority_batch_size: 4,
957            low_priority_batch_size: 8,
958        };
959        let mut manager = DynamicBatchManager::new(strategy);
960
961        // Add tensors with different priorities
962        let tensor = Tensor::zeros(&[1, 2]).expect("operation failed");
963        manager.add_tensor(tensor.clone(), 90).expect("operation failed"); // High priority
964        manager.add_tensor(tensor.clone(), 50).expect("operation failed"); // Normal priority
965        manager.add_tensor(tensor.clone(), 90).expect("operation failed"); // High priority
966        manager.add_tensor(tensor.clone(), 20).expect("operation failed"); // Low priority
967
968        // Should get high priority batch first
969        let batch = manager.get_next_batch().expect("operation failed");
970        assert!(batch.is_some());
971        assert_eq!(batch.expect("operation failed").len(), 2); // High priority batch size
972
973        let stats = manager.get_batch_statistics();
974        assert_eq!(stats.strategy_type, "PriorityBased");
975    }
976}
977
978/// Advanced GPU Memory Management
979///
980/// This module provides sophisticated GPU memory management capabilities
981/// for high-performance inference and training workloads.
982/// GPU Memory Pool for efficient allocation and deallocation
983#[derive(Debug)]
984pub struct GpuMemoryPool {
985    /// Pool of pre-allocated memory chunks by size
986    pools: HashMap<usize, VecDeque<GpuMemoryChunk>>,
987    /// Total memory allocated (in bytes)
988    total_allocated: usize,
989    /// Maximum memory limit (in bytes)
990    max_memory_limit: usize,
991    /// Memory fragmentation threshold
992    fragmentation_threshold: f32,
993    /// Memory allocation statistics
994    stats: GpuMemoryStats,
995}
996
997#[derive(Debug, Clone)]
998pub struct GpuMemoryChunk {
999    /// Unique identifier for this chunk
1000    pub id: String,
1001    /// Size in bytes
1002    pub size_bytes: usize,
1003    /// Whether this chunk is currently in use
1004    pub in_use: bool,
1005    /// Allocation timestamp
1006    pub allocated_at: std::time::Instant,
1007    /// Last access timestamp
1008    pub last_accessed: std::time::Instant,
1009    /// Reference count for shared usage
1010    pub ref_count: usize,
1011}
1012
1013#[derive(Debug, Default, Clone)]
1014pub struct GpuMemoryStats {
1015    /// Total allocations made
1016    pub total_allocations: usize,
1017    /// Total deallocations made
1018    pub total_deallocations: usize,
1019    /// Current active allocations
1020    pub active_allocations: usize,
1021    /// Peak memory usage (bytes)
1022    pub peak_memory_usage: usize,
1023    /// Current memory usage (bytes)
1024    pub current_memory_usage: usize,
1025    /// Memory fragmentation ratio (0.0 - 1.0)
1026    pub fragmentation_ratio: f32,
1027    /// Average allocation size
1028    pub average_allocation_size: f32,
1029    /// Number of cache hits
1030    pub cache_hits: usize,
1031    /// Number of cache misses
1032    pub cache_misses: usize,
1033}
1034
1035impl GpuMemoryPool {
1036    /// Create a new GPU memory pool with specified limit
1037    pub fn new(max_memory_limit: usize) -> Self {
1038        Self {
1039            pools: HashMap::new(),
1040            total_allocated: 0,
1041            max_memory_limit,
1042            fragmentation_threshold: 0.25, // 25% fragmentation threshold
1043            stats: GpuMemoryStats::default(),
1044        }
1045    }
1046
1047    /// Allocate memory from the pool
1048    pub fn allocate(&mut self, size_bytes: usize) -> Result<GpuMemoryChunk> {
1049        // Check if we have available memory
1050        if self.total_allocated + size_bytes > self.max_memory_limit {
1051            self.try_defragment()?;
1052            if self.total_allocated + size_bytes > self.max_memory_limit {
1053                return Err(TrustformersError::invalid_operation(
1054                    "GPU memory limit exceeded".to_string(),
1055                ));
1056            }
1057        }
1058
1059        // Try to find existing chunk from pool
1060        if let Some(chunk) = self.find_suitable_chunk(size_bytes) {
1061            self.stats.cache_hits += 1;
1062            self.stats.active_allocations += 1;
1063            return Ok(chunk);
1064        }
1065
1066        // Allocate new chunk
1067        let chunk = GpuMemoryChunk {
1068            id: uuid::Uuid::new_v4().to_string(),
1069            size_bytes,
1070            in_use: true,
1071            allocated_at: std::time::Instant::now(),
1072            last_accessed: std::time::Instant::now(),
1073            ref_count: 1,
1074        };
1075
1076        self.total_allocated += size_bytes;
1077        self.stats.total_allocations += 1;
1078        self.stats.active_allocations += 1;
1079        self.stats.cache_misses += 1;
1080        self.stats.current_memory_usage += size_bytes;
1081
1082        if self.stats.current_memory_usage > self.stats.peak_memory_usage {
1083            self.stats.peak_memory_usage = self.stats.current_memory_usage;
1084        }
1085
1086        // Update average allocation size
1087        self.stats.average_allocation_size = (self.stats.average_allocation_size
1088            * (self.stats.total_allocations - 1) as f32
1089            + size_bytes as f32)
1090            / self.stats.total_allocations as f32;
1091
1092        Ok(chunk)
1093    }
1094
1095    /// Deallocate memory back to the pool
1096    pub fn deallocate(&mut self, mut chunk: GpuMemoryChunk) -> Result<()> {
1097        chunk.in_use = false;
1098        chunk.ref_count = 0;
1099
1100        // Add back to appropriate pool
1101        let pool = self.pools.entry(chunk.size_bytes).or_default();
1102        pool.push_back(chunk.clone());
1103
1104        self.stats.total_deallocations += 1;
1105        self.stats.active_allocations = self.stats.active_allocations.saturating_sub(1);
1106        self.stats.current_memory_usage =
1107            self.stats.current_memory_usage.saturating_sub(chunk.size_bytes);
1108
1109        // Check if we need to free some pooled memory
1110        self.cleanup_unused_chunks()?;
1111
1112        Ok(())
1113    }
1114
1115    /// Find a suitable chunk from existing pools
1116    fn find_suitable_chunk(&mut self, size_bytes: usize) -> Option<GpuMemoryChunk> {
1117        // Look for exact size match first
1118        if let Some(pool) = self.pools.get_mut(&size_bytes) {
1119            if let Some(mut chunk) = pool.pop_front() {
1120                chunk.in_use = true;
1121                chunk.last_accessed = std::time::Instant::now();
1122                chunk.ref_count = 1;
1123                return Some(chunk);
1124            }
1125        }
1126
1127        // Look for larger chunks that can be split
1128        let suitable_sizes: Vec<usize> = self.pools.keys()
1129            .filter(|&&size| size > size_bytes && size <= size_bytes * 2) // Avoid too much waste
1130            .copied()
1131            .collect();
1132
1133        for pool_size in suitable_sizes {
1134            if let Some(pool) = self.pools.get_mut(&pool_size) {
1135                if let Some(mut chunk) = pool.pop_front() {
1136                    chunk.in_use = true;
1137                    chunk.last_accessed = std::time::Instant::now();
1138                    chunk.ref_count = 1;
1139                    return Some(chunk);
1140                }
1141            }
1142        }
1143
1144        None
1145    }
1146
1147    /// Cleanup unused chunks to free memory
1148    fn cleanup_unused_chunks(&mut self) -> Result<()> {
1149        let now = std::time::Instant::now();
1150        let cleanup_threshold = std::time::Duration::from_secs(300); // 5 minutes
1151
1152        for pool in self.pools.values_mut() {
1153            pool.retain(|chunk| {
1154                let should_keep =
1155                    chunk.in_use || now.duration_since(chunk.last_accessed) < cleanup_threshold;
1156                if !should_keep {
1157                    self.total_allocated = self.total_allocated.saturating_sub(chunk.size_bytes);
1158                }
1159                should_keep
1160            });
1161        }
1162
1163        Ok(())
1164    }
1165
1166    /// Attempt to defragment memory
1167    fn try_defragment(&mut self) -> Result<()> {
1168        // Calculate current fragmentation ratio
1169        let total_pooled = self
1170            .pools
1171            .values()
1172            .map(|pool| pool.iter().map(|chunk| chunk.size_bytes).sum::<usize>())
1173            .sum::<usize>();
1174
1175        self.stats.fragmentation_ratio = if self.total_allocated > 0 {
1176            total_pooled as f32 / self.total_allocated as f32
1177        } else {
1178            0.0
1179        };
1180
1181        // If fragmentation is above threshold, force cleanup
1182        if self.stats.fragmentation_ratio > self.fragmentation_threshold {
1183            self.force_cleanup()?;
1184        }
1185
1186        Ok(())
1187    }
1188
1189    /// Force cleanup of all unused memory
1190    fn force_cleanup(&mut self) -> Result<()> {
1191        for pool in self.pools.values_mut() {
1192            let initial_size: usize = pool.iter().map(|chunk| chunk.size_bytes).sum();
1193            pool.retain(|chunk| chunk.in_use);
1194            let final_size: usize = pool.iter().map(|chunk| chunk.size_bytes).sum();
1195            self.total_allocated = self.total_allocated.saturating_sub(initial_size - final_size);
1196        }
1197
1198        // Recalculate fragmentation
1199        self.try_defragment()?;
1200
1201        Ok(())
1202    }
1203
1204    /// Get memory pool statistics
1205    pub fn get_statistics(&self) -> GpuMemoryStats {
1206        self.stats.clone()
1207    }
1208
1209    /// Get current memory usage as percentage of limit
1210    pub fn get_memory_usage_percentage(&self) -> f32 {
1211        (self.total_allocated as f32 / self.max_memory_limit as f32) * 100.0
1212    }
1213
1214    /// Get cache efficiency (hit rate)
1215    pub fn get_cache_efficiency(&self) -> f32 {
1216        let total_requests = self.stats.cache_hits + self.stats.cache_misses;
1217        if total_requests > 0 {
1218            self.stats.cache_hits as f32 / total_requests as f32
1219        } else {
1220            0.0
1221        }
1222    }
1223}
1224
1225/// Advanced GPU tensor caching with memory-aware eviction
1226#[derive(Debug)]
1227pub struct GpuTensorCache {
1228    /// Memory pool for efficient allocation
1229    memory_pool: GpuMemoryPool,
1230    /// Cached tensors with metadata
1231    tensor_cache: HashMap<String, CachedTensor>,
1232    /// LRU ordering for eviction
1233    lru_order: VecDeque<String>,
1234    /// Maximum cache size (number of tensors)
1235    max_cache_size: usize,
1236    /// Cache statistics
1237    stats: CacheStatistics,
1238}
1239
1240#[derive(Debug, Clone)]
1241pub struct CachedTensor {
1242    /// The cached tensor data
1243    pub tensor: Tensor,
1244    /// Memory chunk information
1245    pub memory_chunk: GpuMemoryChunk,
1246    /// Access frequency score
1247    pub access_frequency: f32,
1248    /// Importance score (for eviction prioritization)
1249    pub importance_score: f32,
1250    /// Last access time
1251    pub last_access: std::time::Instant,
1252    /// Creation time
1253    pub created_at: std::time::Instant,
1254}
1255
1256impl GpuTensorCache {
1257    /// Create a new GPU tensor cache
1258    pub fn new(max_cache_size: usize, max_memory_limit: usize) -> Self {
1259        Self {
1260            memory_pool: GpuMemoryPool::new(max_memory_limit),
1261            tensor_cache: HashMap::new(),
1262            lru_order: VecDeque::new(),
1263            max_cache_size,
1264            stats: CacheStatistics {
1265                current_size: 0,
1266                max_size: max_cache_size,
1267                hit_rate: 0.0,
1268            },
1269        }
1270    }
1271
1272    /// Cache a tensor with optional importance score
1273    pub fn cache_tensor(
1274        &mut self,
1275        key: String,
1276        tensor: Tensor,
1277        importance_score: Option<f32>,
1278    ) -> Result<()> {
1279        // Calculate tensor size (simplified estimation)
1280        let tensor_size = self.estimate_tensor_size(&tensor);
1281
1282        // Allocate memory chunk
1283        let memory_chunk = self.memory_pool.allocate(tensor_size)?;
1284
1285        // Create cached tensor
1286        let cached_tensor = CachedTensor {
1287            tensor,
1288            memory_chunk,
1289            access_frequency: 1.0,
1290            importance_score: importance_score.unwrap_or(0.5),
1291            last_access: std::time::Instant::now(),
1292            created_at: std::time::Instant::now(),
1293        };
1294
1295        // Check if we need to evict
1296        if self.tensor_cache.len() >= self.max_cache_size {
1297            self.evict_least_important()?;
1298        }
1299
1300        // Insert new tensor
1301        self.tensor_cache.insert(key.clone(), cached_tensor);
1302        self.lru_order.push_back(key);
1303        self.stats.current_size = self.tensor_cache.len();
1304
1305        Ok(())
1306    }
1307
1308    /// Retrieve a tensor from cache
1309    pub fn get_tensor(&mut self, key: &str) -> Option<&Tensor> {
1310        // Check if key exists first
1311        if !self.tensor_cache.contains_key(key) {
1312            return None;
1313        }
1314
1315        // Update LRU order first
1316        self.update_lru_order(key);
1317
1318        // Update access information and return tensor
1319        if let Some(cached_tensor) = self.tensor_cache.get_mut(key) {
1320            cached_tensor.access_frequency += 1.0;
1321            cached_tensor.last_access = std::time::Instant::now();
1322            Some(&cached_tensor.tensor)
1323        } else {
1324            None
1325        }
1326    }
1327
1328    /// Update LRU order for a key
1329    fn update_lru_order(&mut self, key: &str) {
1330        // Remove from current position and add to back
1331        if let Some(pos) = self.lru_order.iter().position(|k| k == key) {
1332            self.lru_order.remove(pos);
1333            self.lru_order.push_back(key.to_string());
1334        }
1335    }
1336
1337    /// Evict the least important tensor
1338    fn evict_least_important(&mut self) -> Result<()> {
1339        // Calculate eviction scores for all cached tensors
1340        let mut eviction_candidates: Vec<(String, f32)> = self
1341            .tensor_cache
1342            .iter()
1343            .map(|(key, cached_tensor)| {
1344                let age_factor = cached_tensor.created_at.elapsed().as_secs() as f32 / 3600.0; // Hours
1345                let frequency_factor = cached_tensor.access_frequency;
1346                let importance_factor = cached_tensor.importance_score;
1347
1348                // Lower score = higher priority for eviction
1349                let eviction_score = importance_factor * frequency_factor / (1.0 + age_factor);
1350                (key.clone(), eviction_score)
1351            })
1352            .collect();
1353
1354        // Sort by eviction score (lowest first)
1355        eviction_candidates
1356            .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1357
1358        // Evict the least important tensor
1359        if let Some((key_to_evict, _)) = eviction_candidates.first() {
1360            if let Some(cached_tensor) = self.tensor_cache.remove(key_to_evict) {
1361                self.memory_pool.deallocate(cached_tensor.memory_chunk)?;
1362
1363                // Remove from LRU order
1364                if let Some(pos) = self.lru_order.iter().position(|k| k == key_to_evict) {
1365                    self.lru_order.remove(pos);
1366                }
1367
1368                self.stats.current_size = self.tensor_cache.len();
1369            }
1370        }
1371
1372        Ok(())
1373    }
1374
1375    /// Estimate tensor size in bytes (simplified)
1376    fn estimate_tensor_size(&self, tensor: &Tensor) -> usize {
1377        match tensor {
1378            Tensor::F32(arr) => arr.len() * 4, // 4 bytes per f32
1379            Tensor::F64(arr) => arr.len() * 8, // 8 bytes per f64
1380            _ => 1024,                         // Default estimate for other types
1381        }
1382    }
1383
1384    /// Get comprehensive cache statistics
1385    pub fn get_comprehensive_stats(&self) -> GpuCacheStatistics {
1386        let memory_stats = self.memory_pool.get_statistics();
1387        let fragmentation_ratio = memory_stats.fragmentation_ratio;
1388
1389        GpuCacheStatistics {
1390            cache_stats: self.stats.clone(),
1391            memory_stats,
1392            memory_usage_percentage: self.memory_pool.get_memory_usage_percentage(),
1393            cache_efficiency: self.memory_pool.get_cache_efficiency(),
1394            average_tensor_age: self.calculate_average_tensor_age(),
1395            fragmentation_ratio,
1396        }
1397    }
1398
1399    /// Calculate average age of cached tensors
1400    fn calculate_average_tensor_age(&self) -> f32 {
1401        if self.tensor_cache.is_empty() {
1402            return 0.0;
1403        }
1404
1405        let total_age: f32 = self
1406            .tensor_cache
1407            .values()
1408            .map(|cached_tensor| cached_tensor.created_at.elapsed().as_secs() as f32)
1409            .sum();
1410
1411        total_age / self.tensor_cache.len() as f32
1412    }
1413
1414    /// Clear all cached tensors
1415    pub fn clear(&mut self) -> Result<()> {
1416        for (_, cached_tensor) in self.tensor_cache.drain() {
1417            self.memory_pool.deallocate(cached_tensor.memory_chunk)?;
1418        }
1419        self.lru_order.clear();
1420        self.stats.current_size = 0;
1421        Ok(())
1422    }
1423}
1424
1425/// Comprehensive GPU cache statistics
1426#[derive(Debug, Clone)]
1427pub struct GpuCacheStatistics {
1428    pub cache_stats: CacheStatistics,
1429    pub memory_stats: GpuMemoryStats,
1430    pub memory_usage_percentage: f32,
1431    pub cache_efficiency: f32,
1432    pub average_tensor_age: f32,
1433    pub fragmentation_ratio: f32,
1434}
1435
1436/// GPU memory optimization recommendations
1437#[derive(Debug, Clone)]
1438pub struct GpuOptimizationRecommendations {
1439    /// Recommended actions to improve performance
1440    pub recommendations: Vec<String>,
1441    /// Priority level (High, Medium, Low)
1442    pub priority: String,
1443    /// Estimated performance improvement percentage
1444    pub estimated_improvement: f32,
1445}
1446
1447/// GPU memory optimizer with intelligent recommendations
1448pub struct GpuMemoryOptimizer;
1449
1450impl GpuMemoryOptimizer {
1451    /// Analyze GPU memory usage and provide optimization recommendations
1452    pub fn analyze_and_recommend(stats: &GpuCacheStatistics) -> GpuOptimizationRecommendations {
1453        let mut recommendations = Vec::new();
1454        let mut priority = "Low".to_string();
1455        let mut estimated_improvement: f32 = 0.0;
1456
1457        // Analyze memory usage
1458        if stats.memory_usage_percentage > 90.0 {
1459            recommendations.push("Critical: Memory usage is very high. Consider increasing memory limit or improving eviction strategy.".to_string());
1460            priority = "High".to_string();
1461            estimated_improvement += 25.0;
1462        } else if stats.memory_usage_percentage > 75.0 {
1463            recommendations.push(
1464                "Warning: Memory usage is high. Monitor for potential memory pressure.".to_string(),
1465            );
1466            priority = "Medium".to_string();
1467            estimated_improvement += 10.0;
1468        }
1469
1470        // Analyze fragmentation
1471        if stats.fragmentation_ratio > 0.4 {
1472            recommendations.push(
1473                "High memory fragmentation detected. Consider running defragmentation.".to_string(),
1474            );
1475            if priority == "Low" {
1476                priority = "Medium".to_string();
1477            }
1478            estimated_improvement += 15.0;
1479        }
1480
1481        // Analyze cache efficiency
1482        if stats.cache_efficiency < 0.7 {
1483            recommendations.push(
1484                "Low cache hit rate. Consider adjusting cache size or eviction policy.".to_string(),
1485            );
1486            if priority == "Low" {
1487                priority = "Medium".to_string();
1488            }
1489            estimated_improvement += 20.0;
1490        }
1491
1492        // Analyze tensor age
1493        if stats.average_tensor_age > 3600.0 {
1494            // 1 hour
1495            recommendations.push(
1496                "Cached tensors are aging. Consider more aggressive eviction for unused tensors."
1497                    .to_string(),
1498            );
1499            estimated_improvement += 5.0;
1500        }
1501
1502        // Provide specific optimization suggestions
1503        if stats.memory_stats.active_allocations > 1000 {
1504            recommendations.push(
1505                "High number of active allocations. Consider batching or pooling strategies."
1506                    .to_string(),
1507            );
1508            estimated_improvement += 12.0;
1509        }
1510
1511        if recommendations.is_empty() {
1512            recommendations
1513                .push("GPU memory usage is optimal. No immediate action required.".to_string());
1514        }
1515
1516        GpuOptimizationRecommendations {
1517            recommendations,
1518            priority,
1519            estimated_improvement: estimated_improvement.min(50.0), // Cap at 50%
1520        }
1521    }
1522
1523    /// Perform automatic GPU memory optimization
1524    pub fn auto_optimize(cache: &mut GpuTensorCache) -> Result<Vec<String>> {
1525        let stats = cache.get_comprehensive_stats();
1526        let recommendations = Self::analyze_and_recommend(&stats);
1527        let mut actions_taken = Vec::new();
1528
1529        // Auto-apply some optimizations based on priority
1530        if recommendations.priority == "High" {
1531            // Force cleanup if memory usage is critical
1532            if stats.memory_usage_percentage > 90.0 {
1533                cache.memory_pool.force_cleanup()?;
1534                actions_taken.push("Performed emergency memory cleanup".to_string());
1535            }
1536        }
1537
1538        if stats.fragmentation_ratio > 0.4 {
1539            cache.memory_pool.try_defragment()?;
1540            actions_taken.push("Performed memory defragmentation".to_string());
1541        }
1542
1543        if actions_taken.is_empty() {
1544            actions_taken.push("No automatic optimizations were necessary".to_string());
1545        }
1546
1547        Ok(actions_taken)
1548    }
1549}
1550
1551#[cfg(test)]
1552mod gpu_memory_tests {
1553    use super::*;
1554
1555    #[test]
1556    fn test_gpu_memory_pool_basic() {
1557        let mut pool = GpuMemoryPool::new(1024 * 1024); // 1MB limit
1558
1559        // Test allocation
1560        let chunk = pool.allocate(1024).expect("operation failed");
1561        assert_eq!(chunk.size_bytes, 1024);
1562        assert!(chunk.in_use);
1563        assert_eq!(pool.get_statistics().active_allocations, 1);
1564
1565        // Test deallocation
1566        pool.deallocate(chunk).expect("operation failed");
1567        assert_eq!(pool.get_statistics().active_allocations, 0);
1568    }
1569
1570    #[test]
1571    fn test_gpu_memory_pool_reuse() {
1572        let mut pool = GpuMemoryPool::new(1024 * 1024);
1573
1574        // Allocate and deallocate
1575        let chunk = pool.allocate(1024).expect("operation failed");
1576        pool.deallocate(chunk).expect("operation failed");
1577
1578        // Allocate same size - should reuse
1579        let stats_before = pool.get_statistics();
1580        let _chunk2 = pool.allocate(1024).expect("operation failed");
1581        let stats_after = pool.get_statistics();
1582
1583        assert_eq!(stats_after.cache_hits, stats_before.cache_hits + 1);
1584    }
1585
1586    #[test]
1587    fn test_gpu_tensor_cache() -> Result<()> {
1588        let mut cache = GpuTensorCache::new(2, 1024 * 1024);
1589
1590        let tensor1 = Tensor::zeros(&[10, 10])?;
1591        let tensor2 = Tensor::zeros(&[5, 5])?;
1592        let tensor3 = Tensor::zeros(&[20, 20])?;
1593
1594        // Cache tensors
1595        cache.cache_tensor("tensor1".to_string(), tensor1, Some(0.8))?;
1596        cache.cache_tensor("tensor2".to_string(), tensor2, Some(0.6))?;
1597
1598        // Retrieve cached tensor
1599        assert!(cache.get_tensor("tensor1").is_some());
1600
1601        // Cache third tensor (should evict least important)
1602        cache.cache_tensor("tensor3".to_string(), tensor3, Some(0.9))?;
1603
1604        // tensor2 should be evicted (lowest importance)
1605        assert!(cache.get_tensor("tensor2").is_none());
1606        assert!(cache.get_tensor("tensor1").is_some());
1607        assert!(cache.get_tensor("tensor3").is_some());
1608
1609        Ok(())
1610    }
1611
1612    #[test]
1613    fn test_gpu_optimization_recommendations() {
1614        let stats = GpuCacheStatistics {
1615            cache_stats: CacheStatistics {
1616                current_size: 100,
1617                max_size: 100,
1618                hit_rate: 0.5, // Low hit rate
1619            },
1620            memory_stats: GpuMemoryStats {
1621                fragmentation_ratio: 0.5, // High fragmentation
1622                ..Default::default()
1623            },
1624            memory_usage_percentage: 95.0, // Very high usage
1625            cache_efficiency: 0.5,
1626            average_tensor_age: 7200.0, // 2 hours
1627            fragmentation_ratio: 0.5,
1628        };
1629
1630        let recommendations = GpuMemoryOptimizer::analyze_and_recommend(&stats);
1631
1632        assert_eq!(recommendations.priority, "High");
1633        assert!(!recommendations.recommendations.is_empty());
1634        assert!(recommendations.estimated_improvement > 0.0);
1635    }
1636}