Skip to main content

torsh_jit/debugger/
execution.rs

1//! Debug execution engine for JIT debugging
2//!
3//! This module provides the execution engine specifically designed for debugging,
4//! with instrumentation and step-by-step execution capabilities.
5
6use super::core::{
7    DebugStatistics, DebugValue, DebuggerConfig, ExecutionState, InstructionExecutionResult,
8    NodeExecutionResult,
9};
10use crate::{
11    graph::Node,
12    ir::{Instruction, IrModule, IrOpcode},
13    ComputationGraph, JitError, JitResult, NodeId,
14};
15use std::collections::HashMap;
16use std::time::{Duration, Instant};
17use torsh_core::{DType, Shape};
18
19/// Debug execution engine
20///
21/// Provides specialized execution capabilities for debugging including
22/// instrumentation, step-by-step execution, and state tracking.
23pub struct DebugExecutionEngine {
24    config: DebuggerConfig,
25    execution_count: usize,
26    total_execution_time: Duration,
27    instruction_timings: HashMap<String, Vec<Duration>>,
28    operation_stats: HashMap<String, OperationStatistics>,
29}
30
31/// Statistics for individual operations
32#[derive(Debug, Clone)]
33pub struct OperationStatistics {
34    pub count: usize,
35    pub total_time: Duration,
36    pub average_time: Duration,
37    pub min_time: Duration,
38    pub max_time: Duration,
39}
40
41impl DebugExecutionEngine {
42    /// Create a new debug execution engine
43    ///
44    /// # Arguments
45    /// * `config` - Configuration for the debug execution engine
46    pub fn new(config: DebuggerConfig) -> Self {
47        Self {
48            config,
49            execution_count: 0,
50            total_execution_time: Duration::new(0, 0),
51            instruction_timings: HashMap::new(),
52            operation_stats: HashMap::new(),
53        }
54    }
55
56    /// Execute a graph node with debugging instrumentation
57    ///
58    /// # Arguments
59    /// * `node` - The node to execute
60    /// * `graph` - The computation graph containing the node
61    /// * `node_id` - The ID of the node being executed
62    ///
63    /// # Returns
64    /// The result of node execution with timing information
65    pub fn execute_node_debug(
66        &mut self,
67        node: &Node,
68        graph: &ComputationGraph,
69        node_id: NodeId,
70    ) -> JitResult<NodeExecutionResult> {
71        let start_time = Instant::now();
72
73        // Execute the actual node operation
74        let result = self.execute_node_operation(node, graph, node_id)?;
75
76        let execution_time = start_time.elapsed();
77        self.record_operation_timing(node.op.as_str(), execution_time);
78        self.execution_count += 1;
79        self.total_execution_time += execution_time;
80
81        Ok(result)
82    }
83
84    /// Execute an IR instruction with debugging instrumentation
85    ///
86    /// # Arguments
87    /// * `instruction` - The instruction to execute
88    /// * `ir_module` - The IR module containing the instruction
89    /// * `execution_state` - Current execution state
90    ///
91    /// # Returns
92    /// The result of instruction execution
93    pub fn execute_instruction_debug(
94        &mut self,
95        instruction: &Instruction,
96        ir_module: &IrModule,
97        execution_state: &mut ExecutionState,
98    ) -> JitResult<InstructionExecutionResult> {
99        let start_time = Instant::now();
100
101        // Execute the actual instruction
102        let result = self.execute_ir_instruction(instruction, ir_module, execution_state)?;
103
104        let execution_time = start_time.elapsed();
105        let instruction_name = format!("{:?}", instruction.opcode);
106        self.record_operation_timing(&instruction_name, execution_time);
107        self.execution_count += 1;
108        self.total_execution_time += execution_time;
109
110        Ok(result)
111    }
112
113    /// Execute a node operation (implementation)
114    fn execute_node_operation(
115        &self,
116        node: &Node,
117        graph: &ComputationGraph,
118        node_id: NodeId,
119    ) -> JitResult<NodeExecutionResult> {
120        // Get input values from predecessor nodes
121        let inputs = self.get_node_inputs(graph, node_id)?;
122
123        // Execute based on operation type
124        match node.op.as_str() {
125            "add" => self.execute_add_operation(&inputs),
126            "mul" => self.execute_multiply_operation(&inputs),
127            "sub" => self.execute_subtract_operation(&inputs),
128            "div" => self.execute_divide_operation(&inputs),
129            "relu" => self.execute_relu_operation(&inputs),
130            "sigmoid" => self.execute_sigmoid_operation(&inputs),
131            "tanh" => self.execute_tanh_operation(&inputs),
132            "matmul" => self.execute_matmul_operation(&inputs),
133            "reshape" => self.execute_reshape_operation(&inputs, &node.attrs),
134            "transpose" => self.execute_transpose_operation(&inputs),
135            "concat" => self.execute_concat_operation(&inputs, &node.attrs),
136            "split" => self.execute_split_operation(&inputs, &node.attrs),
137            _ => {
138                // Default implementation for unknown operations
139                Ok(NodeExecutionResult {
140                    data: vec![0.0],
141                    shape: Shape::new(vec![1]),
142                    dtype: DType::F32,
143                })
144            }
145        }
146    }
147
148    /// Execute an IR instruction (implementation)
149    fn execute_ir_instruction(
150        &self,
151        instruction: &Instruction,
152        ir_module: &IrModule,
153        execution_state: &mut ExecutionState,
154    ) -> JitResult<InstructionExecutionResult> {
155        match instruction.opcode {
156            IrOpcode::Add => {
157                let result = self.execute_ir_add(instruction, execution_state)?;
158                Ok(InstructionExecutionResult::Value(result))
159            }
160            IrOpcode::Mul => {
161                let result = self.execute_ir_multiply(instruction, execution_state)?;
162                Ok(InstructionExecutionResult::Value(result))
163            }
164            IrOpcode::Sub => {
165                let result = self.execute_ir_subtract(instruction, execution_state)?;
166                Ok(InstructionExecutionResult::Value(result))
167            }
168            IrOpcode::Div => {
169                let result = self.execute_ir_divide(instruction, execution_state)?;
170                Ok(InstructionExecutionResult::Value(result))
171            }
172            IrOpcode::Const => {
173                let result = self.execute_ir_const(instruction)?;
174                Ok(InstructionExecutionResult::Value(result))
175            }
176            IrOpcode::Load => {
177                let result = self.execute_ir_load(instruction, execution_state)?;
178                Ok(InstructionExecutionResult::Value(result))
179            }
180            IrOpcode::Store => {
181                self.execute_ir_store(instruction, execution_state)?;
182                Ok(InstructionExecutionResult::SideEffect)
183            }
184            IrOpcode::Return => Ok(InstructionExecutionResult::Return),
185            IrOpcode::Call => {
186                let result = self.execute_ir_call(instruction, ir_module, execution_state)?;
187                Ok(InstructionExecutionResult::Value(result))
188            }
189            _ => Ok(InstructionExecutionResult::NoOp),
190        }
191    }
192
193    // Node operation implementations
194
195    fn execute_add_operation(
196        &self,
197        inputs: &[NodeExecutionResult],
198    ) -> JitResult<NodeExecutionResult> {
199        if inputs.len() != 2 {
200            return Err(JitError::RuntimeError(
201                "Add operation requires exactly 2 inputs".to_string(),
202            ));
203        }
204
205        let a = &inputs[0];
206        let b = &inputs[1];
207
208        if a.shape != b.shape {
209            return Err(JitError::RuntimeError(
210                "Shape mismatch in add operation".to_string(),
211            ));
212        }
213
214        let result_data: Vec<f32> = a
215            .data
216            .iter()
217            .zip(b.data.iter())
218            .map(|(&x, &y)| x + y)
219            .collect();
220
221        Ok(NodeExecutionResult {
222            data: result_data,
223            shape: a.shape.clone(),
224            dtype: a.dtype,
225        })
226    }
227
228    fn execute_multiply_operation(
229        &self,
230        inputs: &[NodeExecutionResult],
231    ) -> JitResult<NodeExecutionResult> {
232        if inputs.len() != 2 {
233            return Err(JitError::RuntimeError(
234                "Multiply operation requires exactly 2 inputs".to_string(),
235            ));
236        }
237
238        let a = &inputs[0];
239        let b = &inputs[1];
240
241        if a.shape != b.shape {
242            return Err(JitError::RuntimeError(
243                "Shape mismatch in multiply operation".to_string(),
244            ));
245        }
246
247        let result_data: Vec<f32> = a
248            .data
249            .iter()
250            .zip(b.data.iter())
251            .map(|(&x, &y)| x * y)
252            .collect();
253
254        Ok(NodeExecutionResult {
255            data: result_data,
256            shape: a.shape.clone(),
257            dtype: a.dtype,
258        })
259    }
260
261    fn execute_subtract_operation(
262        &self,
263        inputs: &[NodeExecutionResult],
264    ) -> JitResult<NodeExecutionResult> {
265        if inputs.len() != 2 {
266            return Err(JitError::RuntimeError(
267                "Subtract operation requires exactly 2 inputs".to_string(),
268            ));
269        }
270
271        let a = &inputs[0];
272        let b = &inputs[1];
273
274        if a.shape != b.shape {
275            return Err(JitError::RuntimeError(
276                "Shape mismatch in subtract operation".to_string(),
277            ));
278        }
279
280        let result_data: Vec<f32> = a
281            .data
282            .iter()
283            .zip(b.data.iter())
284            .map(|(&x, &y)| x - y)
285            .collect();
286
287        Ok(NodeExecutionResult {
288            data: result_data,
289            shape: a.shape.clone(),
290            dtype: a.dtype,
291        })
292    }
293
294    fn execute_divide_operation(
295        &self,
296        inputs: &[NodeExecutionResult],
297    ) -> JitResult<NodeExecutionResult> {
298        if inputs.len() != 2 {
299            return Err(JitError::RuntimeError(
300                "Divide operation requires exactly 2 inputs".to_string(),
301            ));
302        }
303
304        let a = &inputs[0];
305        let b = &inputs[1];
306
307        if a.shape != b.shape {
308            return Err(JitError::RuntimeError(
309                "Shape mismatch in divide operation".to_string(),
310            ));
311        }
312
313        let result_data: Vec<f32> = a
314            .data
315            .iter()
316            .zip(b.data.iter())
317            .map(|(&x, &y)| {
318                if y.abs() < f32::EPSILON {
319                    f32::INFINITY
320                } else {
321                    x / y
322                }
323            })
324            .collect();
325
326        Ok(NodeExecutionResult {
327            data: result_data,
328            shape: a.shape.clone(),
329            dtype: a.dtype,
330        })
331    }
332
333    fn execute_relu_operation(
334        &self,
335        inputs: &[NodeExecutionResult],
336    ) -> JitResult<NodeExecutionResult> {
337        if inputs.len() != 1 {
338            return Err(JitError::RuntimeError(
339                "ReLU operation requires exactly 1 input".to_string(),
340            ));
341        }
342
343        let input = &inputs[0];
344        let result_data: Vec<f32> = input.data.iter().map(|&x| x.max(0.0)).collect();
345
346        Ok(NodeExecutionResult {
347            data: result_data,
348            shape: input.shape.clone(),
349            dtype: input.dtype,
350        })
351    }
352
353    fn execute_sigmoid_operation(
354        &self,
355        inputs: &[NodeExecutionResult],
356    ) -> JitResult<NodeExecutionResult> {
357        if inputs.len() != 1 {
358            return Err(JitError::RuntimeError(
359                "Sigmoid operation requires exactly 1 input".to_string(),
360            ));
361        }
362
363        let input = &inputs[0];
364        let result_data: Vec<f32> = input
365            .data
366            .iter()
367            .map(|&x| 1.0 / (1.0 + (-x).exp()))
368            .collect();
369
370        Ok(NodeExecutionResult {
371            data: result_data,
372            shape: input.shape.clone(),
373            dtype: input.dtype,
374        })
375    }
376
377    fn execute_tanh_operation(
378        &self,
379        inputs: &[NodeExecutionResult],
380    ) -> JitResult<NodeExecutionResult> {
381        if inputs.len() != 1 {
382            return Err(JitError::RuntimeError(
383                "Tanh operation requires exactly 1 input".to_string(),
384            ));
385        }
386
387        let input = &inputs[0];
388        let result_data: Vec<f32> = input.data.iter().map(|&x| x.tanh()).collect();
389
390        Ok(NodeExecutionResult {
391            data: result_data,
392            shape: input.shape.clone(),
393            dtype: input.dtype,
394        })
395    }
396
397    fn execute_matmul_operation(
398        &self,
399        inputs: &[NodeExecutionResult],
400    ) -> JitResult<NodeExecutionResult> {
401        if inputs.len() != 2 {
402            return Err(JitError::RuntimeError(
403                "MatMul operation requires exactly 2 inputs".to_string(),
404            ));
405        }
406
407        // Simplified matrix multiplication - in practice this would be more complex
408        let a = &inputs[0];
409        let b = &inputs[1];
410
411        // For simplicity, assume both are 2D matrices
412        if a.shape.ndim() != 2 || b.shape.ndim() != 2 {
413            return Err(JitError::RuntimeError(
414                "MatMul requires 2D matrices".to_string(),
415            ));
416        }
417
418        let (m, k) = (a.shape.dims()[0], a.shape.dims()[1]);
419        let (k2, n) = (b.shape.dims()[0], b.shape.dims()[1]);
420
421        if k != k2 {
422            return Err(JitError::RuntimeError(
423                "Matrix dimension mismatch".to_string(),
424            ));
425        }
426
427        let mut result_data = vec![0.0; m * n];
428
429        for i in 0..m {
430            for j in 0..n {
431                for l in 0..k {
432                    result_data[i * n + j] += a.data[i * k + l] * b.data[l * n + j];
433                }
434            }
435        }
436
437        Ok(NodeExecutionResult {
438            data: result_data,
439            shape: Shape::new(vec![m, n]),
440            dtype: a.dtype,
441        })
442    }
443
444    fn execute_reshape_operation(
445        &self,
446        inputs: &[NodeExecutionResult],
447        attributes: &HashMap<String, crate::graph::Attribute>,
448    ) -> JitResult<NodeExecutionResult> {
449        if inputs.len() != 1 {
450            return Err(JitError::RuntimeError(
451                "Reshape operation requires exactly 1 input".to_string(),
452            ));
453        }
454
455        let input = &inputs[0];
456
457        // Parse new shape from attributes
458        let shape_attr = attributes.get("shape").ok_or_else(|| {
459            JitError::RuntimeError("Reshape operation missing shape attribute".to_string())
460        })?;
461
462        // Extract string value from Attribute enum
463        let new_shape_str = match shape_attr {
464            crate::graph::Attribute::String(s) => s,
465            _ => {
466                return Err(JitError::RuntimeError(
467                    "Reshape shape attribute must be a string".to_string(),
468                ))
469            }
470        };
471
472        // Simplified shape parsing - in practice this would be more robust
473        let new_dims: Result<Vec<usize>, _> = new_shape_str
474            .trim_matches(['[', ']'])
475            .split(',')
476            .map(|s| s.trim().parse())
477            .collect();
478
479        let new_dims =
480            new_dims.map_err(|_| JitError::RuntimeError("Invalid shape format".to_string()))?;
481        let new_shape = Shape::new(new_dims);
482
483        // Verify that total elements remain the same
484        if input.shape.numel() != new_shape.numel() {
485            return Err(JitError::RuntimeError(
486                "Reshape: total elements must remain constant".to_string(),
487            ));
488        }
489
490        Ok(NodeExecutionResult {
491            data: input.data.clone(),
492            shape: new_shape,
493            dtype: input.dtype,
494        })
495    }
496
497    fn execute_transpose_operation(
498        &self,
499        inputs: &[NodeExecutionResult],
500    ) -> JitResult<NodeExecutionResult> {
501        if inputs.len() != 1 {
502            return Err(JitError::RuntimeError(
503                "Transpose operation requires exactly 1 input".to_string(),
504            ));
505        }
506
507        let input = &inputs[0];
508
509        // Simplified transpose for 2D matrices
510        if input.shape.ndim() != 2 {
511            return Err(JitError::RuntimeError(
512                "Transpose currently supports only 2D matrices".to_string(),
513            ));
514        }
515
516        let (rows, cols) = (input.shape.dims()[0], input.shape.dims()[1]);
517        let mut result_data = vec![0.0; rows * cols];
518
519        for i in 0..rows {
520            for j in 0..cols {
521                result_data[j * rows + i] = input.data[i * cols + j];
522            }
523        }
524
525        Ok(NodeExecutionResult {
526            data: result_data,
527            shape: Shape::new(vec![cols, rows]),
528            dtype: input.dtype,
529        })
530    }
531
532    fn execute_concat_operation(
533        &self,
534        inputs: &[NodeExecutionResult],
535        attributes: &HashMap<String, crate::graph::Attribute>,
536    ) -> JitResult<NodeExecutionResult> {
537        if inputs.is_empty() {
538            return Err(JitError::RuntimeError(
539                "Concat operation requires at least 1 input".to_string(),
540            ));
541        }
542
543        // Parse axis from attributes
544        let axis = attributes
545            .get("axis")
546            .and_then(|attr| match attr {
547                crate::graph::Attribute::String(s) => s.parse::<usize>().ok(),
548                crate::graph::Attribute::Int(i) => Some(*i as usize),
549                _ => None,
550            })
551            .unwrap_or(0);
552
553        // Simplified concatenation along axis 0
554        if axis != 0 {
555            return Err(JitError::RuntimeError(
556                "Concat currently supports only axis 0".to_string(),
557            ));
558        }
559
560        let first_input = &inputs[0];
561        let mut total_size = first_input.shape.dims()[0];
562        let mut result_data = first_input.data.clone();
563
564        for input in &inputs[1..] {
565            if input.shape.ndim() != first_input.shape.ndim() {
566                return Err(JitError::RuntimeError(
567                    "All inputs must have same number of dimensions".to_string(),
568                ));
569            }
570
571            // Check that all dimensions except axis 0 match
572            for (i, (&dim1, &dim2)) in first_input.shape.dims()[1..]
573                .iter()
574                .zip(input.shape.dims()[1..].iter())
575                .enumerate()
576            {
577                if dim1 != dim2 {
578                    return Err(JitError::RuntimeError(format!(
579                        "Dimension mismatch at axis {}",
580                        i + 1
581                    )));
582                }
583            }
584
585            total_size += input.shape.dims()[0];
586            result_data.extend_from_slice(&input.data);
587        }
588
589        let mut new_dims = first_input.shape.dims().to_vec();
590        new_dims[0] = total_size;
591
592        Ok(NodeExecutionResult {
593            data: result_data,
594            shape: Shape::new(new_dims),
595            dtype: first_input.dtype,
596        })
597    }
598
599    fn execute_split_operation(
600        &self,
601        inputs: &[NodeExecutionResult],
602        attributes: &HashMap<String, crate::graph::Attribute>,
603    ) -> JitResult<NodeExecutionResult> {
604        if inputs.len() != 1 {
605            return Err(JitError::RuntimeError(
606                "Split operation requires exactly 1 input".to_string(),
607            ));
608        }
609
610        // For simplicity, return the first split only
611        // In practice, this would return multiple outputs
612        Ok(inputs[0].clone())
613    }
614
615    // IR instruction implementations
616
617    fn execute_ir_add(
618        &self,
619        instruction: &Instruction,
620        execution_state: &ExecutionState,
621    ) -> JitResult<DebugValue> {
622        // Simplified IR add - would access operands from instruction
623        Ok(DebugValue::Scalar(42.0))
624    }
625
626    fn execute_ir_multiply(
627        &self,
628        instruction: &Instruction,
629        execution_state: &ExecutionState,
630    ) -> JitResult<DebugValue> {
631        // Simplified IR multiply
632        Ok(DebugValue::Scalar(84.0))
633    }
634
635    fn execute_ir_subtract(
636        &self,
637        instruction: &Instruction,
638        execution_state: &ExecutionState,
639    ) -> JitResult<DebugValue> {
640        // Simplified IR subtract
641        Ok(DebugValue::Scalar(21.0))
642    }
643
644    fn execute_ir_divide(
645        &self,
646        instruction: &Instruction,
647        execution_state: &ExecutionState,
648    ) -> JitResult<DebugValue> {
649        // Simplified IR divide
650        Ok(DebugValue::Scalar(2.0))
651    }
652
653    fn execute_ir_const(&self, instruction: &Instruction) -> JitResult<DebugValue> {
654        // In a real implementation, we'd extract the constant value from instruction attributes
655        Ok(DebugValue::Scalar(1.0))
656    }
657
658    fn execute_ir_load(
659        &self,
660        instruction: &Instruction,
661        execution_state: &ExecutionState,
662    ) -> JitResult<DebugValue> {
663        // Load from memory - simplified
664        Ok(DebugValue::Scalar(std::f64::consts::PI))
665    }
666
667    fn execute_ir_store(
668        &self,
669        instruction: &Instruction,
670        execution_state: &mut ExecutionState,
671    ) -> JitResult<()> {
672        // Store to memory - simplified
673        Ok(())
674    }
675
676    fn execute_ir_call(
677        &self,
678        instruction: &Instruction,
679        ir_module: &IrModule,
680        execution_state: &ExecutionState,
681    ) -> JitResult<DebugValue> {
682        // Function call - simplified
683        Ok(DebugValue::Scalar(100.0))
684    }
685
686    // Helper methods
687
688    fn get_node_inputs(
689        &self,
690        graph: &ComputationGraph,
691        node_id: NodeId,
692    ) -> JitResult<Vec<NodeExecutionResult>> {
693        // Simplified - in practice this would get actual computed values from predecessor nodes
694        Ok(vec![NodeExecutionResult {
695            data: vec![1.0, 2.0, 3.0],
696            shape: Shape::new(vec![3]),
697            dtype: DType::F32,
698        }])
699    }
700
701    fn record_operation_timing(&mut self, operation: &str, duration: Duration) {
702        self.instruction_timings
703            .entry(operation.to_string())
704            .or_insert_with(Vec::new)
705            .push(duration);
706
707        // Update operation statistics
708        let timings = &self.instruction_timings[operation];
709        let count = timings.len();
710        let total_time: Duration = timings.iter().sum();
711        let average_time = total_time / count as u32;
712        let min_time = *timings.iter().min().expect("timings should not be empty");
713        let max_time = *timings.iter().max().expect("timings should not be empty");
714
715        self.operation_stats.insert(
716            operation.to_string(),
717            OperationStatistics {
718                count,
719                total_time,
720                average_time,
721                min_time,
722                max_time,
723            },
724        );
725    }
726
727    /// Get execution statistics
728    pub fn get_statistics(&self) -> DebugStatistics {
729        DebugStatistics {
730            total_steps: self.execution_count,
731            total_execution_time: self.total_execution_time,
732            breakpoints_hit: 0,   // Would be tracked separately
733            watches_triggered: 0, // Would be tracked separately
734        }
735    }
736
737    /// Get detailed operation statistics
738    pub fn get_operation_statistics(&self) -> &HashMap<String, OperationStatistics> {
739        &self.operation_stats
740    }
741
742    /// Get timing information for a specific operation
743    pub fn get_operation_timings(&self, operation: &str) -> Option<&Vec<Duration>> {
744        self.instruction_timings.get(operation)
745    }
746
747    /// Reset all statistics and timing information
748    pub fn reset_statistics(&mut self) {
749        self.execution_count = 0;
750        self.total_execution_time = Duration::new(0, 0);
751        self.instruction_timings.clear();
752        self.operation_stats.clear();
753    }
754
755    /// Get the configuration
756    pub fn config(&self) -> &DebuggerConfig {
757        &self.config
758    }
759
760    /// Update the configuration
761    pub fn update_config(&mut self, config: DebuggerConfig) {
762        self.config = config;
763    }
764}
765
766#[cfg(test)]
767mod tests {
768    use super::*;
769
770    #[test]
771    fn test_debug_execution_engine_creation() {
772        let config = DebuggerConfig::default();
773        let engine = DebugExecutionEngine::new(config);
774
775        assert_eq!(engine.execution_count, 0);
776        assert_eq!(engine.total_execution_time, Duration::new(0, 0));
777        assert!(engine.instruction_timings.is_empty());
778        assert!(engine.operation_stats.is_empty());
779    }
780
781    #[test]
782    fn test_operation_timing_recording() {
783        let config = DebuggerConfig::default();
784        let mut engine = DebugExecutionEngine::new(config);
785
786        let duration = Duration::from_millis(10);
787        engine.record_operation_timing("add", duration);
788
789        assert_eq!(engine.instruction_timings.get("add").unwrap().len(), 1);
790        assert!(engine.operation_stats.contains_key("add"));
791
792        let stats = &engine.operation_stats["add"];
793        assert_eq!(stats.count, 1);
794        assert_eq!(stats.total_time, duration);
795        assert_eq!(stats.min_time, duration);
796        assert_eq!(stats.max_time, duration);
797    }
798
799    #[test]
800    fn test_add_operation() {
801        let config = DebuggerConfig::default();
802        let engine = DebugExecutionEngine::new(config);
803
804        let input1 = NodeExecutionResult {
805            data: vec![1.0, 2.0, 3.0],
806            shape: Shape::new(vec![3]),
807            dtype: DType::F32,
808        };
809
810        let input2 = NodeExecutionResult {
811            data: vec![4.0, 5.0, 6.0],
812            shape: Shape::new(vec![3]),
813            dtype: DType::F32,
814        };
815
816        let result = engine.execute_add_operation(&[input1, input2]).unwrap();
817        assert_eq!(result.data, vec![5.0, 7.0, 9.0]);
818        assert_eq!(result.shape.dims(), &[3]);
819    }
820
821    #[test]
822    fn test_relu_operation() {
823        let config = DebuggerConfig::default();
824        let engine = DebugExecutionEngine::new(config);
825
826        let input = NodeExecutionResult {
827            data: vec![-1.0, 0.0, 1.0, -2.0, 3.0],
828            shape: Shape::new(vec![5]),
829            dtype: DType::F32,
830        };
831
832        let result = engine.execute_relu_operation(&[input]).unwrap();
833        assert_eq!(result.data, vec![0.0, 0.0, 1.0, 0.0, 3.0]);
834    }
835
836    #[test]
837    fn test_statistics_reset() {
838        let config = DebuggerConfig::default();
839        let mut engine = DebugExecutionEngine::new(config);
840
841        engine.record_operation_timing("test", Duration::from_millis(10));
842        engine.execution_count = 5;
843        engine.total_execution_time = Duration::from_millis(50);
844
845        engine.reset_statistics();
846
847        assert_eq!(engine.execution_count, 0);
848        assert_eq!(engine.total_execution_time, Duration::new(0, 0));
849        assert!(engine.instruction_timings.is_empty());
850        assert!(engine.operation_stats.is_empty());
851    }
852}