sway_ir/optimize/
memcpyopt.rs

1//! Optimisations related to mem_copy.
2//! - replace a `store` directly from a `load` with a `mem_copy_val`.
3
4use indexmap::IndexMap;
5use itertools::{Either, Itertools};
6use rustc_hash::{FxHashMap, FxHashSet};
7use sway_types::{FxIndexMap, FxIndexSet};
8
9use crate::{
10    get_gep_symbol, get_loaded_symbols, get_referred_symbol, get_referred_symbols,
11    get_stored_symbols, memory_utils, AnalysisResults, Block, Context, EscapedSymbols,
12    FuelVmInstruction, Function, InstOp, Instruction, InstructionInserter, IrError, LocalVar, Pass,
13    PassMutability, ReferredSymbols, ScopedPass, Symbol, Type, Value, ValueDatum,
14    ESCAPED_SYMBOLS_NAME,
15};
16
17pub const MEMCPYOPT_NAME: &str = "memcpyopt";
18
19pub fn create_memcpyopt_pass() -> Pass {
20    Pass {
21        name: MEMCPYOPT_NAME,
22        descr: "Optimizations related to MemCopy instructions",
23        deps: vec![ESCAPED_SYMBOLS_NAME],
24        runner: ScopedPass::FunctionPass(PassMutability::Transform(mem_copy_opt)),
25    }
26}
27
28pub fn mem_copy_opt(
29    context: &mut Context,
30    analyses: &AnalysisResults,
31    function: Function,
32) -> Result<bool, IrError> {
33    let mut modified = false;
34    modified |= local_copy_prop_prememcpy(context, analyses, function)?;
35    modified |= load_store_to_memcopy(context, function)?;
36    modified |= local_copy_prop(context, analyses, function)?;
37
38    Ok(modified)
39}
40
41fn local_copy_prop_prememcpy(
42    context: &mut Context,
43    analyses: &AnalysisResults,
44    function: Function,
45) -> Result<bool, IrError> {
46    struct InstInfo {
47        // The block containing the instruction.
48        block: Block,
49        // Relative (use only for comparison) position of instruction in `block`.
50        pos: usize,
51    }
52
53    // If the analysis result is incomplete we cannot do any safe optimizations here.
54    // Calculating the candidates below relies on complete result of an escape analysis.
55    let escaped_symbols = match analyses.get_analysis_result(function) {
56        EscapedSymbols::Complete(syms) => syms,
57        EscapedSymbols::Incomplete(_) => return Ok(false),
58    };
59
60    // All instructions that load from the `Symbol`.
61    let mut loads_map = FxHashMap::<Symbol, Vec<Value>>::default();
62    // All instructions that store to the `Symbol`.
63    let mut stores_map = FxHashMap::<Symbol, Vec<Value>>::default();
64    // All load and store instructions.
65    let mut instr_info_map = FxHashMap::<Value, InstInfo>::default();
66
67    for (pos, (block, inst)) in function.instruction_iter(context).enumerate() {
68        let info = || InstInfo { block, pos };
69        let inst_e = inst.get_instruction(context).unwrap();
70        match inst_e {
71            Instruction {
72                op: InstOp::Load(src_val_ptr),
73                ..
74            } => {
75                if let Some(local) = get_referred_symbol(context, *src_val_ptr) {
76                    loads_map
77                        .entry(local)
78                        .and_modify(|loads| loads.push(inst))
79                        .or_insert(vec![inst]);
80                    instr_info_map.insert(inst, info());
81                }
82            }
83            Instruction {
84                op: InstOp::Store { dst_val_ptr, .. },
85                ..
86            } => {
87                if let Some(local) = get_referred_symbol(context, *dst_val_ptr) {
88                    stores_map
89                        .entry(local)
90                        .and_modify(|stores| stores.push(inst))
91                        .or_insert(vec![inst]);
92                    instr_info_map.insert(inst, info());
93                }
94            }
95            _ => (),
96        }
97    }
98
99    let mut to_delete = FxHashSet::<Value>::default();
100    // Candidates for replacements. The map's key `Symbol` is the
101    // destination `Symbol` that can be replaced with the
102    // map's value `Symbol`, the source.
103    // Replacement is possible (among other criteria explained below)
104    // only if the Store of the source is the only storing to the destination.
105    let candidates: FxHashMap<Symbol, Symbol> = function
106        .instruction_iter(context)
107        .enumerate()
108        .filter_map(|(pos, (block, instr_val))| {
109            // 1. Go through all the Store instructions whose source is
110            // a Load instruction...
111            instr_val
112                .get_instruction(context)
113                .and_then(|instr| {
114                    // Is the instruction a Store?
115                    if let Instruction {
116                        op:
117                            InstOp::Store {
118                                dst_val_ptr,
119                                stored_val,
120                            },
121                        ..
122                    } = instr
123                    {
124                        get_gep_symbol(context, *dst_val_ptr).and_then(|dst_local| {
125                            stored_val
126                                .get_instruction(context)
127                                .map(|src_instr| (src_instr, stored_val, dst_local))
128                        })
129                    } else {
130                        None
131                    }
132                })
133                .and_then(|(src_instr, stored_val, dst_local)| {
134                    // Is the Store source a Load?
135                    if let Instruction {
136                        op: InstOp::Load(src_val_ptr),
137                        ..
138                    } = src_instr
139                    {
140                        get_gep_symbol(context, *src_val_ptr)
141                            .map(|src_local| (stored_val, dst_local, src_local))
142                    } else {
143                        None
144                    }
145                })
146                .and_then(|(src_load, dst_local, src_local)| {
147                    // 2. ... and pick the (dest_local, src_local) pairs that fulfill the
148                    //    below criteria, in other words, where `dest_local` can be
149                    //    replaced with `src_local`.
150                    let (temp_empty1, temp_empty2, temp_empty3) = (vec![], vec![], vec![]);
151                    let dst_local_stores = stores_map.get(&dst_local).unwrap_or(&temp_empty1);
152                    let src_local_stores = stores_map.get(&src_local).unwrap_or(&temp_empty2);
153                    let dst_local_loads = loads_map.get(&dst_local).unwrap_or(&temp_empty3);
154                    // This must be the only store of dst_local.
155                    if dst_local_stores.len() != 1 || dst_local_stores[0] != instr_val
156                        ||
157                        // All stores of src_local must be in the same block, prior to src_load.
158                        !src_local_stores.iter().all(|store_val|{
159                            let instr_info = instr_info_map.get(store_val).unwrap();
160                            let src_load_info = instr_info_map.get(src_load).unwrap();
161                            instr_info.block == block && instr_info.pos < src_load_info.pos
162                        })
163                        ||
164                        // All loads of dst_local must be after this instruction, in the same block.
165                        !dst_local_loads.iter().all(|load_val| {
166                            let instr_info = instr_info_map.get(load_val).unwrap();
167                            instr_info.block == block && instr_info.pos > pos
168                        })
169                        // We don't deal with symbols that escape.
170                        || escaped_symbols.contains(&dst_local)
171                        || escaped_symbols.contains(&src_local)
172                        // We don't deal part copies.
173                        || dst_local.get_type(context) != src_local.get_type(context)
174                        // We don't replace the destination when it's an arg.
175                        || matches!(dst_local, Symbol::Arg(_))
176                    {
177                        None
178                    } else {
179                        to_delete.insert(instr_val);
180                        Some((dst_local, src_local))
181                    }
182                })
183        })
184        .collect();
185
186    // If we have A replaces B and B replaces C, then A must replace C also.
187    // Recursively searches for the final replacement for the `local`.
188    // Returns `None` if the `local` cannot be replaced.
189    fn get_replace_with(candidates: &FxHashMap<Symbol, Symbol>, local: &Symbol) -> Option<Symbol> {
190        candidates
191            .get(local)
192            .map(|replace_with| get_replace_with(candidates, replace_with).unwrap_or(*replace_with))
193    }
194
195    // If the source is an Arg, we replace uses of destination with Arg.
196    // Otherwise (`get_local`), we replace the local symbol in-place.
197    enum ReplaceWith {
198        InPlaceLocal(LocalVar),
199        Value(Value),
200    }
201
202    // Because we can't borrow context for both iterating and replacing, do it in 2 steps.
203    // `replaces` are the original GetLocal instructions with the corresponding replacements
204    // of their arguments.
205    let replaces: Vec<_> = function
206        .instruction_iter(context)
207        .filter_map(|(_block, value)| match value.get_instruction(context) {
208            Some(Instruction {
209                op: InstOp::GetLocal(local),
210                ..
211            }) => get_replace_with(&candidates, &Symbol::Local(*local)).map(|replace_with| {
212                (
213                    value,
214                    match replace_with {
215                        Symbol::Local(local) => ReplaceWith::InPlaceLocal(local),
216                        Symbol::Arg(ba) => {
217                            ReplaceWith::Value(ba.block.get_arg(context, ba.idx).unwrap())
218                        }
219                    },
220                )
221            }),
222            _ => None,
223        })
224        .collect();
225
226    let mut value_replace = FxHashMap::<Value, Value>::default();
227    for (value, replace_with) in replaces.into_iter() {
228        match replace_with {
229            ReplaceWith::InPlaceLocal(replacement_var) => {
230                let Some(&Instruction {
231                    op: InstOp::GetLocal(redundant_var),
232                    parent,
233                }) = value.get_instruction(context)
234                else {
235                    panic!("earlier match now fails");
236                };
237                if redundant_var.is_mutable(context) {
238                    replacement_var.set_mutable(context, true);
239                }
240                value.replace(
241                    context,
242                    ValueDatum::Instruction(Instruction {
243                        op: InstOp::GetLocal(replacement_var),
244                        parent,
245                    }),
246                )
247            }
248            ReplaceWith::Value(replace_with) => {
249                value_replace.insert(value, replace_with);
250            }
251        }
252    }
253    function.replace_values(context, &value_replace, None);
254
255    // Delete stores to the replaced local.
256    let blocks: Vec<Block> = function.block_iter(context).collect();
257    for block in blocks {
258        block.remove_instructions(context, |value| to_delete.contains(&value));
259    }
260    Ok(true)
261}
262
263// Deconstruct a memcpy into (dst_val_ptr, src_val_ptr, copy_len).
264fn deconstruct_memcpy(context: &Context, inst: Value) -> Option<(Value, Value, u64)> {
265    match inst.get_instruction(context).unwrap() {
266        Instruction {
267            op:
268                InstOp::MemCopyBytes {
269                    dst_val_ptr,
270                    src_val_ptr,
271                    byte_len,
272                },
273            ..
274        } => Some((*dst_val_ptr, *src_val_ptr, *byte_len)),
275        Instruction {
276            op:
277                InstOp::MemCopyVal {
278                    dst_val_ptr,
279                    src_val_ptr,
280                },
281            ..
282        } => Some((
283            *dst_val_ptr,
284            *src_val_ptr,
285            memory_utils::pointee_size(context, *dst_val_ptr),
286        )),
287        _ => None,
288    }
289}
290
291/// Copy propagation of `memcpy`s within a block.
292fn local_copy_prop(
293    context: &mut Context,
294    analyses: &AnalysisResults,
295    function: Function,
296) -> Result<bool, IrError> {
297    // If the analysis result is incomplete we cannot do any safe optimizations here.
298    // The `gen_new_copy` and `process_load` functions below rely on the fact that the
299    // analyzed symbols do not escape, something we cannot guarantee in case of
300    // an incomplete collection of escaped symbols.
301    let escaped_symbols = match analyses.get_analysis_result(function) {
302        EscapedSymbols::Complete(syms) => syms,
303        EscapedSymbols::Incomplete(_) => return Ok(false),
304    };
305
306    // Currently (as we scan a block) available `memcpy`s.
307    let mut available_copies: FxHashSet<Value>;
308    // Map a symbol to the available `memcpy`s of which it's a source.
309    let mut src_to_copies: FxIndexMap<Symbol, FxIndexSet<Value>>;
310    // Map a symbol to the available `memcpy`s of which it's a destination.
311    // (multiple `memcpy`s for the same destination may be available when
312    // they are partial / field writes, and don't alias).
313    let mut dest_to_copies: FxIndexMap<Symbol, FxIndexSet<Value>>;
314
315    // If a value (symbol) is found to be defined, remove it from our tracking.
316    fn kill_defined_symbol(
317        context: &Context,
318        value: Value,
319        len: u64,
320        available_copies: &mut FxHashSet<Value>,
321        src_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
322        dest_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
323    ) {
324        match get_referred_symbols(context, value) {
325            ReferredSymbols::Complete(rs) => {
326                for sym in rs {
327                    if let Some(copies) = src_to_copies.get_mut(&sym) {
328                        for copy in &*copies {
329                            let (_, src_ptr, copy_size) = deconstruct_memcpy(context, *copy)
330                                .expect("Expected copy instruction");
331                            if memory_utils::may_alias(context, value, len, src_ptr, copy_size) {
332                                available_copies.remove(copy);
333                            }
334                        }
335                    }
336                    if let Some(copies) = dest_to_copies.get_mut(&sym) {
337                        for copy in &*copies {
338                            let (dest_ptr, copy_size) = match copy.get_instruction(context).unwrap()
339                            {
340                                Instruction {
341                                    op:
342                                        InstOp::MemCopyBytes {
343                                            dst_val_ptr,
344                                            src_val_ptr: _,
345                                            byte_len,
346                                        },
347                                    ..
348                                } => (*dst_val_ptr, *byte_len),
349                                Instruction {
350                                    op:
351                                        InstOp::MemCopyVal {
352                                            dst_val_ptr,
353                                            src_val_ptr: _,
354                                        },
355                                    ..
356                                } => (
357                                    *dst_val_ptr,
358                                    memory_utils::pointee_size(context, *dst_val_ptr),
359                                ),
360                                _ => panic!("Unexpected copy instruction"),
361                            };
362                            if memory_utils::may_alias(context, value, len, dest_ptr, copy_size) {
363                                available_copies.remove(copy);
364                            }
365                        }
366                    }
367                }
368                // Update src_to_copies and dest_to_copies to remove every copy not in available_copies.
369                src_to_copies.retain(|_, copies| {
370                    copies.retain(|copy| available_copies.contains(copy));
371                    !copies.is_empty()
372                });
373                dest_to_copies.retain(|_, copies| {
374                    copies.retain(|copy| available_copies.contains(copy));
375                    !copies.is_empty()
376                });
377            }
378            ReferredSymbols::Incomplete(_) => {
379                // The only safe thing we can do is to clear all information.
380                available_copies.clear();
381                src_to_copies.clear();
382                dest_to_copies.clear();
383            }
384        }
385    }
386
387    #[allow(clippy::too_many_arguments)]
388    fn gen_new_copy(
389        context: &Context,
390        escaped_symbols: &FxHashSet<Symbol>,
391        copy_inst: Value,
392        dst_val_ptr: Value,
393        src_val_ptr: Value,
394        available_copies: &mut FxHashSet<Value>,
395        src_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
396        dest_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
397    ) {
398        if let (Some(dst_sym), Some(src_sym)) = (
399            get_gep_symbol(context, dst_val_ptr),
400            get_gep_symbol(context, src_val_ptr),
401        ) {
402            if escaped_symbols.contains(&dst_sym) || escaped_symbols.contains(&src_sym) {
403                return;
404            }
405            dest_to_copies
406                .entry(dst_sym)
407                .and_modify(|set| {
408                    set.insert(copy_inst);
409                })
410                .or_insert([copy_inst].into_iter().collect());
411            src_to_copies
412                .entry(src_sym)
413                .and_modify(|set| {
414                    set.insert(copy_inst);
415                })
416                .or_insert([copy_inst].into_iter().collect());
417            available_copies.insert(copy_inst);
418        }
419    }
420
421    struct ReplGep {
422        base: Symbol,
423        elem_ptr_ty: Type,
424        indices: Vec<Value>,
425    }
426    enum Replacement {
427        OldGep(Value),
428        NewGep(ReplGep),
429    }
430
431    fn process_load(
432        context: &Context,
433        escaped_symbols: &FxHashSet<Symbol>,
434        inst: Value,
435        src_val_ptr: Value,
436        dest_to_copies: &FxIndexMap<Symbol, FxIndexSet<Value>>,
437        replacements: &mut FxHashMap<Value, (Value, Replacement)>,
438    ) -> bool {
439        // For every `memcpy` that src_val_ptr is a destination of,
440        // check if we can do the load from the source of that memcpy.
441        if let Some(src_sym) = get_referred_symbol(context, src_val_ptr) {
442            if escaped_symbols.contains(&src_sym) {
443                return false;
444            }
445            for memcpy in dest_to_copies
446                .get(&src_sym)
447                .iter()
448                .flat_map(|set| set.iter())
449            {
450                let (dst_ptr_memcpy, src_ptr_memcpy, copy_len) =
451                    deconstruct_memcpy(context, *memcpy).expect("Expected copy instruction");
452                // If the location where we're loading from exactly matches the destination of
453                // the memcpy, just load from the source pointer of the memcpy.
454                // TODO: In both the arms below, we check that the pointer type
455                // matches. This isn't really needed as the copy happens and the
456                // data we want is safe to access. But we just don't know how to
457                // generate the right GEP always. So that's left for another day.
458                if memory_utils::must_alias(
459                    context,
460                    src_val_ptr,
461                    memory_utils::pointee_size(context, src_val_ptr),
462                    dst_ptr_memcpy,
463                    copy_len,
464                ) {
465                    // Replace src_val_ptr with src_ptr_memcpy.
466                    if src_val_ptr.get_type(context) == src_ptr_memcpy.get_type(context) {
467                        replacements
468                            .insert(inst, (src_val_ptr, Replacement::OldGep(src_ptr_memcpy)));
469                        return true;
470                    }
471                } else {
472                    // if the memcpy copies the entire symbol, we could
473                    // insert a new GEP from the source of the memcpy.
474                    if let (Some(memcpy_src_sym), Some(memcpy_dst_sym), Some(new_indices)) = (
475                        get_gep_symbol(context, src_ptr_memcpy),
476                        get_gep_symbol(context, dst_ptr_memcpy),
477                        memory_utils::combine_indices(context, src_val_ptr),
478                    ) {
479                        let memcpy_src_sym_type = memcpy_src_sym
480                            .get_type(context)
481                            .get_pointee_type(context)
482                            .unwrap();
483                        let memcpy_dst_sym_type = memcpy_dst_sym
484                            .get_type(context)
485                            .get_pointee_type(context)
486                            .unwrap();
487                        if memcpy_src_sym_type == memcpy_dst_sym_type
488                            && memcpy_dst_sym_type.size(context).in_bytes() == copy_len
489                        {
490                            replacements.insert(
491                                inst,
492                                (
493                                    src_val_ptr,
494                                    Replacement::NewGep(ReplGep {
495                                        base: memcpy_src_sym,
496                                        elem_ptr_ty: src_val_ptr.get_type(context).unwrap(),
497                                        indices: new_indices,
498                                    }),
499                                ),
500                            );
501                            return true;
502                        }
503                    }
504                }
505            }
506        }
507
508        false
509    }
510
511    let mut modified = false;
512    for block in function.block_iter(context) {
513        // A `memcpy` itself has a `load`, so we can `process_load` on it.
514        // If now, we've marked the source of this `memcpy` for optimization,
515        // it itself cannot be "generated" as a new candidate `memcpy`.
516        // This is the reason we run a loop on the block till there's no more
517        // optimization possible. We could track just the changes and do it
518        // all in one go, but that would complicate the algorithm. So I've
519        // marked this as a TODO for now (#4600).
520        loop {
521            available_copies = FxHashSet::default();
522            src_to_copies = IndexMap::default();
523            dest_to_copies = IndexMap::default();
524
525            // Replace the load/memcpy source pointer with something else.
526            let mut replacements = FxHashMap::default();
527
528            fn kill_escape_args(
529                context: &Context,
530                args: &Vec<Value>,
531                available_copies: &mut FxHashSet<Value>,
532                src_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
533                dest_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
534            ) {
535                for arg in args {
536                    match get_referred_symbols(context, *arg) {
537                        ReferredSymbols::Complete(rs) => {
538                            let max_size = rs
539                                .iter()
540                                .filter_map(|sym| {
541                                    sym.get_type(context)
542                                        .get_pointee_type(context)
543                                        .map(|pt| pt.size(context).in_bytes())
544                                })
545                                .max()
546                                .unwrap_or(0);
547                            kill_defined_symbol(
548                                context,
549                                *arg,
550                                max_size,
551                                available_copies,
552                                src_to_copies,
553                                dest_to_copies,
554                            );
555                        }
556                        ReferredSymbols::Incomplete(_) => {
557                            // The only safe thing we can do is to clear all information.
558                            available_copies.clear();
559                            src_to_copies.clear();
560                            dest_to_copies.clear();
561
562                            break;
563                        }
564                    }
565                }
566            }
567
568            for inst in block.instruction_iter(context) {
569                match inst.get_instruction(context).unwrap() {
570                    Instruction {
571                        op: InstOp::Call(callee, args),
572                        ..
573                    } => {
574                        let (immutable_args, mutable_args): (Vec<_>, Vec<_>) =
575                            args.iter().enumerate().partition_map(|(arg_idx, arg)| {
576                                if callee.is_arg_immutable(context, arg_idx) {
577                                    Either::Left(*arg)
578                                } else {
579                                    Either::Right(*arg)
580                                }
581                            });
582                        // whichever args may get mutated, we kill them.
583                        kill_escape_args(
584                            context,
585                            &mutable_args,
586                            &mut available_copies,
587                            &mut src_to_copies,
588                            &mut dest_to_copies,
589                        );
590                        // args that aren't mutated can be treated as a "load" (for the purposes
591                        // of optimization).
592                        for arg in immutable_args {
593                            process_load(
594                                context,
595                                escaped_symbols,
596                                inst,
597                                arg,
598                                &dest_to_copies,
599                                &mut replacements,
600                            );
601                        }
602                    }
603                    Instruction {
604                        op: InstOp::AsmBlock(_, args),
605                        ..
606                    } => {
607                        let args = args.iter().filter_map(|arg| arg.initializer).collect();
608                        kill_escape_args(
609                            context,
610                            &args,
611                            &mut available_copies,
612                            &mut src_to_copies,
613                            &mut dest_to_copies,
614                        );
615                    }
616                    Instruction {
617                        op: InstOp::IntToPtr(_, _),
618                        ..
619                    } => {
620                        // The only safe thing we can do is to clear all information.
621                        available_copies.clear();
622                        src_to_copies.clear();
623                        dest_to_copies.clear();
624                    }
625                    Instruction {
626                        op: InstOp::Load(src_val_ptr),
627                        ..
628                    } => {
629                        process_load(
630                            context,
631                            escaped_symbols,
632                            inst,
633                            *src_val_ptr,
634                            &dest_to_copies,
635                            &mut replacements,
636                        );
637                    }
638                    Instruction {
639                        op: InstOp::MemCopyBytes { .. } | InstOp::MemCopyVal { .. },
640                        ..
641                    } => {
642                        let (dst_val_ptr, src_val_ptr, copy_len) =
643                            deconstruct_memcpy(context, inst).expect("Expected copy instruction");
644                        kill_defined_symbol(
645                            context,
646                            dst_val_ptr,
647                            copy_len,
648                            &mut available_copies,
649                            &mut src_to_copies,
650                            &mut dest_to_copies,
651                        );
652                        // If this memcpy itself can be optimized, we do just that, and not "gen" a new one.
653                        if !process_load(
654                            context,
655                            escaped_symbols,
656                            inst,
657                            src_val_ptr,
658                            &dest_to_copies,
659                            &mut replacements,
660                        ) {
661                            gen_new_copy(
662                                context,
663                                escaped_symbols,
664                                inst,
665                                dst_val_ptr,
666                                src_val_ptr,
667                                &mut available_copies,
668                                &mut src_to_copies,
669                                &mut dest_to_copies,
670                            );
671                        }
672                    }
673                    Instruction {
674                        op:
675                            InstOp::Store {
676                                dst_val_ptr,
677                                stored_val: _,
678                            },
679                        ..
680                    } => {
681                        kill_defined_symbol(
682                            context,
683                            *dst_val_ptr,
684                            memory_utils::pointee_size(context, *dst_val_ptr),
685                            &mut available_copies,
686                            &mut src_to_copies,
687                            &mut dest_to_copies,
688                        );
689                    }
690                    Instruction {
691                        op:
692                            InstOp::FuelVm(
693                                FuelVmInstruction::WideBinaryOp { result, .. }
694                                | FuelVmInstruction::WideUnaryOp { result, .. }
695                                | FuelVmInstruction::WideModularOp { result, .. }
696                                | FuelVmInstruction::StateLoadQuadWord {
697                                    load_val: result, ..
698                                },
699                            ),
700                        ..
701                    } => {
702                        kill_defined_symbol(
703                            context,
704                            *result,
705                            memory_utils::pointee_size(context, *result),
706                            &mut available_copies,
707                            &mut src_to_copies,
708                            &mut dest_to_copies,
709                        );
710                    }
711                    _ => (),
712                }
713            }
714
715            if replacements.is_empty() {
716                break;
717            } else {
718                modified = true;
719            }
720
721            // If we have any NewGep replacements, insert those new GEPs into the block.
722            // Since the new instructions need to be just before the value load that they're
723            // going to be used in, we copy all the instructions into a new vec
724            // and just replace the contents of the basic block.
725            let mut new_insts = vec![];
726            for inst in block.instruction_iter(context) {
727                if let Some(replacement) = replacements.remove(&inst) {
728                    let (to_replace, replacement) = match replacement {
729                        (to_replace, Replacement::OldGep(v)) => (to_replace, v),
730                        (
731                            to_replace,
732                            Replacement::NewGep(ReplGep {
733                                base,
734                                elem_ptr_ty,
735                                indices,
736                            }),
737                        ) => {
738                            let base = match base {
739                                Symbol::Local(local) => {
740                                    let base = Value::new_instruction(
741                                        context,
742                                        block,
743                                        InstOp::GetLocal(local),
744                                    );
745                                    new_insts.push(base);
746                                    base
747                                }
748                                Symbol::Arg(block_arg) => {
749                                    block_arg.block.get_arg(context, block_arg.idx).unwrap()
750                                }
751                            };
752                            let v = Value::new_instruction(
753                                context,
754                                block,
755                                InstOp::GetElemPtr {
756                                    base,
757                                    elem_ptr_ty,
758                                    indices,
759                                },
760                            );
761                            new_insts.push(v);
762                            (to_replace, v)
763                        }
764                    };
765                    match inst.get_instruction_mut(context) {
766                        Some(Instruction {
767                            op: InstOp::Load(ref mut src_val_ptr),
768                            ..
769                        })
770                        | Some(Instruction {
771                            op:
772                                InstOp::MemCopyBytes {
773                                    ref mut src_val_ptr,
774                                    ..
775                                },
776                            ..
777                        })
778                        | Some(Instruction {
779                            op:
780                                InstOp::MemCopyVal {
781                                    ref mut src_val_ptr,
782                                    ..
783                                },
784                            ..
785                        }) => {
786                            assert!(to_replace == *src_val_ptr);
787                            *src_val_ptr = replacement
788                        }
789                        Some(Instruction {
790                            op: InstOp::Call(_callee, args),
791                            ..
792                        }) => {
793                            for arg in args {
794                                if *arg == to_replace {
795                                    *arg = replacement;
796                                }
797                            }
798                        }
799                        _ => panic!("Unexpected instruction type"),
800                    }
801                }
802                new_insts.push(inst);
803            }
804
805            // Replace the basic block contents with what we just built.
806            block.take_body(context, new_insts);
807        }
808    }
809
810    Ok(modified)
811}
812
813struct Candidate {
814    load_val: Value,
815    store_val: Value,
816    dst_ptr: Value,
817    src_ptr: Value,
818}
819
820enum CandidateKind {
821    /// If aggregates are clobbered b/w a load and the store, we still need to,
822    /// for correctness (because asmgen cannot handle aggregate loads and stores)
823    /// do the memcpy. So we insert a memcpy to a temporary stack location right after
824    /// the load, and memcpy it to the store pointer at the point of store.
825    ClobberedNoncopyType(Candidate),
826    NonClobbered(Candidate),
827}
828
829/// Starting backwards from `end_inst`, till we reach `start_inst` or the entry block,
830/// is `scrutiny_ptr` (or an alias of it) stored to (i.e., clobbered)?
831/// Also checks that there is no overlap (common symbols) between
832/// `no_overlap_ptr` and `scrutiny_ptr`.
833fn is_clobbered(
834    context: &Context,
835    start_inst: &Value,
836    end_inst: &Value,
837    no_overlap_ptr: &Value,
838    scrutiny_ptr: &Value,
839) -> bool {
840    let end_block = end_inst.get_instruction(context).unwrap().parent;
841    let entry_block = end_block.get_function(context).get_entry_block(context);
842
843    let mut iter = end_block
844        .instruction_iter(context)
845        .rev()
846        .skip_while(|i| i != end_inst);
847    assert!(iter.next().unwrap() == *end_inst);
848
849    let ReferredSymbols::Complete(scrutiny_symbols) = get_referred_symbols(context, *scrutiny_ptr)
850    else {
851        return true;
852    };
853
854    let ReferredSymbols::Complete(no_overlap_symbols) =
855        get_referred_symbols(context, *no_overlap_ptr)
856    else {
857        return true;
858    };
859
860    // If the two pointers may have an overlap, we'll end up generating a mcp
861    // with overlapping source/destination which is not allowed.
862    if scrutiny_symbols
863        .intersection(&no_overlap_symbols)
864        .next()
865        .is_some()
866    {
867        return true;
868    }
869
870    // Scan backwards till we encounter start_val, checking if
871    // any store aliases with scrutiny_ptr.
872    let mut worklist: Vec<(Block, Box<dyn Iterator<Item = Value>>)> =
873        vec![(end_block, Box::new(iter))];
874    let mut visited = FxHashSet::default();
875    'next_job: while let Some((block, iter)) = worklist.pop() {
876        visited.insert(block);
877        for inst in iter {
878            if inst == *start_inst || inst == *end_inst {
879                // We don't need to go beyond either start_inst or end_inst.
880                continue 'next_job;
881            }
882            let stored_syms = get_stored_symbols(context, inst);
883            if let ReferredSymbols::Complete(syms) = stored_syms {
884                if syms.iter().any(|sym| scrutiny_symbols.contains(sym)) {
885                    return true;
886                }
887            } else {
888                return true;
889            }
890        }
891
892        if entry_block == block {
893            // We've reached the entry block. If any of the scrutiny_symbols
894            // is an argument, then we consider it clobbered.
895            if scrutiny_symbols
896                .iter()
897                .any(|sym| matches!(sym, Symbol::Arg(_)))
898            {
899                return true;
900            }
901        }
902
903        for pred in block.pred_iter(context) {
904            if !visited.contains(pred) {
905                worklist.push((
906                    *pred,
907                    Box::new(pred.instruction_iter(context).rev().skip_while(|_| false)),
908                ));
909            }
910        }
911    }
912
913    false
914}
915
916// This is a copy of sway_core::asm_generation::fuel::fuel_asm_builder::FuelAsmBuilder::is_copy_type.
917fn is_copy_type(ty: &Type, context: &Context) -> bool {
918    ty.is_unit(context)
919        || ty.is_never(context)
920        || ty.is_bool(context)
921        || ty.is_ptr(context)
922        || ty.get_uint_width(context).map(|x| x < 256).unwrap_or(false)
923}
924
925fn load_store_to_memcopy(context: &mut Context, function: Function) -> Result<bool, IrError> {
926    // Find any `store`s of `load`s.  These can be replaced with `mem_copy` and are especially
927    // important for non-copy types on architectures which don't support loading them.
928    let candidates = function
929        .instruction_iter(context)
930        .filter_map(|(_, store_instr_val)| {
931            store_instr_val
932                .get_instruction(context)
933                .and_then(|instr| {
934                    // Is the instruction a Store?
935                    if let Instruction {
936                        op:
937                            InstOp::Store {
938                                dst_val_ptr,
939                                stored_val,
940                            },
941                        ..
942                    } = instr
943                    {
944                        stored_val
945                            .get_instruction(context)
946                            .map(|src_instr| (*stored_val, src_instr, dst_val_ptr))
947                    } else {
948                        None
949                    }
950                })
951                .and_then(|(src_instr_val, src_instr, dst_val_ptr)| {
952                    // Is the Store source a Load?
953                    if let Instruction {
954                        op: InstOp::Load(src_val_ptr),
955                        ..
956                    } = src_instr
957                    {
958                        Some(Candidate {
959                            load_val: src_instr_val,
960                            store_val: store_instr_val,
961                            dst_ptr: *dst_val_ptr,
962                            src_ptr: *src_val_ptr,
963                        })
964                    } else {
965                        None
966                    }
967                })
968                .and_then(|candidate @ Candidate { dst_ptr, .. }| {
969                    // Check that there's no path from load_val to store_val that might overwrite src_ptr.
970                    if !is_clobbered(
971                        context,
972                        &candidate.load_val,
973                        &candidate.store_val,
974                        &candidate.dst_ptr,
975                        &candidate.src_ptr,
976                    ) {
977                        Some(CandidateKind::NonClobbered(candidate))
978                    } else if !is_copy_type(&dst_ptr.match_ptr_type(context).unwrap(), context) {
979                        Some(CandidateKind::ClobberedNoncopyType(candidate))
980                    } else {
981                        None
982                    }
983                })
984        })
985        .collect::<Vec<_>>();
986
987    if candidates.is_empty() {
988        return Ok(false);
989    }
990
991    for candidate in candidates {
992        match candidate {
993            CandidateKind::ClobberedNoncopyType(Candidate {
994                load_val,
995                store_val,
996                dst_ptr,
997                src_ptr,
998            }) => {
999                let load_block = load_val.get_instruction(context).unwrap().parent;
1000                let temp = function.new_unique_local_var(
1001                    context,
1002                    "__aggr_memcpy_0".into(),
1003                    src_ptr.match_ptr_type(context).unwrap(),
1004                    None,
1005                    true,
1006                );
1007                let temp_local =
1008                    Value::new_instruction(context, load_block, InstOp::GetLocal(temp));
1009                let to_temp = Value::new_instruction(
1010                    context,
1011                    load_block,
1012                    InstOp::MemCopyVal {
1013                        dst_val_ptr: temp_local,
1014                        src_val_ptr: src_ptr,
1015                    },
1016                );
1017                let mut inserter = InstructionInserter::new(
1018                    context,
1019                    load_block,
1020                    crate::InsertionPosition::After(load_val),
1021                );
1022                inserter.insert_slice(&[temp_local, to_temp]);
1023
1024                let store_block = store_val.get_instruction(context).unwrap().parent;
1025                let mem_copy_val = Value::new_instruction(
1026                    context,
1027                    store_block,
1028                    InstOp::MemCopyVal {
1029                        dst_val_ptr: dst_ptr,
1030                        src_val_ptr: temp_local,
1031                    },
1032                );
1033                store_block.replace_instruction(context, store_val, mem_copy_val, true)?;
1034            }
1035            CandidateKind::NonClobbered(Candidate {
1036                dst_ptr: dst_val_ptr,
1037                src_ptr: src_val_ptr,
1038                store_val,
1039                ..
1040            }) => {
1041                let store_block = store_val.get_instruction(context).unwrap().parent;
1042                let mem_copy_val = Value::new_instruction(
1043                    context,
1044                    store_block,
1045                    InstOp::MemCopyVal {
1046                        dst_val_ptr,
1047                        src_val_ptr,
1048                    },
1049                );
1050                store_block.replace_instruction(context, store_val, mem_copy_val, true)?;
1051            }
1052        }
1053    }
1054
1055    Ok(true)
1056}
1057
1058pub const MEMCPYPROP_REVERSE_NAME: &str = "memcpyprop_reverse";
1059
1060pub fn create_memcpyprop_reverse_pass() -> Pass {
1061    Pass {
1062        name: MEMCPYPROP_REVERSE_NAME,
1063        descr: "Copy propagation of MemCpy instructions",
1064        deps: vec![],
1065        runner: ScopedPass::FunctionPass(PassMutability::Transform(copy_prop_reverse)),
1066    }
1067}
1068
1069/// Copy propagation of `memcpy`s, replacing source with destination.
1070fn copy_prop_reverse(
1071    context: &mut Context,
1072    _analyses: &AnalysisResults,
1073    function: Function,
1074) -> Result<bool, IrError> {
1075    let mut modified = false;
1076
1077    // let's first compute the definitions and uses of every symbol.
1078    let mut stores_map: FxHashMap<Symbol, Vec<Value>> = FxHashMap::default();
1079    let mut loads_map: FxHashMap<Symbol, Vec<Value>> = FxHashMap::default();
1080    for (_block, instr_val) in function.instruction_iter(context) {
1081        let stored_syms = get_stored_symbols(context, instr_val);
1082        let stored_syms = match stored_syms {
1083            ReferredSymbols::Complete(syms) => syms,
1084            ReferredSymbols::Incomplete(_) => return Ok(false),
1085        };
1086        let loaded_syms = get_loaded_symbols(context, instr_val);
1087        let loaded_syms = match loaded_syms {
1088            ReferredSymbols::Complete(syms) => syms,
1089            ReferredSymbols::Incomplete(_) => return Ok(false),
1090        };
1091        for sym in stored_syms {
1092            stores_map.entry(sym).or_default().push(instr_val);
1093        }
1094        for sym in loaded_syms {
1095            loads_map.entry(sym).or_default().push(instr_val);
1096        }
1097    }
1098
1099    let mut candidates = vec![];
1100
1101    for (_block, inst) in function.instruction_iter(context) {
1102        let Some((dst_ptr, src_ptr, byte_len)) = deconstruct_memcpy(context, inst) else {
1103            continue;
1104        };
1105
1106        if dst_ptr.get_type(context) != src_ptr.get_type(context) {
1107            continue;
1108        }
1109
1110        // We can replace the source of this memcpy with the destination
1111        // if:
1112        // 1. All uses of the destination symbol are dominated by this memcpy.
1113        // 2. All uses of the source symbol are dominated by this memcpy.
1114
1115        let dst_sym = match get_referred_symbols(context, dst_ptr) {
1116            ReferredSymbols::Complete(syms) if syms.len() == 1 => syms.into_iter().next().unwrap(),
1117            _ => continue,
1118        };
1119        let src_sym = match get_referred_symbols(context, src_ptr) {
1120            ReferredSymbols::Complete(syms) if syms.len() == 1 => syms.into_iter().next().unwrap(),
1121            _ => continue,
1122        };
1123
1124        if dst_sym.get_type(context) != src_sym.get_type(context) {
1125            continue;
1126        }
1127
1128        // We don't deal with partial memcpys
1129        if dst_sym
1130            .get_type(context)
1131            .get_pointee_type(context)
1132            .expect("All symbols must be pointer types")
1133            .size(context)
1134            .in_bytes()
1135            != byte_len
1136        {
1137            continue;
1138        }
1139
1140        // For every use of the source symbol:
1141        //   starting from the use, walk backwards till we reach this memcpy,
1142        //   checking if there's a store to an alias of the destination symbol in that path.
1143        let source_uses_not_clobbered = loads_map
1144            .get(&src_sym)
1145            .map(|uses| {
1146                uses.iter().all(|use_val: &Value| {
1147                    *use_val == inst || !is_clobbered(context, &inst, use_val, &src_ptr, &dst_ptr)
1148                })
1149            })
1150            .unwrap_or(true);
1151
1152        // For every use of the destination symbol:
1153        //   starting from the use, walk backwards till we reach this memcpy,
1154        //   checking if there's a store to an alias of the source symbol in that path.
1155        let destination_uses_not_clobbered = loads_map
1156            .get(&dst_sym)
1157            .map(|uses| {
1158                uses.iter()
1159                    .all(|use_val| !is_clobbered(context, &inst, use_val, &dst_ptr, &src_ptr))
1160            })
1161            .unwrap_or(true);
1162
1163        if source_uses_not_clobbered && destination_uses_not_clobbered {
1164            candidates.push((inst, dst_sym, src_sym));
1165        }
1166    }
1167
1168    if candidates.is_empty() {
1169        return Ok(false);
1170    }
1171
1172    let mut to_delete: FxHashSet<Value> = FxHashSet::default();
1173    let mut src_to_dst: FxHashMap<Symbol, Symbol> = FxHashMap::default();
1174
1175    for (inst, dst_sym, src_sym) in candidates {
1176        match src_sym {
1177            Symbol::Arg(_) => {
1178                // Args are mostly copied to locals before actually being used.
1179                // So we don't handle them for now. Handling them would require
1180                // handling more instructions where they can be used, which probably
1181                // isn't worth it.
1182                continue;
1183            }
1184            Symbol::Local(local) => {
1185                if local.get_initializer(context).is_some() {
1186                    // If the source is a local and it has an initializer, we run into trouble
1187                    // 1. If the destination (after transitive closure below) is not a local,
1188                    //    we cannot initialize it with the source's initializer.
1189                    // 2. If the destination is a local, but it already has an initializer (by itself
1190                    //    or by another source in the chain), we cannot initialize it with this initializer.
1191                    continue;
1192                }
1193                match src_to_dst.entry(src_sym) {
1194                    std::collections::hash_map::Entry::Vacant(e) => {
1195                        e.insert(dst_sym);
1196                    }
1197                    std::collections::hash_map::Entry::Occupied(e) => {
1198                        if *e.get() != dst_sym {
1199                            // src_sym is copied to two different dst_syms. We cannot optimize this.
1200                            continue;
1201                        }
1202                    }
1203                }
1204                to_delete.insert(inst);
1205            }
1206        }
1207    }
1208
1209    // Take a transitive closure of src_to_dst.
1210    {
1211        let mut changed = true;
1212        let mut cycle_detected = false;
1213        while changed {
1214            changed = false;
1215            src_to_dst.clone().iter().for_each(|(src, dst)| {
1216                if let Some(next_dst) = src_to_dst.get(dst) {
1217                    // Cycle detection
1218                    if *next_dst == *src {
1219                        cycle_detected = true;
1220                        return;
1221                    }
1222                    src_to_dst.insert(*src, *next_dst);
1223                    changed = true;
1224                }
1225            });
1226        }
1227        if cycle_detected {
1228            // We cannot optimize in presence of cycles.
1229            return Ok(modified);
1230        }
1231    }
1232
1233    // Gather the get_local instructions that need to be replaced.
1234    let mut repl_locals = vec![];
1235    for (_block, inst) in function.instruction_iter(context) {
1236        match inst.get_instruction(context).unwrap() {
1237            Instruction {
1238                op: InstOp::GetLocal(sym),
1239                ..
1240            } => {
1241                if let Some(dst) = src_to_dst.get(&Symbol::Local(*sym)) {
1242                    repl_locals.push((inst, *dst));
1243                }
1244            }
1245            _ => {
1246                // Any access to a local begins with a GetLocal, we can ignore the rest
1247                // (unless we support Symbol::Arg above).
1248            }
1249        }
1250    }
1251
1252    if repl_locals.is_empty() {
1253        return Ok(modified);
1254    }
1255    modified = true;
1256
1257    let mut value_replacements = FxHashMap::default();
1258    for (to_repl, repl_with) in repl_locals {
1259        let Instruction {
1260            op: InstOp::GetLocal(sym),
1261            ..
1262        } = to_repl.get_instruction_mut(context).unwrap()
1263        else {
1264            panic!("Expected GetLocal instruction");
1265        };
1266        match repl_with {
1267            Symbol::Local(dst_local) => {
1268                // We just modify this GetLocal in-place.
1269                *sym = dst_local;
1270            }
1271            Symbol::Arg(arg) => {
1272                // The get_local needs to be replaced with the right argument Value.
1273                value_replacements.insert(to_repl, arg.as_value(context));
1274            }
1275        }
1276    }
1277
1278    // Replace get_locals with the right values.
1279    function.replace_values(context, &value_replacements, None);
1280
1281    // In instances such as
1282    //        (1) b <- a
1283    //        /         \
1284    // (2) x <- b    (3):  x <- a
1285    // when we decide to eliminate (1) and (2), i.e., both `b` and `a` end up
1286    // being replaced by `x`, (3) will end up becoming `x <- x`. We need to
1287    // clean these up.
1288    for (_, inst) in function.instruction_iter(context) {
1289        let Some((dst_ptr, src_ptr, _byte_len)) = deconstruct_memcpy(context, inst) else {
1290            continue;
1291        };
1292
1293        let dst_sym = match get_referred_symbols(context, dst_ptr) {
1294            ReferredSymbols::Complete(syms) if syms.len() == 1 => syms.into_iter().next().unwrap(),
1295            _ => continue,
1296        };
1297        let src_sym = match get_referred_symbols(context, src_ptr) {
1298            ReferredSymbols::Complete(syms) if syms.len() == 1 => syms.into_iter().next().unwrap(),
1299            _ => continue,
1300        };
1301
1302        if dst_sym == src_sym {
1303            to_delete.insert(inst);
1304        }
1305    }
1306
1307    function.remove_instructions(context, |v| to_delete.contains(&v));
1308
1309    Ok(modified)
1310}