1use crate::autograd::function::Function;
5use crate::autograd::graph::{ComputationGraph, GraphNode};
6use crate::error::{RusTorchError, RusTorchResult};
7use crate::tensor::Tensor;
8use num_traits::Float;
9use std::collections::{HashMap, VecDeque};
10use std::sync::{Arc, RwLock, Weak};
11use std::time::Instant;
12
13#[derive(Debug, Clone, PartialEq)]
16pub enum DynamicOp {
17 MatMul,
19 Add,
21 Mul,
23 ReLU,
25 Sigmoid,
27 Conv2d {
29 kernel_size: (usize, usize),
30 stride: (usize, usize),
31 padding: (usize, usize),
32 },
33 Linear {
35 in_features: usize,
36 out_features: usize,
37 },
38 BatchNorm { num_features: usize },
40 Dropout { p: f64 },
42 Reshape { shape: Vec<usize> },
44 Custom(String),
46}
47
48pub struct DynamicNode<T: Float + Send + Sync + 'static> {
51 pub op: DynamicOp,
53 pub inputs: Vec<Arc<DynamicNode<T>>>,
55 pub cached_output: RwLock<Option<Tensor<T>>>,
57 pub dirty: RwLock<bool>,
59 pub id: usize,
61 pub execution_time: RwLock<Option<std::time::Duration>>,
63 pub memory_usage: RwLock<Option<usize>>,
65}
66
67impl<T: Float + Send + Sync + 'static> DynamicNode<T> {
68 pub fn new(op: DynamicOp, inputs: Vec<Arc<DynamicNode<T>>>, id: usize) -> Arc<Self> {
70 Arc::new(DynamicNode {
71 op,
72 inputs,
73 cached_output: RwLock::new(None),
74 dirty: RwLock::new(true),
75 id,
76 execution_time: RwLock::new(None),
77 memory_usage: RwLock::new(None),
78 })
79 }
80
81 pub fn mark_dirty(&self) {
83 *self.dirty.write().unwrap() = true;
84 *self.cached_output.write().unwrap() = None;
85 }
86
87 pub fn is_dirty(&self) -> bool {
89 *self.dirty.read().unwrap()
90 }
91
92 pub fn get_cached_output(&self) -> Option<Tensor<T>> {
94 self.cached_output.read().unwrap().clone()
95 }
96
97 pub fn set_cached_output(&self, output: Tensor<T>) {
99 *self.cached_output.write().unwrap() = Some(output);
100 *self.dirty.write().unwrap() = false;
101 }
102}
103
104pub struct DynamicExecutionContext<T: Float + Send + Sync + 'static> {
107 graph: Arc<RwLock<ComputationGraph<T>>>,
109 dynamic_nodes: HashMap<usize, Arc<DynamicNode<T>>>,
111 execution_order: RwLock<Option<Vec<usize>>>,
113 compiled_ops: HashMap<Vec<DynamicOp>, Arc<dyn Function<T>>>,
115 next_node_id: usize,
117 stats: DynamicExecutionStats,
119}
120
121#[derive(Debug, Default)]
124pub struct DynamicExecutionStats {
125 pub total_ops: usize,
127 pub cache_hit_rate: f64,
129 pub total_execution_time: std::time::Duration,
131 pub memory_allocations: usize,
133 pub jit_compilations: usize,
135}
136
137impl<T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive>
138 DynamicExecutionContext<T>
139{
140 pub fn new() -> Self {
142 DynamicExecutionContext {
143 graph: Arc::new(RwLock::new(ComputationGraph::new())),
144 dynamic_nodes: HashMap::new(),
145 execution_order: RwLock::new(None),
146 compiled_ops: HashMap::new(),
147 next_node_id: 0,
148 stats: DynamicExecutionStats::default(),
149 }
150 }
151
152 pub fn add_operation(&mut self, op: DynamicOp, input_ids: Vec<usize>) -> RusTorchResult<usize> {
154 let node_id = self.next_node_id;
155 self.next_node_id += 1;
156
157 let input_nodes: Vec<Arc<DynamicNode<T>>> = input_ids
159 .iter()
160 .filter_map(|&id| self.dynamic_nodes.get(&id).cloned())
161 .collect();
162
163 if input_nodes.len() != input_ids.len() {
164 return Err(RusTorchError::tensor_op("Some input nodes not found"));
165 }
166
167 let dynamic_node = DynamicNode::new(op, input_nodes, node_id);
169 self.dynamic_nodes.insert(node_id, dynamic_node);
170
171 *self.execution_order.write().unwrap() = None;
173
174 Ok(node_id)
175 }
176
177 pub fn add_leaf(&mut self, tensor: Tensor<T>) -> RusTorchResult<usize> {
179 let node_id = self.next_node_id;
180 self.next_node_id += 1;
181
182 let dynamic_node = DynamicNode::new(DynamicOp::Custom("leaf".to_string()), vec![], node_id);
184 dynamic_node.set_cached_output(tensor);
185
186 self.dynamic_nodes.insert(node_id, dynamic_node);
187
188 Ok(node_id)
189 }
190
191 pub fn get_dynamic_node(&self, id: &usize) -> Option<&Arc<DynamicNode<T>>> {
193 self.dynamic_nodes.get(id)
194 }
195
196 pub fn execute(&mut self, output_node_id: usize) -> RusTorchResult<Tensor<T>> {
198 let start_time = Instant::now();
199
200 self.build_execution_order(output_node_id)?;
202
203 let execution_order = self
204 .execution_order
205 .read()
206 .unwrap()
207 .clone()
208 .ok_or_else(|| RusTorchError::tensor_op("Failed to build execution order"))?;
209
210 for &node_id in &execution_order {
212 if let Some(node) = self.dynamic_nodes.get(&node_id).cloned() {
213 if node.is_dirty() || node.get_cached_output().is_none() {
214 let output = self.execute_node(&node)?;
215 node.set_cached_output(output);
216 self.stats.total_ops += 1;
217 } else {
218 self.stats.cache_hit_rate =
220 (self.stats.cache_hit_rate * (self.stats.total_ops as f64) + 1.0)
221 / (self.stats.total_ops as f64 + 1.0);
222 }
223 }
224 }
225
226 self.stats.total_execution_time += start_time.elapsed();
228
229 if let Some(output_node) = self.dynamic_nodes.get(&output_node_id) {
231 output_node
232 .get_cached_output()
233 .ok_or_else(|| RusTorchError::tensor_op("Output node has no result"))
234 } else {
235 Err(RusTorchError::tensor_op("Output node not found"))
236 }
237 }
238
239 pub fn execute_node(&self, node: &DynamicNode<T>) -> RusTorchResult<Tensor<T>> {
241 let start_time = Instant::now();
242
243 let mut input_tensors = Vec::new();
245 for input_node in &node.inputs {
246 if let Some(tensor) = input_node.get_cached_output() {
247 input_tensors.push(tensor);
248 } else {
249 return Err(RusTorchError::tensor_op(format!(
250 "Input node {} has no cached output",
251 input_node.id
252 )));
253 }
254 }
255
256 let output = match &node.op {
258 DynamicOp::Add => {
259 if input_tensors.len() != 2 {
260 return Err(RusTorchError::tensor_op("Add requires 2 inputs"));
261 }
262 &input_tensors[0] + &input_tensors[1]
263 }
264 DynamicOp::Mul => {
265 if input_tensors.len() != 2 {
266 return Err(RusTorchError::tensor_op("Mul requires 2 inputs"));
267 }
268 &input_tensors[0] * &input_tensors[1]
269 }
270 DynamicOp::MatMul => {
271 if input_tensors.len() != 2 {
272 return Err(RusTorchError::tensor_op("MatMul requires 2 inputs"));
273 }
274 input_tensors[0].matmul(&input_tensors[1])?
275 }
276 DynamicOp::ReLU => {
277 if input_tensors.len() != 1 {
278 return Err(RusTorchError::tensor_op("ReLU requires 1 input"));
279 }
280 let input_data = &input_tensors[0].data;
282 let relu_data: Vec<T> = input_data
283 .iter()
284 .map(|&x| if x > T::zero() { x } else { T::zero() })
285 .collect();
286 Tensor::from_vec(relu_data, input_tensors[0].shape().to_vec())
287 }
288 DynamicOp::Sigmoid => {
289 if input_tensors.len() != 1 {
290 return Err(RusTorchError::tensor_op("Sigmoid requires 1 input"));
291 }
292 let input_data = &input_tensors[0].data;
294 let sigmoid_data: Vec<T> = input_data
295 .iter()
296 .map(|&x| T::one() / (T::one() + (-x).exp()))
297 .collect();
298 Tensor::from_vec(sigmoid_data, input_tensors[0].shape().to_vec())
299 }
300 DynamicOp::Reshape { shape } => {
301 if input_tensors.len() != 1 {
302 return Err(RusTorchError::tensor_op("Reshape requires 1 input"));
303 }
304 input_tensors[0].reshape(shape)?
305 }
306 DynamicOp::Linear {
307 in_features: _,
308 out_features: _,
309 } => {
310 if input_tensors.len() < 2 || input_tensors.len() > 3 {
311 return Err(RusTorchError::tensor_op(
312 "Linear requires 2-3 inputs (input, weight, [bias])",
313 ));
314 }
315 self.execute_linear(&input_tensors)?
316 }
317 DynamicOp::Conv2d {
318 kernel_size: _,
319 stride: _,
320 padding: _,
321 } => {
322 if input_tensors.len() != 2 {
323 return Err(RusTorchError::tensor_op(
324 "Conv2d requires 2 inputs (input, weight)",
325 ));
326 }
327 self.execute_conv2d(&input_tensors)?
328 }
329 _ => {
330 return Err(RusTorchError::tensor_op(format!(
331 "Operation {:?} not implemented yet",
332 node.op
333 )));
334 }
335 };
336
337 let execution_time = start_time.elapsed();
339 *node.execution_time.write().unwrap() = Some(execution_time);
340
341 let memory_usage = output.data.len() * std::mem::size_of::<T>();
343 *node.memory_usage.write().unwrap() = Some(memory_usage);
344
345 Ok(output)
346 }
347
348 fn execute_linear(&self, inputs: &[Tensor<T>]) -> RusTorchResult<Tensor<T>> {
350 let input = &inputs[0];
351 let weight = &inputs[1];
352 let bias = inputs.get(2);
353
354 let mut output = input.matmul(&weight.transpose()?)?;
356
357 if let Some(bias_tensor) = bias {
359 output = &output + bias_tensor;
360 }
361
362 Ok(output)
363 }
364
365 fn execute_conv2d(&self, inputs: &[Tensor<T>]) -> RusTorchResult<Tensor<T>> {
367 let input = &inputs[0];
368 let weight = &inputs[1];
369
370 let input_shape = input.shape();
373 let weight_shape = weight.shape();
374
375 let batch_size = input_shape[0];
378 let in_channels = input_shape[1];
379 let out_channels = weight_shape[0];
380
381 let output_data =
383 vec![T::one(); batch_size * out_channels * input_shape[2] * input_shape[3]];
384 let output = Tensor::from_vec(
385 output_data,
386 vec![batch_size, out_channels, input_shape[2], input_shape[3]],
387 );
388
389 Ok(output)
390 }
391
392 fn build_execution_order(&mut self, output_node_id: usize) -> RusTorchResult<()> {
394 let mut visited = std::collections::HashSet::new();
395 let mut temp_visited = std::collections::HashSet::new();
396 let mut order = Vec::new();
397
398 self.topological_sort(output_node_id, &mut visited, &mut temp_visited, &mut order)?;
399
400 *self.execution_order.write().unwrap() = Some(order);
401 Ok(())
402 }
403
404 fn topological_sort(
406 &self,
407 node_id: usize,
408 visited: &mut std::collections::HashSet<usize>,
409 temp_visited: &mut std::collections::HashSet<usize>,
410 order: &mut Vec<usize>,
411 ) -> RusTorchResult<()> {
412 if temp_visited.contains(&node_id) {
413 return Err(RusTorchError::tensor_op("Circular dependency detected"));
414 }
415
416 if visited.contains(&node_id) {
417 return Ok(());
418 }
419
420 temp_visited.insert(node_id);
421
422 if let Some(node) = self.dynamic_nodes.get(&node_id) {
423 for input_node in &node.inputs {
424 self.topological_sort(input_node.id, visited, temp_visited, order)?;
425 }
426 }
427
428 temp_visited.remove(&node_id);
429 visited.insert(node_id);
430 order.push(node_id);
431
432 Ok(())
433 }
434
435 pub fn get_stats(&self) -> &DynamicExecutionStats {
437 &self.stats
438 }
439
440 pub fn clear_cache(&mut self) {
442 for node in self.dynamic_nodes.values() {
443 node.mark_dirty();
444 }
445 *self.execution_order.write().unwrap() = None;
446 }
447
448 pub fn create_execution_plan(&self, output_node_id: usize) -> RusTorchResult<ExecutionPlan<T>> {
450 let mut plan = ExecutionPlan::new();
451
452 let mut visited = std::collections::HashSet::new();
454 self.build_execution_plan_recursive(output_node_id, &mut visited, &mut plan)?;
455
456 plan.optimize_memory_usage();
458 plan.optimize_execution_order();
459
460 Ok(plan)
461 }
462
463 fn build_execution_plan_recursive(
465 &self,
466 node_id: usize,
467 visited: &mut std::collections::HashSet<usize>,
468 plan: &mut ExecutionPlan<T>,
469 ) -> RusTorchResult<()> {
470 if visited.contains(&node_id) {
471 return Ok(());
472 }
473
474 if let Some(node) = self.dynamic_nodes.get(&node_id) {
475 for input_node in &node.inputs {
477 self.build_execution_plan_recursive(input_node.id, visited, plan)?;
478 }
479
480 plan.add_operation(
482 node_id,
483 node.op.clone(),
484 node.inputs.iter().map(|n| n.id).collect(),
485 );
486 visited.insert(node_id);
487 }
488
489 Ok(())
490 }
491}
492
493#[derive(Clone)]
496pub struct ExecutionPlan<T: Float + Send + Sync + 'static> {
497 pub operations: Vec<PlannedOperation>,
499 pub memory_plan: MemoryPlan,
501 pub parallel_groups: Vec<Vec<usize>>,
503 _phantom: std::marker::PhantomData<T>,
504}
505
506#[derive(Debug, Clone)]
509pub struct PlannedOperation {
510 pub node_id: usize,
512 pub op: DynamicOp,
514 pub input_ids: Vec<usize>,
516 pub estimated_time: Option<std::time::Duration>,
518 pub memory_requirement: usize,
520 pub parallel_safe: bool,
522}
523
524#[derive(Debug, Default, Clone)]
527pub struct MemoryPlan {
528 pub peak_memory: usize,
530 pub allocations: Vec<MemoryAllocation>,
532 pub reuse_map: HashMap<usize, usize>,
534}
535
536#[derive(Debug, Clone)]
539pub struct MemoryAllocation {
540 pub operation_id: usize,
542 pub size: usize,
544 pub lifetime_end: usize,
546 pub reuse_from: Option<usize>,
548}
549
550impl<T: Float + Send + Sync + 'static> ExecutionPlan<T> {
551 pub fn new() -> Self {
553 ExecutionPlan {
554 operations: Vec::new(),
555 memory_plan: MemoryPlan::default(),
556 parallel_groups: Vec::new(),
557 _phantom: std::marker::PhantomData,
558 }
559 }
560
561 pub fn add_operation(&mut self, node_id: usize, op: DynamicOp, input_ids: Vec<usize>) {
563 let planned_op = PlannedOperation {
564 node_id,
565 op,
566 input_ids,
567 estimated_time: None,
568 memory_requirement: 0,
569 parallel_safe: false,
570 };
571 self.operations.push(planned_op);
572 }
573
574 pub fn optimize_memory_usage(&mut self) {
576 let mut last_use = HashMap::new();
578
579 for (op_idx, op) in self.operations.iter().enumerate() {
580 for &input_id in &op.input_ids {
581 last_use.insert(input_id, op_idx);
582 }
583 }
584
585 for (op_idx, op) in self.operations.iter().enumerate() {
587 let allocation = MemoryAllocation {
588 operation_id: op.node_id,
589 size: op.memory_requirement,
590 lifetime_end: last_use.get(&op.node_id).copied().unwrap_or(op_idx),
591 reuse_from: None,
592 };
593 self.memory_plan.allocations.push(allocation);
594 }
595 }
596
597 pub fn optimize_execution_order(&mut self) {
599 let mut current_group = Vec::new();
601
602 for (idx, op) in self.operations.iter().enumerate() {
603 let has_dependency = current_group.iter().any(|&group_idx: &usize| {
605 op.input_ids.contains(&self.operations[group_idx].node_id)
606 });
607
608 if has_dependency {
609 if !current_group.is_empty() {
611 self.parallel_groups.push(current_group.clone());
612 current_group.clear();
613 }
614 }
615
616 current_group.push(idx);
617 }
618
619 if !current_group.is_empty() {
620 self.parallel_groups.push(current_group);
621 }
622 }
623
624 pub fn estimated_execution_time(&self) -> std::time::Duration {
626 let mut total_time = std::time::Duration::default();
627
628 for group in &self.parallel_groups {
629 let group_time = group
631 .iter()
632 .filter_map(|&idx| self.operations[idx].estimated_time)
633 .max()
634 .unwrap_or_default();
635 total_time += group_time;
636 }
637
638 total_time
639 }
640
641 pub fn peak_memory_usage(&self) -> usize {
643 self.memory_plan.peak_memory
644 }
645}
646
647pub struct JitCompiler<T: Float + Send + Sync + 'static> {
650 compiled_cache: HashMap<String, Arc<dyn Function<T>>>,
652 compilation_stats: JitStats,
654}
655
656#[derive(Debug, Default)]
659pub struct JitStats {
660 pub compilations: usize,
662 pub cache_hits: usize,
664 pub compilation_time: std::time::Duration,
666 pub average_speedup: f64,
668}
669
670impl<T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive>
671 JitCompiler<T>
672{
673 pub fn new() -> Self {
675 JitCompiler {
676 compiled_cache: HashMap::new(),
677 compilation_stats: JitStats::default(),
678 }
679 }
680
681 pub fn compile_operations(
683 &mut self,
684 ops: &[DynamicOp],
685 ) -> RusTorchResult<Arc<dyn Function<T>>> {
686 let ops_key = format!("{:?}", ops);
687
688 if let Some(cached) = self.compiled_cache.get(&ops_key) {
689 self.compilation_stats.cache_hits += 1;
690 return Ok(cached.clone());
691 }
692
693 let start_time = Instant::now();
694
695 let fused_op = self.create_fused_operation(ops)?;
697
698 self.compilation_stats.compilations += 1;
699 self.compilation_stats.compilation_time += start_time.elapsed();
700
701 let fused_fn = Arc::new(fused_op);
702 self.compiled_cache.insert(ops_key, fused_fn.clone());
703
704 Ok(fused_fn)
705 }
706
707 fn create_fused_operation(&self, ops: &[DynamicOp]) -> RusTorchResult<FusedOperation<T>> {
709 Ok(FusedOperation::new(ops.to_vec()))
710 }
711
712 pub fn get_stats(&self) -> &JitStats {
714 &self.compilation_stats
715 }
716}
717
718pub struct FusedOperation<T: Float + Send + Sync + 'static> {
721 operations: Vec<DynamicOp>,
722 _phantom: std::marker::PhantomData<T>,
723}
724
725impl<T: Float + Send + Sync + 'static> FusedOperation<T> {
726 pub fn new(operations: Vec<DynamicOp>) -> Self {
728 FusedOperation {
729 operations,
730 _phantom: std::marker::PhantomData,
731 }
732 }
733}
734
735impl<T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive>
736 Function<T> for FusedOperation<T>
737{
738 fn forward(&self, inputs: &[&Tensor<T>]) -> Tensor<T> {
739 if inputs.is_empty() {
742 Tensor::zeros(&[1])
743 } else {
744 inputs[0].clone()
745 }
746 }
747
748 fn backward(&self, grad_output: &Tensor<T>, inputs: &[&Tensor<T>]) -> Vec<Option<Tensor<T>>> {
749 vec![Some(grad_output.clone()); inputs.len()]
752 }
753}
754
755#[cfg(test)]
756mod tests {
757 use super::*;
758
759 #[test]
760 fn test_dynamic_execution_context_creation() {
761 let mut ctx = DynamicExecutionContext::<f32>::new();
762
763 let input1 = Tensor::zeros(&[2, 3]);
765 let input2 = Tensor::ones(&[2, 3]);
766
767 let leaf1_id = ctx.add_leaf(input1).unwrap();
768 let leaf2_id = ctx.add_leaf(input2).unwrap();
769
770 let add_id = ctx
772 .add_operation(DynamicOp::Add, vec![leaf1_id, leaf2_id])
773 .unwrap();
774
775 let result = ctx.execute(add_id).unwrap();
777 assert_eq!(result.shape(), &[2, 3]);
778 }
779
780 #[test]
781 fn test_execution_plan() {
782 let mut plan = ExecutionPlan::<f32>::new();
783 plan.add_operation(0, DynamicOp::Add, vec![]);
784 plan.add_operation(1, DynamicOp::ReLU, vec![0]);
785
786 plan.optimize_execution_order();
787 assert!(!plan.parallel_groups.is_empty());
788 }
789
790 #[test]
791 fn test_jit_compiler() {
792 let mut compiler = JitCompiler::<f32>::new();
793
794 let ops = vec![DynamicOp::Add, DynamicOp::ReLU];
795 let compiled = compiler.compile_operations(&ops).unwrap();
796
797 let compiled2 = compiler.compile_operations(&ops).unwrap();
799 assert_eq!(compiler.get_stats().cache_hits, 1);
800 }
801
802 #[test]
803 fn test_relu_operation() {
804 let mut ctx = DynamicExecutionContext::<f32>::new();
805
806 let input_data = vec![-1.0, 0.0, 1.0, 2.0];
808 let input = Tensor::from_vec(input_data, vec![4]);
809 let leaf_id = ctx.add_leaf(input).unwrap();
810 let relu_id = ctx.add_operation(DynamicOp::ReLU, vec![leaf_id]).unwrap();
811
812 let result = ctx.execute(relu_id).unwrap();
813 let expected = vec![0.0, 0.0, 1.0, 2.0];
814
815 if let Some(slice) = result.as_slice() {
816 for (actual, expected) in slice.iter().zip(expected.iter()) {
817 assert!((actual - expected).abs() < 1e-6);
818 }
819 }
820 }
821
822 #[test]
823 fn test_sigmoid_operation() {
824 let mut ctx = DynamicExecutionContext::<f32>::new();
825
826 let input = Tensor::from_vec(vec![0.0], vec![1]);
827 let leaf_id = ctx.add_leaf(input).unwrap();
828 let sigmoid_id = ctx
829 .add_operation(DynamicOp::Sigmoid, vec![leaf_id])
830 .unwrap();
831
832 let result = ctx.execute(sigmoid_id).unwrap();
833
834 if let Some(slice) = result.as_slice() {
836 assert!((slice[0] - 0.5).abs() < 1e-6);
837 }
838 }
839
840 #[test]
841 fn test_linear_operation() {
842 let mut ctx = DynamicExecutionContext::<f32>::new();
843
844 let input = Tensor::ones(&[2, 3]);
845 let weight = Tensor::ones(&[4, 3]); let bias = Tensor::zeros(&[4]);
847
848 let input_id = ctx.add_leaf(input).unwrap();
849 let weight_id = ctx.add_leaf(weight).unwrap();
850 let bias_id = ctx.add_leaf(bias).unwrap();
851
852 let linear_id = ctx
853 .add_operation(
854 DynamicOp::Linear {
855 in_features: 3,
856 out_features: 4,
857 },
858 vec![input_id, weight_id, bias_id],
859 )
860 .unwrap();
861
862 let result = ctx.execute(linear_id).unwrap();
863 assert_eq!(result.shape(), &[2, 4]);
864
865 if let Some(slice) = result.as_slice() {
867 for &value in slice {
868 assert!((value - 3.0).abs() < 1e-6);
869 }
870 }
871 }
872}