sway_ir/optimize/
sroa.rs

1//! Scalar Replacement of Aggregates
2
3use rustc_hash::{FxHashMap, FxHashSet};
4
5use crate::{
6    combine_indices, get_gep_referred_symbols, get_loaded_ptr_values, get_stored_ptr_values,
7    pointee_size, AnalysisResults, Constant, ConstantValue, Context, EscapedSymbols, Function,
8    InstOp, IrError, LocalVar, Pass, PassMutability, ScopedPass, Symbol, Type, Value,
9    ESCAPED_SYMBOLS_NAME,
10};
11
12pub const SROA_NAME: &str = "sroa";
13
14pub fn create_sroa_pass() -> Pass {
15    Pass {
16        name: SROA_NAME,
17        descr: "Scalar replacement of aggregates",
18        deps: vec![ESCAPED_SYMBOLS_NAME],
19        runner: ScopedPass::FunctionPass(PassMutability::Transform(sroa)),
20    }
21}
22
23// Split at a local aggregate variable into its constituent scalars.
24// Returns a map from the offset of each scalar field to the new local created for it.
25fn split_aggregate(
26    context: &mut Context,
27    function: Function,
28    local_aggr: LocalVar,
29) -> FxHashMap<u32, LocalVar> {
30    let ty = local_aggr
31        .get_type(context)
32        .get_pointee_type(context)
33        .expect("Local not a pointer");
34    assert!(ty.is_aggregate(context));
35    let mut res = FxHashMap::default();
36    let aggr_base_name = function
37        .lookup_local_name(context, &local_aggr)
38        .cloned()
39        .unwrap_or("".to_string());
40
41    fn split_type(
42        context: &mut Context,
43        function: Function,
44        aggr_base_name: &String,
45        map: &mut FxHashMap<u32, LocalVar>,
46        ty: Type,
47        initializer: Option<Constant>,
48        base_off: &mut u32,
49    ) {
50        fn constant_index(context: &mut Context, c: &Constant, idx: usize) -> Constant {
51            match &c.get_content(context).value {
52                ConstantValue::Array(cs) | ConstantValue::Struct(cs) => Constant::unique(
53                    context,
54                    cs.get(idx)
55                        .expect("Malformed initializer. Cannot index into sub-initializer")
56                        .clone(),
57                ),
58                _ => panic!("Expected only array or struct const initializers"),
59            }
60        }
61        if !super::target_fuel::is_demotable_type(context, &ty) {
62            let ty_size: u32 = ty.size(context).in_bytes().try_into().unwrap();
63            let name = aggr_base_name.clone() + &base_off.to_string();
64            let scalarised_local =
65                function.new_unique_local_var(context, name, ty, initializer, false);
66            map.insert(*base_off, scalarised_local);
67
68            *base_off += ty_size;
69        } else {
70            let mut i = 0;
71            while let Some(member_ty) = ty.get_indexed_type(context, &[i]) {
72                let initializer = initializer
73                    .as_ref()
74                    .map(|c| constant_index(context, c, i as usize));
75                split_type(
76                    context,
77                    function,
78                    aggr_base_name,
79                    map,
80                    member_ty,
81                    initializer,
82                    base_off,
83                );
84
85                if ty.is_struct(context) {
86                    *base_off = crate::size_bytes_round_up_to_word_alignment!(*base_off);
87                }
88
89                i += 1;
90            }
91        }
92    }
93
94    let mut base_off = 0;
95    split_type(
96        context,
97        function,
98        &aggr_base_name,
99        &mut res,
100        ty,
101        local_aggr.get_initializer(context).cloned(),
102        &mut base_off,
103    );
104    res
105}
106
107/// Promote aggregates to scalars, so that other optimizations
108/// such as mem2reg can treat them as any other SSA value.
109pub fn sroa(
110    context: &mut Context,
111    analyses: &AnalysisResults,
112    function: Function,
113) -> Result<bool, IrError> {
114    let escaped_symbols: &EscapedSymbols = analyses.get_analysis_result(function);
115    let candidates = candidate_symbols(context, escaped_symbols, function);
116
117    if candidates.is_empty() {
118        return Ok(false);
119    }
120    // We now split each candidate into constituent scalar variables.
121    let offset_scalar_map: FxHashMap<Symbol, FxHashMap<u32, LocalVar>> = candidates
122        .iter()
123        .map(|sym| {
124            let Symbol::Local(local_aggr) = sym else {
125                panic!("Expected only local candidates")
126            };
127            (*sym, split_aggregate(context, function, *local_aggr))
128        })
129        .collect();
130
131    let mut scalar_replacements = FxHashMap::<Value, Value>::default();
132
133    for block in function.block_iter(context) {
134        let mut new_insts = Vec::new();
135        for inst in block.instruction_iter(context) {
136            if let InstOp::MemCopyVal {
137                dst_val_ptr,
138                src_val_ptr,
139            } = inst.get_instruction(context).unwrap().op
140            {
141                let src_syms = get_gep_referred_symbols(context, src_val_ptr);
142                let dst_syms = get_gep_referred_symbols(context, dst_val_ptr);
143
144                // If neither source nor dest needs rewriting, we skip.
145                let src_sym = src_syms
146                    .iter()
147                    .next()
148                    .filter(|src_sym| candidates.contains(src_sym));
149                let dst_sym = dst_syms
150                    .iter()
151                    .next()
152                    .filter(|dst_sym| candidates.contains(dst_sym));
153                if src_sym.is_none() && dst_sym.is_none() {
154                    new_insts.push(inst);
155                    continue;
156                }
157
158                struct ElmDetail {
159                    offset: u32,
160                    r#type: Type,
161                    indices: Vec<u32>,
162                }
163
164                // compute the offsets at which each (nested) field in our pointee type is at.
165                fn calc_elm_details(
166                    context: &Context,
167                    details: &mut Vec<ElmDetail>,
168                    ty: Type,
169                    base_off: &mut u32,
170                    base_index: &mut Vec<u32>,
171                ) {
172                    if !super::target_fuel::is_demotable_type(context, &ty) {
173                        let ty_size: u32 = ty.size(context).in_bytes().try_into().unwrap();
174                        details.push(ElmDetail {
175                            offset: *base_off,
176                            r#type: ty,
177                            indices: base_index.clone(),
178                        });
179                        *base_off += ty_size;
180                    } else {
181                        assert!(ty.is_aggregate(context));
182                        base_index.push(0);
183                        let mut i = 0;
184                        while let Some(member_ty) = ty.get_indexed_type(context, &[i]) {
185                            calc_elm_details(context, details, member_ty, base_off, base_index);
186                            i += 1;
187                            *base_index.last_mut().unwrap() += 1;
188
189                            if ty.is_struct(context) {
190                                *base_off =
191                                    crate::size_bytes_round_up_to_word_alignment!(*base_off);
192                            }
193                        }
194                        base_index.pop();
195                    }
196                }
197                let mut local_base_offset = 0;
198                let mut local_base_index = vec![];
199                let mut elm_details = vec![];
200                calc_elm_details(
201                    context,
202                    &mut elm_details,
203                    src_val_ptr
204                        .get_type(context)
205                        .unwrap()
206                        .get_pointee_type(context)
207                        .expect("Unable to determine pointee type of pointer"),
208                    &mut local_base_offset,
209                    &mut local_base_index,
210                );
211
212                // Handle the source pointer first.
213                let mut elm_local_map = FxHashMap::default();
214                if let Some(src_sym) = src_sym {
215                    // The source symbol is a candidate. So it has been split into scalars.
216                    // Load each of these into a SSA variable.
217                    let base_offset = combine_indices(context, src_val_ptr)
218                        .and_then(|indices| {
219                            src_sym
220                                .get_type(context)
221                                .get_pointee_type(context)
222                                .and_then(|pointee_ty| {
223                                    pointee_ty.get_value_indexed_offset(context, &indices)
224                                })
225                        })
226                        .expect("Source of memcpy was incorrectly identified as a candidate.")
227                        as u32;
228                    for detail in elm_details.iter() {
229                        let elm_offset = detail.offset;
230                        let actual_offset = elm_offset + base_offset;
231                        let remapped_var = offset_scalar_map
232                            .get(src_sym)
233                            .unwrap()
234                            .get(&actual_offset)
235                            .unwrap();
236                        let scalarized_local =
237                            Value::new_instruction(context, block, InstOp::GetLocal(*remapped_var));
238                        let load =
239                            Value::new_instruction(context, block, InstOp::Load(scalarized_local));
240                        elm_local_map.insert(elm_offset, load);
241                        new_insts.push(scalarized_local);
242                        new_insts.push(load);
243                    }
244                } else {
245                    // The source symbol is not a candidate. So it won't be split into scalars.
246                    // We must use GEPs to load each individual element into an SSA variable.
247                    for ElmDetail {
248                        offset,
249                        r#type,
250                        indices,
251                    } in &elm_details
252                    {
253                        let elm_addr = if indices.is_empty() {
254                            // We're looking at a pointer to a scalar, so no GEP needed.
255                            src_val_ptr
256                        } else {
257                            let elm_index_values = indices
258                                .iter()
259                                .map(|&index| Value::new_u64_constant(context, index.into()))
260                                .collect();
261                            let elem_ptr_ty = Type::new_typed_pointer(context, *r#type);
262                            let gep = Value::new_instruction(
263                                context,
264                                block,
265                                InstOp::GetElemPtr {
266                                    base: src_val_ptr,
267                                    elem_ptr_ty,
268                                    indices: elm_index_values,
269                                },
270                            );
271                            new_insts.push(gep);
272                            gep
273                        };
274                        let load = Value::new_instruction(context, block, InstOp::Load(elm_addr));
275                        elm_local_map.insert(*offset, load);
276                        new_insts.push(load);
277                    }
278                }
279                if let Some(dst_sym) = dst_sym {
280                    // The dst symbol is a candidate. So it has been split into scalars.
281                    // Store to each of these from the SSA variable we created above.
282                    let base_offset = combine_indices(context, dst_val_ptr)
283                        .and_then(|indices| {
284                            dst_sym
285                                .get_type(context)
286                                .get_pointee_type(context)
287                                .and_then(|pointee_ty| {
288                                    pointee_ty.get_value_indexed_offset(context, &indices)
289                                })
290                        })
291                        .expect("Source of memcpy was incorrectly identified as a candidate.")
292                        as u32;
293                    for detail in elm_details.iter() {
294                        let elm_offset = detail.offset;
295                        let actual_offset = elm_offset + base_offset;
296                        let remapped_var = offset_scalar_map
297                            .get(dst_sym)
298                            .unwrap()
299                            .get(&actual_offset)
300                            .unwrap();
301                        let scalarized_local =
302                            Value::new_instruction(context, block, InstOp::GetLocal(*remapped_var));
303                        let loaded_source = elm_local_map
304                            .get(&elm_offset)
305                            .expect("memcpy source not loaded");
306                        let store = Value::new_instruction(
307                            context,
308                            block,
309                            InstOp::Store {
310                                dst_val_ptr: scalarized_local,
311                                stored_val: *loaded_source,
312                            },
313                        );
314                        new_insts.push(scalarized_local);
315                        new_insts.push(store);
316                    }
317                } else {
318                    // The dst symbol is not a candidate. So it won't be split into scalars.
319                    // We must use GEPs to store to each individual element from its SSA variable.
320                    for ElmDetail {
321                        offset,
322                        r#type,
323                        indices,
324                    } in elm_details
325                    {
326                        let elm_addr = if indices.is_empty() {
327                            // We're looking at a pointer to a scalar, so no GEP needed.
328                            dst_val_ptr
329                        } else {
330                            let elm_index_values = indices
331                                .iter()
332                                .map(|&index| Value::new_u64_constant(context, index.into()))
333                                .collect();
334                            let elem_ptr_ty = Type::new_typed_pointer(context, r#type);
335                            let gep = Value::new_instruction(
336                                context,
337                                block,
338                                InstOp::GetElemPtr {
339                                    base: dst_val_ptr,
340                                    elem_ptr_ty,
341                                    indices: elm_index_values,
342                                },
343                            );
344                            new_insts.push(gep);
345                            gep
346                        };
347                        let loaded_source = elm_local_map
348                            .get(&offset)
349                            .expect("memcpy source not loaded");
350                        let store = Value::new_instruction(
351                            context,
352                            block,
353                            InstOp::Store {
354                                dst_val_ptr: elm_addr,
355                                stored_val: *loaded_source,
356                            },
357                        );
358                        new_insts.push(store);
359                    }
360                }
361
362                // We've handled the memcpy. it's been replaced with other instructions.
363                continue;
364            }
365            let loaded_pointers = get_loaded_ptr_values(context, inst);
366            let stored_pointers = get_stored_ptr_values(context, inst);
367
368            for ptr in loaded_pointers.iter().chain(stored_pointers.iter()) {
369                let syms = get_gep_referred_symbols(context, *ptr);
370                if let Some(sym) = syms
371                    .iter()
372                    .next()
373                    .filter(|sym| syms.len() == 1 && candidates.contains(sym))
374                {
375                    let Some(offset) = combine_indices(context, *ptr).and_then(|indices| {
376                        sym.get_type(context)
377                            .get_pointee_type(context)
378                            .and_then(|pointee_ty| {
379                                pointee_ty.get_value_indexed_offset(context, &indices)
380                            })
381                    }) else {
382                        continue;
383                    };
384                    let remapped_var = offset_scalar_map
385                        .get(sym)
386                        .unwrap()
387                        .get(&(offset as u32))
388                        .unwrap();
389                    let scalarized_local =
390                        Value::new_instruction(context, block, InstOp::GetLocal(*remapped_var));
391                    new_insts.push(scalarized_local);
392                    scalar_replacements.insert(*ptr, scalarized_local);
393                }
394            }
395            new_insts.push(inst);
396        }
397        block.take_body(context, new_insts);
398    }
399
400    function.replace_values(context, &scalar_replacements, None);
401
402    Ok(true)
403}
404
405// Is the aggregate type something that we can handle?
406fn is_processable_aggregate(context: &Context, ty: Type) -> bool {
407    fn check_sub_types(context: &Context, ty: Type) -> bool {
408        match ty.get_content(context) {
409            crate::TypeContent::Unit => true,
410            crate::TypeContent::Bool => true,
411            crate::TypeContent::Uint(width) => *width <= 64,
412            crate::TypeContent::B256 => false,
413            crate::TypeContent::Array(elm_ty, _) => check_sub_types(context, *elm_ty),
414            crate::TypeContent::Union(_) => false,
415            crate::TypeContent::Struct(fields) => {
416                fields.iter().all(|ty| check_sub_types(context, *ty))
417            }
418            crate::TypeContent::Slice => false,
419            crate::TypeContent::TypedSlice(..) => false,
420            crate::TypeContent::Pointer => true,
421            crate::TypeContent::TypedPointer(_) => true,
422            crate::TypeContent::StringSlice => false,
423            crate::TypeContent::StringArray(_) => false,
424            crate::TypeContent::Never => false,
425        }
426    }
427    ty.is_aggregate(context) && check_sub_types(context, ty)
428}
429
430// Filter out candidates that may not be profitable to scalarise.
431// This can be tuned in detail in the future when we have real benchmarks.
432fn profitability(context: &Context, function: Function, candidates: &mut FxHashSet<Symbol>) {
433    // If a candidate is sufficiently big and there's at least one memcpy
434    // accessing a big part of it, it may not be wise to scalarise it.
435    for (_, inst) in function.instruction_iter(context) {
436        if let InstOp::MemCopyVal {
437            dst_val_ptr,
438            src_val_ptr,
439        } = inst.get_instruction(context).unwrap().op
440        {
441            if pointee_size(context, dst_val_ptr) > 200 {
442                for sym in get_gep_referred_symbols(context, dst_val_ptr)
443                    .union(&get_gep_referred_symbols(context, src_val_ptr))
444                {
445                    candidates.remove(sym);
446                }
447            }
448        }
449    }
450}
451
452/// Only the following aggregates can be scalarised:
453/// 1. Does not escape.
454/// 2. Is always accessed via a scalar (register sized) field.
455///    i.e., The entire aggregate or a sub-aggregate isn't loaded / stored.
456///    (with an exception of `mem_copy_val` which we can handle).
457/// 3. Never accessed via non-const indexing.
458/// 4. Not aliased via a pointer that may point to more than one symbol.
459fn candidate_symbols(
460    context: &Context,
461    escaped_symbols: &EscapedSymbols,
462    function: Function,
463) -> FxHashSet<Symbol> {
464    let escaped_symbols = match escaped_symbols {
465        EscapedSymbols::Complete(syms) => syms,
466        EscapedSymbols::Incomplete(_) => return FxHashSet::<_>::default(),
467    };
468
469    let mut candidates: FxHashSet<Symbol> = function
470        .locals_iter(context)
471        .filter_map(|(_, l)| {
472            let sym = Symbol::Local(*l);
473            (!escaped_symbols.contains(&sym)
474                && l.get_type(context)
475                    .get_pointee_type(context)
476                    .is_some_and(|pointee_ty| is_processable_aggregate(context, pointee_ty)))
477            .then_some(sym)
478        })
479        .collect();
480
481    // We walk the function to remove from `candidates`, any local that is
482    // 1. accessed by a bigger-than-register sized load / store.
483    //    (we make an exception for load / store in `mem_copy_val` as that can be handled).
484    // 2. OR accessed via a non-const indexing.
485    // 3. OR aliased to a pointer that may point to more than one symbol.
486    for (_, inst) in function.instruction_iter(context) {
487        let loaded_pointers = get_loaded_ptr_values(context, inst);
488        let stored_pointers = get_stored_ptr_values(context, inst);
489
490        let inst = inst.get_instruction(context).unwrap();
491        for ptr in loaded_pointers.iter().chain(stored_pointers.iter()) {
492            let syms = get_gep_referred_symbols(context, *ptr);
493            if syms.len() != 1 {
494                for sym in &syms {
495                    candidates.remove(sym);
496                }
497                continue;
498            }
499            if combine_indices(context, *ptr)
500                .is_some_and(|indices| indices.iter().any(|idx| !idx.is_constant(context)))
501                || ptr.match_ptr_type(context).is_some_and(|pointee_ty| {
502                    super::target_fuel::is_demotable_type(context, &pointee_ty)
503                        && !matches!(inst.op, InstOp::MemCopyVal { .. })
504                })
505            {
506                candidates.remove(syms.iter().next().unwrap());
507            }
508        }
509    }
510
511    profitability(context, function, &mut candidates);
512
513    candidates
514}