1use super::dynamic::{DynamicExecutionContext, DynamicOp, ExecutionPlan, JitCompiler};
5use crate::error::{RusTorchError, RusTorchResult};
6use crate::tensor::Tensor;
7use num_traits::Float;
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10use std::time::Instant;
11
12pub struct RuntimeEngine<T: Float + Send + Sync + 'static> {
15 pub context: DynamicExecutionContext<T>,
17 jit_compiler: JitCompiler<T>,
19 pub execution_cache: HashMap<String, CachedExecution<T>>,
21 pub config: RuntimeConfig,
23 metrics: Arc<RwLock<RuntimeMetrics>>,
25}
26
27#[derive(Debug, Clone)]
30pub struct RuntimeConfig {
31 pub enable_jit: bool,
33 pub enable_fusion: bool,
35 pub enable_memory_opt: bool,
37 pub enable_parallel: bool,
39 pub max_cache_size: usize,
41 pub jit_threshold: usize,
43}
44
45impl Default for RuntimeConfig {
46 fn default() -> Self {
47 RuntimeConfig {
48 enable_jit: true,
49 enable_fusion: true,
50 enable_memory_opt: true,
51 enable_parallel: true,
52 max_cache_size: 1000,
53 jit_threshold: 5,
54 }
55 }
56}
57
58#[derive(Clone)]
61pub struct CachedExecution<T: Float + Send + Sync + 'static> {
62 pub plan: ExecutionPlan<T>,
64 pub input_shapes: Vec<Vec<usize>>,
66 pub output_shape: Vec<usize>,
68 pub hit_count: usize,
70 pub last_used: Instant,
72}
73
74#[derive(Debug, Default, Clone)]
77pub struct RuntimeMetrics {
78 pub total_executions: usize,
80 pub cache_hit_rate: f64,
82 pub avg_execution_time: std::time::Duration,
84 pub jit_stats: JitCompilationMetrics,
86 pub memory_stats: MemoryMetrics,
88 pub parallel_stats: ParallelExecutionMetrics,
90}
91
92#[derive(Debug, Default, Clone)]
95pub struct JitCompilationMetrics {
96 pub total_compilations: usize,
98 pub successful_compilations: usize,
100 pub avg_compilation_time: std::time::Duration,
102 pub avg_speedup: f64,
104}
105
106#[derive(Debug, Default, Clone)]
109pub struct MemoryMetrics {
110 pub peak_memory: usize,
112 pub current_memory: usize,
114 pub memory_efficiency: f64,
116 pub allocations: usize,
118 pub deallocations: usize,
120}
121
122#[derive(Debug, Default, Clone)]
125pub struct ParallelExecutionMetrics {
126 pub parallel_opportunities: usize,
128 pub parallel_executions: usize,
130 pub avg_parallelism: f64,
132 pub parallel_efficiency: f64,
134}
135
136impl<T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive>
137 RuntimeEngine<T>
138{
139 pub fn new(config: RuntimeConfig) -> Self {
141 RuntimeEngine {
142 context: DynamicExecutionContext::new(),
143 jit_compiler: JitCompiler::new(),
144 execution_cache: HashMap::new(),
145 config,
146 metrics: Arc::new(RwLock::new(RuntimeMetrics::default())),
147 }
148 }
149
150 pub fn execute_graph(
152 &mut self,
153 graph_builder: impl FnOnce(&mut GraphBuilder<T>) -> RusTorchResult<usize>,
154 ) -> RusTorchResult<Tensor<T>> {
155 let start_time = Instant::now();
156
157 let mut builder = GraphBuilder::new(&mut self.context);
159 let output_node_id = graph_builder(&mut builder)?;
160
161 let pattern_key = self.generate_pattern_key(output_node_id)?;
163
164 if self.execution_cache.contains_key(&pattern_key) {
165 if let Some(cached) = self.execution_cache.get_mut(&pattern_key) {
167 cached.hit_count += 1;
168 cached.last_used = Instant::now();
169 }
170
171 let mut metrics = self.metrics.write().unwrap();
173 metrics.cache_hit_rate = (metrics.cache_hit_rate * metrics.total_executions as f64
174 + 1.0)
175 / (metrics.total_executions as f64 + 1.0);
176 }
177
178 let execution_plan = self.context.create_execution_plan(output_node_id)?;
180
181 if self.config.enable_jit && execution_plan.operations.len() >= self.config.jit_threshold {
183 self.apply_jit_compilation(&execution_plan)?;
184 }
185
186 let result = self.context.execute(output_node_id)?;
188
189 let mut metrics = self.metrics.write().unwrap();
191 metrics.total_executions += 1;
192 metrics.memory_stats.allocations += 1;
193
194 let estimated_memory = result.data.len() * std::mem::size_of::<T>();
196 if estimated_memory > metrics.memory_stats.peak_memory {
197 metrics.memory_stats.peak_memory = estimated_memory;
198 }
199
200 metrics.memory_stats.memory_efficiency =
202 metrics.memory_stats.allocations as f64 / (metrics.total_executions as f64 + 1.0);
203
204 metrics.avg_execution_time = (metrics.avg_execution_time
205 * (metrics.total_executions - 1) as u32
206 + start_time.elapsed())
207 / metrics.total_executions as u32;
208
209 Ok(result)
210 }
211
212 fn generate_pattern_key(&self, output_node_id: usize) -> RusTorchResult<String> {
214 let mut pattern_parts = Vec::new();
216 self.collect_pattern_recursive(
217 output_node_id,
218 &mut pattern_parts,
219 &mut std::collections::HashSet::new(),
220 )?;
221 Ok(pattern_parts.join("->"))
222 }
223
224 fn collect_pattern_recursive(
226 &self,
227 node_id: usize,
228 pattern: &mut Vec<String>,
229 visited: &mut std::collections::HashSet<usize>,
230 ) -> RusTorchResult<()> {
231 if visited.contains(&node_id) {
232 return Ok(());
233 }
234 visited.insert(node_id);
235
236 if let Some(node) = self.context.get_dynamic_node(&node_id) {
237 pattern.push(format!("{:?}", node.op));
239
240 for input_node in &node.inputs {
242 self.collect_pattern_recursive(input_node.id, pattern, visited)?;
243 }
244 }
245
246 Ok(())
247 }
248
249 fn apply_jit_compilation(&mut self, plan: &ExecutionPlan<T>) -> RusTorchResult<()> {
251 let ops: Vec<DynamicOp> = plan.operations.iter().map(|op| op.op.clone()).collect();
253
254 if ops.len() >= self.config.jit_threshold {
255 let start_time = Instant::now();
256 let _compiled_fn = self.jit_compiler.compile_operations(&ops)?;
257
258 let mut metrics = self.metrics.write().unwrap();
260 metrics.jit_stats.total_compilations += 1;
261 metrics.jit_stats.avg_compilation_time = (metrics.jit_stats.avg_compilation_time
262 * (metrics.jit_stats.total_compilations - 1) as u32
263 + start_time.elapsed())
264 / metrics.jit_stats.total_compilations as u32;
265 }
266
267 Ok(())
268 }
269
270 pub fn get_metrics(&self) -> RuntimeMetrics {
272 self.metrics.read().unwrap().clone()
273 }
274
275 pub fn reset_metrics(&mut self) {
277 *self.metrics.write().unwrap() = RuntimeMetrics::default();
278 }
279
280 pub fn warmup(&mut self) -> RusTorchResult<()> {
282 let common_patterns = vec![
284 vec![
285 DynamicOp::Linear {
286 in_features: 784,
287 out_features: 128,
288 },
289 DynamicOp::ReLU,
290 ],
291 vec![
292 DynamicOp::Conv2d {
293 kernel_size: (3, 3),
294 stride: (1, 1),
295 padding: (1, 1),
296 },
297 DynamicOp::ReLU,
298 ],
299 vec![DynamicOp::Add, DynamicOp::ReLU],
300 vec![DynamicOp::MatMul, DynamicOp::Sigmoid],
301 ];
302
303 for pattern in common_patterns {
304 self.jit_compiler.compile_operations(&pattern)?;
305
306 let mut metrics = self.metrics.write().unwrap();
308 metrics.jit_stats.total_compilations += 1;
309 metrics.jit_stats.successful_compilations += 1;
310 }
311
312 Ok(())
313 }
314
315 pub fn cleanup_cache(&mut self) {
317 let now = Instant::now();
318 let max_age = std::time::Duration::from_secs(3600); self.execution_cache
321 .retain(|_, cached| now.duration_since(cached.last_used) < max_age);
322
323 if self.execution_cache.len() > self.config.max_cache_size {
325 let entries: Vec<_> = self
327 .execution_cache
328 .iter()
329 .map(|(k, v)| (k.clone(), v.last_used))
330 .collect();
331 let mut sorted_entries = entries;
332 sorted_entries.sort_by_key(|(_, last_used)| *last_used);
333
334 let to_remove = sorted_entries.len() - self.config.max_cache_size;
335 for (key, _) in sorted_entries.into_iter().take(to_remove) {
336 self.execution_cache.remove(&key);
337 }
338 }
339 }
340
341 pub fn profile_execution(&mut self, iterations: usize) -> RusTorchResult<ProfileResult> {
343 let mut profile_result = ProfileResult::new();
344
345 for i in 0..iterations {
346 let start_time = Instant::now();
347
348 let result = self.execute_graph(|builder| {
350 let input1 = builder.add_input(Tensor::ones(&[32, 784]))?;
351 let weight1 = builder.add_parameter(Tensor::ones(&[128, 784]))?;
352 let bias1 = builder.add_parameter(Tensor::ones(&[128]))?;
353
354 let linear1 = builder.add_operation(
355 DynamicOp::Linear {
356 in_features: 784,
357 out_features: 128,
358 },
359 vec![input1, weight1, bias1],
360 )?;
361
362 let relu1 = builder.add_operation(DynamicOp::ReLU, vec![linear1])?;
363
364 let weight2 = builder.add_parameter(Tensor::ones(&[10, 128]))?;
365 let bias2 = builder.add_parameter(Tensor::ones(&[10]))?;
366
367 let output = builder.add_operation(
368 DynamicOp::Linear {
369 in_features: 128,
370 out_features: 10,
371 },
372 vec![relu1, weight2, bias2],
373 )?;
374
375 Ok(output)
376 })?;
377
378 let execution_time = start_time.elapsed();
379 profile_result.add_execution_time(execution_time);
380
381 if i % 100 == 0 {
382 println!(
383 "Profile iteration {}/{}: {:?}",
384 i + 1,
385 iterations,
386 execution_time
387 );
388 }
389 }
390
391 profile_result.analyze_performance(&self.get_metrics());
393
394 Ok(profile_result)
395 }
396}
397
398pub struct GraphBuilder<'a, T: Float + Send + Sync + 'static> {
401 context: &'a mut DynamicExecutionContext<T>,
402}
403
404impl<'a, T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive>
405 GraphBuilder<'a, T>
406{
407 pub fn new(context: &'a mut DynamicExecutionContext<T>) -> Self {
409 GraphBuilder { context }
410 }
411
412 pub fn add_input(&mut self, tensor: Tensor<T>) -> RusTorchResult<usize> {
414 self.context.add_leaf(tensor)
415 }
416
417 pub fn add_parameter(&mut self, tensor: Tensor<T>) -> RusTorchResult<usize> {
419 self.context.add_leaf(tensor)
420 }
421
422 pub fn add_operation(&mut self, op: DynamicOp, inputs: Vec<usize>) -> RusTorchResult<usize> {
424 self.context.add_operation(op, inputs)
425 }
426
427 pub fn linear(
429 &mut self,
430 input: usize,
431 weight: usize,
432 bias: Option<usize>,
433 ) -> RusTorchResult<usize> {
434 let inputs = if let Some(b) = bias {
435 vec![input, weight, b]
436 } else {
437 vec![input, weight]
438 };
439
440 if let Some(weight_node) = self.context.get_dynamic_node(&weight) {
442 if let Some(weight_tensor) = weight_node.get_cached_output() {
443 let shape = weight_tensor.shape();
444 if shape.len() == 2 && shape[0] > 0 && shape[1] > 0 {
445 return self.add_operation(
446 DynamicOp::Linear {
447 in_features: shape[1],
448 out_features: shape[0],
449 },
450 inputs,
451 );
452 }
453 }
454 }
455
456 self.add_operation(
458 DynamicOp::Linear {
459 in_features: 784,
460 out_features: 128,
461 },
462 inputs,
463 )
464 }
465
466 pub fn conv2d(
468 &mut self,
469 input: usize,
470 weight: usize,
471 kernel_size: (usize, usize),
472 stride: (usize, usize),
473 padding: (usize, usize),
474 ) -> RusTorchResult<usize> {
475 self.add_operation(
476 DynamicOp::Conv2d {
477 kernel_size,
478 stride,
479 padding,
480 },
481 vec![input, weight],
482 )
483 }
484
485 pub fn relu(&mut self, input: usize) -> RusTorchResult<usize> {
487 self.add_operation(DynamicOp::ReLU, vec![input])
488 }
489
490 pub fn sigmoid(&mut self, input: usize) -> RusTorchResult<usize> {
492 self.add_operation(DynamicOp::Sigmoid, vec![input])
493 }
494
495 pub fn add(&mut self, input1: usize, input2: usize) -> RusTorchResult<usize> {
497 if let (Some(node1), Some(node2)) = (
499 self.context.get_dynamic_node(&input1),
500 self.context.get_dynamic_node(&input2),
501 ) {
502 if let (Some(tensor1), Some(tensor2)) =
503 (node1.get_cached_output(), node2.get_cached_output())
504 {
505 let shape1 = tensor1.shape();
506 let shape2 = tensor2.shape();
507
508 if shape1 != shape2 && !Self::can_broadcast(shape1, shape2) {
510 return Err(RusTorchError::shape_mismatch(shape1, shape2));
511 }
512 }
513 }
514
515 self.add_operation(DynamicOp::Add, vec![input1, input2])
516 }
517
518 fn can_broadcast(shape1: &[usize], shape2: &[usize]) -> bool {
520 let (s1, s2) = if shape1.len() > shape2.len() {
521 (shape1, shape2)
522 } else {
523 (shape2, shape1)
524 };
525
526 for (i, (&dim2, &dim1)) in s2.iter().rev().zip(s1.iter().rev()).enumerate() {
527 if dim2 != 1 && dim1 != 1 && dim2 != dim1 {
528 return false;
529 }
530 }
531 true
532 }
533
534 pub fn matmul(&mut self, input1: usize, input2: usize) -> RusTorchResult<usize> {
536 self.add_operation(DynamicOp::MatMul, vec![input1, input2])
537 }
538
539 pub fn reshape(&mut self, input: usize, shape: Vec<usize>) -> RusTorchResult<usize> {
541 self.add_operation(DynamicOp::Reshape { shape }, vec![input])
542 }
543}
544
545pub struct ProfileResult {
548 execution_times: Vec<std::time::Duration>,
550 recommendations: Vec<String>,
552 bottlenecks: Vec<BottleneckInfo>,
554}
555
556#[derive(Debug)]
559pub struct BottleneckInfo {
560 pub operation: String,
562 pub time_percentage: f64,
564 pub recommendation: String,
566}
567
568impl ProfileResult {
569 pub fn new() -> Self {
571 ProfileResult {
572 execution_times: Vec::new(),
573 recommendations: Vec::new(),
574 bottlenecks: Vec::new(),
575 }
576 }
577
578 pub fn add_execution_time(&mut self, time: std::time::Duration) {
580 self.execution_times.push(time);
581 }
582
583 pub fn analyze_performance(&mut self, metrics: &RuntimeMetrics) {
585 let avg_time = if !self.execution_times.is_empty() {
587 self.execution_times.iter().sum::<std::time::Duration>()
588 / self.execution_times.len() as u32
589 } else {
590 std::time::Duration::default()
591 };
592 let min_time = self
593 .execution_times
594 .iter()
595 .min()
596 .copied()
597 .unwrap_or_default();
598 let max_time = self
599 .execution_times
600 .iter()
601 .max()
602 .copied()
603 .unwrap_or_default();
604
605 if metrics.cache_hit_rate < 0.5 {
607 self.recommendations.push(
608 "Consider increasing cache size or improving cache key generation".to_string(),
609 );
610 }
611
612 if metrics.jit_stats.avg_speedup < 2.0 {
613 self.recommendations
614 .push("JIT compilation showing limited benefit, consider disabling".to_string());
615 }
616
617 if metrics.memory_stats.memory_efficiency < 0.7 {
618 self.recommendations
619 .push("Memory efficiency low, consider memory pooling optimization".to_string());
620 }
621
622 if metrics.parallel_stats.parallel_efficiency < 0.6 {
623 self.recommendations.push(
624 "Parallel execution efficiency low, review operation dependencies".to_string(),
625 );
626 }
627
628 if max_time > avg_time * 2 {
630 self.bottlenecks.push(BottleneckInfo {
631 operation: "Variable execution time".to_string(),
632 time_percentage: ((max_time.as_nanos() - min_time.as_nanos()) as f64
633 / max_time.as_nanos() as f64)
634 * 100.0,
635 recommendation: "Investigate inconsistent operation performance".to_string(),
636 });
637 }
638 }
639
640 pub fn summary(&self) -> String {
642 let avg_time = if !self.execution_times.is_empty() {
643 self.execution_times.iter().sum::<std::time::Duration>()
644 / self.execution_times.len() as u32
645 } else {
646 std::time::Duration::default()
647 };
648
649 format!(
650 "Performance Profile Summary:\n\
651 - Executions: {}\n\
652 - Average time: {:?}\n\
653 - Recommendations: {}\n\
654 - Bottlenecks: {}",
655 self.execution_times.len(),
656 avg_time,
657 self.recommendations.len(),
658 self.bottlenecks.len()
659 )
660 }
661}
662
663#[cfg(test)]
664mod tests {
665 use super::*;
666
667 #[test]
668 fn test_runtime_engine_creation() {
669 let config = RuntimeConfig::default();
670 let _engine = RuntimeEngine::<f32>::new(config);
671 }
672
673 #[test]
674 fn test_graph_builder() {
675 let config = RuntimeConfig::default();
676 let mut engine = RuntimeEngine::<f32>::new(config);
677
678 let result = engine.execute_graph(|builder| {
679 let input = builder.add_input(Tensor::ones(&[2, 3]))?;
680 let weight = builder.add_parameter(Tensor::ones(&[4, 3]))?;
681 let output = builder.linear(input, weight, None)?;
682 Ok(output)
683 });
684
685 match result {
686 Ok(_) => {}
687 Err(e) => panic!("Test failed with error: {:?}", e),
688 }
689 }
690
691 #[test]
692 fn test_warmup() {
693 let config = RuntimeConfig::default();
694 let mut engine = RuntimeEngine::<f32>::new(config);
695
696 engine.warmup().unwrap();
697
698 assert!(engine.jit_compiler.get_stats().compilations > 0);
700 }
701
702 #[test]
703 fn test_cache_cleanup() {
704 let mut config = RuntimeConfig::default();
705 config.max_cache_size = 2;
706 let mut engine = RuntimeEngine::<f32>::new(config);
707
708 for i in 0..5 {
710 let _result = engine
711 .execute_graph(|builder| {
712 let input = builder.add_input(Tensor::ones(&[i + 1, 3]))?;
713 let output = builder.relu(input)?;
714 Ok(output)
715 })
716 .unwrap();
717 }
718
719 engine.cleanup_cache();
720 assert!(engine.execution_cache.len() <= 2);
721 }
722
723 #[test]
724 fn test_profiling() {
725 let config = RuntimeConfig::default();
726 let mut engine = RuntimeEngine::<f32>::new(config);
727
728 let profile_result = engine.profile_execution(3).unwrap();
729 let summary = profile_result.summary();
730
731 assert!(summary.contains("Executions: 3"));
732 }
733}