1use crate::{ir::IrModule, ComputationGraph, IrFunction, JitError, JitResult, NodeId};
10use std::collections::{HashMap, HashSet, VecDeque};
11use torsh_core::{DType, Shape};
12
13pub struct PartialEvaluator {
15 config: PartialEvalConfig,
16 constant_folder: ConstantFolder,
17 specializer: FunctionSpecializer,
18 dead_code_eliminator: DeadCodeEliminator,
19 loop_optimizer: LoopOptimizer,
20 symbolic_executor: SymbolicExecutor,
21}
22
23impl PartialEvaluator {
24 pub fn new(config: PartialEvalConfig) -> Self {
26 Self {
27 constant_folder: ConstantFolder::new(),
28 specializer: FunctionSpecializer::new(),
29 dead_code_eliminator: DeadCodeEliminator::new(),
30 loop_optimizer: LoopOptimizer::new(),
31 symbolic_executor: SymbolicExecutor::new(),
32 config,
33 }
34 }
35
36 pub fn evaluate_graph(&mut self, graph: &ComputationGraph) -> JitResult<OptimizedGraph> {
38 let mut working_graph = graph.clone();
39 let mut statistics = EvaluationStatistics::new();
40
41 let symbolic_info = self.symbolic_executor.execute(&working_graph)?;
43 statistics.symbolic_execution_time = symbolic_info.execution_time;
44
45 if self.config.enable_constant_folding {
47 let fold_result = self.constant_folder.fold_constants(&mut working_graph)?;
48 statistics.constants_folded = fold_result.constants_folded;
49 statistics.constant_folding_time = fold_result.execution_time;
50 }
51
52 if self.config.enable_specialization {
54 let spec_result = self
55 .specializer
56 .specialize_functions(&mut working_graph, &symbolic_info)?;
57 statistics.functions_specialized = spec_result.functions_specialized;
58 statistics.specialization_time = spec_result.execution_time;
59 }
60
61 if self.config.enable_dead_code_elimination {
63 let dce_result = self.dead_code_eliminator.eliminate(&mut working_graph)?;
64 statistics.dead_nodes_removed = dce_result.nodes_removed;
65 statistics.dead_code_elimination_time = dce_result.execution_time;
66 }
67
68 if self.config.enable_loop_optimization {
70 let loop_result = self.loop_optimizer.optimize_loops(&mut working_graph)?;
71 statistics.loops_optimized = loop_result.loops_optimized;
72 statistics.loop_optimization_time = loop_result.execution_time;
73 }
74
75 Ok(OptimizedGraph {
76 graph: working_graph,
77 statistics,
78 optimizations_applied: self.get_applied_optimizations(),
79 })
80 }
81
82 pub fn evaluate_ir(&mut self, ir_module: &IrModule) -> JitResult<OptimizedIrModule> {
84 let mut working_module = ir_module.clone();
85 let mut statistics = IrEvaluationStatistics::new();
86
87 let func_result = self.evaluate_function(&mut working_module)?;
89 statistics.merge(func_result);
90
91 self.optimize_module(&mut working_module)?;
93
94 Ok(OptimizedIrModule {
95 module: working_module,
96 statistics,
97 })
98 }
99
100 fn evaluate_function(
102 &mut self,
103 function: &mut IrFunction,
104 ) -> JitResult<IrEvaluationStatistics> {
105 let mut stats = IrEvaluationStatistics::new();
106
107 let deps = self.build_dependency_graph(function)?;
109
110 let data_flow = self.analyze_data_flow(function, &deps)?;
112
113 let const_result = self.propagate_constants(function, &data_flow)?;
115 stats.constants_propagated = const_result.constants_propagated;
116
117 let dead_result = self.eliminate_dead_instructions(function, &deps)?;
119 stats.dead_instructions_removed = dead_result.instructions_removed;
120
121 let strength_result = self.perform_strength_reduction(function)?;
123 stats.strength_reductions = strength_result.reductions_applied;
124
125 Ok(stats)
126 }
127
128 fn build_dependency_graph(&self, function: &IrFunction) -> JitResult<DependencyGraph> {
130 let mut deps = DependencyGraph::new();
131
132 for (idx, instruction) in function.instructions().enumerate() {
133 let inst_id = InstructionId(idx);
134 deps.add_instruction(inst_id, instruction.clone());
135
136 for operand in instruction.operands() {
138 let dep_id = InstructionId(operand.0 as usize);
141 deps.add_dependency(inst_id, dep_id);
142 }
143 }
144
145 Ok(deps)
146 }
147
148 fn analyze_data_flow(
150 &self,
151 function: &IrFunction,
152 deps: &DependencyGraph,
153 ) -> JitResult<DataFlowInfo> {
154 let mut data_flow = DataFlowInfo::new();
155
156 let mut reaching_defs: HashMap<InstructionId, HashSet<InstructionId>> = HashMap::new();
158 for (inst_id, instruction) in deps.instructions() {
159 let mut defs = HashSet::new();
160
161 for dep_id in deps.dependencies(inst_id) {
163 if let Some(dep_defs) = reaching_defs.get(dep_id) {
164 defs.extend(dep_defs.iter().cloned());
165 }
166 }
167
168 if instruction.produces_value() {
170 defs.insert(*inst_id);
171 }
172
173 reaching_defs.insert(*inst_id, defs);
174 }
175
176 data_flow.reaching_definitions = reaching_defs;
177
178 let mut live_vars: HashMap<InstructionId, HashSet<InstructionId>> = HashMap::new();
180 let instructions: Vec<_> = deps.instructions().collect();
181
182 for (inst_id, instruction) in instructions.iter().rev() {
183 let mut live = HashSet::new();
184
185 for user_id in deps.users(inst_id) {
187 if let Some(user_live) = live_vars.get(user_id) {
188 live.extend(user_live.iter().cloned());
189 }
190 }
191
192 if instruction.produces_value() {
194 live.remove(inst_id);
195 }
196
197 for operand in instruction.operands() {
199 let op_id = InstructionId(operand.0 as usize);
201 live.insert(op_id);
202 }
203
204 live_vars.insert(**inst_id, live);
205 }
206
207 data_flow.live_variables = live_vars;
208
209 Ok(data_flow)
210 }
211
212 fn propagate_constants(
214 &mut self,
215 function: &mut IrFunction,
216 data_flow: &DataFlowInfo,
217 ) -> JitResult<ConstantPropagationResult> {
218 let mut constants_propagated = 0;
219 let mut constant_values: HashMap<crate::ir::IrValue, ConstantValue> = HashMap::new();
220
221 for (idx, instruction) in function.instructions_mut().enumerate() {
222 let inst_id = crate::ir::IrValue(idx as u32);
223
224 let mut all_constant = true;
226 let mut operand_values = Vec::new();
227
228 for operand in instruction.operands() {
229 if let Some(const_val) = constant_values.get(operand) {
231 operand_values.push(Some(const_val.clone()));
232 } else {
233 operand_values.push(None);
234 all_constant = false;
235 }
236 }
237
238 if all_constant && self.can_evaluate_at_compile_time(instruction) {
240 if let Some(result) = self.evaluate_instruction(instruction, &operand_values)? {
241 constant_values.insert(inst_id, result.clone());
242
243 constants_propagated += 1;
247 }
248 }
249 }
250
251 Ok(ConstantPropagationResult {
252 constants_propagated,
253 execution_time: std::time::Duration::from_millis(1), })
255 }
256
257 fn can_evaluate_at_compile_time(&self, instruction: &crate::ir::Instruction) -> bool {
259 use crate::ir::IrOpcode;
260 match instruction.opcode {
261 IrOpcode::Add
262 | IrOpcode::Sub
263 | IrOpcode::Mul
264 | IrOpcode::Div
265 | IrOpcode::Neg
266 | IrOpcode::Sqrt
267 | IrOpcode::Exp
268 | IrOpcode::Log => true,
269 _ => false,
270 }
271 }
272
273 fn evaluate_instruction(
275 &self,
276 instruction: &crate::ir::Instruction,
277 operands: &[Option<ConstantValue>],
278 ) -> JitResult<Option<ConstantValue>> {
279 use crate::ir::IrOpcode;
280 match instruction.opcode {
281 IrOpcode::Add => {
282 if let (Some(Some(a)), Some(Some(b))) = (operands.get(0), operands.get(1)) {
283 Ok(Some(self.add_constants(a, b)?))
284 } else {
285 Ok(None)
286 }
287 }
288 IrOpcode::Sub => {
289 if let (Some(Some(a)), Some(Some(b))) = (operands.get(0), operands.get(1)) {
290 Ok(Some(self.sub_constants(a, b)?))
291 } else {
292 Ok(None)
293 }
294 }
295 IrOpcode::Mul => {
296 if let (Some(Some(a)), Some(Some(b))) = (operands.get(0), operands.get(1)) {
297 Ok(Some(self.mul_constants(a, b)?))
298 } else {
299 Ok(None)
300 }
301 }
302 IrOpcode::Div => {
303 if let (Some(Some(a)), Some(Some(b))) = (operands.get(0), operands.get(1)) {
304 Ok(Some(self.div_constants(a, b)?))
305 } else {
306 Ok(None)
307 }
308 }
309 IrOpcode::Neg => {
310 if let Some(Some(a)) = operands.get(0) {
311 Ok(Some(self.neg_constant(a)?))
312 } else {
313 Ok(None)
314 }
315 }
316 _ => Ok(None),
317 }
318 }
319
320 fn add_constants(&self, a: &ConstantValue, b: &ConstantValue) -> JitResult<ConstantValue> {
322 match (a, b) {
323 (ConstantValue::Float32(a), ConstantValue::Float32(b)) => {
324 Ok(ConstantValue::Float32(a + b))
325 }
326 (ConstantValue::Float64(a), ConstantValue::Float64(b)) => {
327 Ok(ConstantValue::Float64(a + b))
328 }
329 (ConstantValue::Int32(a), ConstantValue::Int32(b)) => Ok(ConstantValue::Int32(a + b)),
330 (ConstantValue::Int64(a), ConstantValue::Int64(b)) => Ok(ConstantValue::Int64(a + b)),
331 _ => Err(JitError::CompilationError(
332 "Incompatible types for addition".to_string(),
333 )),
334 }
335 }
336
337 fn sub_constants(&self, a: &ConstantValue, b: &ConstantValue) -> JitResult<ConstantValue> {
339 match (a, b) {
340 (ConstantValue::Float32(a), ConstantValue::Float32(b)) => {
341 Ok(ConstantValue::Float32(a - b))
342 }
343 (ConstantValue::Float64(a), ConstantValue::Float64(b)) => {
344 Ok(ConstantValue::Float64(a - b))
345 }
346 (ConstantValue::Int32(a), ConstantValue::Int32(b)) => Ok(ConstantValue::Int32(a - b)),
347 (ConstantValue::Int64(a), ConstantValue::Int64(b)) => Ok(ConstantValue::Int64(a - b)),
348 _ => Err(JitError::CompilationError(
349 "Incompatible types for subtraction".to_string(),
350 )),
351 }
352 }
353
354 fn mul_constants(&self, a: &ConstantValue, b: &ConstantValue) -> JitResult<ConstantValue> {
356 match (a, b) {
357 (ConstantValue::Float32(a), ConstantValue::Float32(b)) => {
358 Ok(ConstantValue::Float32(a * b))
359 }
360 (ConstantValue::Float64(a), ConstantValue::Float64(b)) => {
361 Ok(ConstantValue::Float64(a * b))
362 }
363 (ConstantValue::Int32(a), ConstantValue::Int32(b)) => Ok(ConstantValue::Int32(a * b)),
364 (ConstantValue::Int64(a), ConstantValue::Int64(b)) => Ok(ConstantValue::Int64(a * b)),
365 _ => Err(JitError::CompilationError(
366 "Incompatible types for multiplication".to_string(),
367 )),
368 }
369 }
370
371 fn div_constants(&self, a: &ConstantValue, b: &ConstantValue) -> JitResult<ConstantValue> {
373 match (a, b) {
374 (ConstantValue::Float32(a), ConstantValue::Float32(b)) => {
375 if *b == 0.0 {
376 Err(JitError::CompilationError("Division by zero".to_string()))
377 } else {
378 Ok(ConstantValue::Float32(a / b))
379 }
380 }
381 (ConstantValue::Float64(a), ConstantValue::Float64(b)) => {
382 if *b == 0.0 {
383 Err(JitError::CompilationError("Division by zero".to_string()))
384 } else {
385 Ok(ConstantValue::Float64(a / b))
386 }
387 }
388 (ConstantValue::Int32(a), ConstantValue::Int32(b)) => {
389 if *b == 0 {
390 Err(JitError::CompilationError("Division by zero".to_string()))
391 } else {
392 Ok(ConstantValue::Int32(a / b))
393 }
394 }
395 (ConstantValue::Int64(a), ConstantValue::Int64(b)) => {
396 if *b == 0 {
397 Err(JitError::CompilationError("Division by zero".to_string()))
398 } else {
399 Ok(ConstantValue::Int64(a / b))
400 }
401 }
402 _ => Err(JitError::CompilationError(
403 "Incompatible types for division".to_string(),
404 )),
405 }
406 }
407
408 fn neg_constant(&self, a: &ConstantValue) -> JitResult<ConstantValue> {
410 match a {
411 ConstantValue::Float32(a) => Ok(ConstantValue::Float32(-a)),
412 ConstantValue::Float64(a) => Ok(ConstantValue::Float64(-a)),
413 ConstantValue::Int32(a) => Ok(ConstantValue::Int32(-a)),
414 ConstantValue::Int64(a) => Ok(ConstantValue::Int64(-a)),
415 _ => Err(JitError::CompilationError(
416 "Cannot negate this constant type".to_string(),
417 )),
418 }
419 }
420
421 fn eliminate_dead_instructions(
423 &mut self,
424 function: &mut IrFunction,
425 deps: &DependencyGraph,
426 ) -> JitResult<DeadInstructionResult> {
427 let mut instructions_removed = 0;
428 let mut to_remove = HashSet::new();
429
430 for (inst_id, _) in deps.instructions() {
432 if deps.users(inst_id).is_empty()
433 && !self.has_side_effects(deps.get_instruction(inst_id))
434 {
435 to_remove.insert(*inst_id);
436 }
437 }
438
439 function.retain_instructions(|idx, _| {
441 let inst_id = InstructionId(idx);
442 if to_remove.contains(&inst_id) {
443 instructions_removed += 1;
444 false
445 } else {
446 true
447 }
448 });
449
450 Ok(DeadInstructionResult {
451 instructions_removed,
452 execution_time: std::time::Duration::from_millis(1), })
454 }
455
456 fn has_side_effects(&self, instruction: &crate::ir::Instruction) -> bool {
458 use crate::ir::IrOpcode;
459 match instruction.opcode {
460 IrOpcode::Store | IrOpcode::Call => true,
461 _ => false,
462 }
463 }
464
465 fn perform_strength_reduction(
467 &mut self,
468 ir_module: &mut crate::ir::IrModule,
469 ) -> JitResult<StrengthReductionResult> {
470 let mut reductions_applied = 0;
471
472 for (_block_id, block) in ir_module.blocks.iter_mut() {
474 for instruction in &mut block.instructions {
475 use crate::ir::IrOpcode;
476 match instruction.opcode {
477 IrOpcode::Mul => {
479 reductions_applied += 1;
483 }
484 IrOpcode::Div => {
486 reductions_applied += 1;
488 }
489 _ => {}
490 }
491 }
492 }
493
494 Ok(StrengthReductionResult {
495 reductions_applied,
496 execution_time: std::time::Duration::from_millis(1), })
498 }
499
500 fn optimize_module(&mut self, module: &mut IrModule) -> JitResult<()> {
502 let _ = module.remove_unused_functions();
504
505 if self.config.enable_inlining {
507 module.inline_small_functions()?;
508 }
509
510 Ok(())
511 }
512
513 fn get_applied_optimizations(&self) -> Vec<OptimizationType> {
515 let mut optimizations = Vec::new();
516
517 if self.config.enable_constant_folding {
518 optimizations.push(OptimizationType::ConstantFolding);
519 }
520 if self.config.enable_specialization {
521 optimizations.push(OptimizationType::FunctionSpecialization);
522 }
523 if self.config.enable_dead_code_elimination {
524 optimizations.push(OptimizationType::DeadCodeElimination);
525 }
526 if self.config.enable_loop_optimization {
527 optimizations.push(OptimizationType::LoopOptimization);
528 }
529
530 optimizations
531 }
532}
533
534#[derive(Debug, Clone)]
536pub struct PartialEvalConfig {
537 pub enable_constant_folding: bool,
538 pub enable_specialization: bool,
539 pub enable_dead_code_elimination: bool,
540 pub enable_loop_optimization: bool,
541 pub enable_inlining: bool,
542 pub inline_threshold: usize,
543 pub max_unroll_iterations: usize,
544 pub aggressive_optimization: bool,
545}
546
547impl Default for PartialEvalConfig {
548 fn default() -> Self {
549 Self {
550 enable_constant_folding: true,
551 enable_specialization: true,
552 enable_dead_code_elimination: true,
553 enable_loop_optimization: true,
554 enable_inlining: true,
555 inline_threshold: 50,
556 max_unroll_iterations: 8,
557 aggressive_optimization: false,
558 }
559 }
560}
561
562pub struct ConstantFolder {
564 evaluation_depth: usize,
565}
566
567impl ConstantFolder {
568 pub fn new() -> Self {
569 Self {
570 evaluation_depth: 0,
571 }
572 }
573
574 pub fn fold_constants(
575 &mut self,
576 graph: &mut ComputationGraph,
577 ) -> JitResult<ConstantFoldingResult> {
578 let mut constants_folded = 0;
579 let start_time = std::time::Instant::now();
580
581 let mut constant_nodes = HashMap::new();
583 for (node_id, node) in graph.nodes() {
584 if self.is_constant_node(node) {
585 constant_nodes.insert(node_id, self.extract_constant_value(node)?);
586 }
587 }
588
589 let mut changed = true;
591 while changed {
592 changed = false;
593
594 let node_ids: Vec<_> = graph.nodes().map(|(id, _)| id).collect();
595 for node_id in node_ids {
596 if let Some(node) = graph.node(node_id).cloned() {
597 if !constant_nodes.contains_key(&node_id)
598 && self.can_fold_node(&node, &constant_nodes)
599 {
600 if let Ok(value) = self.evaluate_node(&node, &constant_nodes) {
601 constant_nodes.insert(node_id, value);
602 constants_folded += 1;
603 changed = true;
604 }
605 }
606 }
607 }
608 }
609
610 Ok(ConstantFoldingResult {
611 constants_folded,
612 execution_time: start_time.elapsed(),
613 })
614 }
615
616 fn is_constant_node(&self, node: &crate::graph::Node) -> bool {
617 matches!(node.op, crate::graph::Operation::Input)
619 }
620
621 fn extract_constant_value(&self, node: &crate::graph::Node) -> JitResult<ConstantValue> {
622 Ok(ConstantValue::Float32(0.0))
625 }
626
627 fn can_fold_node(
628 &self,
629 node: &crate::graph::Node,
630 constants: &HashMap<NodeId, ConstantValue>,
631 ) -> bool {
632 for input in &node.inputs {
634 if !constants.contains_key(input) {
635 return false;
636 }
637 }
638 true
639 }
640
641 fn evaluate_node(
642 &self,
643 node: &crate::graph::Node,
644 constants: &HashMap<NodeId, ConstantValue>,
645 ) -> JitResult<ConstantValue> {
646 Ok(ConstantValue::Float32(1.0))
649 }
650}
651
652pub struct FunctionSpecializer {
654 specializations: HashMap<String, Vec<SpecializedFunction>>,
655}
656
657impl FunctionSpecializer {
658 pub fn new() -> Self {
659 Self {
660 specializations: HashMap::new(),
661 }
662 }
663
664 pub fn specialize_functions(
665 &mut self,
666 graph: &mut ComputationGraph,
667 symbolic_info: &SymbolicExecutionInfo,
668 ) -> JitResult<SpecializationResult> {
669 let mut functions_specialized = 0;
670 let start_time = std::time::Instant::now();
671
672 for (node_id, node) in graph.nodes() {
674 if let Some(spec_params) = self.identify_specialization_opportunity(node, symbolic_info)
675 {
676 if self.should_specialize(node, &spec_params) {
677 self.create_specialized_version(node, spec_params)?;
678 functions_specialized += 1;
679 }
680 }
681 }
682
683 Ok(SpecializationResult {
684 functions_specialized,
685 execution_time: start_time.elapsed(),
686 })
687 }
688
689 fn identify_specialization_opportunity(
690 &self,
691 node: &crate::graph::Node,
692 symbolic_info: &SymbolicExecutionInfo,
693 ) -> Option<SpecializationParameters> {
694 None }
697
698 fn should_specialize(
699 &self,
700 node: &crate::graph::Node,
701 params: &SpecializationParameters,
702 ) -> bool {
703 true }
706
707 fn create_specialized_version(
708 &mut self,
709 node: &crate::graph::Node,
710 params: SpecializationParameters,
711 ) -> JitResult<()> {
712 Ok(()) }
715}
716
717pub struct DeadCodeEliminator;
719
720impl DeadCodeEliminator {
721 pub fn new() -> Self {
722 Self
723 }
724
725 pub fn eliminate(
726 &mut self,
727 graph: &mut ComputationGraph,
728 ) -> JitResult<DeadCodeEliminationResult> {
729 let mut nodes_removed = 0;
730 let start_time = std::time::Instant::now();
731
732 let mut reachable = HashSet::new();
734 let mut queue = VecDeque::new();
735
736 for (node_id, node) in graph.nodes() {
738 if node.is_output {
739 queue.push_back(node_id);
740 reachable.insert(node_id);
741 }
742 }
743
744 while let Some(node_id) = queue.pop_front() {
746 if let Some(node) = graph.node(node_id) {
747 for input_id in &node.inputs {
748 if !reachable.contains(input_id) {
749 reachable.insert(*input_id);
750 queue.push_back(*input_id);
751 }
752 }
753 }
754 }
755
756 let all_nodes: Vec<_> = graph.nodes().map(|(id, _)| id).collect();
758 for node_id in all_nodes {
759 if !reachable.contains(&node_id) {
760 let _ = graph.remove_node(node_id);
761 nodes_removed += 1;
762 }
763 }
764
765 Ok(DeadCodeEliminationResult {
766 nodes_removed,
767 execution_time: start_time.elapsed(),
768 })
769 }
770}
771
772pub struct LoopOptimizer;
774
775impl LoopOptimizer {
776 pub fn new() -> Self {
777 Self
778 }
779
780 pub fn optimize_loops(
781 &mut self,
782 graph: &mut ComputationGraph,
783 ) -> JitResult<LoopOptimizationResult> {
784 let mut loops_optimized = 0;
785 let start_time = std::time::Instant::now();
786
787 let loops = self.detect_loops(graph)?;
789
790 for loop_info in loops {
792 if self.should_unroll(&loop_info) {
793 self.unroll_loop(graph, &loop_info)?;
794 loops_optimized += 1;
795 }
796 }
797
798 Ok(LoopOptimizationResult {
799 loops_optimized,
800 execution_time: start_time.elapsed(),
801 })
802 }
803
804 fn detect_loops(&self, graph: &ComputationGraph) -> JitResult<Vec<LoopInfo>> {
805 Ok(Vec::new()) }
808
809 fn should_unroll(&self, loop_info: &LoopInfo) -> bool {
810 loop_info.iteration_count.is_some()
812 && loop_info
813 .iteration_count
814 .expect("iteration count should be Some based on check")
815 <= 8
816 }
817
818 fn unroll_loop(&mut self, graph: &mut ComputationGraph, loop_info: &LoopInfo) -> JitResult<()> {
819 Ok(()) }
822}
823
824pub struct SymbolicExecutor;
826
827impl SymbolicExecutor {
828 pub fn new() -> Self {
829 Self
830 }
831
832 pub fn execute(&mut self, graph: &ComputationGraph) -> JitResult<SymbolicExecutionInfo> {
833 let start_time = std::time::Instant::now();
834
835 let mut info = SymbolicExecutionInfo {
837 constant_values: HashMap::new(),
838 shape_information: HashMap::new(),
839 type_information: HashMap::new(),
840 execution_time: std::time::Duration::from_millis(0),
841 };
842
843 for (node_id, node) in graph.nodes() {
845 if let Some(shape) = self.infer_symbolic_shape(node) {
847 info.shape_information.insert(node_id, shape);
848 }
849
850 if let Some(dtype) = self.infer_symbolic_type(node) {
851 info.type_information.insert(node_id, dtype);
852 }
853 }
854
855 info.execution_time = start_time.elapsed();
856 Ok(info)
857 }
858
859 fn infer_symbolic_shape(&self, node: &crate::graph::Node) -> Option<SymbolicShape> {
860 None }
863
864 fn infer_symbolic_type(&self, node: &crate::graph::Node) -> Option<DType> {
865 Some(node.dtype)
867 }
868}
869
870#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
873pub struct InstructionId(usize);
874
875#[derive(Debug, Clone)]
876pub struct DependencyGraph {
877 instructions: HashMap<InstructionId, crate::ir::Instruction>,
878 dependencies: HashMap<InstructionId, Vec<InstructionId>>,
879 users: HashMap<InstructionId, Vec<InstructionId>>,
880}
881
882impl DependencyGraph {
883 pub fn new() -> Self {
884 Self {
885 instructions: HashMap::new(),
886 dependencies: HashMap::new(),
887 users: HashMap::new(),
888 }
889 }
890
891 pub fn add_instruction(&mut self, id: InstructionId, instruction: crate::ir::Instruction) {
892 self.instructions.insert(id, instruction);
893 self.dependencies.insert(id, Vec::new());
894 self.users.insert(id, Vec::new());
895 }
896
897 pub fn add_dependency(&mut self, user: InstructionId, dep: InstructionId) {
898 self.dependencies.entry(user).or_default().push(dep);
899 self.users.entry(dep).or_default().push(user);
900 }
901
902 pub fn instructions(&self) -> impl Iterator<Item = (&InstructionId, &crate::ir::Instruction)> {
903 self.instructions.iter()
904 }
905
906 pub fn dependencies(&self, id: &InstructionId) -> &[InstructionId] {
907 self.dependencies
908 .get(id)
909 .map(|v| v.as_slice())
910 .unwrap_or(&[])
911 }
912
913 pub fn users(&self, id: &InstructionId) -> &[InstructionId] {
914 self.users.get(id).map(|v| v.as_slice()).unwrap_or(&[])
915 }
916
917 pub fn get_instruction(&self, id: &InstructionId) -> &crate::ir::Instruction {
918 &self.instructions[id]
919 }
920}
921
922#[derive(Debug)]
923pub struct DataFlowInfo {
924 pub reaching_definitions: HashMap<InstructionId, HashSet<InstructionId>>,
925 pub live_variables: HashMap<InstructionId, HashSet<InstructionId>>,
926}
927
928impl DataFlowInfo {
929 pub fn new() -> Self {
930 Self {
931 reaching_definitions: HashMap::new(),
932 live_variables: HashMap::new(),
933 }
934 }
935}
936
937#[derive(Debug, Clone)]
938pub enum ConstantValue {
939 Float32(f32),
940 Float64(f64),
941 Int32(i32),
942 Int64(i64),
943 Boolean(bool),
944}
945
946#[derive(Debug)]
947pub struct SymbolicExecutionInfo {
948 pub constant_values: HashMap<NodeId, ConstantValue>,
949 pub shape_information: HashMap<NodeId, SymbolicShape>,
950 pub type_information: HashMap<NodeId, DType>,
951 pub execution_time: std::time::Duration,
952}
953
954#[derive(Debug)]
955pub struct SymbolicShape {
956 pub dimensions: Vec<SymbolicDimension>,
957}
958
959#[derive(Debug)]
960pub enum SymbolicDimension {
961 Constant(usize),
962 Variable(String),
963 Expression(String),
964}
965
966#[derive(Debug)]
967pub struct SpecializationParameters {
968 pub constant_params: HashMap<String, ConstantValue>,
969 pub shape_params: HashMap<String, Shape>,
970 pub type_params: HashMap<String, DType>,
971}
972
973#[derive(Debug)]
974pub struct SpecializedFunction {
975 pub original_name: String,
976 pub specialized_name: String,
977 pub parameters: SpecializationParameters,
978 pub estimated_speedup: f64,
979}
980
981#[derive(Debug)]
982pub struct LoopInfo {
983 pub header_node: NodeId,
984 pub back_edges: Vec<(NodeId, NodeId)>,
985 pub iteration_count: Option<usize>,
986 pub induction_variables: Vec<NodeId>,
987}
988
989#[derive(Debug)]
992pub struct OptimizedGraph {
993 pub graph: ComputationGraph,
994 pub statistics: EvaluationStatistics,
995 pub optimizations_applied: Vec<OptimizationType>,
996}
997
998#[derive(Debug)]
999pub struct OptimizedIrModule {
1000 pub module: IrModule,
1001 pub statistics: IrEvaluationStatistics,
1002}
1003
1004#[derive(Debug, Default)]
1005pub struct EvaluationStatistics {
1006 pub constants_folded: usize,
1007 pub functions_specialized: usize,
1008 pub dead_nodes_removed: usize,
1009 pub loops_optimized: usize,
1010 pub constant_folding_time: std::time::Duration,
1011 pub specialization_time: std::time::Duration,
1012 pub dead_code_elimination_time: std::time::Duration,
1013 pub loop_optimization_time: std::time::Duration,
1014 pub symbolic_execution_time: std::time::Duration,
1015}
1016
1017impl EvaluationStatistics {
1018 pub fn new() -> Self {
1019 Self::default()
1020 }
1021
1022 pub fn merge(&mut self, other: Self) {
1023 self.constants_folded += other.constants_folded;
1024 self.functions_specialized += other.functions_specialized;
1025 self.dead_nodes_removed += other.dead_nodes_removed;
1026 self.loops_optimized += other.loops_optimized;
1027 self.constant_folding_time += other.constant_folding_time;
1028 self.specialization_time += other.specialization_time;
1029 self.dead_code_elimination_time += other.dead_code_elimination_time;
1030 self.loop_optimization_time += other.loop_optimization_time;
1031 self.symbolic_execution_time += other.symbolic_execution_time;
1032 }
1033}
1034
1035#[derive(Debug, Default)]
1036pub struct IrEvaluationStatistics {
1037 pub constants_propagated: usize,
1038 pub dead_instructions_removed: usize,
1039 pub strength_reductions: usize,
1040}
1041
1042impl IrEvaluationStatistics {
1043 pub fn new() -> Self {
1044 Self::default()
1045 }
1046
1047 pub fn merge(&mut self, other: Self) {
1048 self.constants_propagated += other.constants_propagated;
1049 self.dead_instructions_removed += other.dead_instructions_removed;
1050 self.strength_reductions += other.strength_reductions;
1051 }
1052}
1053
1054#[derive(Debug)]
1055pub struct ConstantFoldingResult {
1056 pub constants_folded: usize,
1057 pub execution_time: std::time::Duration,
1058}
1059
1060#[derive(Debug)]
1061pub struct SpecializationResult {
1062 pub functions_specialized: usize,
1063 pub execution_time: std::time::Duration,
1064}
1065
1066#[derive(Debug)]
1067pub struct DeadCodeEliminationResult {
1068 pub nodes_removed: usize,
1069 pub execution_time: std::time::Duration,
1070}
1071
1072#[derive(Debug)]
1073pub struct LoopOptimizationResult {
1074 pub loops_optimized: usize,
1075 pub execution_time: std::time::Duration,
1076}
1077
1078#[derive(Debug)]
1079pub struct ConstantPropagationResult {
1080 pub constants_propagated: usize,
1081 pub execution_time: std::time::Duration,
1082}
1083
1084#[derive(Debug)]
1085pub struct DeadInstructionResult {
1086 pub instructions_removed: usize,
1087 pub execution_time: std::time::Duration,
1088}
1089
1090#[derive(Debug)]
1091pub struct StrengthReductionResult {
1092 pub reductions_applied: usize,
1093 pub execution_time: std::time::Duration,
1094}
1095
1096#[derive(Debug, Clone)]
1097pub enum OptimizationType {
1098 ConstantFolding,
1099 FunctionSpecialization,
1100 DeadCodeElimination,
1101 LoopOptimization,
1102 ConstantPropagation,
1103 StrengthReduction,
1104}
1105
1106#[cfg(test)]
1107mod tests {
1108 use super::*;
1109
1110 #[test]
1111 fn test_partial_eval_config() {
1112 let config = PartialEvalConfig::default();
1113 assert!(config.enable_constant_folding);
1114 assert!(config.enable_specialization);
1115 assert!(config.enable_dead_code_elimination);
1116 assert!(config.enable_loop_optimization);
1117 }
1118
1119 #[test]
1120 fn test_constant_value_operations() {
1121 let evaluator = PartialEvaluator::new(PartialEvalConfig::default());
1122
1123 let a = ConstantValue::Float32(2.0);
1124 let b = ConstantValue::Float32(3.0);
1125
1126 let result = evaluator.add_constants(&a, &b).unwrap();
1127 if let ConstantValue::Float32(val) = result {
1128 assert_eq!(val, 5.0);
1129 } else {
1130 panic!("Expected Float32 result");
1131 }
1132 }
1133
1134 #[test]
1135 fn test_dependency_graph() {
1136 let mut deps = DependencyGraph::new();
1137 let inst1 = InstructionId(0);
1138 let inst2 = InstructionId(1);
1139
1140 use crate::ir::{Instruction, IrOpcode, IrValue};
1141 use std::collections::HashMap;
1142 let inst1_instruction = Instruction {
1143 result: Some(IrValue(0)),
1144 opcode: IrOpcode::Const,
1145 operands: vec![],
1146 attrs: HashMap::new(),
1147 };
1148 let inst2_instruction = Instruction {
1149 result: Some(IrValue(1)),
1150 opcode: IrOpcode::Const,
1151 operands: vec![],
1152 attrs: HashMap::new(),
1153 };
1154 deps.add_instruction(inst1, inst1_instruction);
1155 deps.add_instruction(inst2, inst2_instruction);
1156 deps.add_dependency(inst2, inst1);
1157
1158 assert_eq!(deps.dependencies(&inst2), &[inst1]);
1159 assert_eq!(deps.users(&inst1), &[inst2]);
1160 }
1161
1162 #[test]
1163 fn test_partial_evaluator_creation() {
1164 let config = PartialEvalConfig::default();
1165 let evaluator = PartialEvaluator::new(config);
1166
1167 assert_eq!(evaluator.get_applied_optimizations().len(), 4);
1169 }
1170}