1use crate::{
2 create_arg_demotion_pass, create_arg_pointee_mutability_tagger_pass, create_ccp_pass,
3 create_const_demotion_pass, create_const_folding_pass, create_cse_pass, create_dce_pass,
4 create_dom_fronts_pass, create_dominators_pass, create_escaped_symbols_pass,
5 create_fn_dedup_debug_profile_pass, create_fn_dedup_release_profile_pass,
6 create_fn_inline_pass, create_globals_dce_pass, create_mem2reg_pass, create_memcpyopt_pass,
7 create_misc_demotion_pass, create_module_printer_pass, create_module_verifier_pass,
8 create_postorder_pass, create_ret_demotion_pass, create_simplify_cfg_pass, create_sroa_pass,
9 Context, Function, IrError, Module, ARG_DEMOTION_NAME, ARG_POINTEE_MUTABILITY_TAGGER_NAME,
10 CCP_NAME, CONST_DEMOTION_NAME, CONST_FOLDING_NAME, CSE_NAME, DCE_NAME,
11 FN_DEDUP_DEBUG_PROFILE_NAME, FN_DEDUP_RELEASE_PROFILE_NAME, FN_INLINE_NAME, GLOBALS_DCE_NAME,
12 MEM2REG_NAME, MEMCPYOPT_NAME, MISC_DEMOTION_NAME, RET_DEMOTION_NAME, SIMPLIFY_CFG_NAME,
13 SROA_NAME,
14};
15use downcast_rs::{impl_downcast, Downcast};
16use rustc_hash::FxHashMap;
17use std::{
18 any::{type_name, TypeId},
19 collections::{hash_map, HashSet},
20};
21
22pub trait AnalysisResultT: Downcast {}
24impl_downcast!(AnalysisResultT);
25pub type AnalysisResult = Box<dyn AnalysisResultT>;
26
27pub trait PassScope {
29 fn get_arena_idx(&self) -> slotmap::DefaultKey;
30}
31impl PassScope for Module {
32 fn get_arena_idx(&self) -> slotmap::DefaultKey {
33 self.0
34 }
35}
36impl PassScope for Function {
37 fn get_arena_idx(&self) -> slotmap::DefaultKey {
38 self.0
39 }
40}
41
42#[derive(Clone)]
44pub enum PassMutability<S: PassScope> {
45 Analysis(fn(&Context, analyses: &AnalysisResults, S) -> Result<AnalysisResult, IrError>),
47 Transform(fn(&mut Context, analyses: &AnalysisResults, S) -> Result<bool, IrError>),
49}
50
51#[derive(Clone)]
53pub enum ScopedPass {
54 ModulePass(PassMutability<Module>),
55 FunctionPass(PassMutability<Function>),
56}
57
58pub struct Pass {
60 pub name: &'static str,
62 pub descr: &'static str,
64 pub deps: Vec<&'static str>,
66 pub runner: ScopedPass,
68}
69
70impl Pass {
71 pub fn is_analysis(&self) -> bool {
72 match &self.runner {
73 ScopedPass::ModulePass(pm) => matches!(pm, PassMutability::Analysis(_)),
74 ScopedPass::FunctionPass(pm) => matches!(pm, PassMutability::Analysis(_)),
75 }
76 }
77
78 pub fn is_transform(&self) -> bool {
79 !self.is_analysis()
80 }
81
82 pub fn is_module_pass(&self) -> bool {
83 matches!(self.runner, ScopedPass::ModulePass(_))
84 }
85
86 pub fn is_function_pass(&self) -> bool {
87 matches!(self.runner, ScopedPass::FunctionPass(_))
88 }
89}
90
91#[derive(Default)]
92pub struct AnalysisResults {
93 results: FxHashMap<(TypeId, (TypeId, slotmap::DefaultKey)), AnalysisResult>,
95 name_typeid_map: FxHashMap<&'static str, TypeId>,
96}
97
98impl AnalysisResults {
99 pub fn get_analysis_result<T: AnalysisResultT, S: PassScope + 'static>(&self, scope: S) -> &T {
102 self.results
103 .get(&(
104 TypeId::of::<T>(),
105 (TypeId::of::<S>(), scope.get_arena_idx()),
106 ))
107 .unwrap_or_else(|| {
108 panic!(
109 "Internal error. Analysis result {} unavailable for {} with idx {:?}",
110 type_name::<T>(),
111 type_name::<S>(),
112 scope.get_arena_idx()
113 )
114 })
115 .downcast_ref()
116 .expect("AnalysisResult: Incorrect type")
117 }
118
119 fn is_analysis_result_available<S: PassScope + 'static>(
121 &self,
122 name: &'static str,
123 scope: S,
124 ) -> bool {
125 self.name_typeid_map
126 .get(name)
127 .and_then(|result_typeid| {
128 self.results
129 .get(&(*result_typeid, (TypeId::of::<S>(), scope.get_arena_idx())))
130 })
131 .is_some()
132 }
133
134 fn add_result<S: PassScope + 'static>(
136 &mut self,
137 name: &'static str,
138 scope: S,
139 result: AnalysisResult,
140 ) {
141 let result_typeid = (*result).type_id();
142 self.results.insert(
143 (result_typeid, (TypeId::of::<S>(), scope.get_arena_idx())),
144 result,
145 );
146 self.name_typeid_map.insert(name, result_typeid);
147 }
148
149 fn invalidate_all_results_at_scope<S: PassScope + 'static>(&mut self, scope: S) {
151 self.results
152 .retain(|(_result_typeid, (scope_typeid, scope_idx)), _v| {
153 (*scope_typeid, *scope_idx) != (TypeId::of::<S>(), scope.get_arena_idx())
154 });
155 }
156}
157
158#[derive(Debug)]
165pub struct PrintPassesOpts {
166 pub initial: bool,
167 pub r#final: bool,
168 pub modified_only: bool,
169 pub passes: HashSet<String>,
170}
171
172#[derive(Default)]
173pub struct PassManager {
174 passes: FxHashMap<&'static str, Pass>,
175 analyses: AnalysisResults,
176}
177
178impl PassManager {
179 pub const OPTIMIZATION_PASSES: [&'static str; 14] = [
180 FN_INLINE_NAME,
181 SIMPLIFY_CFG_NAME,
182 SROA_NAME,
183 DCE_NAME,
184 GLOBALS_DCE_NAME,
185 FN_DEDUP_RELEASE_PROFILE_NAME,
186 FN_DEDUP_DEBUG_PROFILE_NAME,
187 MEM2REG_NAME,
188 MEMCPYOPT_NAME,
189 CONST_FOLDING_NAME,
190 ARG_DEMOTION_NAME,
191 CONST_DEMOTION_NAME,
192 RET_DEMOTION_NAME,
193 MISC_DEMOTION_NAME,
194 ];
195
196 pub fn register(&mut self, pass: Pass) -> &'static str {
198 for dep in &pass.deps {
199 if let Some(dep_t) = self.lookup_registered_pass(dep) {
200 if dep_t.is_transform() {
201 panic!(
202 "Pass {} cannot depend on a transformation pass {}",
203 pass.name, dep
204 );
205 }
206 if pass.is_function_pass() && dep_t.is_module_pass() {
207 panic!(
208 "Function pass {} cannot depend on module pass {}",
209 pass.name, dep
210 );
211 }
212 } else {
213 panic!(
214 "Pass {} depends on a (yet) unregistered pass {}",
215 pass.name, dep
216 );
217 }
218 }
219 let pass_name = pass.name;
220 match self.passes.entry(pass.name) {
221 hash_map::Entry::Occupied(_) => {
222 panic!("Trying to register an already registered pass");
223 }
224 hash_map::Entry::Vacant(entry) => {
225 entry.insert(pass);
226 }
227 }
228 pass_name
229 }
230
231 fn actually_run(&mut self, ir: &mut Context, pass: &'static str) -> Result<bool, IrError> {
232 let mut modified = false;
233
234 fn run_module_pass(
235 pm: &mut PassManager,
236 ir: &mut Context,
237 pass: &'static str,
238 module: Module,
239 ) -> Result<bool, IrError> {
240 let mut modified = false;
241 let pass_t = pm.passes.get(pass).expect("Unregistered pass");
242 for dep in pass_t.deps.clone() {
243 let dep_t = pm.passes.get(dep).expect("Unregistered dependent pass");
244 assert!(dep_t.is_analysis());
246 match dep_t.runner {
247 ScopedPass::ModulePass(_) => {
248 if !pm.analyses.is_analysis_result_available(dep, module) {
249 run_module_pass(pm, ir, dep, module)?;
250 }
251 }
252 ScopedPass::FunctionPass(_) => {
253 for f in module.function_iter(ir) {
254 if !pm.analyses.is_analysis_result_available(dep, f) {
255 run_function_pass(pm, ir, dep, f)?;
256 }
257 }
258 }
259 }
260 }
261
262 let pass_t = pm.passes.get(pass).expect("Unregistered pass");
264 let ScopedPass::ModulePass(mp) = pass_t.runner.clone() else {
265 panic!("Expected a module pass");
266 };
267 match mp {
268 PassMutability::Analysis(analysis) => {
269 let result = analysis(ir, &pm.analyses, module)?;
270 pm.analyses.add_result(pass, module, result);
271 }
272 PassMutability::Transform(transform) => {
273 if transform(ir, &pm.analyses, module)? {
274 pm.analyses.invalidate_all_results_at_scope(module);
275 for f in module.function_iter(ir) {
276 pm.analyses.invalidate_all_results_at_scope(f);
277 }
278 modified = true;
279 }
280 }
281 }
282
283 Ok(modified)
284 }
285
286 fn run_function_pass(
287 pm: &mut PassManager,
288 ir: &mut Context,
289 pass: &'static str,
290 function: Function,
291 ) -> Result<bool, IrError> {
292 let mut modified = false;
293 let pass_t = pm.passes.get(pass).expect("Unregistered pass");
294 for dep in pass_t.deps.clone() {
295 let dep_t = pm.passes.get(dep).expect("Unregistered dependent pass");
296 assert!(dep_t.is_analysis());
298 match dep_t.runner {
299 ScopedPass::ModulePass(_) => {
300 panic!(
301 "Function pass {} cannot depend on module pass {}",
302 pass, dep
303 )
304 }
305 ScopedPass::FunctionPass(_) => {
306 if !pm.analyses.is_analysis_result_available(dep, function) {
307 run_function_pass(pm, ir, dep, function)?;
308 };
309 }
310 }
311 }
312
313 let pass_t = pm.passes.get(pass).expect("Unregistered pass");
315 let ScopedPass::FunctionPass(fp) = pass_t.runner.clone() else {
316 panic!("Expected a function pass");
317 };
318 match fp {
319 PassMutability::Analysis(analysis) => {
320 let result = analysis(ir, &pm.analyses, function)?;
321 pm.analyses.add_result(pass, function, result);
322 }
323 PassMutability::Transform(transform) => {
324 if transform(ir, &pm.analyses, function)? {
325 pm.analyses.invalidate_all_results_at_scope(function);
326 modified = true;
327 }
328 }
329 }
330
331 Ok(modified)
332 }
333
334 for m in ir.module_iter() {
335 let pass_t = self.passes.get(pass).expect("Unregistered pass");
336 let pass_runner = pass_t.runner.clone();
337 match pass_runner {
338 ScopedPass::ModulePass(_) => {
339 modified |= run_module_pass(self, ir, pass, m)?;
340 }
341 ScopedPass::FunctionPass(_) => {
342 for f in m.function_iter(ir) {
343 modified |= run_function_pass(self, ir, pass, f)?;
344 }
345 }
346 }
347 }
348 Ok(modified)
349 }
350
351 pub fn run(&mut self, ir: &mut Context, passes: &PassGroup) -> Result<bool, IrError> {
353 let mut modified = false;
354 for pass in passes.flatten_pass_group() {
355 modified |= self.actually_run(ir, pass)?;
356 }
357 Ok(modified)
358 }
359
360 pub fn run_with_print(
363 &mut self,
364 ir: &mut Context,
365 passes: &PassGroup,
366 print_opts: &PrintPassesOpts,
367 ) -> Result<bool, IrError> {
368 fn ir_is_empty(ir: &Context) -> bool {
370 ir.functions.is_empty()
371 && ir.blocks.is_empty()
372 && ir.values.is_empty()
373 && ir.local_vars.is_empty()
374 }
375
376 fn print_ir_after_pass(ir: &Context, pass: &Pass) {
377 if !ir_is_empty(ir) {
378 println!("// IR: [{}] {}", pass.name, pass.descr);
379 println!("{ir}");
380 }
381 }
382
383 fn print_initial_or_final_ir(ir: &Context, initial_or_final: &'static str) {
384 if !ir_is_empty(ir) {
385 println!("// IR: {initial_or_final}");
386 println!("{ir}");
387 }
388 }
389
390 if print_opts.initial {
391 print_initial_or_final_ir(ir, "Initial");
392 }
393
394 let mut modified = false;
395 for pass in passes.flatten_pass_group() {
396 let modified_in_pass = self.actually_run(ir, pass)?;
397
398 if print_opts.passes.contains(pass) && (!print_opts.modified_only || modified_in_pass) {
399 print_ir_after_pass(ir, self.lookup_registered_pass(pass).unwrap());
400 }
401
402 modified |= modified_in_pass;
403 }
404
405 if print_opts.r#final {
406 print_initial_or_final_ir(ir, "Final");
407 }
408
409 Ok(modified)
410 }
411
412 pub fn lookup_registered_pass(&self, name: &str) -> Option<&Pass> {
414 self.passes.get(name)
415 }
416
417 pub fn help_text(&self) -> String {
418 let summary = self
419 .passes
420 .iter()
421 .map(|(name, pass)| format!(" {name:16} - {}", pass.descr))
422 .collect::<Vec<_>>()
423 .join("\n");
424
425 format!("Valid pass names are:\n\n{summary}",)
426 }
427}
428
429#[derive(Default)]
432pub struct PassGroup(Vec<PassOrGroup>);
433
434pub enum PassOrGroup {
436 Pass(&'static str),
437 Group(PassGroup),
438}
439
440impl PassGroup {
441 fn flatten_pass_group(&self) -> Vec<&'static str> {
443 let mut output = Vec::<&str>::new();
444 fn inner(output: &mut Vec<&str>, input: &PassGroup) {
445 for pass_or_group in &input.0 {
446 match pass_or_group {
447 PassOrGroup::Pass(pass) => output.push(pass),
448 PassOrGroup::Group(pg) => inner(output, pg),
449 }
450 }
451 }
452 inner(&mut output, self);
453 output
454 }
455
456 pub fn append_pass(&mut self, pass: &'static str) {
458 self.0.push(PassOrGroup::Pass(pass));
459 }
460
461 pub fn append_group(&mut self, group: PassGroup) {
463 self.0.push(PassOrGroup::Group(group));
464 }
465}
466
467pub fn register_known_passes(pm: &mut PassManager) {
469 pm.register(create_postorder_pass());
471 pm.register(create_dominators_pass());
472 pm.register(create_dom_fronts_pass());
473 pm.register(create_escaped_symbols_pass());
474 pm.register(create_module_printer_pass());
475 pm.register(create_module_verifier_pass());
476 pm.register(create_arg_pointee_mutability_tagger_pass());
478 pm.register(create_fn_dedup_release_profile_pass());
479 pm.register(create_fn_dedup_debug_profile_pass());
480 pm.register(create_mem2reg_pass());
481 pm.register(create_sroa_pass());
482 pm.register(create_fn_inline_pass());
483 pm.register(create_const_folding_pass());
484 pm.register(create_ccp_pass());
485 pm.register(create_simplify_cfg_pass());
486 pm.register(create_globals_dce_pass());
487 pm.register(create_dce_pass());
488 pm.register(create_cse_pass());
489 pm.register(create_arg_demotion_pass());
490 pm.register(create_const_demotion_pass());
491 pm.register(create_ret_demotion_pass());
492 pm.register(create_misc_demotion_pass());
493 pm.register(create_memcpyopt_pass());
494}
495
496pub fn create_o1_pass_group() -> PassGroup {
497 let mut o1 = PassGroup::default();
499 o1.append_pass(MEM2REG_NAME);
501 o1.append_pass(FN_DEDUP_RELEASE_PROFILE_NAME);
502 o1.append_pass(FN_INLINE_NAME);
503 o1.append_pass(ARG_POINTEE_MUTABILITY_TAGGER_NAME);
504 o1.append_pass(SIMPLIFY_CFG_NAME);
505 o1.append_pass(GLOBALS_DCE_NAME);
506 o1.append_pass(DCE_NAME);
507 o1.append_pass(FN_INLINE_NAME);
508 o1.append_pass(ARG_POINTEE_MUTABILITY_TAGGER_NAME);
509 o1.append_pass(CCP_NAME);
510 o1.append_pass(CONST_FOLDING_NAME);
511 o1.append_pass(SIMPLIFY_CFG_NAME);
512 o1.append_pass(CSE_NAME);
513 o1.append_pass(CONST_FOLDING_NAME);
514 o1.append_pass(SIMPLIFY_CFG_NAME);
515 o1.append_pass(GLOBALS_DCE_NAME);
516 o1.append_pass(DCE_NAME);
517 o1.append_pass(FN_DEDUP_RELEASE_PROFILE_NAME);
518
519 o1
520}
521
522pub fn insert_after_each(pg: PassGroup, pass: &'static str) -> PassGroup {
528 fn insert_after_each_rec(pg: PassGroup, pass: &'static str) -> Vec<PassOrGroup> {
529 pg.0.into_iter()
530 .flat_map(|p_o_g| match p_o_g {
531 PassOrGroup::Group(group) => vec![PassOrGroup::Group(PassGroup(
532 insert_after_each_rec(group, pass),
533 ))],
534 PassOrGroup::Pass(_) => vec![p_o_g, PassOrGroup::Pass(pass)],
535 })
536 .collect()
537 }
538
539 PassGroup(insert_after_each_rec(pg, pass))
540}