1use super::core::{
7 DebugStatistics, DebugValue, DebuggerConfig, ExecutionState, InstructionExecutionResult,
8 NodeExecutionResult,
9};
10use crate::{
11 graph::Node,
12 ir::{Instruction, IrModule, IrOpcode},
13 ComputationGraph, JitError, JitResult, NodeId,
14};
15use std::collections::HashMap;
16use std::time::{Duration, Instant};
17use torsh_core::{DType, Shape};
18
19pub struct DebugExecutionEngine {
24 config: DebuggerConfig,
25 execution_count: usize,
26 total_execution_time: Duration,
27 instruction_timings: HashMap<String, Vec<Duration>>,
28 operation_stats: HashMap<String, OperationStatistics>,
29}
30
31#[derive(Debug, Clone)]
33pub struct OperationStatistics {
34 pub count: usize,
35 pub total_time: Duration,
36 pub average_time: Duration,
37 pub min_time: Duration,
38 pub max_time: Duration,
39}
40
41impl DebugExecutionEngine {
42 pub fn new(config: DebuggerConfig) -> Self {
47 Self {
48 config,
49 execution_count: 0,
50 total_execution_time: Duration::new(0, 0),
51 instruction_timings: HashMap::new(),
52 operation_stats: HashMap::new(),
53 }
54 }
55
56 pub fn execute_node_debug(
66 &mut self,
67 node: &Node,
68 graph: &ComputationGraph,
69 node_id: NodeId,
70 ) -> JitResult<NodeExecutionResult> {
71 let start_time = Instant::now();
72
73 let result = self.execute_node_operation(node, graph, node_id)?;
75
76 let execution_time = start_time.elapsed();
77 self.record_operation_timing(node.op.as_str(), execution_time);
78 self.execution_count += 1;
79 self.total_execution_time += execution_time;
80
81 Ok(result)
82 }
83
84 pub fn execute_instruction_debug(
94 &mut self,
95 instruction: &Instruction,
96 ir_module: &IrModule,
97 execution_state: &mut ExecutionState,
98 ) -> JitResult<InstructionExecutionResult> {
99 let start_time = Instant::now();
100
101 let result = self.execute_ir_instruction(instruction, ir_module, execution_state)?;
103
104 let execution_time = start_time.elapsed();
105 let instruction_name = format!("{:?}", instruction.opcode);
106 self.record_operation_timing(&instruction_name, execution_time);
107 self.execution_count += 1;
108 self.total_execution_time += execution_time;
109
110 Ok(result)
111 }
112
113 fn execute_node_operation(
115 &self,
116 node: &Node,
117 graph: &ComputationGraph,
118 node_id: NodeId,
119 ) -> JitResult<NodeExecutionResult> {
120 let inputs = self.get_node_inputs(graph, node_id)?;
122
123 match node.op.as_str() {
125 "add" => self.execute_add_operation(&inputs),
126 "mul" => self.execute_multiply_operation(&inputs),
127 "sub" => self.execute_subtract_operation(&inputs),
128 "div" => self.execute_divide_operation(&inputs),
129 "relu" => self.execute_relu_operation(&inputs),
130 "sigmoid" => self.execute_sigmoid_operation(&inputs),
131 "tanh" => self.execute_tanh_operation(&inputs),
132 "matmul" => self.execute_matmul_operation(&inputs),
133 "reshape" => self.execute_reshape_operation(&inputs, &node.attrs),
134 "transpose" => self.execute_transpose_operation(&inputs),
135 "concat" => self.execute_concat_operation(&inputs, &node.attrs),
136 "split" => self.execute_split_operation(&inputs, &node.attrs),
137 _ => {
138 Ok(NodeExecutionResult {
140 data: vec![0.0],
141 shape: Shape::new(vec![1]),
142 dtype: DType::F32,
143 })
144 }
145 }
146 }
147
148 fn execute_ir_instruction(
150 &self,
151 instruction: &Instruction,
152 ir_module: &IrModule,
153 execution_state: &mut ExecutionState,
154 ) -> JitResult<InstructionExecutionResult> {
155 match instruction.opcode {
156 IrOpcode::Add => {
157 let result = self.execute_ir_add(instruction, execution_state)?;
158 Ok(InstructionExecutionResult::Value(result))
159 }
160 IrOpcode::Mul => {
161 let result = self.execute_ir_multiply(instruction, execution_state)?;
162 Ok(InstructionExecutionResult::Value(result))
163 }
164 IrOpcode::Sub => {
165 let result = self.execute_ir_subtract(instruction, execution_state)?;
166 Ok(InstructionExecutionResult::Value(result))
167 }
168 IrOpcode::Div => {
169 let result = self.execute_ir_divide(instruction, execution_state)?;
170 Ok(InstructionExecutionResult::Value(result))
171 }
172 IrOpcode::Const => {
173 let result = self.execute_ir_const(instruction)?;
174 Ok(InstructionExecutionResult::Value(result))
175 }
176 IrOpcode::Load => {
177 let result = self.execute_ir_load(instruction, execution_state)?;
178 Ok(InstructionExecutionResult::Value(result))
179 }
180 IrOpcode::Store => {
181 self.execute_ir_store(instruction, execution_state)?;
182 Ok(InstructionExecutionResult::SideEffect)
183 }
184 IrOpcode::Return => Ok(InstructionExecutionResult::Return),
185 IrOpcode::Call => {
186 let result = self.execute_ir_call(instruction, ir_module, execution_state)?;
187 Ok(InstructionExecutionResult::Value(result))
188 }
189 _ => Ok(InstructionExecutionResult::NoOp),
190 }
191 }
192
193 fn execute_add_operation(
196 &self,
197 inputs: &[NodeExecutionResult],
198 ) -> JitResult<NodeExecutionResult> {
199 if inputs.len() != 2 {
200 return Err(JitError::RuntimeError(
201 "Add operation requires exactly 2 inputs".to_string(),
202 ));
203 }
204
205 let a = &inputs[0];
206 let b = &inputs[1];
207
208 if a.shape != b.shape {
209 return Err(JitError::RuntimeError(
210 "Shape mismatch in add operation".to_string(),
211 ));
212 }
213
214 let result_data: Vec<f32> = a
215 .data
216 .iter()
217 .zip(b.data.iter())
218 .map(|(&x, &y)| x + y)
219 .collect();
220
221 Ok(NodeExecutionResult {
222 data: result_data,
223 shape: a.shape.clone(),
224 dtype: a.dtype,
225 })
226 }
227
228 fn execute_multiply_operation(
229 &self,
230 inputs: &[NodeExecutionResult],
231 ) -> JitResult<NodeExecutionResult> {
232 if inputs.len() != 2 {
233 return Err(JitError::RuntimeError(
234 "Multiply operation requires exactly 2 inputs".to_string(),
235 ));
236 }
237
238 let a = &inputs[0];
239 let b = &inputs[1];
240
241 if a.shape != b.shape {
242 return Err(JitError::RuntimeError(
243 "Shape mismatch in multiply operation".to_string(),
244 ));
245 }
246
247 let result_data: Vec<f32> = a
248 .data
249 .iter()
250 .zip(b.data.iter())
251 .map(|(&x, &y)| x * y)
252 .collect();
253
254 Ok(NodeExecutionResult {
255 data: result_data,
256 shape: a.shape.clone(),
257 dtype: a.dtype,
258 })
259 }
260
261 fn execute_subtract_operation(
262 &self,
263 inputs: &[NodeExecutionResult],
264 ) -> JitResult<NodeExecutionResult> {
265 if inputs.len() != 2 {
266 return Err(JitError::RuntimeError(
267 "Subtract operation requires exactly 2 inputs".to_string(),
268 ));
269 }
270
271 let a = &inputs[0];
272 let b = &inputs[1];
273
274 if a.shape != b.shape {
275 return Err(JitError::RuntimeError(
276 "Shape mismatch in subtract operation".to_string(),
277 ));
278 }
279
280 let result_data: Vec<f32> = a
281 .data
282 .iter()
283 .zip(b.data.iter())
284 .map(|(&x, &y)| x - y)
285 .collect();
286
287 Ok(NodeExecutionResult {
288 data: result_data,
289 shape: a.shape.clone(),
290 dtype: a.dtype,
291 })
292 }
293
294 fn execute_divide_operation(
295 &self,
296 inputs: &[NodeExecutionResult],
297 ) -> JitResult<NodeExecutionResult> {
298 if inputs.len() != 2 {
299 return Err(JitError::RuntimeError(
300 "Divide operation requires exactly 2 inputs".to_string(),
301 ));
302 }
303
304 let a = &inputs[0];
305 let b = &inputs[1];
306
307 if a.shape != b.shape {
308 return Err(JitError::RuntimeError(
309 "Shape mismatch in divide operation".to_string(),
310 ));
311 }
312
313 let result_data: Vec<f32> = a
314 .data
315 .iter()
316 .zip(b.data.iter())
317 .map(|(&x, &y)| {
318 if y.abs() < f32::EPSILON {
319 f32::INFINITY
320 } else {
321 x / y
322 }
323 })
324 .collect();
325
326 Ok(NodeExecutionResult {
327 data: result_data,
328 shape: a.shape.clone(),
329 dtype: a.dtype,
330 })
331 }
332
333 fn execute_relu_operation(
334 &self,
335 inputs: &[NodeExecutionResult],
336 ) -> JitResult<NodeExecutionResult> {
337 if inputs.len() != 1 {
338 return Err(JitError::RuntimeError(
339 "ReLU operation requires exactly 1 input".to_string(),
340 ));
341 }
342
343 let input = &inputs[0];
344 let result_data: Vec<f32> = input.data.iter().map(|&x| x.max(0.0)).collect();
345
346 Ok(NodeExecutionResult {
347 data: result_data,
348 shape: input.shape.clone(),
349 dtype: input.dtype,
350 })
351 }
352
353 fn execute_sigmoid_operation(
354 &self,
355 inputs: &[NodeExecutionResult],
356 ) -> JitResult<NodeExecutionResult> {
357 if inputs.len() != 1 {
358 return Err(JitError::RuntimeError(
359 "Sigmoid operation requires exactly 1 input".to_string(),
360 ));
361 }
362
363 let input = &inputs[0];
364 let result_data: Vec<f32> = input
365 .data
366 .iter()
367 .map(|&x| 1.0 / (1.0 + (-x).exp()))
368 .collect();
369
370 Ok(NodeExecutionResult {
371 data: result_data,
372 shape: input.shape.clone(),
373 dtype: input.dtype,
374 })
375 }
376
377 fn execute_tanh_operation(
378 &self,
379 inputs: &[NodeExecutionResult],
380 ) -> JitResult<NodeExecutionResult> {
381 if inputs.len() != 1 {
382 return Err(JitError::RuntimeError(
383 "Tanh operation requires exactly 1 input".to_string(),
384 ));
385 }
386
387 let input = &inputs[0];
388 let result_data: Vec<f32> = input.data.iter().map(|&x| x.tanh()).collect();
389
390 Ok(NodeExecutionResult {
391 data: result_data,
392 shape: input.shape.clone(),
393 dtype: input.dtype,
394 })
395 }
396
397 fn execute_matmul_operation(
398 &self,
399 inputs: &[NodeExecutionResult],
400 ) -> JitResult<NodeExecutionResult> {
401 if inputs.len() != 2 {
402 return Err(JitError::RuntimeError(
403 "MatMul operation requires exactly 2 inputs".to_string(),
404 ));
405 }
406
407 let a = &inputs[0];
409 let b = &inputs[1];
410
411 if a.shape.ndim() != 2 || b.shape.ndim() != 2 {
413 return Err(JitError::RuntimeError(
414 "MatMul requires 2D matrices".to_string(),
415 ));
416 }
417
418 let (m, k) = (a.shape.dims()[0], a.shape.dims()[1]);
419 let (k2, n) = (b.shape.dims()[0], b.shape.dims()[1]);
420
421 if k != k2 {
422 return Err(JitError::RuntimeError(
423 "Matrix dimension mismatch".to_string(),
424 ));
425 }
426
427 let mut result_data = vec![0.0; m * n];
428
429 for i in 0..m {
430 for j in 0..n {
431 for l in 0..k {
432 result_data[i * n + j] += a.data[i * k + l] * b.data[l * n + j];
433 }
434 }
435 }
436
437 Ok(NodeExecutionResult {
438 data: result_data,
439 shape: Shape::new(vec![m, n]),
440 dtype: a.dtype,
441 })
442 }
443
444 fn execute_reshape_operation(
445 &self,
446 inputs: &[NodeExecutionResult],
447 attributes: &HashMap<String, crate::graph::Attribute>,
448 ) -> JitResult<NodeExecutionResult> {
449 if inputs.len() != 1 {
450 return Err(JitError::RuntimeError(
451 "Reshape operation requires exactly 1 input".to_string(),
452 ));
453 }
454
455 let input = &inputs[0];
456
457 let shape_attr = attributes.get("shape").ok_or_else(|| {
459 JitError::RuntimeError("Reshape operation missing shape attribute".to_string())
460 })?;
461
462 let new_shape_str = match shape_attr {
464 crate::graph::Attribute::String(s) => s,
465 _ => {
466 return Err(JitError::RuntimeError(
467 "Reshape shape attribute must be a string".to_string(),
468 ))
469 }
470 };
471
472 let new_dims: Result<Vec<usize>, _> = new_shape_str
474 .trim_matches(['[', ']'])
475 .split(',')
476 .map(|s| s.trim().parse())
477 .collect();
478
479 let new_dims =
480 new_dims.map_err(|_| JitError::RuntimeError("Invalid shape format".to_string()))?;
481 let new_shape = Shape::new(new_dims);
482
483 if input.shape.numel() != new_shape.numel() {
485 return Err(JitError::RuntimeError(
486 "Reshape: total elements must remain constant".to_string(),
487 ));
488 }
489
490 Ok(NodeExecutionResult {
491 data: input.data.clone(),
492 shape: new_shape,
493 dtype: input.dtype,
494 })
495 }
496
497 fn execute_transpose_operation(
498 &self,
499 inputs: &[NodeExecutionResult],
500 ) -> JitResult<NodeExecutionResult> {
501 if inputs.len() != 1 {
502 return Err(JitError::RuntimeError(
503 "Transpose operation requires exactly 1 input".to_string(),
504 ));
505 }
506
507 let input = &inputs[0];
508
509 if input.shape.ndim() != 2 {
511 return Err(JitError::RuntimeError(
512 "Transpose currently supports only 2D matrices".to_string(),
513 ));
514 }
515
516 let (rows, cols) = (input.shape.dims()[0], input.shape.dims()[1]);
517 let mut result_data = vec![0.0; rows * cols];
518
519 for i in 0..rows {
520 for j in 0..cols {
521 result_data[j * rows + i] = input.data[i * cols + j];
522 }
523 }
524
525 Ok(NodeExecutionResult {
526 data: result_data,
527 shape: Shape::new(vec![cols, rows]),
528 dtype: input.dtype,
529 })
530 }
531
532 fn execute_concat_operation(
533 &self,
534 inputs: &[NodeExecutionResult],
535 attributes: &HashMap<String, crate::graph::Attribute>,
536 ) -> JitResult<NodeExecutionResult> {
537 if inputs.is_empty() {
538 return Err(JitError::RuntimeError(
539 "Concat operation requires at least 1 input".to_string(),
540 ));
541 }
542
543 let axis = attributes
545 .get("axis")
546 .and_then(|attr| match attr {
547 crate::graph::Attribute::String(s) => s.parse::<usize>().ok(),
548 crate::graph::Attribute::Int(i) => Some(*i as usize),
549 _ => None,
550 })
551 .unwrap_or(0);
552
553 if axis != 0 {
555 return Err(JitError::RuntimeError(
556 "Concat currently supports only axis 0".to_string(),
557 ));
558 }
559
560 let first_input = &inputs[0];
561 let mut total_size = first_input.shape.dims()[0];
562 let mut result_data = first_input.data.clone();
563
564 for input in &inputs[1..] {
565 if input.shape.ndim() != first_input.shape.ndim() {
566 return Err(JitError::RuntimeError(
567 "All inputs must have same number of dimensions".to_string(),
568 ));
569 }
570
571 for (i, (&dim1, &dim2)) in first_input.shape.dims()[1..]
573 .iter()
574 .zip(input.shape.dims()[1..].iter())
575 .enumerate()
576 {
577 if dim1 != dim2 {
578 return Err(JitError::RuntimeError(format!(
579 "Dimension mismatch at axis {}",
580 i + 1
581 )));
582 }
583 }
584
585 total_size += input.shape.dims()[0];
586 result_data.extend_from_slice(&input.data);
587 }
588
589 let mut new_dims = first_input.shape.dims().to_vec();
590 new_dims[0] = total_size;
591
592 Ok(NodeExecutionResult {
593 data: result_data,
594 shape: Shape::new(new_dims),
595 dtype: first_input.dtype,
596 })
597 }
598
599 fn execute_split_operation(
600 &self,
601 inputs: &[NodeExecutionResult],
602 attributes: &HashMap<String, crate::graph::Attribute>,
603 ) -> JitResult<NodeExecutionResult> {
604 if inputs.len() != 1 {
605 return Err(JitError::RuntimeError(
606 "Split operation requires exactly 1 input".to_string(),
607 ));
608 }
609
610 Ok(inputs[0].clone())
613 }
614
615 fn execute_ir_add(
618 &self,
619 instruction: &Instruction,
620 execution_state: &ExecutionState,
621 ) -> JitResult<DebugValue> {
622 Ok(DebugValue::Scalar(42.0))
624 }
625
626 fn execute_ir_multiply(
627 &self,
628 instruction: &Instruction,
629 execution_state: &ExecutionState,
630 ) -> JitResult<DebugValue> {
631 Ok(DebugValue::Scalar(84.0))
633 }
634
635 fn execute_ir_subtract(
636 &self,
637 instruction: &Instruction,
638 execution_state: &ExecutionState,
639 ) -> JitResult<DebugValue> {
640 Ok(DebugValue::Scalar(21.0))
642 }
643
644 fn execute_ir_divide(
645 &self,
646 instruction: &Instruction,
647 execution_state: &ExecutionState,
648 ) -> JitResult<DebugValue> {
649 Ok(DebugValue::Scalar(2.0))
651 }
652
653 fn execute_ir_const(&self, instruction: &Instruction) -> JitResult<DebugValue> {
654 Ok(DebugValue::Scalar(1.0))
656 }
657
658 fn execute_ir_load(
659 &self,
660 instruction: &Instruction,
661 execution_state: &ExecutionState,
662 ) -> JitResult<DebugValue> {
663 Ok(DebugValue::Scalar(std::f64::consts::PI))
665 }
666
667 fn execute_ir_store(
668 &self,
669 instruction: &Instruction,
670 execution_state: &mut ExecutionState,
671 ) -> JitResult<()> {
672 Ok(())
674 }
675
676 fn execute_ir_call(
677 &self,
678 instruction: &Instruction,
679 ir_module: &IrModule,
680 execution_state: &ExecutionState,
681 ) -> JitResult<DebugValue> {
682 Ok(DebugValue::Scalar(100.0))
684 }
685
686 fn get_node_inputs(
689 &self,
690 graph: &ComputationGraph,
691 node_id: NodeId,
692 ) -> JitResult<Vec<NodeExecutionResult>> {
693 Ok(vec![NodeExecutionResult {
695 data: vec![1.0, 2.0, 3.0],
696 shape: Shape::new(vec![3]),
697 dtype: DType::F32,
698 }])
699 }
700
701 fn record_operation_timing(&mut self, operation: &str, duration: Duration) {
702 self.instruction_timings
703 .entry(operation.to_string())
704 .or_insert_with(Vec::new)
705 .push(duration);
706
707 let timings = &self.instruction_timings[operation];
709 let count = timings.len();
710 let total_time: Duration = timings.iter().sum();
711 let average_time = total_time / count as u32;
712 let min_time = *timings.iter().min().expect("timings should not be empty");
713 let max_time = *timings.iter().max().expect("timings should not be empty");
714
715 self.operation_stats.insert(
716 operation.to_string(),
717 OperationStatistics {
718 count,
719 total_time,
720 average_time,
721 min_time,
722 max_time,
723 },
724 );
725 }
726
727 pub fn get_statistics(&self) -> DebugStatistics {
729 DebugStatistics {
730 total_steps: self.execution_count,
731 total_execution_time: self.total_execution_time,
732 breakpoints_hit: 0, watches_triggered: 0, }
735 }
736
737 pub fn get_operation_statistics(&self) -> &HashMap<String, OperationStatistics> {
739 &self.operation_stats
740 }
741
742 pub fn get_operation_timings(&self, operation: &str) -> Option<&Vec<Duration>> {
744 self.instruction_timings.get(operation)
745 }
746
747 pub fn reset_statistics(&mut self) {
749 self.execution_count = 0;
750 self.total_execution_time = Duration::new(0, 0);
751 self.instruction_timings.clear();
752 self.operation_stats.clear();
753 }
754
755 pub fn config(&self) -> &DebuggerConfig {
757 &self.config
758 }
759
760 pub fn update_config(&mut self, config: DebuggerConfig) {
762 self.config = config;
763 }
764}
765
766#[cfg(test)]
767mod tests {
768 use super::*;
769
770 #[test]
771 fn test_debug_execution_engine_creation() {
772 let config = DebuggerConfig::default();
773 let engine = DebugExecutionEngine::new(config);
774
775 assert_eq!(engine.execution_count, 0);
776 assert_eq!(engine.total_execution_time, Duration::new(0, 0));
777 assert!(engine.instruction_timings.is_empty());
778 assert!(engine.operation_stats.is_empty());
779 }
780
781 #[test]
782 fn test_operation_timing_recording() {
783 let config = DebuggerConfig::default();
784 let mut engine = DebugExecutionEngine::new(config);
785
786 let duration = Duration::from_millis(10);
787 engine.record_operation_timing("add", duration);
788
789 assert_eq!(engine.instruction_timings.get("add").unwrap().len(), 1);
790 assert!(engine.operation_stats.contains_key("add"));
791
792 let stats = &engine.operation_stats["add"];
793 assert_eq!(stats.count, 1);
794 assert_eq!(stats.total_time, duration);
795 assert_eq!(stats.min_time, duration);
796 assert_eq!(stats.max_time, duration);
797 }
798
799 #[test]
800 fn test_add_operation() {
801 let config = DebuggerConfig::default();
802 let engine = DebugExecutionEngine::new(config);
803
804 let input1 = NodeExecutionResult {
805 data: vec![1.0, 2.0, 3.0],
806 shape: Shape::new(vec![3]),
807 dtype: DType::F32,
808 };
809
810 let input2 = NodeExecutionResult {
811 data: vec![4.0, 5.0, 6.0],
812 shape: Shape::new(vec![3]),
813 dtype: DType::F32,
814 };
815
816 let result = engine.execute_add_operation(&[input1, input2]).unwrap();
817 assert_eq!(result.data, vec![5.0, 7.0, 9.0]);
818 assert_eq!(result.shape.dims(), &[3]);
819 }
820
821 #[test]
822 fn test_relu_operation() {
823 let config = DebuggerConfig::default();
824 let engine = DebugExecutionEngine::new(config);
825
826 let input = NodeExecutionResult {
827 data: vec![-1.0, 0.0, 1.0, -2.0, 3.0],
828 shape: Shape::new(vec![5]),
829 dtype: DType::F32,
830 };
831
832 let result = engine.execute_relu_operation(&[input]).unwrap();
833 assert_eq!(result.data, vec![0.0, 0.0, 1.0, 0.0, 3.0]);
834 }
835
836 #[test]
837 fn test_statistics_reset() {
838 let config = DebuggerConfig::default();
839 let mut engine = DebugExecutionEngine::new(config);
840
841 engine.record_operation_timing("test", Duration::from_millis(10));
842 engine.execution_count = 5;
843 engine.total_execution_time = Duration::from_millis(50);
844
845 engine.reset_statistics();
846
847 assert_eq!(engine.execution_count, 0);
848 assert_eq!(engine.total_execution_time, Duration::new(0, 0));
849 assert!(engine.instruction_timings.is_empty());
850 assert!(engine.operation_stats.is_empty());
851 }
852}