1use crate::{
2 create_arg_demotion_pass, create_arg_pointee_mutability_tagger_pass, create_ccp_pass,
3 create_const_demotion_pass, create_const_folding_pass, create_cse_pass, create_dce_pass,
4 create_dom_fronts_pass, create_dominators_pass, create_escaped_symbols_pass,
5 create_fn_dedup_debug_profile_pass, create_fn_dedup_release_profile_pass,
6 create_fn_inline_pass, create_globals_dce_pass, create_mem2reg_pass, create_memcpyopt_pass,
7 create_memcpyprop_reverse_pass, create_misc_demotion_pass, create_module_printer_pass,
8 create_module_verifier_pass, create_postorder_pass, create_ret_demotion_pass,
9 create_simplify_cfg_pass, create_sroa_pass, Context, Function, IrError, Module,
10 ARG_DEMOTION_NAME, ARG_POINTEE_MUTABILITY_TAGGER_NAME, CCP_NAME, CONST_DEMOTION_NAME,
11 CONST_FOLDING_NAME, CSE_NAME, DCE_NAME, FN_DEDUP_DEBUG_PROFILE_NAME,
12 FN_DEDUP_RELEASE_PROFILE_NAME, FN_INLINE_NAME, GLOBALS_DCE_NAME, MEM2REG_NAME, MEMCPYOPT_NAME,
13 MEMCPYPROP_REVERSE_NAME, MISC_DEMOTION_NAME, RET_DEMOTION_NAME, SIMPLIFY_CFG_NAME, SROA_NAME,
14};
15use downcast_rs::{impl_downcast, Downcast};
16use rustc_hash::FxHashMap;
17use std::{
18 any::{type_name, TypeId},
19 collections::{hash_map, HashSet},
20};
21
22pub trait AnalysisResultT: Downcast {}
24impl_downcast!(AnalysisResultT);
25pub type AnalysisResult = Box<dyn AnalysisResultT>;
26
27pub trait PassScope {
29 fn get_arena_idx(&self) -> slotmap::DefaultKey;
30}
31impl PassScope for Module {
32 fn get_arena_idx(&self) -> slotmap::DefaultKey {
33 self.0
34 }
35}
36impl PassScope for Function {
37 fn get_arena_idx(&self) -> slotmap::DefaultKey {
38 self.0
39 }
40}
41
42#[derive(Clone)]
44pub enum PassMutability<S: PassScope> {
45 Analysis(fn(&Context, analyses: &AnalysisResults, S) -> Result<AnalysisResult, IrError>),
47 Transform(fn(&mut Context, analyses: &AnalysisResults, S) -> Result<bool, IrError>),
49}
50
51#[derive(Clone)]
53pub enum ScopedPass {
54 ModulePass(PassMutability<Module>),
55 FunctionPass(PassMutability<Function>),
56}
57
58pub struct Pass {
60 pub name: &'static str,
62 pub descr: &'static str,
64 pub deps: Vec<&'static str>,
66 pub runner: ScopedPass,
68}
69
70impl Pass {
71 pub fn is_analysis(&self) -> bool {
72 match &self.runner {
73 ScopedPass::ModulePass(pm) => matches!(pm, PassMutability::Analysis(_)),
74 ScopedPass::FunctionPass(pm) => matches!(pm, PassMutability::Analysis(_)),
75 }
76 }
77
78 pub fn is_transform(&self) -> bool {
79 !self.is_analysis()
80 }
81
82 pub fn is_module_pass(&self) -> bool {
83 matches!(self.runner, ScopedPass::ModulePass(_))
84 }
85
86 pub fn is_function_pass(&self) -> bool {
87 matches!(self.runner, ScopedPass::FunctionPass(_))
88 }
89}
90
91#[derive(Default)]
92pub struct AnalysisResults {
93 results: FxHashMap<(TypeId, (TypeId, slotmap::DefaultKey)), AnalysisResult>,
95 name_typeid_map: FxHashMap<&'static str, TypeId>,
96}
97
98impl AnalysisResults {
99 pub fn get_analysis_result<T: AnalysisResultT, S: PassScope + 'static>(&self, scope: S) -> &T {
102 self.results
103 .get(&(
104 TypeId::of::<T>(),
105 (TypeId::of::<S>(), scope.get_arena_idx()),
106 ))
107 .unwrap_or_else(|| {
108 panic!(
109 "Internal error. Analysis result {} unavailable for {} with idx {:?}",
110 type_name::<T>(),
111 type_name::<S>(),
112 scope.get_arena_idx()
113 )
114 })
115 .downcast_ref()
116 .expect("AnalysisResult: Incorrect type")
117 }
118
119 fn is_analysis_result_available<S: PassScope + 'static>(
121 &self,
122 name: &'static str,
123 scope: S,
124 ) -> bool {
125 self.name_typeid_map
126 .get(name)
127 .and_then(|result_typeid| {
128 self.results
129 .get(&(*result_typeid, (TypeId::of::<S>(), scope.get_arena_idx())))
130 })
131 .is_some()
132 }
133
134 fn add_result<S: PassScope + 'static>(
136 &mut self,
137 name: &'static str,
138 scope: S,
139 result: AnalysisResult,
140 ) {
141 let result_typeid = (*result).type_id();
142 self.results.insert(
143 (result_typeid, (TypeId::of::<S>(), scope.get_arena_idx())),
144 result,
145 );
146 self.name_typeid_map.insert(name, result_typeid);
147 }
148
149 fn invalidate_all_results_at_scope<S: PassScope + 'static>(&mut self, scope: S) {
151 self.results
152 .retain(|(_result_typeid, (scope_typeid, scope_idx)), _v| {
153 (*scope_typeid, *scope_idx) != (TypeId::of::<S>(), scope.get_arena_idx())
154 });
155 }
156}
157
158#[derive(Debug)]
165pub struct PrintPassesOpts {
166 pub initial: bool,
167 pub r#final: bool,
168 pub modified_only: bool,
169 pub passes: HashSet<String>,
170}
171
172#[derive(Debug)]
179pub struct VerifyPassesOpts {
180 pub initial: bool,
181 pub r#final: bool,
182 pub modified_only: bool,
183 pub passes: HashSet<String>,
184}
185
186#[derive(Default)]
187pub struct PassManager {
188 passes: FxHashMap<&'static str, Pass>,
189 analyses: AnalysisResults,
190}
191
192impl PassManager {
193 pub const OPTIMIZATION_PASSES: [&'static str; 15] = [
194 FN_INLINE_NAME,
195 SIMPLIFY_CFG_NAME,
196 SROA_NAME,
197 DCE_NAME,
198 GLOBALS_DCE_NAME,
199 FN_DEDUP_RELEASE_PROFILE_NAME,
200 FN_DEDUP_DEBUG_PROFILE_NAME,
201 MEM2REG_NAME,
202 MEMCPYOPT_NAME,
203 MEMCPYPROP_REVERSE_NAME,
204 CONST_FOLDING_NAME,
205 ARG_DEMOTION_NAME,
206 CONST_DEMOTION_NAME,
207 RET_DEMOTION_NAME,
208 MISC_DEMOTION_NAME,
209 ];
210
211 pub fn register(&mut self, pass: Pass) -> &'static str {
213 for dep in &pass.deps {
214 if let Some(dep_t) = self.lookup_registered_pass(dep) {
215 if dep_t.is_transform() {
216 panic!(
217 "Pass {} cannot depend on a transformation pass {}",
218 pass.name, dep
219 );
220 }
221 if pass.is_function_pass() && dep_t.is_module_pass() {
222 panic!(
223 "Function pass {} cannot depend on module pass {}",
224 pass.name, dep
225 );
226 }
227 } else {
228 panic!(
229 "Pass {} depends on a (yet) unregistered pass {}",
230 pass.name, dep
231 );
232 }
233 }
234 let pass_name = pass.name;
235 match self.passes.entry(pass.name) {
236 hash_map::Entry::Occupied(_) => {
237 panic!("Trying to register an already registered pass");
238 }
239 hash_map::Entry::Vacant(entry) => {
240 entry.insert(pass);
241 }
242 }
243 pass_name
244 }
245
246 fn actually_run(&mut self, ir: &mut Context, pass: &'static str) -> Result<bool, IrError> {
247 let mut modified = false;
248
249 fn run_module_pass(
250 pm: &mut PassManager,
251 ir: &mut Context,
252 pass: &'static str,
253 module: Module,
254 ) -> Result<bool, IrError> {
255 let mut modified = false;
256 let pass_t = pm.passes.get(pass).expect("Unregistered pass");
257 for dep in pass_t.deps.clone() {
258 let dep_t = pm.passes.get(dep).expect("Unregistered dependent pass");
259 assert!(dep_t.is_analysis());
261 match dep_t.runner {
262 ScopedPass::ModulePass(_) => {
263 if !pm.analyses.is_analysis_result_available(dep, module) {
264 run_module_pass(pm, ir, dep, module)?;
265 }
266 }
267 ScopedPass::FunctionPass(_) => {
268 for f in module.function_iter(ir) {
269 if !pm.analyses.is_analysis_result_available(dep, f) {
270 run_function_pass(pm, ir, dep, f)?;
271 }
272 }
273 }
274 }
275 }
276
277 let pass_t = pm.passes.get(pass).expect("Unregistered pass");
279 let ScopedPass::ModulePass(mp) = pass_t.runner.clone() else {
280 panic!("Expected a module pass");
281 };
282 match mp {
283 PassMutability::Analysis(analysis) => {
284 let result = analysis(ir, &pm.analyses, module)?;
285 pm.analyses.add_result(pass, module, result);
286 }
287 PassMutability::Transform(transform) => {
288 if transform(ir, &pm.analyses, module)? {
289 pm.analyses.invalidate_all_results_at_scope(module);
290 for f in module.function_iter(ir) {
291 pm.analyses.invalidate_all_results_at_scope(f);
292 }
293 modified = true;
294 }
295 }
296 }
297
298 Ok(modified)
299 }
300
301 fn run_function_pass(
302 pm: &mut PassManager,
303 ir: &mut Context,
304 pass: &'static str,
305 function: Function,
306 ) -> Result<bool, IrError> {
307 let mut modified = false;
308 let pass_t = pm.passes.get(pass).expect("Unregistered pass");
309 for dep in pass_t.deps.clone() {
310 let dep_t = pm.passes.get(dep).expect("Unregistered dependent pass");
311 assert!(dep_t.is_analysis());
313 match dep_t.runner {
314 ScopedPass::ModulePass(_) => {
315 panic!("Function pass {pass} cannot depend on module pass {dep}")
316 }
317 ScopedPass::FunctionPass(_) => {
318 if !pm.analyses.is_analysis_result_available(dep, function) {
319 run_function_pass(pm, ir, dep, function)?;
320 };
321 }
322 }
323 }
324
325 let pass_t = pm.passes.get(pass).expect("Unregistered pass");
327 let ScopedPass::FunctionPass(fp) = pass_t.runner.clone() else {
328 panic!("Expected a function pass");
329 };
330 match fp {
331 PassMutability::Analysis(analysis) => {
332 let result = analysis(ir, &pm.analyses, function)?;
333 pm.analyses.add_result(pass, function, result);
334 }
335 PassMutability::Transform(transform) => {
336 if transform(ir, &pm.analyses, function)? {
337 pm.analyses.invalidate_all_results_at_scope(function);
338 modified = true;
339 }
340 }
341 }
342
343 Ok(modified)
344 }
345
346 for m in ir.module_iter() {
347 let pass_t = self.passes.get(pass).expect("Unregistered pass");
348 let pass_runner = pass_t.runner.clone();
349 match pass_runner {
350 ScopedPass::ModulePass(_) => {
351 modified |= run_module_pass(self, ir, pass, m)?;
352 }
353 ScopedPass::FunctionPass(_) => {
354 for f in m.function_iter(ir) {
355 modified |= run_function_pass(self, ir, pass, f)?;
356 }
357 }
358 }
359 }
360 Ok(modified)
361 }
362
363 pub fn run(&mut self, ir: &mut Context, passes: &PassGroup) -> Result<bool, IrError> {
365 let mut modified = false;
366 for pass in passes.flatten_pass_group() {
367 modified |= self.actually_run(ir, pass)?;
368 }
369 Ok(modified)
370 }
371
372 pub fn run_with_print_verify(
375 &mut self,
376 ir: &mut Context,
377 passes: &PassGroup,
378 print_opts: &PrintPassesOpts,
379 verify_opts: &VerifyPassesOpts,
380 ) -> Result<bool, IrError> {
381 fn ir_is_empty(ir: &Context) -> bool {
383 ir.functions.is_empty()
384 && ir.blocks.is_empty()
385 && ir.values.is_empty()
386 && ir.local_vars.is_empty()
387 }
388
389 fn print_ir_after_pass(ir: &Context, pass: &Pass) {
390 if !ir_is_empty(ir) {
391 println!("// IR: [{}] {}", pass.name, pass.descr);
392 println!("{ir}");
393 }
394 }
395
396 fn print_initial_or_final_ir(ir: &Context, initial_or_final: &'static str) {
397 if !ir_is_empty(ir) {
398 println!("// IR: {initial_or_final}");
399 println!("{ir}");
400 }
401 }
402
403 if print_opts.initial {
404 print_initial_or_final_ir(ir, "Initial");
405 }
406
407 if verify_opts.initial {
408 ir.verify()?;
409 }
410
411 let mut modified = false;
412 for pass in passes.flatten_pass_group() {
413 let modified_in_pass = self.actually_run(ir, pass)?;
414
415 if print_opts.passes.contains(pass) && (!print_opts.modified_only || modified_in_pass) {
416 print_ir_after_pass(ir, self.lookup_registered_pass(pass).unwrap());
417 }
418
419 modified |= modified_in_pass;
420 if verify_opts.passes.contains(pass) && (!verify_opts.modified_only || modified_in_pass)
421 {
422 ir.verify()?;
423 }
424 }
425
426 if print_opts.r#final {
427 print_initial_or_final_ir(ir, "Final");
428 }
429
430 if verify_opts.r#final {
431 ir.verify()?;
432 }
433
434 Ok(modified)
435 }
436
437 pub fn lookup_registered_pass(&self, name: &str) -> Option<&Pass> {
439 self.passes.get(name)
440 }
441
442 pub fn help_text(&self) -> String {
443 let summary = self
444 .passes
445 .iter()
446 .map(|(name, pass)| format!(" {name:16} - {}", pass.descr))
447 .collect::<Vec<_>>()
448 .join("\n");
449
450 format!("Valid pass names are:\n\n{summary}",)
451 }
452}
453
454#[derive(Default)]
457pub struct PassGroup(Vec<PassOrGroup>);
458
459pub enum PassOrGroup {
461 Pass(&'static str),
462 Group(PassGroup),
463}
464
465impl PassGroup {
466 fn flatten_pass_group(&self) -> Vec<&'static str> {
468 let mut output = Vec::<&str>::new();
469 fn inner(output: &mut Vec<&str>, input: &PassGroup) {
470 for pass_or_group in &input.0 {
471 match pass_or_group {
472 PassOrGroup::Pass(pass) => output.push(pass),
473 PassOrGroup::Group(pg) => inner(output, pg),
474 }
475 }
476 }
477 inner(&mut output, self);
478 output
479 }
480
481 pub fn append_pass(&mut self, pass: &'static str) {
483 self.0.push(PassOrGroup::Pass(pass));
484 }
485
486 pub fn append_group(&mut self, group: PassGroup) {
488 self.0.push(PassOrGroup::Group(group));
489 }
490}
491
492pub fn register_known_passes(pm: &mut PassManager) {
494 pm.register(create_postorder_pass());
496 pm.register(create_dominators_pass());
497 pm.register(create_dom_fronts_pass());
498 pm.register(create_escaped_symbols_pass());
499 pm.register(create_module_printer_pass());
500 pm.register(create_module_verifier_pass());
501 pm.register(create_arg_pointee_mutability_tagger_pass());
503 pm.register(create_fn_dedup_release_profile_pass());
504 pm.register(create_fn_dedup_debug_profile_pass());
505 pm.register(create_mem2reg_pass());
506 pm.register(create_sroa_pass());
507 pm.register(create_fn_inline_pass());
508 pm.register(create_const_folding_pass());
509 pm.register(create_ccp_pass());
510 pm.register(create_simplify_cfg_pass());
511 pm.register(create_globals_dce_pass());
512 pm.register(create_dce_pass());
513 pm.register(create_cse_pass());
514 pm.register(create_arg_demotion_pass());
515 pm.register(create_const_demotion_pass());
516 pm.register(create_ret_demotion_pass());
517 pm.register(create_misc_demotion_pass());
518 pm.register(create_memcpyopt_pass());
519 pm.register(create_memcpyprop_reverse_pass());
520}
521
522pub fn create_o1_pass_group() -> PassGroup {
523 let mut o1 = PassGroup::default();
525 o1.append_pass(MEM2REG_NAME);
527 o1.append_pass(FN_DEDUP_RELEASE_PROFILE_NAME);
528 o1.append_pass(FN_INLINE_NAME);
529 o1.append_pass(ARG_POINTEE_MUTABILITY_TAGGER_NAME);
530 o1.append_pass(SIMPLIFY_CFG_NAME);
531 o1.append_pass(GLOBALS_DCE_NAME);
532 o1.append_pass(DCE_NAME);
533 o1.append_pass(FN_INLINE_NAME);
534 o1.append_pass(ARG_POINTEE_MUTABILITY_TAGGER_NAME);
535 o1.append_pass(CCP_NAME);
536 o1.append_pass(CONST_FOLDING_NAME);
537 o1.append_pass(SIMPLIFY_CFG_NAME);
538 o1.append_pass(CSE_NAME);
539 o1.append_pass(CONST_FOLDING_NAME);
540 o1.append_pass(SIMPLIFY_CFG_NAME);
541 o1.append_pass(GLOBALS_DCE_NAME);
542 o1.append_pass(DCE_NAME);
543 o1.append_pass(FN_DEDUP_RELEASE_PROFILE_NAME);
544
545 o1
546}
547
548pub fn insert_after_each(pg: PassGroup, pass: &'static str) -> PassGroup {
554 fn insert_after_each_rec(pg: PassGroup, pass: &'static str) -> Vec<PassOrGroup> {
555 pg.0.into_iter()
556 .flat_map(|p_o_g| match p_o_g {
557 PassOrGroup::Group(group) => vec![PassOrGroup::Group(PassGroup(
558 insert_after_each_rec(group, pass),
559 ))],
560 PassOrGroup::Pass(_) => vec![p_o_g, PassOrGroup::Pass(pass)],
561 })
562 .collect()
563 }
564
565 PassGroup(insert_after_each_rec(pg, pass))
566}