1use crate::ast::{
7 self, BinOp, Block, Expr, FunctionAttrs, Ident, Item, Literal, NumBase, Param, PathSegment,
8 Pattern, Stmt, TypeExpr, TypePath, UnaryOp, Visibility,
9};
10use crate::span::Span;
11use std::collections::{HashMap, HashSet};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum OptLevel {
20 None,
22 Basic,
24 Standard,
26 Aggressive,
28 Size,
30}
31
32#[derive(Debug, Default, Clone)]
34pub struct OptStats {
35 pub constants_folded: usize,
36 pub dead_code_eliminated: usize,
37 pub expressions_deduplicated: usize,
38 pub functions_inlined: usize,
39 pub strength_reductions: usize,
40 pub branches_simplified: usize,
41 pub loops_optimized: usize,
42 pub tail_recursion_transforms: usize,
43 pub memoization_transforms: usize,
44}
45
46pub struct Optimizer {
48 level: OptLevel,
49 stats: OptStats,
50 functions: HashMap<String, ast::Function>,
52 recursive_functions: HashSet<String>,
54 cse_counter: usize,
56}
57
58impl Optimizer {
59 pub fn new(level: OptLevel) -> Self {
60 Self {
61 level,
62 stats: OptStats::default(),
63 functions: HashMap::new(),
64 recursive_functions: HashSet::new(),
65 cse_counter: 0,
66 }
67 }
68
69 pub fn stats(&self) -> &OptStats {
71 &self.stats
72 }
73
74 pub fn optimize_file(&mut self, file: &ast::SourceFile) -> ast::SourceFile {
76 for item in &file.items {
78 if let Item::Function(func) = &item.node {
79 self.functions.insert(func.name.name.clone(), func.clone());
80 if self.is_recursive(&func.name.name, func) {
81 self.recursive_functions.insert(func.name.name.clone());
82 }
83 }
84 }
85
86 let mut new_items: Vec<crate::span::Spanned<Item>> = Vec::new();
90 let mut transformed_functions: HashMap<String, String> = HashMap::new();
91
92 if matches!(self.level, OptLevel::Standard | OptLevel::Aggressive) {
93 for item in &file.items {
94 if let Item::Function(func) = &item.node {
95 if let Some((helper_func, wrapper_func)) = self.try_accumulator_transform(func)
96 {
97 new_items.push(crate::span::Spanned {
99 node: Item::Function(helper_func),
100 span: item.span.clone(),
101 });
102 transformed_functions
103 .insert(func.name.name.clone(), wrapper_func.name.name.clone());
104 self.stats.tail_recursion_transforms += 1;
105 }
106 }
107 }
108
109 }
115
116 let items: Vec<_> = file
118 .items
119 .iter()
120 .map(|item| {
121 let node = match &item.node {
122 Item::Function(func) => {
123 if let Some((_, wrapper)) = self.try_accumulator_transform(func) {
125 if matches!(self.level, OptLevel::Standard | OptLevel::Aggressive)
126 && transformed_functions.contains_key(&func.name.name)
127 {
128 Item::Function(self.optimize_function(&wrapper))
129 } else {
130 Item::Function(self.optimize_function(func))
131 }
132 } else {
133 Item::Function(self.optimize_function(func))
134 }
135 }
136 other => other.clone(),
137 };
138 crate::span::Spanned {
139 node,
140 span: item.span.clone(),
141 }
142 })
143 .collect();
144
145 new_items.extend(items);
147
148 ast::SourceFile {
149 attrs: file.attrs.clone(),
150 config: file.config.clone(),
151 items: new_items,
152 }
153 }
154
155 fn try_accumulator_transform(
158 &self,
159 func: &ast::Function,
160 ) -> Option<(ast::Function, ast::Function)> {
161 if func.params.len() != 1 {
163 return None;
164 }
165
166 if !self.recursive_functions.contains(&func.name.name) {
168 return None;
169 }
170
171 let body = func.body.as_ref()?;
172
173 if !self.is_fib_like_pattern(&func.name.name, body) {
177 return None;
178 }
179
180 let param_name = if let Pattern::Ident { name, .. } = &func.params[0].pattern {
182 name.name.clone()
183 } else {
184 return None;
185 };
186
187 let helper_name = format!("{}_tail", func.name.name);
189
190 let helper_func = self.generate_fib_helper(&helper_name, ¶m_name);
192
193 let wrapper_func =
195 self.generate_fib_wrapper(&func.name.name, &helper_name, ¶m_name, func);
196
197 Some((helper_func, wrapper_func))
198 }
199
200 fn is_fib_like_pattern(&self, func_name: &str, body: &Block) -> bool {
202 if body.stmts.is_empty() && body.expr.is_none() {
209 return false;
210 }
211
212 if let Some(expr) = &body.expr {
214 if let Expr::If {
215 else_branch: Some(else_expr),
216 ..
217 } = expr.as_ref()
218 {
219 return self.is_double_recursive_expr(func_name, else_expr);
222 }
223 }
224
225 if body.stmts.len() >= 1 {
227 if let Some(Stmt::Expr(expr) | Stmt::Semi(expr)) = body.stmts.last() {
229 if let Expr::Return(Some(ret_expr)) = expr {
230 return self.is_double_recursive_expr(func_name, ret_expr);
231 }
232 }
233 if let Some(expr) = &body.expr {
234 return self.is_double_recursive_expr(func_name, expr);
235 }
236 }
237
238 false
239 }
240
241 fn is_double_recursive_expr(&self, func_name: &str, expr: &Expr) -> bool {
243 if let Expr::Binary {
244 op: BinOp::Add,
245 left,
246 right,
247 } = expr
248 {
249 let left_is_recursive = self.is_recursive_call_with_decrement(func_name, left);
250 let right_is_recursive = self.is_recursive_call_with_decrement(func_name, right);
251 return left_is_recursive && right_is_recursive;
252 }
253 false
254 }
255
256 fn is_recursive_call_with_decrement(&self, func_name: &str, expr: &Expr) -> bool {
258 if let Expr::Call { func, args } = expr {
259 if let Expr::Path(path) = func.as_ref() {
260 if path.segments.last().map(|s| s.ident.name.as_str()) == Some(func_name) {
261 if args.len() == 1 {
263 if let Expr::Binary { op: BinOp::Sub, .. } = &args[0] {
264 return true;
265 }
266 }
267 }
268 }
269 }
270 false
271 }
272
273 fn generate_fib_helper(&self, name: &str, _param_name: &str) -> ast::Function {
275 let span = Span { start: 0, end: 0 };
276
277 let n_ident = Ident {
282 name: "n".to_string(),
283 evidentiality: None,
284 affect: None,
285 span: span.clone(),
286 };
287 let a_ident = Ident {
288 name: "a".to_string(),
289 evidentiality: None,
290 affect: None,
291 span: span.clone(),
292 };
293 let b_ident = Ident {
294 name: "b".to_string(),
295 evidentiality: None,
296 affect: None,
297 span: span.clone(),
298 };
299
300 let params = vec![
301 Param {
302 pattern: Pattern::Ident {
303 mutable: false,
304 name: n_ident.clone(),
305 evidentiality: None,
306 },
307 ty: TypeExpr::Infer,
308 },
309 Param {
310 pattern: Pattern::Ident {
311 mutable: false,
312 name: a_ident.clone(),
313 evidentiality: None,
314 },
315 ty: TypeExpr::Infer,
316 },
317 Param {
318 pattern: Pattern::Ident {
319 mutable: false,
320 name: b_ident.clone(),
321 evidentiality: None,
322 },
323 ty: TypeExpr::Infer,
324 },
325 ];
326
327 let condition = Expr::Binary {
329 op: BinOp::Le,
330 left: Box::new(Expr::Path(TypePath {
331 segments: vec![PathSegment {
332 ident: n_ident.clone(),
333 generics: None,
334 }],
335 })),
336 right: Box::new(Expr::Literal(Literal::Int {
337 value: "0".to_string(),
338 base: NumBase::Decimal,
339 suffix: None,
340 })),
341 };
342
343 let then_branch = Block {
345 stmts: vec![],
346 expr: Some(Box::new(Expr::Return(Some(Box::new(Expr::Path(
347 TypePath {
348 segments: vec![PathSegment {
349 ident: a_ident.clone(),
350 generics: None,
351 }],
352 },
353 )))))),
354 };
355
356 let recursive_call = Expr::Call {
358 func: Box::new(Expr::Path(TypePath {
359 segments: vec![PathSegment {
360 ident: Ident {
361 name: name.to_string(),
362 evidentiality: None,
363 affect: None,
364 span: span.clone(),
365 },
366 generics: None,
367 }],
368 })),
369 args: vec![
370 Expr::Binary {
372 op: BinOp::Sub,
373 left: Box::new(Expr::Path(TypePath {
374 segments: vec![PathSegment {
375 ident: n_ident.clone(),
376 generics: None,
377 }],
378 })),
379 right: Box::new(Expr::Literal(Literal::Int {
380 value: "1".to_string(),
381 base: NumBase::Decimal,
382 suffix: None,
383 })),
384 },
385 Expr::Path(TypePath {
387 segments: vec![PathSegment {
388 ident: b_ident.clone(),
389 generics: None,
390 }],
391 }),
392 Expr::Binary {
394 op: BinOp::Add,
395 left: Box::new(Expr::Path(TypePath {
396 segments: vec![PathSegment {
397 ident: a_ident.clone(),
398 generics: None,
399 }],
400 })),
401 right: Box::new(Expr::Path(TypePath {
402 segments: vec![PathSegment {
403 ident: b_ident.clone(),
404 generics: None,
405 }],
406 })),
407 },
408 ],
409 };
410
411 let body = Block {
413 stmts: vec![],
414 expr: Some(Box::new(Expr::If {
415 condition: Box::new(condition),
416 then_branch,
417 else_branch: Some(Box::new(Expr::Return(Some(Box::new(recursive_call))))),
418 })),
419 };
420
421 ast::Function {
422 visibility: Visibility::default(),
423 is_async: false,
424 is_const: false,
425 is_unsafe: false,
426 attrs: FunctionAttrs::default(),
427 name: Ident {
428 name: name.to_string(),
429 evidentiality: None,
430 affect: None,
431 span: span.clone(),
432 },
433 aspect: None,
434 generics: None,
435 params,
436 return_type: None,
437 where_clause: None,
438 body: Some(body),
439 }
440 }
441
442 fn generate_fib_wrapper(
444 &self,
445 name: &str,
446 helper_name: &str,
447 param_name: &str,
448 original: &ast::Function,
449 ) -> ast::Function {
450 let span = Span { start: 0, end: 0 };
451
452 let call_helper = Expr::Call {
454 func: Box::new(Expr::Path(TypePath {
455 segments: vec![PathSegment {
456 ident: Ident {
457 name: helper_name.to_string(),
458 evidentiality: None,
459 affect: None,
460 span: span.clone(),
461 },
462 generics: None,
463 }],
464 })),
465 args: vec![
466 Expr::Path(TypePath {
468 segments: vec![PathSegment {
469 ident: Ident {
470 name: param_name.to_string(),
471 evidentiality: None,
472 affect: None,
473 span: span.clone(),
474 },
475 generics: None,
476 }],
477 }),
478 Expr::Literal(Literal::Int {
480 value: "0".to_string(),
481 base: NumBase::Decimal,
482 suffix: None,
483 }),
484 Expr::Literal(Literal::Int {
486 value: "1".to_string(),
487 base: NumBase::Decimal,
488 suffix: None,
489 }),
490 ],
491 };
492
493 let body = Block {
494 stmts: vec![],
495 expr: Some(Box::new(Expr::Return(Some(Box::new(call_helper))))),
496 };
497
498 ast::Function {
499 visibility: original.visibility,
500 is_async: original.is_async,
501 is_const: original.is_const,
502 is_unsafe: original.is_unsafe,
503 attrs: original.attrs.clone(),
504 name: Ident {
505 name: name.to_string(),
506 evidentiality: None,
507 affect: None,
508 span: span.clone(),
509 },
510 aspect: original.aspect,
511 generics: original.generics.clone(),
512 params: original.params.clone(),
513 return_type: original.return_type.clone(),
514 where_clause: original.where_clause.clone(),
515 body: Some(body),
516 }
517 }
518
519 #[allow(dead_code)]
526 fn try_memoize_transform(
527 &self,
528 func: &ast::Function,
529 ) -> Option<(ast::Function, ast::Function, ast::Function)> {
530 let param_count = func.params.len();
531 if param_count != 1 && param_count != 2 {
532 return None;
533 }
534
535 let span = Span { start: 0, end: 0 };
536 let func_name = &func.name.name;
537 let impl_name = format!("_memo_impl_{}", func_name);
538 let _cache_name = format!("_memo_cache_{}", func_name);
539 let init_name = format!("_memo_init_{}", func_name);
540
541 let param_names: Vec<String> = func
543 .params
544 .iter()
545 .filter_map(|p| {
546 if let Pattern::Ident { name, .. } = &p.pattern {
547 Some(name.name.clone())
548 } else {
549 None
550 }
551 })
552 .collect();
553
554 if param_names.len() != param_count {
555 return None;
556 }
557
558 let impl_func = ast::Function {
560 visibility: Visibility::default(),
561 is_async: func.is_async,
562 is_const: func.is_const,
563 is_unsafe: func.is_unsafe,
564 attrs: func.attrs.clone(),
565 name: Ident {
566 name: impl_name.clone(),
567 evidentiality: None,
568 affect: None,
569 span: span.clone(),
570 },
571 aspect: func.aspect,
572 generics: func.generics.clone(),
573 params: func.params.clone(),
574 return_type: func.return_type.clone(),
575 where_clause: func.where_clause.clone(),
576 body: func
577 .body
578 .as_ref()
579 .map(|b| self.redirect_calls_in_block(func_name, func_name, b)),
580 };
581
582 let cache_init_body = Block {
585 stmts: vec![],
586 expr: Some(Box::new(Expr::Call {
587 func: Box::new(Expr::Path(TypePath {
588 segments: vec![PathSegment {
589 ident: Ident {
590 name: "sigil_memo_new".to_string(),
591 evidentiality: None,
592 affect: None,
593 span: span.clone(),
594 },
595 generics: None,
596 }],
597 })),
598 args: vec![Expr::Literal(Literal::Int {
599 value: "65536".to_string(),
600 base: NumBase::Decimal,
601 suffix: None,
602 })],
603 })),
604 };
605
606 let cache_init_func = ast::Function {
607 visibility: Visibility::default(),
608 is_async: false,
609 is_const: false,
610 is_unsafe: false,
611 attrs: FunctionAttrs::default(),
612 name: Ident {
613 name: init_name.clone(),
614 evidentiality: None,
615 affect: None,
616 span: span.clone(),
617 },
618 aspect: None,
619 generics: None,
620 params: vec![],
621 return_type: None,
622 where_clause: None,
623 body: Some(cache_init_body),
624 };
625
626 let wrapper_func = self.generate_memo_wrapper(func, &impl_name, ¶m_names);
628
629 Some((impl_func, cache_init_func, wrapper_func))
630 }
631
632 #[allow(dead_code)]
634 fn generate_memo_wrapper(
635 &self,
636 original: &ast::Function,
637 impl_name: &str,
638 param_names: &[String],
639 ) -> ast::Function {
640 let span = Span { start: 0, end: 0 };
641 let param_count = param_names.len();
642
643 let cache_var = Ident {
645 name: "__cache".to_string(),
646 evidentiality: None,
647 affect: None,
648 span: span.clone(),
649 };
650 let result_var = Ident {
651 name: "__result".to_string(),
652 evidentiality: None,
653 affect: None,
654 span: span.clone(),
655 };
656 let cached_var = Ident {
657 name: "__cached".to_string(),
658 evidentiality: None,
659 affect: None,
660 span: span.clone(),
661 };
662
663 let mut stmts = vec![];
664
665 stmts.push(Stmt::Let {
667 pattern: Pattern::Ident {
668 mutable: false,
669 name: cache_var.clone(),
670 evidentiality: None,
671 },
672 ty: None,
673 init: Some(Expr::Call {
674 func: Box::new(Expr::Path(TypePath {
675 segments: vec![PathSegment {
676 ident: Ident {
677 name: "sigil_memo_new".to_string(),
678 evidentiality: None,
679 affect: None,
680 span: span.clone(),
681 },
682 generics: None,
683 }],
684 })),
685 args: vec![Expr::Literal(Literal::Int {
686 value: "65536".to_string(),
687 base: NumBase::Decimal,
688 suffix: None,
689 })],
690 }),
691 });
692
693 let get_fn_name = if param_count == 1 {
695 "sigil_memo_get_1"
696 } else {
697 "sigil_memo_get_2"
698 };
699 let mut get_args = vec![Expr::Path(TypePath {
700 segments: vec![PathSegment {
701 ident: cache_var.clone(),
702 generics: None,
703 }],
704 })];
705 for name in param_names {
706 get_args.push(Expr::Path(TypePath {
707 segments: vec![PathSegment {
708 ident: Ident {
709 name: name.clone(),
710 evidentiality: None,
711 affect: None,
712 span: span.clone(),
713 },
714 generics: None,
715 }],
716 }));
717 }
718
719 stmts.push(Stmt::Let {
720 pattern: Pattern::Ident {
721 mutable: false,
722 name: cached_var.clone(),
723 evidentiality: None,
724 },
725 ty: None,
726 init: Some(Expr::Call {
727 func: Box::new(Expr::Path(TypePath {
728 segments: vec![PathSegment {
729 ident: Ident {
730 name: get_fn_name.to_string(),
731 evidentiality: None,
732 affect: None,
733 span: span.clone(),
734 },
735 generics: None,
736 }],
737 })),
738 args: get_args,
739 }),
740 });
741
742 let cache_check = Expr::If {
745 condition: Box::new(Expr::Binary {
746 op: BinOp::Ne,
747 left: Box::new(Expr::Path(TypePath {
748 segments: vec![PathSegment {
749 ident: cached_var.clone(),
750 generics: None,
751 }],
752 })),
753 right: Box::new(Expr::Unary {
754 op: UnaryOp::Neg,
755 expr: Box::new(Expr::Literal(Literal::Int {
756 value: "9223372036854775807".to_string(),
757 base: NumBase::Decimal,
758 suffix: None,
759 })),
760 }),
761 }),
762 then_branch: Block {
763 stmts: vec![],
764 expr: Some(Box::new(Expr::Return(Some(Box::new(Expr::Path(
765 TypePath {
766 segments: vec![PathSegment {
767 ident: cached_var.clone(),
768 generics: None,
769 }],
770 },
771 )))))),
772 },
773 else_branch: None,
774 };
775 stmts.push(Stmt::Semi(cache_check));
776
777 let mut impl_args = vec![];
779 for name in param_names {
780 impl_args.push(Expr::Path(TypePath {
781 segments: vec![PathSegment {
782 ident: Ident {
783 name: name.clone(),
784 evidentiality: None,
785 affect: None,
786 span: span.clone(),
787 },
788 generics: None,
789 }],
790 }));
791 }
792
793 stmts.push(Stmt::Let {
794 pattern: Pattern::Ident {
795 mutable: false,
796 name: result_var.clone(),
797 evidentiality: None,
798 },
799 ty: None,
800 init: Some(Expr::Call {
801 func: Box::new(Expr::Path(TypePath {
802 segments: vec![PathSegment {
803 ident: Ident {
804 name: impl_name.to_string(),
805 evidentiality: None,
806 affect: None,
807 span: span.clone(),
808 },
809 generics: None,
810 }],
811 })),
812 args: impl_args,
813 }),
814 });
815
816 let set_fn_name = if param_count == 1 {
818 "sigil_memo_set_1"
819 } else {
820 "sigil_memo_set_2"
821 };
822 let mut set_args = vec![Expr::Path(TypePath {
823 segments: vec![PathSegment {
824 ident: cache_var.clone(),
825 generics: None,
826 }],
827 })];
828 for name in param_names {
829 set_args.push(Expr::Path(TypePath {
830 segments: vec![PathSegment {
831 ident: Ident {
832 name: name.clone(),
833 evidentiality: None,
834 affect: None,
835 span: span.clone(),
836 },
837 generics: None,
838 }],
839 }));
840 }
841 set_args.push(Expr::Path(TypePath {
842 segments: vec![PathSegment {
843 ident: result_var.clone(),
844 generics: None,
845 }],
846 }));
847
848 stmts.push(Stmt::Semi(Expr::Call {
849 func: Box::new(Expr::Path(TypePath {
850 segments: vec![PathSegment {
851 ident: Ident {
852 name: set_fn_name.to_string(),
853 evidentiality: None,
854 affect: None,
855 span: span.clone(),
856 },
857 generics: None,
858 }],
859 })),
860 args: set_args,
861 }));
862
863 let body = Block {
865 stmts,
866 expr: Some(Box::new(Expr::Return(Some(Box::new(Expr::Path(
867 TypePath {
868 segments: vec![PathSegment {
869 ident: result_var.clone(),
870 generics: None,
871 }],
872 },
873 )))))),
874 };
875
876 ast::Function {
877 visibility: original.visibility,
878 is_async: original.is_async,
879 is_const: original.is_const,
880 is_unsafe: original.is_unsafe,
881 attrs: original.attrs.clone(),
882 name: original.name.clone(),
883 aspect: original.aspect,
884 generics: original.generics.clone(),
885 params: original.params.clone(),
886 return_type: original.return_type.clone(),
887 where_clause: original.where_clause.clone(),
888 body: Some(body),
889 }
890 }
891
892 #[allow(dead_code)]
894 fn redirect_calls_in_block(&self, _old_name: &str, _new_name: &str, block: &Block) -> Block {
895 block.clone()
897 }
898
899 fn is_recursive(&self, name: &str, func: &ast::Function) -> bool {
901 if let Some(body) = &func.body {
902 self.block_calls_function(name, body)
903 } else {
904 false
905 }
906 }
907
908 fn block_calls_function(&self, name: &str, block: &Block) -> bool {
909 for stmt in &block.stmts {
910 if self.stmt_calls_function(name, stmt) {
911 return true;
912 }
913 }
914 if let Some(expr) = &block.expr {
915 if self.expr_calls_function(name, expr) {
916 return true;
917 }
918 }
919 false
920 }
921
922 fn stmt_calls_function(&self, name: &str, stmt: &Stmt) -> bool {
923 match stmt {
924 Stmt::Let {
925 init: Some(expr), ..
926 } => self.expr_calls_function(name, expr),
927 Stmt::Expr(expr) | Stmt::Semi(expr) => self.expr_calls_function(name, expr),
928 _ => false,
929 }
930 }
931
932 fn expr_calls_function(&self, name: &str, expr: &Expr) -> bool {
933 match expr {
934 Expr::Call { func, args } => {
935 if let Expr::Path(path) = func.as_ref() {
936 if path.segments.last().map(|s| s.ident.name.as_str()) == Some(name) {
937 return true;
938 }
939 }
940 args.iter().any(|a| self.expr_calls_function(name, a))
941 }
942 Expr::Binary { left, right, .. } => {
943 self.expr_calls_function(name, left) || self.expr_calls_function(name, right)
944 }
945 Expr::Unary { expr, .. } => self.expr_calls_function(name, expr),
946 Expr::If {
947 condition,
948 then_branch,
949 else_branch,
950 } => {
951 self.expr_calls_function(name, condition)
952 || self.block_calls_function(name, then_branch)
953 || else_branch
954 .as_ref()
955 .map(|e| self.expr_calls_function(name, e))
956 .unwrap_or(false)
957 }
958 Expr::While {
959 label,
960 condition,
961 body,
962 } => self.expr_calls_function(name, condition) || self.block_calls_function(name, body),
963 Expr::Block(block) => self.block_calls_function(name, block),
964 Expr::Return(Some(e)) => self.expr_calls_function(name, e),
965 _ => false,
966 }
967 }
968
969 fn optimize_function(&mut self, func: &ast::Function) -> ast::Function {
971 self.cse_counter = 0;
973
974 let body = if let Some(body) = &func.body {
975 let optimized = match self.level {
977 OptLevel::None => body.clone(),
978 OptLevel::Basic => {
979 let b = self.pass_constant_fold_block(body);
980 self.pass_dead_code_block(&b)
981 }
982 OptLevel::Standard | OptLevel::Size => {
983 let b = self.pass_constant_fold_block(body);
984 let b = self.pass_inline_block(&b); let b = self.pass_strength_reduce_block(&b);
986 let b = self.pass_licm_block(&b); let b = self.pass_cse_block(&b); let b = self.pass_dead_code_block(&b);
989 self.pass_simplify_branches_block(&b)
990 }
991 OptLevel::Aggressive => {
992 let mut b = body.clone();
994 for _ in 0..3 {
995 b = self.pass_constant_fold_block(&b);
996 b = self.pass_inline_block(&b); b = self.pass_strength_reduce_block(&b);
998 b = self.pass_loop_unroll_block(&b); b = self.pass_licm_block(&b); b = self.pass_cse_block(&b); b = self.pass_dead_code_block(&b);
1002 b = self.pass_simplify_branches_block(&b);
1003 }
1004 b
1005 }
1006 };
1007 Some(optimized)
1008 } else {
1009 None
1010 };
1011
1012 ast::Function {
1013 visibility: func.visibility.clone(),
1014 is_async: func.is_async,
1015 is_const: func.is_const,
1016 is_unsafe: func.is_unsafe,
1017 attrs: func.attrs.clone(),
1018 name: func.name.clone(),
1019 aspect: func.aspect,
1020 generics: func.generics.clone(),
1021 params: func.params.clone(),
1022 return_type: func.return_type.clone(),
1023 where_clause: func.where_clause.clone(),
1024 body,
1025 }
1026 }
1027
1028 fn pass_constant_fold_block(&mut self, block: &Block) -> Block {
1033 let stmts = block
1034 .stmts
1035 .iter()
1036 .map(|s| self.pass_constant_fold_stmt(s))
1037 .collect();
1038 let expr = block
1039 .expr
1040 .as_ref()
1041 .map(|e| Box::new(self.pass_constant_fold_expr(e)));
1042 Block { stmts, expr }
1043 }
1044
1045 fn pass_constant_fold_stmt(&mut self, stmt: &Stmt) -> Stmt {
1046 match stmt {
1047 Stmt::Let {
1048 pattern, ty, init, ..
1049 } => Stmt::Let {
1050 pattern: pattern.clone(),
1051 ty: ty.clone(),
1052 init: init.as_ref().map(|e| self.pass_constant_fold_expr(e)),
1053 },
1054 Stmt::LetElse {
1055 pattern,
1056 ty,
1057 init,
1058 else_branch,
1059 } => Stmt::LetElse {
1060 pattern: pattern.clone(),
1061 ty: ty.clone(),
1062 init: self.pass_constant_fold_expr(init),
1063 else_branch: Box::new(self.pass_constant_fold_expr(else_branch)),
1064 },
1065 Stmt::Expr(expr) => Stmt::Expr(self.pass_constant_fold_expr(expr)),
1066 Stmt::Semi(expr) => Stmt::Semi(self.pass_constant_fold_expr(expr)),
1067 Stmt::Item(item) => Stmt::Item(item.clone()),
1068 }
1069 }
1070
1071 fn pass_constant_fold_expr(&mut self, expr: &Expr) -> Expr {
1072 match expr {
1073 Expr::Binary { op, left, right } => {
1074 let left = Box::new(self.pass_constant_fold_expr(left));
1075 let right = Box::new(self.pass_constant_fold_expr(right));
1076
1077 if let (Some(l), Some(r)) = (self.as_int(&left), self.as_int(&right)) {
1079 if let Some(result) = self.fold_binary(op.clone(), l, r) {
1080 self.stats.constants_folded += 1;
1081 return Expr::Literal(Literal::Int {
1082 value: result.to_string(),
1083 base: NumBase::Decimal,
1084 suffix: None,
1085 });
1086 }
1087 }
1088
1089 Expr::Binary {
1090 op: op.clone(),
1091 left,
1092 right,
1093 }
1094 }
1095 Expr::Unary { op, expr: inner } => {
1096 let inner = Box::new(self.pass_constant_fold_expr(inner));
1097
1098 if let Some(v) = self.as_int(&inner) {
1099 if let Some(result) = self.fold_unary(*op, v) {
1100 self.stats.constants_folded += 1;
1101 return Expr::Literal(Literal::Int {
1102 value: result.to_string(),
1103 base: NumBase::Decimal,
1104 suffix: None,
1105 });
1106 }
1107 }
1108
1109 Expr::Unary {
1110 op: *op,
1111 expr: inner,
1112 }
1113 }
1114 Expr::If {
1115 condition,
1116 then_branch,
1117 else_branch,
1118 } => {
1119 let condition = Box::new(self.pass_constant_fold_expr(condition));
1120 let then_branch = self.pass_constant_fold_block(then_branch);
1121 let else_branch = else_branch
1122 .as_ref()
1123 .map(|e| Box::new(self.pass_constant_fold_expr(e)));
1124
1125 if let Some(cond) = self.as_bool(&condition) {
1127 self.stats.branches_simplified += 1;
1128 if cond {
1129 return Expr::Block(then_branch);
1130 } else if let Some(else_expr) = else_branch {
1131 return *else_expr;
1132 } else {
1133 return Expr::Literal(Literal::Bool(false));
1134 }
1135 }
1136
1137 Expr::If {
1138 condition,
1139 then_branch,
1140 else_branch,
1141 }
1142 }
1143 Expr::While {
1144 label,
1145 condition,
1146 body,
1147 } => {
1148 let condition = Box::new(self.pass_constant_fold_expr(condition));
1149 let body = self.pass_constant_fold_block(body);
1150
1151 if let Some(false) = self.as_bool(&condition) {
1153 self.stats.branches_simplified += 1;
1154 return Expr::Block(Block {
1155 stmts: vec![],
1156 expr: None,
1157 });
1158 }
1159
1160 Expr::While {
1161 label: label.clone(),
1162 condition,
1163 body,
1164 }
1165 }
1166 Expr::Block(block) => Expr::Block(self.pass_constant_fold_block(block)),
1167 Expr::Call { func, args } => {
1168 let args = args
1169 .iter()
1170 .map(|a| self.pass_constant_fold_expr(a))
1171 .collect();
1172 Expr::Call {
1173 func: func.clone(),
1174 args,
1175 }
1176 }
1177 Expr::Return(e) => Expr::Return(
1178 e.as_ref()
1179 .map(|e| Box::new(self.pass_constant_fold_expr(e))),
1180 ),
1181 Expr::Assign { target, value } => {
1182 let value = Box::new(self.pass_constant_fold_expr(value));
1183 Expr::Assign {
1184 target: target.clone(),
1185 value,
1186 }
1187 }
1188 Expr::Index { expr: e, index } => {
1189 let e = Box::new(self.pass_constant_fold_expr(e));
1190 let index = Box::new(self.pass_constant_fold_expr(index));
1191 Expr::Index { expr: e, index }
1192 }
1193 Expr::Array(elements) => {
1194 let elements = elements
1195 .iter()
1196 .map(|e| self.pass_constant_fold_expr(e))
1197 .collect();
1198 Expr::Array(elements)
1199 }
1200 other => other.clone(),
1201 }
1202 }
1203
1204 fn as_int(&self, expr: &Expr) -> Option<i64> {
1205 match expr {
1206 Expr::Literal(Literal::Int { value, .. }) => value.parse().ok(),
1207 Expr::Literal(Literal::Bool(b)) => Some(if *b { 1 } else { 0 }),
1208 _ => None,
1209 }
1210 }
1211
1212 fn as_bool(&self, expr: &Expr) -> Option<bool> {
1213 match expr {
1214 Expr::Literal(Literal::Bool(b)) => Some(*b),
1215 Expr::Literal(Literal::Int { value, .. }) => value.parse::<i64>().ok().map(|v| v != 0),
1216 _ => None,
1217 }
1218 }
1219
1220 fn fold_binary(&self, op: BinOp, l: i64, r: i64) -> Option<i64> {
1221 match op {
1222 BinOp::Add => Some(l.wrapping_add(r)),
1223 BinOp::Sub => Some(l.wrapping_sub(r)),
1224 BinOp::Mul => Some(l.wrapping_mul(r)),
1225 BinOp::Div if r != 0 => Some(l / r),
1226 BinOp::Rem if r != 0 => Some(l % r),
1227 BinOp::BitAnd => Some(l & r),
1228 BinOp::BitOr => Some(l | r),
1229 BinOp::BitXor => Some(l ^ r),
1230 BinOp::Shl => Some(l << (r & 63)),
1231 BinOp::Shr => Some(l >> (r & 63)),
1232 BinOp::Eq => Some(if l == r { 1 } else { 0 }),
1233 BinOp::Ne => Some(if l != r { 1 } else { 0 }),
1234 BinOp::Lt => Some(if l < r { 1 } else { 0 }),
1235 BinOp::Le => Some(if l <= r { 1 } else { 0 }),
1236 BinOp::Gt => Some(if l > r { 1 } else { 0 }),
1237 BinOp::Ge => Some(if l >= r { 1 } else { 0 }),
1238 BinOp::And => Some(if l != 0 && r != 0 { 1 } else { 0 }),
1239 BinOp::Or => Some(if l != 0 || r != 0 { 1 } else { 0 }),
1240 _ => None,
1241 }
1242 }
1243
1244 fn fold_unary(&self, op: UnaryOp, v: i64) -> Option<i64> {
1245 match op {
1246 UnaryOp::Neg => Some(-v),
1247 UnaryOp::Not => Some(if v == 0 { 1 } else { 0 }),
1248 _ => None,
1249 }
1250 }
1251
1252 fn pass_strength_reduce_block(&mut self, block: &Block) -> Block {
1257 let stmts = block
1258 .stmts
1259 .iter()
1260 .map(|s| self.pass_strength_reduce_stmt(s))
1261 .collect();
1262 let expr = block
1263 .expr
1264 .as_ref()
1265 .map(|e| Box::new(self.pass_strength_reduce_expr(e)));
1266 Block { stmts, expr }
1267 }
1268
1269 fn pass_strength_reduce_stmt(&mut self, stmt: &Stmt) -> Stmt {
1270 match stmt {
1271 Stmt::Let {
1272 pattern, ty, init, ..
1273 } => Stmt::Let {
1274 pattern: pattern.clone(),
1275 ty: ty.clone(),
1276 init: init.as_ref().map(|e| self.pass_strength_reduce_expr(e)),
1277 },
1278 Stmt::LetElse {
1279 pattern,
1280 ty,
1281 init,
1282 else_branch,
1283 } => Stmt::LetElse {
1284 pattern: pattern.clone(),
1285 ty: ty.clone(),
1286 init: self.pass_strength_reduce_expr(init),
1287 else_branch: Box::new(self.pass_strength_reduce_expr(else_branch)),
1288 },
1289 Stmt::Expr(expr) => Stmt::Expr(self.pass_strength_reduce_expr(expr)),
1290 Stmt::Semi(expr) => Stmt::Semi(self.pass_strength_reduce_expr(expr)),
1291 Stmt::Item(item) => Stmt::Item(item.clone()),
1292 }
1293 }
1294
1295 fn pass_strength_reduce_expr(&mut self, expr: &Expr) -> Expr {
1296 match expr {
1297 Expr::Binary { op, left, right } => {
1298 let left = Box::new(self.pass_strength_reduce_expr(left));
1299 let right = Box::new(self.pass_strength_reduce_expr(right));
1300
1301 if *op == BinOp::Mul {
1303 if let Some(n) = self.as_int(&right) {
1304 if n > 0 && (n as u64).is_power_of_two() {
1305 self.stats.strength_reductions += 1;
1306 let shift = (n as u64).trailing_zeros() as i64;
1307 return Expr::Binary {
1308 op: BinOp::Shl,
1309 left,
1310 right: Box::new(Expr::Literal(Literal::Int {
1311 value: shift.to_string(),
1312 base: NumBase::Decimal,
1313 suffix: None,
1314 })),
1315 };
1316 }
1317 }
1318 if let Some(n) = self.as_int(&left) {
1319 if n > 0 && (n as u64).is_power_of_two() {
1320 self.stats.strength_reductions += 1;
1321 let shift = (n as u64).trailing_zeros() as i64;
1322 return Expr::Binary {
1323 op: BinOp::Shl,
1324 left: right,
1325 right: Box::new(Expr::Literal(Literal::Int {
1326 value: shift.to_string(),
1327 base: NumBase::Decimal,
1328 suffix: None,
1329 })),
1330 };
1331 }
1332 }
1333 }
1334
1335 if let Some(n) = self.as_int(&right) {
1337 match (op, n) {
1338 (BinOp::Add | BinOp::Sub | BinOp::BitOr | BinOp::BitXor, 0)
1339 | (BinOp::Mul | BinOp::Div, 1)
1340 | (BinOp::Shl | BinOp::Shr, 0) => {
1341 self.stats.strength_reductions += 1;
1342 return *left;
1343 }
1344 (BinOp::Mul, 0) | (BinOp::BitAnd, 0) => {
1345 self.stats.strength_reductions += 1;
1346 return Expr::Literal(Literal::Int {
1347 value: "0".to_string(),
1348 base: NumBase::Decimal,
1349 suffix: None,
1350 });
1351 }
1352 _ => {}
1353 }
1354 }
1355
1356 if let Some(n) = self.as_int(&left) {
1358 match (op, n) {
1359 (BinOp::Add | BinOp::BitOr | BinOp::BitXor, 0) | (BinOp::Mul, 1) => {
1360 self.stats.strength_reductions += 1;
1361 return *right;
1362 }
1363 (BinOp::Mul, 0) | (BinOp::BitAnd, 0) => {
1364 self.stats.strength_reductions += 1;
1365 return Expr::Literal(Literal::Int {
1366 value: "0".to_string(),
1367 base: NumBase::Decimal,
1368 suffix: None,
1369 });
1370 }
1371 _ => {}
1372 }
1373 }
1374
1375 Expr::Binary {
1376 op: op.clone(),
1377 left,
1378 right,
1379 }
1380 }
1381 Expr::Unary { op, expr: inner } => {
1382 let inner = Box::new(self.pass_strength_reduce_expr(inner));
1383
1384 if *op == UnaryOp::Neg {
1386 if let Expr::Unary {
1387 op: UnaryOp::Neg,
1388 expr: inner2,
1389 } = inner.as_ref()
1390 {
1391 self.stats.strength_reductions += 1;
1392 return *inner2.clone();
1393 }
1394 }
1395
1396 if *op == UnaryOp::Not {
1398 if let Expr::Unary {
1399 op: UnaryOp::Not,
1400 expr: inner2,
1401 } = inner.as_ref()
1402 {
1403 self.stats.strength_reductions += 1;
1404 return *inner2.clone();
1405 }
1406 }
1407
1408 Expr::Unary {
1409 op: *op,
1410 expr: inner,
1411 }
1412 }
1413 Expr::If {
1414 condition,
1415 then_branch,
1416 else_branch,
1417 } => {
1418 let condition = Box::new(self.pass_strength_reduce_expr(condition));
1419 let then_branch = self.pass_strength_reduce_block(then_branch);
1420 let else_branch = else_branch
1421 .as_ref()
1422 .map(|e| Box::new(self.pass_strength_reduce_expr(e)));
1423 Expr::If {
1424 condition,
1425 then_branch,
1426 else_branch,
1427 }
1428 }
1429 Expr::While {
1430 label,
1431 condition,
1432 body,
1433 } => {
1434 let condition = Box::new(self.pass_strength_reduce_expr(condition));
1435 let body = self.pass_strength_reduce_block(body);
1436 Expr::While {
1437 label: label.clone(),
1438 condition,
1439 body,
1440 }
1441 }
1442 Expr::Block(block) => Expr::Block(self.pass_strength_reduce_block(block)),
1443 Expr::Call { func, args } => {
1444 let args = args
1445 .iter()
1446 .map(|a| self.pass_strength_reduce_expr(a))
1447 .collect();
1448 Expr::Call {
1449 func: func.clone(),
1450 args,
1451 }
1452 }
1453 Expr::Return(e) => Expr::Return(
1454 e.as_ref()
1455 .map(|e| Box::new(self.pass_strength_reduce_expr(e))),
1456 ),
1457 Expr::Assign { target, value } => {
1458 let value = Box::new(self.pass_strength_reduce_expr(value));
1459 Expr::Assign {
1460 target: target.clone(),
1461 value,
1462 }
1463 }
1464 other => other.clone(),
1465 }
1466 }
1467
1468 fn pass_dead_code_block(&mut self, block: &Block) -> Block {
1473 let mut stmts = Vec::new();
1475 let mut found_return = false;
1476
1477 for stmt in &block.stmts {
1478 if found_return {
1479 self.stats.dead_code_eliminated += 1;
1480 continue;
1481 }
1482 let stmt = self.pass_dead_code_stmt(stmt);
1483 if self.stmt_returns(&stmt) {
1484 found_return = true;
1485 }
1486 stmts.push(stmt);
1487 }
1488
1489 let expr = if found_return {
1491 if block.expr.is_some() {
1492 self.stats.dead_code_eliminated += 1;
1493 }
1494 None
1495 } else {
1496 block
1497 .expr
1498 .as_ref()
1499 .map(|e| Box::new(self.pass_dead_code_expr(e)))
1500 };
1501
1502 Block { stmts, expr }
1503 }
1504
1505 fn pass_dead_code_stmt(&mut self, stmt: &Stmt) -> Stmt {
1506 match stmt {
1507 Stmt::Let {
1508 pattern, ty, init, ..
1509 } => Stmt::Let {
1510 pattern: pattern.clone(),
1511 ty: ty.clone(),
1512 init: init.as_ref().map(|e| self.pass_dead_code_expr(e)),
1513 },
1514 Stmt::LetElse {
1515 pattern,
1516 ty,
1517 init,
1518 else_branch,
1519 } => Stmt::LetElse {
1520 pattern: pattern.clone(),
1521 ty: ty.clone(),
1522 init: self.pass_dead_code_expr(init),
1523 else_branch: Box::new(self.pass_dead_code_expr(else_branch)),
1524 },
1525 Stmt::Expr(expr) => Stmt::Expr(self.pass_dead_code_expr(expr)),
1526 Stmt::Semi(expr) => Stmt::Semi(self.pass_dead_code_expr(expr)),
1527 Stmt::Item(item) => Stmt::Item(item.clone()),
1528 }
1529 }
1530
1531 fn pass_dead_code_expr(&mut self, expr: &Expr) -> Expr {
1532 match expr {
1533 Expr::If {
1534 condition,
1535 then_branch,
1536 else_branch,
1537 } => {
1538 let condition = Box::new(self.pass_dead_code_expr(condition));
1539 let then_branch = self.pass_dead_code_block(then_branch);
1540 let else_branch = else_branch
1541 .as_ref()
1542 .map(|e| Box::new(self.pass_dead_code_expr(e)));
1543 Expr::If {
1544 condition,
1545 then_branch,
1546 else_branch,
1547 }
1548 }
1549 Expr::While {
1550 label,
1551 condition,
1552 body,
1553 } => {
1554 let condition = Box::new(self.pass_dead_code_expr(condition));
1555 let body = self.pass_dead_code_block(body);
1556 Expr::While {
1557 label: label.clone(),
1558 condition,
1559 body,
1560 }
1561 }
1562 Expr::Block(block) => Expr::Block(self.pass_dead_code_block(block)),
1563 other => other.clone(),
1564 }
1565 }
1566
1567 fn stmt_returns(&self, stmt: &Stmt) -> bool {
1568 match stmt {
1569 Stmt::Expr(expr) | Stmt::Semi(expr) => self.expr_returns(expr),
1570 _ => false,
1571 }
1572 }
1573
1574 fn expr_returns(&self, expr: &Expr) -> bool {
1575 match expr {
1576 Expr::Return(_) => true,
1577 Expr::Block(block) => {
1578 block.stmts.iter().any(|s| self.stmt_returns(s))
1579 || block
1580 .expr
1581 .as_ref()
1582 .map(|e| self.expr_returns(e))
1583 .unwrap_or(false)
1584 }
1585 _ => false,
1586 }
1587 }
1588
1589 fn pass_simplify_branches_block(&mut self, block: &Block) -> Block {
1594 let stmts = block
1595 .stmts
1596 .iter()
1597 .map(|s| self.pass_simplify_branches_stmt(s))
1598 .collect();
1599 let expr = block
1600 .expr
1601 .as_ref()
1602 .map(|e| Box::new(self.pass_simplify_branches_expr(e)));
1603 Block { stmts, expr }
1604 }
1605
1606 fn pass_simplify_branches_stmt(&mut self, stmt: &Stmt) -> Stmt {
1607 match stmt {
1608 Stmt::Let {
1609 pattern, ty, init, ..
1610 } => Stmt::Let {
1611 pattern: pattern.clone(),
1612 ty: ty.clone(),
1613 init: init.as_ref().map(|e| self.pass_simplify_branches_expr(e)),
1614 },
1615 Stmt::LetElse {
1616 pattern,
1617 ty,
1618 init,
1619 else_branch,
1620 } => Stmt::LetElse {
1621 pattern: pattern.clone(),
1622 ty: ty.clone(),
1623 init: self.pass_simplify_branches_expr(init),
1624 else_branch: Box::new(self.pass_simplify_branches_expr(else_branch)),
1625 },
1626 Stmt::Expr(expr) => Stmt::Expr(self.pass_simplify_branches_expr(expr)),
1627 Stmt::Semi(expr) => Stmt::Semi(self.pass_simplify_branches_expr(expr)),
1628 Stmt::Item(item) => Stmt::Item(item.clone()),
1629 }
1630 }
1631
1632 fn pass_simplify_branches_expr(&mut self, expr: &Expr) -> Expr {
1633 match expr {
1634 Expr::If {
1635 condition,
1636 then_branch,
1637 else_branch,
1638 } => {
1639 let condition = Box::new(self.pass_simplify_branches_expr(condition));
1640 let then_branch = self.pass_simplify_branches_block(then_branch);
1641 let else_branch = else_branch
1642 .as_ref()
1643 .map(|e| Box::new(self.pass_simplify_branches_expr(e)));
1644
1645 if let Expr::Unary {
1647 op: UnaryOp::Not,
1648 expr: inner,
1649 } = condition.as_ref()
1650 {
1651 if let Some(else_expr) = &else_branch {
1652 self.stats.branches_simplified += 1;
1653 let new_else = Some(Box::new(Expr::Block(then_branch)));
1654 let new_then = match else_expr.as_ref() {
1655 Expr::Block(b) => b.clone(),
1656 other => Block {
1657 stmts: vec![],
1658 expr: Some(Box::new(other.clone())),
1659 },
1660 };
1661 return Expr::If {
1662 condition: inner.clone(),
1663 then_branch: new_then,
1664 else_branch: new_else,
1665 };
1666 }
1667 }
1668
1669 Expr::If {
1670 condition,
1671 then_branch,
1672 else_branch,
1673 }
1674 }
1675 Expr::While {
1676 label,
1677 condition,
1678 body,
1679 } => {
1680 let condition = Box::new(self.pass_simplify_branches_expr(condition));
1681 let body = self.pass_simplify_branches_block(body);
1682 Expr::While {
1683 label: label.clone(),
1684 condition,
1685 body,
1686 }
1687 }
1688 Expr::Block(block) => Expr::Block(self.pass_simplify_branches_block(block)),
1689 Expr::Binary { op, left, right } => {
1690 let left = Box::new(self.pass_simplify_branches_expr(left));
1691 let right = Box::new(self.pass_simplify_branches_expr(right));
1692 Expr::Binary {
1693 op: op.clone(),
1694 left,
1695 right,
1696 }
1697 }
1698 Expr::Unary { op, expr: inner } => {
1699 let inner = Box::new(self.pass_simplify_branches_expr(inner));
1700 Expr::Unary {
1701 op: *op,
1702 expr: inner,
1703 }
1704 }
1705 Expr::Call { func, args } => {
1706 let args = args
1707 .iter()
1708 .map(|a| self.pass_simplify_branches_expr(a))
1709 .collect();
1710 Expr::Call {
1711 func: func.clone(),
1712 args,
1713 }
1714 }
1715 Expr::Return(e) => Expr::Return(
1716 e.as_ref()
1717 .map(|e| Box::new(self.pass_simplify_branches_expr(e))),
1718 ),
1719 other => other.clone(),
1720 }
1721 }
1722
1723 fn should_inline(&self, func: &ast::Function) -> bool {
1729 if self.recursive_functions.contains(&func.name.name) {
1731 return false;
1732 }
1733
1734 if let Some(body) = &func.body {
1736 let stmt_count = self.count_stmts_in_block(body);
1737 stmt_count <= 10
1739 } else {
1740 false
1741 }
1742 }
1743
1744 fn count_stmts_in_block(&self, block: &Block) -> usize {
1746 let mut count = block.stmts.len();
1747 if block.expr.is_some() {
1748 count += 1;
1749 }
1750 for stmt in &block.stmts {
1752 count += self.count_stmts_in_stmt(stmt);
1753 }
1754 count
1755 }
1756
1757 fn count_stmts_in_stmt(&self, stmt: &Stmt) -> usize {
1758 match stmt {
1759 Stmt::Expr(e) | Stmt::Semi(e) => self.count_stmts_in_expr(e),
1760 Stmt::Let { init: Some(e), .. } => self.count_stmts_in_expr(e),
1761 _ => 0,
1762 }
1763 }
1764
1765 fn count_stmts_in_expr(&self, expr: &Expr) -> usize {
1766 match expr {
1767 Expr::If {
1768 then_branch,
1769 else_branch,
1770 ..
1771 } => {
1772 let mut count = self.count_stmts_in_block(then_branch);
1773 if let Some(else_expr) = else_branch {
1774 count += self.count_stmts_in_expr(else_expr);
1775 }
1776 count
1777 }
1778 Expr::While { body, .. } => self.count_stmts_in_block(body),
1779 Expr::Block(block) => self.count_stmts_in_block(block),
1780 _ => 0,
1781 }
1782 }
1783
1784 fn inline_call(&mut self, func: &ast::Function, args: &[Expr]) -> Option<Expr> {
1786 let body = func.body.as_ref()?;
1787
1788 let mut param_map: HashMap<String, Expr> = HashMap::new();
1790 for (param, arg) in func.params.iter().zip(args.iter()) {
1791 if let Pattern::Ident { name, .. } = ¶m.pattern {
1792 param_map.insert(name.name.clone(), arg.clone());
1793 }
1794 }
1795
1796 let inlined_body = self.substitute_params_in_block(body, ¶m_map);
1798
1799 self.stats.functions_inlined += 1;
1800
1801 if inlined_body.stmts.is_empty() {
1804 if let Some(expr) = inlined_body.expr {
1805 if let Expr::Return(Some(inner)) = expr.as_ref() {
1807 return Some(inner.as_ref().clone());
1808 }
1809 return Some(*expr);
1810 }
1811 }
1812
1813 Some(Expr::Block(inlined_body))
1814 }
1815
1816 fn substitute_params_in_block(
1818 &self,
1819 block: &Block,
1820 param_map: &HashMap<String, Expr>,
1821 ) -> Block {
1822 let stmts = block
1823 .stmts
1824 .iter()
1825 .map(|s| self.substitute_params_in_stmt(s, param_map))
1826 .collect();
1827 let expr = block
1828 .expr
1829 .as_ref()
1830 .map(|e| Box::new(self.substitute_params_in_expr(e, param_map)));
1831 Block { stmts, expr }
1832 }
1833
1834 fn substitute_params_in_stmt(&self, stmt: &Stmt, param_map: &HashMap<String, Expr>) -> Stmt {
1835 match stmt {
1836 Stmt::Let { pattern, ty, init } => Stmt::Let {
1837 pattern: pattern.clone(),
1838 ty: ty.clone(),
1839 init: init
1840 .as_ref()
1841 .map(|e| self.substitute_params_in_expr(e, param_map)),
1842 },
1843 Stmt::LetElse {
1844 pattern,
1845 ty,
1846 init,
1847 else_branch,
1848 } => Stmt::LetElse {
1849 pattern: pattern.clone(),
1850 ty: ty.clone(),
1851 init: self.substitute_params_in_expr(init, param_map),
1852 else_branch: Box::new(self.substitute_params_in_expr(else_branch, param_map)),
1853 },
1854 Stmt::Expr(e) => Stmt::Expr(self.substitute_params_in_expr(e, param_map)),
1855 Stmt::Semi(e) => Stmt::Semi(self.substitute_params_in_expr(e, param_map)),
1856 Stmt::Item(item) => Stmt::Item(item.clone()),
1857 }
1858 }
1859
1860 fn substitute_params_in_expr(&self, expr: &Expr, param_map: &HashMap<String, Expr>) -> Expr {
1861 match expr {
1862 Expr::Path(path) => {
1863 if path.segments.len() == 1 {
1865 let name = &path.segments[0].ident.name;
1866 if let Some(arg) = param_map.get(name) {
1867 return arg.clone();
1868 }
1869 }
1870 expr.clone()
1871 }
1872 Expr::Binary { op, left, right } => Expr::Binary {
1873 op: op.clone(),
1874 left: Box::new(self.substitute_params_in_expr(left, param_map)),
1875 right: Box::new(self.substitute_params_in_expr(right, param_map)),
1876 },
1877 Expr::Unary { op, expr: inner } => Expr::Unary {
1878 op: *op,
1879 expr: Box::new(self.substitute_params_in_expr(inner, param_map)),
1880 },
1881 Expr::If {
1882 condition,
1883 then_branch,
1884 else_branch,
1885 } => Expr::If {
1886 condition: Box::new(self.substitute_params_in_expr(condition, param_map)),
1887 then_branch: self.substitute_params_in_block(then_branch, param_map),
1888 else_branch: else_branch
1889 .as_ref()
1890 .map(|e| Box::new(self.substitute_params_in_expr(e, param_map))),
1891 },
1892 Expr::While {
1893 label,
1894 condition,
1895 body,
1896 } => Expr::While {
1897 label: label.clone(),
1898 condition: Box::new(self.substitute_params_in_expr(condition, param_map)),
1899 body: self.substitute_params_in_block(body, param_map),
1900 },
1901 Expr::Block(block) => Expr::Block(self.substitute_params_in_block(block, param_map)),
1902 Expr::Call { func, args } => Expr::Call {
1903 func: Box::new(self.substitute_params_in_expr(func, param_map)),
1904 args: args
1905 .iter()
1906 .map(|a| self.substitute_params_in_expr(a, param_map))
1907 .collect(),
1908 },
1909 Expr::Return(e) => Expr::Return(
1910 e.as_ref()
1911 .map(|e| Box::new(self.substitute_params_in_expr(e, param_map))),
1912 ),
1913 Expr::Assign { target, value } => Expr::Assign {
1914 target: target.clone(),
1915 value: Box::new(self.substitute_params_in_expr(value, param_map)),
1916 },
1917 Expr::Index { expr: e, index } => Expr::Index {
1918 expr: Box::new(self.substitute_params_in_expr(e, param_map)),
1919 index: Box::new(self.substitute_params_in_expr(index, param_map)),
1920 },
1921 Expr::Array(elements) => Expr::Array(
1922 elements
1923 .iter()
1924 .map(|e| self.substitute_params_in_expr(e, param_map))
1925 .collect(),
1926 ),
1927 other => other.clone(),
1928 }
1929 }
1930
1931 fn pass_inline_block(&mut self, block: &Block) -> Block {
1932 let stmts = block
1933 .stmts
1934 .iter()
1935 .map(|s| self.pass_inline_stmt(s))
1936 .collect();
1937 let expr = block
1938 .expr
1939 .as_ref()
1940 .map(|e| Box::new(self.pass_inline_expr(e)));
1941 Block { stmts, expr }
1942 }
1943
1944 fn pass_inline_stmt(&mut self, stmt: &Stmt) -> Stmt {
1945 match stmt {
1946 Stmt::Let { pattern, ty, init } => Stmt::Let {
1947 pattern: pattern.clone(),
1948 ty: ty.clone(),
1949 init: init.as_ref().map(|e| self.pass_inline_expr(e)),
1950 },
1951 Stmt::LetElse {
1952 pattern,
1953 ty,
1954 init,
1955 else_branch,
1956 } => Stmt::LetElse {
1957 pattern: pattern.clone(),
1958 ty: ty.clone(),
1959 init: self.pass_inline_expr(init),
1960 else_branch: Box::new(self.pass_inline_expr(else_branch)),
1961 },
1962 Stmt::Expr(e) => Stmt::Expr(self.pass_inline_expr(e)),
1963 Stmt::Semi(e) => Stmt::Semi(self.pass_inline_expr(e)),
1964 Stmt::Item(item) => Stmt::Item(item.clone()),
1965 }
1966 }
1967
1968 fn pass_inline_expr(&mut self, expr: &Expr) -> Expr {
1969 match expr {
1970 Expr::Call { func, args } => {
1971 let args: Vec<Expr> = args.iter().map(|a| self.pass_inline_expr(a)).collect();
1973
1974 if let Expr::Path(path) = func.as_ref() {
1976 if path.segments.len() == 1 {
1977 let func_name = &path.segments[0].ident.name;
1978 if let Some(target_func) = self.functions.get(func_name).cloned() {
1979 if self.should_inline(&target_func)
1980 && args.len() == target_func.params.len()
1981 {
1982 if let Some(inlined) = self.inline_call(&target_func, &args) {
1983 return inlined;
1984 }
1985 }
1986 }
1987 }
1988 }
1989
1990 Expr::Call {
1991 func: func.clone(),
1992 args,
1993 }
1994 }
1995 Expr::Binary { op, left, right } => Expr::Binary {
1996 op: op.clone(),
1997 left: Box::new(self.pass_inline_expr(left)),
1998 right: Box::new(self.pass_inline_expr(right)),
1999 },
2000 Expr::Unary { op, expr: inner } => Expr::Unary {
2001 op: *op,
2002 expr: Box::new(self.pass_inline_expr(inner)),
2003 },
2004 Expr::If {
2005 condition,
2006 then_branch,
2007 else_branch,
2008 } => Expr::If {
2009 condition: Box::new(self.pass_inline_expr(condition)),
2010 then_branch: self.pass_inline_block(then_branch),
2011 else_branch: else_branch
2012 .as_ref()
2013 .map(|e| Box::new(self.pass_inline_expr(e))),
2014 },
2015 Expr::While {
2016 label,
2017 condition,
2018 body,
2019 } => Expr::While {
2020 label: label.clone(),
2021 condition: Box::new(self.pass_inline_expr(condition)),
2022 body: self.pass_inline_block(body),
2023 },
2024 Expr::Block(block) => Expr::Block(self.pass_inline_block(block)),
2025 Expr::Return(e) => Expr::Return(e.as_ref().map(|e| Box::new(self.pass_inline_expr(e)))),
2026 Expr::Assign { target, value } => Expr::Assign {
2027 target: target.clone(),
2028 value: Box::new(self.pass_inline_expr(value)),
2029 },
2030 Expr::Index { expr: e, index } => Expr::Index {
2031 expr: Box::new(self.pass_inline_expr(e)),
2032 index: Box::new(self.pass_inline_expr(index)),
2033 },
2034 Expr::Array(elements) => {
2035 Expr::Array(elements.iter().map(|e| self.pass_inline_expr(e)).collect())
2036 }
2037 other => other.clone(),
2038 }
2039 }
2040
2041 fn pass_loop_unroll_block(&mut self, block: &Block) -> Block {
2047 let stmts = block
2048 .stmts
2049 .iter()
2050 .map(|s| self.pass_loop_unroll_stmt(s))
2051 .collect();
2052 let expr = block
2053 .expr
2054 .as_ref()
2055 .map(|e| Box::new(self.pass_loop_unroll_expr(e)));
2056 Block { stmts, expr }
2057 }
2058
2059 fn pass_loop_unroll_stmt(&mut self, stmt: &Stmt) -> Stmt {
2060 match stmt {
2061 Stmt::Let { pattern, ty, init } => Stmt::Let {
2062 pattern: pattern.clone(),
2063 ty: ty.clone(),
2064 init: init.as_ref().map(|e| self.pass_loop_unroll_expr(e)),
2065 },
2066 Stmt::LetElse {
2067 pattern,
2068 ty,
2069 init,
2070 else_branch,
2071 } => Stmt::LetElse {
2072 pattern: pattern.clone(),
2073 ty: ty.clone(),
2074 init: self.pass_loop_unroll_expr(init),
2075 else_branch: Box::new(self.pass_loop_unroll_expr(else_branch)),
2076 },
2077 Stmt::Expr(e) => Stmt::Expr(self.pass_loop_unroll_expr(e)),
2078 Stmt::Semi(e) => Stmt::Semi(self.pass_loop_unroll_expr(e)),
2079 Stmt::Item(item) => Stmt::Item(item.clone()),
2080 }
2081 }
2082
2083 fn pass_loop_unroll_expr(&mut self, expr: &Expr) -> Expr {
2084 match expr {
2085 Expr::While {
2086 label,
2087 condition,
2088 body,
2089 } => {
2090 if let Some(unrolled) = self.try_unroll_loop(condition, body) {
2092 self.stats.loops_optimized += 1;
2093 return unrolled;
2094 }
2095 Expr::While {
2097 label: label.clone(),
2098 condition: Box::new(self.pass_loop_unroll_expr(condition)),
2099 body: self.pass_loop_unroll_block(body),
2100 }
2101 }
2102 Expr::If {
2103 condition,
2104 then_branch,
2105 else_branch,
2106 } => Expr::If {
2107 condition: Box::new(self.pass_loop_unroll_expr(condition)),
2108 then_branch: self.pass_loop_unroll_block(then_branch),
2109 else_branch: else_branch
2110 .as_ref()
2111 .map(|e| Box::new(self.pass_loop_unroll_expr(e))),
2112 },
2113 Expr::Block(b) => Expr::Block(self.pass_loop_unroll_block(b)),
2114 Expr::Binary { op, left, right } => Expr::Binary {
2115 op: *op,
2116 left: Box::new(self.pass_loop_unroll_expr(left)),
2117 right: Box::new(self.pass_loop_unroll_expr(right)),
2118 },
2119 Expr::Unary { op, expr: inner } => Expr::Unary {
2120 op: *op,
2121 expr: Box::new(self.pass_loop_unroll_expr(inner)),
2122 },
2123 Expr::Call { func, args } => Expr::Call {
2124 func: func.clone(),
2125 args: args.iter().map(|a| self.pass_loop_unroll_expr(a)).collect(),
2126 },
2127 Expr::Return(e) => {
2128 Expr::Return(e.as_ref().map(|e| Box::new(self.pass_loop_unroll_expr(e))))
2129 }
2130 Expr::Assign { target, value } => Expr::Assign {
2131 target: target.clone(),
2132 value: Box::new(self.pass_loop_unroll_expr(value)),
2133 },
2134 other => other.clone(),
2135 }
2136 }
2137
2138 fn try_unroll_loop(&self, condition: &Expr, body: &Block) -> Option<Expr> {
2141 let (loop_var, upper_bound) = self.extract_loop_bounds(condition)?;
2143
2144 if upper_bound > 8 || upper_bound <= 0 {
2146 return None;
2147 }
2148
2149 if !self.body_has_simple_increment(&loop_var, body) {
2151 return None;
2152 }
2153
2154 let stmt_count = body.stmts.len();
2156 if stmt_count > 5 {
2157 return None;
2158 }
2159
2160 let mut unrolled_stmts: Vec<Stmt> = Vec::new();
2162
2163 for i in 0..upper_bound {
2164 let substituted_body = self.substitute_loop_var_in_block(body, &loop_var, i);
2166
2167 for stmt in &substituted_body.stmts {
2169 if !self.is_increment_stmt(&loop_var, stmt) {
2170 unrolled_stmts.push(stmt.clone());
2171 }
2172 }
2173 }
2174
2175 Some(Expr::Block(Block {
2177 stmts: unrolled_stmts,
2178 expr: None,
2179 }))
2180 }
2181
2182 fn extract_loop_bounds(&self, condition: &Expr) -> Option<(String, i64)> {
2184 if let Expr::Binary {
2185 op: BinOp::Lt,
2186 left,
2187 right,
2188 } = condition
2189 {
2190 if let Expr::Path(path) = left.as_ref() {
2192 if path.segments.len() == 1 {
2193 let var_name = path.segments[0].ident.name.clone();
2194 if let Some(bound) = self.as_int(right) {
2196 return Some((var_name, bound));
2197 }
2198 }
2199 }
2200 }
2201 None
2202 }
2203
2204 fn body_has_simple_increment(&self, loop_var: &str, body: &Block) -> bool {
2206 for stmt in &body.stmts {
2207 if self.is_increment_stmt(loop_var, stmt) {
2208 return true;
2209 }
2210 }
2211 false
2212 }
2213
2214 fn is_increment_stmt(&self, var_name: &str, stmt: &Stmt) -> bool {
2216 match stmt {
2217 Stmt::Semi(Expr::Assign { target, value })
2218 | Stmt::Expr(Expr::Assign { target, value }) => {
2219 if let Expr::Path(path) = target.as_ref() {
2221 if path.segments.len() == 1 && path.segments[0].ident.name == var_name {
2222 if let Expr::Binary {
2224 op: BinOp::Add,
2225 left,
2226 right,
2227 } = value.as_ref()
2228 {
2229 if let Expr::Path(lpath) = left.as_ref() {
2230 if lpath.segments.len() == 1
2231 && lpath.segments[0].ident.name == var_name
2232 {
2233 if let Some(1) = self.as_int(right) {
2234 return true;
2235 }
2236 }
2237 }
2238 }
2239 }
2240 }
2241 false
2242 }
2243 _ => false,
2244 }
2245 }
2246
2247 fn substitute_loop_var_in_block(&self, block: &Block, var_name: &str, value: i64) -> Block {
2249 let stmts = block
2250 .stmts
2251 .iter()
2252 .map(|s| self.substitute_loop_var_in_stmt(s, var_name, value))
2253 .collect();
2254 let expr = block
2255 .expr
2256 .as_ref()
2257 .map(|e| Box::new(self.substitute_loop_var_in_expr(e, var_name, value)));
2258 Block { stmts, expr }
2259 }
2260
2261 fn substitute_loop_var_in_stmt(&self, stmt: &Stmt, var_name: &str, value: i64) -> Stmt {
2262 match stmt {
2263 Stmt::Let { pattern, ty, init } => Stmt::Let {
2264 pattern: pattern.clone(),
2265 ty: ty.clone(),
2266 init: init
2267 .as_ref()
2268 .map(|e| self.substitute_loop_var_in_expr(e, var_name, value)),
2269 },
2270 Stmt::LetElse {
2271 pattern,
2272 ty,
2273 init,
2274 else_branch,
2275 } => Stmt::LetElse {
2276 pattern: pattern.clone(),
2277 ty: ty.clone(),
2278 init: self.substitute_loop_var_in_expr(init, var_name, value),
2279 else_branch: Box::new(self.substitute_loop_var_in_expr(
2280 else_branch,
2281 var_name,
2282 value,
2283 )),
2284 },
2285 Stmt::Expr(e) => Stmt::Expr(self.substitute_loop_var_in_expr(e, var_name, value)),
2286 Stmt::Semi(e) => Stmt::Semi(self.substitute_loop_var_in_expr(e, var_name, value)),
2287 Stmt::Item(item) => Stmt::Item(item.clone()),
2288 }
2289 }
2290
2291 fn substitute_loop_var_in_expr(&self, expr: &Expr, var_name: &str, value: i64) -> Expr {
2292 match expr {
2293 Expr::Path(path) => {
2294 if path.segments.len() == 1 && path.segments[0].ident.name == var_name {
2295 return Expr::Literal(Literal::Int {
2296 value: value.to_string(),
2297 base: NumBase::Decimal,
2298 suffix: None,
2299 });
2300 }
2301 expr.clone()
2302 }
2303 Expr::Binary { op, left, right } => Expr::Binary {
2304 op: *op,
2305 left: Box::new(self.substitute_loop_var_in_expr(left, var_name, value)),
2306 right: Box::new(self.substitute_loop_var_in_expr(right, var_name, value)),
2307 },
2308 Expr::Unary { op, expr: inner } => Expr::Unary {
2309 op: *op,
2310 expr: Box::new(self.substitute_loop_var_in_expr(inner, var_name, value)),
2311 },
2312 Expr::Call { func, args } => Expr::Call {
2313 func: Box::new(self.substitute_loop_var_in_expr(func, var_name, value)),
2314 args: args
2315 .iter()
2316 .map(|a| self.substitute_loop_var_in_expr(a, var_name, value))
2317 .collect(),
2318 },
2319 Expr::If {
2320 condition,
2321 then_branch,
2322 else_branch,
2323 } => Expr::If {
2324 condition: Box::new(self.substitute_loop_var_in_expr(condition, var_name, value)),
2325 then_branch: self.substitute_loop_var_in_block(then_branch, var_name, value),
2326 else_branch: else_branch
2327 .as_ref()
2328 .map(|e| Box::new(self.substitute_loop_var_in_expr(e, var_name, value))),
2329 },
2330 Expr::While {
2331 label,
2332 condition,
2333 body,
2334 } => Expr::While {
2335 label: label.clone(),
2336 condition: Box::new(self.substitute_loop_var_in_expr(condition, var_name, value)),
2337 body: self.substitute_loop_var_in_block(body, var_name, value),
2338 },
2339 Expr::Block(b) => Expr::Block(self.substitute_loop_var_in_block(b, var_name, value)),
2340 Expr::Return(e) => Expr::Return(
2341 e.as_ref()
2342 .map(|e| Box::new(self.substitute_loop_var_in_expr(e, var_name, value))),
2343 ),
2344 Expr::Assign { target, value: v } => Expr::Assign {
2345 target: Box::new(self.substitute_loop_var_in_expr(target, var_name, value)),
2346 value: Box::new(self.substitute_loop_var_in_expr(v, var_name, value)),
2347 },
2348 Expr::Index { expr: e, index } => Expr::Index {
2349 expr: Box::new(self.substitute_loop_var_in_expr(e, var_name, value)),
2350 index: Box::new(self.substitute_loop_var_in_expr(index, var_name, value)),
2351 },
2352 Expr::Array(elements) => Expr::Array(
2353 elements
2354 .iter()
2355 .map(|e| self.substitute_loop_var_in_expr(e, var_name, value))
2356 .collect(),
2357 ),
2358 other => other.clone(),
2359 }
2360 }
2361
2362 fn pass_licm_block(&mut self, block: &Block) -> Block {
2368 let stmts = block.stmts.iter().map(|s| self.pass_licm_stmt(s)).collect();
2369 let expr = block
2370 .expr
2371 .as_ref()
2372 .map(|e| Box::new(self.pass_licm_expr(e)));
2373 Block { stmts, expr }
2374 }
2375
2376 fn pass_licm_stmt(&mut self, stmt: &Stmt) -> Stmt {
2377 match stmt {
2378 Stmt::Let { pattern, ty, init } => Stmt::Let {
2379 pattern: pattern.clone(),
2380 ty: ty.clone(),
2381 init: init.as_ref().map(|e| self.pass_licm_expr(e)),
2382 },
2383 Stmt::LetElse {
2384 pattern,
2385 ty,
2386 init,
2387 else_branch,
2388 } => Stmt::LetElse {
2389 pattern: pattern.clone(),
2390 ty: ty.clone(),
2391 init: self.pass_licm_expr(init),
2392 else_branch: Box::new(self.pass_licm_expr(else_branch)),
2393 },
2394 Stmt::Expr(e) => Stmt::Expr(self.pass_licm_expr(e)),
2395 Stmt::Semi(e) => Stmt::Semi(self.pass_licm_expr(e)),
2396 Stmt::Item(item) => Stmt::Item(item.clone()),
2397 }
2398 }
2399
2400 fn pass_licm_expr(&mut self, expr: &Expr) -> Expr {
2401 match expr {
2402 Expr::While {
2403 label,
2404 condition,
2405 body,
2406 } => {
2407 let mut modified_vars = HashSet::new();
2409 self.collect_modified_vars_block(body, &mut modified_vars);
2410
2411 self.collect_modified_vars_expr(condition, &mut modified_vars);
2413
2414 let mut invariant_exprs: Vec<(String, Expr)> = Vec::new();
2416 self.find_loop_invariants(body, &modified_vars, &mut invariant_exprs);
2417
2418 if invariant_exprs.is_empty() {
2419 return Expr::While {
2421 label: label.clone(),
2422 condition: Box::new(self.pass_licm_expr(condition)),
2423 body: self.pass_licm_block(body),
2424 };
2425 }
2426
2427 let mut pre_loop_stmts: Vec<Stmt> = Vec::new();
2429 let mut substitution_map: HashMap<String, String> = HashMap::new();
2430
2431 for (original_key, invariant_expr) in &invariant_exprs {
2432 let var_name = format!("__licm_{}", self.cse_counter);
2433 self.cse_counter += 1;
2434
2435 pre_loop_stmts.push(make_cse_let(&var_name, invariant_expr.clone()));
2436 substitution_map.insert(original_key.clone(), var_name);
2437 self.stats.loops_optimized += 1;
2438 }
2439
2440 let new_body =
2442 self.replace_invariants_in_block(body, &invariant_exprs, &substitution_map);
2443
2444 let new_while = Expr::While {
2446 label: label.clone(),
2447 condition: Box::new(self.pass_licm_expr(condition)),
2448 body: self.pass_licm_block(&new_body),
2449 };
2450
2451 pre_loop_stmts.push(Stmt::Expr(new_while));
2453 Expr::Block(Block {
2454 stmts: pre_loop_stmts,
2455 expr: None,
2456 })
2457 }
2458 Expr::If {
2459 condition,
2460 then_branch,
2461 else_branch,
2462 } => Expr::If {
2463 condition: Box::new(self.pass_licm_expr(condition)),
2464 then_branch: self.pass_licm_block(then_branch),
2465 else_branch: else_branch
2466 .as_ref()
2467 .map(|e| Box::new(self.pass_licm_expr(e))),
2468 },
2469 Expr::Block(b) => Expr::Block(self.pass_licm_block(b)),
2470 Expr::Binary { op, left, right } => Expr::Binary {
2471 op: *op,
2472 left: Box::new(self.pass_licm_expr(left)),
2473 right: Box::new(self.pass_licm_expr(right)),
2474 },
2475 Expr::Unary { op, expr: inner } => Expr::Unary {
2476 op: *op,
2477 expr: Box::new(self.pass_licm_expr(inner)),
2478 },
2479 Expr::Call { func, args } => Expr::Call {
2480 func: func.clone(),
2481 args: args.iter().map(|a| self.pass_licm_expr(a)).collect(),
2482 },
2483 Expr::Return(e) => Expr::Return(e.as_ref().map(|e| Box::new(self.pass_licm_expr(e)))),
2484 Expr::Assign { target, value } => Expr::Assign {
2485 target: target.clone(),
2486 value: Box::new(self.pass_licm_expr(value)),
2487 },
2488 other => other.clone(),
2489 }
2490 }
2491
2492 fn collect_modified_vars_block(&self, block: &Block, modified: &mut HashSet<String>) {
2494 for stmt in &block.stmts {
2495 self.collect_modified_vars_stmt(stmt, modified);
2496 }
2497 if let Some(expr) = &block.expr {
2498 self.collect_modified_vars_expr(expr, modified);
2499 }
2500 }
2501
2502 fn collect_modified_vars_stmt(&self, stmt: &Stmt, modified: &mut HashSet<String>) {
2503 match stmt {
2504 Stmt::Let { pattern, init, .. } => {
2505 if let Pattern::Ident { name, .. } = pattern {
2507 modified.insert(name.name.clone());
2508 }
2509 if let Some(e) = init {
2510 self.collect_modified_vars_expr(e, modified);
2511 }
2512 }
2513 Stmt::Expr(e) | Stmt::Semi(e) => self.collect_modified_vars_expr(e, modified),
2514 _ => {}
2515 }
2516 }
2517
2518 fn collect_modified_vars_expr(&self, expr: &Expr, modified: &mut HashSet<String>) {
2519 match expr {
2520 Expr::Assign { target, value } => {
2521 if let Expr::Path(path) = target.as_ref() {
2522 if path.segments.len() == 1 {
2523 modified.insert(path.segments[0].ident.name.clone());
2524 }
2525 }
2526 self.collect_modified_vars_expr(value, modified);
2527 }
2528 Expr::Binary { left, right, .. } => {
2529 self.collect_modified_vars_expr(left, modified);
2530 self.collect_modified_vars_expr(right, modified);
2531 }
2532 Expr::Unary { expr: inner, .. } => {
2533 self.collect_modified_vars_expr(inner, modified);
2534 }
2535 Expr::If {
2536 condition,
2537 then_branch,
2538 else_branch,
2539 } => {
2540 self.collect_modified_vars_expr(condition, modified);
2541 self.collect_modified_vars_block(then_branch, modified);
2542 if let Some(e) = else_branch {
2543 self.collect_modified_vars_expr(e, modified);
2544 }
2545 }
2546 Expr::While {
2547 label,
2548 condition,
2549 body,
2550 } => {
2551 self.collect_modified_vars_expr(condition, modified);
2552 self.collect_modified_vars_block(body, modified);
2553 }
2554 Expr::Block(b) => self.collect_modified_vars_block(b, modified),
2555 Expr::Call { args, .. } => {
2556 for arg in args {
2557 self.collect_modified_vars_expr(arg, modified);
2558 }
2559 }
2560 Expr::Return(Some(e)) => self.collect_modified_vars_expr(e, modified),
2561 _ => {}
2562 }
2563 }
2564
2565 fn find_loop_invariants(
2567 &self,
2568 block: &Block,
2569 modified: &HashSet<String>,
2570 out: &mut Vec<(String, Expr)>,
2571 ) {
2572 for stmt in &block.stmts {
2573 self.find_loop_invariants_stmt(stmt, modified, out);
2574 }
2575 if let Some(expr) = &block.expr {
2576 self.find_loop_invariants_expr(expr, modified, out);
2577 }
2578 }
2579
2580 fn find_loop_invariants_stmt(
2581 &self,
2582 stmt: &Stmt,
2583 modified: &HashSet<String>,
2584 out: &mut Vec<(String, Expr)>,
2585 ) {
2586 match stmt {
2587 Stmt::Let { init: Some(e), .. } => self.find_loop_invariants_expr(e, modified, out),
2588 Stmt::Expr(e) | Stmt::Semi(e) => self.find_loop_invariants_expr(e, modified, out),
2589 _ => {}
2590 }
2591 }
2592
2593 fn find_loop_invariants_expr(
2594 &self,
2595 expr: &Expr,
2596 modified: &HashSet<String>,
2597 out: &mut Vec<(String, Expr)>,
2598 ) {
2599 match expr {
2601 Expr::Binary { left, right, .. } => {
2602 self.find_loop_invariants_expr(left, modified, out);
2603 self.find_loop_invariants_expr(right, modified, out);
2604 }
2605 Expr::Unary { expr: inner, .. } => {
2606 self.find_loop_invariants_expr(inner, modified, out);
2607 }
2608 Expr::Call { args, .. } => {
2609 for arg in args {
2610 self.find_loop_invariants_expr(arg, modified, out);
2611 }
2612 }
2613 Expr::Index { expr: e, index } => {
2614 self.find_loop_invariants_expr(e, modified, out);
2615 self.find_loop_invariants_expr(index, modified, out);
2616 }
2617 _ => {}
2618 }
2619
2620 if self.is_loop_invariant(expr, modified) && is_cse_worthy(expr) && is_pure_expr(expr) {
2622 let key = format!("{:?}", expr_hash(expr));
2623 if !out.iter().any(|(k, _)| k == &key) {
2625 out.push((key, expr.clone()));
2626 }
2627 }
2628 }
2629
2630 fn is_loop_invariant(&self, expr: &Expr, modified: &HashSet<String>) -> bool {
2632 match expr {
2633 Expr::Literal(_) => true,
2634 Expr::Path(path) => {
2635 if path.segments.len() == 1 {
2636 !modified.contains(&path.segments[0].ident.name)
2637 } else {
2638 true }
2640 }
2641 Expr::Binary { left, right, .. } => {
2642 self.is_loop_invariant(left, modified) && self.is_loop_invariant(right, modified)
2643 }
2644 Expr::Unary { expr: inner, .. } => self.is_loop_invariant(inner, modified),
2645 Expr::Index { expr: e, index } => {
2646 self.is_loop_invariant(e, modified) && self.is_loop_invariant(index, modified)
2647 }
2648 Expr::Call { .. } => false,
2650 _ => false,
2652 }
2653 }
2654
2655 fn replace_invariants_in_block(
2657 &self,
2658 block: &Block,
2659 invariants: &[(String, Expr)],
2660 subs: &HashMap<String, String>,
2661 ) -> Block {
2662 let stmts = block
2663 .stmts
2664 .iter()
2665 .map(|s| self.replace_invariants_in_stmt(s, invariants, subs))
2666 .collect();
2667 let expr = block
2668 .expr
2669 .as_ref()
2670 .map(|e| Box::new(self.replace_invariants_in_expr(e, invariants, subs)));
2671 Block { stmts, expr }
2672 }
2673
2674 fn replace_invariants_in_stmt(
2675 &self,
2676 stmt: &Stmt,
2677 invariants: &[(String, Expr)],
2678 subs: &HashMap<String, String>,
2679 ) -> Stmt {
2680 match stmt {
2681 Stmt::Let { pattern, ty, init } => Stmt::Let {
2682 pattern: pattern.clone(),
2683 ty: ty.clone(),
2684 init: init
2685 .as_ref()
2686 .map(|e| self.replace_invariants_in_expr(e, invariants, subs)),
2687 },
2688 Stmt::LetElse {
2689 pattern,
2690 ty,
2691 init,
2692 else_branch,
2693 } => Stmt::LetElse {
2694 pattern: pattern.clone(),
2695 ty: ty.clone(),
2696 init: self.replace_invariants_in_expr(init, invariants, subs),
2697 else_branch: Box::new(self.replace_invariants_in_expr(
2698 else_branch,
2699 invariants,
2700 subs,
2701 )),
2702 },
2703 Stmt::Expr(e) => Stmt::Expr(self.replace_invariants_in_expr(e, invariants, subs)),
2704 Stmt::Semi(e) => Stmt::Semi(self.replace_invariants_in_expr(e, invariants, subs)),
2705 Stmt::Item(item) => Stmt::Item(item.clone()),
2706 }
2707 }
2708
2709 fn replace_invariants_in_expr(
2710 &self,
2711 expr: &Expr,
2712 invariants: &[(String, Expr)],
2713 subs: &HashMap<String, String>,
2714 ) -> Expr {
2715 let key = format!("{:?}", expr_hash(expr));
2717 for (inv_key, inv_expr) in invariants {
2718 if &key == inv_key && expr_eq(expr, inv_expr) {
2719 if let Some(var_name) = subs.get(inv_key) {
2720 return Expr::Path(TypePath {
2721 segments: vec![PathSegment {
2722 ident: Ident {
2723 name: var_name.clone(),
2724 evidentiality: None,
2725 affect: None,
2726 span: Span { start: 0, end: 0 },
2727 },
2728 generics: None,
2729 }],
2730 });
2731 }
2732 }
2733 }
2734
2735 match expr {
2737 Expr::Binary { op, left, right } => Expr::Binary {
2738 op: *op,
2739 left: Box::new(self.replace_invariants_in_expr(left, invariants, subs)),
2740 right: Box::new(self.replace_invariants_in_expr(right, invariants, subs)),
2741 },
2742 Expr::Unary { op, expr: inner } => Expr::Unary {
2743 op: *op,
2744 expr: Box::new(self.replace_invariants_in_expr(inner, invariants, subs)),
2745 },
2746 Expr::Call { func, args } => Expr::Call {
2747 func: func.clone(),
2748 args: args
2749 .iter()
2750 .map(|a| self.replace_invariants_in_expr(a, invariants, subs))
2751 .collect(),
2752 },
2753 Expr::If {
2754 condition,
2755 then_branch,
2756 else_branch,
2757 } => Expr::If {
2758 condition: Box::new(self.replace_invariants_in_expr(condition, invariants, subs)),
2759 then_branch: self.replace_invariants_in_block(then_branch, invariants, subs),
2760 else_branch: else_branch
2761 .as_ref()
2762 .map(|e| Box::new(self.replace_invariants_in_expr(e, invariants, subs))),
2763 },
2764 Expr::While {
2765 label,
2766 condition,
2767 body,
2768 } => Expr::While {
2769 label: label.clone(),
2770 condition: Box::new(self.replace_invariants_in_expr(condition, invariants, subs)),
2771 body: self.replace_invariants_in_block(body, invariants, subs),
2772 },
2773 Expr::Block(b) => Expr::Block(self.replace_invariants_in_block(b, invariants, subs)),
2774 Expr::Return(e) => Expr::Return(
2775 e.as_ref()
2776 .map(|e| Box::new(self.replace_invariants_in_expr(e, invariants, subs))),
2777 ),
2778 Expr::Assign { target, value } => Expr::Assign {
2779 target: target.clone(),
2780 value: Box::new(self.replace_invariants_in_expr(value, invariants, subs)),
2781 },
2782 Expr::Index { expr: e, index } => Expr::Index {
2783 expr: Box::new(self.replace_invariants_in_expr(e, invariants, subs)),
2784 index: Box::new(self.replace_invariants_in_expr(index, invariants, subs)),
2785 },
2786 other => other.clone(),
2787 }
2788 }
2789
2790 fn pass_cse_block(&mut self, block: &Block) -> Block {
2795 let mut collected = Vec::new();
2797 collect_exprs_from_block(block, &mut collected);
2798
2799 let mut expr_counts: HashMap<u64, Vec<Expr>> = HashMap::new();
2801 for ce in &collected {
2802 let entry = expr_counts.entry(ce.hash).or_insert_with(Vec::new);
2803 let found = entry.iter().any(|e| expr_eq(e, &ce.expr));
2805 if !found {
2806 entry.push(ce.expr.clone());
2807 }
2808 }
2809
2810 let mut occurrence_counts: Vec<(Expr, usize)> = Vec::new();
2812 for ce in &collected {
2813 let existing = occurrence_counts
2815 .iter_mut()
2816 .find(|(e, _)| expr_eq(e, &ce.expr));
2817 if let Some((_, count)) = existing {
2818 *count += 1;
2819 } else {
2820 occurrence_counts.push((ce.expr.clone(), 1));
2821 }
2822 }
2823
2824 let candidates: Vec<Expr> = occurrence_counts
2826 .into_iter()
2827 .filter(|(_, count)| *count >= 2)
2828 .map(|(expr, _)| expr)
2829 .collect();
2830
2831 if candidates.is_empty() {
2832 return self.pass_cse_nested(block);
2834 }
2835
2836 let mut result_block = block.clone();
2838 let mut new_lets: Vec<Stmt> = Vec::new();
2839
2840 for expr in candidates {
2841 let var_name = format!("__cse_{}", self.cse_counter);
2842 self.cse_counter += 1;
2843
2844 new_lets.push(make_cse_let(&var_name, expr.clone()));
2846
2847 result_block = replace_in_block(&result_block, &expr, &var_name);
2849
2850 self.stats.expressions_deduplicated += 1;
2851 }
2852
2853 let mut final_stmts = new_lets;
2855 final_stmts.extend(result_block.stmts);
2856
2857 let result = Block {
2859 stmts: final_stmts,
2860 expr: result_block.expr,
2861 };
2862 self.pass_cse_nested(&result)
2863 }
2864
2865 fn pass_cse_nested(&mut self, block: &Block) -> Block {
2867 let stmts = block
2868 .stmts
2869 .iter()
2870 .map(|stmt| self.pass_cse_stmt(stmt))
2871 .collect();
2872 let expr = block.expr.as_ref().map(|e| Box::new(self.pass_cse_expr(e)));
2873 Block { stmts, expr }
2874 }
2875
2876 fn pass_cse_stmt(&mut self, stmt: &Stmt) -> Stmt {
2877 match stmt {
2878 Stmt::Let { pattern, ty, init } => Stmt::Let {
2879 pattern: pattern.clone(),
2880 ty: ty.clone(),
2881 init: init.as_ref().map(|e| self.pass_cse_expr(e)),
2882 },
2883 Stmt::LetElse {
2884 pattern,
2885 ty,
2886 init,
2887 else_branch,
2888 } => Stmt::LetElse {
2889 pattern: pattern.clone(),
2890 ty: ty.clone(),
2891 init: self.pass_cse_expr(init),
2892 else_branch: Box::new(self.pass_cse_expr(else_branch)),
2893 },
2894 Stmt::Expr(e) => Stmt::Expr(self.pass_cse_expr(e)),
2895 Stmt::Semi(e) => Stmt::Semi(self.pass_cse_expr(e)),
2896 Stmt::Item(item) => Stmt::Item(item.clone()),
2897 }
2898 }
2899
2900 fn pass_cse_expr(&mut self, expr: &Expr) -> Expr {
2901 match expr {
2902 Expr::If {
2903 condition,
2904 then_branch,
2905 else_branch,
2906 } => Expr::If {
2907 condition: Box::new(self.pass_cse_expr(condition)),
2908 then_branch: self.pass_cse_block(then_branch),
2909 else_branch: else_branch
2910 .as_ref()
2911 .map(|e| Box::new(self.pass_cse_expr(e))),
2912 },
2913 Expr::While {
2914 label,
2915 condition,
2916 body,
2917 } => Expr::While {
2918 label: label.clone(),
2919 condition: Box::new(self.pass_cse_expr(condition)),
2920 body: self.pass_cse_block(body),
2921 },
2922 Expr::Block(b) => Expr::Block(self.pass_cse_block(b)),
2923 Expr::Binary { op, left, right } => Expr::Binary {
2924 op: *op,
2925 left: Box::new(self.pass_cse_expr(left)),
2926 right: Box::new(self.pass_cse_expr(right)),
2927 },
2928 Expr::Unary { op, expr: inner } => Expr::Unary {
2929 op: *op,
2930 expr: Box::new(self.pass_cse_expr(inner)),
2931 },
2932 Expr::Call { func, args } => Expr::Call {
2933 func: func.clone(),
2934 args: args.iter().map(|a| self.pass_cse_expr(a)).collect(),
2935 },
2936 Expr::Return(e) => Expr::Return(e.as_ref().map(|e| Box::new(self.pass_cse_expr(e)))),
2937 Expr::Assign { target, value } => Expr::Assign {
2938 target: target.clone(),
2939 value: Box::new(self.pass_cse_expr(value)),
2940 },
2941 other => other.clone(),
2942 }
2943 }
2944}
2945
2946fn expr_hash(expr: &Expr) -> u64 {
2952 use std::collections::hash_map::DefaultHasher;
2953 use std::hash::Hasher;
2954
2955 let mut hasher = DefaultHasher::new();
2956 expr_hash_recursive(expr, &mut hasher);
2957 hasher.finish()
2958}
2959
2960fn expr_hash_recursive<H: std::hash::Hasher>(expr: &Expr, hasher: &mut H) {
2961 use std::hash::Hash;
2962
2963 std::mem::discriminant(expr).hash(hasher);
2964
2965 match expr {
2966 Expr::Literal(lit) => match lit {
2967 Literal::Int { value, .. } => value.hash(hasher),
2968 Literal::Float { value, .. } => value.hash(hasher),
2969 Literal::String(s) => s.hash(hasher),
2970 Literal::Char(c) => c.hash(hasher),
2971 Literal::Bool(b) => b.hash(hasher),
2972 _ => {}
2973 },
2974 Expr::Path(path) => {
2975 for seg in &path.segments {
2976 seg.ident.name.hash(hasher);
2977 }
2978 }
2979 Expr::Binary { op, left, right } => {
2980 std::mem::discriminant(op).hash(hasher);
2981 expr_hash_recursive(left, hasher);
2982 expr_hash_recursive(right, hasher);
2983 }
2984 Expr::Unary { op, expr } => {
2985 std::mem::discriminant(op).hash(hasher);
2986 expr_hash_recursive(expr, hasher);
2987 }
2988 Expr::Call { func, args } => {
2989 expr_hash_recursive(func, hasher);
2990 args.len().hash(hasher);
2991 for arg in args {
2992 expr_hash_recursive(arg, hasher);
2993 }
2994 }
2995 Expr::Index { expr, index } => {
2996 expr_hash_recursive(expr, hasher);
2997 expr_hash_recursive(index, hasher);
2998 }
2999 _ => {}
3000 }
3001}
3002
3003fn is_pure_expr(expr: &Expr) -> bool {
3005 match expr {
3006 Expr::Literal(_) => true,
3007 Expr::Path(_) => true,
3008 Expr::Binary { left, right, .. } => is_pure_expr(left) && is_pure_expr(right),
3009 Expr::Unary { expr, .. } => is_pure_expr(expr),
3010 Expr::If {
3011 condition,
3012 then_branch,
3013 else_branch,
3014 } => {
3015 is_pure_expr(condition)
3016 && then_branch.stmts.is_empty()
3017 && then_branch
3018 .expr
3019 .as_ref()
3020 .map(|e| is_pure_expr(e))
3021 .unwrap_or(true)
3022 && else_branch
3023 .as_ref()
3024 .map(|e| is_pure_expr(e))
3025 .unwrap_or(true)
3026 }
3027 Expr::Index { expr, index } => is_pure_expr(expr) && is_pure_expr(index),
3028 Expr::Array(elements) => elements.iter().all(is_pure_expr),
3029 Expr::Call { .. } => false,
3031 Expr::Assign { .. } => false,
3032 Expr::Return(_) => false,
3033 _ => false,
3034 }
3035}
3036
3037fn is_cse_worthy(expr: &Expr) -> bool {
3039 match expr {
3040 Expr::Literal(_) => false,
3042 Expr::Path(_) => false,
3043 Expr::Binary { .. } => true,
3045 Expr::Unary { .. } => true,
3047 Expr::Call { .. } => false,
3049 Expr::Index { .. } => true,
3051 _ => false,
3052 }
3053}
3054
3055fn expr_eq(a: &Expr, b: &Expr) -> bool {
3057 match (a, b) {
3058 (Expr::Literal(la), Expr::Literal(lb)) => match (la, lb) {
3059 (Literal::Int { value: va, .. }, Literal::Int { value: vb, .. }) => va == vb,
3060 (Literal::Float { value: va, .. }, Literal::Float { value: vb, .. }) => va == vb,
3061 (Literal::String(sa), Literal::String(sb)) => sa == sb,
3062 (Literal::Char(ca), Literal::Char(cb)) => ca == cb,
3063 (Literal::Bool(ba), Literal::Bool(bb)) => ba == bb,
3064 _ => false,
3065 },
3066 (Expr::Path(pa), Expr::Path(pb)) => {
3067 pa.segments.len() == pb.segments.len()
3068 && pa
3069 .segments
3070 .iter()
3071 .zip(&pb.segments)
3072 .all(|(sa, sb)| sa.ident.name == sb.ident.name)
3073 }
3074 (
3075 Expr::Binary {
3076 op: oa,
3077 left: la,
3078 right: ra,
3079 },
3080 Expr::Binary {
3081 op: ob,
3082 left: lb,
3083 right: rb,
3084 },
3085 ) => oa == ob && expr_eq(la, lb) && expr_eq(ra, rb),
3086 (Expr::Unary { op: oa, expr: ea }, Expr::Unary { op: ob, expr: eb }) => {
3087 oa == ob && expr_eq(ea, eb)
3088 }
3089 (
3090 Expr::Index {
3091 expr: ea,
3092 index: ia,
3093 },
3094 Expr::Index {
3095 expr: eb,
3096 index: ib,
3097 },
3098 ) => expr_eq(ea, eb) && expr_eq(ia, ib),
3099 (Expr::Call { func: fa, args: aa }, Expr::Call { func: fb, args: ab }) => {
3100 expr_eq(fa, fb) && aa.len() == ab.len() && aa.iter().zip(ab).all(|(a, b)| expr_eq(a, b))
3101 }
3102 _ => false,
3103 }
3104}
3105
3106#[derive(Clone)]
3108struct CollectedExpr {
3109 expr: Expr,
3110 hash: u64,
3111}
3112
3113fn collect_exprs_from_expr(expr: &Expr, out: &mut Vec<CollectedExpr>) {
3115 match expr {
3117 Expr::Binary { left, right, .. } => {
3118 collect_exprs_from_expr(left, out);
3119 collect_exprs_from_expr(right, out);
3120 }
3121 Expr::Unary { expr: inner, .. } => {
3122 collect_exprs_from_expr(inner, out);
3123 }
3124 Expr::Index { expr: e, index } => {
3125 collect_exprs_from_expr(e, out);
3126 collect_exprs_from_expr(index, out);
3127 }
3128 Expr::Call { func, args } => {
3129 collect_exprs_from_expr(func, out);
3130 for arg in args {
3131 collect_exprs_from_expr(arg, out);
3132 }
3133 }
3134 Expr::If {
3135 condition,
3136 then_branch,
3137 else_branch,
3138 } => {
3139 collect_exprs_from_expr(condition, out);
3140 collect_exprs_from_block(then_branch, out);
3141 if let Some(else_expr) = else_branch {
3142 collect_exprs_from_expr(else_expr, out);
3143 }
3144 }
3145 Expr::While {
3146 label,
3147 condition,
3148 body,
3149 } => {
3150 collect_exprs_from_expr(condition, out);
3151 collect_exprs_from_block(body, out);
3152 }
3153 Expr::Block(block) => {
3154 collect_exprs_from_block(block, out);
3155 }
3156 Expr::Return(Some(e)) => {
3157 collect_exprs_from_expr(e, out);
3158 }
3159 Expr::Assign { value, .. } => {
3160 collect_exprs_from_expr(value, out);
3161 }
3162 Expr::Array(elements) => {
3163 for e in elements {
3164 collect_exprs_from_expr(e, out);
3165 }
3166 }
3167 _ => {}
3168 }
3169
3170 if is_cse_worthy(expr) && is_pure_expr(expr) {
3172 out.push(CollectedExpr {
3173 expr: expr.clone(),
3174 hash: expr_hash(expr),
3175 });
3176 }
3177}
3178
3179fn collect_exprs_from_block(block: &Block, out: &mut Vec<CollectedExpr>) {
3181 for stmt in &block.stmts {
3182 match stmt {
3183 Stmt::Let { init: Some(e), .. } => collect_exprs_from_expr(e, out),
3184 Stmt::Expr(e) | Stmt::Semi(e) => collect_exprs_from_expr(e, out),
3185 _ => {}
3186 }
3187 }
3188 if let Some(e) = &block.expr {
3189 collect_exprs_from_expr(e, out);
3190 }
3191}
3192
3193fn replace_in_expr(expr: &Expr, target: &Expr, var_name: &str) -> Expr {
3195 if expr_eq(expr, target) {
3197 return Expr::Path(TypePath {
3198 segments: vec![PathSegment {
3199 ident: Ident {
3200 name: var_name.to_string(),
3201 evidentiality: None,
3202 affect: None,
3203 span: Span { start: 0, end: 0 },
3204 },
3205 generics: None,
3206 }],
3207 });
3208 }
3209
3210 match expr {
3212 Expr::Binary { op, left, right } => Expr::Binary {
3213 op: *op,
3214 left: Box::new(replace_in_expr(left, target, var_name)),
3215 right: Box::new(replace_in_expr(right, target, var_name)),
3216 },
3217 Expr::Unary { op, expr: inner } => Expr::Unary {
3218 op: *op,
3219 expr: Box::new(replace_in_expr(inner, target, var_name)),
3220 },
3221 Expr::Index { expr: e, index } => Expr::Index {
3222 expr: Box::new(replace_in_expr(e, target, var_name)),
3223 index: Box::new(replace_in_expr(index, target, var_name)),
3224 },
3225 Expr::Call { func, args } => Expr::Call {
3226 func: Box::new(replace_in_expr(func, target, var_name)),
3227 args: args
3228 .iter()
3229 .map(|a| replace_in_expr(a, target, var_name))
3230 .collect(),
3231 },
3232 Expr::If {
3233 condition,
3234 then_branch,
3235 else_branch,
3236 } => Expr::If {
3237 condition: Box::new(replace_in_expr(condition, target, var_name)),
3238 then_branch: replace_in_block(then_branch, target, var_name),
3239 else_branch: else_branch
3240 .as_ref()
3241 .map(|e| Box::new(replace_in_expr(e, target, var_name))),
3242 },
3243 Expr::While {
3244 label,
3245 condition,
3246 body,
3247 } => Expr::While {
3248 label: label.clone(),
3249 condition: Box::new(replace_in_expr(condition, target, var_name)),
3250 body: replace_in_block(body, target, var_name),
3251 },
3252 Expr::Block(block) => Expr::Block(replace_in_block(block, target, var_name)),
3253 Expr::Return(e) => Expr::Return(
3254 e.as_ref()
3255 .map(|e| Box::new(replace_in_expr(e, target, var_name))),
3256 ),
3257 Expr::Assign { target: t, value } => Expr::Assign {
3258 target: t.clone(),
3259 value: Box::new(replace_in_expr(value, target, var_name)),
3260 },
3261 Expr::Array(elements) => Expr::Array(
3262 elements
3263 .iter()
3264 .map(|e| replace_in_expr(e, target, var_name))
3265 .collect(),
3266 ),
3267 other => other.clone(),
3268 }
3269}
3270
3271fn replace_in_block(block: &Block, target: &Expr, var_name: &str) -> Block {
3273 let stmts = block
3274 .stmts
3275 .iter()
3276 .map(|stmt| match stmt {
3277 Stmt::Let { pattern, ty, init } => Stmt::Let {
3278 pattern: pattern.clone(),
3279 ty: ty.clone(),
3280 init: init.as_ref().map(|e| replace_in_expr(e, target, var_name)),
3281 },
3282 Stmt::LetElse {
3283 pattern,
3284 ty,
3285 init,
3286 else_branch,
3287 } => Stmt::LetElse {
3288 pattern: pattern.clone(),
3289 ty: ty.clone(),
3290 init: replace_in_expr(init, target, var_name),
3291 else_branch: Box::new(replace_in_expr(else_branch, target, var_name)),
3292 },
3293 Stmt::Expr(e) => Stmt::Expr(replace_in_expr(e, target, var_name)),
3294 Stmt::Semi(e) => Stmt::Semi(replace_in_expr(e, target, var_name)),
3295 Stmt::Item(item) => Stmt::Item(item.clone()),
3296 })
3297 .collect();
3298
3299 let expr = block
3300 .expr
3301 .as_ref()
3302 .map(|e| Box::new(replace_in_expr(e, target, var_name)));
3303
3304 Block { stmts, expr }
3305}
3306
3307fn make_cse_let(var_name: &str, expr: Expr) -> Stmt {
3309 Stmt::Let {
3310 pattern: Pattern::Ident {
3311 mutable: false,
3312 name: Ident {
3313 name: var_name.to_string(),
3314 evidentiality: None,
3315 affect: None,
3316 span: Span { start: 0, end: 0 },
3317 },
3318 evidentiality: None,
3319 },
3320 ty: None,
3321 init: Some(expr),
3322 }
3323}
3324
3325pub fn optimize(file: &ast::SourceFile, level: OptLevel) -> (ast::SourceFile, OptStats) {
3331 let mut optimizer = Optimizer::new(level);
3332 let optimized = optimizer.optimize_file(file);
3333 (optimized, optimizer.stats)
3334}
3335
3336#[cfg(test)]
3341mod tests {
3342 use super::*;
3343
3344 fn int_lit(v: i64) -> Expr {
3346 Expr::Literal(Literal::Int {
3347 value: v.to_string(),
3348 base: NumBase::Decimal,
3349 suffix: None,
3350 })
3351 }
3352
3353 fn var(name: &str) -> Expr {
3355 Expr::Path(TypePath {
3356 segments: vec![PathSegment {
3357 ident: Ident {
3358 name: name.to_string(),
3359 evidentiality: None,
3360 affect: None,
3361 span: Span { start: 0, end: 0 },
3362 },
3363 generics: None,
3364 }],
3365 })
3366 }
3367
3368 fn add(left: Expr, right: Expr) -> Expr {
3370 Expr::Binary {
3371 op: BinOp::Add,
3372 left: Box::new(left),
3373 right: Box::new(right),
3374 }
3375 }
3376
3377 fn mul(left: Expr, right: Expr) -> Expr {
3379 Expr::Binary {
3380 op: BinOp::Mul,
3381 left: Box::new(left),
3382 right: Box::new(right),
3383 }
3384 }
3385
3386 #[test]
3387 fn test_expr_hash_equal() {
3388 let e1 = add(var("a"), var("b"));
3390 let e2 = add(var("a"), var("b"));
3391 assert_eq!(expr_hash(&e1), expr_hash(&e2));
3392 }
3393
3394 #[test]
3395 fn test_expr_hash_different() {
3396 let e1 = add(var("a"), var("b"));
3398 let e2 = add(var("a"), var("c"));
3399 assert_ne!(expr_hash(&e1), expr_hash(&e2));
3400 }
3401
3402 #[test]
3403 fn test_expr_eq() {
3404 let e1 = add(var("a"), var("b"));
3405 let e2 = add(var("a"), var("b"));
3406 let e3 = add(var("a"), var("c"));
3407
3408 assert!(expr_eq(&e1, &e2));
3409 assert!(!expr_eq(&e1, &e3));
3410 }
3411
3412 #[test]
3413 fn test_is_pure_expr() {
3414 assert!(is_pure_expr(&int_lit(42)));
3415 assert!(is_pure_expr(&var("x")));
3416 assert!(is_pure_expr(&add(var("a"), var("b"))));
3417
3418 let call = Expr::Call {
3420 func: Box::new(var("print")),
3421 args: vec![int_lit(42)],
3422 };
3423 assert!(!is_pure_expr(&call));
3424 }
3425
3426 #[test]
3427 fn test_is_cse_worthy() {
3428 assert!(!is_cse_worthy(&int_lit(42))); assert!(!is_cse_worthy(&var("x"))); assert!(is_cse_worthy(&add(var("a"), var("b")))); }
3432
3433 #[test]
3434 fn test_cse_basic() {
3435 let a_plus_b = add(var("a"), var("b"));
3440
3441 let block = Block {
3442 stmts: vec![
3443 Stmt::Let {
3444 pattern: Pattern::Ident {
3445 mutable: false,
3446 name: Ident {
3447 name: "x".to_string(),
3448 evidentiality: None,
3449 affect: None,
3450 span: Span { start: 0, end: 0 },
3451 },
3452 evidentiality: None,
3453 },
3454 ty: None,
3455 init: Some(a_plus_b.clone()),
3456 },
3457 Stmt::Let {
3458 pattern: Pattern::Ident {
3459 mutable: false,
3460 name: Ident {
3461 name: "y".to_string(),
3462 evidentiality: None,
3463 affect: None,
3464 span: Span { start: 0, end: 0 },
3465 },
3466 evidentiality: None,
3467 },
3468 ty: None,
3469 init: Some(mul(a_plus_b.clone(), int_lit(2))),
3470 },
3471 ],
3472 expr: None,
3473 };
3474
3475 let mut optimizer = Optimizer::new(OptLevel::Standard);
3476 let result = optimizer.pass_cse_block(&block);
3477
3478 assert_eq!(result.stmts.len(), 3);
3480 assert_eq!(optimizer.stats.expressions_deduplicated, 1);
3481
3482 if let Stmt::Let {
3484 pattern: Pattern::Ident { name, .. },
3485 ..
3486 } = &result.stmts[0]
3487 {
3488 assert_eq!(name.name, "__cse_0");
3489 } else {
3490 panic!("Expected CSE let binding");
3491 }
3492 }
3493
3494 #[test]
3495 fn test_cse_no_duplicates() {
3496 let block = Block {
3498 stmts: vec![
3499 Stmt::Let {
3500 pattern: Pattern::Ident {
3501 mutable: false,
3502 name: Ident {
3503 name: "x".to_string(),
3504 evidentiality: None,
3505 affect: None,
3506 span: Span { start: 0, end: 0 },
3507 },
3508 evidentiality: None,
3509 },
3510 ty: None,
3511 init: Some(add(var("a"), var("b"))),
3512 },
3513 Stmt::Let {
3514 pattern: Pattern::Ident {
3515 mutable: false,
3516 name: Ident {
3517 name: "y".to_string(),
3518 evidentiality: None,
3519 affect: None,
3520 span: Span { start: 0, end: 0 },
3521 },
3522 evidentiality: None,
3523 },
3524 ty: None,
3525 init: Some(add(var("c"), var("d"))),
3526 },
3527 ],
3528 expr: None,
3529 };
3530
3531 let mut optimizer = Optimizer::new(OptLevel::Standard);
3532 let result = optimizer.pass_cse_block(&block);
3533
3534 assert_eq!(result.stmts.len(), 2);
3536 assert_eq!(optimizer.stats.expressions_deduplicated, 0);
3537 }
3538}