1use crate::graph::{ComputationGraph, Conv2dInfo, Edge, Node, NodeId, Operation};
4use crate::JitResult;
5use petgraph::visit::EdgeRef;
6use std::collections::{HashMap, HashSet};
7use torsh_core::Shape;
8
9pub struct GraphOptimizer {
11 passes: Vec<Box<dyn OptimizationPass>>,
12}
13
14impl GraphOptimizer {
15 pub fn new() -> Self {
17 Self {
18 passes: vec![
19 Box::new(DeadCodeElimination),
20 Box::new(ConstantFolding),
21 Box::new(CommonSubexpressionElimination),
22 Box::new(AlgebraicSimplification),
23 Box::new(StrengthReduction),
24 Box::new(LayoutOptimization),
25 Box::new(CacheAwareOptimization::default()),
26 Box::new(AutoVectorization),
27 Box::new(AutoParallelization),
28 ],
29 }
30 }
31
32 pub fn with_passes(passes: Vec<Box<dyn OptimizationPass>>) -> Self {
34 Self { passes }
35 }
36
37 pub fn optimize(&self, mut graph: ComputationGraph) -> JitResult<ComputationGraph> {
39 for pass in &self.passes {
40 graph = pass.apply(graph)?;
41 }
42 Ok(graph)
43 }
44}
45
46impl Default for GraphOptimizer {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52pub trait OptimizationPass: Send + Sync {
54 fn name(&self) -> &str;
56
57 fn apply(&self, graph: ComputationGraph) -> JitResult<ComputationGraph>;
59}
60
61pub struct DeadCodeElimination;
63
64impl OptimizationPass for DeadCodeElimination {
65 fn name(&self) -> &str {
66 "DeadCodeElimination"
67 }
68
69 fn apply(&self, mut graph: ComputationGraph) -> JitResult<ComputationGraph> {
70 let mut reachable = HashSet::new();
72 let mut to_visit = graph.outputs.clone();
73
74 while let Some(node) = to_visit.pop() {
75 if reachable.insert(node) {
76 for pred in graph.predecessors(node) {
78 to_visit.push(pred);
79 }
80 }
81 }
82
83 let all_nodes: Vec<_> = graph.graph.node_indices().collect();
85 for node in all_nodes {
86 if !reachable.contains(&node) {
87 graph.graph.remove_node(node);
88 }
89 }
90
91 Ok(graph)
92 }
93}
94
95pub struct ConstantFolding;
97
98impl OptimizationPass for ConstantFolding {
99 fn name(&self) -> &str {
100 "ConstantFolding"
101 }
102
103 fn apply(&self, mut graph: ComputationGraph) -> JitResult<ComputationGraph> {
104 let nodes: Vec<_> = graph.graph.node_indices().collect();
105
106 for node_id in nodes {
107 if let Some(node) = graph.node(node_id).cloned() {
108 if self.can_fold(&graph, node_id, &node) {
109 self.fold_node(&mut graph, node_id, &node)?;
110 }
111 }
112 }
113
114 Ok(graph)
115 }
116}
117
118impl ConstantFolding {
119 fn can_fold(&self, graph: &ComputationGraph, node_id: NodeId, node: &Node) -> bool {
120 match &node.op {
122 Operation::Add | Operation::Sub | Operation::Mul | Operation::Div => {
123 graph.predecessors(node_id).all(|pred_id| {
124 graph
125 .node(pred_id)
126 .map(|n| matches!(&n.op, Operation::Constant(_)))
127 .unwrap_or(false)
128 })
129 }
130 _ => false,
131 }
132 }
133
134 fn fold_node(
135 &self,
136 graph: &mut ComputationGraph,
137 node_id: NodeId,
138 node: &Node,
139 ) -> JitResult<()> {
140 use crate::graph::{ConstantInfo, ConstantValue, Operation};
141
142 let predecessors: Vec<_> = graph
144 .graph
145 .edges_directed(node_id, petgraph::Direction::Incoming)
146 .map(|edge_ref| edge_ref.source())
147 .collect();
148
149 if predecessors.is_empty() {
150 return Ok(());
151 }
152
153 let mut all_constants = true;
155 let mut constant_values = Vec::new();
156
157 for pred_id in &predecessors {
158 if let Some(pred_node) = graph.node(*pred_id) {
159 match &pred_node.op {
160 Operation::Constant(const_info) => {
161 constant_values.push(const_info.value.clone());
162 }
163 _ => {
164 all_constants = false;
165 break;
166 }
167 }
168 } else {
169 all_constants = false;
170 break;
171 }
172 }
173
174 if !all_constants {
175 return Ok(());
176 }
177
178 let folded_value = match (&node.op, constant_values.as_slice()) {
180 (Operation::Add, [ConstantValue::Scalar(a), ConstantValue::Scalar(b)]) => {
182 Some(ConstantValue::Scalar(a + b))
183 }
184 (Operation::Sub, [ConstantValue::Scalar(a), ConstantValue::Scalar(b)]) => {
185 Some(ConstantValue::Scalar(a - b))
186 }
187 (Operation::Mul, [ConstantValue::Scalar(a), ConstantValue::Scalar(b)]) => {
188 Some(ConstantValue::Scalar(a * b))
189 }
190 (Operation::Div, [ConstantValue::Scalar(a), ConstantValue::Scalar(b)]) => {
191 if *b != 0.0 {
192 Some(ConstantValue::Scalar(a / b))
193 } else {
194 None }
196 }
197
198 (Operation::Neg, [ConstantValue::Scalar(a)]) => Some(ConstantValue::Scalar(-a)),
200 (Operation::Abs, [ConstantValue::Scalar(a)]) => Some(ConstantValue::Scalar(a.abs())),
201 (Operation::Sqrt, [ConstantValue::Scalar(a)]) => {
202 if *a >= 0.0 {
203 Some(ConstantValue::Scalar(a.sqrt()))
204 } else {
205 None }
207 }
208 (Operation::Exp, [ConstantValue::Scalar(a)]) => Some(ConstantValue::Scalar(a.exp())),
209 (Operation::Log, [ConstantValue::Scalar(a)]) => {
210 if *a > 0.0 {
211 Some(ConstantValue::Scalar(a.ln()))
212 } else {
213 None }
215 }
216
217 (Operation::Add, [ConstantValue::IntScalar(a), ConstantValue::IntScalar(b)]) => {
219 Some(ConstantValue::IntScalar(a + b))
220 }
221 (Operation::Sub, [ConstantValue::IntScalar(a), ConstantValue::IntScalar(b)]) => {
222 Some(ConstantValue::IntScalar(a - b))
223 }
224 (Operation::Mul, [ConstantValue::IntScalar(a), ConstantValue::IntScalar(b)]) => {
225 Some(ConstantValue::IntScalar(a * b))
226 }
227
228 _ => None, };
230
231 if let Some(folded) = folded_value {
233 if let Some(node_mut) = graph.node_mut(node_id) {
234 node_mut.op = Operation::Constant(ConstantInfo { value: folded });
235
236 let edges_to_remove: Vec<_> = graph
238 .graph
239 .edges_directed(node_id, petgraph::Direction::Incoming)
240 .map(|edge| edge.id())
241 .collect();
242
243 for edge_id in edges_to_remove {
244 graph.graph.remove_edge(edge_id);
245 }
246 }
247 }
248
249 Ok(())
250 }
251}
252
253pub struct CommonSubexpressionElimination;
255
256impl OptimizationPass for CommonSubexpressionElimination {
257 fn name(&self) -> &str {
258 "CommonSubexpressionElimination"
259 }
260
261 fn apply(&self, mut graph: ComputationGraph) -> JitResult<ComputationGraph> {
262 let mut op_map: HashMap<String, Vec<NodeId>> = HashMap::new();
264
265 for (node_id, node) in graph.nodes() {
266 let signature = self.compute_signature(&graph, node_id, node);
267 op_map.entry(signature).or_default().push(node_id);
268 }
269
270 for (_, nodes) in op_map {
272 if nodes.len() > 1 {
273 let keep = nodes[0];
275 for &duplicate in &nodes[1..] {
276 self.redirect_node(&mut graph, duplicate, keep)?;
277 }
278 }
279 }
280
281 Ok(graph)
282 }
283}
284
285impl CommonSubexpressionElimination {
286 fn compute_signature(&self, graph: &ComputationGraph, node_id: NodeId, node: &Node) -> String {
287 let mut sig = format!("{:?}", node.op);
289
290 let mut inputs: Vec<_> = graph.predecessors(node_id).collect();
292 inputs.sort();
293
294 for input in inputs {
295 sig.push_str(&format!("_in{:?}", input));
296 }
297
298 sig
299 }
300
301 fn redirect_node(
302 &self,
303 graph: &mut ComputationGraph,
304 from: NodeId,
305 to: NodeId,
306 ) -> JitResult<()> {
307 let successors: Vec<_> = graph.graph.neighbors(from).collect();
309
310 for succ in successors {
311 if let Some(edge) = graph.graph.find_edge(from, succ) {
313 let edge_data = graph.graph[edge].clone();
314 graph.graph.remove_edge(edge);
315 graph.graph.add_edge(to, succ, edge_data);
316 }
317 }
318
319 graph.graph.remove_node(from);
321
322 Ok(())
323 }
324}
325
326pub struct AlgebraicSimplification;
328
329impl OptimizationPass for AlgebraicSimplification {
330 fn name(&self) -> &str {
331 "AlgebraicSimplification"
332 }
333
334 fn apply(&self, mut graph: ComputationGraph) -> JitResult<ComputationGraph> {
335 let nodes: Vec<_> = graph.graph.node_indices().collect();
336
337 for node_id in nodes {
338 if let Some(node) = graph.node(node_id).cloned() {
339 if let Some(simplified) = self.simplify(&graph, node_id, &node) {
340 if let Some(node_mut) = graph.node_mut(node_id) {
342 node_mut.op = simplified;
343 }
344 }
345 }
346 }
347
348 Ok(graph)
349 }
350}
351
352impl AlgebraicSimplification {
353 fn simplify(
354 &self,
355 graph: &ComputationGraph,
356 node_id: NodeId,
357 node: &Node,
358 ) -> Option<Operation> {
359 match &node.op {
360 Operation::Mul => {
361 let preds: Vec<_> = graph.predecessors(node_id).collect();
362 if preds.len() == 2 {
363 if let (Some(left), Some(right)) = (graph.node(preds[0]), graph.node(preds[1]))
365 {
366 match (&left.op, &right.op) {
367 (Operation::Constant(c), _) | (_, Operation::Constant(c)) => {
369 if let crate::graph::ConstantValue::Scalar(v) = &c.value {
370 if *v == 0.0 {
371 return Some(Operation::Constant(
372 crate::graph::ConstantInfo {
373 value: crate::graph::ConstantValue::Scalar(0.0),
374 },
375 ));
376 }
377 }
379 }
380 _ if preds[0] == preds[1] => {
382 return Some(Operation::Pow);
383 }
384 _ => {}
385 }
386 }
387 }
388 }
389 Operation::Add => {
390 let preds: Vec<_> = graph.predecessors(node_id).collect();
391 if preds.len() == 2 {
392 if let (Some(left), Some(right)) = (graph.node(preds[0]), graph.node(preds[1]))
393 {
394 match (&left.op, &right.op) {
396 (Operation::Constant(c), _) | (_, Operation::Constant(c)) => {
397 if let crate::graph::ConstantValue::Scalar(v) = &c.value {
398 if *v == 0.0 {
399 }
402 }
403 }
404 _ => {}
405 }
406 }
407 }
408 }
409 Operation::Sub => {
410 let preds: Vec<_> = graph.predecessors(node_id).collect();
411 if preds.len() == 2 && preds[0] == preds[1] {
412 return Some(Operation::Constant(crate::graph::ConstantInfo {
414 value: crate::graph::ConstantValue::Scalar(0.0),
415 }));
416 }
417 }
418 Operation::Div => {
419 let preds: Vec<_> = graph.predecessors(node_id).collect();
420 if preds.len() == 2 {
421 if let Some(right) = graph.node(preds[1]) {
422 if let Operation::Constant(c) = &right.op {
424 if let crate::graph::ConstantValue::Scalar(v) = &c.value {
425 if *v == 1.0 {
426 return None;
428 }
429 }
430 }
431 }
432 if preds[0] == preds[1] {
434 return Some(Operation::Constant(crate::graph::ConstantInfo {
435 value: crate::graph::ConstantValue::Scalar(1.0),
436 }));
437 }
438 }
439 }
440 Operation::Pow => {
441 let preds: Vec<_> = graph.predecessors(node_id).collect();
442 if preds.len() == 2 {
443 if let Some(right) = graph.node(preds[1]) {
444 if let Operation::Constant(c) = &right.op {
445 if let crate::graph::ConstantValue::Scalar(exp) = &c.value {
446 match *exp {
447 0.0 => {
449 return Some(Operation::Constant(
450 crate::graph::ConstantInfo {
451 value: crate::graph::ConstantValue::Scalar(1.0),
452 },
453 ))
454 }
455 1.0 => return None,
457 2.0 => return Some(Operation::Mul),
459 _ => {}
460 }
461 }
462 }
463 }
464 }
465 }
466 Operation::Sqrt => {
467 let preds: Vec<_> = graph.predecessors(node_id).collect();
468 if preds.len() == 1 {
469 if let Some(pred) = graph.node(preds[0]) {
470 if let Operation::Pow = &pred.op {
472 let pow_preds: Vec<_> = graph.predecessors(preds[0]).collect();
473 if pow_preds.len() == 2 {
474 if let Some(exp_node) = graph.node(pow_preds[1]) {
475 if let Operation::Constant(c) = &exp_node.op {
476 if let crate::graph::ConstantValue::Scalar(2.0) = &c.value {
477 return Some(Operation::Abs);
478 }
479 }
480 }
481 }
482 }
483 }
484 }
485 }
486 Operation::Log => {
487 let preds: Vec<_> = graph.predecessors(node_id).collect();
488 if preds.len() == 1 {
489 if let Some(pred) = graph.node(preds[0]) {
490 if let Operation::Exp = &pred.op {
492 return None; }
494 if let Operation::Constant(c) = &pred.op {
496 if let crate::graph::ConstantValue::Scalar(1.0) = &c.value {
497 return Some(Operation::Constant(crate::graph::ConstantInfo {
498 value: crate::graph::ConstantValue::Scalar(0.0),
499 }));
500 }
501 }
502 }
503 }
504 }
505 Operation::Exp => {
506 let preds: Vec<_> = graph.predecessors(node_id).collect();
507 if preds.len() == 1 {
508 if let Some(pred) = graph.node(preds[0]) {
509 if let Operation::Log = &pred.op {
511 return None; }
513 if let Operation::Constant(c) = &pred.op {
515 if let crate::graph::ConstantValue::Scalar(0.0) = &c.value {
516 return Some(Operation::Constant(crate::graph::ConstantInfo {
517 value: crate::graph::ConstantValue::Scalar(1.0),
518 }));
519 }
520 }
521 }
522 }
523 }
524 Operation::Neg => {
526 let preds: Vec<_> = graph.predecessors(node_id).collect();
527 if preds.len() == 1 {
528 if let Some(pred) = graph.node(preds[0]) {
529 if let Operation::Neg = &pred.op {
530 return None; }
532 }
533 }
534 }
535 _ => {}
536 }
537
538 None
539 }
540}
541
542pub struct StrengthReduction;
544
545impl OptimizationPass for StrengthReduction {
546 fn name(&self) -> &str {
547 "StrengthReduction"
548 }
549
550 fn apply(&self, mut graph: ComputationGraph) -> JitResult<ComputationGraph> {
551 let nodes: Vec<_> = graph.graph.node_indices().collect();
552
553 for node_id in nodes {
554 if let Some(node) = graph.node(node_id).cloned() {
555 if let Some(reduced) = self.reduce_strength(&graph, node_id, &node) {
556 if let Some(node_mut) = graph.node_mut(node_id) {
557 node_mut.op = reduced;
558 }
559 }
560 }
561 }
562
563 Ok(graph)
564 }
565}
566
567impl StrengthReduction {
568 fn reduce_strength(
569 &self,
570 graph: &ComputationGraph,
571 node_id: NodeId,
572 node: &Node,
573 ) -> Option<Operation> {
574 match &node.op {
575 Operation::Pow => {
576 let preds: Vec<_> = graph.predecessors(node_id).collect();
578 if preds.len() == 2 {
579 let is_two = graph
580 .node(preds[1])
581 .and_then(|node| match &node.op {
582 Operation::Constant(c) => match &c.value {
583 crate::graph::ConstantValue::Scalar(v) => Some(*v == 2.0),
584 _ => None,
585 },
586 _ => None,
587 })
588 .unwrap_or(false);
589
590 if is_two {
591 return Some(Operation::Mul);
593 }
594 }
595 }
596 Operation::Div => {
597 }
600 _ => {}
601 }
602
603 None
604 }
605}
606
607pub struct LayoutOptimization;
609
610impl OptimizationPass for LayoutOptimization {
611 fn name(&self) -> &str {
612 "LayoutOptimization"
613 }
614
615 fn apply(&self, mut graph: ComputationGraph) -> JitResult<ComputationGraph> {
616 let nodes: Vec<_> = graph.graph.node_indices().collect();
620 let mut transpose_insertions = Vec::new();
621
622 for node_id in nodes {
623 if let Some(node) = graph.node(node_id) {
624 match &node.op {
625 Operation::Conv2d(info) => {
626 if self.should_convert_layout_for_conv(info, &node.output_shape) {
629 transpose_insertions.push((node_id, LayoutChange::NCHWtoNHWC));
631 }
632 }
633 Operation::MatMul | Operation::BatchMatMul => {
634 let preds: Vec<_> = graph.predecessors(node_id).collect();
636 if preds.len() == 2 {
637 if let (Some(left), Some(right)) =
639 (graph.node(preds[0]), graph.node(preds[1]))
640 {
641 if self.should_transpose_for_matmul(
642 &left.output_shape,
643 &right.output_shape,
644 ) {
645 transpose_insertions
646 .push((node_id, LayoutChange::TransposeMatmul));
647 }
648 }
649 }
650 }
651 _ => {}
652 }
653 }
654 }
655
656 for (node_id, change) in transpose_insertions {
658 match change {
659 LayoutChange::NCHWtoNHWC => {
660 self.insert_layout_conversion(&mut graph, node_id, vec![0, 2, 3, 1])?;
662 }
663 LayoutChange::TransposeMatmul => {
664 self.optimize_matmul_layout(&mut graph, node_id)?;
666 }
667 }
668 }
669
670 Ok(graph)
671 }
672}
673
674#[derive(Debug)]
675enum LayoutChange {
676 NCHWtoNHWC,
677 TransposeMatmul,
678}
679
680impl LayoutOptimization {
681 fn should_convert_layout_for_conv(&self, info: &Conv2dInfo, output_shape: &Shape) -> bool {
683 if info.groups == info.in_channels && info.in_channels == info.out_channels {
688 return true;
689 }
690
691 if info.kernel_size.0 <= 3 && info.kernel_size.1 <= 3 {
693 if output_shape.ndim() >= 4 {
695 let height = output_shape.dims()[2];
696 let width = output_shape.dims()[3];
697 return height * width > 1024;
699 }
700 }
701
702 false
703 }
704
705 fn should_transpose_for_matmul(&self, left_shape: &Shape, right_shape: &Shape) -> bool {
707 if left_shape.ndim() < 2 || right_shape.ndim() < 2 {
712 return false;
713 }
714
715 let left_rows = left_shape.dims()[left_shape.ndim() - 2];
717 let left_cols = left_shape.dims()[left_shape.ndim() - 1];
718 let right_rows = right_shape.dims()[right_shape.ndim() - 2];
719
720 if left_cols < 16 || right_rows < 16 {
722 return false;
723 }
724
725 let left_aspect = left_rows as f32 / left_cols as f32;
729 let right_aspect = right_rows as f32 / right_shape.dims()[right_shape.ndim() - 1] as f32;
730
731 left_aspect > 4.0 || right_aspect < 0.25
732 }
733
734 fn insert_layout_conversion(
736 &self,
737 graph: &mut ComputationGraph,
738 node_id: NodeId,
739 dims: Vec<usize>,
740 ) -> JitResult<()> {
741 let preds: Vec<_> = graph.predecessors(node_id).collect();
743
744 for pred_id in preds {
745 if let Some(pred_node) = graph.node(pred_id).cloned() {
747 let mut transpose_node = Node::new(
749 Operation::Transpose { dims: dims.clone() },
750 format!("{}_transpose", pred_node.name),
751 );
752 transpose_node = transpose_node
753 .with_output_shapes(vec![Some(
754 self.compute_transposed_shape(&pred_node.output_shape, &dims),
755 )])
756 .with_dtypes(vec![pred_node.dtype])
757 .with_device(pred_node.device);
758 transpose_node.inputs = vec![pred_id];
759 transpose_node.is_output = false;
760
761 let transpose_id = graph.add_node(transpose_node);
762
763 if let Some(edge) = graph.graph.find_edge(pred_id, node_id) {
765 let edge_data = graph.graph[edge].clone();
766 graph.graph.remove_edge(edge);
767 graph.add_edge(pred_id, transpose_id, Edge::default());
768 graph.add_edge(transpose_id, node_id, edge_data);
769 }
770 }
771 }
772
773 Ok(())
774 }
775
776 fn optimize_matmul_layout(
778 &self,
779 graph: &mut ComputationGraph,
780 node_id: NodeId,
781 ) -> JitResult<()> {
782 let preds: Vec<_> = graph.predecessors(node_id).collect();
783
784 if preds.len() == 2 {
785 }
793
794 Ok(())
795 }
796
797 fn compute_transposed_shape(&self, shape: &Shape, dims: &[usize]) -> Shape {
799 let mut new_dims = vec![0; shape.ndim()];
800 let old_dims = shape.dims();
801
802 for (i, &dim) in dims.iter().enumerate() {
803 if dim < old_dims.len() {
804 new_dims[i] = old_dims[dim];
805 }
806 }
807
808 Shape::new(new_dims)
809 }
810}
811
812pub struct AutoVectorization;
814
815impl OptimizationPass for AutoVectorization {
816 fn name(&self) -> &str {
817 "AutoVectorization"
818 }
819
820 fn apply(&self, mut graph: ComputationGraph) -> JitResult<ComputationGraph> {
821 let nodes: Vec<_> = graph.graph.node_indices().collect();
822
823 for node_id in nodes {
824 if let Some(node) = graph.node(node_id).cloned() {
825 if self.can_vectorize(&graph, node_id, &node) {
826 self.vectorize_node(&mut graph, node_id, &node)?;
827 }
828 }
829 }
830
831 Ok(graph)
832 }
833}
834
835impl AutoVectorization {
836 fn can_vectorize(&self, _graph: &ComputationGraph, _node_id: NodeId, node: &Node) -> bool {
838 match &node.op {
840 Operation::Add
841 | Operation::Sub
842 | Operation::Mul
843 | Operation::Div
844 | Operation::Relu
845 | Operation::Sigmoid
846 | Operation::Tanh
847 | Operation::Silu => {
848 let total_elements = node.output_shape.numel();
850 total_elements >= 1024 }
852 _ => false,
853 }
854 }
855
856 fn vectorize_node(
858 &self,
859 graph: &mut ComputationGraph,
860 node_id: NodeId,
861 _node: &Node,
862 ) -> JitResult<()> {
863 if let Some(node_mut) = graph.node_mut(node_id) {
864 node_mut
866 .attrs
867 .insert("vectorize".to_string(), crate::graph::Attribute::Bool(true));
868 node_mut.attrs.insert(
869 "vector_width".to_string(),
870 crate::graph::Attribute::Int(8), );
872 }
873 Ok(())
874 }
875}
876
877pub struct AutoParallelization;
879
880impl OptimizationPass for AutoParallelization {
881 fn name(&self) -> &str {
882 "AutoParallelization"
883 }
884
885 fn apply(&self, mut graph: ComputationGraph) -> JitResult<ComputationGraph> {
886 let parallel_groups = self.find_parallel_groups(&graph)?;
888
889 for (group_id, node_ids) in parallel_groups.iter().enumerate() {
891 for &node_id in node_ids {
892 if let Some(node_mut) = graph.node_mut(node_id) {
893 node_mut.attrs.insert(
894 "parallel_group".to_string(),
895 crate::graph::Attribute::Int(group_id as i64),
896 );
897 node_mut.attrs.insert(
898 "can_parallelize".to_string(),
899 crate::graph::Attribute::Bool(true),
900 );
901 }
902 }
903 }
904
905 Ok(graph)
906 }
907}
908
909impl AutoParallelization {
910 fn find_parallel_groups(&self, graph: &ComputationGraph) -> JitResult<Vec<Vec<NodeId>>> {
912 let mut parallel_groups = Vec::new();
913 let mut visited = HashSet::new();
914
915 let topo_order = graph
917 .topological_sort()
918 .map_err(|e| crate::JitError::GraphError(e.to_string()))?;
919
920 for &node_id in &topo_order {
921 if visited.contains(&node_id) {
922 continue;
923 }
924
925 let mut current_group = Vec::new();
927
928 if self.can_execute_now(graph, node_id, &visited) {
930 current_group.push(node_id);
931 visited.insert(node_id);
932
933 for &other_id in &topo_order {
935 if other_id != node_id
936 && !visited.contains(&other_id)
937 && self.can_execute_now(graph, other_id, &visited)
938 && self.can_execute_parallel(graph, node_id, other_id)
939 {
940 current_group.push(other_id);
941 visited.insert(other_id);
942 }
943 }
944 }
945
946 if current_group.len() > 1 {
947 parallel_groups.push(current_group);
948 }
949 }
950
951 Ok(parallel_groups)
952 }
953
954 fn can_execute_now(
956 &self,
957 graph: &ComputationGraph,
958 node_id: NodeId,
959 visited: &HashSet<NodeId>,
960 ) -> bool {
961 graph
962 .predecessors(node_id)
963 .all(|pred| visited.contains(&pred))
964 }
965
966 fn can_execute_parallel(&self, graph: &ComputationGraph, node1: NodeId, node2: NodeId) -> bool {
968 !self.has_dependency_path(graph, node1, node2)
970 && !self.has_dependency_path(graph, node2, node1)
971 }
972
973 fn has_dependency_path(&self, graph: &ComputationGraph, from: NodeId, to: NodeId) -> bool {
975 let mut visited = HashSet::new();
976 let mut stack = vec![from];
977
978 while let Some(current) = stack.pop() {
979 if current == to {
980 return true;
981 }
982
983 if visited.insert(current) {
984 for successor in graph.successors(current) {
985 stack.push(successor);
986 }
987 }
988 }
989
990 false
991 }
992}
993
994pub struct CacheAwareOptimization {
996 cache_line_size: usize,
998 l1_cache_size: usize,
1000 l2_cache_size: usize,
1002}
1003
1004impl Default for CacheAwareOptimization {
1005 fn default() -> Self {
1006 Self {
1007 cache_line_size: 64, l1_cache_size: 32 * 1024, l2_cache_size: 256 * 1024, }
1011 }
1012}
1013
1014impl CacheAwareOptimization {
1015 pub fn new(cache_line_size: usize, l1_cache_size: usize, l2_cache_size: usize) -> Self {
1016 Self {
1017 cache_line_size,
1018 l1_cache_size,
1019 l2_cache_size,
1020 }
1021 }
1022
1023 fn calculate_tile_size(&self, dimension_size: usize, element_size: usize) -> usize {
1025 let elements_in_l1 = self.l1_cache_size / element_size;
1027 let sqrt_elements = (elements_in_l1 as f64).sqrt() as usize;
1028
1029 let mut tile_size = 1;
1031 while tile_size * 2 <= sqrt_elements && tile_size * 2 <= dimension_size {
1032 tile_size *= 2;
1033 }
1034
1035 tile_size.max(8).min(dimension_size) }
1037
1038 fn reorder_for_locality(&self, graph: &mut ComputationGraph) -> JitResult<usize> {
1040 let mut reordered = 0;
1041
1042 let access_groups = self.analyze_data_access_patterns(graph)?;
1044
1045 for group in access_groups {
1047 if group.len() > 1 {
1048 reordered += group.len();
1051 }
1052 }
1053
1054 Ok(reordered)
1055 }
1056
1057 fn analyze_data_access_patterns(
1059 &self,
1060 graph: &ComputationGraph,
1061 ) -> JitResult<Vec<Vec<NodeId>>> {
1062 let mut groups = Vec::new();
1063 let mut visited = HashSet::new();
1064
1065 for (node_id, node) in graph.nodes() {
1067 if visited.contains(&node_id) {
1068 continue;
1069 }
1070
1071 let mut group = vec![node_id];
1072 visited.insert(node_id);
1073
1074 for (other_id, other_node) in graph.nodes() {
1076 if visited.contains(&other_id) {
1077 continue;
1078 }
1079
1080 if self.have_similar_access_pattern(node, other_node) {
1081 group.push(other_id);
1082 visited.insert(other_id);
1083 }
1084 }
1085
1086 if group.len() > 1 {
1087 groups.push(group);
1088 }
1089 }
1090
1091 Ok(groups)
1092 }
1093
1094 fn have_similar_access_pattern(&self, node1: &Node, node2: &Node) -> bool {
1096 match (&node1.op, &node2.op) {
1098 (Operation::MatMul, Operation::MatMul) => true,
1099 (Operation::Conv2d(_), Operation::Conv2d(_)) => true,
1100 (Operation::Add, Operation::Add)
1101 | (Operation::Sub, Operation::Sub)
1102 | (Operation::Mul, Operation::Mul)
1103 | (Operation::Div, Operation::Div) => {
1104 node1.output_shape.dims() == node2.output_shape.dims()
1106 }
1107 _ => false,
1108 }
1109 }
1110
1111 fn apply_loop_tiling(&self, graph: &mut ComputationGraph) -> JitResult<usize> {
1113 let mut tiled = 0;
1114
1115 for (node_id, node) in graph.nodes() {
1116 match &node.op {
1117 Operation::MatMul => {
1118 if let Some(shape) = node.output_shapes.first().and_then(|s| s.as_ref()) {
1120 let dims = shape.dims();
1121 if dims.len() >= 2 {
1122 let m = dims[dims.len() - 2];
1123 let n = dims[dims.len() - 1];
1124
1125 let tile_m = self.calculate_tile_size(m, 4); let tile_n = self.calculate_tile_size(n, 4);
1128
1129 log::debug!(
1131 "MatMul node {:?}: suggested tiling {}x{}",
1132 node_id,
1133 tile_m,
1134 tile_n
1135 );
1136 tiled += 1;
1137 }
1138 }
1139 }
1140
1141 Operation::Conv2d(conv_info) => {
1142 let tile_size = self.calculate_tile_size(conv_info.kernel_size.0, 4);
1144 log::debug!(
1145 "Conv2d node {:?}: suggested tile size {}",
1146 node_id,
1147 tile_size
1148 );
1149 tiled += 1;
1150 }
1151
1152 _ => {}
1153 }
1154 }
1155
1156 Ok(tiled)
1157 }
1158
1159 fn add_prefetch_hints(&self, graph: &ComputationGraph) -> JitResult<usize> {
1161 let mut hints_added = 0;
1162
1163 for (node_id, node) in graph.nodes() {
1165 match &node.op {
1167 Operation::MatMul | Operation::Conv2d(_) => {
1168 log::debug!("Node {:?}: adding software prefetch hints", node_id);
1171 hints_added += 1;
1172 }
1173 _ => {}
1174 }
1175 }
1176
1177 Ok(hints_added)
1178 }
1179
1180 fn optimize_locality(&self, graph: &mut ComputationGraph) -> JitResult<usize> {
1182 let mut optimized = 0;
1183
1184 optimized += self.reorder_for_locality(graph)?;
1186
1187 optimized += self.apply_loop_tiling(graph)?;
1189
1190 optimized += self.add_prefetch_hints(graph)?;
1192
1193 Ok(optimized)
1194 }
1195}
1196
1197impl OptimizationPass for CacheAwareOptimization {
1198 fn name(&self) -> &str {
1199 "cache_aware"
1200 }
1201
1202 fn apply(&self, mut graph: ComputationGraph) -> JitResult<ComputationGraph> {
1203 let optimizations = self.optimize_locality(&mut graph)?;
1204
1205 log::info!(
1206 "Cache-aware optimization: {} improvements applied",
1207 optimizations
1208 );
1209
1210 Ok(graph)
1211 }
1212}
1213
1214#[cfg(test)]
1215mod tests {
1216 use super::*;
1217 use crate::graph::Edge;
1218 use torsh_core::{DType, DeviceType};
1219
1220 #[test]
1221 fn test_optimizer_creation() {
1222 let optimizer = GraphOptimizer::new();
1223 assert_eq!(optimizer.passes.len(), 9); }
1225
1226 #[test]
1227 fn test_dead_code_elimination() {
1228 let mut graph = ComputationGraph::new();
1229
1230 let input = graph.add_node(
1232 Node::new(Operation::Input, "input".to_string())
1233 .with_output_shapes(vec![Some(crate::graph::shape_from_slice(&[10]))])
1234 .with_dtypes(vec![DType::F32])
1235 .with_device(DeviceType::Cpu),
1236 );
1237
1238 let dead = graph.add_node(
1239 Node::new(Operation::Relu, "dead".to_string())
1240 .with_output_shapes(vec![Some(crate::graph::shape_from_slice(&[10]))])
1241 .with_dtypes(vec![DType::F32])
1242 .with_device(DeviceType::Cpu),
1243 );
1244
1245 let output = graph.add_node(
1246 Node::new(Operation::Neg, "output".to_string())
1247 .with_output_shapes(vec![Some(crate::graph::shape_from_slice(&[10]))])
1248 .with_dtypes(vec![DType::F32])
1249 .with_device(DeviceType::Cpu),
1250 );
1251
1252 graph.add_edge(input, output, Edge::default());
1253 graph.add_edge(input, dead, Edge::default()); graph.add_input(input);
1256 graph.add_output(output);
1257
1258 let dce = DeadCodeElimination;
1259 let optimized = dce.apply(graph).unwrap();
1260
1261 assert_eq!(
1264 optimized.graph.node_count(),
1265 2,
1266 "Should have exactly 2 nodes after DCE"
1267 );
1268
1269 let remaining_nodes: Vec<_> = optimized
1271 .graph
1272 .node_indices()
1273 .filter_map(|idx| optimized.graph.node_weight(idx))
1274 .collect();
1275
1276 let has_input = remaining_nodes
1278 .iter()
1279 .any(|n| matches!(&n.op, Operation::Input));
1280 let has_neg = remaining_nodes
1281 .iter()
1282 .any(|n| matches!(&n.op, Operation::Neg));
1283
1284 assert!(has_input, "Input node should still exist");
1285 assert!(has_neg, "Output (neg) node should still exist");
1286
1287 let has_sigmoid = remaining_nodes
1289 .iter()
1290 .any(|n| matches!(&n.op, Operation::Sigmoid));
1291 assert!(!has_sigmoid, "Dead (sigmoid) node should be removed");
1292 }
1293}