1use crate::ast::{
7 self, BinOp, Block, Expr, FunctionAttrs, Ident, Item, Literal, NumBase, Param, PathSegment,
8 Pattern, Stmt, TypeExpr, TypePath, UnaryOp, Visibility,
9};
10use crate::span::Span;
11use std::collections::{HashMap, HashSet};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum OptLevel {
20 None,
22 Basic,
24 Standard,
26 Aggressive,
28 Size,
30}
31
32#[derive(Debug, Default, Clone)]
34pub struct OptStats {
35 pub constants_folded: usize,
36 pub dead_code_eliminated: usize,
37 pub expressions_deduplicated: usize,
38 pub functions_inlined: usize,
39 pub strength_reductions: usize,
40 pub branches_simplified: usize,
41 pub loops_optimized: usize,
42 pub tail_recursion_transforms: usize,
43 pub memoization_transforms: usize,
44}
45
46pub struct Optimizer {
48 level: OptLevel,
49 stats: OptStats,
50 functions: HashMap<String, ast::Function>,
52 recursive_functions: HashSet<String>,
54 cse_counter: usize,
56}
57
58impl Optimizer {
59 pub fn new(level: OptLevel) -> Self {
60 Self {
61 level,
62 stats: OptStats::default(),
63 functions: HashMap::new(),
64 recursive_functions: HashSet::new(),
65 cse_counter: 0,
66 }
67 }
68
69 pub fn stats(&self) -> &OptStats {
71 &self.stats
72 }
73
74 pub fn optimize_file(&mut self, file: &ast::SourceFile) -> ast::SourceFile {
76 for item in &file.items {
78 if let Item::Function(func) = &item.node {
79 self.functions.insert(func.name.name.clone(), func.clone());
80 if self.is_recursive(&func.name.name, func) {
81 self.recursive_functions.insert(func.name.name.clone());
82 }
83 }
84 }
85
86 let mut new_items: Vec<crate::span::Spanned<Item>> = Vec::new();
90 let mut transformed_functions: HashMap<String, String> = HashMap::new();
91
92 if matches!(self.level, OptLevel::Standard | OptLevel::Aggressive) {
93 for item in &file.items {
94 if let Item::Function(func) = &item.node {
95 if let Some((helper_func, wrapper_func)) = self.try_accumulator_transform(func)
96 {
97 new_items.push(crate::span::Spanned {
99 node: Item::Function(helper_func),
100 span: item.span.clone(),
101 });
102 transformed_functions
103 .insert(func.name.name.clone(), wrapper_func.name.name.clone());
104 self.stats.tail_recursion_transforms += 1;
105 }
106 }
107 }
108
109 }
115
116 let items: Vec<_> = file
118 .items
119 .iter()
120 .map(|item| {
121 let node = match &item.node {
122 Item::Function(func) => {
123 if let Some((_, wrapper)) = self.try_accumulator_transform(func) {
125 if matches!(self.level, OptLevel::Standard | OptLevel::Aggressive)
126 && transformed_functions.contains_key(&func.name.name)
127 {
128 Item::Function(self.optimize_function(&wrapper))
129 } else {
130 Item::Function(self.optimize_function(func))
131 }
132 } else {
133 Item::Function(self.optimize_function(func))
134 }
135 }
136 other => other.clone(),
137 };
138 crate::span::Spanned {
139 node,
140 span: item.span.clone(),
141 }
142 })
143 .collect();
144
145 new_items.extend(items);
147
148 ast::SourceFile {
149 attrs: file.attrs.clone(),
150 config: file.config.clone(),
151 items: new_items,
152 }
153 }
154
155 fn try_accumulator_transform(
158 &self,
159 func: &ast::Function,
160 ) -> Option<(ast::Function, ast::Function)> {
161 if func.params.len() != 1 {
163 return None;
164 }
165
166 if !self.recursive_functions.contains(&func.name.name) {
168 return None;
169 }
170
171 let body = func.body.as_ref()?;
172
173 if !self.is_fib_like_pattern(&func.name.name, body) {
177 return None;
178 }
179
180 let param_name = if let Pattern::Ident { name, .. } = &func.params[0].pattern {
182 name.name.clone()
183 } else {
184 return None;
185 };
186
187 let helper_name = format!("{}_tail", func.name.name);
189
190 let helper_func = self.generate_fib_helper(&helper_name, ¶m_name);
192
193 let wrapper_func =
195 self.generate_fib_wrapper(&func.name.name, &helper_name, ¶m_name, func);
196
197 Some((helper_func, wrapper_func))
198 }
199
200 fn is_fib_like_pattern(&self, func_name: &str, body: &Block) -> bool {
202 if body.stmts.is_empty() && body.expr.is_none() {
209 return false;
210 }
211
212 if let Some(expr) = &body.expr {
214 if let Expr::If {
215 else_branch: Some(else_expr),
216 ..
217 } = expr.as_ref()
218 {
219 return self.is_double_recursive_expr(func_name, else_expr);
222 }
223 }
224
225 if body.stmts.len() >= 1 {
227 if let Some(Stmt::Expr(expr) | Stmt::Semi(expr)) = body.stmts.last() {
229 if let Expr::Return(Some(ret_expr)) = expr {
230 return self.is_double_recursive_expr(func_name, ret_expr);
231 }
232 }
233 if let Some(expr) = &body.expr {
234 return self.is_double_recursive_expr(func_name, expr);
235 }
236 }
237
238 false
239 }
240
241 fn is_double_recursive_expr(&self, func_name: &str, expr: &Expr) -> bool {
243 if let Expr::Binary {
244 op: BinOp::Add,
245 left,
246 right,
247 } = expr
248 {
249 let left_is_recursive = self.is_recursive_call_with_decrement(func_name, left);
250 let right_is_recursive = self.is_recursive_call_with_decrement(func_name, right);
251 return left_is_recursive && right_is_recursive;
252 }
253 false
254 }
255
256 fn is_recursive_call_with_decrement(&self, func_name: &str, expr: &Expr) -> bool {
258 if let Expr::Call { func, args } = expr {
259 if let Expr::Path(path) = func.as_ref() {
260 if path.segments.last().map(|s| s.ident.name.as_str()) == Some(func_name) {
261 if args.len() == 1 {
263 if let Expr::Binary { op: BinOp::Sub, .. } = &args[0] {
264 return true;
265 }
266 }
267 }
268 }
269 }
270 false
271 }
272
273 fn generate_fib_helper(&self, name: &str, _param_name: &str) -> ast::Function {
275 let span = Span { start: 0, end: 0 };
276
277 let n_ident = Ident {
282 name: "n".to_string(),
283 evidentiality: None,
284 affect: None,
285 span: span.clone(),
286 };
287 let a_ident = Ident {
288 name: "a".to_string(),
289 evidentiality: None,
290 affect: None,
291 span: span.clone(),
292 };
293 let b_ident = Ident {
294 name: "b".to_string(),
295 evidentiality: None,
296 affect: None,
297 span: span.clone(),
298 };
299
300 let params = vec![
301 Param {
302 pattern: Pattern::Ident {
303 mutable: false,
304 name: n_ident.clone(),
305 evidentiality: None,
306 },
307 ty: TypeExpr::Infer,
308 },
309 Param {
310 pattern: Pattern::Ident {
311 mutable: false,
312 name: a_ident.clone(),
313 evidentiality: None,
314 },
315 ty: TypeExpr::Infer,
316 },
317 Param {
318 pattern: Pattern::Ident {
319 mutable: false,
320 name: b_ident.clone(),
321 evidentiality: None,
322 },
323 ty: TypeExpr::Infer,
324 },
325 ];
326
327 let condition = Expr::Binary {
329 op: BinOp::Le,
330 left: Box::new(Expr::Path(TypePath {
331 segments: vec![PathSegment {
332 ident: n_ident.clone(),
333 generics: None,
334 }],
335 })),
336 right: Box::new(Expr::Literal(Literal::Int {
337 value: "0".to_string(),
338 base: NumBase::Decimal,
339 suffix: None,
340 })),
341 };
342
343 let then_branch = Block {
345 stmts: vec![],
346 expr: Some(Box::new(Expr::Return(Some(Box::new(Expr::Path(
347 TypePath {
348 segments: vec![PathSegment {
349 ident: a_ident.clone(),
350 generics: None,
351 }],
352 },
353 )))))),
354 };
355
356 let recursive_call = Expr::Call {
358 func: Box::new(Expr::Path(TypePath {
359 segments: vec![PathSegment {
360 ident: Ident {
361 name: name.to_string(),
362 evidentiality: None,
363 affect: None,
364 span: span.clone(),
365 },
366 generics: None,
367 }],
368 })),
369 args: vec![
370 Expr::Binary {
372 op: BinOp::Sub,
373 left: Box::new(Expr::Path(TypePath {
374 segments: vec![PathSegment {
375 ident: n_ident.clone(),
376 generics: None,
377 }],
378 })),
379 right: Box::new(Expr::Literal(Literal::Int {
380 value: "1".to_string(),
381 base: NumBase::Decimal,
382 suffix: None,
383 })),
384 },
385 Expr::Path(TypePath {
387 segments: vec![PathSegment {
388 ident: b_ident.clone(),
389 generics: None,
390 }],
391 }),
392 Expr::Binary {
394 op: BinOp::Add,
395 left: Box::new(Expr::Path(TypePath {
396 segments: vec![PathSegment {
397 ident: a_ident.clone(),
398 generics: None,
399 }],
400 })),
401 right: Box::new(Expr::Path(TypePath {
402 segments: vec![PathSegment {
403 ident: b_ident.clone(),
404 generics: None,
405 }],
406 })),
407 },
408 ],
409 };
410
411 let body = Block {
413 stmts: vec![],
414 expr: Some(Box::new(Expr::If {
415 condition: Box::new(condition),
416 then_branch,
417 else_branch: Some(Box::new(Expr::Return(Some(Box::new(recursive_call))))),
418 })),
419 };
420
421 ast::Function {
422 visibility: Visibility::default(),
423 is_async: false,
424 attrs: FunctionAttrs::default(),
425 name: Ident {
426 name: name.to_string(),
427 evidentiality: None,
428 affect: None,
429 span: span.clone(),
430 },
431 aspect: None,
432 generics: None,
433 params,
434 return_type: None,
435 where_clause: None,
436 body: Some(body),
437 }
438 }
439
440 fn generate_fib_wrapper(
442 &self,
443 name: &str,
444 helper_name: &str,
445 param_name: &str,
446 original: &ast::Function,
447 ) -> ast::Function {
448 let span = Span { start: 0, end: 0 };
449
450 let call_helper = Expr::Call {
452 func: Box::new(Expr::Path(TypePath {
453 segments: vec![PathSegment {
454 ident: Ident {
455 name: helper_name.to_string(),
456 evidentiality: None,
457 affect: None,
458 span: span.clone(),
459 },
460 generics: None,
461 }],
462 })),
463 args: vec![
464 Expr::Path(TypePath {
466 segments: vec![PathSegment {
467 ident: Ident {
468 name: param_name.to_string(),
469 evidentiality: None,
470 affect: None,
471 span: span.clone(),
472 },
473 generics: None,
474 }],
475 }),
476 Expr::Literal(Literal::Int {
478 value: "0".to_string(),
479 base: NumBase::Decimal,
480 suffix: None,
481 }),
482 Expr::Literal(Literal::Int {
484 value: "1".to_string(),
485 base: NumBase::Decimal,
486 suffix: None,
487 }),
488 ],
489 };
490
491 let body = Block {
492 stmts: vec![],
493 expr: Some(Box::new(Expr::Return(Some(Box::new(call_helper))))),
494 };
495
496 ast::Function {
497 visibility: original.visibility,
498 is_async: original.is_async,
499 attrs: original.attrs.clone(),
500 name: Ident {
501 name: name.to_string(),
502 evidentiality: None,
503 affect: None,
504 span: span.clone(),
505 },
506 aspect: original.aspect,
507 generics: original.generics.clone(),
508 params: original.params.clone(),
509 return_type: original.return_type.clone(),
510 where_clause: original.where_clause.clone(),
511 body: Some(body),
512 }
513 }
514
515 #[allow(dead_code)]
522 fn try_memoize_transform(
523 &self,
524 func: &ast::Function,
525 ) -> Option<(ast::Function, ast::Function, ast::Function)> {
526 let param_count = func.params.len();
527 if param_count != 1 && param_count != 2 {
528 return None;
529 }
530
531 let span = Span { start: 0, end: 0 };
532 let func_name = &func.name.name;
533 let impl_name = format!("_memo_impl_{}", func_name);
534 let _cache_name = format!("_memo_cache_{}", func_name);
535 let init_name = format!("_memo_init_{}", func_name);
536
537 let param_names: Vec<String> = func
539 .params
540 .iter()
541 .filter_map(|p| {
542 if let Pattern::Ident { name, .. } = &p.pattern {
543 Some(name.name.clone())
544 } else {
545 None
546 }
547 })
548 .collect();
549
550 if param_names.len() != param_count {
551 return None;
552 }
553
554 let impl_func = ast::Function {
556 visibility: Visibility::default(),
557 is_async: func.is_async,
558 attrs: func.attrs.clone(),
559 name: Ident {
560 name: impl_name.clone(),
561 evidentiality: None,
562 affect: None,
563 span: span.clone(),
564 },
565 aspect: func.aspect,
566 generics: func.generics.clone(),
567 params: func.params.clone(),
568 return_type: func.return_type.clone(),
569 where_clause: func.where_clause.clone(),
570 body: func
571 .body
572 .as_ref()
573 .map(|b| self.redirect_calls_in_block(func_name, func_name, b)),
574 };
575
576 let cache_init_body = Block {
579 stmts: vec![],
580 expr: Some(Box::new(Expr::Call {
581 func: Box::new(Expr::Path(TypePath {
582 segments: vec![PathSegment {
583 ident: Ident {
584 name: "sigil_memo_new".to_string(),
585 evidentiality: None,
586 affect: None,
587 span: span.clone(),
588 },
589 generics: None,
590 }],
591 })),
592 args: vec![Expr::Literal(Literal::Int {
593 value: "65536".to_string(),
594 base: NumBase::Decimal,
595 suffix: None,
596 })],
597 })),
598 };
599
600 let cache_init_func = ast::Function {
601 visibility: Visibility::default(),
602 is_async: false,
603 attrs: FunctionAttrs::default(),
604 name: Ident {
605 name: init_name.clone(),
606 evidentiality: None,
607 affect: None,
608 span: span.clone(),
609 },
610 aspect: None,
611 generics: None,
612 params: vec![],
613 return_type: None,
614 where_clause: None,
615 body: Some(cache_init_body),
616 };
617
618 let wrapper_func = self.generate_memo_wrapper(func, &impl_name, ¶m_names);
620
621 Some((impl_func, cache_init_func, wrapper_func))
622 }
623
624 #[allow(dead_code)]
626 fn generate_memo_wrapper(
627 &self,
628 original: &ast::Function,
629 impl_name: &str,
630 param_names: &[String],
631 ) -> ast::Function {
632 let span = Span { start: 0, end: 0 };
633 let param_count = param_names.len();
634
635 let cache_var = Ident {
637 name: "__cache".to_string(),
638 evidentiality: None,
639 affect: None,
640 span: span.clone(),
641 };
642 let result_var = Ident {
643 name: "__result".to_string(),
644 evidentiality: None,
645 affect: None,
646 span: span.clone(),
647 };
648 let cached_var = Ident {
649 name: "__cached".to_string(),
650 evidentiality: None,
651 affect: None,
652 span: span.clone(),
653 };
654
655 let mut stmts = vec![];
656
657 stmts.push(Stmt::Let {
659 pattern: Pattern::Ident {
660 mutable: false,
661 name: cache_var.clone(),
662 evidentiality: None,
663 },
664 ty: None,
665 init: Some(Expr::Call {
666 func: Box::new(Expr::Path(TypePath {
667 segments: vec![PathSegment {
668 ident: Ident {
669 name: "sigil_memo_new".to_string(),
670 evidentiality: None,
671 affect: None,
672 span: span.clone(),
673 },
674 generics: None,
675 }],
676 })),
677 args: vec![Expr::Literal(Literal::Int {
678 value: "65536".to_string(),
679 base: NumBase::Decimal,
680 suffix: None,
681 })],
682 }),
683 });
684
685 let get_fn_name = if param_count == 1 {
687 "sigil_memo_get_1"
688 } else {
689 "sigil_memo_get_2"
690 };
691 let mut get_args = vec![Expr::Path(TypePath {
692 segments: vec![PathSegment {
693 ident: cache_var.clone(),
694 generics: None,
695 }],
696 })];
697 for name in param_names {
698 get_args.push(Expr::Path(TypePath {
699 segments: vec![PathSegment {
700 ident: Ident {
701 name: name.clone(),
702 evidentiality: None,
703 affect: None,
704 span: span.clone(),
705 },
706 generics: None,
707 }],
708 }));
709 }
710
711 stmts.push(Stmt::Let {
712 pattern: Pattern::Ident {
713 mutable: false,
714 name: cached_var.clone(),
715 evidentiality: None,
716 },
717 ty: None,
718 init: Some(Expr::Call {
719 func: Box::new(Expr::Path(TypePath {
720 segments: vec![PathSegment {
721 ident: Ident {
722 name: get_fn_name.to_string(),
723 evidentiality: None,
724 affect: None,
725 span: span.clone(),
726 },
727 generics: None,
728 }],
729 })),
730 args: get_args,
731 }),
732 });
733
734 let cache_check = Expr::If {
737 condition: Box::new(Expr::Binary {
738 op: BinOp::Ne,
739 left: Box::new(Expr::Path(TypePath {
740 segments: vec![PathSegment {
741 ident: cached_var.clone(),
742 generics: None,
743 }],
744 })),
745 right: Box::new(Expr::Unary {
746 op: UnaryOp::Neg,
747 expr: Box::new(Expr::Literal(Literal::Int {
748 value: "9223372036854775807".to_string(),
749 base: NumBase::Decimal,
750 suffix: None,
751 })),
752 }),
753 }),
754 then_branch: Block {
755 stmts: vec![],
756 expr: Some(Box::new(Expr::Return(Some(Box::new(Expr::Path(
757 TypePath {
758 segments: vec![PathSegment {
759 ident: cached_var.clone(),
760 generics: None,
761 }],
762 },
763 )))))),
764 },
765 else_branch: None,
766 };
767 stmts.push(Stmt::Semi(cache_check));
768
769 let mut impl_args = vec![];
771 for name in param_names {
772 impl_args.push(Expr::Path(TypePath {
773 segments: vec![PathSegment {
774 ident: Ident {
775 name: name.clone(),
776 evidentiality: None,
777 affect: None,
778 span: span.clone(),
779 },
780 generics: None,
781 }],
782 }));
783 }
784
785 stmts.push(Stmt::Let {
786 pattern: Pattern::Ident {
787 mutable: false,
788 name: result_var.clone(),
789 evidentiality: None,
790 },
791 ty: None,
792 init: Some(Expr::Call {
793 func: Box::new(Expr::Path(TypePath {
794 segments: vec![PathSegment {
795 ident: Ident {
796 name: impl_name.to_string(),
797 evidentiality: None,
798 affect: None,
799 span: span.clone(),
800 },
801 generics: None,
802 }],
803 })),
804 args: impl_args,
805 }),
806 });
807
808 let set_fn_name = if param_count == 1 {
810 "sigil_memo_set_1"
811 } else {
812 "sigil_memo_set_2"
813 };
814 let mut set_args = vec![Expr::Path(TypePath {
815 segments: vec![PathSegment {
816 ident: cache_var.clone(),
817 generics: None,
818 }],
819 })];
820 for name in param_names {
821 set_args.push(Expr::Path(TypePath {
822 segments: vec![PathSegment {
823 ident: Ident {
824 name: name.clone(),
825 evidentiality: None,
826 affect: None,
827 span: span.clone(),
828 },
829 generics: None,
830 }],
831 }));
832 }
833 set_args.push(Expr::Path(TypePath {
834 segments: vec![PathSegment {
835 ident: result_var.clone(),
836 generics: None,
837 }],
838 }));
839
840 stmts.push(Stmt::Semi(Expr::Call {
841 func: Box::new(Expr::Path(TypePath {
842 segments: vec![PathSegment {
843 ident: Ident {
844 name: set_fn_name.to_string(),
845 evidentiality: None,
846 affect: None,
847 span: span.clone(),
848 },
849 generics: None,
850 }],
851 })),
852 args: set_args,
853 }));
854
855 let body = Block {
857 stmts,
858 expr: Some(Box::new(Expr::Return(Some(Box::new(Expr::Path(
859 TypePath {
860 segments: vec![PathSegment {
861 ident: result_var.clone(),
862 generics: None,
863 }],
864 },
865 )))))),
866 };
867
868 ast::Function {
869 visibility: original.visibility,
870 is_async: original.is_async,
871 attrs: original.attrs.clone(),
872 name: original.name.clone(),
873 aspect: original.aspect,
874 generics: original.generics.clone(),
875 params: original.params.clone(),
876 return_type: original.return_type.clone(),
877 where_clause: original.where_clause.clone(),
878 body: Some(body),
879 }
880 }
881
882 #[allow(dead_code)]
884 fn redirect_calls_in_block(&self, _old_name: &str, _new_name: &str, block: &Block) -> Block {
885 block.clone()
887 }
888
889 fn is_recursive(&self, name: &str, func: &ast::Function) -> bool {
891 if let Some(body) = &func.body {
892 self.block_calls_function(name, body)
893 } else {
894 false
895 }
896 }
897
898 fn block_calls_function(&self, name: &str, block: &Block) -> bool {
899 for stmt in &block.stmts {
900 if self.stmt_calls_function(name, stmt) {
901 return true;
902 }
903 }
904 if let Some(expr) = &block.expr {
905 if self.expr_calls_function(name, expr) {
906 return true;
907 }
908 }
909 false
910 }
911
912 fn stmt_calls_function(&self, name: &str, stmt: &Stmt) -> bool {
913 match stmt {
914 Stmt::Let {
915 init: Some(expr), ..
916 } => self.expr_calls_function(name, expr),
917 Stmt::Expr(expr) | Stmt::Semi(expr) => self.expr_calls_function(name, expr),
918 _ => false,
919 }
920 }
921
922 fn expr_calls_function(&self, name: &str, expr: &Expr) -> bool {
923 match expr {
924 Expr::Call { func, args } => {
925 if let Expr::Path(path) = func.as_ref() {
926 if path.segments.last().map(|s| s.ident.name.as_str()) == Some(name) {
927 return true;
928 }
929 }
930 args.iter().any(|a| self.expr_calls_function(name, a))
931 }
932 Expr::Binary { left, right, .. } => {
933 self.expr_calls_function(name, left) || self.expr_calls_function(name, right)
934 }
935 Expr::Unary { expr, .. } => self.expr_calls_function(name, expr),
936 Expr::If {
937 condition,
938 then_branch,
939 else_branch,
940 } => {
941 self.expr_calls_function(name, condition)
942 || self.block_calls_function(name, then_branch)
943 || else_branch
944 .as_ref()
945 .map(|e| self.expr_calls_function(name, e))
946 .unwrap_or(false)
947 }
948 Expr::While { condition, body } => {
949 self.expr_calls_function(name, condition) || self.block_calls_function(name, body)
950 }
951 Expr::Block(block) => self.block_calls_function(name, block),
952 Expr::Return(Some(e)) => self.expr_calls_function(name, e),
953 _ => false,
954 }
955 }
956
957 fn optimize_function(&mut self, func: &ast::Function) -> ast::Function {
959 self.cse_counter = 0;
961
962 let body = if let Some(body) = &func.body {
963 let optimized = match self.level {
965 OptLevel::None => body.clone(),
966 OptLevel::Basic => {
967 let b = self.pass_constant_fold_block(body);
968 self.pass_dead_code_block(&b)
969 }
970 OptLevel::Standard | OptLevel::Size => {
971 let b = self.pass_constant_fold_block(body);
972 let b = self.pass_inline_block(&b); let b = self.pass_strength_reduce_block(&b);
974 let b = self.pass_licm_block(&b); let b = self.pass_cse_block(&b); let b = self.pass_dead_code_block(&b);
977 self.pass_simplify_branches_block(&b)
978 }
979 OptLevel::Aggressive => {
980 let mut b = body.clone();
982 for _ in 0..3 {
983 b = self.pass_constant_fold_block(&b);
984 b = self.pass_inline_block(&b); b = self.pass_strength_reduce_block(&b);
986 b = self.pass_loop_unroll_block(&b); b = self.pass_licm_block(&b); b = self.pass_cse_block(&b); b = self.pass_dead_code_block(&b);
990 b = self.pass_simplify_branches_block(&b);
991 }
992 b
993 }
994 };
995 Some(optimized)
996 } else {
997 None
998 };
999
1000 ast::Function {
1001 visibility: func.visibility.clone(),
1002 is_async: func.is_async,
1003 attrs: func.attrs.clone(),
1004 name: func.name.clone(),
1005 aspect: func.aspect,
1006 generics: func.generics.clone(),
1007 params: func.params.clone(),
1008 return_type: func.return_type.clone(),
1009 where_clause: func.where_clause.clone(),
1010 body,
1011 }
1012 }
1013
1014 fn pass_constant_fold_block(&mut self, block: &Block) -> Block {
1019 let stmts = block
1020 .stmts
1021 .iter()
1022 .map(|s| self.pass_constant_fold_stmt(s))
1023 .collect();
1024 let expr = block
1025 .expr
1026 .as_ref()
1027 .map(|e| Box::new(self.pass_constant_fold_expr(e)));
1028 Block { stmts, expr }
1029 }
1030
1031 fn pass_constant_fold_stmt(&mut self, stmt: &Stmt) -> Stmt {
1032 match stmt {
1033 Stmt::Let {
1034 pattern, ty, init, ..
1035 } => Stmt::Let {
1036 pattern: pattern.clone(),
1037 ty: ty.clone(),
1038 init: init.as_ref().map(|e| self.pass_constant_fold_expr(e)),
1039 },
1040 Stmt::Expr(expr) => Stmt::Expr(self.pass_constant_fold_expr(expr)),
1041 Stmt::Semi(expr) => Stmt::Semi(self.pass_constant_fold_expr(expr)),
1042 Stmt::Item(item) => Stmt::Item(item.clone()),
1043 }
1044 }
1045
1046 fn pass_constant_fold_expr(&mut self, expr: &Expr) -> Expr {
1047 match expr {
1048 Expr::Binary { op, left, right } => {
1049 let left = Box::new(self.pass_constant_fold_expr(left));
1050 let right = Box::new(self.pass_constant_fold_expr(right));
1051
1052 if let (Some(l), Some(r)) = (self.as_int(&left), self.as_int(&right)) {
1054 if let Some(result) = self.fold_binary(op.clone(), l, r) {
1055 self.stats.constants_folded += 1;
1056 return Expr::Literal(Literal::Int {
1057 value: result.to_string(),
1058 base: NumBase::Decimal,
1059 suffix: None,
1060 });
1061 }
1062 }
1063
1064 Expr::Binary {
1065 op: op.clone(),
1066 left,
1067 right,
1068 }
1069 }
1070 Expr::Unary { op, expr: inner } => {
1071 let inner = Box::new(self.pass_constant_fold_expr(inner));
1072
1073 if let Some(v) = self.as_int(&inner) {
1074 if let Some(result) = self.fold_unary(*op, v) {
1075 self.stats.constants_folded += 1;
1076 return Expr::Literal(Literal::Int {
1077 value: result.to_string(),
1078 base: NumBase::Decimal,
1079 suffix: None,
1080 });
1081 }
1082 }
1083
1084 Expr::Unary {
1085 op: *op,
1086 expr: inner,
1087 }
1088 }
1089 Expr::If {
1090 condition,
1091 then_branch,
1092 else_branch,
1093 } => {
1094 let condition = Box::new(self.pass_constant_fold_expr(condition));
1095 let then_branch = self.pass_constant_fold_block(then_branch);
1096 let else_branch = else_branch
1097 .as_ref()
1098 .map(|e| Box::new(self.pass_constant_fold_expr(e)));
1099
1100 if let Some(cond) = self.as_bool(&condition) {
1102 self.stats.branches_simplified += 1;
1103 if cond {
1104 return Expr::Block(then_branch);
1105 } else if let Some(else_expr) = else_branch {
1106 return *else_expr;
1107 } else {
1108 return Expr::Literal(Literal::Bool(false));
1109 }
1110 }
1111
1112 Expr::If {
1113 condition,
1114 then_branch,
1115 else_branch,
1116 }
1117 }
1118 Expr::While { condition, body } => {
1119 let condition = Box::new(self.pass_constant_fold_expr(condition));
1120 let body = self.pass_constant_fold_block(body);
1121
1122 if let Some(false) = self.as_bool(&condition) {
1124 self.stats.branches_simplified += 1;
1125 return Expr::Block(Block {
1126 stmts: vec![],
1127 expr: None,
1128 });
1129 }
1130
1131 Expr::While { condition, body }
1132 }
1133 Expr::Block(block) => Expr::Block(self.pass_constant_fold_block(block)),
1134 Expr::Call { func, args } => {
1135 let args = args
1136 .iter()
1137 .map(|a| self.pass_constant_fold_expr(a))
1138 .collect();
1139 Expr::Call {
1140 func: func.clone(),
1141 args,
1142 }
1143 }
1144 Expr::Return(e) => Expr::Return(
1145 e.as_ref()
1146 .map(|e| Box::new(self.pass_constant_fold_expr(e))),
1147 ),
1148 Expr::Assign { target, value } => {
1149 let value = Box::new(self.pass_constant_fold_expr(value));
1150 Expr::Assign {
1151 target: target.clone(),
1152 value,
1153 }
1154 }
1155 Expr::Index { expr: e, index } => {
1156 let e = Box::new(self.pass_constant_fold_expr(e));
1157 let index = Box::new(self.pass_constant_fold_expr(index));
1158 Expr::Index { expr: e, index }
1159 }
1160 Expr::Array(elements) => {
1161 let elements = elements
1162 .iter()
1163 .map(|e| self.pass_constant_fold_expr(e))
1164 .collect();
1165 Expr::Array(elements)
1166 }
1167 other => other.clone(),
1168 }
1169 }
1170
1171 fn as_int(&self, expr: &Expr) -> Option<i64> {
1172 match expr {
1173 Expr::Literal(Literal::Int { value, .. }) => value.parse().ok(),
1174 Expr::Literal(Literal::Bool(b)) => Some(if *b { 1 } else { 0 }),
1175 _ => None,
1176 }
1177 }
1178
1179 fn as_bool(&self, expr: &Expr) -> Option<bool> {
1180 match expr {
1181 Expr::Literal(Literal::Bool(b)) => Some(*b),
1182 Expr::Literal(Literal::Int { value, .. }) => value.parse::<i64>().ok().map(|v| v != 0),
1183 _ => None,
1184 }
1185 }
1186
1187 fn fold_binary(&self, op: BinOp, l: i64, r: i64) -> Option<i64> {
1188 match op {
1189 BinOp::Add => Some(l.wrapping_add(r)),
1190 BinOp::Sub => Some(l.wrapping_sub(r)),
1191 BinOp::Mul => Some(l.wrapping_mul(r)),
1192 BinOp::Div if r != 0 => Some(l / r),
1193 BinOp::Rem if r != 0 => Some(l % r),
1194 BinOp::BitAnd => Some(l & r),
1195 BinOp::BitOr => Some(l | r),
1196 BinOp::BitXor => Some(l ^ r),
1197 BinOp::Shl => Some(l << (r & 63)),
1198 BinOp::Shr => Some(l >> (r & 63)),
1199 BinOp::Eq => Some(if l == r { 1 } else { 0 }),
1200 BinOp::Ne => Some(if l != r { 1 } else { 0 }),
1201 BinOp::Lt => Some(if l < r { 1 } else { 0 }),
1202 BinOp::Le => Some(if l <= r { 1 } else { 0 }),
1203 BinOp::Gt => Some(if l > r { 1 } else { 0 }),
1204 BinOp::Ge => Some(if l >= r { 1 } else { 0 }),
1205 BinOp::And => Some(if l != 0 && r != 0 { 1 } else { 0 }),
1206 BinOp::Or => Some(if l != 0 || r != 0 { 1 } else { 0 }),
1207 _ => None,
1208 }
1209 }
1210
1211 fn fold_unary(&self, op: UnaryOp, v: i64) -> Option<i64> {
1212 match op {
1213 UnaryOp::Neg => Some(-v),
1214 UnaryOp::Not => Some(if v == 0 { 1 } else { 0 }),
1215 _ => None,
1216 }
1217 }
1218
1219 fn pass_strength_reduce_block(&mut self, block: &Block) -> Block {
1224 let stmts = block
1225 .stmts
1226 .iter()
1227 .map(|s| self.pass_strength_reduce_stmt(s))
1228 .collect();
1229 let expr = block
1230 .expr
1231 .as_ref()
1232 .map(|e| Box::new(self.pass_strength_reduce_expr(e)));
1233 Block { stmts, expr }
1234 }
1235
1236 fn pass_strength_reduce_stmt(&mut self, stmt: &Stmt) -> Stmt {
1237 match stmt {
1238 Stmt::Let {
1239 pattern, ty, init, ..
1240 } => Stmt::Let {
1241 pattern: pattern.clone(),
1242 ty: ty.clone(),
1243 init: init.as_ref().map(|e| self.pass_strength_reduce_expr(e)),
1244 },
1245 Stmt::Expr(expr) => Stmt::Expr(self.pass_strength_reduce_expr(expr)),
1246 Stmt::Semi(expr) => Stmt::Semi(self.pass_strength_reduce_expr(expr)),
1247 Stmt::Item(item) => Stmt::Item(item.clone()),
1248 }
1249 }
1250
1251 fn pass_strength_reduce_expr(&mut self, expr: &Expr) -> Expr {
1252 match expr {
1253 Expr::Binary { op, left, right } => {
1254 let left = Box::new(self.pass_strength_reduce_expr(left));
1255 let right = Box::new(self.pass_strength_reduce_expr(right));
1256
1257 if *op == BinOp::Mul {
1259 if let Some(n) = self.as_int(&right) {
1260 if n > 0 && (n as u64).is_power_of_two() {
1261 self.stats.strength_reductions += 1;
1262 let shift = (n as u64).trailing_zeros() as i64;
1263 return Expr::Binary {
1264 op: BinOp::Shl,
1265 left,
1266 right: Box::new(Expr::Literal(Literal::Int {
1267 value: shift.to_string(),
1268 base: NumBase::Decimal,
1269 suffix: None,
1270 })),
1271 };
1272 }
1273 }
1274 if let Some(n) = self.as_int(&left) {
1275 if n > 0 && (n as u64).is_power_of_two() {
1276 self.stats.strength_reductions += 1;
1277 let shift = (n as u64).trailing_zeros() as i64;
1278 return Expr::Binary {
1279 op: BinOp::Shl,
1280 left: right,
1281 right: Box::new(Expr::Literal(Literal::Int {
1282 value: shift.to_string(),
1283 base: NumBase::Decimal,
1284 suffix: None,
1285 })),
1286 };
1287 }
1288 }
1289 }
1290
1291 if let Some(n) = self.as_int(&right) {
1293 match (op, n) {
1294 (BinOp::Add | BinOp::Sub | BinOp::BitOr | BinOp::BitXor, 0)
1295 | (BinOp::Mul | BinOp::Div, 1)
1296 | (BinOp::Shl | BinOp::Shr, 0) => {
1297 self.stats.strength_reductions += 1;
1298 return *left;
1299 }
1300 (BinOp::Mul, 0) | (BinOp::BitAnd, 0) => {
1301 self.stats.strength_reductions += 1;
1302 return Expr::Literal(Literal::Int {
1303 value: "0".to_string(),
1304 base: NumBase::Decimal,
1305 suffix: None,
1306 });
1307 }
1308 _ => {}
1309 }
1310 }
1311
1312 if let Some(n) = self.as_int(&left) {
1314 match (op, n) {
1315 (BinOp::Add | BinOp::BitOr | BinOp::BitXor, 0) | (BinOp::Mul, 1) => {
1316 self.stats.strength_reductions += 1;
1317 return *right;
1318 }
1319 (BinOp::Mul, 0) | (BinOp::BitAnd, 0) => {
1320 self.stats.strength_reductions += 1;
1321 return Expr::Literal(Literal::Int {
1322 value: "0".to_string(),
1323 base: NumBase::Decimal,
1324 suffix: None,
1325 });
1326 }
1327 _ => {}
1328 }
1329 }
1330
1331 Expr::Binary {
1332 op: op.clone(),
1333 left,
1334 right,
1335 }
1336 }
1337 Expr::Unary { op, expr: inner } => {
1338 let inner = Box::new(self.pass_strength_reduce_expr(inner));
1339
1340 if *op == UnaryOp::Neg {
1342 if let Expr::Unary {
1343 op: UnaryOp::Neg,
1344 expr: inner2,
1345 } = inner.as_ref()
1346 {
1347 self.stats.strength_reductions += 1;
1348 return *inner2.clone();
1349 }
1350 }
1351
1352 if *op == UnaryOp::Not {
1354 if let Expr::Unary {
1355 op: UnaryOp::Not,
1356 expr: inner2,
1357 } = inner.as_ref()
1358 {
1359 self.stats.strength_reductions += 1;
1360 return *inner2.clone();
1361 }
1362 }
1363
1364 Expr::Unary {
1365 op: *op,
1366 expr: inner,
1367 }
1368 }
1369 Expr::If {
1370 condition,
1371 then_branch,
1372 else_branch,
1373 } => {
1374 let condition = Box::new(self.pass_strength_reduce_expr(condition));
1375 let then_branch = self.pass_strength_reduce_block(then_branch);
1376 let else_branch = else_branch
1377 .as_ref()
1378 .map(|e| Box::new(self.pass_strength_reduce_expr(e)));
1379 Expr::If {
1380 condition,
1381 then_branch,
1382 else_branch,
1383 }
1384 }
1385 Expr::While { condition, body } => {
1386 let condition = Box::new(self.pass_strength_reduce_expr(condition));
1387 let body = self.pass_strength_reduce_block(body);
1388 Expr::While { condition, body }
1389 }
1390 Expr::Block(block) => Expr::Block(self.pass_strength_reduce_block(block)),
1391 Expr::Call { func, args } => {
1392 let args = args
1393 .iter()
1394 .map(|a| self.pass_strength_reduce_expr(a))
1395 .collect();
1396 Expr::Call {
1397 func: func.clone(),
1398 args,
1399 }
1400 }
1401 Expr::Return(e) => Expr::Return(
1402 e.as_ref()
1403 .map(|e| Box::new(self.pass_strength_reduce_expr(e))),
1404 ),
1405 Expr::Assign { target, value } => {
1406 let value = Box::new(self.pass_strength_reduce_expr(value));
1407 Expr::Assign {
1408 target: target.clone(),
1409 value,
1410 }
1411 }
1412 other => other.clone(),
1413 }
1414 }
1415
1416 fn pass_dead_code_block(&mut self, block: &Block) -> Block {
1421 let mut stmts = Vec::new();
1423 let mut found_return = false;
1424
1425 for stmt in &block.stmts {
1426 if found_return {
1427 self.stats.dead_code_eliminated += 1;
1428 continue;
1429 }
1430 let stmt = self.pass_dead_code_stmt(stmt);
1431 if self.stmt_returns(&stmt) {
1432 found_return = true;
1433 }
1434 stmts.push(stmt);
1435 }
1436
1437 let expr = if found_return {
1439 if block.expr.is_some() {
1440 self.stats.dead_code_eliminated += 1;
1441 }
1442 None
1443 } else {
1444 block
1445 .expr
1446 .as_ref()
1447 .map(|e| Box::new(self.pass_dead_code_expr(e)))
1448 };
1449
1450 Block { stmts, expr }
1451 }
1452
1453 fn pass_dead_code_stmt(&mut self, stmt: &Stmt) -> Stmt {
1454 match stmt {
1455 Stmt::Let {
1456 pattern, ty, init, ..
1457 } => Stmt::Let {
1458 pattern: pattern.clone(),
1459 ty: ty.clone(),
1460 init: init.as_ref().map(|e| self.pass_dead_code_expr(e)),
1461 },
1462 Stmt::Expr(expr) => Stmt::Expr(self.pass_dead_code_expr(expr)),
1463 Stmt::Semi(expr) => Stmt::Semi(self.pass_dead_code_expr(expr)),
1464 Stmt::Item(item) => Stmt::Item(item.clone()),
1465 }
1466 }
1467
1468 fn pass_dead_code_expr(&mut self, expr: &Expr) -> Expr {
1469 match expr {
1470 Expr::If {
1471 condition,
1472 then_branch,
1473 else_branch,
1474 } => {
1475 let condition = Box::new(self.pass_dead_code_expr(condition));
1476 let then_branch = self.pass_dead_code_block(then_branch);
1477 let else_branch = else_branch
1478 .as_ref()
1479 .map(|e| Box::new(self.pass_dead_code_expr(e)));
1480 Expr::If {
1481 condition,
1482 then_branch,
1483 else_branch,
1484 }
1485 }
1486 Expr::While { condition, body } => {
1487 let condition = Box::new(self.pass_dead_code_expr(condition));
1488 let body = self.pass_dead_code_block(body);
1489 Expr::While { condition, body }
1490 }
1491 Expr::Block(block) => Expr::Block(self.pass_dead_code_block(block)),
1492 other => other.clone(),
1493 }
1494 }
1495
1496 fn stmt_returns(&self, stmt: &Stmt) -> bool {
1497 match stmt {
1498 Stmt::Expr(expr) | Stmt::Semi(expr) => self.expr_returns(expr),
1499 _ => false,
1500 }
1501 }
1502
1503 fn expr_returns(&self, expr: &Expr) -> bool {
1504 match expr {
1505 Expr::Return(_) => true,
1506 Expr::Block(block) => {
1507 block.stmts.iter().any(|s| self.stmt_returns(s))
1508 || block
1509 .expr
1510 .as_ref()
1511 .map(|e| self.expr_returns(e))
1512 .unwrap_or(false)
1513 }
1514 _ => false,
1515 }
1516 }
1517
1518 fn pass_simplify_branches_block(&mut self, block: &Block) -> Block {
1523 let stmts = block
1524 .stmts
1525 .iter()
1526 .map(|s| self.pass_simplify_branches_stmt(s))
1527 .collect();
1528 let expr = block
1529 .expr
1530 .as_ref()
1531 .map(|e| Box::new(self.pass_simplify_branches_expr(e)));
1532 Block { stmts, expr }
1533 }
1534
1535 fn pass_simplify_branches_stmt(&mut self, stmt: &Stmt) -> Stmt {
1536 match stmt {
1537 Stmt::Let {
1538 pattern, ty, init, ..
1539 } => Stmt::Let {
1540 pattern: pattern.clone(),
1541 ty: ty.clone(),
1542 init: init.as_ref().map(|e| self.pass_simplify_branches_expr(e)),
1543 },
1544 Stmt::Expr(expr) => Stmt::Expr(self.pass_simplify_branches_expr(expr)),
1545 Stmt::Semi(expr) => Stmt::Semi(self.pass_simplify_branches_expr(expr)),
1546 Stmt::Item(item) => Stmt::Item(item.clone()),
1547 }
1548 }
1549
1550 fn pass_simplify_branches_expr(&mut self, expr: &Expr) -> Expr {
1551 match expr {
1552 Expr::If {
1553 condition,
1554 then_branch,
1555 else_branch,
1556 } => {
1557 let condition = Box::new(self.pass_simplify_branches_expr(condition));
1558 let then_branch = self.pass_simplify_branches_block(then_branch);
1559 let else_branch = else_branch
1560 .as_ref()
1561 .map(|e| Box::new(self.pass_simplify_branches_expr(e)));
1562
1563 if let Expr::Unary {
1565 op: UnaryOp::Not,
1566 expr: inner,
1567 } = condition.as_ref()
1568 {
1569 if let Some(else_expr) = &else_branch {
1570 self.stats.branches_simplified += 1;
1571 let new_else = Some(Box::new(Expr::Block(then_branch)));
1572 let new_then = match else_expr.as_ref() {
1573 Expr::Block(b) => b.clone(),
1574 other => Block {
1575 stmts: vec![],
1576 expr: Some(Box::new(other.clone())),
1577 },
1578 };
1579 return Expr::If {
1580 condition: inner.clone(),
1581 then_branch: new_then,
1582 else_branch: new_else,
1583 };
1584 }
1585 }
1586
1587 Expr::If {
1588 condition,
1589 then_branch,
1590 else_branch,
1591 }
1592 }
1593 Expr::While { condition, body } => {
1594 let condition = Box::new(self.pass_simplify_branches_expr(condition));
1595 let body = self.pass_simplify_branches_block(body);
1596 Expr::While { condition, body }
1597 }
1598 Expr::Block(block) => Expr::Block(self.pass_simplify_branches_block(block)),
1599 Expr::Binary { op, left, right } => {
1600 let left = Box::new(self.pass_simplify_branches_expr(left));
1601 let right = Box::new(self.pass_simplify_branches_expr(right));
1602 Expr::Binary {
1603 op: op.clone(),
1604 left,
1605 right,
1606 }
1607 }
1608 Expr::Unary { op, expr: inner } => {
1609 let inner = Box::new(self.pass_simplify_branches_expr(inner));
1610 Expr::Unary {
1611 op: *op,
1612 expr: inner,
1613 }
1614 }
1615 Expr::Call { func, args } => {
1616 let args = args
1617 .iter()
1618 .map(|a| self.pass_simplify_branches_expr(a))
1619 .collect();
1620 Expr::Call {
1621 func: func.clone(),
1622 args,
1623 }
1624 }
1625 Expr::Return(e) => Expr::Return(
1626 e.as_ref()
1627 .map(|e| Box::new(self.pass_simplify_branches_expr(e))),
1628 ),
1629 other => other.clone(),
1630 }
1631 }
1632
1633 fn should_inline(&self, func: &ast::Function) -> bool {
1639 if self.recursive_functions.contains(&func.name.name) {
1641 return false;
1642 }
1643
1644 if let Some(body) = &func.body {
1646 let stmt_count = self.count_stmts_in_block(body);
1647 stmt_count <= 10
1649 } else {
1650 false
1651 }
1652 }
1653
1654 fn count_stmts_in_block(&self, block: &Block) -> usize {
1656 let mut count = block.stmts.len();
1657 if block.expr.is_some() {
1658 count += 1;
1659 }
1660 for stmt in &block.stmts {
1662 count += self.count_stmts_in_stmt(stmt);
1663 }
1664 count
1665 }
1666
1667 fn count_stmts_in_stmt(&self, stmt: &Stmt) -> usize {
1668 match stmt {
1669 Stmt::Expr(e) | Stmt::Semi(e) => self.count_stmts_in_expr(e),
1670 Stmt::Let { init: Some(e), .. } => self.count_stmts_in_expr(e),
1671 _ => 0,
1672 }
1673 }
1674
1675 fn count_stmts_in_expr(&self, expr: &Expr) -> usize {
1676 match expr {
1677 Expr::If {
1678 then_branch,
1679 else_branch,
1680 ..
1681 } => {
1682 let mut count = self.count_stmts_in_block(then_branch);
1683 if let Some(else_expr) = else_branch {
1684 count += self.count_stmts_in_expr(else_expr);
1685 }
1686 count
1687 }
1688 Expr::While { body, .. } => self.count_stmts_in_block(body),
1689 Expr::Block(block) => self.count_stmts_in_block(block),
1690 _ => 0,
1691 }
1692 }
1693
1694 fn inline_call(&mut self, func: &ast::Function, args: &[Expr]) -> Option<Expr> {
1696 let body = func.body.as_ref()?;
1697
1698 let mut param_map: HashMap<String, Expr> = HashMap::new();
1700 for (param, arg) in func.params.iter().zip(args.iter()) {
1701 if let Pattern::Ident { name, .. } = ¶m.pattern {
1702 param_map.insert(name.name.clone(), arg.clone());
1703 }
1704 }
1705
1706 let inlined_body = self.substitute_params_in_block(body, ¶m_map);
1708
1709 self.stats.functions_inlined += 1;
1710
1711 if inlined_body.stmts.is_empty() {
1714 if let Some(expr) = inlined_body.expr {
1715 if let Expr::Return(Some(inner)) = expr.as_ref() {
1717 return Some(inner.as_ref().clone());
1718 }
1719 return Some(*expr);
1720 }
1721 }
1722
1723 Some(Expr::Block(inlined_body))
1724 }
1725
1726 fn substitute_params_in_block(
1728 &self,
1729 block: &Block,
1730 param_map: &HashMap<String, Expr>,
1731 ) -> Block {
1732 let stmts = block
1733 .stmts
1734 .iter()
1735 .map(|s| self.substitute_params_in_stmt(s, param_map))
1736 .collect();
1737 let expr = block
1738 .expr
1739 .as_ref()
1740 .map(|e| Box::new(self.substitute_params_in_expr(e, param_map)));
1741 Block { stmts, expr }
1742 }
1743
1744 fn substitute_params_in_stmt(&self, stmt: &Stmt, param_map: &HashMap<String, Expr>) -> Stmt {
1745 match stmt {
1746 Stmt::Let { pattern, ty, init } => Stmt::Let {
1747 pattern: pattern.clone(),
1748 ty: ty.clone(),
1749 init: init
1750 .as_ref()
1751 .map(|e| self.substitute_params_in_expr(e, param_map)),
1752 },
1753 Stmt::Expr(e) => Stmt::Expr(self.substitute_params_in_expr(e, param_map)),
1754 Stmt::Semi(e) => Stmt::Semi(self.substitute_params_in_expr(e, param_map)),
1755 Stmt::Item(item) => Stmt::Item(item.clone()),
1756 }
1757 }
1758
1759 fn substitute_params_in_expr(&self, expr: &Expr, param_map: &HashMap<String, Expr>) -> Expr {
1760 match expr {
1761 Expr::Path(path) => {
1762 if path.segments.len() == 1 {
1764 let name = &path.segments[0].ident.name;
1765 if let Some(arg) = param_map.get(name) {
1766 return arg.clone();
1767 }
1768 }
1769 expr.clone()
1770 }
1771 Expr::Binary { op, left, right } => Expr::Binary {
1772 op: op.clone(),
1773 left: Box::new(self.substitute_params_in_expr(left, param_map)),
1774 right: Box::new(self.substitute_params_in_expr(right, param_map)),
1775 },
1776 Expr::Unary { op, expr: inner } => Expr::Unary {
1777 op: *op,
1778 expr: Box::new(self.substitute_params_in_expr(inner, param_map)),
1779 },
1780 Expr::If {
1781 condition,
1782 then_branch,
1783 else_branch,
1784 } => Expr::If {
1785 condition: Box::new(self.substitute_params_in_expr(condition, param_map)),
1786 then_branch: self.substitute_params_in_block(then_branch, param_map),
1787 else_branch: else_branch
1788 .as_ref()
1789 .map(|e| Box::new(self.substitute_params_in_expr(e, param_map))),
1790 },
1791 Expr::While { condition, body } => Expr::While {
1792 condition: Box::new(self.substitute_params_in_expr(condition, param_map)),
1793 body: self.substitute_params_in_block(body, param_map),
1794 },
1795 Expr::Block(block) => Expr::Block(self.substitute_params_in_block(block, param_map)),
1796 Expr::Call { func, args } => Expr::Call {
1797 func: Box::new(self.substitute_params_in_expr(func, param_map)),
1798 args: args
1799 .iter()
1800 .map(|a| self.substitute_params_in_expr(a, param_map))
1801 .collect(),
1802 },
1803 Expr::Return(e) => Expr::Return(
1804 e.as_ref()
1805 .map(|e| Box::new(self.substitute_params_in_expr(e, param_map))),
1806 ),
1807 Expr::Assign { target, value } => Expr::Assign {
1808 target: target.clone(),
1809 value: Box::new(self.substitute_params_in_expr(value, param_map)),
1810 },
1811 Expr::Index { expr: e, index } => Expr::Index {
1812 expr: Box::new(self.substitute_params_in_expr(e, param_map)),
1813 index: Box::new(self.substitute_params_in_expr(index, param_map)),
1814 },
1815 Expr::Array(elements) => Expr::Array(
1816 elements
1817 .iter()
1818 .map(|e| self.substitute_params_in_expr(e, param_map))
1819 .collect(),
1820 ),
1821 other => other.clone(),
1822 }
1823 }
1824
1825 fn pass_inline_block(&mut self, block: &Block) -> Block {
1826 let stmts = block
1827 .stmts
1828 .iter()
1829 .map(|s| self.pass_inline_stmt(s))
1830 .collect();
1831 let expr = block
1832 .expr
1833 .as_ref()
1834 .map(|e| Box::new(self.pass_inline_expr(e)));
1835 Block { stmts, expr }
1836 }
1837
1838 fn pass_inline_stmt(&mut self, stmt: &Stmt) -> Stmt {
1839 match stmt {
1840 Stmt::Let { pattern, ty, init } => Stmt::Let {
1841 pattern: pattern.clone(),
1842 ty: ty.clone(),
1843 init: init.as_ref().map(|e| self.pass_inline_expr(e)),
1844 },
1845 Stmt::Expr(e) => Stmt::Expr(self.pass_inline_expr(e)),
1846 Stmt::Semi(e) => Stmt::Semi(self.pass_inline_expr(e)),
1847 Stmt::Item(item) => Stmt::Item(item.clone()),
1848 }
1849 }
1850
1851 fn pass_inline_expr(&mut self, expr: &Expr) -> Expr {
1852 match expr {
1853 Expr::Call { func, args } => {
1854 let args: Vec<Expr> = args.iter().map(|a| self.pass_inline_expr(a)).collect();
1856
1857 if let Expr::Path(path) = func.as_ref() {
1859 if path.segments.len() == 1 {
1860 let func_name = &path.segments[0].ident.name;
1861 if let Some(target_func) = self.functions.get(func_name).cloned() {
1862 if self.should_inline(&target_func)
1863 && args.len() == target_func.params.len()
1864 {
1865 if let Some(inlined) = self.inline_call(&target_func, &args) {
1866 return inlined;
1867 }
1868 }
1869 }
1870 }
1871 }
1872
1873 Expr::Call {
1874 func: func.clone(),
1875 args,
1876 }
1877 }
1878 Expr::Binary { op, left, right } => Expr::Binary {
1879 op: op.clone(),
1880 left: Box::new(self.pass_inline_expr(left)),
1881 right: Box::new(self.pass_inline_expr(right)),
1882 },
1883 Expr::Unary { op, expr: inner } => Expr::Unary {
1884 op: *op,
1885 expr: Box::new(self.pass_inline_expr(inner)),
1886 },
1887 Expr::If {
1888 condition,
1889 then_branch,
1890 else_branch,
1891 } => Expr::If {
1892 condition: Box::new(self.pass_inline_expr(condition)),
1893 then_branch: self.pass_inline_block(then_branch),
1894 else_branch: else_branch
1895 .as_ref()
1896 .map(|e| Box::new(self.pass_inline_expr(e))),
1897 },
1898 Expr::While { condition, body } => Expr::While {
1899 condition: Box::new(self.pass_inline_expr(condition)),
1900 body: self.pass_inline_block(body),
1901 },
1902 Expr::Block(block) => Expr::Block(self.pass_inline_block(block)),
1903 Expr::Return(e) => Expr::Return(e.as_ref().map(|e| Box::new(self.pass_inline_expr(e)))),
1904 Expr::Assign { target, value } => Expr::Assign {
1905 target: target.clone(),
1906 value: Box::new(self.pass_inline_expr(value)),
1907 },
1908 Expr::Index { expr: e, index } => Expr::Index {
1909 expr: Box::new(self.pass_inline_expr(e)),
1910 index: Box::new(self.pass_inline_expr(index)),
1911 },
1912 Expr::Array(elements) => {
1913 Expr::Array(elements.iter().map(|e| self.pass_inline_expr(e)).collect())
1914 }
1915 other => other.clone(),
1916 }
1917 }
1918
1919 fn pass_loop_unroll_block(&mut self, block: &Block) -> Block {
1925 let stmts = block
1926 .stmts
1927 .iter()
1928 .map(|s| self.pass_loop_unroll_stmt(s))
1929 .collect();
1930 let expr = block
1931 .expr
1932 .as_ref()
1933 .map(|e| Box::new(self.pass_loop_unroll_expr(e)));
1934 Block { stmts, expr }
1935 }
1936
1937 fn pass_loop_unroll_stmt(&mut self, stmt: &Stmt) -> Stmt {
1938 match stmt {
1939 Stmt::Let { pattern, ty, init } => Stmt::Let {
1940 pattern: pattern.clone(),
1941 ty: ty.clone(),
1942 init: init.as_ref().map(|e| self.pass_loop_unroll_expr(e)),
1943 },
1944 Stmt::Expr(e) => Stmt::Expr(self.pass_loop_unroll_expr(e)),
1945 Stmt::Semi(e) => Stmt::Semi(self.pass_loop_unroll_expr(e)),
1946 Stmt::Item(item) => Stmt::Item(item.clone()),
1947 }
1948 }
1949
1950 fn pass_loop_unroll_expr(&mut self, expr: &Expr) -> Expr {
1951 match expr {
1952 Expr::While { condition, body } => {
1953 if let Some(unrolled) = self.try_unroll_loop(condition, body) {
1955 self.stats.loops_optimized += 1;
1956 return unrolled;
1957 }
1958 Expr::While {
1960 condition: Box::new(self.pass_loop_unroll_expr(condition)),
1961 body: self.pass_loop_unroll_block(body),
1962 }
1963 }
1964 Expr::If {
1965 condition,
1966 then_branch,
1967 else_branch,
1968 } => Expr::If {
1969 condition: Box::new(self.pass_loop_unroll_expr(condition)),
1970 then_branch: self.pass_loop_unroll_block(then_branch),
1971 else_branch: else_branch
1972 .as_ref()
1973 .map(|e| Box::new(self.pass_loop_unroll_expr(e))),
1974 },
1975 Expr::Block(b) => Expr::Block(self.pass_loop_unroll_block(b)),
1976 Expr::Binary { op, left, right } => Expr::Binary {
1977 op: *op,
1978 left: Box::new(self.pass_loop_unroll_expr(left)),
1979 right: Box::new(self.pass_loop_unroll_expr(right)),
1980 },
1981 Expr::Unary { op, expr: inner } => Expr::Unary {
1982 op: *op,
1983 expr: Box::new(self.pass_loop_unroll_expr(inner)),
1984 },
1985 Expr::Call { func, args } => Expr::Call {
1986 func: func.clone(),
1987 args: args.iter().map(|a| self.pass_loop_unroll_expr(a)).collect(),
1988 },
1989 Expr::Return(e) => {
1990 Expr::Return(e.as_ref().map(|e| Box::new(self.pass_loop_unroll_expr(e))))
1991 }
1992 Expr::Assign { target, value } => Expr::Assign {
1993 target: target.clone(),
1994 value: Box::new(self.pass_loop_unroll_expr(value)),
1995 },
1996 other => other.clone(),
1997 }
1998 }
1999
2000 fn try_unroll_loop(&self, condition: &Expr, body: &Block) -> Option<Expr> {
2003 let (loop_var, upper_bound) = self.extract_loop_bounds(condition)?;
2005
2006 if upper_bound > 8 || upper_bound <= 0 {
2008 return None;
2009 }
2010
2011 if !self.body_has_simple_increment(&loop_var, body) {
2013 return None;
2014 }
2015
2016 let stmt_count = body.stmts.len();
2018 if stmt_count > 5 {
2019 return None;
2020 }
2021
2022 let mut unrolled_stmts: Vec<Stmt> = Vec::new();
2024
2025 for i in 0..upper_bound {
2026 let substituted_body = self.substitute_loop_var_in_block(body, &loop_var, i);
2028
2029 for stmt in &substituted_body.stmts {
2031 if !self.is_increment_stmt(&loop_var, stmt) {
2032 unrolled_stmts.push(stmt.clone());
2033 }
2034 }
2035 }
2036
2037 Some(Expr::Block(Block {
2039 stmts: unrolled_stmts,
2040 expr: None,
2041 }))
2042 }
2043
2044 fn extract_loop_bounds(&self, condition: &Expr) -> Option<(String, i64)> {
2046 if let Expr::Binary {
2047 op: BinOp::Lt,
2048 left,
2049 right,
2050 } = condition
2051 {
2052 if let Expr::Path(path) = left.as_ref() {
2054 if path.segments.len() == 1 {
2055 let var_name = path.segments[0].ident.name.clone();
2056 if let Some(bound) = self.as_int(right) {
2058 return Some((var_name, bound));
2059 }
2060 }
2061 }
2062 }
2063 None
2064 }
2065
2066 fn body_has_simple_increment(&self, loop_var: &str, body: &Block) -> bool {
2068 for stmt in &body.stmts {
2069 if self.is_increment_stmt(loop_var, stmt) {
2070 return true;
2071 }
2072 }
2073 false
2074 }
2075
2076 fn is_increment_stmt(&self, var_name: &str, stmt: &Stmt) -> bool {
2078 match stmt {
2079 Stmt::Semi(Expr::Assign { target, value })
2080 | Stmt::Expr(Expr::Assign { target, value }) => {
2081 if let Expr::Path(path) = target.as_ref() {
2083 if path.segments.len() == 1 && path.segments[0].ident.name == var_name {
2084 if let Expr::Binary {
2086 op: BinOp::Add,
2087 left,
2088 right,
2089 } = value.as_ref()
2090 {
2091 if let Expr::Path(lpath) = left.as_ref() {
2092 if lpath.segments.len() == 1
2093 && lpath.segments[0].ident.name == var_name
2094 {
2095 if let Some(1) = self.as_int(right) {
2096 return true;
2097 }
2098 }
2099 }
2100 }
2101 }
2102 }
2103 false
2104 }
2105 _ => false,
2106 }
2107 }
2108
2109 fn substitute_loop_var_in_block(&self, block: &Block, var_name: &str, value: i64) -> Block {
2111 let stmts = block
2112 .stmts
2113 .iter()
2114 .map(|s| self.substitute_loop_var_in_stmt(s, var_name, value))
2115 .collect();
2116 let expr = block
2117 .expr
2118 .as_ref()
2119 .map(|e| Box::new(self.substitute_loop_var_in_expr(e, var_name, value)));
2120 Block { stmts, expr }
2121 }
2122
2123 fn substitute_loop_var_in_stmt(&self, stmt: &Stmt, var_name: &str, value: i64) -> Stmt {
2124 match stmt {
2125 Stmt::Let { pattern, ty, init } => Stmt::Let {
2126 pattern: pattern.clone(),
2127 ty: ty.clone(),
2128 init: init
2129 .as_ref()
2130 .map(|e| self.substitute_loop_var_in_expr(e, var_name, value)),
2131 },
2132 Stmt::Expr(e) => Stmt::Expr(self.substitute_loop_var_in_expr(e, var_name, value)),
2133 Stmt::Semi(e) => Stmt::Semi(self.substitute_loop_var_in_expr(e, var_name, value)),
2134 Stmt::Item(item) => Stmt::Item(item.clone()),
2135 }
2136 }
2137
2138 fn substitute_loop_var_in_expr(&self, expr: &Expr, var_name: &str, value: i64) -> Expr {
2139 match expr {
2140 Expr::Path(path) => {
2141 if path.segments.len() == 1 && path.segments[0].ident.name == var_name {
2142 return Expr::Literal(Literal::Int {
2143 value: value.to_string(),
2144 base: NumBase::Decimal,
2145 suffix: None,
2146 });
2147 }
2148 expr.clone()
2149 }
2150 Expr::Binary { op, left, right } => Expr::Binary {
2151 op: *op,
2152 left: Box::new(self.substitute_loop_var_in_expr(left, var_name, value)),
2153 right: Box::new(self.substitute_loop_var_in_expr(right, var_name, value)),
2154 },
2155 Expr::Unary { op, expr: inner } => Expr::Unary {
2156 op: *op,
2157 expr: Box::new(self.substitute_loop_var_in_expr(inner, var_name, value)),
2158 },
2159 Expr::Call { func, args } => Expr::Call {
2160 func: Box::new(self.substitute_loop_var_in_expr(func, var_name, value)),
2161 args: args
2162 .iter()
2163 .map(|a| self.substitute_loop_var_in_expr(a, var_name, value))
2164 .collect(),
2165 },
2166 Expr::If {
2167 condition,
2168 then_branch,
2169 else_branch,
2170 } => Expr::If {
2171 condition: Box::new(self.substitute_loop_var_in_expr(condition, var_name, value)),
2172 then_branch: self.substitute_loop_var_in_block(then_branch, var_name, value),
2173 else_branch: else_branch
2174 .as_ref()
2175 .map(|e| Box::new(self.substitute_loop_var_in_expr(e, var_name, value))),
2176 },
2177 Expr::While { condition, body } => Expr::While {
2178 condition: Box::new(self.substitute_loop_var_in_expr(condition, var_name, value)),
2179 body: self.substitute_loop_var_in_block(body, var_name, value),
2180 },
2181 Expr::Block(b) => Expr::Block(self.substitute_loop_var_in_block(b, var_name, value)),
2182 Expr::Return(e) => Expr::Return(
2183 e.as_ref()
2184 .map(|e| Box::new(self.substitute_loop_var_in_expr(e, var_name, value))),
2185 ),
2186 Expr::Assign { target, value: v } => Expr::Assign {
2187 target: Box::new(self.substitute_loop_var_in_expr(target, var_name, value)),
2188 value: Box::new(self.substitute_loop_var_in_expr(v, var_name, value)),
2189 },
2190 Expr::Index { expr: e, index } => Expr::Index {
2191 expr: Box::new(self.substitute_loop_var_in_expr(e, var_name, value)),
2192 index: Box::new(self.substitute_loop_var_in_expr(index, var_name, value)),
2193 },
2194 Expr::Array(elements) => Expr::Array(
2195 elements
2196 .iter()
2197 .map(|e| self.substitute_loop_var_in_expr(e, var_name, value))
2198 .collect(),
2199 ),
2200 other => other.clone(),
2201 }
2202 }
2203
2204 fn pass_licm_block(&mut self, block: &Block) -> Block {
2210 let stmts = block.stmts.iter().map(|s| self.pass_licm_stmt(s)).collect();
2211 let expr = block
2212 .expr
2213 .as_ref()
2214 .map(|e| Box::new(self.pass_licm_expr(e)));
2215 Block { stmts, expr }
2216 }
2217
2218 fn pass_licm_stmt(&mut self, stmt: &Stmt) -> Stmt {
2219 match stmt {
2220 Stmt::Let { pattern, ty, init } => Stmt::Let {
2221 pattern: pattern.clone(),
2222 ty: ty.clone(),
2223 init: init.as_ref().map(|e| self.pass_licm_expr(e)),
2224 },
2225 Stmt::Expr(e) => Stmt::Expr(self.pass_licm_expr(e)),
2226 Stmt::Semi(e) => Stmt::Semi(self.pass_licm_expr(e)),
2227 Stmt::Item(item) => Stmt::Item(item.clone()),
2228 }
2229 }
2230
2231 fn pass_licm_expr(&mut self, expr: &Expr) -> Expr {
2232 match expr {
2233 Expr::While { condition, body } => {
2234 let mut modified_vars = HashSet::new();
2236 self.collect_modified_vars_block(body, &mut modified_vars);
2237
2238 self.collect_modified_vars_expr(condition, &mut modified_vars);
2240
2241 let mut invariant_exprs: Vec<(String, Expr)> = Vec::new();
2243 self.find_loop_invariants(body, &modified_vars, &mut invariant_exprs);
2244
2245 if invariant_exprs.is_empty() {
2246 return Expr::While {
2248 condition: Box::new(self.pass_licm_expr(condition)),
2249 body: self.pass_licm_block(body),
2250 };
2251 }
2252
2253 let mut pre_loop_stmts: Vec<Stmt> = Vec::new();
2255 let mut substitution_map: HashMap<String, String> = HashMap::new();
2256
2257 for (original_key, invariant_expr) in &invariant_exprs {
2258 let var_name = format!("__licm_{}", self.cse_counter);
2259 self.cse_counter += 1;
2260
2261 pre_loop_stmts.push(make_cse_let(&var_name, invariant_expr.clone()));
2262 substitution_map.insert(original_key.clone(), var_name);
2263 self.stats.loops_optimized += 1;
2264 }
2265
2266 let new_body =
2268 self.replace_invariants_in_block(body, &invariant_exprs, &substitution_map);
2269
2270 let new_while = Expr::While {
2272 condition: Box::new(self.pass_licm_expr(condition)),
2273 body: self.pass_licm_block(&new_body),
2274 };
2275
2276 pre_loop_stmts.push(Stmt::Expr(new_while));
2278 Expr::Block(Block {
2279 stmts: pre_loop_stmts,
2280 expr: None,
2281 })
2282 }
2283 Expr::If {
2284 condition,
2285 then_branch,
2286 else_branch,
2287 } => Expr::If {
2288 condition: Box::new(self.pass_licm_expr(condition)),
2289 then_branch: self.pass_licm_block(then_branch),
2290 else_branch: else_branch
2291 .as_ref()
2292 .map(|e| Box::new(self.pass_licm_expr(e))),
2293 },
2294 Expr::Block(b) => Expr::Block(self.pass_licm_block(b)),
2295 Expr::Binary { op, left, right } => Expr::Binary {
2296 op: *op,
2297 left: Box::new(self.pass_licm_expr(left)),
2298 right: Box::new(self.pass_licm_expr(right)),
2299 },
2300 Expr::Unary { op, expr: inner } => Expr::Unary {
2301 op: *op,
2302 expr: Box::new(self.pass_licm_expr(inner)),
2303 },
2304 Expr::Call { func, args } => Expr::Call {
2305 func: func.clone(),
2306 args: args.iter().map(|a| self.pass_licm_expr(a)).collect(),
2307 },
2308 Expr::Return(e) => Expr::Return(e.as_ref().map(|e| Box::new(self.pass_licm_expr(e)))),
2309 Expr::Assign { target, value } => Expr::Assign {
2310 target: target.clone(),
2311 value: Box::new(self.pass_licm_expr(value)),
2312 },
2313 other => other.clone(),
2314 }
2315 }
2316
2317 fn collect_modified_vars_block(&self, block: &Block, modified: &mut HashSet<String>) {
2319 for stmt in &block.stmts {
2320 self.collect_modified_vars_stmt(stmt, modified);
2321 }
2322 if let Some(expr) = &block.expr {
2323 self.collect_modified_vars_expr(expr, modified);
2324 }
2325 }
2326
2327 fn collect_modified_vars_stmt(&self, stmt: &Stmt, modified: &mut HashSet<String>) {
2328 match stmt {
2329 Stmt::Let { pattern, init, .. } => {
2330 if let Pattern::Ident { name, .. } = pattern {
2332 modified.insert(name.name.clone());
2333 }
2334 if let Some(e) = init {
2335 self.collect_modified_vars_expr(e, modified);
2336 }
2337 }
2338 Stmt::Expr(e) | Stmt::Semi(e) => self.collect_modified_vars_expr(e, modified),
2339 _ => {}
2340 }
2341 }
2342
2343 fn collect_modified_vars_expr(&self, expr: &Expr, modified: &mut HashSet<String>) {
2344 match expr {
2345 Expr::Assign { target, value } => {
2346 if let Expr::Path(path) = target.as_ref() {
2347 if path.segments.len() == 1 {
2348 modified.insert(path.segments[0].ident.name.clone());
2349 }
2350 }
2351 self.collect_modified_vars_expr(value, modified);
2352 }
2353 Expr::Binary { left, right, .. } => {
2354 self.collect_modified_vars_expr(left, modified);
2355 self.collect_modified_vars_expr(right, modified);
2356 }
2357 Expr::Unary { expr: inner, .. } => {
2358 self.collect_modified_vars_expr(inner, modified);
2359 }
2360 Expr::If {
2361 condition,
2362 then_branch,
2363 else_branch,
2364 } => {
2365 self.collect_modified_vars_expr(condition, modified);
2366 self.collect_modified_vars_block(then_branch, modified);
2367 if let Some(e) = else_branch {
2368 self.collect_modified_vars_expr(e, modified);
2369 }
2370 }
2371 Expr::While { condition, body } => {
2372 self.collect_modified_vars_expr(condition, modified);
2373 self.collect_modified_vars_block(body, modified);
2374 }
2375 Expr::Block(b) => self.collect_modified_vars_block(b, modified),
2376 Expr::Call { args, .. } => {
2377 for arg in args {
2378 self.collect_modified_vars_expr(arg, modified);
2379 }
2380 }
2381 Expr::Return(Some(e)) => self.collect_modified_vars_expr(e, modified),
2382 _ => {}
2383 }
2384 }
2385
2386 fn find_loop_invariants(
2388 &self,
2389 block: &Block,
2390 modified: &HashSet<String>,
2391 out: &mut Vec<(String, Expr)>,
2392 ) {
2393 for stmt in &block.stmts {
2394 self.find_loop_invariants_stmt(stmt, modified, out);
2395 }
2396 if let Some(expr) = &block.expr {
2397 self.find_loop_invariants_expr(expr, modified, out);
2398 }
2399 }
2400
2401 fn find_loop_invariants_stmt(
2402 &self,
2403 stmt: &Stmt,
2404 modified: &HashSet<String>,
2405 out: &mut Vec<(String, Expr)>,
2406 ) {
2407 match stmt {
2408 Stmt::Let { init: Some(e), .. } => self.find_loop_invariants_expr(e, modified, out),
2409 Stmt::Expr(e) | Stmt::Semi(e) => self.find_loop_invariants_expr(e, modified, out),
2410 _ => {}
2411 }
2412 }
2413
2414 fn find_loop_invariants_expr(
2415 &self,
2416 expr: &Expr,
2417 modified: &HashSet<String>,
2418 out: &mut Vec<(String, Expr)>,
2419 ) {
2420 match expr {
2422 Expr::Binary { left, right, .. } => {
2423 self.find_loop_invariants_expr(left, modified, out);
2424 self.find_loop_invariants_expr(right, modified, out);
2425 }
2426 Expr::Unary { expr: inner, .. } => {
2427 self.find_loop_invariants_expr(inner, modified, out);
2428 }
2429 Expr::Call { args, .. } => {
2430 for arg in args {
2431 self.find_loop_invariants_expr(arg, modified, out);
2432 }
2433 }
2434 Expr::Index { expr: e, index } => {
2435 self.find_loop_invariants_expr(e, modified, out);
2436 self.find_loop_invariants_expr(index, modified, out);
2437 }
2438 _ => {}
2439 }
2440
2441 if self.is_loop_invariant(expr, modified) && is_cse_worthy(expr) && is_pure_expr(expr) {
2443 let key = format!("{:?}", expr_hash(expr));
2444 if !out.iter().any(|(k, _)| k == &key) {
2446 out.push((key, expr.clone()));
2447 }
2448 }
2449 }
2450
2451 fn is_loop_invariant(&self, expr: &Expr, modified: &HashSet<String>) -> bool {
2453 match expr {
2454 Expr::Literal(_) => true,
2455 Expr::Path(path) => {
2456 if path.segments.len() == 1 {
2457 !modified.contains(&path.segments[0].ident.name)
2458 } else {
2459 true }
2461 }
2462 Expr::Binary { left, right, .. } => {
2463 self.is_loop_invariant(left, modified) && self.is_loop_invariant(right, modified)
2464 }
2465 Expr::Unary { expr: inner, .. } => self.is_loop_invariant(inner, modified),
2466 Expr::Index { expr: e, index } => {
2467 self.is_loop_invariant(e, modified) && self.is_loop_invariant(index, modified)
2468 }
2469 Expr::Call { .. } => false,
2471 _ => false,
2473 }
2474 }
2475
2476 fn replace_invariants_in_block(
2478 &self,
2479 block: &Block,
2480 invariants: &[(String, Expr)],
2481 subs: &HashMap<String, String>,
2482 ) -> Block {
2483 let stmts = block
2484 .stmts
2485 .iter()
2486 .map(|s| self.replace_invariants_in_stmt(s, invariants, subs))
2487 .collect();
2488 let expr = block
2489 .expr
2490 .as_ref()
2491 .map(|e| Box::new(self.replace_invariants_in_expr(e, invariants, subs)));
2492 Block { stmts, expr }
2493 }
2494
2495 fn replace_invariants_in_stmt(
2496 &self,
2497 stmt: &Stmt,
2498 invariants: &[(String, Expr)],
2499 subs: &HashMap<String, String>,
2500 ) -> Stmt {
2501 match stmt {
2502 Stmt::Let { pattern, ty, init } => Stmt::Let {
2503 pattern: pattern.clone(),
2504 ty: ty.clone(),
2505 init: init
2506 .as_ref()
2507 .map(|e| self.replace_invariants_in_expr(e, invariants, subs)),
2508 },
2509 Stmt::Expr(e) => Stmt::Expr(self.replace_invariants_in_expr(e, invariants, subs)),
2510 Stmt::Semi(e) => Stmt::Semi(self.replace_invariants_in_expr(e, invariants, subs)),
2511 Stmt::Item(item) => Stmt::Item(item.clone()),
2512 }
2513 }
2514
2515 fn replace_invariants_in_expr(
2516 &self,
2517 expr: &Expr,
2518 invariants: &[(String, Expr)],
2519 subs: &HashMap<String, String>,
2520 ) -> Expr {
2521 let key = format!("{:?}", expr_hash(expr));
2523 for (inv_key, inv_expr) in invariants {
2524 if &key == inv_key && expr_eq(expr, inv_expr) {
2525 if let Some(var_name) = subs.get(inv_key) {
2526 return Expr::Path(TypePath {
2527 segments: vec![PathSegment {
2528 ident: Ident {
2529 name: var_name.clone(),
2530 evidentiality: None,
2531 affect: None,
2532 span: Span { start: 0, end: 0 },
2533 },
2534 generics: None,
2535 }],
2536 });
2537 }
2538 }
2539 }
2540
2541 match expr {
2543 Expr::Binary { op, left, right } => Expr::Binary {
2544 op: *op,
2545 left: Box::new(self.replace_invariants_in_expr(left, invariants, subs)),
2546 right: Box::new(self.replace_invariants_in_expr(right, invariants, subs)),
2547 },
2548 Expr::Unary { op, expr: inner } => Expr::Unary {
2549 op: *op,
2550 expr: Box::new(self.replace_invariants_in_expr(inner, invariants, subs)),
2551 },
2552 Expr::Call { func, args } => Expr::Call {
2553 func: func.clone(),
2554 args: args
2555 .iter()
2556 .map(|a| self.replace_invariants_in_expr(a, invariants, subs))
2557 .collect(),
2558 },
2559 Expr::If {
2560 condition,
2561 then_branch,
2562 else_branch,
2563 } => Expr::If {
2564 condition: Box::new(self.replace_invariants_in_expr(condition, invariants, subs)),
2565 then_branch: self.replace_invariants_in_block(then_branch, invariants, subs),
2566 else_branch: else_branch
2567 .as_ref()
2568 .map(|e| Box::new(self.replace_invariants_in_expr(e, invariants, subs))),
2569 },
2570 Expr::While { condition, body } => Expr::While {
2571 condition: Box::new(self.replace_invariants_in_expr(condition, invariants, subs)),
2572 body: self.replace_invariants_in_block(body, invariants, subs),
2573 },
2574 Expr::Block(b) => Expr::Block(self.replace_invariants_in_block(b, invariants, subs)),
2575 Expr::Return(e) => Expr::Return(
2576 e.as_ref()
2577 .map(|e| Box::new(self.replace_invariants_in_expr(e, invariants, subs))),
2578 ),
2579 Expr::Assign { target, value } => Expr::Assign {
2580 target: target.clone(),
2581 value: Box::new(self.replace_invariants_in_expr(value, invariants, subs)),
2582 },
2583 Expr::Index { expr: e, index } => Expr::Index {
2584 expr: Box::new(self.replace_invariants_in_expr(e, invariants, subs)),
2585 index: Box::new(self.replace_invariants_in_expr(index, invariants, subs)),
2586 },
2587 other => other.clone(),
2588 }
2589 }
2590
2591 fn pass_cse_block(&mut self, block: &Block) -> Block {
2596 let mut collected = Vec::new();
2598 collect_exprs_from_block(block, &mut collected);
2599
2600 let mut expr_counts: HashMap<u64, Vec<Expr>> = HashMap::new();
2602 for ce in &collected {
2603 let entry = expr_counts.entry(ce.hash).or_insert_with(Vec::new);
2604 let found = entry.iter().any(|e| expr_eq(e, &ce.expr));
2606 if !found {
2607 entry.push(ce.expr.clone());
2608 }
2609 }
2610
2611 let mut occurrence_counts: Vec<(Expr, usize)> = Vec::new();
2613 for ce in &collected {
2614 let existing = occurrence_counts
2616 .iter_mut()
2617 .find(|(e, _)| expr_eq(e, &ce.expr));
2618 if let Some((_, count)) = existing {
2619 *count += 1;
2620 } else {
2621 occurrence_counts.push((ce.expr.clone(), 1));
2622 }
2623 }
2624
2625 let candidates: Vec<Expr> = occurrence_counts
2627 .into_iter()
2628 .filter(|(_, count)| *count >= 2)
2629 .map(|(expr, _)| expr)
2630 .collect();
2631
2632 if candidates.is_empty() {
2633 return self.pass_cse_nested(block);
2635 }
2636
2637 let mut result_block = block.clone();
2639 let mut new_lets: Vec<Stmt> = Vec::new();
2640
2641 for expr in candidates {
2642 let var_name = format!("__cse_{}", self.cse_counter);
2643 self.cse_counter += 1;
2644
2645 new_lets.push(make_cse_let(&var_name, expr.clone()));
2647
2648 result_block = replace_in_block(&result_block, &expr, &var_name);
2650
2651 self.stats.expressions_deduplicated += 1;
2652 }
2653
2654 let mut final_stmts = new_lets;
2656 final_stmts.extend(result_block.stmts);
2657
2658 let result = Block {
2660 stmts: final_stmts,
2661 expr: result_block.expr,
2662 };
2663 self.pass_cse_nested(&result)
2664 }
2665
2666 fn pass_cse_nested(&mut self, block: &Block) -> Block {
2668 let stmts = block
2669 .stmts
2670 .iter()
2671 .map(|stmt| self.pass_cse_stmt(stmt))
2672 .collect();
2673 let expr = block.expr.as_ref().map(|e| Box::new(self.pass_cse_expr(e)));
2674 Block { stmts, expr }
2675 }
2676
2677 fn pass_cse_stmt(&mut self, stmt: &Stmt) -> Stmt {
2678 match stmt {
2679 Stmt::Let { pattern, ty, init } => Stmt::Let {
2680 pattern: pattern.clone(),
2681 ty: ty.clone(),
2682 init: init.as_ref().map(|e| self.pass_cse_expr(e)),
2683 },
2684 Stmt::Expr(e) => Stmt::Expr(self.pass_cse_expr(e)),
2685 Stmt::Semi(e) => Stmt::Semi(self.pass_cse_expr(e)),
2686 Stmt::Item(item) => Stmt::Item(item.clone()),
2687 }
2688 }
2689
2690 fn pass_cse_expr(&mut self, expr: &Expr) -> Expr {
2691 match expr {
2692 Expr::If {
2693 condition,
2694 then_branch,
2695 else_branch,
2696 } => Expr::If {
2697 condition: Box::new(self.pass_cse_expr(condition)),
2698 then_branch: self.pass_cse_block(then_branch),
2699 else_branch: else_branch
2700 .as_ref()
2701 .map(|e| Box::new(self.pass_cse_expr(e))),
2702 },
2703 Expr::While { condition, body } => Expr::While {
2704 condition: Box::new(self.pass_cse_expr(condition)),
2705 body: self.pass_cse_block(body),
2706 },
2707 Expr::Block(b) => Expr::Block(self.pass_cse_block(b)),
2708 Expr::Binary { op, left, right } => Expr::Binary {
2709 op: *op,
2710 left: Box::new(self.pass_cse_expr(left)),
2711 right: Box::new(self.pass_cse_expr(right)),
2712 },
2713 Expr::Unary { op, expr: inner } => Expr::Unary {
2714 op: *op,
2715 expr: Box::new(self.pass_cse_expr(inner)),
2716 },
2717 Expr::Call { func, args } => Expr::Call {
2718 func: func.clone(),
2719 args: args.iter().map(|a| self.pass_cse_expr(a)).collect(),
2720 },
2721 Expr::Return(e) => Expr::Return(e.as_ref().map(|e| Box::new(self.pass_cse_expr(e)))),
2722 Expr::Assign { target, value } => Expr::Assign {
2723 target: target.clone(),
2724 value: Box::new(self.pass_cse_expr(value)),
2725 },
2726 other => other.clone(),
2727 }
2728 }
2729}
2730
2731fn expr_hash(expr: &Expr) -> u64 {
2737 use std::collections::hash_map::DefaultHasher;
2738 use std::hash::Hasher;
2739
2740 let mut hasher = DefaultHasher::new();
2741 expr_hash_recursive(expr, &mut hasher);
2742 hasher.finish()
2743}
2744
2745fn expr_hash_recursive<H: std::hash::Hasher>(expr: &Expr, hasher: &mut H) {
2746 use std::hash::Hash;
2747
2748 std::mem::discriminant(expr).hash(hasher);
2749
2750 match expr {
2751 Expr::Literal(lit) => match lit {
2752 Literal::Int { value, .. } => value.hash(hasher),
2753 Literal::Float { value, .. } => value.hash(hasher),
2754 Literal::String(s) => s.hash(hasher),
2755 Literal::Char(c) => c.hash(hasher),
2756 Literal::Bool(b) => b.hash(hasher),
2757 _ => {}
2758 },
2759 Expr::Path(path) => {
2760 for seg in &path.segments {
2761 seg.ident.name.hash(hasher);
2762 }
2763 }
2764 Expr::Binary { op, left, right } => {
2765 std::mem::discriminant(op).hash(hasher);
2766 expr_hash_recursive(left, hasher);
2767 expr_hash_recursive(right, hasher);
2768 }
2769 Expr::Unary { op, expr } => {
2770 std::mem::discriminant(op).hash(hasher);
2771 expr_hash_recursive(expr, hasher);
2772 }
2773 Expr::Call { func, args } => {
2774 expr_hash_recursive(func, hasher);
2775 args.len().hash(hasher);
2776 for arg in args {
2777 expr_hash_recursive(arg, hasher);
2778 }
2779 }
2780 Expr::Index { expr, index } => {
2781 expr_hash_recursive(expr, hasher);
2782 expr_hash_recursive(index, hasher);
2783 }
2784 _ => {}
2785 }
2786}
2787
2788fn is_pure_expr(expr: &Expr) -> bool {
2790 match expr {
2791 Expr::Literal(_) => true,
2792 Expr::Path(_) => true,
2793 Expr::Binary { left, right, .. } => is_pure_expr(left) && is_pure_expr(right),
2794 Expr::Unary { expr, .. } => is_pure_expr(expr),
2795 Expr::If {
2796 condition,
2797 then_branch,
2798 else_branch,
2799 } => {
2800 is_pure_expr(condition)
2801 && then_branch.stmts.is_empty()
2802 && then_branch
2803 .expr
2804 .as_ref()
2805 .map(|e| is_pure_expr(e))
2806 .unwrap_or(true)
2807 && else_branch
2808 .as_ref()
2809 .map(|e| is_pure_expr(e))
2810 .unwrap_or(true)
2811 }
2812 Expr::Index { expr, index } => is_pure_expr(expr) && is_pure_expr(index),
2813 Expr::Array(elements) => elements.iter().all(is_pure_expr),
2814 Expr::Call { .. } => false,
2816 Expr::Assign { .. } => false,
2817 Expr::Return(_) => false,
2818 _ => false,
2819 }
2820}
2821
2822fn is_cse_worthy(expr: &Expr) -> bool {
2824 match expr {
2825 Expr::Literal(_) => false,
2827 Expr::Path(_) => false,
2828 Expr::Binary { .. } => true,
2830 Expr::Unary { .. } => true,
2832 Expr::Call { .. } => false,
2834 Expr::Index { .. } => true,
2836 _ => false,
2837 }
2838}
2839
2840fn expr_eq(a: &Expr, b: &Expr) -> bool {
2842 match (a, b) {
2843 (Expr::Literal(la), Expr::Literal(lb)) => match (la, lb) {
2844 (Literal::Int { value: va, .. }, Literal::Int { value: vb, .. }) => va == vb,
2845 (Literal::Float { value: va, .. }, Literal::Float { value: vb, .. }) => va == vb,
2846 (Literal::String(sa), Literal::String(sb)) => sa == sb,
2847 (Literal::Char(ca), Literal::Char(cb)) => ca == cb,
2848 (Literal::Bool(ba), Literal::Bool(bb)) => ba == bb,
2849 _ => false,
2850 },
2851 (Expr::Path(pa), Expr::Path(pb)) => {
2852 pa.segments.len() == pb.segments.len()
2853 && pa
2854 .segments
2855 .iter()
2856 .zip(&pb.segments)
2857 .all(|(sa, sb)| sa.ident.name == sb.ident.name)
2858 }
2859 (
2860 Expr::Binary {
2861 op: oa,
2862 left: la,
2863 right: ra,
2864 },
2865 Expr::Binary {
2866 op: ob,
2867 left: lb,
2868 right: rb,
2869 },
2870 ) => oa == ob && expr_eq(la, lb) && expr_eq(ra, rb),
2871 (Expr::Unary { op: oa, expr: ea }, Expr::Unary { op: ob, expr: eb }) => {
2872 oa == ob && expr_eq(ea, eb)
2873 }
2874 (
2875 Expr::Index {
2876 expr: ea,
2877 index: ia,
2878 },
2879 Expr::Index {
2880 expr: eb,
2881 index: ib,
2882 },
2883 ) => expr_eq(ea, eb) && expr_eq(ia, ib),
2884 (Expr::Call { func: fa, args: aa }, Expr::Call { func: fb, args: ab }) => {
2885 expr_eq(fa, fb) && aa.len() == ab.len() && aa.iter().zip(ab).all(|(a, b)| expr_eq(a, b))
2886 }
2887 _ => false,
2888 }
2889}
2890
2891#[derive(Clone)]
2893struct CollectedExpr {
2894 expr: Expr,
2895 hash: u64,
2896}
2897
2898fn collect_exprs_from_expr(expr: &Expr, out: &mut Vec<CollectedExpr>) {
2900 match expr {
2902 Expr::Binary { left, right, .. } => {
2903 collect_exprs_from_expr(left, out);
2904 collect_exprs_from_expr(right, out);
2905 }
2906 Expr::Unary { expr: inner, .. } => {
2907 collect_exprs_from_expr(inner, out);
2908 }
2909 Expr::Index { expr: e, index } => {
2910 collect_exprs_from_expr(e, out);
2911 collect_exprs_from_expr(index, out);
2912 }
2913 Expr::Call { func, args } => {
2914 collect_exprs_from_expr(func, out);
2915 for arg in args {
2916 collect_exprs_from_expr(arg, out);
2917 }
2918 }
2919 Expr::If {
2920 condition,
2921 then_branch,
2922 else_branch,
2923 } => {
2924 collect_exprs_from_expr(condition, out);
2925 collect_exprs_from_block(then_branch, out);
2926 if let Some(else_expr) = else_branch {
2927 collect_exprs_from_expr(else_expr, out);
2928 }
2929 }
2930 Expr::While { condition, body } => {
2931 collect_exprs_from_expr(condition, out);
2932 collect_exprs_from_block(body, out);
2933 }
2934 Expr::Block(block) => {
2935 collect_exprs_from_block(block, out);
2936 }
2937 Expr::Return(Some(e)) => {
2938 collect_exprs_from_expr(e, out);
2939 }
2940 Expr::Assign { value, .. } => {
2941 collect_exprs_from_expr(value, out);
2942 }
2943 Expr::Array(elements) => {
2944 for e in elements {
2945 collect_exprs_from_expr(e, out);
2946 }
2947 }
2948 _ => {}
2949 }
2950
2951 if is_cse_worthy(expr) && is_pure_expr(expr) {
2953 out.push(CollectedExpr {
2954 expr: expr.clone(),
2955 hash: expr_hash(expr),
2956 });
2957 }
2958}
2959
2960fn collect_exprs_from_block(block: &Block, out: &mut Vec<CollectedExpr>) {
2962 for stmt in &block.stmts {
2963 match stmt {
2964 Stmt::Let { init: Some(e), .. } => collect_exprs_from_expr(e, out),
2965 Stmt::Expr(e) | Stmt::Semi(e) => collect_exprs_from_expr(e, out),
2966 _ => {}
2967 }
2968 }
2969 if let Some(e) = &block.expr {
2970 collect_exprs_from_expr(e, out);
2971 }
2972}
2973
2974fn replace_in_expr(expr: &Expr, target: &Expr, var_name: &str) -> Expr {
2976 if expr_eq(expr, target) {
2978 return Expr::Path(TypePath {
2979 segments: vec![PathSegment {
2980 ident: Ident {
2981 name: var_name.to_string(),
2982 evidentiality: None,
2983 affect: None,
2984 span: Span { start: 0, end: 0 },
2985 },
2986 generics: None,
2987 }],
2988 });
2989 }
2990
2991 match expr {
2993 Expr::Binary { op, left, right } => Expr::Binary {
2994 op: *op,
2995 left: Box::new(replace_in_expr(left, target, var_name)),
2996 right: Box::new(replace_in_expr(right, target, var_name)),
2997 },
2998 Expr::Unary { op, expr: inner } => Expr::Unary {
2999 op: *op,
3000 expr: Box::new(replace_in_expr(inner, target, var_name)),
3001 },
3002 Expr::Index { expr: e, index } => Expr::Index {
3003 expr: Box::new(replace_in_expr(e, target, var_name)),
3004 index: Box::new(replace_in_expr(index, target, var_name)),
3005 },
3006 Expr::Call { func, args } => Expr::Call {
3007 func: Box::new(replace_in_expr(func, target, var_name)),
3008 args: args
3009 .iter()
3010 .map(|a| replace_in_expr(a, target, var_name))
3011 .collect(),
3012 },
3013 Expr::If {
3014 condition,
3015 then_branch,
3016 else_branch,
3017 } => Expr::If {
3018 condition: Box::new(replace_in_expr(condition, target, var_name)),
3019 then_branch: replace_in_block(then_branch, target, var_name),
3020 else_branch: else_branch
3021 .as_ref()
3022 .map(|e| Box::new(replace_in_expr(e, target, var_name))),
3023 },
3024 Expr::While { condition, body } => Expr::While {
3025 condition: Box::new(replace_in_expr(condition, target, var_name)),
3026 body: replace_in_block(body, target, var_name),
3027 },
3028 Expr::Block(block) => Expr::Block(replace_in_block(block, target, var_name)),
3029 Expr::Return(e) => Expr::Return(
3030 e.as_ref()
3031 .map(|e| Box::new(replace_in_expr(e, target, var_name))),
3032 ),
3033 Expr::Assign { target: t, value } => Expr::Assign {
3034 target: t.clone(),
3035 value: Box::new(replace_in_expr(value, target, var_name)),
3036 },
3037 Expr::Array(elements) => Expr::Array(
3038 elements
3039 .iter()
3040 .map(|e| replace_in_expr(e, target, var_name))
3041 .collect(),
3042 ),
3043 other => other.clone(),
3044 }
3045}
3046
3047fn replace_in_block(block: &Block, target: &Expr, var_name: &str) -> Block {
3049 let stmts = block
3050 .stmts
3051 .iter()
3052 .map(|stmt| match stmt {
3053 Stmt::Let { pattern, ty, init } => Stmt::Let {
3054 pattern: pattern.clone(),
3055 ty: ty.clone(),
3056 init: init.as_ref().map(|e| replace_in_expr(e, target, var_name)),
3057 },
3058 Stmt::Expr(e) => Stmt::Expr(replace_in_expr(e, target, var_name)),
3059 Stmt::Semi(e) => Stmt::Semi(replace_in_expr(e, target, var_name)),
3060 Stmt::Item(item) => Stmt::Item(item.clone()),
3061 })
3062 .collect();
3063
3064 let expr = block
3065 .expr
3066 .as_ref()
3067 .map(|e| Box::new(replace_in_expr(e, target, var_name)));
3068
3069 Block { stmts, expr }
3070}
3071
3072fn make_cse_let(var_name: &str, expr: Expr) -> Stmt {
3074 Stmt::Let {
3075 pattern: Pattern::Ident {
3076 mutable: false,
3077 name: Ident {
3078 name: var_name.to_string(),
3079 evidentiality: None,
3080 affect: None,
3081 span: Span { start: 0, end: 0 },
3082 },
3083 evidentiality: None,
3084 },
3085 ty: None,
3086 init: Some(expr),
3087 }
3088}
3089
3090pub fn optimize(file: &ast::SourceFile, level: OptLevel) -> (ast::SourceFile, OptStats) {
3096 let mut optimizer = Optimizer::new(level);
3097 let optimized = optimizer.optimize_file(file);
3098 (optimized, optimizer.stats)
3099}
3100
3101#[cfg(test)]
3106mod tests {
3107 use super::*;
3108
3109 fn int_lit(v: i64) -> Expr {
3111 Expr::Literal(Literal::Int {
3112 value: v.to_string(),
3113 base: NumBase::Decimal,
3114 suffix: None,
3115 })
3116 }
3117
3118 fn var(name: &str) -> Expr {
3120 Expr::Path(TypePath {
3121 segments: vec![PathSegment {
3122 ident: Ident {
3123 name: name.to_string(),
3124 evidentiality: None,
3125 affect: None,
3126 span: Span { start: 0, end: 0 },
3127 },
3128 generics: None,
3129 }],
3130 })
3131 }
3132
3133 fn add(left: Expr, right: Expr) -> Expr {
3135 Expr::Binary {
3136 op: BinOp::Add,
3137 left: Box::new(left),
3138 right: Box::new(right),
3139 }
3140 }
3141
3142 fn mul(left: Expr, right: Expr) -> Expr {
3144 Expr::Binary {
3145 op: BinOp::Mul,
3146 left: Box::new(left),
3147 right: Box::new(right),
3148 }
3149 }
3150
3151 #[test]
3152 fn test_expr_hash_equal() {
3153 let e1 = add(var("a"), var("b"));
3155 let e2 = add(var("a"), var("b"));
3156 assert_eq!(expr_hash(&e1), expr_hash(&e2));
3157 }
3158
3159 #[test]
3160 fn test_expr_hash_different() {
3161 let e1 = add(var("a"), var("b"));
3163 let e2 = add(var("a"), var("c"));
3164 assert_ne!(expr_hash(&e1), expr_hash(&e2));
3165 }
3166
3167 #[test]
3168 fn test_expr_eq() {
3169 let e1 = add(var("a"), var("b"));
3170 let e2 = add(var("a"), var("b"));
3171 let e3 = add(var("a"), var("c"));
3172
3173 assert!(expr_eq(&e1, &e2));
3174 assert!(!expr_eq(&e1, &e3));
3175 }
3176
3177 #[test]
3178 fn test_is_pure_expr() {
3179 assert!(is_pure_expr(&int_lit(42)));
3180 assert!(is_pure_expr(&var("x")));
3181 assert!(is_pure_expr(&add(var("a"), var("b"))));
3182
3183 let call = Expr::Call {
3185 func: Box::new(var("print")),
3186 args: vec![int_lit(42)],
3187 };
3188 assert!(!is_pure_expr(&call));
3189 }
3190
3191 #[test]
3192 fn test_is_cse_worthy() {
3193 assert!(!is_cse_worthy(&int_lit(42))); assert!(!is_cse_worthy(&var("x"))); assert!(is_cse_worthy(&add(var("a"), var("b")))); }
3197
3198 #[test]
3199 fn test_cse_basic() {
3200 let a_plus_b = add(var("a"), var("b"));
3205
3206 let block = Block {
3207 stmts: vec![
3208 Stmt::Let {
3209 pattern: Pattern::Ident {
3210 mutable: false,
3211 name: Ident {
3212 name: "x".to_string(),
3213 evidentiality: None,
3214 affect: None,
3215 span: Span { start: 0, end: 0 },
3216 },
3217 evidentiality: None,
3218 },
3219 ty: None,
3220 init: Some(a_plus_b.clone()),
3221 },
3222 Stmt::Let {
3223 pattern: Pattern::Ident {
3224 mutable: false,
3225 name: Ident {
3226 name: "y".to_string(),
3227 evidentiality: None,
3228 affect: None,
3229 span: Span { start: 0, end: 0 },
3230 },
3231 evidentiality: None,
3232 },
3233 ty: None,
3234 init: Some(mul(a_plus_b.clone(), int_lit(2))),
3235 },
3236 ],
3237 expr: None,
3238 };
3239
3240 let mut optimizer = Optimizer::new(OptLevel::Standard);
3241 let result = optimizer.pass_cse_block(&block);
3242
3243 assert_eq!(result.stmts.len(), 3);
3245 assert_eq!(optimizer.stats.expressions_deduplicated, 1);
3246
3247 if let Stmt::Let {
3249 pattern: Pattern::Ident { name, .. },
3250 ..
3251 } = &result.stmts[0]
3252 {
3253 assert_eq!(name.name, "__cse_0");
3254 } else {
3255 panic!("Expected CSE let binding");
3256 }
3257 }
3258
3259 #[test]
3260 fn test_cse_no_duplicates() {
3261 let block = Block {
3263 stmts: vec![
3264 Stmt::Let {
3265 pattern: Pattern::Ident {
3266 mutable: false,
3267 name: Ident {
3268 name: "x".to_string(),
3269 evidentiality: None,
3270 affect: None,
3271 span: Span { start: 0, end: 0 },
3272 },
3273 evidentiality: None,
3274 },
3275 ty: None,
3276 init: Some(add(var("a"), var("b"))),
3277 },
3278 Stmt::Let {
3279 pattern: Pattern::Ident {
3280 mutable: false,
3281 name: Ident {
3282 name: "y".to_string(),
3283 evidentiality: None,
3284 affect: None,
3285 span: Span { start: 0, end: 0 },
3286 },
3287 evidentiality: None,
3288 },
3289 ty: None,
3290 init: Some(add(var("c"), var("d"))),
3291 },
3292 ],
3293 expr: None,
3294 };
3295
3296 let mut optimizer = Optimizer::new(OptLevel::Standard);
3297 let result = optimizer.pass_cse_block(&block);
3298
3299 assert_eq!(result.stmts.len(), 2);
3301 assert_eq!(optimizer.stats.expressions_deduplicated, 0);
3302 }
3303}