1use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, VecDeque};
8use trustformers_core::errors::{Result, TrustformersError};
9use trustformers_core::Tensor;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct PerformanceConfig {
14 pub max_batch_size: usize,
16 pub enable_dynamic_batching: bool,
18 pub cache_size: usize,
20 pub enable_memory_optimization: bool,
22 pub num_threads: Option<usize>,
24}
25
26impl Default for PerformanceConfig {
27 fn default() -> Self {
28 Self {
29 max_batch_size: 32,
30 enable_dynamic_batching: true,
31 cache_size: 1000,
32 enable_memory_optimization: true,
33 num_threads: None, }
35 }
36}
37
38#[derive(Debug)]
40pub struct LruCache {
41 capacity: usize,
42 cache: HashMap<String, (Tensor, usize)>, access_order: usize,
44 access_history: VecDeque<String>,
45 hits: usize,
46 misses: usize,
47}
48
49impl LruCache {
50 pub fn new(capacity: usize) -> Self {
51 Self {
52 capacity,
53 cache: HashMap::new(),
54 access_order: 0,
55 access_history: VecDeque::new(),
56 hits: 0,
57 misses: 0,
58 }
59 }
60
61 pub fn get(&mut self, key: &str) -> Option<&Tensor> {
62 if let Some((tensor, _)) = self.cache.get(key).cloned() {
63 self.access_order += 1;
64 self.cache.insert(key.to_string(), (tensor, self.access_order));
65 self.hits += 1;
66 self.cache.get(key).map(|(tensor, _)| tensor)
67 } else {
68 self.misses += 1;
69 None
70 }
71 }
72
73 pub fn put(&mut self, key: String, tensor: Tensor) {
74 if self.cache.len() >= self.capacity && !self.cache.contains_key(&key) {
75 self.evict_lru();
76 }
77
78 self.access_order += 1;
79 self.cache.insert(key.clone(), (tensor, self.access_order));
80 self.access_history.push_back(key);
81
82 if self.access_history.len() > self.capacity * 2 {
84 self.access_history.pop_front();
85 }
86 }
87
88 fn evict_lru(&mut self) {
89 if let Some(lru_key) = self.find_lru_key() {
90 self.cache.remove(&lru_key);
91 }
92 }
93
94 fn find_lru_key(&self) -> Option<String> {
95 self.cache
96 .iter()
97 .min_by_key(|(_, (_, access_order))| *access_order)
98 .map(|(key, _)| key.clone())
99 }
100
101 pub fn clear(&mut self) {
102 self.cache.clear();
103 self.access_history.clear();
104 self.access_order = 0;
105 self.hits = 0;
106 self.misses = 0;
107 }
108
109 pub fn len(&self) -> usize {
110 self.cache.len()
111 }
112
113 pub fn hit_rate(&self) -> f64 {
114 let total = self.hits + self.misses;
115 if total > 0 {
116 self.hits as f64 / total as f64
117 } else {
118 0.0
119 }
120 }
121
122 pub fn statistics(&self) -> CacheStatistics {
123 CacheStatistics {
124 current_size: self.cache.len(),
125 max_size: self.capacity,
126 hit_rate: self.hit_rate(),
127 }
128 }
129}
130
131#[derive(Debug)]
133pub struct BatchProcessor {
134 config: PerformanceConfig,
135 cache: LruCache,
136 batch_buffer: Vec<Tensor>,
137}
138
139impl BatchProcessor {
140 pub fn new(config: PerformanceConfig) -> Self {
142 Self {
143 cache: LruCache::new(config.cache_size),
144 config,
145 batch_buffer: Vec::new(),
146 }
147 }
148
149 pub fn add_to_batch(&mut self, tensor: Tensor) -> Result<Option<Vec<Tensor>>> {
151 self.batch_buffer.push(tensor);
152
153 if self.batch_buffer.len() >= self.config.max_batch_size {
154 Ok(Some(self.flush_batch()?))
155 } else {
156 Ok(None)
157 }
158 }
159
160 pub fn flush_batch(&mut self) -> Result<Vec<Tensor>> {
162 let batch = std::mem::take(&mut self.batch_buffer);
163 Ok(batch)
164 }
165
166 pub fn cache_tensor(&mut self, key: String, tensor: Tensor) -> Result<()> {
168 self.cache.put(key, tensor);
169 Ok(())
170 }
171
172 pub fn cache_stats(&self) -> CacheStatistics {
174 self.cache.statistics()
175 }
176
177 pub fn get_cached_tensor(&mut self, key: &str) -> Option<&Tensor> {
179 self.cache.get(key)
180 }
181
182 pub fn clear_cache(&mut self) {
184 self.cache.clear();
185 }
186
187 pub fn current_batch_size(&self) -> usize {
189 self.batch_buffer.len()
190 }
191}
192
193pub struct MemoryOptimizer;
195
196impl MemoryOptimizer {
197 pub fn optimize_memory_layout(tensors: &mut [Tensor]) -> Result<()> {
199 tensors.sort_by(|a, b| {
201 let size_a = a.shape().iter().product::<usize>();
202 let size_b = b.shape().iter().product::<usize>();
203 size_b.cmp(&size_a) });
205
206 for tensor in tensors.iter_mut() {
208 Self::optimize_single_tensor_layout(tensor)?;
209 }
210
211 Ok(())
212 }
213
214 fn optimize_single_tensor_layout(tensor: &mut Tensor) -> Result<()> {
216 match tensor {
217 Tensor::F32(ref mut data)
218 if data.ndim() > 2
221 && !data.is_standard_layout() => {
223 let owned = data.to_owned();
224 *data = owned;
225 },
226 Tensor::I64(ref mut data)
227 if data.ndim() > 2 && !data.is_standard_layout() => {
229 let owned = data.to_owned();
230 *data = owned;
231 },
232 _ => {
233 },
235 }
236 Ok(())
237 }
238
239 pub fn analyze_memory_patterns(tensors: &[Tensor]) -> Vec<String> {
241 let mut recommendations = Vec::new();
242
243 let total_elements: usize =
245 tensors.iter().map(|t| t.shape().iter().product::<usize>()).sum();
246
247 if total_elements > 1_000_000 {
248 recommendations
249 .push("Consider using memory pooling for large tensor operations".to_string());
250 }
251
252 let small_tensors =
254 tensors.iter().filter(|t| t.shape().iter().product::<usize>() < 1000).count();
255
256 if small_tensors > 10 {
257 recommendations
258 .push("Consider tensor batching to reduce small tensor overhead".to_string());
259 }
260
261 for (i, tensor) in tensors.iter().enumerate() {
263 let shape = tensor.shape();
264 if shape.len() >= 2 {
265 let last_dim = shape[shape.len() - 1];
266 if last_dim % 4 != 0 {
267 recommendations.push(format!(
268 "Tensor {} last dimension ({}) not aligned for SIMD operations",
269 i, last_dim
270 ));
271 }
272 }
273 }
274
275 recommendations
276 }
277
278 pub fn estimate_memory_usage(tensors: &[Tensor]) -> Result<usize> {
280 let mut total_bytes = 0;
281
282 for tensor in tensors {
283 let shape = tensor.shape();
284 let elements = shape.iter().product::<usize>();
285 total_bytes += elements * 4;
287 }
288
289 Ok(total_bytes)
290 }
291
292 pub fn check_memory_constraints(tensors: &[Tensor], max_memory_mb: usize) -> Result<bool> {
294 let estimated_bytes = Self::estimate_memory_usage(tensors)?;
295 let max_bytes = max_memory_mb * 1024 * 1024;
296 Ok(estimated_bytes <= max_bytes)
297 }
298}
299
300#[derive(Debug, Clone, Serialize, Deserialize)]
302pub enum BatchingStrategy {
303 Fixed(usize),
305 DynamicByLength {
307 max_length: usize,
308 max_batch_size: usize,
309 },
310 DynamicByMemory { max_memory_mb: usize },
312 Adaptive {
314 initial_batch_size: usize,
315 max_batch_size: usize,
316 target_latency_ms: f64,
317 adjustment_factor: f64,
318 },
319 PriorityBased {
321 high_priority_batch_size: usize,
322 normal_priority_batch_size: usize,
323 low_priority_batch_size: usize,
324 },
325}
326
327#[derive(Debug)]
329pub struct DynamicBatchManager {
330 strategy: BatchingStrategy,
331 pending_tensors: Vec<(Tensor, usize)>, current_batch_size: usize,
333 recent_latencies: VecDeque<f64>,
334 total_batches_processed: usize,
335}
336
337impl DynamicBatchManager {
338 pub fn new(strategy: BatchingStrategy) -> Self {
340 let initial_batch_size = match &strategy {
341 BatchingStrategy::Fixed(size) => *size,
342 BatchingStrategy::DynamicByLength { max_batch_size, .. } => *max_batch_size / 2,
343 BatchingStrategy::DynamicByMemory { .. } => 16,
344 BatchingStrategy::Adaptive {
345 initial_batch_size, ..
346 } => *initial_batch_size,
347 BatchingStrategy::PriorityBased {
348 normal_priority_batch_size,
349 ..
350 } => *normal_priority_batch_size,
351 };
352
353 Self {
354 strategy,
355 pending_tensors: Vec::new(),
356 current_batch_size: initial_batch_size,
357 recent_latencies: VecDeque::new(),
358 total_batches_processed: 0,
359 }
360 }
361
362 pub fn record_latency(&mut self, latency_ms: f64) {
364 self.recent_latencies.push_back(latency_ms);
365
366 if self.recent_latencies.len() > 20 {
368 self.recent_latencies.pop_front();
369 }
370
371 self.total_batches_processed += 1;
372
373 if let BatchingStrategy::Adaptive {
375 target_latency_ms,
376 max_batch_size,
377 adjustment_factor,
378 ..
379 } = &self.strategy
380 {
381 if self.recent_latencies.len() >= 5 {
382 let avg_latency: f64 =
383 self.recent_latencies.iter().sum::<f64>() / self.recent_latencies.len() as f64;
384
385 if avg_latency > *target_latency_ms {
386 self.current_batch_size = std::cmp::max(
388 1,
389 (self.current_batch_size as f64 * (1.0 - adjustment_factor)) as usize,
390 );
391 } else if avg_latency < *target_latency_ms * 0.8 {
392 self.current_batch_size = std::cmp::min(
394 *max_batch_size,
395 (self.current_batch_size as f64 * (1.0 + adjustment_factor)) as usize,
396 );
397 }
398 }
399 }
400 }
401
402 pub fn add_tensor(&mut self, tensor: Tensor, priority: usize) -> Result<()> {
404 self.pending_tensors.push((tensor, priority));
405 self.pending_tensors.sort_by_key(|item| std::cmp::Reverse(item.1));
407 Ok(())
408 }
409
410 pub fn get_next_batch(&mut self) -> Result<Option<Vec<Tensor>>> {
412 if self.pending_tensors.is_empty() {
413 return Ok(None);
414 }
415
416 match &self.strategy {
417 BatchingStrategy::Fixed(batch_size) => {
418 if self.pending_tensors.len() >= *batch_size {
419 let batch: Vec<Tensor> = self
420 .pending_tensors
421 .drain(0..*batch_size)
422 .map(|(tensor, _)| tensor)
423 .collect();
424 Ok(Some(batch))
425 } else {
426 Ok(None)
427 }
428 },
429 BatchingStrategy::DynamicByLength {
430 max_length: _,
431 max_batch_size,
432 } => {
433 let batch_size = std::cmp::min(self.pending_tensors.len(), *max_batch_size);
434 if batch_size > 0 {
435 let batch: Vec<Tensor> = self
436 .pending_tensors
437 .drain(0..batch_size)
438 .map(|(tensor, _)| tensor)
439 .collect();
440 Ok(Some(batch))
441 } else {
442 Ok(None)
443 }
444 },
445 BatchingStrategy::DynamicByMemory { max_memory_mb } => {
446 let mut batch = Vec::new();
447 let mut current_memory = 0;
448
449 while !self.pending_tensors.is_empty() {
450 let tensor_memory = self.estimate_tensor_memory(&self.pending_tensors[0].0)?;
451 if current_memory + tensor_memory <= *max_memory_mb * 1024 * 1024 {
452 let (tensor, _) = self.pending_tensors.remove(0);
453 batch.push(tensor);
454 current_memory += tensor_memory;
455 } else {
456 break;
457 }
458 }
459
460 if batch.is_empty() {
461 Ok(None)
462 } else {
463 Ok(Some(batch))
464 }
465 },
466 BatchingStrategy::Adaptive { .. } => {
467 if self.pending_tensors.len() >= self.current_batch_size {
468 let batch: Vec<Tensor> = self
469 .pending_tensors
470 .drain(0..self.current_batch_size)
471 .map(|(tensor, _)| tensor)
472 .collect();
473 Ok(Some(batch))
474 } else {
475 Ok(None)
476 }
477 },
478 BatchingStrategy::PriorityBased {
479 high_priority_batch_size,
480 normal_priority_batch_size,
481 low_priority_batch_size,
482 } => {
483 let high_priority: Vec<_> = self
485 .pending_tensors
486 .iter()
487 .filter(|(_, priority)| *priority >= 80)
488 .cloned()
489 .collect();
490 let normal_priority: Vec<_> = self
491 .pending_tensors
492 .iter()
493 .filter(|(_, priority)| *priority >= 40 && *priority < 80)
494 .cloned()
495 .collect();
496 let low_priority: Vec<_> = self
497 .pending_tensors
498 .iter()
499 .filter(|(_, priority)| *priority < 40)
500 .cloned()
501 .collect();
502
503 if high_priority.len() >= *high_priority_batch_size {
504 let batch: Vec<Tensor> = high_priority
505 .into_iter()
506 .take(*high_priority_batch_size)
507 .map(|(tensor, _)| tensor)
508 .collect();
509 self.pending_tensors.retain(|(_, priority)| *priority < 80);
511 Ok(Some(batch))
512 } else if normal_priority.len() >= *normal_priority_batch_size {
513 let batch: Vec<Tensor> = normal_priority
514 .into_iter()
515 .take(*normal_priority_batch_size)
516 .map(|(tensor, _)| tensor)
517 .collect();
518 self.pending_tensors.retain(|(_, priority)| *priority < 40 || *priority >= 80);
520 Ok(Some(batch))
521 } else if low_priority.len() >= *low_priority_batch_size {
522 let batch: Vec<Tensor> = low_priority
523 .into_iter()
524 .take(*low_priority_batch_size)
525 .map(|(tensor, _)| tensor)
526 .collect();
527 self.pending_tensors.retain(|(_, priority)| *priority >= 40);
529 Ok(Some(batch))
530 } else {
531 Ok(None)
532 }
533 },
534 }
535 }
536
537 fn estimate_tensor_memory(&self, tensor: &Tensor) -> Result<usize> {
539 let shape = tensor.shape();
540 let elements = shape.iter().product::<usize>();
541 Ok(elements * 4)
543 }
544
545 pub fn pending_count(&self) -> usize {
547 self.pending_tensors.len()
548 }
549
550 pub fn current_batch_size(&self) -> usize {
552 self.current_batch_size
553 }
554
555 pub fn average_latency(&self) -> f64 {
557 if self.recent_latencies.is_empty() {
558 0.0
559 } else {
560 self.recent_latencies.iter().sum::<f64>() / self.recent_latencies.len() as f64
561 }
562 }
563
564 pub fn get_batch_statistics(&self) -> BatchStatistics {
566 BatchStatistics {
567 total_batches_processed: self.total_batches_processed,
568 current_batch_size: self.current_batch_size,
569 pending_tensors: self.pending_tensors.len(),
570 average_latency_ms: self.average_latency(),
571 strategy_type: match &self.strategy {
572 BatchingStrategy::Fixed(_) => "Fixed".to_string(),
573 BatchingStrategy::DynamicByLength { .. } => "DynamicByLength".to_string(),
574 BatchingStrategy::DynamicByMemory { .. } => "DynamicByMemory".to_string(),
575 BatchingStrategy::Adaptive { .. } => "Adaptive".to_string(),
576 BatchingStrategy::PriorityBased { .. } => "PriorityBased".to_string(),
577 },
578 }
579 }
580}
581
582#[derive(Debug, Default)]
584pub struct PerformanceMonitor {
585 total_inference_time: f64,
586 total_inferences: usize,
587 batch_sizes: Vec<usize>,
588 memory_usage: Vec<usize>,
589}
590
591impl PerformanceMonitor {
592 pub fn record_inference(&mut self, time_ms: f64, batch_size: usize, memory_usage: usize) {
594 self.total_inference_time += time_ms;
595 self.total_inferences += 1;
596 self.batch_sizes.push(batch_size);
597 self.memory_usage.push(memory_usage);
598 }
599
600 pub fn average_inference_time(&self) -> f64 {
602 if self.total_inferences > 0 {
603 self.total_inference_time / self.total_inferences as f64
604 } else {
605 0.0
606 }
607 }
608
609 pub fn average_batch_size(&self) -> f64 {
611 if self.batch_sizes.is_empty() {
612 0.0
613 } else {
614 self.batch_sizes.iter().sum::<usize>() as f64 / self.batch_sizes.len() as f64
615 }
616 }
617
618 pub fn peak_memory_usage(&self) -> usize {
620 self.memory_usage.iter().max().copied().unwrap_or(0)
621 }
622
623 pub fn get_statistics(&self) -> PerformanceStatistics {
625 PerformanceStatistics {
626 total_inferences: self.total_inferences,
627 average_inference_time_ms: self.average_inference_time(),
628 average_batch_size: self.average_batch_size(),
629 peak_memory_usage_bytes: self.peak_memory_usage(),
630 throughput_inferences_per_second: if self.total_inference_time > 0.0 {
631 (self.total_inferences as f64) / (self.total_inference_time / 1000.0)
632 } else {
633 0.0
634 },
635 }
636 }
637}
638
639#[derive(Debug, Clone, Serialize, Deserialize)]
641pub struct CacheStatistics {
642 pub current_size: usize,
643 pub max_size: usize,
644 pub hit_rate: f64,
645}
646
647#[derive(Debug, Clone, Serialize, Deserialize)]
649pub struct PerformanceStatistics {
650 pub total_inferences: usize,
651 pub average_inference_time_ms: f64,
652 pub average_batch_size: f64,
653 pub peak_memory_usage_bytes: usize,
654 pub throughput_inferences_per_second: f64,
655}
656
657#[derive(Debug)]
659pub struct AdvancedPerformanceOptimizer {
660 #[allow(dead_code)]
661 config: PerformanceConfig,
662 workload_history: Vec<WorkloadMetrics>,
663 optimization_recommendations: Vec<String>,
664}
665
666#[derive(Debug, Clone)]
668pub struct WorkloadMetrics {
669 pub batch_size: usize,
670 pub sequence_length: usize,
671 pub memory_usage: usize,
672 pub inference_time_ms: f64,
673 pub timestamp: std::time::Instant,
674}
675
676impl AdvancedPerformanceOptimizer {
677 pub fn new(config: PerformanceConfig) -> Self {
679 Self {
680 config,
681 workload_history: Vec::new(),
682 optimization_recommendations: Vec::new(),
683 }
684 }
685
686 pub fn record_workload(&mut self, metrics: WorkloadMetrics) {
688 self.workload_history.push(metrics);
689
690 if self.workload_history.len() > 1000 {
692 self.workload_history.remove(0);
693 }
694
695 self.generate_recommendations();
697 }
698
699 fn generate_recommendations(&mut self) {
701 self.optimization_recommendations.clear();
702
703 if self.workload_history.len() < 10 {
704 return;
705 }
706
707 let recent_metrics: Vec<_> = self.workload_history.iter().rev().take(50).collect();
709
710 let avg_batch_size: f64 = recent_metrics.iter().map(|m| m.batch_size as f64).sum::<f64>()
712 / recent_metrics.len() as f64;
713
714 if avg_batch_size < 8.0 {
715 self.optimization_recommendations
716 .push("Consider increasing batch size for better throughput".to_string());
717 }
718
719 let memory_usages: Vec<usize> = recent_metrics.iter().map(|m| m.memory_usage).collect();
721 let max_memory = memory_usages.iter().max().unwrap_or(&0);
722 let min_memory = memory_usages.iter().min().unwrap_or(&0);
723
724 if *max_memory > min_memory * 2 {
725 self.optimization_recommendations.push(
726 "High memory usage variation detected - consider dynamic batching".to_string(),
727 );
728 }
729
730 if recent_metrics.len() >= 20 {
732 let first_half_avg: f64 =
733 recent_metrics[10..].iter().map(|m| m.inference_time_ms).sum::<f64>() / 10.0;
734 let second_half_avg: f64 =
735 recent_metrics[..10].iter().map(|m| m.inference_time_ms).sum::<f64>() / 10.0;
736
737 if second_half_avg > first_half_avg * 1.2 {
738 self.optimization_recommendations.push(
739 "Performance degradation detected - consider cache clearing or model reloading"
740 .to_string(),
741 );
742 }
743 }
744 }
745
746 pub fn get_recommendations(&self) -> &[String] {
748 &self.optimization_recommendations
749 }
750
751 pub fn get_workload_analysis(&self) -> WorkloadAnalysis {
753 if self.workload_history.is_empty() {
754 return WorkloadAnalysis::default();
755 }
756
757 let total_metrics = self.workload_history.len();
758 let avg_batch_size = self.workload_history.iter().map(|m| m.batch_size as f64).sum::<f64>()
759 / total_metrics as f64;
760
761 let avg_inference_time =
762 self.workload_history.iter().map(|m| m.inference_time_ms).sum::<f64>()
763 / total_metrics as f64;
764
765 let peak_memory = self.workload_history.iter().map(|m| m.memory_usage).max().unwrap_or(0);
766
767 WorkloadAnalysis {
768 total_samples: total_metrics,
769 average_batch_size: avg_batch_size,
770 average_inference_time_ms: avg_inference_time,
771 peak_memory_usage_bytes: peak_memory,
772 recommendations_count: self.optimization_recommendations.len(),
773 }
774 }
775}
776
777#[derive(Debug, Default, Clone, Serialize, Deserialize)]
779pub struct WorkloadAnalysis {
780 pub total_samples: usize,
781 pub average_batch_size: f64,
782 pub average_inference_time_ms: f64,
783 pub peak_memory_usage_bytes: usize,
784 pub recommendations_count: usize,
785}
786
787#[derive(Debug, Clone, Serialize, Deserialize)]
789pub struct BatchStatistics {
790 pub total_batches_processed: usize,
791 pub current_batch_size: usize,
792 pub pending_tensors: usize,
793 pub average_latency_ms: f64,
794 pub strategy_type: String,
795}
796
797#[cfg(test)]
798mod tests {
799 use super::*;
800
801 #[test]
802 fn test_performance_config_default() {
803 let config = PerformanceConfig::default();
804 assert_eq!(config.max_batch_size, 32);
805 assert!(config.enable_dynamic_batching);
806 assert_eq!(config.cache_size, 1000);
807 assert!(config.enable_memory_optimization);
808 }
809
810 #[test]
811 fn test_batch_processor_creation() {
812 let config = PerformanceConfig::default();
813 let processor = BatchProcessor::new(config);
814 assert_eq!(processor.current_batch_size(), 0);
815 }
816
817 #[test]
818 fn test_memory_optimizer_estimate() {
819 let tensor = Tensor::zeros(&[2, 3]).expect("operation failed");
821 let tensors = vec![tensor];
822
823 let estimated = MemoryOptimizer::estimate_memory_usage(&tensors).expect("operation failed");
824 assert_eq!(estimated, 24);
826 }
827
828 #[test]
829 fn test_dynamic_batch_manager() {
830 let strategy = BatchingStrategy::Fixed(2);
831 let mut manager = DynamicBatchManager::new(strategy);
832
833 let tensor1 = Tensor::zeros(&[1, 2]).expect("operation failed");
834 let tensor2 = Tensor::zeros(&[1, 2]).expect("operation failed");
835
836 manager.add_tensor(tensor1, 1).expect("operation failed");
837 manager.add_tensor(tensor2, 2).expect("operation failed");
838
839 let batch = manager.get_next_batch().expect("operation failed");
840 assert!(batch.is_some());
841 assert_eq!(batch.expect("operation failed").len(), 2);
842 }
843
844 #[test]
845 fn test_performance_monitor() {
846 let mut monitor = PerformanceMonitor::default();
847
848 monitor.record_inference(100.0, 4, 1024);
849 monitor.record_inference(200.0, 8, 2048);
850
851 let stats = monitor.get_statistics();
852 assert_eq!(stats.total_inferences, 2);
853 assert_eq!(stats.average_inference_time_ms, 150.0);
854 assert_eq!(stats.average_batch_size, 6.0);
855 assert_eq!(stats.peak_memory_usage_bytes, 2048);
856 }
857
858 #[test]
859 fn test_cache_statistics() {
860 let config = PerformanceConfig::default();
861 let processor = BatchProcessor::new(config);
862 let stats = processor.cache_stats();
863
864 assert_eq!(stats.current_size, 0);
865 assert_eq!(stats.max_size, 1000);
866 assert_eq!(stats.hit_rate, 0.0);
867 }
868
869 #[test]
870 fn test_advanced_performance_optimizer() {
871 let config = PerformanceConfig::default();
872 let mut optimizer = AdvancedPerformanceOptimizer::new(config);
873
874 for i in 1..=20 {
876 let metrics = WorkloadMetrics {
877 batch_size: if i < 10 { 2 } else { 16 }, sequence_length: 512,
879 memory_usage: 1024 * i,
880 inference_time_ms: 100.0 + (i as f64 * 5.0),
881 timestamp: std::time::Instant::now(),
882 };
883 optimizer.record_workload(metrics);
884 }
885
886 let analysis = optimizer.get_workload_analysis();
887 assert_eq!(analysis.total_samples, 20);
888 assert!(analysis.average_batch_size > 2.0); let recommendations = optimizer.get_recommendations();
891 assert!(!recommendations.is_empty()); }
893
894 #[test]
895 fn test_lru_cache() {
896 let mut cache = LruCache::new(2);
897
898 let tensor1 = Tensor::zeros(&[1, 2]).expect("operation failed");
899 let tensor2 = Tensor::zeros(&[1, 3]).expect("operation failed");
900 let tensor3 = Tensor::zeros(&[1, 4]).expect("operation failed");
901
902 cache.put("key1".to_string(), tensor1);
904 cache.put("key2".to_string(), tensor2);
905
906 let _ = cache.get("key1");
908
909 cache.put("key3".to_string(), tensor3);
911
912 assert!(cache.get("key1").is_some());
914 assert!(cache.get("key3").is_some());
915 assert!(cache.get("key2").is_none());
916
917 let stats = cache.statistics();
919 assert_eq!(stats.current_size, 2);
920 assert_eq!(stats.max_size, 2);
921 assert!(stats.hit_rate > 0.0);
922 }
923
924 #[test]
925 fn test_adaptive_batching() {
926 let strategy = BatchingStrategy::Adaptive {
927 initial_batch_size: 4,
928 max_batch_size: 16,
929 target_latency_ms: 100.0,
930 adjustment_factor: 0.2,
931 };
932 let mut manager = DynamicBatchManager::new(strategy);
933
934 for _ in 0..10 {
936 manager.record_latency(150.0); }
938
939 assert!(manager.current_batch_size() < 4); for _ in 0..10 {
943 manager.record_latency(50.0); }
945
946 let stats = manager.get_batch_statistics();
948 assert_eq!(stats.strategy_type, "Adaptive");
949 assert!(stats.average_latency_ms > 0.0);
950 }
951
952 #[test]
953 fn test_priority_batching() {
954 let strategy = BatchingStrategy::PriorityBased {
955 high_priority_batch_size: 2,
956 normal_priority_batch_size: 4,
957 low_priority_batch_size: 8,
958 };
959 let mut manager = DynamicBatchManager::new(strategy);
960
961 let tensor = Tensor::zeros(&[1, 2]).expect("operation failed");
963 manager.add_tensor(tensor.clone(), 90).expect("operation failed"); manager.add_tensor(tensor.clone(), 50).expect("operation failed"); manager.add_tensor(tensor.clone(), 90).expect("operation failed"); manager.add_tensor(tensor.clone(), 20).expect("operation failed"); let batch = manager.get_next_batch().expect("operation failed");
970 assert!(batch.is_some());
971 assert_eq!(batch.expect("operation failed").len(), 2); let stats = manager.get_batch_statistics();
974 assert_eq!(stats.strategy_type, "PriorityBased");
975 }
976}
977
978#[derive(Debug)]
984pub struct GpuMemoryPool {
985 pools: HashMap<usize, VecDeque<GpuMemoryChunk>>,
987 total_allocated: usize,
989 max_memory_limit: usize,
991 fragmentation_threshold: f32,
993 stats: GpuMemoryStats,
995}
996
997#[derive(Debug, Clone)]
998pub struct GpuMemoryChunk {
999 pub id: String,
1001 pub size_bytes: usize,
1003 pub in_use: bool,
1005 pub allocated_at: std::time::Instant,
1007 pub last_accessed: std::time::Instant,
1009 pub ref_count: usize,
1011}
1012
1013#[derive(Debug, Default, Clone)]
1014pub struct GpuMemoryStats {
1015 pub total_allocations: usize,
1017 pub total_deallocations: usize,
1019 pub active_allocations: usize,
1021 pub peak_memory_usage: usize,
1023 pub current_memory_usage: usize,
1025 pub fragmentation_ratio: f32,
1027 pub average_allocation_size: f32,
1029 pub cache_hits: usize,
1031 pub cache_misses: usize,
1033}
1034
1035impl GpuMemoryPool {
1036 pub fn new(max_memory_limit: usize) -> Self {
1038 Self {
1039 pools: HashMap::new(),
1040 total_allocated: 0,
1041 max_memory_limit,
1042 fragmentation_threshold: 0.25, stats: GpuMemoryStats::default(),
1044 }
1045 }
1046
1047 pub fn allocate(&mut self, size_bytes: usize) -> Result<GpuMemoryChunk> {
1049 if self.total_allocated + size_bytes > self.max_memory_limit {
1051 self.try_defragment()?;
1052 if self.total_allocated + size_bytes > self.max_memory_limit {
1053 return Err(TrustformersError::invalid_operation(
1054 "GPU memory limit exceeded".to_string(),
1055 ));
1056 }
1057 }
1058
1059 if let Some(chunk) = self.find_suitable_chunk(size_bytes) {
1061 self.stats.cache_hits += 1;
1062 self.stats.active_allocations += 1;
1063 return Ok(chunk);
1064 }
1065
1066 let chunk = GpuMemoryChunk {
1068 id: uuid::Uuid::new_v4().to_string(),
1069 size_bytes,
1070 in_use: true,
1071 allocated_at: std::time::Instant::now(),
1072 last_accessed: std::time::Instant::now(),
1073 ref_count: 1,
1074 };
1075
1076 self.total_allocated += size_bytes;
1077 self.stats.total_allocations += 1;
1078 self.stats.active_allocations += 1;
1079 self.stats.cache_misses += 1;
1080 self.stats.current_memory_usage += size_bytes;
1081
1082 if self.stats.current_memory_usage > self.stats.peak_memory_usage {
1083 self.stats.peak_memory_usage = self.stats.current_memory_usage;
1084 }
1085
1086 self.stats.average_allocation_size = (self.stats.average_allocation_size
1088 * (self.stats.total_allocations - 1) as f32
1089 + size_bytes as f32)
1090 / self.stats.total_allocations as f32;
1091
1092 Ok(chunk)
1093 }
1094
1095 pub fn deallocate(&mut self, mut chunk: GpuMemoryChunk) -> Result<()> {
1097 chunk.in_use = false;
1098 chunk.ref_count = 0;
1099
1100 let pool = self.pools.entry(chunk.size_bytes).or_default();
1102 pool.push_back(chunk.clone());
1103
1104 self.stats.total_deallocations += 1;
1105 self.stats.active_allocations = self.stats.active_allocations.saturating_sub(1);
1106 self.stats.current_memory_usage =
1107 self.stats.current_memory_usage.saturating_sub(chunk.size_bytes);
1108
1109 self.cleanup_unused_chunks()?;
1111
1112 Ok(())
1113 }
1114
1115 fn find_suitable_chunk(&mut self, size_bytes: usize) -> Option<GpuMemoryChunk> {
1117 if let Some(pool) = self.pools.get_mut(&size_bytes) {
1119 if let Some(mut chunk) = pool.pop_front() {
1120 chunk.in_use = true;
1121 chunk.last_accessed = std::time::Instant::now();
1122 chunk.ref_count = 1;
1123 return Some(chunk);
1124 }
1125 }
1126
1127 let suitable_sizes: Vec<usize> = self.pools.keys()
1129 .filter(|&&size| size > size_bytes && size <= size_bytes * 2) .copied()
1131 .collect();
1132
1133 for pool_size in suitable_sizes {
1134 if let Some(pool) = self.pools.get_mut(&pool_size) {
1135 if let Some(mut chunk) = pool.pop_front() {
1136 chunk.in_use = true;
1137 chunk.last_accessed = std::time::Instant::now();
1138 chunk.ref_count = 1;
1139 return Some(chunk);
1140 }
1141 }
1142 }
1143
1144 None
1145 }
1146
1147 fn cleanup_unused_chunks(&mut self) -> Result<()> {
1149 let now = std::time::Instant::now();
1150 let cleanup_threshold = std::time::Duration::from_secs(300); for pool in self.pools.values_mut() {
1153 pool.retain(|chunk| {
1154 let should_keep =
1155 chunk.in_use || now.duration_since(chunk.last_accessed) < cleanup_threshold;
1156 if !should_keep {
1157 self.total_allocated = self.total_allocated.saturating_sub(chunk.size_bytes);
1158 }
1159 should_keep
1160 });
1161 }
1162
1163 Ok(())
1164 }
1165
1166 fn try_defragment(&mut self) -> Result<()> {
1168 let total_pooled = self
1170 .pools
1171 .values()
1172 .map(|pool| pool.iter().map(|chunk| chunk.size_bytes).sum::<usize>())
1173 .sum::<usize>();
1174
1175 self.stats.fragmentation_ratio = if self.total_allocated > 0 {
1176 total_pooled as f32 / self.total_allocated as f32
1177 } else {
1178 0.0
1179 };
1180
1181 if self.stats.fragmentation_ratio > self.fragmentation_threshold {
1183 self.force_cleanup()?;
1184 }
1185
1186 Ok(())
1187 }
1188
1189 fn force_cleanup(&mut self) -> Result<()> {
1191 for pool in self.pools.values_mut() {
1192 let initial_size: usize = pool.iter().map(|chunk| chunk.size_bytes).sum();
1193 pool.retain(|chunk| chunk.in_use);
1194 let final_size: usize = pool.iter().map(|chunk| chunk.size_bytes).sum();
1195 self.total_allocated = self.total_allocated.saturating_sub(initial_size - final_size);
1196 }
1197
1198 self.try_defragment()?;
1200
1201 Ok(())
1202 }
1203
1204 pub fn get_statistics(&self) -> GpuMemoryStats {
1206 self.stats.clone()
1207 }
1208
1209 pub fn get_memory_usage_percentage(&self) -> f32 {
1211 (self.total_allocated as f32 / self.max_memory_limit as f32) * 100.0
1212 }
1213
1214 pub fn get_cache_efficiency(&self) -> f32 {
1216 let total_requests = self.stats.cache_hits + self.stats.cache_misses;
1217 if total_requests > 0 {
1218 self.stats.cache_hits as f32 / total_requests as f32
1219 } else {
1220 0.0
1221 }
1222 }
1223}
1224
1225#[derive(Debug)]
1227pub struct GpuTensorCache {
1228 memory_pool: GpuMemoryPool,
1230 tensor_cache: HashMap<String, CachedTensor>,
1232 lru_order: VecDeque<String>,
1234 max_cache_size: usize,
1236 stats: CacheStatistics,
1238}
1239
1240#[derive(Debug, Clone)]
1241pub struct CachedTensor {
1242 pub tensor: Tensor,
1244 pub memory_chunk: GpuMemoryChunk,
1246 pub access_frequency: f32,
1248 pub importance_score: f32,
1250 pub last_access: std::time::Instant,
1252 pub created_at: std::time::Instant,
1254}
1255
1256impl GpuTensorCache {
1257 pub fn new(max_cache_size: usize, max_memory_limit: usize) -> Self {
1259 Self {
1260 memory_pool: GpuMemoryPool::new(max_memory_limit),
1261 tensor_cache: HashMap::new(),
1262 lru_order: VecDeque::new(),
1263 max_cache_size,
1264 stats: CacheStatistics {
1265 current_size: 0,
1266 max_size: max_cache_size,
1267 hit_rate: 0.0,
1268 },
1269 }
1270 }
1271
1272 pub fn cache_tensor(
1274 &mut self,
1275 key: String,
1276 tensor: Tensor,
1277 importance_score: Option<f32>,
1278 ) -> Result<()> {
1279 let tensor_size = self.estimate_tensor_size(&tensor);
1281
1282 let memory_chunk = self.memory_pool.allocate(tensor_size)?;
1284
1285 let cached_tensor = CachedTensor {
1287 tensor,
1288 memory_chunk,
1289 access_frequency: 1.0,
1290 importance_score: importance_score.unwrap_or(0.5),
1291 last_access: std::time::Instant::now(),
1292 created_at: std::time::Instant::now(),
1293 };
1294
1295 if self.tensor_cache.len() >= self.max_cache_size {
1297 self.evict_least_important()?;
1298 }
1299
1300 self.tensor_cache.insert(key.clone(), cached_tensor);
1302 self.lru_order.push_back(key);
1303 self.stats.current_size = self.tensor_cache.len();
1304
1305 Ok(())
1306 }
1307
1308 pub fn get_tensor(&mut self, key: &str) -> Option<&Tensor> {
1310 if !self.tensor_cache.contains_key(key) {
1312 return None;
1313 }
1314
1315 self.update_lru_order(key);
1317
1318 if let Some(cached_tensor) = self.tensor_cache.get_mut(key) {
1320 cached_tensor.access_frequency += 1.0;
1321 cached_tensor.last_access = std::time::Instant::now();
1322 Some(&cached_tensor.tensor)
1323 } else {
1324 None
1325 }
1326 }
1327
1328 fn update_lru_order(&mut self, key: &str) {
1330 if let Some(pos) = self.lru_order.iter().position(|k| k == key) {
1332 self.lru_order.remove(pos);
1333 self.lru_order.push_back(key.to_string());
1334 }
1335 }
1336
1337 fn evict_least_important(&mut self) -> Result<()> {
1339 let mut eviction_candidates: Vec<(String, f32)> = self
1341 .tensor_cache
1342 .iter()
1343 .map(|(key, cached_tensor)| {
1344 let age_factor = cached_tensor.created_at.elapsed().as_secs() as f32 / 3600.0; let frequency_factor = cached_tensor.access_frequency;
1346 let importance_factor = cached_tensor.importance_score;
1347
1348 let eviction_score = importance_factor * frequency_factor / (1.0 + age_factor);
1350 (key.clone(), eviction_score)
1351 })
1352 .collect();
1353
1354 eviction_candidates
1356 .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1357
1358 if let Some((key_to_evict, _)) = eviction_candidates.first() {
1360 if let Some(cached_tensor) = self.tensor_cache.remove(key_to_evict) {
1361 self.memory_pool.deallocate(cached_tensor.memory_chunk)?;
1362
1363 if let Some(pos) = self.lru_order.iter().position(|k| k == key_to_evict) {
1365 self.lru_order.remove(pos);
1366 }
1367
1368 self.stats.current_size = self.tensor_cache.len();
1369 }
1370 }
1371
1372 Ok(())
1373 }
1374
1375 fn estimate_tensor_size(&self, tensor: &Tensor) -> usize {
1377 match tensor {
1378 Tensor::F32(arr) => arr.len() * 4, Tensor::F64(arr) => arr.len() * 8, _ => 1024, }
1382 }
1383
1384 pub fn get_comprehensive_stats(&self) -> GpuCacheStatistics {
1386 let memory_stats = self.memory_pool.get_statistics();
1387 let fragmentation_ratio = memory_stats.fragmentation_ratio;
1388
1389 GpuCacheStatistics {
1390 cache_stats: self.stats.clone(),
1391 memory_stats,
1392 memory_usage_percentage: self.memory_pool.get_memory_usage_percentage(),
1393 cache_efficiency: self.memory_pool.get_cache_efficiency(),
1394 average_tensor_age: self.calculate_average_tensor_age(),
1395 fragmentation_ratio,
1396 }
1397 }
1398
1399 fn calculate_average_tensor_age(&self) -> f32 {
1401 if self.tensor_cache.is_empty() {
1402 return 0.0;
1403 }
1404
1405 let total_age: f32 = self
1406 .tensor_cache
1407 .values()
1408 .map(|cached_tensor| cached_tensor.created_at.elapsed().as_secs() as f32)
1409 .sum();
1410
1411 total_age / self.tensor_cache.len() as f32
1412 }
1413
1414 pub fn clear(&mut self) -> Result<()> {
1416 for (_, cached_tensor) in self.tensor_cache.drain() {
1417 self.memory_pool.deallocate(cached_tensor.memory_chunk)?;
1418 }
1419 self.lru_order.clear();
1420 self.stats.current_size = 0;
1421 Ok(())
1422 }
1423}
1424
1425#[derive(Debug, Clone)]
1427pub struct GpuCacheStatistics {
1428 pub cache_stats: CacheStatistics,
1429 pub memory_stats: GpuMemoryStats,
1430 pub memory_usage_percentage: f32,
1431 pub cache_efficiency: f32,
1432 pub average_tensor_age: f32,
1433 pub fragmentation_ratio: f32,
1434}
1435
1436#[derive(Debug, Clone)]
1438pub struct GpuOptimizationRecommendations {
1439 pub recommendations: Vec<String>,
1441 pub priority: String,
1443 pub estimated_improvement: f32,
1445}
1446
1447pub struct GpuMemoryOptimizer;
1449
1450impl GpuMemoryOptimizer {
1451 pub fn analyze_and_recommend(stats: &GpuCacheStatistics) -> GpuOptimizationRecommendations {
1453 let mut recommendations = Vec::new();
1454 let mut priority = "Low".to_string();
1455 let mut estimated_improvement: f32 = 0.0;
1456
1457 if stats.memory_usage_percentage > 90.0 {
1459 recommendations.push("Critical: Memory usage is very high. Consider increasing memory limit or improving eviction strategy.".to_string());
1460 priority = "High".to_string();
1461 estimated_improvement += 25.0;
1462 } else if stats.memory_usage_percentage > 75.0 {
1463 recommendations.push(
1464 "Warning: Memory usage is high. Monitor for potential memory pressure.".to_string(),
1465 );
1466 priority = "Medium".to_string();
1467 estimated_improvement += 10.0;
1468 }
1469
1470 if stats.fragmentation_ratio > 0.4 {
1472 recommendations.push(
1473 "High memory fragmentation detected. Consider running defragmentation.".to_string(),
1474 );
1475 if priority == "Low" {
1476 priority = "Medium".to_string();
1477 }
1478 estimated_improvement += 15.0;
1479 }
1480
1481 if stats.cache_efficiency < 0.7 {
1483 recommendations.push(
1484 "Low cache hit rate. Consider adjusting cache size or eviction policy.".to_string(),
1485 );
1486 if priority == "Low" {
1487 priority = "Medium".to_string();
1488 }
1489 estimated_improvement += 20.0;
1490 }
1491
1492 if stats.average_tensor_age > 3600.0 {
1494 recommendations.push(
1496 "Cached tensors are aging. Consider more aggressive eviction for unused tensors."
1497 .to_string(),
1498 );
1499 estimated_improvement += 5.0;
1500 }
1501
1502 if stats.memory_stats.active_allocations > 1000 {
1504 recommendations.push(
1505 "High number of active allocations. Consider batching or pooling strategies."
1506 .to_string(),
1507 );
1508 estimated_improvement += 12.0;
1509 }
1510
1511 if recommendations.is_empty() {
1512 recommendations
1513 .push("GPU memory usage is optimal. No immediate action required.".to_string());
1514 }
1515
1516 GpuOptimizationRecommendations {
1517 recommendations,
1518 priority,
1519 estimated_improvement: estimated_improvement.min(50.0), }
1521 }
1522
1523 pub fn auto_optimize(cache: &mut GpuTensorCache) -> Result<Vec<String>> {
1525 let stats = cache.get_comprehensive_stats();
1526 let recommendations = Self::analyze_and_recommend(&stats);
1527 let mut actions_taken = Vec::new();
1528
1529 if recommendations.priority == "High" {
1531 if stats.memory_usage_percentage > 90.0 {
1533 cache.memory_pool.force_cleanup()?;
1534 actions_taken.push("Performed emergency memory cleanup".to_string());
1535 }
1536 }
1537
1538 if stats.fragmentation_ratio > 0.4 {
1539 cache.memory_pool.try_defragment()?;
1540 actions_taken.push("Performed memory defragmentation".to_string());
1541 }
1542
1543 if actions_taken.is_empty() {
1544 actions_taken.push("No automatic optimizations were necessary".to_string());
1545 }
1546
1547 Ok(actions_taken)
1548 }
1549}
1550
1551#[cfg(test)]
1552mod gpu_memory_tests {
1553 use super::*;
1554
1555 #[test]
1556 fn test_gpu_memory_pool_basic() {
1557 let mut pool = GpuMemoryPool::new(1024 * 1024); let chunk = pool.allocate(1024).expect("operation failed");
1561 assert_eq!(chunk.size_bytes, 1024);
1562 assert!(chunk.in_use);
1563 assert_eq!(pool.get_statistics().active_allocations, 1);
1564
1565 pool.deallocate(chunk).expect("operation failed");
1567 assert_eq!(pool.get_statistics().active_allocations, 0);
1568 }
1569
1570 #[test]
1571 fn test_gpu_memory_pool_reuse() {
1572 let mut pool = GpuMemoryPool::new(1024 * 1024);
1573
1574 let chunk = pool.allocate(1024).expect("operation failed");
1576 pool.deallocate(chunk).expect("operation failed");
1577
1578 let stats_before = pool.get_statistics();
1580 let _chunk2 = pool.allocate(1024).expect("operation failed");
1581 let stats_after = pool.get_statistics();
1582
1583 assert_eq!(stats_after.cache_hits, stats_before.cache_hits + 1);
1584 }
1585
1586 #[test]
1587 fn test_gpu_tensor_cache() -> Result<()> {
1588 let mut cache = GpuTensorCache::new(2, 1024 * 1024);
1589
1590 let tensor1 = Tensor::zeros(&[10, 10])?;
1591 let tensor2 = Tensor::zeros(&[5, 5])?;
1592 let tensor3 = Tensor::zeros(&[20, 20])?;
1593
1594 cache.cache_tensor("tensor1".to_string(), tensor1, Some(0.8))?;
1596 cache.cache_tensor("tensor2".to_string(), tensor2, Some(0.6))?;
1597
1598 assert!(cache.get_tensor("tensor1").is_some());
1600
1601 cache.cache_tensor("tensor3".to_string(), tensor3, Some(0.9))?;
1603
1604 assert!(cache.get_tensor("tensor2").is_none());
1606 assert!(cache.get_tensor("tensor1").is_some());
1607 assert!(cache.get_tensor("tensor3").is_some());
1608
1609 Ok(())
1610 }
1611
1612 #[test]
1613 fn test_gpu_optimization_recommendations() {
1614 let stats = GpuCacheStatistics {
1615 cache_stats: CacheStatistics {
1616 current_size: 100,
1617 max_size: 100,
1618 hit_rate: 0.5, },
1620 memory_stats: GpuMemoryStats {
1621 fragmentation_ratio: 0.5, ..Default::default()
1623 },
1624 memory_usage_percentage: 95.0, cache_efficiency: 0.5,
1626 average_tensor_age: 7200.0, fragmentation_ratio: 0.5,
1628 };
1629
1630 let recommendations = GpuMemoryOptimizer::analyze_and_recommend(&stats);
1631
1632 assert_eq!(recommendations.priority, "High");
1633 assert!(!recommendations.recommendations.is_empty());
1634 assert!(recommendations.estimated_improvement > 0.0);
1635 }
1636}