1use crate::ast::{
7 self, BinOp, Block, Expr, FunctionAttrs, Ident, Item, Literal, NumBase, Param, PathSegment,
8 Pattern, Stmt, TypeExpr, TypePath, UnaryOp, Visibility,
9};
10use crate::span::Span;
11use std::collections::{HashMap, HashSet};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum OptLevel {
20 None,
22 Basic,
24 Standard,
26 Aggressive,
28 Size,
30}
31
32#[derive(Debug, Default, Clone)]
34pub struct OptStats {
35 pub constants_folded: usize,
36 pub dead_code_eliminated: usize,
37 pub expressions_deduplicated: usize,
38 pub functions_inlined: usize,
39 pub strength_reductions: usize,
40 pub branches_simplified: usize,
41 pub loops_optimized: usize,
42 pub tail_recursion_transforms: usize,
43 pub memoization_transforms: usize,
44}
45
46pub struct Optimizer {
48 level: OptLevel,
49 stats: OptStats,
50 functions: HashMap<String, ast::Function>,
52 recursive_functions: HashSet<String>,
54 cse_counter: usize,
56}
57
58impl Optimizer {
59 pub fn new(level: OptLevel) -> Self {
60 Self {
61 level,
62 stats: OptStats::default(),
63 functions: HashMap::new(),
64 recursive_functions: HashSet::new(),
65 cse_counter: 0,
66 }
67 }
68
69 pub fn stats(&self) -> &OptStats {
71 &self.stats
72 }
73
74 pub fn optimize_file(&mut self, file: &ast::SourceFile) -> ast::SourceFile {
76 for item in &file.items {
78 if let Item::Function(func) = &item.node {
79 self.functions.insert(func.name.name.clone(), func.clone());
80 if self.is_recursive(&func.name.name, func) {
81 self.recursive_functions.insert(func.name.name.clone());
82 }
83 }
84 }
85
86 let mut new_items: Vec<crate::span::Spanned<Item>> = Vec::new();
90 let mut transformed_functions: HashMap<String, String> = HashMap::new();
91
92 if matches!(self.level, OptLevel::Standard | OptLevel::Aggressive) {
93 for item in &file.items {
94 if let Item::Function(func) = &item.node {
95 if let Some((helper_func, wrapper_func)) = self.try_accumulator_transform(func)
96 {
97 new_items.push(crate::span::Spanned {
99 node: Item::Function(helper_func),
100 span: item.span.clone(),
101 });
102 transformed_functions
103 .insert(func.name.name.clone(), wrapper_func.name.name.clone());
104 self.stats.tail_recursion_transforms += 1;
105 }
106 }
107 }
108
109 }
115
116 let items: Vec<_> = file
118 .items
119 .iter()
120 .map(|item| {
121 let node = match &item.node {
122 Item::Function(func) => {
123 if let Some((_, wrapper)) = self.try_accumulator_transform(func) {
125 if matches!(self.level, OptLevel::Standard | OptLevel::Aggressive)
126 && transformed_functions.contains_key(&func.name.name)
127 {
128 Item::Function(self.optimize_function(&wrapper))
129 } else {
130 Item::Function(self.optimize_function(func))
131 }
132 } else {
133 Item::Function(self.optimize_function(func))
134 }
135 }
136 other => other.clone(),
137 };
138 crate::span::Spanned {
139 node,
140 span: item.span.clone(),
141 }
142 })
143 .collect();
144
145 new_items.extend(items);
147
148 ast::SourceFile {
149 attrs: file.attrs.clone(),
150 config: file.config.clone(),
151 items: new_items,
152 }
153 }
154
155 fn try_accumulator_transform(
158 &self,
159 func: &ast::Function,
160 ) -> Option<(ast::Function, ast::Function)> {
161 if func.params.len() != 1 {
163 return None;
164 }
165
166 if !self.recursive_functions.contains(&func.name.name) {
168 return None;
169 }
170
171 let body = func.body.as_ref()?;
172
173 if !self.is_fib_like_pattern(&func.name.name, body) {
177 return None;
178 }
179
180 let param_name = if let Pattern::Ident { name, .. } = &func.params[0].pattern {
182 name.name.clone()
183 } else {
184 return None;
185 };
186
187 let helper_name = format!("{}_tail", func.name.name);
189
190 let helper_func = self.generate_fib_helper(&helper_name, ¶m_name);
192
193 let wrapper_func =
195 self.generate_fib_wrapper(&func.name.name, &helper_name, ¶m_name, func);
196
197 Some((helper_func, wrapper_func))
198 }
199
200 fn is_fib_like_pattern(&self, func_name: &str, body: &Block) -> bool {
202 if body.stmts.is_empty() && body.expr.is_none() {
209 return false;
210 }
211
212 if let Some(expr) = &body.expr {
214 if let Expr::If {
215 else_branch: Some(else_expr),
216 ..
217 } = expr.as_ref()
218 {
219 return self.is_double_recursive_expr(func_name, else_expr);
222 }
223 }
224
225 if body.stmts.len() >= 1 {
227 if let Some(Stmt::Expr(expr) | Stmt::Semi(expr)) = body.stmts.last() {
229 if let Expr::Return(Some(ret_expr)) = expr {
230 return self.is_double_recursive_expr(func_name, ret_expr);
231 }
232 }
233 if let Some(expr) = &body.expr {
234 return self.is_double_recursive_expr(func_name, expr);
235 }
236 }
237
238 false
239 }
240
241 fn is_double_recursive_expr(&self, func_name: &str, expr: &Expr) -> bool {
243 if let Expr::Binary {
244 op: BinOp::Add,
245 left,
246 right,
247 } = expr
248 {
249 let left_is_recursive = self.is_recursive_call_with_decrement(func_name, left);
250 let right_is_recursive = self.is_recursive_call_with_decrement(func_name, right);
251 return left_is_recursive && right_is_recursive;
252 }
253 false
254 }
255
256 fn is_recursive_call_with_decrement(&self, func_name: &str, expr: &Expr) -> bool {
258 if let Expr::Call { func, args } = expr {
259 if let Expr::Path(path) = func.as_ref() {
260 if path.segments.last().map(|s| s.ident.name.as_str()) == Some(func_name) {
261 if args.len() == 1 {
263 if let Expr::Binary { op: BinOp::Sub, .. } = &args[0] {
264 return true;
265 }
266 }
267 }
268 }
269 }
270 false
271 }
272
273 fn generate_fib_helper(&self, name: &str, _param_name: &str) -> ast::Function {
275 let span = Span { start: 0, end: 0 };
276
277 let n_ident = Ident {
282 name: "n".to_string(),
283 evidentiality: None,
284 affect: None,
285 span: span.clone(),
286 };
287 let a_ident = Ident {
288 name: "a".to_string(),
289 evidentiality: None,
290 affect: None,
291 span: span.clone(),
292 };
293 let b_ident = Ident {
294 name: "b".to_string(),
295 evidentiality: None,
296 affect: None,
297 span: span.clone(),
298 };
299
300 let params = vec![
301 Param {
302 pattern: Pattern::Ident {
303 mutable: false,
304 name: n_ident.clone(),
305 evidentiality: None,
306 },
307 ty: TypeExpr::Infer,
308 },
309 Param {
310 pattern: Pattern::Ident {
311 mutable: false,
312 name: a_ident.clone(),
313 evidentiality: None,
314 },
315 ty: TypeExpr::Infer,
316 },
317 Param {
318 pattern: Pattern::Ident {
319 mutable: false,
320 name: b_ident.clone(),
321 evidentiality: None,
322 },
323 ty: TypeExpr::Infer,
324 },
325 ];
326
327 let condition = Expr::Binary {
329 op: BinOp::Le,
330 left: Box::new(Expr::Path(TypePath {
331 segments: vec![PathSegment {
332 ident: n_ident.clone(),
333 generics: None,
334 }],
335 })),
336 right: Box::new(Expr::Literal(Literal::Int {
337 value: "0".to_string(),
338 base: NumBase::Decimal,
339 suffix: None,
340 })),
341 };
342
343 let then_branch = Block {
345 stmts: vec![],
346 expr: Some(Box::new(Expr::Return(Some(Box::new(Expr::Path(
347 TypePath {
348 segments: vec![PathSegment {
349 ident: a_ident.clone(),
350 generics: None,
351 }],
352 },
353 )))))),
354 };
355
356 let recursive_call = Expr::Call {
358 func: Box::new(Expr::Path(TypePath {
359 segments: vec![PathSegment {
360 ident: Ident {
361 name: name.to_string(),
362 evidentiality: None,
363 affect: None,
364 span: span.clone(),
365 },
366 generics: None,
367 }],
368 })),
369 args: vec![
370 Expr::Binary {
372 op: BinOp::Sub,
373 left: Box::new(Expr::Path(TypePath {
374 segments: vec![PathSegment {
375 ident: n_ident.clone(),
376 generics: None,
377 }],
378 })),
379 right: Box::new(Expr::Literal(Literal::Int {
380 value: "1".to_string(),
381 base: NumBase::Decimal,
382 suffix: None,
383 })),
384 },
385 Expr::Path(TypePath {
387 segments: vec![PathSegment {
388 ident: b_ident.clone(),
389 generics: None,
390 }],
391 }),
392 Expr::Binary {
394 op: BinOp::Add,
395 left: Box::new(Expr::Path(TypePath {
396 segments: vec![PathSegment {
397 ident: a_ident.clone(),
398 generics: None,
399 }],
400 })),
401 right: Box::new(Expr::Path(TypePath {
402 segments: vec![PathSegment {
403 ident: b_ident.clone(),
404 generics: None,
405 }],
406 })),
407 },
408 ],
409 };
410
411 let body = Block {
413 stmts: vec![],
414 expr: Some(Box::new(Expr::If {
415 condition: Box::new(condition),
416 then_branch,
417 else_branch: Some(Box::new(Expr::Return(Some(Box::new(recursive_call))))),
418 })),
419 };
420
421 ast::Function {
422 doc_comments: vec![],
423 visibility: Visibility::default(),
424 is_async: false,
425 is_const: false,
426 is_unsafe: false,
427 attrs: FunctionAttrs::default(),
428 name: Ident {
429 name: name.to_string(),
430 evidentiality: None,
431 affect: None,
432 span: span.clone(),
433 },
434 aspect: None,
435 generics: None,
436 params,
437 return_type: None,
438 where_clause: None,
439 body: Some(body),
440 }
441 }
442
443 fn generate_fib_wrapper(
445 &self,
446 name: &str,
447 helper_name: &str,
448 param_name: &str,
449 original: &ast::Function,
450 ) -> ast::Function {
451 let span = Span { start: 0, end: 0 };
452
453 let call_helper = Expr::Call {
455 func: Box::new(Expr::Path(TypePath {
456 segments: vec![PathSegment {
457 ident: Ident {
458 name: helper_name.to_string(),
459 evidentiality: None,
460 affect: None,
461 span: span.clone(),
462 },
463 generics: None,
464 }],
465 })),
466 args: vec![
467 Expr::Path(TypePath {
469 segments: vec![PathSegment {
470 ident: Ident {
471 name: param_name.to_string(),
472 evidentiality: None,
473 affect: None,
474 span: span.clone(),
475 },
476 generics: None,
477 }],
478 }),
479 Expr::Literal(Literal::Int {
481 value: "0".to_string(),
482 base: NumBase::Decimal,
483 suffix: None,
484 }),
485 Expr::Literal(Literal::Int {
487 value: "1".to_string(),
488 base: NumBase::Decimal,
489 suffix: None,
490 }),
491 ],
492 };
493
494 let body = Block {
495 stmts: vec![],
496 expr: Some(Box::new(Expr::Return(Some(Box::new(call_helper))))),
497 };
498
499 ast::Function {
500 doc_comments: original.doc_comments.clone(),
501 visibility: original.visibility,
502 is_async: original.is_async,
503 is_const: original.is_const,
504 is_unsafe: original.is_unsafe,
505 attrs: original.attrs.clone(),
506 name: Ident {
507 name: name.to_string(),
508 evidentiality: None,
509 affect: None,
510 span: span.clone(),
511 },
512 aspect: original.aspect,
513 generics: original.generics.clone(),
514 params: original.params.clone(),
515 return_type: original.return_type.clone(),
516 where_clause: original.where_clause.clone(),
517 body: Some(body),
518 }
519 }
520
521 #[allow(dead_code)]
528 fn try_memoize_transform(
529 &self,
530 func: &ast::Function,
531 ) -> Option<(ast::Function, ast::Function, ast::Function)> {
532 let param_count = func.params.len();
533 if param_count != 1 && param_count != 2 {
534 return None;
535 }
536
537 let span = Span { start: 0, end: 0 };
538 let func_name = &func.name.name;
539 let impl_name = format!("_memo_impl_{}", func_name);
540 let _cache_name = format!("_memo_cache_{}", func_name);
541 let init_name = format!("_memo_init_{}", func_name);
542
543 let param_names: Vec<String> = func
545 .params
546 .iter()
547 .filter_map(|p| {
548 if let Pattern::Ident { name, .. } = &p.pattern {
549 Some(name.name.clone())
550 } else {
551 None
552 }
553 })
554 .collect();
555
556 if param_names.len() != param_count {
557 return None;
558 }
559
560 let impl_func = ast::Function {
562 doc_comments: vec![],
563 visibility: Visibility::default(),
564 is_async: func.is_async,
565 is_const: func.is_const,
566 is_unsafe: func.is_unsafe,
567 attrs: func.attrs.clone(),
568 name: Ident {
569 name: impl_name.clone(),
570 evidentiality: None,
571 affect: None,
572 span: span.clone(),
573 },
574 aspect: func.aspect,
575 generics: func.generics.clone(),
576 params: func.params.clone(),
577 return_type: func.return_type.clone(),
578 where_clause: func.where_clause.clone(),
579 body: func
580 .body
581 .as_ref()
582 .map(|b| self.redirect_calls_in_block(func_name, func_name, b)),
583 };
584
585 let cache_init_body = Block {
588 stmts: vec![],
589 expr: Some(Box::new(Expr::Call {
590 func: Box::new(Expr::Path(TypePath {
591 segments: vec![PathSegment {
592 ident: Ident {
593 name: "sigil_memo_new".to_string(),
594 evidentiality: None,
595 affect: None,
596 span: span.clone(),
597 },
598 generics: None,
599 }],
600 })),
601 args: vec![Expr::Literal(Literal::Int {
602 value: "65536".to_string(),
603 base: NumBase::Decimal,
604 suffix: None,
605 })],
606 })),
607 };
608
609 let cache_init_func = ast::Function {
610 doc_comments: vec![],
611 visibility: Visibility::default(),
612 is_async: false,
613 is_const: false,
614 is_unsafe: false,
615 attrs: FunctionAttrs::default(),
616 name: Ident {
617 name: init_name.clone(),
618 evidentiality: None,
619 affect: None,
620 span: span.clone(),
621 },
622 aspect: None,
623 generics: None,
624 params: vec![],
625 return_type: None,
626 where_clause: None,
627 body: Some(cache_init_body),
628 };
629
630 let wrapper_func = self.generate_memo_wrapper(func, &impl_name, ¶m_names);
632
633 Some((impl_func, cache_init_func, wrapper_func))
634 }
635
636 #[allow(dead_code)]
638 fn generate_memo_wrapper(
639 &self,
640 original: &ast::Function,
641 impl_name: &str,
642 param_names: &[String],
643 ) -> ast::Function {
644 let span = Span { start: 0, end: 0 };
645 let param_count = param_names.len();
646
647 let cache_var = Ident {
649 name: "__cache".to_string(),
650 evidentiality: None,
651 affect: None,
652 span: span.clone(),
653 };
654 let result_var = Ident {
655 name: "__result".to_string(),
656 evidentiality: None,
657 affect: None,
658 span: span.clone(),
659 };
660 let cached_var = Ident {
661 name: "__cached".to_string(),
662 evidentiality: None,
663 affect: None,
664 span: span.clone(),
665 };
666
667 let mut stmts = vec![];
668
669 stmts.push(Stmt::Let {
671 pattern: Pattern::Ident {
672 mutable: false,
673 name: cache_var.clone(),
674 evidentiality: None,
675 },
676 ty: None,
677 init: Some(Expr::Call {
678 func: Box::new(Expr::Path(TypePath {
679 segments: vec![PathSegment {
680 ident: Ident {
681 name: "sigil_memo_new".to_string(),
682 evidentiality: None,
683 affect: None,
684 span: span.clone(),
685 },
686 generics: None,
687 }],
688 })),
689 args: vec![Expr::Literal(Literal::Int {
690 value: "65536".to_string(),
691 base: NumBase::Decimal,
692 suffix: None,
693 })],
694 }),
695 });
696
697 let get_fn_name = if param_count == 1 {
699 "sigil_memo_get_1"
700 } else {
701 "sigil_memo_get_2"
702 };
703 let mut get_args = vec![Expr::Path(TypePath {
704 segments: vec![PathSegment {
705 ident: cache_var.clone(),
706 generics: None,
707 }],
708 })];
709 for name in param_names {
710 get_args.push(Expr::Path(TypePath {
711 segments: vec![PathSegment {
712 ident: Ident {
713 name: name.clone(),
714 evidentiality: None,
715 affect: None,
716 span: span.clone(),
717 },
718 generics: None,
719 }],
720 }));
721 }
722
723 stmts.push(Stmt::Let {
724 pattern: Pattern::Ident {
725 mutable: false,
726 name: cached_var.clone(),
727 evidentiality: None,
728 },
729 ty: None,
730 init: Some(Expr::Call {
731 func: Box::new(Expr::Path(TypePath {
732 segments: vec![PathSegment {
733 ident: Ident {
734 name: get_fn_name.to_string(),
735 evidentiality: None,
736 affect: None,
737 span: span.clone(),
738 },
739 generics: None,
740 }],
741 })),
742 args: get_args,
743 }),
744 });
745
746 let cache_check = Expr::If {
749 condition: Box::new(Expr::Binary {
750 op: BinOp::Ne,
751 left: Box::new(Expr::Path(TypePath {
752 segments: vec![PathSegment {
753 ident: cached_var.clone(),
754 generics: None,
755 }],
756 })),
757 right: Box::new(Expr::Unary {
758 op: UnaryOp::Neg,
759 expr: Box::new(Expr::Literal(Literal::Int {
760 value: "9223372036854775807".to_string(),
761 base: NumBase::Decimal,
762 suffix: None,
763 })),
764 }),
765 }),
766 then_branch: Block {
767 stmts: vec![],
768 expr: Some(Box::new(Expr::Return(Some(Box::new(Expr::Path(
769 TypePath {
770 segments: vec![PathSegment {
771 ident: cached_var.clone(),
772 generics: None,
773 }],
774 },
775 )))))),
776 },
777 else_branch: None,
778 };
779 stmts.push(Stmt::Semi(cache_check));
780
781 let mut impl_args = vec![];
783 for name in param_names {
784 impl_args.push(Expr::Path(TypePath {
785 segments: vec![PathSegment {
786 ident: Ident {
787 name: name.clone(),
788 evidentiality: None,
789 affect: None,
790 span: span.clone(),
791 },
792 generics: None,
793 }],
794 }));
795 }
796
797 stmts.push(Stmt::Let {
798 pattern: Pattern::Ident {
799 mutable: false,
800 name: result_var.clone(),
801 evidentiality: None,
802 },
803 ty: None,
804 init: Some(Expr::Call {
805 func: Box::new(Expr::Path(TypePath {
806 segments: vec![PathSegment {
807 ident: Ident {
808 name: impl_name.to_string(),
809 evidentiality: None,
810 affect: None,
811 span: span.clone(),
812 },
813 generics: None,
814 }],
815 })),
816 args: impl_args,
817 }),
818 });
819
820 let set_fn_name = if param_count == 1 {
822 "sigil_memo_set_1"
823 } else {
824 "sigil_memo_set_2"
825 };
826 let mut set_args = vec![Expr::Path(TypePath {
827 segments: vec![PathSegment {
828 ident: cache_var.clone(),
829 generics: None,
830 }],
831 })];
832 for name in param_names {
833 set_args.push(Expr::Path(TypePath {
834 segments: vec![PathSegment {
835 ident: Ident {
836 name: name.clone(),
837 evidentiality: None,
838 affect: None,
839 span: span.clone(),
840 },
841 generics: None,
842 }],
843 }));
844 }
845 set_args.push(Expr::Path(TypePath {
846 segments: vec![PathSegment {
847 ident: result_var.clone(),
848 generics: None,
849 }],
850 }));
851
852 stmts.push(Stmt::Semi(Expr::Call {
853 func: Box::new(Expr::Path(TypePath {
854 segments: vec![PathSegment {
855 ident: Ident {
856 name: set_fn_name.to_string(),
857 evidentiality: None,
858 affect: None,
859 span: span.clone(),
860 },
861 generics: None,
862 }],
863 })),
864 args: set_args,
865 }));
866
867 let body = Block {
869 stmts,
870 expr: Some(Box::new(Expr::Return(Some(Box::new(Expr::Path(
871 TypePath {
872 segments: vec![PathSegment {
873 ident: result_var.clone(),
874 generics: None,
875 }],
876 },
877 )))))),
878 };
879
880 ast::Function {
881 doc_comments: original.doc_comments.clone(),
882 visibility: original.visibility,
883 is_async: original.is_async,
884 is_const: original.is_const,
885 is_unsafe: original.is_unsafe,
886 attrs: original.attrs.clone(),
887 name: original.name.clone(),
888 aspect: original.aspect,
889 generics: original.generics.clone(),
890 params: original.params.clone(),
891 return_type: original.return_type.clone(),
892 where_clause: original.where_clause.clone(),
893 body: Some(body),
894 }
895 }
896
897 #[allow(dead_code)]
899 fn redirect_calls_in_block(&self, _old_name: &str, _new_name: &str, block: &Block) -> Block {
900 block.clone()
902 }
903
904 fn is_recursive(&self, name: &str, func: &ast::Function) -> bool {
906 if let Some(body) = &func.body {
907 self.block_calls_function(name, body)
908 } else {
909 false
910 }
911 }
912
913 fn block_calls_function(&self, name: &str, block: &Block) -> bool {
914 for stmt in &block.stmts {
915 if self.stmt_calls_function(name, stmt) {
916 return true;
917 }
918 }
919 if let Some(expr) = &block.expr {
920 if self.expr_calls_function(name, expr) {
921 return true;
922 }
923 }
924 false
925 }
926
927 fn stmt_calls_function(&self, name: &str, stmt: &Stmt) -> bool {
928 match stmt {
929 Stmt::Let {
930 init: Some(expr), ..
931 } => self.expr_calls_function(name, expr),
932 Stmt::Expr(expr) | Stmt::Semi(expr) => self.expr_calls_function(name, expr),
933 _ => false,
934 }
935 }
936
937 fn expr_calls_function(&self, name: &str, expr: &Expr) -> bool {
938 match expr {
939 Expr::Call { func, args } => {
940 if let Expr::Path(path) = func.as_ref() {
941 if path.segments.last().map(|s| s.ident.name.as_str()) == Some(name) {
942 return true;
943 }
944 }
945 args.iter().any(|a| self.expr_calls_function(name, a))
946 }
947 Expr::Binary { left, right, .. } => {
948 self.expr_calls_function(name, left) || self.expr_calls_function(name, right)
949 }
950 Expr::Unary { expr, .. } => self.expr_calls_function(name, expr),
951 Expr::If {
952 condition,
953 then_branch,
954 else_branch,
955 } => {
956 self.expr_calls_function(name, condition)
957 || self.block_calls_function(name, then_branch)
958 || else_branch
959 .as_ref()
960 .map(|e| self.expr_calls_function(name, e))
961 .unwrap_or(false)
962 }
963 Expr::While { label, condition, body } => {
964 self.expr_calls_function(name, condition) || self.block_calls_function(name, body)
965 }
966 Expr::Block(block) => self.block_calls_function(name, block),
967 Expr::Return(Some(e)) => self.expr_calls_function(name, e),
968 _ => false,
969 }
970 }
971
972 fn optimize_function(&mut self, func: &ast::Function) -> ast::Function {
974 self.cse_counter = 0;
976
977 let body = if let Some(body) = &func.body {
978 let optimized = match self.level {
980 OptLevel::None => body.clone(),
981 OptLevel::Basic => {
982 let b = self.pass_constant_fold_block(body);
983 self.pass_dead_code_block(&b)
984 }
985 OptLevel::Standard | OptLevel::Size => {
986 let b = self.pass_constant_fold_block(body);
987 let b = self.pass_inline_block(&b); let b = self.pass_strength_reduce_block(&b);
989 let b = self.pass_licm_block(&b); let b = self.pass_cse_block(&b); let b = self.pass_dead_code_block(&b);
992 self.pass_simplify_branches_block(&b)
993 }
994 OptLevel::Aggressive => {
995 let mut b = body.clone();
997 for _ in 0..3 {
998 b = self.pass_constant_fold_block(&b);
999 b = self.pass_inline_block(&b); b = self.pass_strength_reduce_block(&b);
1001 b = self.pass_loop_unroll_block(&b); b = self.pass_licm_block(&b); b = self.pass_cse_block(&b); b = self.pass_dead_code_block(&b);
1005 b = self.pass_simplify_branches_block(&b);
1006 }
1007 b
1008 }
1009 };
1010 Some(optimized)
1011 } else {
1012 None
1013 };
1014
1015 ast::Function {
1016 doc_comments: func.doc_comments.clone(),
1017 visibility: func.visibility.clone(),
1018 is_async: func.is_async,
1019 is_const: func.is_const,
1020 is_unsafe: func.is_unsafe,
1021 attrs: func.attrs.clone(),
1022 name: func.name.clone(),
1023 aspect: func.aspect,
1024 generics: func.generics.clone(),
1025 params: func.params.clone(),
1026 return_type: func.return_type.clone(),
1027 where_clause: func.where_clause.clone(),
1028 body,
1029 }
1030 }
1031
1032 fn pass_constant_fold_block(&mut self, block: &Block) -> Block {
1037 let stmts = block
1038 .stmts
1039 .iter()
1040 .map(|s| self.pass_constant_fold_stmt(s))
1041 .collect();
1042 let expr = block
1043 .expr
1044 .as_ref()
1045 .map(|e| Box::new(self.pass_constant_fold_expr(e)));
1046 Block { stmts, expr }
1047 }
1048
1049 fn pass_constant_fold_stmt(&mut self, stmt: &Stmt) -> Stmt {
1050 match stmt {
1051 Stmt::Let {
1052 pattern, ty, init, ..
1053 } => Stmt::Let {
1054 pattern: pattern.clone(),
1055 ty: ty.clone(),
1056 init: init.as_ref().map(|e| self.pass_constant_fold_expr(e)),
1057 },
1058 Stmt::LetElse { pattern, ty, init, else_branch } => Stmt::LetElse {
1059 pattern: pattern.clone(),
1060 ty: ty.clone(),
1061 init: self.pass_constant_fold_expr(init),
1062 else_branch: Box::new(self.pass_constant_fold_expr(else_branch)),
1063 },
1064 Stmt::Expr(expr) => Stmt::Expr(self.pass_constant_fold_expr(expr)),
1065 Stmt::Semi(expr) => Stmt::Semi(self.pass_constant_fold_expr(expr)),
1066 Stmt::Item(item) => Stmt::Item(item.clone()),
1067 }
1068 }
1069
1070 fn pass_constant_fold_expr(&mut self, expr: &Expr) -> Expr {
1071 match expr {
1072 Expr::Binary { op, left, right } => {
1073 let left = Box::new(self.pass_constant_fold_expr(left));
1074 let right = Box::new(self.pass_constant_fold_expr(right));
1075
1076 if let (Some(l), Some(r)) = (self.as_int(&left), self.as_int(&right)) {
1078 if let Some(result) = self.fold_binary(op.clone(), l, r) {
1079 self.stats.constants_folded += 1;
1080 return Expr::Literal(Literal::Int {
1081 value: result.to_string(),
1082 base: NumBase::Decimal,
1083 suffix: None,
1084 });
1085 }
1086 }
1087
1088 Expr::Binary {
1089 op: op.clone(),
1090 left,
1091 right,
1092 }
1093 }
1094 Expr::Unary { op, expr: inner } => {
1095 let inner = Box::new(self.pass_constant_fold_expr(inner));
1096
1097 if let Some(v) = self.as_int(&inner) {
1098 if let Some(result) = self.fold_unary(*op, v) {
1099 self.stats.constants_folded += 1;
1100 return Expr::Literal(Literal::Int {
1101 value: result.to_string(),
1102 base: NumBase::Decimal,
1103 suffix: None,
1104 });
1105 }
1106 }
1107
1108 Expr::Unary {
1109 op: *op,
1110 expr: inner,
1111 }
1112 }
1113 Expr::If {
1114 condition,
1115 then_branch,
1116 else_branch,
1117 } => {
1118 let condition = Box::new(self.pass_constant_fold_expr(condition));
1119 let then_branch = self.pass_constant_fold_block(then_branch);
1120 let else_branch = else_branch
1121 .as_ref()
1122 .map(|e| Box::new(self.pass_constant_fold_expr(e)));
1123
1124 if let Some(cond) = self.as_bool(&condition) {
1126 self.stats.branches_simplified += 1;
1127 if cond {
1128 return Expr::Block(then_branch);
1129 } else if let Some(else_expr) = else_branch {
1130 return *else_expr;
1131 } else {
1132 return Expr::Literal(Literal::Bool(false));
1133 }
1134 }
1135
1136 Expr::If {
1137 condition,
1138 then_branch,
1139 else_branch,
1140 }
1141 }
1142 Expr::While { label, condition, body } => {
1143 let condition = Box::new(self.pass_constant_fold_expr(condition));
1144 let body = self.pass_constant_fold_block(body);
1145
1146 if let Some(false) = self.as_bool(&condition) {
1148 self.stats.branches_simplified += 1;
1149 return Expr::Block(Block {
1150 stmts: vec![],
1151 expr: None,
1152 });
1153 }
1154
1155 Expr::While { label: label.clone(), condition, body }
1156 }
1157 Expr::Block(block) => Expr::Block(self.pass_constant_fold_block(block)),
1158 Expr::Call { func, args } => {
1159 let args = args
1160 .iter()
1161 .map(|a| self.pass_constant_fold_expr(a))
1162 .collect();
1163 Expr::Call {
1164 func: func.clone(),
1165 args,
1166 }
1167 }
1168 Expr::Return(e) => Expr::Return(
1169 e.as_ref()
1170 .map(|e| Box::new(self.pass_constant_fold_expr(e))),
1171 ),
1172 Expr::Assign { target, value } => {
1173 let value = Box::new(self.pass_constant_fold_expr(value));
1174 Expr::Assign {
1175 target: target.clone(),
1176 value,
1177 }
1178 }
1179 Expr::Index { expr: e, index } => {
1180 let e = Box::new(self.pass_constant_fold_expr(e));
1181 let index = Box::new(self.pass_constant_fold_expr(index));
1182 Expr::Index { expr: e, index }
1183 }
1184 Expr::Array(elements) => {
1185 let elements = elements
1186 .iter()
1187 .map(|e| self.pass_constant_fold_expr(e))
1188 .collect();
1189 Expr::Array(elements)
1190 }
1191 other => other.clone(),
1192 }
1193 }
1194
1195 fn as_int(&self, expr: &Expr) -> Option<i64> {
1196 match expr {
1197 Expr::Literal(Literal::Int { value, .. }) => value.parse().ok(),
1198 Expr::Literal(Literal::Bool(b)) => Some(if *b { 1 } else { 0 }),
1199 _ => None,
1200 }
1201 }
1202
1203 fn as_bool(&self, expr: &Expr) -> Option<bool> {
1204 match expr {
1205 Expr::Literal(Literal::Bool(b)) => Some(*b),
1206 Expr::Literal(Literal::Int { value, .. }) => value.parse::<i64>().ok().map(|v| v != 0),
1207 _ => None,
1208 }
1209 }
1210
1211 fn fold_binary(&self, op: BinOp, l: i64, r: i64) -> Option<i64> {
1212 match op {
1213 BinOp::Add => Some(l.wrapping_add(r)),
1214 BinOp::Sub => Some(l.wrapping_sub(r)),
1215 BinOp::Mul => Some(l.wrapping_mul(r)),
1216 BinOp::Div if r != 0 => Some(l / r),
1217 BinOp::Rem if r != 0 => Some(l % r),
1218 BinOp::BitAnd => Some(l & r),
1219 BinOp::BitOr => Some(l | r),
1220 BinOp::BitXor => Some(l ^ r),
1221 BinOp::Shl => Some(l << (r & 63)),
1222 BinOp::Shr => Some(l >> (r & 63)),
1223 BinOp::Eq => Some(if l == r { 1 } else { 0 }),
1224 BinOp::Ne => Some(if l != r { 1 } else { 0 }),
1225 BinOp::Lt => Some(if l < r { 1 } else { 0 }),
1226 BinOp::Le => Some(if l <= r { 1 } else { 0 }),
1227 BinOp::Gt => Some(if l > r { 1 } else { 0 }),
1228 BinOp::Ge => Some(if l >= r { 1 } else { 0 }),
1229 BinOp::And => Some(if l != 0 && r != 0 { 1 } else { 0 }),
1230 BinOp::Or => Some(if l != 0 || r != 0 { 1 } else { 0 }),
1231 _ => None,
1232 }
1233 }
1234
1235 fn fold_unary(&self, op: UnaryOp, v: i64) -> Option<i64> {
1236 match op {
1237 UnaryOp::Neg => Some(-v),
1238 UnaryOp::Not => Some(if v == 0 { 1 } else { 0 }),
1239 _ => None,
1240 }
1241 }
1242
1243 fn pass_strength_reduce_block(&mut self, block: &Block) -> Block {
1248 let stmts = block
1249 .stmts
1250 .iter()
1251 .map(|s| self.pass_strength_reduce_stmt(s))
1252 .collect();
1253 let expr = block
1254 .expr
1255 .as_ref()
1256 .map(|e| Box::new(self.pass_strength_reduce_expr(e)));
1257 Block { stmts, expr }
1258 }
1259
1260 fn pass_strength_reduce_stmt(&mut self, stmt: &Stmt) -> Stmt {
1261 match stmt {
1262 Stmt::Let {
1263 pattern, ty, init, ..
1264 } => Stmt::Let {
1265 pattern: pattern.clone(),
1266 ty: ty.clone(),
1267 init: init.as_ref().map(|e| self.pass_strength_reduce_expr(e)),
1268 },
1269 Stmt::LetElse { pattern, ty, init, else_branch } => Stmt::LetElse {
1270 pattern: pattern.clone(),
1271 ty: ty.clone(),
1272 init: self.pass_strength_reduce_expr(init),
1273 else_branch: Box::new(self.pass_strength_reduce_expr(else_branch)),
1274 },
1275 Stmt::Expr(expr) => Stmt::Expr(self.pass_strength_reduce_expr(expr)),
1276 Stmt::Semi(expr) => Stmt::Semi(self.pass_strength_reduce_expr(expr)),
1277 Stmt::Item(item) => Stmt::Item(item.clone()),
1278 }
1279 }
1280
1281 fn pass_strength_reduce_expr(&mut self, expr: &Expr) -> Expr {
1282 match expr {
1283 Expr::Binary { op, left, right } => {
1284 let left = Box::new(self.pass_strength_reduce_expr(left));
1285 let right = Box::new(self.pass_strength_reduce_expr(right));
1286
1287 if *op == BinOp::Mul {
1289 if let Some(n) = self.as_int(&right) {
1290 if n > 0 && (n as u64).is_power_of_two() {
1291 self.stats.strength_reductions += 1;
1292 let shift = (n as u64).trailing_zeros() as i64;
1293 return Expr::Binary {
1294 op: BinOp::Shl,
1295 left,
1296 right: Box::new(Expr::Literal(Literal::Int {
1297 value: shift.to_string(),
1298 base: NumBase::Decimal,
1299 suffix: None,
1300 })),
1301 };
1302 }
1303 }
1304 if let Some(n) = self.as_int(&left) {
1305 if n > 0 && (n as u64).is_power_of_two() {
1306 self.stats.strength_reductions += 1;
1307 let shift = (n as u64).trailing_zeros() as i64;
1308 return Expr::Binary {
1309 op: BinOp::Shl,
1310 left: right,
1311 right: Box::new(Expr::Literal(Literal::Int {
1312 value: shift.to_string(),
1313 base: NumBase::Decimal,
1314 suffix: None,
1315 })),
1316 };
1317 }
1318 }
1319 }
1320
1321 if let Some(n) = self.as_int(&right) {
1323 match (op, n) {
1324 (BinOp::Add | BinOp::Sub | BinOp::BitOr | BinOp::BitXor, 0)
1325 | (BinOp::Mul | BinOp::Div, 1)
1326 | (BinOp::Shl | BinOp::Shr, 0) => {
1327 self.stats.strength_reductions += 1;
1328 return *left;
1329 }
1330 (BinOp::Mul, 0) | (BinOp::BitAnd, 0) => {
1331 self.stats.strength_reductions += 1;
1332 return Expr::Literal(Literal::Int {
1333 value: "0".to_string(),
1334 base: NumBase::Decimal,
1335 suffix: None,
1336 });
1337 }
1338 _ => {}
1339 }
1340 }
1341
1342 if let Some(n) = self.as_int(&left) {
1344 match (op, n) {
1345 (BinOp::Add | BinOp::BitOr | BinOp::BitXor, 0) | (BinOp::Mul, 1) => {
1346 self.stats.strength_reductions += 1;
1347 return *right;
1348 }
1349 (BinOp::Mul, 0) | (BinOp::BitAnd, 0) => {
1350 self.stats.strength_reductions += 1;
1351 return Expr::Literal(Literal::Int {
1352 value: "0".to_string(),
1353 base: NumBase::Decimal,
1354 suffix: None,
1355 });
1356 }
1357 _ => {}
1358 }
1359 }
1360
1361 Expr::Binary {
1362 op: op.clone(),
1363 left,
1364 right,
1365 }
1366 }
1367 Expr::Unary { op, expr: inner } => {
1368 let inner = Box::new(self.pass_strength_reduce_expr(inner));
1369
1370 if *op == UnaryOp::Neg {
1372 if let Expr::Unary {
1373 op: UnaryOp::Neg,
1374 expr: inner2,
1375 } = inner.as_ref()
1376 {
1377 self.stats.strength_reductions += 1;
1378 return *inner2.clone();
1379 }
1380 }
1381
1382 if *op == UnaryOp::Not {
1384 if let Expr::Unary {
1385 op: UnaryOp::Not,
1386 expr: inner2,
1387 } = inner.as_ref()
1388 {
1389 self.stats.strength_reductions += 1;
1390 return *inner2.clone();
1391 }
1392 }
1393
1394 Expr::Unary {
1395 op: *op,
1396 expr: inner,
1397 }
1398 }
1399 Expr::If {
1400 condition,
1401 then_branch,
1402 else_branch,
1403 } => {
1404 let condition = Box::new(self.pass_strength_reduce_expr(condition));
1405 let then_branch = self.pass_strength_reduce_block(then_branch);
1406 let else_branch = else_branch
1407 .as_ref()
1408 .map(|e| Box::new(self.pass_strength_reduce_expr(e)));
1409 Expr::If {
1410 condition,
1411 then_branch,
1412 else_branch,
1413 }
1414 }
1415 Expr::While { label, condition, body } => {
1416 let condition = Box::new(self.pass_strength_reduce_expr(condition));
1417 let body = self.pass_strength_reduce_block(body);
1418 Expr::While { label: label.clone(), condition, body }
1419 }
1420 Expr::Block(block) => Expr::Block(self.pass_strength_reduce_block(block)),
1421 Expr::Call { func, args } => {
1422 let args = args
1423 .iter()
1424 .map(|a| self.pass_strength_reduce_expr(a))
1425 .collect();
1426 Expr::Call {
1427 func: func.clone(),
1428 args,
1429 }
1430 }
1431 Expr::Return(e) => Expr::Return(
1432 e.as_ref()
1433 .map(|e| Box::new(self.pass_strength_reduce_expr(e))),
1434 ),
1435 Expr::Assign { target, value } => {
1436 let value = Box::new(self.pass_strength_reduce_expr(value));
1437 Expr::Assign {
1438 target: target.clone(),
1439 value,
1440 }
1441 }
1442 other => other.clone(),
1443 }
1444 }
1445
1446 fn pass_dead_code_block(&mut self, block: &Block) -> Block {
1451 let mut stmts = Vec::new();
1453 let mut found_return = false;
1454
1455 for stmt in &block.stmts {
1456 if found_return {
1457 self.stats.dead_code_eliminated += 1;
1458 continue;
1459 }
1460 let stmt = self.pass_dead_code_stmt(stmt);
1461 if self.stmt_returns(&stmt) {
1462 found_return = true;
1463 }
1464 stmts.push(stmt);
1465 }
1466
1467 let expr = if found_return {
1469 if block.expr.is_some() {
1470 self.stats.dead_code_eliminated += 1;
1471 }
1472 None
1473 } else {
1474 block
1475 .expr
1476 .as_ref()
1477 .map(|e| Box::new(self.pass_dead_code_expr(e)))
1478 };
1479
1480 Block { stmts, expr }
1481 }
1482
1483 fn pass_dead_code_stmt(&mut self, stmt: &Stmt) -> Stmt {
1484 match stmt {
1485 Stmt::Let {
1486 pattern, ty, init, ..
1487 } => Stmt::Let {
1488 pattern: pattern.clone(),
1489 ty: ty.clone(),
1490 init: init.as_ref().map(|e| self.pass_dead_code_expr(e)),
1491 },
1492 Stmt::LetElse { pattern, ty, init, else_branch } => Stmt::LetElse {
1493 pattern: pattern.clone(),
1494 ty: ty.clone(),
1495 init: self.pass_dead_code_expr(init),
1496 else_branch: Box::new(self.pass_dead_code_expr(else_branch)),
1497 },
1498 Stmt::Expr(expr) => Stmt::Expr(self.pass_dead_code_expr(expr)),
1499 Stmt::Semi(expr) => Stmt::Semi(self.pass_dead_code_expr(expr)),
1500 Stmt::Item(item) => Stmt::Item(item.clone()),
1501 }
1502 }
1503
1504 fn pass_dead_code_expr(&mut self, expr: &Expr) -> Expr {
1505 match expr {
1506 Expr::If {
1507 condition,
1508 then_branch,
1509 else_branch,
1510 } => {
1511 let condition = Box::new(self.pass_dead_code_expr(condition));
1512 let then_branch = self.pass_dead_code_block(then_branch);
1513 let else_branch = else_branch
1514 .as_ref()
1515 .map(|e| Box::new(self.pass_dead_code_expr(e)));
1516 Expr::If {
1517 condition,
1518 then_branch,
1519 else_branch,
1520 }
1521 }
1522 Expr::While { label, condition, body } => {
1523 let condition = Box::new(self.pass_dead_code_expr(condition));
1524 let body = self.pass_dead_code_block(body);
1525 Expr::While { label: label.clone(), condition, body }
1526 }
1527 Expr::Block(block) => Expr::Block(self.pass_dead_code_block(block)),
1528 other => other.clone(),
1529 }
1530 }
1531
1532 fn stmt_returns(&self, stmt: &Stmt) -> bool {
1533 match stmt {
1534 Stmt::Expr(expr) | Stmt::Semi(expr) => self.expr_returns(expr),
1535 _ => false,
1536 }
1537 }
1538
1539 fn expr_returns(&self, expr: &Expr) -> bool {
1540 match expr {
1541 Expr::Return(_) => true,
1542 Expr::Block(block) => {
1543 block.stmts.iter().any(|s| self.stmt_returns(s))
1544 || block
1545 .expr
1546 .as_ref()
1547 .map(|e| self.expr_returns(e))
1548 .unwrap_or(false)
1549 }
1550 _ => false,
1551 }
1552 }
1553
1554 fn pass_simplify_branches_block(&mut self, block: &Block) -> Block {
1559 let stmts = block
1560 .stmts
1561 .iter()
1562 .map(|s| self.pass_simplify_branches_stmt(s))
1563 .collect();
1564 let expr = block
1565 .expr
1566 .as_ref()
1567 .map(|e| Box::new(self.pass_simplify_branches_expr(e)));
1568 Block { stmts, expr }
1569 }
1570
1571 fn pass_simplify_branches_stmt(&mut self, stmt: &Stmt) -> Stmt {
1572 match stmt {
1573 Stmt::Let {
1574 pattern, ty, init, ..
1575 } => Stmt::Let {
1576 pattern: pattern.clone(),
1577 ty: ty.clone(),
1578 init: init.as_ref().map(|e| self.pass_simplify_branches_expr(e)),
1579 },
1580 Stmt::LetElse { pattern, ty, init, else_branch } => Stmt::LetElse {
1581 pattern: pattern.clone(),
1582 ty: ty.clone(),
1583 init: self.pass_simplify_branches_expr(init),
1584 else_branch: Box::new(self.pass_simplify_branches_expr(else_branch)),
1585 },
1586 Stmt::Expr(expr) => Stmt::Expr(self.pass_simplify_branches_expr(expr)),
1587 Stmt::Semi(expr) => Stmt::Semi(self.pass_simplify_branches_expr(expr)),
1588 Stmt::Item(item) => Stmt::Item(item.clone()),
1589 }
1590 }
1591
1592 fn pass_simplify_branches_expr(&mut self, expr: &Expr) -> Expr {
1593 match expr {
1594 Expr::If {
1595 condition,
1596 then_branch,
1597 else_branch,
1598 } => {
1599 let condition = Box::new(self.pass_simplify_branches_expr(condition));
1600 let then_branch = self.pass_simplify_branches_block(then_branch);
1601 let else_branch = else_branch
1602 .as_ref()
1603 .map(|e| Box::new(self.pass_simplify_branches_expr(e)));
1604
1605 if let Expr::Unary {
1607 op: UnaryOp::Not,
1608 expr: inner,
1609 } = condition.as_ref()
1610 {
1611 if let Some(else_expr) = &else_branch {
1612 self.stats.branches_simplified += 1;
1613 let new_else = Some(Box::new(Expr::Block(then_branch)));
1614 let new_then = match else_expr.as_ref() {
1615 Expr::Block(b) => b.clone(),
1616 other => Block {
1617 stmts: vec![],
1618 expr: Some(Box::new(other.clone())),
1619 },
1620 };
1621 return Expr::If {
1622 condition: inner.clone(),
1623 then_branch: new_then,
1624 else_branch: new_else,
1625 };
1626 }
1627 }
1628
1629 Expr::If {
1630 condition,
1631 then_branch,
1632 else_branch,
1633 }
1634 }
1635 Expr::While { label, condition, body } => {
1636 let condition = Box::new(self.pass_simplify_branches_expr(condition));
1637 let body = self.pass_simplify_branches_block(body);
1638 Expr::While { label: label.clone(), condition, body }
1639 }
1640 Expr::Block(block) => Expr::Block(self.pass_simplify_branches_block(block)),
1641 Expr::Binary { op, left, right } => {
1642 let left = Box::new(self.pass_simplify_branches_expr(left));
1643 let right = Box::new(self.pass_simplify_branches_expr(right));
1644 Expr::Binary {
1645 op: op.clone(),
1646 left,
1647 right,
1648 }
1649 }
1650 Expr::Unary { op, expr: inner } => {
1651 let inner = Box::new(self.pass_simplify_branches_expr(inner));
1652 Expr::Unary {
1653 op: *op,
1654 expr: inner,
1655 }
1656 }
1657 Expr::Call { func, args } => {
1658 let args = args
1659 .iter()
1660 .map(|a| self.pass_simplify_branches_expr(a))
1661 .collect();
1662 Expr::Call {
1663 func: func.clone(),
1664 args,
1665 }
1666 }
1667 Expr::Return(e) => Expr::Return(
1668 e.as_ref()
1669 .map(|e| Box::new(self.pass_simplify_branches_expr(e))),
1670 ),
1671 other => other.clone(),
1672 }
1673 }
1674
1675 fn should_inline(&self, func: &ast::Function) -> bool {
1681 if self.recursive_functions.contains(&func.name.name) {
1683 return false;
1684 }
1685
1686 if let Some(body) = &func.body {
1689 if self.contains_inline_asm_in_block(body) {
1690 return false;
1691 }
1692 }
1693
1694 if let Some(body) = &func.body {
1696 let stmt_count = self.count_stmts_in_block(body);
1697 stmt_count <= 10
1699 } else {
1700 false
1701 }
1702 }
1703
1704 fn contains_inline_asm_in_block(&self, block: &Block) -> bool {
1706 for stmt in &block.stmts {
1707 if self.contains_inline_asm_in_stmt(stmt) {
1708 return true;
1709 }
1710 }
1711 if let Some(ref expr) = block.expr {
1712 if self.contains_inline_asm_in_expr(expr) {
1713 return true;
1714 }
1715 }
1716 false
1717 }
1718
1719 fn contains_inline_asm_in_stmt(&self, stmt: &Stmt) -> bool {
1720 match stmt {
1721 Stmt::Expr(e) | Stmt::Semi(e) => self.contains_inline_asm_in_expr(e),
1722 Stmt::Let { init: Some(e), .. } => self.contains_inline_asm_in_expr(e),
1723 Stmt::LetElse { init, else_branch, .. } => {
1724 self.contains_inline_asm_in_expr(init) || self.contains_inline_asm_in_expr(else_branch)
1725 }
1726 _ => false,
1727 }
1728 }
1729
1730 fn contains_inline_asm_in_expr(&self, expr: &Expr) -> bool {
1731 match expr {
1732 Expr::InlineAsm(_) => true,
1733
1734 Expr::Unsafe(block) | Expr::Block(block) => self.contains_inline_asm_in_block(block),
1736 Expr::Loop { body, .. } | Expr::While { body, .. } | Expr::For { body, .. } => {
1737 self.contains_inline_asm_in_block(body)
1738 }
1739
1740 Expr::If { condition, then_branch, else_branch } => {
1742 self.contains_inline_asm_in_expr(condition)
1743 || self.contains_inline_asm_in_block(then_branch)
1744 || else_branch.as_ref().map_or(false, |e| self.contains_inline_asm_in_expr(e))
1745 }
1746 Expr::Match { expr, arms } => {
1747 self.contains_inline_asm_in_expr(expr)
1748 || arms.iter().any(|arm| self.contains_inline_asm_in_expr(&arm.body))
1749 }
1750
1751 Expr::Binary { left, right, .. } => {
1753 self.contains_inline_asm_in_expr(left) || self.contains_inline_asm_in_expr(right)
1754 }
1755 Expr::Unary { expr, .. } => self.contains_inline_asm_in_expr(expr),
1756 Expr::Assign { target, value } => {
1757 self.contains_inline_asm_in_expr(target) || self.contains_inline_asm_in_expr(value)
1758 }
1759
1760 Expr::Call { func, args } => {
1762 self.contains_inline_asm_in_expr(func)
1763 || args.iter().any(|a| self.contains_inline_asm_in_expr(a))
1764 }
1765 Expr::MethodCall { receiver, args, .. } => {
1766 self.contains_inline_asm_in_expr(receiver)
1767 || args.iter().any(|a| self.contains_inline_asm_in_expr(a))
1768 }
1769
1770 Expr::Field { expr, .. } | Expr::Index { expr, .. } | Expr::Cast { expr, .. } => {
1772 self.contains_inline_asm_in_expr(expr)
1773 }
1774
1775 Expr::Try(inner) | Expr::Await { expr: inner, .. } => {
1777 self.contains_inline_asm_in_expr(inner)
1778 }
1779 Expr::Return(Some(inner)) => self.contains_inline_asm_in_expr(inner),
1780
1781 Expr::Tuple(exprs) | Expr::Array(exprs) => {
1783 exprs.iter().any(|e| self.contains_inline_asm_in_expr(e))
1784 }
1785
1786 Expr::Closure { body, .. } => self.contains_inline_asm_in_expr(body),
1788
1789 Expr::Path(_) | Expr::Literal(_) | Expr::Return(None) |
1791 Expr::Break { .. } | Expr::Continue { .. } | Expr::Range { .. } |
1792 Expr::Struct { .. } => false,
1793
1794 _ => false,
1797 }
1798 }
1799
1800 fn count_stmts_in_block(&self, block: &Block) -> usize {
1802 let mut count = block.stmts.len();
1803 if block.expr.is_some() {
1804 count += 1;
1805 }
1806 for stmt in &block.stmts {
1808 count += self.count_stmts_in_stmt(stmt);
1809 }
1810 count
1811 }
1812
1813 fn count_stmts_in_stmt(&self, stmt: &Stmt) -> usize {
1814 match stmt {
1815 Stmt::Expr(e) | Stmt::Semi(e) => self.count_stmts_in_expr(e),
1816 Stmt::Let { init: Some(e), .. } => self.count_stmts_in_expr(e),
1817 _ => 0,
1818 }
1819 }
1820
1821 fn count_stmts_in_expr(&self, expr: &Expr) -> usize {
1822 match expr {
1823 Expr::If {
1824 then_branch,
1825 else_branch,
1826 ..
1827 } => {
1828 let mut count = self.count_stmts_in_block(then_branch);
1829 if let Some(else_expr) = else_branch {
1830 count += self.count_stmts_in_expr(else_expr);
1831 }
1832 count
1833 }
1834 Expr::While { body, .. } => self.count_stmts_in_block(body),
1835 Expr::Block(block) => self.count_stmts_in_block(block),
1836 _ => 0,
1837 }
1838 }
1839
1840 fn inline_call(&mut self, func: &ast::Function, args: &[Expr]) -> Option<Expr> {
1842 let body = func.body.as_ref()?;
1843
1844 let mut param_map: HashMap<String, Expr> = HashMap::new();
1846 for (param, arg) in func.params.iter().zip(args.iter()) {
1847 if let Pattern::Ident { name, .. } = ¶m.pattern {
1848 param_map.insert(name.name.clone(), arg.clone());
1849 }
1850 }
1851
1852 let inlined_body = self.substitute_params_in_block(body, ¶m_map);
1854
1855 self.stats.functions_inlined += 1;
1856
1857 if inlined_body.stmts.is_empty() {
1860 if let Some(expr) = inlined_body.expr {
1861 if let Expr::Return(Some(inner)) = expr.as_ref() {
1863 return Some(inner.as_ref().clone());
1864 }
1865 return Some(*expr);
1866 }
1867 }
1868
1869 Some(Expr::Block(inlined_body))
1870 }
1871
1872 fn substitute_params_in_block(
1874 &self,
1875 block: &Block,
1876 param_map: &HashMap<String, Expr>,
1877 ) -> Block {
1878 let stmts = block
1879 .stmts
1880 .iter()
1881 .map(|s| self.substitute_params_in_stmt(s, param_map))
1882 .collect();
1883 let expr = block
1884 .expr
1885 .as_ref()
1886 .map(|e| Box::new(self.substitute_params_in_expr(e, param_map)));
1887 Block { stmts, expr }
1888 }
1889
1890 fn substitute_params_in_stmt(&self, stmt: &Stmt, param_map: &HashMap<String, Expr>) -> Stmt {
1891 match stmt {
1892 Stmt::Let { pattern, ty, init } => Stmt::Let {
1893 pattern: pattern.clone(),
1894 ty: ty.clone(),
1895 init: init
1896 .as_ref()
1897 .map(|e| self.substitute_params_in_expr(e, param_map)),
1898 },
1899 Stmt::LetElse { pattern, ty, init, else_branch } => Stmt::LetElse {
1900 pattern: pattern.clone(),
1901 ty: ty.clone(),
1902 init: self.substitute_params_in_expr(init, param_map),
1903 else_branch: Box::new(self.substitute_params_in_expr(else_branch, param_map)),
1904 },
1905 Stmt::Expr(e) => Stmt::Expr(self.substitute_params_in_expr(e, param_map)),
1906 Stmt::Semi(e) => Stmt::Semi(self.substitute_params_in_expr(e, param_map)),
1907 Stmt::Item(item) => Stmt::Item(item.clone()),
1908 }
1909 }
1910
1911 fn substitute_params_in_expr(&self, expr: &Expr, param_map: &HashMap<String, Expr>) -> Expr {
1912 match expr {
1913 Expr::Path(path) => {
1914 if path.segments.len() == 1 {
1916 let name = &path.segments[0].ident.name;
1917 if let Some(arg) = param_map.get(name) {
1918 return arg.clone();
1919 }
1920 }
1921 expr.clone()
1922 }
1923 Expr::Binary { op, left, right } => Expr::Binary {
1924 op: op.clone(),
1925 left: Box::new(self.substitute_params_in_expr(left, param_map)),
1926 right: Box::new(self.substitute_params_in_expr(right, param_map)),
1927 },
1928 Expr::Unary { op, expr: inner } => Expr::Unary {
1929 op: *op,
1930 expr: Box::new(self.substitute_params_in_expr(inner, param_map)),
1931 },
1932 Expr::If {
1933 condition,
1934 then_branch,
1935 else_branch,
1936 } => Expr::If {
1937 condition: Box::new(self.substitute_params_in_expr(condition, param_map)),
1938 then_branch: self.substitute_params_in_block(then_branch, param_map),
1939 else_branch: else_branch
1940 .as_ref()
1941 .map(|e| Box::new(self.substitute_params_in_expr(e, param_map))),
1942 },
1943 Expr::While { label, condition, body } => Expr::While {
1944 label: label.clone(),
1945 condition: Box::new(self.substitute_params_in_expr(condition, param_map)),
1946 body: self.substitute_params_in_block(body, param_map),
1947 },
1948 Expr::Block(block) => Expr::Block(self.substitute_params_in_block(block, param_map)),
1949 Expr::Call { func, args } => Expr::Call {
1950 func: Box::new(self.substitute_params_in_expr(func, param_map)),
1951 args: args
1952 .iter()
1953 .map(|a| self.substitute_params_in_expr(a, param_map))
1954 .collect(),
1955 },
1956 Expr::Return(e) => Expr::Return(
1957 e.as_ref()
1958 .map(|e| Box::new(self.substitute_params_in_expr(e, param_map))),
1959 ),
1960 Expr::Assign { target, value } => Expr::Assign {
1961 target: target.clone(),
1962 value: Box::new(self.substitute_params_in_expr(value, param_map)),
1963 },
1964 Expr::Index { expr: e, index } => Expr::Index {
1965 expr: Box::new(self.substitute_params_in_expr(e, param_map)),
1966 index: Box::new(self.substitute_params_in_expr(index, param_map)),
1967 },
1968 Expr::Array(elements) => Expr::Array(
1969 elements
1970 .iter()
1971 .map(|e| self.substitute_params_in_expr(e, param_map))
1972 .collect(),
1973 ),
1974 Expr::Match { expr: match_expr, arms } => Expr::Match {
1975 expr: Box::new(self.substitute_params_in_expr(match_expr, param_map)),
1976 arms: arms
1977 .iter()
1978 .map(|arm| crate::ast::MatchArm {
1979 pattern: arm.pattern.clone(),
1980 guard: arm.guard.as_ref().map(|g| self.substitute_params_in_expr(g, param_map)),
1981 body: self.substitute_params_in_expr(&arm.body, param_map),
1982 })
1983 .collect(),
1984 },
1985 Expr::Field { expr: inner, field } => Expr::Field {
1987 expr: Box::new(self.substitute_params_in_expr(inner, param_map)),
1988 field: field.clone(),
1989 },
1990 Expr::MethodCall { receiver, method, args, type_args } => Expr::MethodCall {
1992 receiver: Box::new(self.substitute_params_in_expr(receiver, param_map)),
1993 method: method.clone(),
1994 args: args
1995 .iter()
1996 .map(|a| self.substitute_params_in_expr(a, param_map))
1997 .collect(),
1998 type_args: type_args.clone(),
1999 },
2000 other => other.clone(),
2001 }
2002 }
2003
2004 fn pass_inline_block(&mut self, block: &Block) -> Block {
2005 let stmts = block
2006 .stmts
2007 .iter()
2008 .map(|s| self.pass_inline_stmt(s))
2009 .collect();
2010 let expr = block
2011 .expr
2012 .as_ref()
2013 .map(|e| Box::new(self.pass_inline_expr(e)));
2014 Block { stmts, expr }
2015 }
2016
2017 fn pass_inline_stmt(&mut self, stmt: &Stmt) -> Stmt {
2018 match stmt {
2019 Stmt::Let { pattern, ty, init } => Stmt::Let {
2020 pattern: pattern.clone(),
2021 ty: ty.clone(),
2022 init: init.as_ref().map(|e| self.pass_inline_expr(e)),
2023 },
2024 Stmt::LetElse { pattern, ty, init, else_branch } => Stmt::LetElse {
2025 pattern: pattern.clone(),
2026 ty: ty.clone(),
2027 init: self.pass_inline_expr(init),
2028 else_branch: Box::new(self.pass_inline_expr(else_branch)),
2029 },
2030 Stmt::Expr(e) => Stmt::Expr(self.pass_inline_expr(e)),
2031 Stmt::Semi(e) => Stmt::Semi(self.pass_inline_expr(e)),
2032 Stmt::Item(item) => Stmt::Item(item.clone()),
2033 }
2034 }
2035
2036 fn pass_inline_expr(&mut self, expr: &Expr) -> Expr {
2037 match expr {
2038 Expr::Call { func, args } => {
2039 let args: Vec<Expr> = args.iter().map(|a| self.pass_inline_expr(a)).collect();
2041
2042 if let Expr::Path(path) = func.as_ref() {
2044 if path.segments.len() == 1 {
2045 let func_name = &path.segments[0].ident.name;
2046 if let Some(target_func) = self.functions.get(func_name).cloned() {
2047 if self.should_inline(&target_func)
2048 && args.len() == target_func.params.len()
2049 {
2050 if let Some(inlined) = self.inline_call(&target_func, &args) {
2051 return inlined;
2052 }
2053 }
2054 }
2055 }
2056 }
2057
2058 Expr::Call {
2059 func: func.clone(),
2060 args,
2061 }
2062 }
2063 Expr::Binary { op, left, right } => Expr::Binary {
2064 op: op.clone(),
2065 left: Box::new(self.pass_inline_expr(left)),
2066 right: Box::new(self.pass_inline_expr(right)),
2067 },
2068 Expr::Unary { op, expr: inner } => Expr::Unary {
2069 op: *op,
2070 expr: Box::new(self.pass_inline_expr(inner)),
2071 },
2072 Expr::If {
2073 condition,
2074 then_branch,
2075 else_branch,
2076 } => Expr::If {
2077 condition: Box::new(self.pass_inline_expr(condition)),
2078 then_branch: self.pass_inline_block(then_branch),
2079 else_branch: else_branch
2080 .as_ref()
2081 .map(|e| Box::new(self.pass_inline_expr(e))),
2082 },
2083 Expr::While { label, condition, body } => Expr::While {
2084 label: label.clone(),
2085 condition: Box::new(self.pass_inline_expr(condition)),
2086 body: self.pass_inline_block(body),
2087 },
2088 Expr::Block(block) => Expr::Block(self.pass_inline_block(block)),
2089 Expr::Return(e) => Expr::Return(e.as_ref().map(|e| Box::new(self.pass_inline_expr(e)))),
2090 Expr::Assign { target, value } => Expr::Assign {
2091 target: target.clone(),
2092 value: Box::new(self.pass_inline_expr(value)),
2093 },
2094 Expr::Index { expr: e, index } => Expr::Index {
2095 expr: Box::new(self.pass_inline_expr(e)),
2096 index: Box::new(self.pass_inline_expr(index)),
2097 },
2098 Expr::Array(elements) => {
2099 Expr::Array(elements.iter().map(|e| self.pass_inline_expr(e)).collect())
2100 }
2101 other => other.clone(),
2102 }
2103 }
2104
2105 fn pass_loop_unroll_block(&mut self, block: &Block) -> Block {
2111 let stmts = block
2112 .stmts
2113 .iter()
2114 .map(|s| self.pass_loop_unroll_stmt(s))
2115 .collect();
2116 let expr = block
2117 .expr
2118 .as_ref()
2119 .map(|e| Box::new(self.pass_loop_unroll_expr(e)));
2120 Block { stmts, expr }
2121 }
2122
2123 fn pass_loop_unroll_stmt(&mut self, stmt: &Stmt) -> Stmt {
2124 match stmt {
2125 Stmt::Let { pattern, ty, init } => Stmt::Let {
2126 pattern: pattern.clone(),
2127 ty: ty.clone(),
2128 init: init.as_ref().map(|e| self.pass_loop_unroll_expr(e)),
2129 },
2130 Stmt::LetElse { pattern, ty, init, else_branch } => Stmt::LetElse {
2131 pattern: pattern.clone(),
2132 ty: ty.clone(),
2133 init: self.pass_loop_unroll_expr(init),
2134 else_branch: Box::new(self.pass_loop_unroll_expr(else_branch)),
2135 },
2136 Stmt::Expr(e) => Stmt::Expr(self.pass_loop_unroll_expr(e)),
2137 Stmt::Semi(e) => Stmt::Semi(self.pass_loop_unroll_expr(e)),
2138 Stmt::Item(item) => Stmt::Item(item.clone()),
2139 }
2140 }
2141
2142 fn pass_loop_unroll_expr(&mut self, expr: &Expr) -> Expr {
2143 match expr {
2144 Expr::While { label, condition, body } => {
2145 if let Some(unrolled) = self.try_unroll_loop(condition, body) {
2147 self.stats.loops_optimized += 1;
2148 return unrolled;
2149 }
2150 Expr::While {
2152 label: label.clone(),
2153 condition: Box::new(self.pass_loop_unroll_expr(condition)),
2154 body: self.pass_loop_unroll_block(body),
2155 }
2156 }
2157 Expr::If {
2158 condition,
2159 then_branch,
2160 else_branch,
2161 } => Expr::If {
2162 condition: Box::new(self.pass_loop_unroll_expr(condition)),
2163 then_branch: self.pass_loop_unroll_block(then_branch),
2164 else_branch: else_branch
2165 .as_ref()
2166 .map(|e| Box::new(self.pass_loop_unroll_expr(e))),
2167 },
2168 Expr::Block(b) => Expr::Block(self.pass_loop_unroll_block(b)),
2169 Expr::Binary { op, left, right } => Expr::Binary {
2170 op: *op,
2171 left: Box::new(self.pass_loop_unroll_expr(left)),
2172 right: Box::new(self.pass_loop_unroll_expr(right)),
2173 },
2174 Expr::Unary { op, expr: inner } => Expr::Unary {
2175 op: *op,
2176 expr: Box::new(self.pass_loop_unroll_expr(inner)),
2177 },
2178 Expr::Call { func, args } => Expr::Call {
2179 func: func.clone(),
2180 args: args.iter().map(|a| self.pass_loop_unroll_expr(a)).collect(),
2181 },
2182 Expr::Return(e) => {
2183 Expr::Return(e.as_ref().map(|e| Box::new(self.pass_loop_unroll_expr(e))))
2184 }
2185 Expr::Assign { target, value } => Expr::Assign {
2186 target: target.clone(),
2187 value: Box::new(self.pass_loop_unroll_expr(value)),
2188 },
2189 Expr::Unsafe(block) => Expr::Unsafe(self.pass_loop_unroll_block(block)),
2190 other => other.clone(),
2191 }
2192 }
2193
2194 fn try_unroll_loop(&self, condition: &Expr, body: &Block) -> Option<Expr> {
2197 if self.contains_inline_asm_in_block(body) {
2199 return None;
2200 }
2201
2202 let (loop_var, upper_bound) = self.extract_loop_bounds(condition)?;
2204
2205 if upper_bound > 8 || upper_bound <= 0 {
2207 return None;
2208 }
2209
2210 if !self.body_has_simple_increment(&loop_var, body) {
2212 return None;
2213 }
2214
2215 let stmt_count = body.stmts.len();
2217 if stmt_count > 5 {
2218 return None;
2219 }
2220
2221 let mut unrolled_stmts: Vec<Stmt> = Vec::new();
2223
2224 for i in 0..upper_bound {
2225 let substituted_body = self.substitute_loop_var_in_block(body, &loop_var, i);
2227
2228 for stmt in &substituted_body.stmts {
2230 if !self.is_increment_stmt(&loop_var, stmt) {
2231 unrolled_stmts.push(stmt.clone());
2232 }
2233 }
2234 }
2235
2236 Some(Expr::Block(Block {
2238 stmts: unrolled_stmts,
2239 expr: None,
2240 }))
2241 }
2242
2243 fn extract_loop_bounds(&self, condition: &Expr) -> Option<(String, i64)> {
2245 if let Expr::Binary {
2246 op: BinOp::Lt,
2247 left,
2248 right,
2249 } = condition
2250 {
2251 if let Expr::Path(path) = left.as_ref() {
2253 if path.segments.len() == 1 {
2254 let var_name = path.segments[0].ident.name.clone();
2255 if let Some(bound) = self.as_int(right) {
2257 return Some((var_name, bound));
2258 }
2259 }
2260 }
2261 }
2262 None
2263 }
2264
2265 fn body_has_simple_increment(&self, loop_var: &str, body: &Block) -> bool {
2267 for stmt in &body.stmts {
2268 if self.is_increment_stmt(loop_var, stmt) {
2269 return true;
2270 }
2271 }
2272 false
2273 }
2274
2275 fn is_increment_stmt(&self, var_name: &str, stmt: &Stmt) -> bool {
2277 match stmt {
2278 Stmt::Semi(Expr::Assign { target, value })
2279 | Stmt::Expr(Expr::Assign { target, value }) => {
2280 if let Expr::Path(path) = target.as_ref() {
2282 if path.segments.len() == 1 && path.segments[0].ident.name == var_name {
2283 if let Expr::Binary {
2285 op: BinOp::Add,
2286 left,
2287 right,
2288 } = value.as_ref()
2289 {
2290 if let Expr::Path(lpath) = left.as_ref() {
2291 if lpath.segments.len() == 1
2292 && lpath.segments[0].ident.name == var_name
2293 {
2294 if let Some(1) = self.as_int(right) {
2295 return true;
2296 }
2297 }
2298 }
2299 }
2300 }
2301 }
2302 false
2303 }
2304 _ => false,
2305 }
2306 }
2307
2308 fn substitute_loop_var_in_block(&self, block: &Block, var_name: &str, value: i64) -> Block {
2310 let stmts = block
2311 .stmts
2312 .iter()
2313 .map(|s| self.substitute_loop_var_in_stmt(s, var_name, value))
2314 .collect();
2315 let expr = block
2316 .expr
2317 .as_ref()
2318 .map(|e| Box::new(self.substitute_loop_var_in_expr(e, var_name, value)));
2319 Block { stmts, expr }
2320 }
2321
2322 fn substitute_loop_var_in_stmt(&self, stmt: &Stmt, var_name: &str, value: i64) -> Stmt {
2323 match stmt {
2324 Stmt::Let { pattern, ty, init } => Stmt::Let {
2325 pattern: pattern.clone(),
2326 ty: ty.clone(),
2327 init: init
2328 .as_ref()
2329 .map(|e| self.substitute_loop_var_in_expr(e, var_name, value)),
2330 },
2331 Stmt::LetElse { pattern, ty, init, else_branch } => Stmt::LetElse {
2332 pattern: pattern.clone(),
2333 ty: ty.clone(),
2334 init: self.substitute_loop_var_in_expr(init, var_name, value),
2335 else_branch: Box::new(self.substitute_loop_var_in_expr(else_branch, var_name, value)),
2336 },
2337 Stmt::Expr(e) => Stmt::Expr(self.substitute_loop_var_in_expr(e, var_name, value)),
2338 Stmt::Semi(e) => Stmt::Semi(self.substitute_loop_var_in_expr(e, var_name, value)),
2339 Stmt::Item(item) => Stmt::Item(item.clone()),
2340 }
2341 }
2342
2343 fn substitute_loop_var_in_expr(&self, expr: &Expr, var_name: &str, value: i64) -> Expr {
2344 match expr {
2345 Expr::Path(path) => {
2346 if path.segments.len() == 1 && path.segments[0].ident.name == var_name {
2347 return Expr::Literal(Literal::Int {
2348 value: value.to_string(),
2349 base: NumBase::Decimal,
2350 suffix: None,
2351 });
2352 }
2353 expr.clone()
2354 }
2355 Expr::Binary { op, left, right } => Expr::Binary {
2356 op: *op,
2357 left: Box::new(self.substitute_loop_var_in_expr(left, var_name, value)),
2358 right: Box::new(self.substitute_loop_var_in_expr(right, var_name, value)),
2359 },
2360 Expr::Unary { op, expr: inner } => Expr::Unary {
2361 op: *op,
2362 expr: Box::new(self.substitute_loop_var_in_expr(inner, var_name, value)),
2363 },
2364 Expr::Call { func, args } => Expr::Call {
2365 func: Box::new(self.substitute_loop_var_in_expr(func, var_name, value)),
2366 args: args
2367 .iter()
2368 .map(|a| self.substitute_loop_var_in_expr(a, var_name, value))
2369 .collect(),
2370 },
2371 Expr::If {
2372 condition,
2373 then_branch,
2374 else_branch,
2375 } => Expr::If {
2376 condition: Box::new(self.substitute_loop_var_in_expr(condition, var_name, value)),
2377 then_branch: self.substitute_loop_var_in_block(then_branch, var_name, value),
2378 else_branch: else_branch
2379 .as_ref()
2380 .map(|e| Box::new(self.substitute_loop_var_in_expr(e, var_name, value))),
2381 },
2382 Expr::While { label, condition, body } => Expr::While {
2383 label: label.clone(),
2384 condition: Box::new(self.substitute_loop_var_in_expr(condition, var_name, value)),
2385 body: self.substitute_loop_var_in_block(body, var_name, value),
2386 },
2387 Expr::Block(b) => Expr::Block(self.substitute_loop_var_in_block(b, var_name, value)),
2388 Expr::Return(e) => Expr::Return(
2389 e.as_ref()
2390 .map(|e| Box::new(self.substitute_loop_var_in_expr(e, var_name, value))),
2391 ),
2392 Expr::Assign { target, value: v } => Expr::Assign {
2393 target: target.clone(),
2395 value: Box::new(self.substitute_loop_var_in_expr(v, var_name, value)),
2396 },
2397 Expr::Index { expr: e, index } => Expr::Index {
2398 expr: Box::new(self.substitute_loop_var_in_expr(e, var_name, value)),
2399 index: Box::new(self.substitute_loop_var_in_expr(index, var_name, value)),
2400 },
2401 Expr::Array(elements) => Expr::Array(
2402 elements
2403 .iter()
2404 .map(|e| self.substitute_loop_var_in_expr(e, var_name, value))
2405 .collect(),
2406 ),
2407 other => other.clone(),
2408 }
2409 }
2410
2411 fn pass_licm_block(&mut self, block: &Block) -> Block {
2417 let stmts = block.stmts.iter().map(|s| self.pass_licm_stmt(s)).collect();
2418 let expr = block
2419 .expr
2420 .as_ref()
2421 .map(|e| Box::new(self.pass_licm_expr(e)));
2422 Block { stmts, expr }
2423 }
2424
2425 fn pass_licm_stmt(&mut self, stmt: &Stmt) -> Stmt {
2426 match stmt {
2427 Stmt::Let { pattern, ty, init } => Stmt::Let {
2428 pattern: pattern.clone(),
2429 ty: ty.clone(),
2430 init: init.as_ref().map(|e| self.pass_licm_expr(e)),
2431 },
2432 Stmt::LetElse { pattern, ty, init, else_branch } => Stmt::LetElse {
2433 pattern: pattern.clone(),
2434 ty: ty.clone(),
2435 init: self.pass_licm_expr(init),
2436 else_branch: Box::new(self.pass_licm_expr(else_branch)),
2437 },
2438 Stmt::Expr(e) => Stmt::Expr(self.pass_licm_expr(e)),
2439 Stmt::Semi(e) => Stmt::Semi(self.pass_licm_expr(e)),
2440 Stmt::Item(item) => Stmt::Item(item.clone()),
2441 }
2442 }
2443
2444 fn pass_licm_expr(&mut self, expr: &Expr) -> Expr {
2445 match expr {
2446 Expr::While { label, condition, body } => {
2447 if self.contains_inline_asm_in_block(body) || self.contains_inline_asm_in_expr(condition) {
2450 return Expr::While {
2451 label: label.clone(),
2452 condition: Box::new(self.pass_licm_expr(condition)),
2453 body: self.pass_licm_block(body),
2454 };
2455 }
2456
2457 let mut modified_vars = HashSet::new();
2459 self.collect_modified_vars_block(body, &mut modified_vars);
2460
2461 self.collect_modified_vars_expr(condition, &mut modified_vars);
2463
2464 let mut invariant_exprs: Vec<(String, Expr)> = Vec::new();
2466 self.find_loop_invariants(body, &modified_vars, &mut invariant_exprs);
2467
2468 if invariant_exprs.is_empty() {
2469 return Expr::While {
2471 label: label.clone(),
2472 condition: Box::new(self.pass_licm_expr(condition)),
2473 body: self.pass_licm_block(body),
2474 };
2475 }
2476
2477 let mut pre_loop_stmts: Vec<Stmt> = Vec::new();
2479 let mut substitution_map: HashMap<String, String> = HashMap::new();
2480
2481 for (original_key, invariant_expr) in &invariant_exprs {
2482 let var_name = format!("__licm_{}", self.cse_counter);
2483 self.cse_counter += 1;
2484
2485 pre_loop_stmts.push(make_cse_let(&var_name, invariant_expr.clone()));
2486 substitution_map.insert(original_key.clone(), var_name);
2487 self.stats.loops_optimized += 1;
2488 }
2489
2490 let new_body =
2492 self.replace_invariants_in_block(body, &invariant_exprs, &substitution_map);
2493
2494 let new_while = Expr::While {
2496 label: label.clone(),
2497 condition: Box::new(self.pass_licm_expr(condition)),
2498 body: self.pass_licm_block(&new_body),
2499 };
2500
2501 pre_loop_stmts.push(Stmt::Expr(new_while));
2503 Expr::Block(Block {
2504 stmts: pre_loop_stmts,
2505 expr: None,
2506 })
2507 }
2508 Expr::If {
2509 condition,
2510 then_branch,
2511 else_branch,
2512 } => Expr::If {
2513 condition: Box::new(self.pass_licm_expr(condition)),
2514 then_branch: self.pass_licm_block(then_branch),
2515 else_branch: else_branch
2516 .as_ref()
2517 .map(|e| Box::new(self.pass_licm_expr(e))),
2518 },
2519 Expr::Block(b) => Expr::Block(self.pass_licm_block(b)),
2520 Expr::Binary { op, left, right } => Expr::Binary {
2521 op: *op,
2522 left: Box::new(self.pass_licm_expr(left)),
2523 right: Box::new(self.pass_licm_expr(right)),
2524 },
2525 Expr::Unary { op, expr: inner } => Expr::Unary {
2526 op: *op,
2527 expr: Box::new(self.pass_licm_expr(inner)),
2528 },
2529 Expr::Call { func, args } => Expr::Call {
2530 func: func.clone(),
2531 args: args.iter().map(|a| self.pass_licm_expr(a)).collect(),
2532 },
2533 Expr::Return(e) => Expr::Return(e.as_ref().map(|e| Box::new(self.pass_licm_expr(e)))),
2534 Expr::Assign { target, value } => Expr::Assign {
2535 target: target.clone(),
2536 value: Box::new(self.pass_licm_expr(value)),
2537 },
2538 Expr::Unsafe(inner) => Expr::Unsafe(self.pass_licm_block(inner)),
2539 other => other.clone(),
2540 }
2541 }
2542
2543 fn collect_modified_vars_block(&self, block: &Block, modified: &mut HashSet<String>) {
2545 for stmt in &block.stmts {
2546 self.collect_modified_vars_stmt(stmt, modified);
2547 }
2548 if let Some(expr) = &block.expr {
2549 self.collect_modified_vars_expr(expr, modified);
2550 }
2551 }
2552
2553 fn collect_modified_vars_stmt(&self, stmt: &Stmt, modified: &mut HashSet<String>) {
2554 match stmt {
2555 Stmt::Let { pattern, init, .. } => {
2556 if let Pattern::Ident { name, .. } = pattern {
2558 modified.insert(name.name.clone());
2559 }
2560 if let Some(e) = init {
2561 self.collect_modified_vars_expr(e, modified);
2562 }
2563 }
2564 Stmt::Expr(e) | Stmt::Semi(e) => self.collect_modified_vars_expr(e, modified),
2565 _ => {}
2566 }
2567 }
2568
2569 fn collect_modified_vars_expr(&self, expr: &Expr, modified: &mut HashSet<String>) {
2570 match expr {
2571 Expr::Assign { target, value } => {
2572 if let Expr::Path(path) = target.as_ref() {
2573 if path.segments.len() == 1 {
2574 modified.insert(path.segments[0].ident.name.clone());
2575 }
2576 }
2577 self.collect_modified_vars_expr(value, modified);
2578 }
2579 Expr::Binary { left, right, .. } => {
2580 self.collect_modified_vars_expr(left, modified);
2581 self.collect_modified_vars_expr(right, modified);
2582 }
2583 Expr::Unary { expr: inner, .. } => {
2584 self.collect_modified_vars_expr(inner, modified);
2585 }
2586 Expr::If {
2587 condition,
2588 then_branch,
2589 else_branch,
2590 } => {
2591 self.collect_modified_vars_expr(condition, modified);
2592 self.collect_modified_vars_block(then_branch, modified);
2593 if let Some(e) = else_branch {
2594 self.collect_modified_vars_expr(e, modified);
2595 }
2596 }
2597 Expr::While { label, condition, body } => {
2598 self.collect_modified_vars_expr(condition, modified);
2599 self.collect_modified_vars_block(body, modified);
2600 }
2601 Expr::Block(b) => self.collect_modified_vars_block(b, modified),
2602 Expr::Call { args, .. } => {
2603 for arg in args {
2604 self.collect_modified_vars_expr(arg, modified);
2605 }
2606 }
2607 Expr::Return(Some(e)) => self.collect_modified_vars_expr(e, modified),
2608 Expr::Unsafe(inner) => self.collect_modified_vars_block(inner, modified),
2609 Expr::InlineAsm(asm) => {
2610 for operand in &asm.outputs {
2612 if let Expr::Path(path) = &operand.expr {
2613 if path.segments.len() == 1 {
2614 modified.insert(path.segments[0].ident.name.clone());
2615 }
2616 }
2617 }
2618 }
2619 _ => {}
2620 }
2621 }
2622
2623 fn find_loop_invariants(
2625 &self,
2626 block: &Block,
2627 modified: &HashSet<String>,
2628 out: &mut Vec<(String, Expr)>,
2629 ) {
2630 for stmt in &block.stmts {
2631 self.find_loop_invariants_stmt(stmt, modified, out);
2632 }
2633 if let Some(expr) = &block.expr {
2634 self.find_loop_invariants_expr(expr, modified, out);
2635 }
2636 }
2637
2638 fn find_loop_invariants_stmt(
2639 &self,
2640 stmt: &Stmt,
2641 modified: &HashSet<String>,
2642 out: &mut Vec<(String, Expr)>,
2643 ) {
2644 match stmt {
2645 Stmt::Let { init: Some(e), .. } => self.find_loop_invariants_expr(e, modified, out),
2646 Stmt::Expr(e) | Stmt::Semi(e) => self.find_loop_invariants_expr(e, modified, out),
2647 _ => {}
2648 }
2649 }
2650
2651 fn find_loop_invariants_expr(
2652 &self,
2653 expr: &Expr,
2654 modified: &HashSet<String>,
2655 out: &mut Vec<(String, Expr)>,
2656 ) {
2657 match expr {
2659 Expr::Binary { left, right, .. } => {
2660 self.find_loop_invariants_expr(left, modified, out);
2661 self.find_loop_invariants_expr(right, modified, out);
2662 }
2663 Expr::Unary { expr: inner, .. } => {
2664 self.find_loop_invariants_expr(inner, modified, out);
2665 }
2666 Expr::Call { args, .. } => {
2667 for arg in args {
2668 self.find_loop_invariants_expr(arg, modified, out);
2669 }
2670 }
2671 Expr::Index { expr: e, index } => {
2672 self.find_loop_invariants_expr(e, modified, out);
2673 self.find_loop_invariants_expr(index, modified, out);
2674 }
2675 _ => {}
2676 }
2677
2678 if self.is_loop_invariant(expr, modified) && is_cse_worthy(expr) && is_pure_expr(expr) {
2680 let key = format!("{:?}", expr_hash(expr));
2681 if !out.iter().any(|(k, _)| k == &key) {
2683 out.push((key, expr.clone()));
2684 }
2685 }
2686 }
2687
2688 fn is_loop_invariant(&self, expr: &Expr, modified: &HashSet<String>) -> bool {
2690 match expr {
2691 Expr::Literal(_) => true,
2692 Expr::Path(path) => {
2693 if path.segments.len() == 1 {
2694 !modified.contains(&path.segments[0].ident.name)
2695 } else {
2696 true }
2698 }
2699 Expr::Binary { left, right, .. } => {
2700 self.is_loop_invariant(left, modified) && self.is_loop_invariant(right, modified)
2701 }
2702 Expr::Unary { expr: inner, .. } => self.is_loop_invariant(inner, modified),
2703 Expr::Index { expr: e, index } => {
2704 self.is_loop_invariant(e, modified) && self.is_loop_invariant(index, modified)
2705 }
2706 Expr::Call { .. } => false,
2708 _ => false,
2710 }
2711 }
2712
2713 fn replace_invariants_in_block(
2715 &self,
2716 block: &Block,
2717 invariants: &[(String, Expr)],
2718 subs: &HashMap<String, String>,
2719 ) -> Block {
2720 let stmts = block
2721 .stmts
2722 .iter()
2723 .map(|s| self.replace_invariants_in_stmt(s, invariants, subs))
2724 .collect();
2725 let expr = block
2726 .expr
2727 .as_ref()
2728 .map(|e| Box::new(self.replace_invariants_in_expr(e, invariants, subs)));
2729 Block { stmts, expr }
2730 }
2731
2732 fn replace_invariants_in_stmt(
2733 &self,
2734 stmt: &Stmt,
2735 invariants: &[(String, Expr)],
2736 subs: &HashMap<String, String>,
2737 ) -> Stmt {
2738 match stmt {
2739 Stmt::Let { pattern, ty, init } => Stmt::Let {
2740 pattern: pattern.clone(),
2741 ty: ty.clone(),
2742 init: init
2743 .as_ref()
2744 .map(|e| self.replace_invariants_in_expr(e, invariants, subs)),
2745 },
2746 Stmt::LetElse { pattern, ty, init, else_branch } => Stmt::LetElse {
2747 pattern: pattern.clone(),
2748 ty: ty.clone(),
2749 init: self.replace_invariants_in_expr(init, invariants, subs),
2750 else_branch: Box::new(self.replace_invariants_in_expr(else_branch, invariants, subs)),
2751 },
2752 Stmt::Expr(e) => Stmt::Expr(self.replace_invariants_in_expr(e, invariants, subs)),
2753 Stmt::Semi(e) => Stmt::Semi(self.replace_invariants_in_expr(e, invariants, subs)),
2754 Stmt::Item(item) => Stmt::Item(item.clone()),
2755 }
2756 }
2757
2758 fn replace_invariants_in_expr(
2759 &self,
2760 expr: &Expr,
2761 invariants: &[(String, Expr)],
2762 subs: &HashMap<String, String>,
2763 ) -> Expr {
2764 let key = format!("{:?}", expr_hash(expr));
2766 for (inv_key, inv_expr) in invariants {
2767 if &key == inv_key && expr_eq(expr, inv_expr) {
2768 if let Some(var_name) = subs.get(inv_key) {
2769 return Expr::Path(TypePath {
2770 segments: vec![PathSegment {
2771 ident: Ident {
2772 name: var_name.clone(),
2773 evidentiality: None,
2774 affect: None,
2775 span: Span { start: 0, end: 0 },
2776 },
2777 generics: None,
2778 }],
2779 });
2780 }
2781 }
2782 }
2783
2784 match expr {
2786 Expr::Binary { op, left, right } => Expr::Binary {
2787 op: *op,
2788 left: Box::new(self.replace_invariants_in_expr(left, invariants, subs)),
2789 right: Box::new(self.replace_invariants_in_expr(right, invariants, subs)),
2790 },
2791 Expr::Unary { op, expr: inner } => Expr::Unary {
2792 op: *op,
2793 expr: Box::new(self.replace_invariants_in_expr(inner, invariants, subs)),
2794 },
2795 Expr::Call { func, args } => Expr::Call {
2796 func: func.clone(),
2797 args: args
2798 .iter()
2799 .map(|a| self.replace_invariants_in_expr(a, invariants, subs))
2800 .collect(),
2801 },
2802 Expr::If {
2803 condition,
2804 then_branch,
2805 else_branch,
2806 } => Expr::If {
2807 condition: Box::new(self.replace_invariants_in_expr(condition, invariants, subs)),
2808 then_branch: self.replace_invariants_in_block(then_branch, invariants, subs),
2809 else_branch: else_branch
2810 .as_ref()
2811 .map(|e| Box::new(self.replace_invariants_in_expr(e, invariants, subs))),
2812 },
2813 Expr::While { label, condition, body } => Expr::While {
2814 label: label.clone(),
2815 condition: Box::new(self.replace_invariants_in_expr(condition, invariants, subs)),
2816 body: self.replace_invariants_in_block(body, invariants, subs),
2817 },
2818 Expr::Block(b) => Expr::Block(self.replace_invariants_in_block(b, invariants, subs)),
2819 Expr::Return(e) => Expr::Return(
2820 e.as_ref()
2821 .map(|e| Box::new(self.replace_invariants_in_expr(e, invariants, subs))),
2822 ),
2823 Expr::Assign { target, value } => Expr::Assign {
2824 target: target.clone(),
2825 value: Box::new(self.replace_invariants_in_expr(value, invariants, subs)),
2826 },
2827 Expr::Index { expr: e, index } => Expr::Index {
2828 expr: Box::new(self.replace_invariants_in_expr(e, invariants, subs)),
2829 index: Box::new(self.replace_invariants_in_expr(index, invariants, subs)),
2830 },
2831 other => other.clone(),
2832 }
2833 }
2834
2835 fn pass_cse_block(&mut self, block: &Block) -> Block {
2840 if self.contains_inline_asm_in_block(block) {
2843 return self.pass_cse_nested(block);
2844 }
2845
2846 let mut collected = Vec::new();
2848 collect_exprs_from_block(block, &mut collected);
2849
2850 let mut expr_counts: HashMap<u64, Vec<Expr>> = HashMap::new();
2852 for ce in &collected {
2853 let entry = expr_counts.entry(ce.hash).or_insert_with(Vec::new);
2854 let found = entry.iter().any(|e| expr_eq(e, &ce.expr));
2856 if !found {
2857 entry.push(ce.expr.clone());
2858 }
2859 }
2860
2861 let mut occurrence_counts: Vec<(Expr, usize)> = Vec::new();
2863 for ce in &collected {
2864 let existing = occurrence_counts
2866 .iter_mut()
2867 .find(|(e, _)| expr_eq(e, &ce.expr));
2868 if let Some((_, count)) = existing {
2869 *count += 1;
2870 } else {
2871 occurrence_counts.push((ce.expr.clone(), 1));
2872 }
2873 }
2874
2875 let mut declared_vars: HashSet<String> = HashSet::new();
2880 self.collect_declared_vars_in_block(block, &mut declared_vars);
2881
2882 let candidates: Vec<Expr> = occurrence_counts
2885 .into_iter()
2886 .filter(|(expr, count)| {
2887 if *count < 2 {
2888 return false;
2889 }
2890 !self.expr_references_vars(expr, &declared_vars)
2892 })
2893 .map(|(expr, _)| expr)
2894 .collect();
2895
2896 if candidates.is_empty() {
2897 return self.pass_cse_nested(block);
2899 }
2900
2901 let mut result_block = block.clone();
2903 let mut new_lets: Vec<Stmt> = Vec::new();
2904
2905 for expr in candidates {
2906 let var_name = format!("__cse_{}", self.cse_counter);
2907 self.cse_counter += 1;
2908
2909 new_lets.push(make_cse_let(&var_name, expr.clone()));
2911
2912 result_block = replace_in_block(&result_block, &expr, &var_name);
2914
2915 self.stats.expressions_deduplicated += 1;
2916 }
2917
2918 let mut final_stmts = new_lets;
2920 final_stmts.extend(result_block.stmts);
2921
2922 let result = Block {
2924 stmts: final_stmts,
2925 expr: result_block.expr,
2926 };
2927 self.pass_cse_nested(&result)
2928 }
2929
2930 fn collect_declared_vars_in_block(&self, block: &Block, vars: &mut HashSet<String>) {
2932 for stmt in &block.stmts {
2933 match stmt {
2934 Stmt::Let { pattern, init, .. } => {
2935 if let Pattern::Ident { name, .. } = pattern {
2937 vars.insert(name.name.clone());
2938 }
2939 if let Some(init_expr) = init {
2941 self.collect_declared_vars_in_expr(init_expr, vars);
2942 }
2943 }
2944 Stmt::LetElse { pattern, init, else_branch, .. } => {
2945 if let Pattern::Ident { name, .. } = pattern {
2946 vars.insert(name.name.clone());
2947 }
2948 self.collect_declared_vars_in_expr(init, vars);
2949 self.collect_declared_vars_in_expr(else_branch, vars);
2950 }
2951 Stmt::Expr(expr) | Stmt::Semi(expr) => {
2952 self.collect_declared_vars_in_expr(expr, vars);
2953 }
2954 Stmt::Item(_) => {}
2955 }
2956 }
2957 if let Some(expr) = &block.expr {
2958 self.collect_declared_vars_in_expr(expr, vars);
2959 }
2960 }
2961
2962 fn collect_declared_vars_in_expr(&self, expr: &Expr, vars: &mut HashSet<String>) {
2964 match expr {
2965 Expr::Block(inner_block) | Expr::Unsafe(inner_block) => {
2966 self.collect_declared_vars_in_block(inner_block, vars);
2967 }
2968 Expr::If { condition, then_branch, else_branch } => {
2969 self.collect_declared_vars_in_expr(condition, vars);
2970 self.collect_declared_vars_in_block(then_branch, vars);
2971 if let Some(else_expr) = else_branch {
2972 self.collect_declared_vars_in_expr(else_expr, vars);
2973 }
2974 }
2975 Expr::While { condition, body, .. } => {
2976 self.collect_declared_vars_in_expr(condition, vars);
2977 self.collect_declared_vars_in_block(body, vars);
2978 }
2979 Expr::Loop { body, .. } => {
2980 self.collect_declared_vars_in_block(body, vars);
2981 }
2982 Expr::For { pattern, iter, body, .. } => {
2983 if let Pattern::Ident { name, .. } = pattern {
2984 vars.insert(name.name.clone());
2985 }
2986 self.collect_declared_vars_in_expr(iter, vars);
2987 self.collect_declared_vars_in_block(body, vars);
2988 }
2989 Expr::Match { expr: match_expr, arms } => {
2990 self.collect_declared_vars_in_expr(match_expr, vars);
2991 for arm in arms {
2992 if let Some(guard) = &arm.guard {
2993 self.collect_declared_vars_in_expr(guard, vars);
2994 }
2995 self.collect_declared_vars_in_expr(&arm.body, vars);
2996 }
2997 }
2998 Expr::Binary { left, right, .. } => {
2999 self.collect_declared_vars_in_expr(left, vars);
3000 self.collect_declared_vars_in_expr(right, vars);
3001 }
3002 Expr::Unary { expr: inner, .. } => {
3003 self.collect_declared_vars_in_expr(inner, vars);
3004 }
3005 Expr::Call { func, args } => {
3006 self.collect_declared_vars_in_expr(func, vars);
3007 for arg in args {
3008 self.collect_declared_vars_in_expr(arg, vars);
3009 }
3010 }
3011 Expr::Assign { target, value } => {
3012 self.collect_declared_vars_in_expr(target, vars);
3013 self.collect_declared_vars_in_expr(value, vars);
3014 }
3015 Expr::Return(Some(inner)) => {
3016 self.collect_declared_vars_in_expr(inner, vars);
3017 }
3018 Expr::Closure { body, .. } => {
3019 self.collect_declared_vars_in_expr(body, vars);
3020 }
3021 _ => {}
3022 }
3023 }
3024
3025 fn expr_references_vars(&self, expr: &Expr, vars: &HashSet<String>) -> bool {
3027 match expr {
3028 Expr::Path(path) => {
3029 if path.segments.len() == 1 {
3030 vars.contains(&path.segments[0].ident.name)
3031 } else {
3032 false
3033 }
3034 }
3035 Expr::Binary { left, right, .. } => {
3036 self.expr_references_vars(left, vars) || self.expr_references_vars(right, vars)
3037 }
3038 Expr::Unary { expr: inner, .. } => self.expr_references_vars(inner, vars),
3039 Expr::Call { func, args } => {
3040 self.expr_references_vars(func, vars)
3041 || args.iter().any(|a| self.expr_references_vars(a, vars))
3042 }
3043 Expr::Index { expr: e, index } => {
3044 self.expr_references_vars(e, vars) || self.expr_references_vars(index, vars)
3045 }
3046 Expr::Field { expr: e, .. } => self.expr_references_vars(e, vars),
3047 Expr::MethodCall { receiver, args, .. } => {
3048 self.expr_references_vars(receiver, vars)
3049 || args.iter().any(|a| self.expr_references_vars(a, vars))
3050 }
3051 Expr::Cast { expr: e, .. } => self.expr_references_vars(e, vars),
3052 Expr::Deref(e) => self.expr_references_vars(e, vars),
3053 Expr::Tuple(elems) | Expr::Array(elems) => {
3054 elems.iter().any(|e| self.expr_references_vars(e, vars))
3055 }
3056 _ => false,
3058 }
3059 }
3060
3061 fn pass_cse_nested(&mut self, block: &Block) -> Block {
3063 let stmts = block
3064 .stmts
3065 .iter()
3066 .map(|stmt| self.pass_cse_stmt(stmt))
3067 .collect();
3068 let expr = block.expr.as_ref().map(|e| Box::new(self.pass_cse_expr(e)));
3069 Block { stmts, expr }
3070 }
3071
3072 fn pass_cse_stmt(&mut self, stmt: &Stmt) -> Stmt {
3073 match stmt {
3074 Stmt::Let { pattern, ty, init } => Stmt::Let {
3075 pattern: pattern.clone(),
3076 ty: ty.clone(),
3077 init: init.as_ref().map(|e| self.pass_cse_expr(e)),
3078 },
3079 Stmt::LetElse { pattern, ty, init, else_branch } => Stmt::LetElse {
3080 pattern: pattern.clone(),
3081 ty: ty.clone(),
3082 init: self.pass_cse_expr(init),
3083 else_branch: Box::new(self.pass_cse_expr(else_branch)),
3084 },
3085 Stmt::Expr(e) => Stmt::Expr(self.pass_cse_expr(e)),
3086 Stmt::Semi(e) => Stmt::Semi(self.pass_cse_expr(e)),
3087 Stmt::Item(item) => Stmt::Item(item.clone()),
3088 }
3089 }
3090
3091 fn pass_cse_expr(&mut self, expr: &Expr) -> Expr {
3092 match expr {
3093 Expr::If {
3094 condition,
3095 then_branch,
3096 else_branch,
3097 } => Expr::If {
3098 condition: Box::new(self.pass_cse_expr(condition)),
3099 then_branch: self.pass_cse_block(then_branch),
3100 else_branch: else_branch
3101 .as_ref()
3102 .map(|e| Box::new(self.pass_cse_expr(e))),
3103 },
3104 Expr::While { label, condition, body } => Expr::While {
3105 label: label.clone(),
3106 condition: Box::new(self.pass_cse_expr(condition)),
3107 body: self.pass_cse_block(body),
3108 },
3109 Expr::Block(b) => Expr::Block(self.pass_cse_block(b)),
3110 Expr::Binary { op, left, right } => Expr::Binary {
3111 op: *op,
3112 left: Box::new(self.pass_cse_expr(left)),
3113 right: Box::new(self.pass_cse_expr(right)),
3114 },
3115 Expr::Unary { op, expr: inner } => Expr::Unary {
3116 op: *op,
3117 expr: Box::new(self.pass_cse_expr(inner)),
3118 },
3119 Expr::Call { func, args } => Expr::Call {
3120 func: func.clone(),
3121 args: args.iter().map(|a| self.pass_cse_expr(a)).collect(),
3122 },
3123 Expr::Return(e) => Expr::Return(e.as_ref().map(|e| Box::new(self.pass_cse_expr(e)))),
3124 Expr::Assign { target, value } => Expr::Assign {
3125 target: target.clone(),
3126 value: Box::new(self.pass_cse_expr(value)),
3127 },
3128 other => other.clone(),
3129 }
3130 }
3131}
3132
3133fn expr_hash(expr: &Expr) -> u64 {
3139 use std::collections::hash_map::DefaultHasher;
3140 use std::hash::Hasher;
3141
3142 let mut hasher = DefaultHasher::new();
3143 expr_hash_recursive(expr, &mut hasher);
3144 hasher.finish()
3145}
3146
3147fn expr_hash_recursive<H: std::hash::Hasher>(expr: &Expr, hasher: &mut H) {
3148 use std::hash::Hash;
3149
3150 std::mem::discriminant(expr).hash(hasher);
3151
3152 match expr {
3153 Expr::Literal(lit) => match lit {
3154 Literal::Int { value, .. } => value.hash(hasher),
3155 Literal::Float { value, .. } => value.hash(hasher),
3156 Literal::String(s) => s.hash(hasher),
3157 Literal::Char(c) => c.hash(hasher),
3158 Literal::Bool(b) => b.hash(hasher),
3159 _ => {}
3160 },
3161 Expr::Path(path) => {
3162 for seg in &path.segments {
3163 seg.ident.name.hash(hasher);
3164 }
3165 }
3166 Expr::Binary { op, left, right } => {
3167 std::mem::discriminant(op).hash(hasher);
3168 expr_hash_recursive(left, hasher);
3169 expr_hash_recursive(right, hasher);
3170 }
3171 Expr::Unary { op, expr } => {
3172 std::mem::discriminant(op).hash(hasher);
3173 expr_hash_recursive(expr, hasher);
3174 }
3175 Expr::Call { func, args } => {
3176 expr_hash_recursive(func, hasher);
3177 args.len().hash(hasher);
3178 for arg in args {
3179 expr_hash_recursive(arg, hasher);
3180 }
3181 }
3182 Expr::Index { expr, index } => {
3183 expr_hash_recursive(expr, hasher);
3184 expr_hash_recursive(index, hasher);
3185 }
3186 _ => {}
3187 }
3188}
3189
3190fn is_pure_expr(expr: &Expr) -> bool {
3192 match expr {
3193 Expr::Literal(_) => true,
3194 Expr::Path(_) => true,
3195 Expr::Binary { left, right, .. } => is_pure_expr(left) && is_pure_expr(right),
3196 Expr::Unary { expr, .. } => is_pure_expr(expr),
3197 Expr::If {
3198 condition,
3199 then_branch,
3200 else_branch,
3201 } => {
3202 is_pure_expr(condition)
3203 && then_branch.stmts.is_empty()
3204 && then_branch
3205 .expr
3206 .as_ref()
3207 .map(|e| is_pure_expr(e))
3208 .unwrap_or(true)
3209 && else_branch
3210 .as_ref()
3211 .map(|e| is_pure_expr(e))
3212 .unwrap_or(true)
3213 }
3214 Expr::Index { expr, index } => is_pure_expr(expr) && is_pure_expr(index),
3215 Expr::Array(elements) => elements.iter().all(is_pure_expr),
3216 Expr::Call { .. } => false,
3218 Expr::Assign { .. } => false,
3219 Expr::Return(_) => false,
3220 _ => false,
3221 }
3222}
3223
3224fn is_cse_worthy(expr: &Expr) -> bool {
3226 match expr {
3227 Expr::Literal(_) => false,
3229 Expr::Path(_) => false,
3230 Expr::Binary { .. } => true,
3232 Expr::Unary { .. } => true,
3234 Expr::Call { .. } => false,
3236 Expr::Index { .. } => true,
3238 _ => false,
3239 }
3240}
3241
3242fn expr_eq(a: &Expr, b: &Expr) -> bool {
3244 match (a, b) {
3245 (Expr::Literal(la), Expr::Literal(lb)) => match (la, lb) {
3246 (Literal::Int { value: va, .. }, Literal::Int { value: vb, .. }) => va == vb,
3247 (Literal::Float { value: va, .. }, Literal::Float { value: vb, .. }) => va == vb,
3248 (Literal::String(sa), Literal::String(sb)) => sa == sb,
3249 (Literal::Char(ca), Literal::Char(cb)) => ca == cb,
3250 (Literal::Bool(ba), Literal::Bool(bb)) => ba == bb,
3251 _ => false,
3252 },
3253 (Expr::Path(pa), Expr::Path(pb)) => {
3254 pa.segments.len() == pb.segments.len()
3255 && pa
3256 .segments
3257 .iter()
3258 .zip(&pb.segments)
3259 .all(|(sa, sb)| sa.ident.name == sb.ident.name)
3260 }
3261 (
3262 Expr::Binary {
3263 op: oa,
3264 left: la,
3265 right: ra,
3266 },
3267 Expr::Binary {
3268 op: ob,
3269 left: lb,
3270 right: rb,
3271 },
3272 ) => oa == ob && expr_eq(la, lb) && expr_eq(ra, rb),
3273 (Expr::Unary { op: oa, expr: ea }, Expr::Unary { op: ob, expr: eb }) => {
3274 oa == ob && expr_eq(ea, eb)
3275 }
3276 (
3277 Expr::Index {
3278 expr: ea,
3279 index: ia,
3280 },
3281 Expr::Index {
3282 expr: eb,
3283 index: ib,
3284 },
3285 ) => expr_eq(ea, eb) && expr_eq(ia, ib),
3286 (Expr::Call { func: fa, args: aa }, Expr::Call { func: fb, args: ab }) => {
3287 expr_eq(fa, fb) && aa.len() == ab.len() && aa.iter().zip(ab).all(|(a, b)| expr_eq(a, b))
3288 }
3289 _ => false,
3290 }
3291}
3292
3293#[derive(Clone)]
3295struct CollectedExpr {
3296 expr: Expr,
3297 hash: u64,
3298}
3299
3300fn collect_exprs_from_expr(expr: &Expr, out: &mut Vec<CollectedExpr>) {
3302 match expr {
3304 Expr::Binary { left, right, .. } => {
3305 collect_exprs_from_expr(left, out);
3306 collect_exprs_from_expr(right, out);
3307 }
3308 Expr::Unary { expr: inner, .. } => {
3309 collect_exprs_from_expr(inner, out);
3310 }
3311 Expr::Index { expr: e, index } => {
3312 collect_exprs_from_expr(e, out);
3313 collect_exprs_from_expr(index, out);
3314 }
3315 Expr::Call { func, args } => {
3316 collect_exprs_from_expr(func, out);
3317 for arg in args {
3318 collect_exprs_from_expr(arg, out);
3319 }
3320 }
3321 Expr::If {
3322 condition,
3323 then_branch,
3324 else_branch,
3325 } => {
3326 collect_exprs_from_expr(condition, out);
3327 collect_exprs_from_block(then_branch, out);
3328 if let Some(else_expr) = else_branch {
3329 collect_exprs_from_expr(else_expr, out);
3330 }
3331 }
3332 Expr::While { label, condition, body } => {
3333 collect_exprs_from_expr(condition, out);
3334 collect_exprs_from_block(body, out);
3335 }
3336 Expr::Block(block) => {
3337 collect_exprs_from_block(block, out);
3338 }
3339 Expr::Return(Some(e)) => {
3340 collect_exprs_from_expr(e, out);
3341 }
3342 Expr::Assign { value, .. } => {
3343 collect_exprs_from_expr(value, out);
3344 }
3345 Expr::Array(elements) => {
3346 for e in elements {
3347 collect_exprs_from_expr(e, out);
3348 }
3349 }
3350 _ => {}
3351 }
3352
3353 if is_cse_worthy(expr) && is_pure_expr(expr) {
3355 out.push(CollectedExpr {
3356 expr: expr.clone(),
3357 hash: expr_hash(expr),
3358 });
3359 }
3360}
3361
3362fn collect_exprs_from_block(block: &Block, out: &mut Vec<CollectedExpr>) {
3364 for stmt in &block.stmts {
3365 match stmt {
3366 Stmt::Let { init: Some(e), .. } => collect_exprs_from_expr(e, out),
3367 Stmt::Expr(e) | Stmt::Semi(e) => collect_exprs_from_expr(e, out),
3368 _ => {}
3369 }
3370 }
3371 if let Some(e) = &block.expr {
3372 collect_exprs_from_expr(e, out);
3373 }
3374}
3375
3376fn replace_in_expr(expr: &Expr, target: &Expr, var_name: &str) -> Expr {
3378 if expr_eq(expr, target) {
3380 return Expr::Path(TypePath {
3381 segments: vec![PathSegment {
3382 ident: Ident {
3383 name: var_name.to_string(),
3384 evidentiality: None,
3385 affect: None,
3386 span: Span { start: 0, end: 0 },
3387 },
3388 generics: None,
3389 }],
3390 });
3391 }
3392
3393 match expr {
3395 Expr::Binary { op, left, right } => Expr::Binary {
3396 op: *op,
3397 left: Box::new(replace_in_expr(left, target, var_name)),
3398 right: Box::new(replace_in_expr(right, target, var_name)),
3399 },
3400 Expr::Unary { op, expr: inner } => Expr::Unary {
3401 op: *op,
3402 expr: Box::new(replace_in_expr(inner, target, var_name)),
3403 },
3404 Expr::Index { expr: e, index } => Expr::Index {
3405 expr: Box::new(replace_in_expr(e, target, var_name)),
3406 index: Box::new(replace_in_expr(index, target, var_name)),
3407 },
3408 Expr::Call { func, args } => Expr::Call {
3409 func: Box::new(replace_in_expr(func, target, var_name)),
3410 args: args
3411 .iter()
3412 .map(|a| replace_in_expr(a, target, var_name))
3413 .collect(),
3414 },
3415 Expr::If {
3416 condition,
3417 then_branch,
3418 else_branch,
3419 } => Expr::If {
3420 condition: Box::new(replace_in_expr(condition, target, var_name)),
3421 then_branch: replace_in_block(then_branch, target, var_name),
3422 else_branch: else_branch
3423 .as_ref()
3424 .map(|e| Box::new(replace_in_expr(e, target, var_name))),
3425 },
3426 Expr::While { label, condition, body } => Expr::While {
3427 label: label.clone(),
3428 condition: Box::new(replace_in_expr(condition, target, var_name)),
3429 body: replace_in_block(body, target, var_name),
3430 },
3431 Expr::Block(block) => Expr::Block(replace_in_block(block, target, var_name)),
3432 Expr::Return(e) => Expr::Return(
3433 e.as_ref()
3434 .map(|e| Box::new(replace_in_expr(e, target, var_name))),
3435 ),
3436 Expr::Assign { target: t, value } => Expr::Assign {
3437 target: t.clone(),
3438 value: Box::new(replace_in_expr(value, target, var_name)),
3439 },
3440 Expr::Array(elements) => Expr::Array(
3441 elements
3442 .iter()
3443 .map(|e| replace_in_expr(e, target, var_name))
3444 .collect(),
3445 ),
3446 other => other.clone(),
3447 }
3448}
3449
3450fn replace_in_block(block: &Block, target: &Expr, var_name: &str) -> Block {
3452 let stmts = block
3453 .stmts
3454 .iter()
3455 .map(|stmt| match stmt {
3456 Stmt::Let { pattern, ty, init } => Stmt::Let {
3457 pattern: pattern.clone(),
3458 ty: ty.clone(),
3459 init: init.as_ref().map(|e| replace_in_expr(e, target, var_name)),
3460 },
3461 Stmt::LetElse { pattern, ty, init, else_branch } => Stmt::LetElse {
3462 pattern: pattern.clone(),
3463 ty: ty.clone(),
3464 init: replace_in_expr(init, target, var_name),
3465 else_branch: Box::new(replace_in_expr(else_branch, target, var_name)),
3466 },
3467 Stmt::Expr(e) => Stmt::Expr(replace_in_expr(e, target, var_name)),
3468 Stmt::Semi(e) => Stmt::Semi(replace_in_expr(e, target, var_name)),
3469 Stmt::Item(item) => Stmt::Item(item.clone()),
3470 })
3471 .collect();
3472
3473 let expr = block
3474 .expr
3475 .as_ref()
3476 .map(|e| Box::new(replace_in_expr(e, target, var_name)));
3477
3478 Block { stmts, expr }
3479}
3480
3481fn make_cse_let(var_name: &str, expr: Expr) -> Stmt {
3483 Stmt::Let {
3484 pattern: Pattern::Ident {
3485 mutable: false,
3486 name: Ident {
3487 name: var_name.to_string(),
3488 evidentiality: None,
3489 affect: None,
3490 span: Span { start: 0, end: 0 },
3491 },
3492 evidentiality: None,
3493 },
3494 ty: None,
3495 init: Some(expr),
3496 }
3497}
3498
3499pub fn optimize(file: &ast::SourceFile, level: OptLevel) -> (ast::SourceFile, OptStats) {
3505 let mut optimizer = Optimizer::new(level);
3506 let optimized = optimizer.optimize_file(file);
3507 (optimized, optimizer.stats)
3508}
3509
3510#[cfg(test)]
3515mod tests {
3516 use super::*;
3517
3518 fn int_lit(v: i64) -> Expr {
3520 Expr::Literal(Literal::Int {
3521 value: v.to_string(),
3522 base: NumBase::Decimal,
3523 suffix: None,
3524 })
3525 }
3526
3527 fn var(name: &str) -> Expr {
3529 Expr::Path(TypePath {
3530 segments: vec![PathSegment {
3531 ident: Ident {
3532 name: name.to_string(),
3533 evidentiality: None,
3534 affect: None,
3535 span: Span { start: 0, end: 0 },
3536 },
3537 generics: None,
3538 }],
3539 })
3540 }
3541
3542 fn add(left: Expr, right: Expr) -> Expr {
3544 Expr::Binary {
3545 op: BinOp::Add,
3546 left: Box::new(left),
3547 right: Box::new(right),
3548 }
3549 }
3550
3551 fn mul(left: Expr, right: Expr) -> Expr {
3553 Expr::Binary {
3554 op: BinOp::Mul,
3555 left: Box::new(left),
3556 right: Box::new(right),
3557 }
3558 }
3559
3560 #[test]
3561 fn test_expr_hash_equal() {
3562 let e1 = add(var("a"), var("b"));
3564 let e2 = add(var("a"), var("b"));
3565 assert_eq!(expr_hash(&e1), expr_hash(&e2));
3566 }
3567
3568 #[test]
3569 fn test_expr_hash_different() {
3570 let e1 = add(var("a"), var("b"));
3572 let e2 = add(var("a"), var("c"));
3573 assert_ne!(expr_hash(&e1), expr_hash(&e2));
3574 }
3575
3576 #[test]
3577 fn test_expr_eq() {
3578 let e1 = add(var("a"), var("b"));
3579 let e2 = add(var("a"), var("b"));
3580 let e3 = add(var("a"), var("c"));
3581
3582 assert!(expr_eq(&e1, &e2));
3583 assert!(!expr_eq(&e1, &e3));
3584 }
3585
3586 #[test]
3587 fn test_is_pure_expr() {
3588 assert!(is_pure_expr(&int_lit(42)));
3589 assert!(is_pure_expr(&var("x")));
3590 assert!(is_pure_expr(&add(var("a"), var("b"))));
3591
3592 let call = Expr::Call {
3594 func: Box::new(var("print")),
3595 args: vec![int_lit(42)],
3596 };
3597 assert!(!is_pure_expr(&call));
3598 }
3599
3600 #[test]
3601 fn test_is_cse_worthy() {
3602 assert!(!is_cse_worthy(&int_lit(42))); assert!(!is_cse_worthy(&var("x"))); assert!(is_cse_worthy(&add(var("a"), var("b")))); }
3606
3607 #[test]
3608 fn test_cse_basic() {
3609 let a_plus_b = add(var("a"), var("b"));
3614
3615 let block = Block {
3616 stmts: vec![
3617 Stmt::Let {
3618 pattern: Pattern::Ident {
3619 mutable: false,
3620 name: Ident {
3621 name: "x".to_string(),
3622 evidentiality: None,
3623 affect: None,
3624 span: Span { start: 0, end: 0 },
3625 },
3626 evidentiality: None,
3627 },
3628 ty: None,
3629 init: Some(a_plus_b.clone()),
3630 },
3631 Stmt::Let {
3632 pattern: Pattern::Ident {
3633 mutable: false,
3634 name: Ident {
3635 name: "y".to_string(),
3636 evidentiality: None,
3637 affect: None,
3638 span: Span { start: 0, end: 0 },
3639 },
3640 evidentiality: None,
3641 },
3642 ty: None,
3643 init: Some(mul(a_plus_b.clone(), int_lit(2))),
3644 },
3645 ],
3646 expr: None,
3647 };
3648
3649 let mut optimizer = Optimizer::new(OptLevel::Standard);
3650 let result = optimizer.pass_cse_block(&block);
3651
3652 assert_eq!(result.stmts.len(), 3);
3654 assert_eq!(optimizer.stats.expressions_deduplicated, 1);
3655
3656 if let Stmt::Let {
3658 pattern: Pattern::Ident { name, .. },
3659 ..
3660 } = &result.stmts[0]
3661 {
3662 assert_eq!(name.name, "__cse_0");
3663 } else {
3664 panic!("Expected CSE let binding");
3665 }
3666 }
3667
3668 #[test]
3669 fn test_cse_no_duplicates() {
3670 let block = Block {
3672 stmts: vec![
3673 Stmt::Let {
3674 pattern: Pattern::Ident {
3675 mutable: false,
3676 name: Ident {
3677 name: "x".to_string(),
3678 evidentiality: None,
3679 affect: None,
3680 span: Span { start: 0, end: 0 },
3681 },
3682 evidentiality: None,
3683 },
3684 ty: None,
3685 init: Some(add(var("a"), var("b"))),
3686 },
3687 Stmt::Let {
3688 pattern: Pattern::Ident {
3689 mutable: false,
3690 name: Ident {
3691 name: "y".to_string(),
3692 evidentiality: None,
3693 affect: None,
3694 span: Span { start: 0, end: 0 },
3695 },
3696 evidentiality: None,
3697 },
3698 ty: None,
3699 init: Some(add(var("c"), var("d"))),
3700 },
3701 ],
3702 expr: None,
3703 };
3704
3705 let mut optimizer = Optimizer::new(OptLevel::Standard);
3706 let result = optimizer.pass_cse_block(&block);
3707
3708 assert_eq!(result.stmts.len(), 2);
3710 assert_eq!(optimizer.stats.expressions_deduplicated, 0);
3711 }
3712}