1use super::core::{
7 ContinueResult, DebugState, DebugStatistics, DebugValue, DebuggerConfig,
8 DisassemblyInstruction, DisassemblyView, EvaluationResult, ExecutionLocation, ExecutionState,
9 ExecutionStep, InspectionResult, InspectionTarget, InstructionExecutionResult, MemoryView,
10 NodeExecutionResult, NodeMetadata, StepResult, TypeInfo,
11};
12use super::execution::DebugExecutionEngine;
13use super::state::{CallStack, MemoryState};
14use crate::{ir::IrModule, ComputationGraph, JitError, JitResult, NodeId};
15use std::collections::HashMap;
16use std::time::SystemTime;
17use torsh_core::{DType, Shape};
18
19pub struct DebugSession {
21 graph: Option<ComputationGraph>,
22 ir_module: Option<IrModule>,
23 current_location: ExecutionLocation,
24 execution_state: ExecutionState,
25 execution_trace: Vec<ExecutionStep>,
26 call_stack: CallStack,
27 variable_bindings: HashMap<String, DebugValue>,
28 memory_state: MemoryState,
29 statistics: DebugStatistics,
30 config: DebuggerConfig,
31 execution_engine: DebugExecutionEngine,
32}
33
34impl DebugSession {
35 pub fn new(graph: ComputationGraph, config: DebuggerConfig) -> Self {
41 let initial_location = ExecutionLocation::GraphNode(
42 graph
43 .nodes()
44 .next()
45 .map(|(id, _)| id)
46 .unwrap_or(NodeId::new(0)),
47 );
48
49 let execution_engine = DebugExecutionEngine::new(config.clone());
50
51 Self {
52 graph: Some(graph),
53 ir_module: None,
54 current_location: initial_location,
55 execution_state: ExecutionState::new(),
56 execution_trace: Vec::new(),
57 call_stack: CallStack::new(),
58 variable_bindings: HashMap::new(),
59 memory_state: MemoryState::new(),
60 statistics: DebugStatistics::new(),
61 config,
62 execution_engine,
63 }
64 }
65
66 pub fn from_ir(ir_module: IrModule, config: DebuggerConfig) -> Self {
72 let initial_location = ExecutionLocation::Instruction {
73 function: "main".to_string(),
74 instruction_index: 0,
75 };
76
77 let execution_engine = DebugExecutionEngine::new(config.clone());
78
79 Self {
80 graph: None,
81 ir_module: Some(ir_module),
82 current_location: initial_location,
83 execution_state: ExecutionState::new(),
84 execution_trace: Vec::new(),
85 call_stack: CallStack::new(),
86 variable_bindings: HashMap::new(),
87 memory_state: MemoryState::new(),
88 statistics: DebugStatistics::new(),
89 config,
90 execution_engine,
91 }
92 }
93
94 pub fn step(&mut self) -> JitResult<StepResult> {
99 if self.is_execution_complete() {
100 return Err(JitError::AnalysisError(
101 "Execution already completed".to_string(),
102 ));
103 }
104
105 let step_start = std::time::Instant::now();
106
107 let result = match self.current_location.clone() {
108 ExecutionLocation::GraphNode(node_id) => self.step_graph_node(node_id),
109 ExecutionLocation::Instruction {
110 function,
111 instruction_index,
112 } => self.step_ir_instruction(&function, instruction_index),
113 ExecutionLocation::Completed => {
114 return Err(JitError::AnalysisError(
115 "Execution already completed".to_string(),
116 ));
117 }
118 };
119
120 let step_duration = step_start.elapsed();
121 self.statistics.total_steps += 1;
122 self.statistics.total_execution_time += step_duration;
123
124 match result {
125 Ok(_) => {
126 if self.is_execution_complete() {
127 Ok(StepResult::Completed)
128 } else {
129 Ok(StepResult::Success)
130 }
131 }
132 Err(e) => Err(e),
133 }
134 }
135
136 pub fn step_over(&mut self) -> JitResult<StepResult> {
141 let current_call_depth = self.call_stack.depth();
142
143 loop {
144 let result = self.step()?;
145
146 match result {
147 StepResult::Completed => return Ok(StepResult::Completed),
148 StepResult::Success => {
149 if self.call_stack.depth() <= current_call_depth {
151 break;
152 }
153 }
154 }
155 }
156
157 Ok(StepResult::Success)
158 }
159
160 pub fn step_into(&mut self) -> JitResult<StepResult> {
165 self.step()
167 }
168
169 pub fn step_out(&mut self) -> JitResult<StepResult> {
174 let target_depth = self.call_stack.depth().saturating_sub(1);
175
176 while self.call_stack.depth() > target_depth {
177 let result = self.step()?;
178 if matches!(result, StepResult::Completed) {
179 return Ok(StepResult::Completed);
180 }
181 }
182
183 Ok(StepResult::Success)
184 }
185
186 pub fn continue_execution(&mut self) -> JitResult<ContinueResult> {
191 loop {
192 if self.should_break_at_current_location() {
194 return Ok(ContinueResult::Breakpoint);
195 }
196
197 match self.step() {
199 Ok(StepResult::Success) => {
200 if self.is_execution_complete() {
202 return Ok(ContinueResult::Completed);
203 }
204 }
205 Ok(StepResult::Completed) => {
206 return Ok(ContinueResult::Completed);
207 }
208 Err(e) => return Err(e),
209 }
210 }
211 }
212
213 fn step_graph_node(&mut self, node_id: NodeId) -> JitResult<()> {
215 if let Some(graph) = &self.graph {
216 if let Some(node) = graph.node(node_id) {
217 let step = ExecutionStep {
219 location: self.current_location.clone(),
220 timestamp: SystemTime::now(),
221 operation: node.operation_type().to_string(),
222 inputs: self.get_node_inputs(node_id),
223 outputs: Vec::new(), state_changes: HashMap::new(),
225 };
226
227 let result = self
229 .execution_engine
230 .execute_node_debug(node, graph, node_id)?;
231
232 self.variable_bindings.insert(
234 format!("node_{:?}", node_id),
235 DebugValue::Tensor {
236 data: result.data.clone(),
237 shape: result.shape.clone(),
238 dtype: result.dtype,
239 },
240 );
241
242 self.advance_to_next_graph_node(node_id)?;
244
245 let mut completed_step = step;
247 completed_step.outputs = vec![result];
248 self.execution_trace.push(completed_step);
249 }
250 }
251
252 Ok(())
253 }
254
255 fn step_ir_instruction(&mut self, function: &str, instruction_index: usize) -> JitResult<()> {
257 if let Some(ir_module) = self.ir_module.clone() {
258 if let Some(func) = ir_module.get_function(function) {
259 if let Some(instruction) = func.instructions().get(instruction_index) {
260 let step = ExecutionStep {
262 location: self.current_location.clone(),
263 timestamp: SystemTime::now(),
264 operation: format!("{:?}", instruction),
265 inputs: self.get_instruction_inputs(instruction),
266 outputs: Vec::new(),
267 state_changes: HashMap::new(),
268 };
269
270 let result = self.execution_engine.execute_instruction_debug(
272 instruction,
273 &ir_module,
274 &mut self.execution_state,
275 )?;
276
277 self.update_execution_state_from_instruction(instruction, result)?;
279
280 self.advance_to_next_instruction(function, instruction_index)?;
282
283 self.execution_trace.push(step);
285 }
286 }
287 }
288
289 Ok(())
290 }
291
292 fn get_node_inputs(&self, node_id: NodeId) -> Vec<NodeExecutionResult> {
294 if let Some(graph) = &self.graph {
296 let inputs = graph.get_node_inputs(node_id);
297 inputs
298 .iter()
299 .map(|&input_id| {
300 let var_name = format!("node_{:?}", input_id);
302 if let Some(DebugValue::Tensor { data, shape, dtype }) =
303 self.variable_bindings.get(&var_name)
304 {
305 NodeExecutionResult {
306 data: data.clone(),
307 shape: shape.clone(),
308 dtype: *dtype,
309 }
310 } else {
311 NodeExecutionResult {
313 data: vec![1.0, 2.0, 3.0],
314 shape: Shape::new(vec![3]),
315 dtype: DType::F32,
316 }
317 }
318 })
319 .collect()
320 } else {
321 Vec::new()
322 }
323 }
324
325 fn get_instruction_inputs(
327 &self,
328 instruction: &crate::ir::Instruction,
329 ) -> Vec<NodeExecutionResult> {
330 Vec::new()
332 }
333
334 fn advance_to_next_graph_node(&mut self, current_node: NodeId) -> JitResult<()> {
336 if let Some(graph) = &self.graph {
337 let successors: Vec<_> = graph
339 .edges()
340 .filter(|(source, _target, _edge)| *source == current_node)
341 .map(|(_source, target, _edge)| target)
342 .collect();
343
344 if let Some(next_node) = successors.first() {
345 self.current_location = ExecutionLocation::GraphNode(*next_node);
346 } else {
347 self.current_location = ExecutionLocation::Completed;
349 }
350 }
351 Ok(())
352 }
353
354 fn advance_to_next_instruction(
356 &mut self,
357 function: &str,
358 current_index: usize,
359 ) -> JitResult<()> {
360 if let Some(ir_module) = &self.ir_module {
361 if let Some(func) = ir_module.get_function(function) {
362 let next_index = current_index + 1;
363 if next_index < func.instructions().len() {
364 self.current_location = ExecutionLocation::Instruction {
365 function: function.to_string(),
366 instruction_index: next_index,
367 };
368 } else {
369 if !self.call_stack.is_empty() {
371 self.current_location = self.call_stack.pop();
372 } else {
373 self.current_location = ExecutionLocation::Completed;
374 }
375 }
376 }
377 }
378 Ok(())
379 }
380
381 fn update_execution_state_from_instruction(
383 &mut self,
384 instruction: &crate::ir::Instruction,
385 result: InstructionExecutionResult,
386 ) -> JitResult<()> {
387 match result {
388 InstructionExecutionResult::Value(value) => {
389 let var_name = format!("temp_{}", self.execution_trace.len());
391 self.variable_bindings.insert(var_name, value);
392 }
393 InstructionExecutionResult::SideEffect => {
394 }
397 InstructionExecutionResult::Return => {
398 if !self.call_stack.is_empty() {
400 let return_location = self.call_stack.pop();
401 self.current_location = return_location;
402 } else {
403 self.current_location = ExecutionLocation::Completed;
404 }
405 }
406 InstructionExecutionResult::NoOp => {
407 }
409 }
410 Ok(())
411 }
412
413 fn should_break_at_current_location(&self) -> bool {
415 false
418 }
419
420 pub fn is_execution_complete(&self) -> bool {
422 matches!(self.current_location, ExecutionLocation::Completed)
423 }
424
425 pub fn inspect_target(&self, target: &InspectionTarget) -> JitResult<InspectionResult> {
433 match target {
434 InspectionTarget::Variable(name) => {
435 if let Some(value) = self.variable_bindings.get(name) {
436 Ok(InspectionResult::Variable {
437 name: name.clone(),
438 value: value.clone(),
439 type_info: self.get_type_info_for_value(value),
440 })
441 } else {
442 Err(JitError::RuntimeError(format!(
443 "Variable '{}' not found",
444 name
445 )))
446 }
447 }
448 InspectionTarget::Node(node_id) => {
449 let var_name = format!("node_{:?}", node_id);
450 if let Some(value) = self.variable_bindings.get(&var_name) {
451 Ok(InspectionResult::Node {
452 node_id: *node_id,
453 value: value.clone(),
454 metadata: self.get_node_metadata(*node_id),
455 })
456 } else {
457 Err(JitError::RuntimeError(format!(
458 "Node {:?} not executed yet",
459 node_id
460 )))
461 }
462 }
463 InspectionTarget::Memory(address) => {
464 let memory_content = self.memory_state.read_memory(*address, 16)?;
465 Ok(InspectionResult::Memory {
466 address: *address,
467 content: memory_content,
468 size: 16,
469 })
470 }
471 }
472 }
473
474 fn get_type_info_for_value(&self, value: &DebugValue) -> TypeInfo {
476 match value {
477 DebugValue::Scalar(_) => TypeInfo {
478 type_name: "f64".to_string(),
479 size_bytes: 8,
480 alignment: 8,
481 },
482 DebugValue::Integer(_) => TypeInfo {
483 type_name: "i64".to_string(),
484 size_bytes: 8,
485 alignment: 8,
486 },
487 DebugValue::Boolean(_) => TypeInfo {
488 type_name: "bool".to_string(),
489 size_bytes: 1,
490 alignment: 1,
491 },
492 DebugValue::Tensor { dtype, shape, .. } => TypeInfo {
493 type_name: format!("Tensor<{:?}>", dtype),
494 size_bytes: shape.size(0).unwrap_or(1) * dtype.size_bytes(),
495 alignment: dtype.size_bytes(),
496 },
497 }
498 }
499
500 fn get_node_metadata(&self, node_id: NodeId) -> NodeMetadata {
502 if let Some(graph) = &self.graph {
503 if let Some(node) = graph.node(node_id) {
504 let input_count = graph.get_node_inputs(node_id).len();
505 return NodeMetadata {
506 operation: node.operation_type().to_string(),
507 input_count,
508 output_shape: node.output_shape.clone(),
509 dtype: node.dtype,
510 };
511 }
512 }
513
514 NodeMetadata {
515 operation: "unknown".to_string(),
516 input_count: 0,
517 output_shape: Shape::new(vec![]),
518 dtype: DType::F32,
519 }
520 }
521
522 pub fn evaluate_expression(&self, expression: &str) -> JitResult<EvaluationResult> {
530 if let Some(value) = self.variable_bindings.get(expression) {
532 Ok(EvaluationResult {
533 expression: expression.to_string(),
534 result: value.clone(),
535 success: true,
536 error_message: None,
537 })
538 } else if expression.starts_with("node_") {
539 if let Ok(node_index) = expression[5..].parse::<usize>() {
541 let var_name = format!("node_{}", node_index);
542 if let Some(value) = self.variable_bindings.get(&var_name) {
543 Ok(EvaluationResult {
544 expression: expression.to_string(),
545 result: value.clone(),
546 success: true,
547 error_message: None,
548 })
549 } else {
550 Ok(EvaluationResult {
551 expression: expression.to_string(),
552 result: DebugValue::Scalar(0.0),
553 success: false,
554 error_message: Some("Node not executed yet".to_string()),
555 })
556 }
557 } else {
558 Ok(EvaluationResult {
559 expression: expression.to_string(),
560 result: DebugValue::Scalar(0.0),
561 success: false,
562 error_message: Some("Invalid node reference".to_string()),
563 })
564 }
565 } else {
566 if let Ok(value) = expression.parse::<i64>() {
568 Ok(EvaluationResult {
569 expression: expression.to_string(),
570 result: DebugValue::Integer(value),
571 success: true,
572 error_message: None,
573 })
574 } else if let Ok(value) = expression.parse::<f64>() {
575 Ok(EvaluationResult {
576 expression: expression.to_string(),
577 result: DebugValue::Scalar(value),
578 success: true,
579 error_message: None,
580 })
581 } else if let Ok(value) = expression.parse::<bool>() {
582 Ok(EvaluationResult {
583 expression: expression.to_string(),
584 result: DebugValue::Boolean(value),
585 success: true,
586 error_message: None,
587 })
588 } else {
589 Ok(EvaluationResult {
590 expression: expression.to_string(),
591 result: DebugValue::Scalar(0.0),
592 success: false,
593 error_message: Some("Expression not found or invalid".to_string()),
594 })
595 }
596 }
597 }
598
599 pub fn get_current_state(&self) -> DebugState {
601 DebugState {
602 location: self.current_location.clone(),
603 call_stack: self.call_stack.clone(),
604 variables: self.variable_bindings.clone(),
605 execution_step: self.execution_trace.len(),
606 is_running: !self.is_execution_complete(),
607 }
608 }
609
610 pub fn get_execution_trace(&self) -> Vec<ExecutionStep> {
612 self.execution_trace.clone()
613 }
614
615 pub fn get_call_stack(&self) -> CallStack {
617 self.call_stack.clone()
618 }
619
620 pub fn get_local_variables(&self) -> HashMap<String, DebugValue> {
622 self.variable_bindings.clone()
624 }
625
626 pub fn get_memory_view(&self, address: u64) -> JitResult<MemoryView> {
628 let content = self.memory_state.read_memory(address, 64)?;
629 Ok(MemoryView {
630 start_address: address,
631 content,
632 size: 64,
633 })
634 }
635
636 pub fn disassemble_at(&self, location: ExecutionLocation) -> JitResult<DisassemblyView> {
638 match location {
639 ExecutionLocation::GraphNode(node_id) => {
640 if let Some(graph) = &self.graph {
641 if let Some(node) = graph.node(node_id) {
642 Ok(DisassemblyView {
643 location,
644 instructions: vec![DisassemblyInstruction {
645 address: node_id.index() as u64,
646 opcode: node.operation_type().to_string(),
647 operands: format!(
648 "inputs: {}",
649 graph.get_node_inputs(node_id).len()
650 ),
651 comment: Some(format!("Output shape: {:?}", node.output_shape)),
652 }],
653 })
654 } else {
655 Err(JitError::RuntimeError("Node not found".to_string()))
656 }
657 } else {
658 Err(JitError::RuntimeError("No graph available".to_string()))
659 }
660 }
661 ExecutionLocation::Instruction {
662 ref function,
663 instruction_index,
664 } => {
665 if let Some(ir_module) = &self.ir_module {
666 if let Some(func) = ir_module.get_function(function) {
667 if let Some(instruction) = func.instructions().get(instruction_index) {
668 Ok(DisassemblyView {
669 location,
670 instructions: vec![DisassemblyInstruction {
671 address: instruction_index as u64,
672 opcode: format!("{:?}", instruction.opcode),
673 operands: format!("operands: {}", instruction.operands.len()),
674 comment: Some(format!("result: {:?}", instruction.result)),
675 }],
676 })
677 } else {
678 Err(JitError::RuntimeError("Instruction not found".to_string()))
679 }
680 } else {
681 Err(JitError::RuntimeError("Function not found".to_string()))
682 }
683 } else {
684 Err(JitError::RuntimeError("No IR module available".to_string()))
685 }
686 }
687 ExecutionLocation::Completed => {
688 Err(JitError::RuntimeError("Execution completed".to_string()))
689 }
690 }
691 }
692
693 pub fn get_statistics(&self) -> DebugStatistics {
695 let mut stats = self.statistics.clone();
697 let engine_stats = self.execution_engine.get_statistics();
698
699 stats.total_steps = stats.total_steps.max(engine_stats.total_steps);
700 stats.total_execution_time = stats
701 .total_execution_time
702 .max(engine_stats.total_execution_time);
703
704 stats
705 }
706
707 pub fn reset(&mut self) {
709 self.execution_trace.clear();
710 self.call_stack.clear();
711 self.variable_bindings.clear();
712 self.memory_state.clear();
713 self.statistics = DebugStatistics::new();
714 self.execution_state = ExecutionState::new();
715
716 if let Some(graph) = &self.graph {
718 self.current_location = ExecutionLocation::GraphNode(
719 graph
720 .nodes()
721 .next()
722 .map(|(id, _)| id)
723 .unwrap_or(NodeId::new(0)),
724 );
725 } else if self.ir_module.is_some() {
726 self.current_location = ExecutionLocation::Instruction {
727 function: "main".to_string(),
728 instruction_index: 0,
729 };
730 } else {
731 self.current_location = ExecutionLocation::Completed;
732 }
733 }
734
735 pub fn config(&self) -> &DebuggerConfig {
737 &self.config
738 }
739
740 pub fn update_config(&mut self, config: DebuggerConfig) {
742 self.config = config.clone();
743 self.execution_engine.update_config(config);
744 }
745
746 pub fn set_variable(&mut self, name: String, value: DebugValue) {
748 self.variable_bindings.insert(name, value);
749 }
750
751 pub fn get_variable(&self, name: &str) -> Option<&DebugValue> {
753 self.variable_bindings.get(name)
754 }
755
756 pub fn get_execution_engine_statistics(
758 &self,
759 ) -> &std::collections::HashMap<String, super::execution::OperationStatistics> {
760 self.execution_engine.get_operation_statistics()
761 }
762
763 pub fn memory_state(&self) -> &MemoryState {
765 &self.memory_state
766 }
767
768 pub fn memory_state_mut(&mut self) -> &mut MemoryState {
770 &mut self.memory_state
771 }
772}
773
774impl super::watch::ExpressionEvaluator for DebugSession {
776 fn evaluate_expression(&self, expression: &str) -> JitResult<EvaluationResult> {
777 self.evaluate_expression(expression)
778 }
779}
780
781#[cfg(test)]
782mod tests {
783 use super::*;
784
785 fn create_test_session() -> DebugSession {
786 let config = DebuggerConfig::default();
787 let graph = ComputationGraph::new(); DebugSession::new(graph, config)
790 }
791
792 #[test]
793 fn test_session_creation() {
794 let session = create_test_session();
795 assert!(!session.is_execution_complete());
796 assert_eq!(session.execution_trace.len(), 0);
797 assert_eq!(session.variable_bindings.len(), 0);
798 }
799
800 #[test]
801 fn test_variable_management() {
802 let mut session = create_test_session();
803
804 let value = DebugValue::Scalar(42.0);
805 session.set_variable("test_var".to_string(), value.clone());
806
807 assert_eq!(session.get_variable("test_var"), Some(&value));
808 assert_eq!(session.get_variable("nonexistent"), None);
809 }
810
811 #[test]
812 fn test_expression_evaluation() {
813 let mut session = create_test_session();
814
815 let result = session.evaluate_expression("42.5").unwrap();
817 assert!(result.success);
818 assert!(matches!(result.result, DebugValue::Scalar(42.5)));
819
820 let result = session.evaluate_expression("123").unwrap();
821 assert!(result.success);
822 assert!(matches!(result.result, DebugValue::Integer(123)));
823
824 let result = session.evaluate_expression("true").unwrap();
825 assert!(result.success);
826 assert!(matches!(result.result, DebugValue::Boolean(true)));
827
828 session.set_variable("x".to_string(), DebugValue::Scalar(3.14));
830 let result = session.evaluate_expression("x").unwrap();
831 assert!(result.success);
832 assert!(matches!(result.result, DebugValue::Scalar(3.14)));
833
834 let result = session.evaluate_expression("unknown_var").unwrap();
836 assert!(!result.success);
837 assert!(result.error_message.is_some());
838 }
839
840 #[test]
841 fn test_execution_state() {
842 let session = create_test_session();
843 let state = session.get_current_state();
844
845 assert!(state.is_running);
846 assert_eq!(state.execution_step, 0);
847 assert!(state.call_stack.is_empty());
848 assert!(state.variables.is_empty());
849 }
850
851 #[test]
852 fn test_memory_operations() {
853 let mut session = create_test_session();
854
855 let memory = session.memory_state_mut();
856 memory.write_memory(0x1000, &[1, 2, 3, 4]).unwrap();
857
858 let memory_view = session.get_memory_view(0x1000).unwrap();
859 assert_eq!(memory_view.start_address, 0x1000);
860 assert_eq!(&memory_view.content[0..4], &[1, 2, 3, 4]);
861 }
862
863 #[test]
864 fn test_session_reset() {
865 let mut session = create_test_session();
866
867 session.set_variable("test".to_string(), DebugValue::Scalar(1.0));
869 session.statistics.total_steps = 10;
870
871 session.reset();
873
874 assert_eq!(session.variable_bindings.len(), 0);
875 assert_eq!(session.execution_trace.len(), 0);
876 assert_eq!(session.statistics.total_steps, 0);
877 assert!(session.call_stack.is_empty());
878 }
879
880 #[test]
881 fn test_configuration_update() {
882 let mut session = create_test_session();
883
884 let mut new_config = DebuggerConfig::default();
885 new_config.max_trace_length = 5000;
886
887 session.update_config(new_config.clone());
888 assert_eq!(session.config().max_trace_length, 5000);
889 }
890
891 #[test]
892 fn test_inspection_targets() {
893 let mut session = create_test_session();
894
895 session.set_variable("test_var".to_string(), DebugValue::Scalar(42.0));
897 let result = session.inspect_target(&InspectionTarget::Variable("test_var".to_string()));
898 assert!(result.is_ok());
899
900 session
902 .memory_state_mut()
903 .write_memory(0x1000, &[1, 2, 3, 4])
904 .unwrap();
905 let result = session.inspect_target(&InspectionTarget::Memory(0x1000));
906 assert!(result.is_ok());
907
908 let result = session.inspect_target(&InspectionTarget::Variable("unknown".to_string()));
910 assert!(result.is_err());
911 }
912}