Skip to main content

sway_ir/
pass_manager.rs

1use crate::{
2    create_arg_demotion_pass, create_arg_pointee_mutability_tagger_pass, create_ccp_pass,
3    create_const_demotion_pass, create_const_folding_pass, create_cse_pass, create_dce_pass,
4    create_dom_fronts_pass, create_dominators_pass, create_escaped_symbols_pass,
5    create_fn_dedup_debug_profile_pass, create_fn_dedup_release_profile_pass,
6    create_fn_inline_pass, create_globals_dce_pass, create_init_aggr_lowering_pass,
7    create_mem2reg_pass, create_memcpyopt_pass, create_memcpyprop_reverse_pass,
8    create_misc_demotion_pass, create_module_printer_pass, create_module_verifier_pass,
9    create_postorder_pass, create_ret_demotion_pass, create_simplify_cfg_pass, create_sroa_pass,
10    Context, Function, IrError, Module, ARG_DEMOTION_NAME, ARG_POINTEE_MUTABILITY_TAGGER_NAME,
11    CCP_NAME, CONST_DEMOTION_NAME, CONST_FOLDING_NAME, CSE_NAME, DCE_NAME,
12    FN_DEDUP_DEBUG_PROFILE_NAME, FN_DEDUP_RELEASE_PROFILE_NAME, FN_INLINE_NAME, GLOBALS_DCE_NAME,
13    INIT_AGGR_LOWERING_NAME, MEM2REG_NAME, MEMCPYOPT_NAME, MEMCPYPROP_REVERSE_NAME,
14    MISC_DEMOTION_NAME, RET_DEMOTION_NAME, SIMPLIFY_CFG_NAME, SROA_NAME,
15};
16use downcast_rs::{impl_downcast, Downcast};
17use rustc_hash::FxHashMap;
18use std::{
19    any::{type_name, TypeId},
20    collections::{hash_map, HashSet},
21};
22
23/// Result of an analysis. Specific result must be downcasted to.
24pub trait AnalysisResultT: Downcast {}
25impl_downcast!(AnalysisResultT);
26pub type AnalysisResult = Box<dyn AnalysisResultT>;
27
28/// Program scope over which a pass executes.
29pub trait PassScope {
30    fn get_arena_idx(&self) -> slotmap::DefaultKey;
31}
32impl PassScope for Module {
33    fn get_arena_idx(&self) -> slotmap::DefaultKey {
34        self.0
35    }
36}
37impl PassScope for Function {
38    fn get_arena_idx(&self) -> slotmap::DefaultKey {
39        self.0
40    }
41}
42
43/// Is a pass an Analysis or a Transformation over the IR?
44#[derive(Clone)]
45pub enum PassMutability<S: PassScope> {
46    /// An analysis pass, producing an analysis result.
47    Analysis(fn(&Context, analyses: &AnalysisResults, S) -> Result<AnalysisResult, IrError>),
48    /// A pass over the IR that can possibly modify it.
49    Transform(fn(&mut Context, analyses: &AnalysisResults, S) -> Result<bool, IrError>),
50}
51
52/// A concrete version of [PassScope].
53#[derive(Clone)]
54pub enum ScopedPass {
55    ModulePass(PassMutability<Module>),
56    FunctionPass(PassMutability<Function>),
57}
58
59/// An analysis or transformation pass.
60pub struct Pass {
61    /// Pass identifier.
62    pub name: &'static str,
63    /// A short description.
64    pub descr: &'static str,
65    /// Other passes that this pass depends on.
66    pub deps: Vec<&'static str>,
67    /// The executor.
68    pub runner: ScopedPass,
69}
70
71impl Pass {
72    pub fn is_analysis(&self) -> bool {
73        match &self.runner {
74            ScopedPass::ModulePass(pm) => matches!(pm, PassMutability::Analysis(_)),
75            ScopedPass::FunctionPass(pm) => matches!(pm, PassMutability::Analysis(_)),
76        }
77    }
78
79    pub fn is_transform(&self) -> bool {
80        !self.is_analysis()
81    }
82
83    pub fn is_module_pass(&self) -> bool {
84        matches!(self.runner, ScopedPass::ModulePass(_))
85    }
86
87    pub fn is_function_pass(&self) -> bool {
88        matches!(self.runner, ScopedPass::FunctionPass(_))
89    }
90}
91
92#[derive(Default)]
93pub struct AnalysisResults {
94    // Hash from (AnalysisResultT, (PassScope, Scope Identity)) to an actual result.
95    results: FxHashMap<(TypeId, (TypeId, slotmap::DefaultKey)), AnalysisResult>,
96    name_typeid_map: FxHashMap<&'static str, TypeId>,
97}
98
99impl AnalysisResults {
100    /// Get the results of an analysis.
101    /// Example analyses.get_analysis_result::<DomTreeAnalysis>(foo).
102    pub fn get_analysis_result<T: AnalysisResultT, S: PassScope + 'static>(&self, scope: S) -> &T {
103        self.results
104            .get(&(
105                TypeId::of::<T>(),
106                (TypeId::of::<S>(), scope.get_arena_idx()),
107            ))
108            .unwrap_or_else(|| {
109                panic!(
110                    "Internal error. Analysis result {} unavailable for {} with idx {:?}",
111                    type_name::<T>(),
112                    type_name::<S>(),
113                    scope.get_arena_idx()
114                )
115            })
116            .downcast_ref()
117            .expect("AnalysisResult: Incorrect type")
118    }
119
120    /// Is an analysis result available at the given scope?
121    fn is_analysis_result_available<S: PassScope + 'static>(
122        &self,
123        name: &'static str,
124        scope: S,
125    ) -> bool {
126        self.name_typeid_map
127            .get(name)
128            .and_then(|result_typeid| {
129                self.results
130                    .get(&(*result_typeid, (TypeId::of::<S>(), scope.get_arena_idx())))
131            })
132            .is_some()
133    }
134
135    /// Add a new result.
136    fn add_result<S: PassScope + 'static>(
137        &mut self,
138        name: &'static str,
139        scope: S,
140        result: AnalysisResult,
141    ) {
142        let result_typeid = (*result).type_id();
143        self.results.insert(
144            (result_typeid, (TypeId::of::<S>(), scope.get_arena_idx())),
145            result,
146        );
147        self.name_typeid_map.insert(name, result_typeid);
148    }
149
150    /// Invalidate all results at a given scope.
151    fn invalidate_all_results_at_scope<S: PassScope + 'static>(&mut self, scope: S) {
152        self.results
153            .retain(|(_result_typeid, (scope_typeid, scope_idx)), _v| {
154                (*scope_typeid, *scope_idx) != (TypeId::of::<S>(), scope.get_arena_idx())
155            });
156    }
157}
158
159/// Options for printing [Pass]es in case of running them with printing requested.
160///
161/// Note that states of IR can always be printed by injecting the module printer pass
162/// and just running the passes. That approach however offers less control over the
163/// printing. E.g., requiring the printing to happen only if the previous passes
164/// modified the IR cannot be done by simply injecting a module printer.
165#[derive(Debug)]
166pub struct PrintPassesOpts {
167    pub initial: bool,
168    pub r#final: bool,
169    pub modified_only: bool,
170    pub passes: HashSet<String>,
171}
172
173/// Options for verifying [Pass]es in case of running them with verifying requested.
174///
175/// Note that states of IR can always be verified by injecting the module verifier pass
176/// and just running the passes. That approach however offers less control over the
177/// verification. E.g., requiring the verification to happen only if the previous passes
178/// modified the IR cannot be done by simply injecting a module verifier.
179#[derive(Debug)]
180pub struct VerifyPassesOpts {
181    pub initial: bool,
182    pub r#final: bool,
183    pub modified_only: bool,
184    pub passes: HashSet<String>,
185}
186
187#[derive(Default)]
188pub struct PassManager {
189    passes: FxHashMap<&'static str, Pass>,
190    analyses: AnalysisResults,
191}
192
193impl PassManager {
194    pub const OPTIMIZATION_PASSES: [&'static str; 16] = [
195        FN_INLINE_NAME,
196        SIMPLIFY_CFG_NAME,
197        SROA_NAME,
198        DCE_NAME,
199        GLOBALS_DCE_NAME,
200        FN_DEDUP_RELEASE_PROFILE_NAME,
201        FN_DEDUP_DEBUG_PROFILE_NAME,
202        MEM2REG_NAME,
203        MEMCPYOPT_NAME,
204        MEMCPYPROP_REVERSE_NAME,
205        CONST_FOLDING_NAME,
206        ARG_DEMOTION_NAME,
207        CONST_DEMOTION_NAME,
208        RET_DEMOTION_NAME,
209        MISC_DEMOTION_NAME,
210        INIT_AGGR_LOWERING_NAME,
211    ];
212
213    /// Register a pass. Should be called only once for each pass.
214    pub fn register(&mut self, pass: Pass) -> &'static str {
215        for dep in &pass.deps {
216            if let Some(dep_t) = self.lookup_registered_pass(dep) {
217                if dep_t.is_transform() {
218                    panic!(
219                        "Pass {} cannot depend on a transformation pass {}",
220                        pass.name, dep
221                    );
222                }
223                if pass.is_function_pass() && dep_t.is_module_pass() {
224                    panic!(
225                        "Function pass {} cannot depend on module pass {}",
226                        pass.name, dep
227                    );
228                }
229            } else {
230                panic!(
231                    "Pass {} depends on a (yet) unregistered pass {}",
232                    pass.name, dep
233                );
234            }
235        }
236        let pass_name = pass.name;
237        match self.passes.entry(pass.name) {
238            hash_map::Entry::Occupied(_) => {
239                panic!("Trying to register an already registered pass");
240            }
241            hash_map::Entry::Vacant(entry) => {
242                entry.insert(pass);
243            }
244        }
245        pass_name
246    }
247
248    fn actually_run(&mut self, ir: &mut Context, pass: &'static str) -> Result<bool, IrError> {
249        let mut modified = false;
250
251        fn run_module_pass(
252            pm: &mut PassManager,
253            ir: &mut Context,
254            pass: &'static str,
255            module: Module,
256        ) -> Result<bool, IrError> {
257            let mut modified = false;
258            let pass_t = pm.passes.get(pass).expect("Unregistered pass");
259            for dep in pass_t.deps.clone() {
260                let dep_t = pm.passes.get(dep).expect("Unregistered dependent pass");
261                // If pass registration allows transformations as dependents, we could remove this I guess.
262                assert!(dep_t.is_analysis());
263                match dep_t.runner {
264                    ScopedPass::ModulePass(_) => {
265                        if !pm.analyses.is_analysis_result_available(dep, module) {
266                            run_module_pass(pm, ir, dep, module)?;
267                        }
268                    }
269                    ScopedPass::FunctionPass(_) => {
270                        for f in module.function_iter(ir) {
271                            if !pm.analyses.is_analysis_result_available(dep, f) {
272                                run_function_pass(pm, ir, dep, f)?;
273                            }
274                        }
275                    }
276                }
277            }
278
279            // Get the pass again to satisfy the borrow checker.
280            let pass_t = pm.passes.get(pass).expect("Unregistered pass");
281            let ScopedPass::ModulePass(mp) = pass_t.runner.clone() else {
282                panic!("Expected a module pass");
283            };
284            match mp {
285                PassMutability::Analysis(analysis) => {
286                    let result = analysis(ir, &pm.analyses, module)?;
287                    pm.analyses.add_result(pass, module, result);
288                }
289                PassMutability::Transform(transform) => {
290                    if transform(ir, &pm.analyses, module)? {
291                        pm.analyses.invalidate_all_results_at_scope(module);
292                        for f in module.function_iter(ir) {
293                            pm.analyses.invalidate_all_results_at_scope(f);
294                        }
295                        modified = true;
296                    }
297                }
298            }
299
300            Ok(modified)
301        }
302
303        fn run_function_pass(
304            pm: &mut PassManager,
305            ir: &mut Context,
306            pass: &'static str,
307            function: Function,
308        ) -> Result<bool, IrError> {
309            let mut modified = false;
310            let pass_t = pm.passes.get(pass).expect("Unregistered pass");
311            for dep in pass_t.deps.clone() {
312                let dep_t = pm.passes.get(dep).expect("Unregistered dependent pass");
313                // If pass registration allows transformations as dependents, we could remove this I guess.
314                assert!(dep_t.is_analysis());
315                match dep_t.runner {
316                    ScopedPass::ModulePass(_) => {
317                        panic!("Function pass {pass} cannot depend on module pass {dep}")
318                    }
319                    ScopedPass::FunctionPass(_) => {
320                        if !pm.analyses.is_analysis_result_available(dep, function) {
321                            run_function_pass(pm, ir, dep, function)?;
322                        };
323                    }
324                }
325            }
326
327            // Get the pass again to satisfy the borrow checker.
328            let pass_t = pm.passes.get(pass).expect("Unregistered pass");
329            let ScopedPass::FunctionPass(fp) = pass_t.runner.clone() else {
330                panic!("Expected a function pass");
331            };
332            match fp {
333                PassMutability::Analysis(analysis) => {
334                    let result = analysis(ir, &pm.analyses, function)?;
335                    pm.analyses.add_result(pass, function, result);
336                }
337                PassMutability::Transform(transform) => {
338                    if transform(ir, &pm.analyses, function)? {
339                        pm.analyses.invalidate_all_results_at_scope(function);
340                        modified = true;
341                    }
342                }
343            }
344
345            Ok(modified)
346        }
347
348        for m in ir.module_iter() {
349            let pass_t = self.passes.get(pass).expect("Unregistered pass");
350            let pass_runner = pass_t.runner.clone();
351            match pass_runner {
352                ScopedPass::ModulePass(_) => {
353                    modified |= run_module_pass(self, ir, pass, m)?;
354                }
355                ScopedPass::FunctionPass(_) => {
356                    for f in m.function_iter(ir) {
357                        modified |= run_function_pass(self, ir, pass, f)?;
358                    }
359                }
360            }
361        }
362        Ok(modified)
363    }
364
365    /// Run the `passes` and return true if the `passes` modify the initial `ir`.
366    pub fn run(&mut self, ir: &mut Context, passes: &PassGroup) -> Result<bool, IrError> {
367        let mut modified = false;
368        for pass in passes.flatten_pass_group() {
369            modified |= self.actually_run(ir, pass)?;
370        }
371        Ok(modified)
372    }
373
374    /// Run the `passes` and return true if the `passes` modify the initial `ir`.
375    /// The IR states are printed and verified according to the options provided.
376    pub fn run_with_print_verify(
377        &mut self,
378        ir: &mut Context,
379        passes: &PassGroup,
380        print_opts: &PrintPassesOpts,
381        verify_opts: &VerifyPassesOpts,
382    ) -> Result<bool, IrError> {
383        // Empty IRs are result of compiling dependencies. We don't want to print those.
384        fn ir_is_empty(ir: &Context) -> bool {
385            ir.functions.is_empty()
386                && ir.blocks.is_empty()
387                && ir.values.is_empty()
388                && ir.local_vars.is_empty()
389        }
390
391        fn print_ir_after_pass(ir: &Context, pass: &Pass) {
392            if !ir_is_empty(ir) {
393                println!("// IR: [{}] {}", pass.name, pass.descr);
394                println!("{ir}");
395            }
396        }
397
398        fn print_initial_or_final_ir(ir: &Context, initial_or_final: &'static str) {
399            if !ir_is_empty(ir) {
400                println!("// IR: {initial_or_final}");
401                println!("{ir}");
402            }
403        }
404
405        if print_opts.initial {
406            print_initial_or_final_ir(ir, "Initial");
407        }
408
409        if verify_opts.initial {
410            ir.verify()?;
411        }
412
413        let mut modified = false;
414        for pass in passes.flatten_pass_group() {
415            let modified_in_pass = self.actually_run(ir, pass)?;
416
417            if print_opts.passes.contains(pass) && (!print_opts.modified_only || modified_in_pass) {
418                print_ir_after_pass(ir, self.lookup_registered_pass(pass).unwrap());
419            }
420
421            modified |= modified_in_pass;
422            if verify_opts.passes.contains(pass) && (!verify_opts.modified_only || modified_in_pass)
423            {
424                ir.verify()?;
425            }
426        }
427
428        if print_opts.r#final {
429            print_initial_or_final_ir(ir, "Final");
430        }
431
432        if verify_opts.r#final {
433            ir.verify()?;
434        }
435
436        Ok(modified)
437    }
438
439    /// Get reference to a registered pass.
440    pub fn lookup_registered_pass(&self, name: &str) -> Option<&Pass> {
441        self.passes.get(name)
442    }
443
444    pub fn help_text(&self) -> String {
445        let summary = self
446            .passes
447            .iter()
448            .map(|(name, pass)| format!("  {name:16} - {}", pass.descr))
449            .collect::<Vec<_>>()
450            .join("\n");
451
452        format!("Valid pass names are:\n\n{summary}",)
453    }
454}
455
456/// A group of passes.
457/// Can contain sub-groups.
458#[derive(Default)]
459pub struct PassGroup(Vec<PassOrGroup>);
460
461/// An individual pass, or a group (with possible subgroup) of passes.
462pub enum PassOrGroup {
463    Pass(&'static str),
464    Group(PassGroup),
465}
466
467impl PassGroup {
468    // Flatten a group of passes into an ordered list.
469    fn flatten_pass_group(&self) -> Vec<&'static str> {
470        let mut output = Vec::<&str>::new();
471        fn inner(output: &mut Vec<&str>, input: &PassGroup) {
472            for pass_or_group in &input.0 {
473                match pass_or_group {
474                    PassOrGroup::Pass(pass) => output.push(pass),
475                    PassOrGroup::Group(pg) => inner(output, pg),
476                }
477            }
478        }
479        inner(&mut output, self);
480        output
481    }
482
483    /// Append a pass to this group.
484    pub fn append_pass(&mut self, pass: &'static str) {
485        self.0.push(PassOrGroup::Pass(pass));
486    }
487
488    /// Append a pass group.
489    pub fn append_group(&mut self, group: PassGroup) {
490        self.0.push(PassOrGroup::Group(group));
491    }
492}
493
494/// A convenience utility to register known passes.
495pub fn register_known_passes(pm: &mut PassManager) {
496    // Analysis passes.
497    pm.register(create_postorder_pass());
498    pm.register(create_dominators_pass());
499    pm.register(create_dom_fronts_pass());
500    pm.register(create_escaped_symbols_pass());
501    pm.register(create_module_printer_pass());
502    pm.register(create_module_verifier_pass());
503
504    // Lowering passes.
505    pm.register(create_init_aggr_lowering_pass());
506
507    // Optimization passes.
508    pm.register(create_arg_pointee_mutability_tagger_pass());
509    pm.register(create_fn_dedup_release_profile_pass());
510    pm.register(create_fn_dedup_debug_profile_pass());
511    pm.register(create_mem2reg_pass());
512    pm.register(create_sroa_pass());
513    pm.register(create_fn_inline_pass());
514    pm.register(create_const_folding_pass());
515    pm.register(create_ccp_pass());
516    pm.register(create_simplify_cfg_pass());
517    pm.register(create_globals_dce_pass());
518    pm.register(create_dce_pass());
519    pm.register(create_cse_pass());
520    pm.register(create_arg_demotion_pass());
521    pm.register(create_const_demotion_pass());
522    pm.register(create_ret_demotion_pass());
523    pm.register(create_misc_demotion_pass());
524    pm.register(create_memcpyopt_pass());
525    pm.register(create_memcpyprop_reverse_pass());
526}
527
528pub fn create_o1_pass_group() -> PassGroup {
529    // Create a create_ccp_passo specify which passes we want to run now.
530    let mut o1 = PassGroup::default();
531    // Configure to run our passes.
532    o1.append_pass(MEM2REG_NAME);
533    o1.append_pass(FN_DEDUP_RELEASE_PROFILE_NAME);
534    o1.append_pass(FN_INLINE_NAME);
535    o1.append_pass(ARG_POINTEE_MUTABILITY_TAGGER_NAME);
536    o1.append_pass(SIMPLIFY_CFG_NAME);
537    o1.append_pass(GLOBALS_DCE_NAME);
538    o1.append_pass(DCE_NAME);
539    o1.append_pass(FN_INLINE_NAME);
540    o1.append_pass(ARG_POINTEE_MUTABILITY_TAGGER_NAME);
541    o1.append_pass(CCP_NAME);
542    o1.append_pass(CONST_FOLDING_NAME);
543    o1.append_pass(SIMPLIFY_CFG_NAME);
544    o1.append_pass(CSE_NAME);
545    o1.append_pass(CONST_FOLDING_NAME);
546    o1.append_pass(SIMPLIFY_CFG_NAME);
547    o1.append_pass(GLOBALS_DCE_NAME);
548    o1.append_pass(DCE_NAME);
549    o1.append_pass(FN_DEDUP_RELEASE_PROFILE_NAME);
550
551    o1
552}
553
554/// Utility to insert a pass after every pass in the given group `pg`.
555/// It preserves the `pg` group's structure. This means if `pg` has subgroups
556/// and those have subgroups, the resulting [PassGroup] will have the
557/// same subgroups, but with the `pass` inserted after every pass in every
558/// subgroup, as well as all passes outside of any groups.
559pub fn insert_after_each(pg: PassGroup, pass: &'static str) -> PassGroup {
560    fn insert_after_each_rec(pg: PassGroup, pass: &'static str) -> Vec<PassOrGroup> {
561        pg.0.into_iter()
562            .flat_map(|p_o_g| match p_o_g {
563                PassOrGroup::Group(group) => vec![PassOrGroup::Group(PassGroup(
564                    insert_after_each_rec(group, pass),
565                ))],
566                PassOrGroup::Pass(_) => vec![p_o_g, PassOrGroup::Pass(pass)],
567            })
568            .collect()
569    }
570
571    PassGroup(insert_after_each_rec(pg, pass))
572}