1use crate::device::context::{DeviceContext, DEVICE_MANAGER};
4use crate::{DType, Device, Result, Tensor};
5use std::collections::HashMap;
6use std::sync::{Arc, Mutex, RwLock};
7use std::time::{Duration, Instant};
8
9use super::config::{CachedOperation, EagerExecutionConfig, ExecutionMetrics, OpSignature};
10use super::memory_pool::{MemoryGuard, MemoryPool};
11use super::reporting::{CacheStatistics, EagerPerformanceReport};
12
13pub struct EagerExecutionEngine {
15 pub(super) config: EagerExecutionConfig,
16 pub(super) op_cache: RwLock<HashMap<OpSignature, CachedOperation>>,
17 pub(super) memory_pool: MemoryPool,
18 pub(super) metrics: Mutex<Vec<ExecutionMetrics>>,
19 pub(super) active_contexts: RwLock<HashMap<Device, Arc<dyn DeviceContext>>>,
20 pub(super) fusion_opportunities: RwLock<Vec<FusionOpportunity>>,
21}
22
23#[derive(Debug)]
24#[allow(dead_code)]
25pub(super) struct FusionOpportunity {
26 pub(super) operations: Vec<String>,
27 pub(super) potential_speedup: f64,
28 pub(super) memory_savings: usize,
29}
30
31impl EagerExecutionEngine {
32 pub fn new(config: EagerExecutionConfig) -> Self {
34 Self {
35 memory_pool: MemoryPool::new(config.clone()),
36 config,
37 op_cache: RwLock::new(HashMap::new()),
38 metrics: Mutex::new(Vec::new()),
39 active_contexts: RwLock::new(HashMap::new()),
40 fusion_opportunities: RwLock::new(Vec::new()),
41 }
42 }
43
44 pub fn execute_operation<T, F>(
46 &self,
47 operation: &str,
48 inputs: &[&Tensor<T>],
49 params: &HashMap<String, String>,
50 executor: F,
51 ) -> Result<(Tensor<T>, ExecutionMetrics)>
52 where
53 T: Clone + Send + Sync + 'static,
54 F: FnOnce(&[&Tensor<T>]) -> Result<Tensor<T>>,
55 {
56 let overall_start = Instant::now();
57
58 let signature = self.create_signature(operation, inputs, params)?;
60
61 let setup_start = Instant::now();
63 let cache_hit = self.check_cache(&signature);
64 let setup_time = setup_start.elapsed();
65
66 let exec_start = Instant::now();
68 let result = if cache_hit && self.config.enable_op_cache {
69 executor(inputs)?
71 } else {
72 let _memory_guard = if self.config.enable_memory_pool {
74 Some(self.prepare_memory_for_operation(&signature)?)
75 } else {
76 None
77 };
78
79 let result = if self.config.enable_context_optimization {
81 self.execute_with_context_optimization(inputs, executor)?
82 } else {
83 executor(inputs)?
84 };
85
86 if self.config.enable_op_cache {
88 self.cache_operation(&signature, &result, exec_start.elapsed())?;
89 }
90
91 result
92 };
93 let execution_time = exec_start.elapsed();
94
95 let teardown_start = Instant::now();
97 if self.config.enable_memory_pool {
98 self.cleanup_operation_memory(&signature)?;
99 }
100 let teardown_time = teardown_start.elapsed();
101
102 let total_time = overall_start.elapsed();
103 let total_overhead = total_time - execution_time;
104
105 let metrics = ExecutionMetrics {
107 operation: operation.to_string(),
108 device: *inputs[0].device(),
109 setup_time,
110 execution_time,
111 teardown_time,
112 total_overhead,
113 memory_allocation_time: Duration::ZERO, cache_hit,
115 meets_target: total_overhead.as_nanos() <= self.config.target_overhead_ns as u128,
116 };
117
118 self.metrics
119 .lock()
120 .expect("lock should not be poisoned")
121 .push(metrics.clone());
122
123 if self.config.enable_kernel_fusion {
125 self.analyze_fusion_opportunity(operation, &signature);
126 }
127
128 Ok((result, metrics))
129 }
130
131 fn create_signature<T: 'static>(
133 &self,
134 operation: &str,
135 inputs: &[&Tensor<T>],
136 params: &HashMap<String, String>,
137 ) -> Result<OpSignature> {
138 let input_shapes: Vec<Vec<usize>> =
139 inputs.iter().map(|t| t.shape().dims().to_vec()).collect();
140
141 let device = *inputs[0].device();
142 let dtype = inputs[0].dtype();
143
144 let params: Vec<(String, String)> =
145 params.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
146
147 Ok(OpSignature {
148 operation: operation.to_string(),
149 input_shapes,
150 dtype,
151 device,
152 params,
153 })
154 }
155
156 fn check_cache(&self, signature: &OpSignature) -> bool {
158 let cache = self
159 .op_cache
160 .read()
161 .expect("read lock should not be poisoned");
162 cache.contains_key(signature)
163 }
164
165 fn cache_operation<T>(
167 &self,
168 signature: &OpSignature,
169 result: &Tensor<T>,
170 execution_time: Duration,
171 ) -> Result<()> {
172 let mut cache = self
173 .op_cache
174 .write()
175 .expect("write lock should not be poisoned");
176
177 if cache.len() >= self.config.max_cache_size {
179 let oldest_key = cache
181 .iter()
182 .min_by_key(|(_, cached_op)| cached_op.last_used)
183 .map(|(k, _)| k.clone());
184
185 if let Some(key) = oldest_key {
186 cache.remove(&key);
187 }
188 }
189
190 let cached_op = CachedOperation {
192 signature: signature.clone(),
193 result_shape: result.shape().dims().to_vec(),
194 execution_time,
195 memory_usage: result.shape().size() * std::mem::size_of::<T>(),
196 created_at: Instant::now(),
197 last_used: Instant::now(),
198 use_count: 1,
199 };
200
201 cache.insert(signature.clone(), cached_op);
202 Ok(())
203 }
204
205 fn prepare_memory_for_operation(&self, signature: &OpSignature) -> Result<MemoryGuard> {
207 let output_memory_required = self.estimate_output_memory_requirements(signature)?;
209 let intermediate_memory_required =
210 self.estimate_intermediate_memory_requirements(signature)?;
211
212 if output_memory_required > 1024 * 1024 {
214 self.pre_warm_memory_pool(&signature.device, output_memory_required)?;
216 }
217
218 self.optimize_memory_layout_for_operation(signature)?;
220
221 Ok(MemoryGuard {
223 device: signature.device,
224 estimated_memory: output_memory_required + intermediate_memory_required,
225 operation: signature.operation.clone(),
226 })
227 }
228
229 fn estimate_output_memory_requirements(&self, signature: &OpSignature) -> Result<usize> {
231 let element_size = self.get_dtype_size(&signature.dtype);
232
233 let output_elements = match signature.operation.as_str() {
234 "add" | "sub" | "mul" | "div" | "relu" | "sigmoid" | "tanh" | "gelu" => signature
236 .input_shapes
237 .iter()
238 .map(|shape| shape.iter().product::<usize>())
239 .max()
240 .unwrap_or(0),
241
242 "matmul" => {
244 if signature.input_shapes.len() >= 2
245 && signature.input_shapes[0].len() >= 2
246 && signature.input_shapes[1].len() >= 2
247 {
248 let m = signature.input_shapes[0][signature.input_shapes[0].len() - 2];
249 let n = signature.input_shapes[1][signature.input_shapes[1].len() - 1];
250 let batch_size = signature.input_shapes[0]
251 .iter()
252 .take(signature.input_shapes[0].len() - 2)
253 .product::<usize>();
254 batch_size * m * n
255 } else {
256 0
257 }
258 }
259
260 "sum" | "mean" | "max" | "min" => {
262 signature
264 .input_shapes
265 .iter()
266 .map(|shape| shape.iter().product::<usize>() / shape.len().max(1))
267 .sum()
268 }
269
270 "conv2d" => {
272 if !signature.input_shapes.is_empty() && signature.input_shapes[0].len() >= 4 {
273 let batch = signature.input_shapes[0][0];
274 let height = signature.input_shapes[0][2];
275 let width = signature.input_shapes[0][3];
276 let output_channels = signature.input_shapes[0][1]; batch * output_channels * height * width
279 } else {
280 0
281 }
282 }
283
284 _ => {
285 signature
287 .input_shapes
288 .iter()
289 .map(|shape| shape.iter().product::<usize>())
290 .max()
291 .unwrap_or(0)
292 }
293 };
294
295 Ok(output_elements * element_size)
296 }
297
298 fn estimate_intermediate_memory_requirements(&self, signature: &OpSignature) -> Result<usize> {
300 let element_size = self.get_dtype_size(&signature.dtype);
301 let total_input_elements: usize = signature
302 .input_shapes
303 .iter()
304 .map(|shape| shape.iter().product::<usize>())
305 .sum();
306
307 let intermediate_factor = match signature.operation.as_str() {
308 "add" | "sub" | "mul" | "div" => 0.1,
310
311 "relu" | "sigmoid" | "tanh" | "gelu" => 0.2,
313
314 "matmul" => 0.5,
316
317 "batch_norm" | "layer_norm" | "group_norm" => 0.8,
319
320 "conv2d" | "conv3d" => 1.2,
322
323 "sum" | "mean" | "max" | "min" => 0.3,
325
326 _ => 0.5, };
328
329 Ok((total_input_elements as f64 * intermediate_factor * element_size as f64) as usize)
330 }
331
332 fn pre_warm_memory_pool(&self, device: &Device, required_memory: usize) -> Result<()> {
334 if self.config.enable_memory_pool {
336 let warmup_size = required_memory.next_power_of_two();
337
338 let num_blocks = if warmup_size > 1024 * 1024 { 2 } else { 3 }; self.memory_pool.pre_warm(device, warmup_size, num_blocks)?;
344 }
345 Ok(())
346 }
347
348 fn optimize_memory_layout_for_operation(&self, signature: &OpSignature) -> Result<()> {
350 match signature.operation.as_str() {
351 "matmul" | "conv2d" | "conv3d" => {
353 }
356
357 "add" | "sub" | "mul" | "div" => {
359 }
361
362 _ => {
363 }
365 }
366 Ok(())
367 }
368
369 fn get_dtype_size(&self, dtype: &DType) -> usize {
371 match dtype {
372 DType::Float16 => 2,
373 DType::BFloat16 => 2,
374 DType::Float32 => 4,
375 DType::Float64 => 8,
376 DType::Int8 => 1,
377 DType::Int16 => 2,
378 DType::Int32 => 4,
379 DType::Int64 => 8,
380 DType::Int4 => 1, DType::UInt8 => 1,
382 DType::UInt16 => 2,
383 DType::UInt32 => 4,
384 DType::UInt64 => 8,
385 DType::Bool => 1,
386 DType::Complex32 => 8,
387 DType::Complex64 => 16,
388 DType::String => 8, }
390 }
391
392 fn execute_with_context_optimization<T, F>(
394 &self,
395 inputs: &[&Tensor<T>],
396 executor: F,
397 ) -> Result<Tensor<T>>
398 where
399 F: FnOnce(&[&Tensor<T>]) -> Result<Tensor<T>>,
400 {
401 let device = *inputs[0].device();
402
403 {
405 let mut contexts = self
406 .active_contexts
407 .write()
408 .expect("write lock should not be poisoned");
409 if let std::collections::hash_map::Entry::Vacant(e) = contexts.entry(device) {
410 let context = DEVICE_MANAGER.get_context(&device)?;
411 e.insert(context);
412 }
413 }
414
415 executor(inputs)
417 }
418
419 fn cleanup_operation_memory(&self, _signature: &OpSignature) -> Result<()> {
421 Ok(())
423 }
424
425 fn analyze_fusion_opportunity(&self, operation: &str, signature: &OpSignature) {
427 let mut opportunities = self
428 .fusion_opportunities
429 .write()
430 .expect("write lock should not be poisoned");
431
432 let fusion_speedup = match operation {
434 "add" | "sub" | "mul" | "div" => self.calculate_elementwise_fusion_benefit(signature),
436
437 "relu" | "sigmoid" | "tanh" | "gelu" => {
439 self.calculate_activation_fusion_benefit(signature)
440 }
441
442 "batch_norm" | "layer_norm" | "group_norm" => {
444 self.calculate_normalization_fusion_benefit(signature)
445 }
446
447 "matmul" | "conv2d" | "conv3d" => {
449 self.calculate_compute_intensive_fusion_benefit(signature)
450 }
451
452 "sum" | "mean" | "max" | "min" => self.calculate_reduction_fusion_benefit(signature),
454
455 _ => 1.0, };
457
458 if fusion_speedup > 1.1 && opportunities.len() < 50 {
460 let memory_savings = self.estimate_memory_savings(operation, signature);
461
462 if let Some(existing) = opportunities
464 .iter_mut()
465 .find(|opp| self.can_extend_fusion_chain(&opp.operations, operation))
466 {
467 existing.operations.push(operation.to_string());
468 existing.potential_speedup *= fusion_speedup.min(1.5); existing.memory_savings += memory_savings;
470 } else {
471 opportunities.push(FusionOpportunity {
472 operations: vec![operation.to_string()],
473 potential_speedup: fusion_speedup,
474 memory_savings,
475 });
476 }
477 }
478 }
479
480 fn calculate_elementwise_fusion_benefit(&self, signature: &OpSignature) -> f64 {
482 let total_elements: usize = signature
483 .input_shapes
484 .iter()
485 .map(|shape| shape.iter().product::<usize>())
486 .sum();
487
488 if total_elements > 10_000 {
490 1.8 } else if total_elements > 1_000 {
492 1.4 } else {
494 1.1 }
496 }
497
498 #[allow(unused_variables)] fn calculate_activation_fusion_benefit(&self, signature: &OpSignature) -> f64 {
501 let is_gpu = {
504 #[cfg(feature = "gpu")]
505 {
506 matches!(signature.device, Device::Gpu(_))
507 }
508 #[cfg(not(feature = "gpu"))]
509 {
510 false
511 }
512 };
513 if is_gpu {
514 1.6 } else {
516 1.3 }
518 }
519
520 fn calculate_normalization_fusion_benefit(&self, signature: &OpSignature) -> f64 {
522 let input_size: usize = signature
525 .input_shapes
526 .iter()
527 .map(|shape| shape.iter().product::<usize>())
528 .max()
529 .unwrap_or(0);
530
531 if input_size > 50_000 {
532 1.7 } else {
534 1.2 }
536 }
537
538 fn calculate_compute_intensive_fusion_benefit(&self, signature: &OpSignature) -> f64 {
540 let is_large_computation = signature
542 .input_shapes
543 .iter()
544 .any(|shape| shape.iter().product::<usize>() > 100_000);
545
546 if is_large_computation {
547 1.4 } else {
549 1.1 }
551 }
552
553 fn calculate_reduction_fusion_benefit(&self, signature: &OpSignature) -> f64 {
555 let input_size: usize = signature
557 .input_shapes
558 .iter()
559 .map(|shape| shape.iter().product::<usize>())
560 .max()
561 .unwrap_or(0);
562
563 if input_size > 20_000 {
564 1.5 } else {
566 1.2 }
568 }
569
570 fn estimate_memory_savings(&self, operation: &str, signature: &OpSignature) -> usize {
572 let element_size = match signature.dtype {
573 DType::Float16 => 2,
574 DType::BFloat16 => 2,
575 DType::Float32 => 4,
576 DType::Float64 => 8,
577 DType::Int8 => 1,
578 DType::Int16 => 2,
579 DType::Int32 => 4,
580 DType::Int64 => 8,
581 DType::Int4 => 1, DType::UInt8 => 1,
583 DType::UInt16 => 2,
584 DType::UInt32 => 4,
585 DType::UInt64 => 8,
586 DType::Bool => 1,
587 DType::Complex32 => 8,
588 DType::Complex64 => 16,
589 DType::String => 8, };
591
592 let total_elements: usize = signature
593 .input_shapes
594 .iter()
595 .map(|shape| shape.iter().product::<usize>())
596 .sum();
597
598 match operation {
600 "add" | "sub" | "mul" | "div" => total_elements * element_size, "relu" | "sigmoid" | "tanh" => total_elements * element_size / 2, "batch_norm" | "layer_norm" => total_elements * element_size * 2, _ => total_elements * element_size / 4, }
605 }
606
607 fn can_extend_fusion_chain(&self, existing_ops: &[String], new_op: &str) -> bool {
609 if existing_ops.is_empty() {
610 return false;
611 }
612
613 let last_op = &existing_ops[existing_ops.len() - 1];
614
615 match (last_op.as_str(), new_op) {
617 ("add" | "sub" | "mul" | "div", "add" | "sub" | "mul" | "div") => true,
619
620 ("matmul" | "conv2d" | "conv3d", "relu" | "sigmoid" | "tanh" | "gelu") => true,
622 ("add" | "sub", "relu" | "sigmoid" | "tanh" | "gelu") => true,
623
624 ("relu" | "sigmoid" | "tanh" | "gelu", "batch_norm" | "layer_norm") => true,
626
627 (_, "sum" | "mean" | "max" | "min") => existing_ops.len() < 3, _ => false,
631 }
632 }
633
634 pub fn get_metrics(&self) -> Vec<ExecutionMetrics> {
636 self.metrics
637 .lock()
638 .expect("lock should not be poisoned")
639 .clone()
640 }
641
642 pub fn get_cache_stats(&self) -> CacheStatistics {
644 let cache = self
645 .op_cache
646 .read()
647 .expect("read lock should not be poisoned");
648
649 let total_entries = cache.len();
650 let total_hits = cache.values().map(|op| op.use_count).sum();
651 let avg_execution_time = if total_entries > 0 {
652 cache
653 .values()
654 .map(|op| op.execution_time.as_nanos())
655 .sum::<u128>()
656 / total_entries as u128
657 } else {
658 0
659 };
660
661 CacheStatistics {
662 total_entries,
663 total_hits,
664 hit_rate: if total_hits > 0 {
665 cache.len() as f64 / total_hits as f64
666 } else {
667 0.0
668 },
669 avg_execution_time: Duration::from_nanos(avg_execution_time as u64),
670 }
671 }
672
673 pub fn generate_performance_report(&self) -> EagerPerformanceReport {
675 let metrics = self.get_metrics();
676 let cache_stats = self.get_cache_stats();
677
678 if metrics.is_empty() {
679 return EagerPerformanceReport::default();
680 }
681
682 let total_operations = metrics.len();
683 let meets_target = metrics.iter().filter(|m| m.meets_target).count();
684 let success_rate = meets_target as f64 / total_operations as f64;
685
686 let avg_overhead = Duration::from_nanos(
687 (metrics
688 .iter()
689 .map(|m| m.total_overhead.as_nanos())
690 .sum::<u128>()
691 / total_operations as u128) as u64,
692 );
693
694 let min_overhead = metrics
695 .iter()
696 .map(|m| m.total_overhead)
697 .min()
698 .unwrap_or(Duration::ZERO);
699
700 let max_overhead = metrics
701 .iter()
702 .map(|m| m.total_overhead)
703 .max()
704 .unwrap_or(Duration::ZERO);
705
706 let cache_hit_rate =
707 metrics.iter().filter(|m| m.cache_hit).count() as f64 / total_operations as f64;
708
709 EagerPerformanceReport {
710 total_operations,
711 operations_meeting_target: meets_target,
712 success_rate,
713 avg_overhead,
714 min_overhead,
715 max_overhead,
716 cache_statistics: cache_stats,
717 cache_hit_rate,
718 target_overhead: Duration::from_nanos(self.config.target_overhead_ns),
719 recommendations: self.generate_recommendations(&metrics),
720 }
721 }
722
723 fn generate_recommendations(&self, metrics: &[ExecutionMetrics]) -> Vec<String> {
725 let mut recommendations = Vec::new();
726
727 let avg_overhead = if !metrics.is_empty() {
728 metrics
729 .iter()
730 .map(|m| m.total_overhead.as_nanos())
731 .sum::<u128>()
732 / metrics.len() as u128
733 } else {
734 0
735 };
736
737 if avg_overhead > self.config.target_overhead_ns as u128 {
738 recommendations
739 .push("Consider enabling operation caching to reduce setup overhead".to_string());
740 recommendations.push("Enable memory pooling to reduce allocation overhead".to_string());
741 }
742
743 let cache_hit_rate = if !metrics.is_empty() {
744 metrics.iter().filter(|m| m.cache_hit).count() as f64 / metrics.len() as f64
745 } else {
746 0.0
747 };
748
749 if cache_hit_rate < 0.3 {
750 recommendations.push("Increase cache size to improve hit rates".to_string());
751 }
752
753 let high_setup_ops = metrics
754 .iter()
755 .filter(|m| m.setup_time > Duration::from_micros(100))
756 .count();
757
758 if high_setup_ops > metrics.len() / 4 {
759 recommendations
760 .push("Enable context optimization to reduce setup overhead".to_string());
761 }
762
763 recommendations
764 }
765
766 pub fn cleanup(&self) {
768 self.memory_pool.cleanup_old_blocks();
770
771 let threshold = Duration::from_secs(300); let now = Instant::now();
774
775 let mut cache = self
776 .op_cache
777 .write()
778 .expect("write lock should not be poisoned");
779 cache.retain(|_, cached_op| now.duration_since(cached_op.last_used) <= threshold);
780 }
781}