sway_ir/optimize/
memcpyopt.rs

1//! Optimisations related to mem_copy.
2//! - replace a `store` directly from a `load` with a `mem_copy_val`.
3
4use indexmap::IndexMap;
5use itertools::{Either, Itertools};
6use rustc_hash::{FxHashMap, FxHashSet};
7use sway_types::{FxIndexMap, FxIndexSet};
8
9use crate::{
10    get_gep_symbol, get_referred_symbol, get_referred_symbols, get_stored_symbols, memory_utils,
11    AnalysisResults, Block, Context, EscapedSymbols, FuelVmInstruction, Function, InstOp,
12    Instruction, InstructionInserter, IrError, LocalVar, Pass, PassMutability, ReferredSymbols,
13    ScopedPass, Symbol, Type, Value, ValueDatum, ESCAPED_SYMBOLS_NAME,
14};
15
16pub const MEMCPYOPT_NAME: &str = "memcpyopt";
17
18pub fn create_memcpyopt_pass() -> Pass {
19    Pass {
20        name: MEMCPYOPT_NAME,
21        descr: "Optimizations related to MemCopy instructions",
22        deps: vec![ESCAPED_SYMBOLS_NAME],
23        runner: ScopedPass::FunctionPass(PassMutability::Transform(mem_copy_opt)),
24    }
25}
26
27pub fn mem_copy_opt(
28    context: &mut Context,
29    analyses: &AnalysisResults,
30    function: Function,
31) -> Result<bool, IrError> {
32    let mut modified = false;
33    modified |= local_copy_prop_prememcpy(context, analyses, function)?;
34    modified |= load_store_to_memcopy(context, function)?;
35    modified |= local_copy_prop(context, analyses, function)?;
36
37    Ok(modified)
38}
39
40fn local_copy_prop_prememcpy(
41    context: &mut Context,
42    analyses: &AnalysisResults,
43    function: Function,
44) -> Result<bool, IrError> {
45    struct InstInfo {
46        // The block containing the instruction.
47        block: Block,
48        // Relative (use only for comparison) position of instruction in `block`.
49        pos: usize,
50    }
51
52    // If the analysis result is incomplete we cannot do any safe optimizations here.
53    // Calculating the candidates below relies on complete result of an escape analysis.
54    let escaped_symbols = match analyses.get_analysis_result(function) {
55        EscapedSymbols::Complete(syms) => syms,
56        EscapedSymbols::Incomplete(_) => return Ok(false),
57    };
58
59    // All instructions that load from the `Symbol`.
60    let mut loads_map = FxHashMap::<Symbol, Vec<Value>>::default();
61    // All instructions that store to the `Symbol`.
62    let mut stores_map = FxHashMap::<Symbol, Vec<Value>>::default();
63    // All load and store instructions.
64    let mut instr_info_map = FxHashMap::<Value, InstInfo>::default();
65
66    for (pos, (block, inst)) in function.instruction_iter(context).enumerate() {
67        let info = || InstInfo { block, pos };
68        let inst_e = inst.get_instruction(context).unwrap();
69        match inst_e {
70            Instruction {
71                op: InstOp::Load(src_val_ptr),
72                ..
73            } => {
74                if let Some(local) = get_referred_symbol(context, *src_val_ptr) {
75                    loads_map
76                        .entry(local)
77                        .and_modify(|loads| loads.push(inst))
78                        .or_insert(vec![inst]);
79                    instr_info_map.insert(inst, info());
80                }
81            }
82            Instruction {
83                op: InstOp::Store { dst_val_ptr, .. },
84                ..
85            } => {
86                if let Some(local) = get_referred_symbol(context, *dst_val_ptr) {
87                    stores_map
88                        .entry(local)
89                        .and_modify(|stores| stores.push(inst))
90                        .or_insert(vec![inst]);
91                    instr_info_map.insert(inst, info());
92                }
93            }
94            _ => (),
95        }
96    }
97
98    let mut to_delete = FxHashSet::<Value>::default();
99    // Candidates for replacements. The map's key `Symbol` is the
100    // destination `Symbol` that can be replaced with the
101    // map's value `Symbol`, the source.
102    // Replacement is possible (among other criteria explained below)
103    // only if the Store of the source is the only storing to the destination.
104    let candidates: FxHashMap<Symbol, Symbol> = function
105        .instruction_iter(context)
106        .enumerate()
107        .filter_map(|(pos, (block, instr_val))| {
108            // 1. Go through all the Store instructions whose source is
109            // a Load instruction...
110            instr_val
111                .get_instruction(context)
112                .and_then(|instr| {
113                    // Is the instruction a Store?
114                    if let Instruction {
115                        op:
116                            InstOp::Store {
117                                dst_val_ptr,
118                                stored_val,
119                            },
120                        ..
121                    } = instr
122                    {
123                        get_gep_symbol(context, *dst_val_ptr).and_then(|dst_local| {
124                            stored_val
125                                .get_instruction(context)
126                                .map(|src_instr| (src_instr, stored_val, dst_local))
127                        })
128                    } else {
129                        None
130                    }
131                })
132                .and_then(|(src_instr, stored_val, dst_local)| {
133                    // Is the Store source a Load?
134                    if let Instruction {
135                        op: InstOp::Load(src_val_ptr),
136                        ..
137                    } = src_instr
138                    {
139                        get_gep_symbol(context, *src_val_ptr)
140                            .map(|src_local| (stored_val, dst_local, src_local))
141                    } else {
142                        None
143                    }
144                })
145                .and_then(|(src_load, dst_local, src_local)| {
146                    // 2. ... and pick the (dest_local, src_local) pairs that fulfill the
147                    //    below criteria, in other words, where `dest_local` can be
148                    //    replaced with `src_local`.
149                    let (temp_empty1, temp_empty2, temp_empty3) = (vec![], vec![], vec![]);
150                    let dst_local_stores = stores_map.get(&dst_local).unwrap_or(&temp_empty1);
151                    let src_local_stores = stores_map.get(&src_local).unwrap_or(&temp_empty2);
152                    let dst_local_loads = loads_map.get(&dst_local).unwrap_or(&temp_empty3);
153                    // This must be the only store of dst_local.
154                    if dst_local_stores.len() != 1 || dst_local_stores[0] != instr_val
155                        ||
156                        // All stores of src_local must be in the same block, prior to src_load.
157                        !src_local_stores.iter().all(|store_val|{
158                            let instr_info = instr_info_map.get(store_val).unwrap();
159                            let src_load_info = instr_info_map.get(src_load).unwrap();
160                            instr_info.block == block && instr_info.pos < src_load_info.pos
161                        })
162                        ||
163                        // All loads of dst_local must be after this instruction, in the same block.
164                        !dst_local_loads.iter().all(|load_val| {
165                            let instr_info = instr_info_map.get(load_val).unwrap();
166                            instr_info.block == block && instr_info.pos > pos
167                        })
168                        // We don't deal with symbols that escape.
169                        || escaped_symbols.contains(&dst_local)
170                        || escaped_symbols.contains(&src_local)
171                        // We don't deal part copies.
172                        || dst_local.get_type(context) != src_local.get_type(context)
173                        // We don't replace the destination when it's an arg.
174                        || matches!(dst_local, Symbol::Arg(_))
175                    {
176                        None
177                    } else {
178                        to_delete.insert(instr_val);
179                        Some((dst_local, src_local))
180                    }
181                })
182        })
183        .collect();
184
185    // If we have A replaces B and B replaces C, then A must replace C also.
186    // Recursively searches for the final replacement for the `local`.
187    // Returns `None` if the `local` cannot be replaced.
188    fn get_replace_with(candidates: &FxHashMap<Symbol, Symbol>, local: &Symbol) -> Option<Symbol> {
189        candidates
190            .get(local)
191            .map(|replace_with| get_replace_with(candidates, replace_with).unwrap_or(*replace_with))
192    }
193
194    // If the source is an Arg, we replace uses of destination with Arg.
195    // Otherwise (`get_local`), we replace the local symbol in-place.
196    enum ReplaceWith {
197        InPlaceLocal(LocalVar),
198        Value(Value),
199    }
200
201    // Because we can't borrow context for both iterating and replacing, do it in 2 steps.
202    // `replaces` are the original GetLocal instructions with the corresponding replacements
203    // of their arguments.
204    let replaces: Vec<_> = function
205        .instruction_iter(context)
206        .filter_map(|(_block, value)| match value.get_instruction(context) {
207            Some(Instruction {
208                op: InstOp::GetLocal(local),
209                ..
210            }) => get_replace_with(&candidates, &Symbol::Local(*local)).map(|replace_with| {
211                (
212                    value,
213                    match replace_with {
214                        Symbol::Local(local) => ReplaceWith::InPlaceLocal(local),
215                        Symbol::Arg(ba) => {
216                            ReplaceWith::Value(ba.block.get_arg(context, ba.idx).unwrap())
217                        }
218                    },
219                )
220            }),
221            _ => None,
222        })
223        .collect();
224
225    let mut value_replace = FxHashMap::<Value, Value>::default();
226    for (value, replace_with) in replaces.into_iter() {
227        match replace_with {
228            ReplaceWith::InPlaceLocal(replacement_var) => {
229                let Some(&Instruction {
230                    op: InstOp::GetLocal(redundant_var),
231                    parent,
232                }) = value.get_instruction(context)
233                else {
234                    panic!("earlier match now fails");
235                };
236                if redundant_var.is_mutable(context) {
237                    replacement_var.set_mutable(context, true);
238                }
239                value.replace(
240                    context,
241                    ValueDatum::Instruction(Instruction {
242                        op: InstOp::GetLocal(replacement_var),
243                        parent,
244                    }),
245                )
246            }
247            ReplaceWith::Value(replace_with) => {
248                value_replace.insert(value, replace_with);
249            }
250        }
251    }
252    function.replace_values(context, &value_replace, None);
253
254    // Delete stores to the replaced local.
255    let blocks: Vec<Block> = function.block_iter(context).collect();
256    for block in blocks {
257        block.remove_instructions(context, |value| to_delete.contains(&value));
258    }
259    Ok(true)
260}
261
262/// Copy propagation of `memcpy`s within a block.
263fn local_copy_prop(
264    context: &mut Context,
265    analyses: &AnalysisResults,
266    function: Function,
267) -> Result<bool, IrError> {
268    // If the analysis result is incomplete we cannot do any safe optimizations here.
269    // The `gen_new_copy` and `process_load` functions below rely on the fact that the
270    // analyzed symbols do not escape, something we cannot guarantee in case of
271    // an incomplete collection of escaped symbols.
272    let escaped_symbols = match analyses.get_analysis_result(function) {
273        EscapedSymbols::Complete(syms) => syms,
274        EscapedSymbols::Incomplete(_) => return Ok(false),
275    };
276
277    // Currently (as we scan a block) available `memcpy`s.
278    let mut available_copies: FxHashSet<Value>;
279    // Map a symbol to the available `memcpy`s of which it's a source.
280    let mut src_to_copies: FxIndexMap<Symbol, FxIndexSet<Value>>;
281    // Map a symbol to the available `memcpy`s of which it's a destination.
282    // (multiple `memcpy`s for the same destination may be available when
283    // they are partial / field writes, and don't alias).
284    let mut dest_to_copies: FxIndexMap<Symbol, FxIndexSet<Value>>;
285
286    // If a value (symbol) is found to be defined, remove it from our tracking.
287    fn kill_defined_symbol(
288        context: &Context,
289        value: Value,
290        len: u64,
291        available_copies: &mut FxHashSet<Value>,
292        src_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
293        dest_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
294    ) {
295        match get_referred_symbols(context, value) {
296            ReferredSymbols::Complete(rs) => {
297                for sym in rs {
298                    if let Some(copies) = src_to_copies.get_mut(&sym) {
299                        for copy in &*copies {
300                            let (_, src_ptr, copy_size) = deconstruct_memcpy(context, *copy);
301                            if memory_utils::may_alias(context, value, len, src_ptr, copy_size) {
302                                available_copies.remove(copy);
303                            }
304                        }
305                    }
306                    if let Some(copies) = dest_to_copies.get_mut(&sym) {
307                        for copy in &*copies {
308                            let (dest_ptr, copy_size) = match copy.get_instruction(context).unwrap()
309                            {
310                                Instruction {
311                                    op:
312                                        InstOp::MemCopyBytes {
313                                            dst_val_ptr,
314                                            src_val_ptr: _,
315                                            byte_len,
316                                        },
317                                    ..
318                                } => (*dst_val_ptr, *byte_len),
319                                Instruction {
320                                    op:
321                                        InstOp::MemCopyVal {
322                                            dst_val_ptr,
323                                            src_val_ptr: _,
324                                        },
325                                    ..
326                                } => (
327                                    *dst_val_ptr,
328                                    memory_utils::pointee_size(context, *dst_val_ptr),
329                                ),
330                                _ => panic!("Unexpected copy instruction"),
331                            };
332                            if memory_utils::may_alias(context, value, len, dest_ptr, copy_size) {
333                                available_copies.remove(copy);
334                            }
335                        }
336                    }
337                }
338                // Update src_to_copies and dest_to_copies to remove every copy not in available_copies.
339                src_to_copies.retain(|_, copies| {
340                    copies.retain(|copy| available_copies.contains(copy));
341                    !copies.is_empty()
342                });
343                dest_to_copies.retain(|_, copies| {
344                    copies.retain(|copy| available_copies.contains(copy));
345                    !copies.is_empty()
346                });
347            }
348            ReferredSymbols::Incomplete(_) => {
349                // The only safe thing we can do is to clear all information.
350                available_copies.clear();
351                src_to_copies.clear();
352                dest_to_copies.clear();
353            }
354        }
355    }
356
357    #[allow(clippy::too_many_arguments)]
358    fn gen_new_copy(
359        context: &Context,
360        escaped_symbols: &FxHashSet<Symbol>,
361        copy_inst: Value,
362        dst_val_ptr: Value,
363        src_val_ptr: Value,
364        available_copies: &mut FxHashSet<Value>,
365        src_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
366        dest_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
367    ) {
368        if let (Some(dst_sym), Some(src_sym)) = (
369            get_gep_symbol(context, dst_val_ptr),
370            get_gep_symbol(context, src_val_ptr),
371        ) {
372            if escaped_symbols.contains(&dst_sym) || escaped_symbols.contains(&src_sym) {
373                return;
374            }
375            dest_to_copies
376                .entry(dst_sym)
377                .and_modify(|set| {
378                    set.insert(copy_inst);
379                })
380                .or_insert([copy_inst].into_iter().collect());
381            src_to_copies
382                .entry(src_sym)
383                .and_modify(|set| {
384                    set.insert(copy_inst);
385                })
386                .or_insert([copy_inst].into_iter().collect());
387            available_copies.insert(copy_inst);
388        }
389    }
390
391    // Deconstruct a memcpy into (dst_val_ptr, src_val_ptr, copy_len).
392    fn deconstruct_memcpy(context: &Context, inst: Value) -> (Value, Value, u64) {
393        match inst.get_instruction(context).unwrap() {
394            Instruction {
395                op:
396                    InstOp::MemCopyBytes {
397                        dst_val_ptr,
398                        src_val_ptr,
399                        byte_len,
400                    },
401                ..
402            } => (*dst_val_ptr, *src_val_ptr, *byte_len),
403            Instruction {
404                op:
405                    InstOp::MemCopyVal {
406                        dst_val_ptr,
407                        src_val_ptr,
408                    },
409                ..
410            } => (
411                *dst_val_ptr,
412                *src_val_ptr,
413                memory_utils::pointee_size(context, *dst_val_ptr),
414            ),
415            _ => unreachable!("Only memcpy instructions handled"),
416        }
417    }
418
419    struct ReplGep {
420        base: Symbol,
421        elem_ptr_ty: Type,
422        indices: Vec<Value>,
423    }
424    enum Replacement {
425        OldGep(Value),
426        NewGep(ReplGep),
427    }
428
429    fn process_load(
430        context: &Context,
431        escaped_symbols: &FxHashSet<Symbol>,
432        inst: Value,
433        src_val_ptr: Value,
434        dest_to_copies: &FxIndexMap<Symbol, FxIndexSet<Value>>,
435        replacements: &mut FxHashMap<Value, (Value, Replacement)>,
436    ) -> bool {
437        // For every `memcpy` that src_val_ptr is a destination of,
438        // check if we can do the load from the source of that memcpy.
439        if let Some(src_sym) = get_referred_symbol(context, src_val_ptr) {
440            if escaped_symbols.contains(&src_sym) {
441                return false;
442            }
443            for memcpy in dest_to_copies
444                .get(&src_sym)
445                .iter()
446                .flat_map(|set| set.iter())
447            {
448                let (dst_ptr_memcpy, src_ptr_memcpy, copy_len) =
449                    deconstruct_memcpy(context, *memcpy);
450                // If the location where we're loading from exactly matches the destination of
451                // the memcpy, just load from the source pointer of the memcpy.
452                // TODO: In both the arms below, we check that the pointer type
453                // matches. This isn't really needed as the copy happens and the
454                // data we want is safe to access. But we just don't know how to
455                // generate the right GEP always. So that's left for another day.
456                if memory_utils::must_alias(
457                    context,
458                    src_val_ptr,
459                    memory_utils::pointee_size(context, src_val_ptr),
460                    dst_ptr_memcpy,
461                    copy_len,
462                ) {
463                    // Replace src_val_ptr with src_ptr_memcpy.
464                    if src_val_ptr.get_type(context) == src_ptr_memcpy.get_type(context) {
465                        replacements
466                            .insert(inst, (src_val_ptr, Replacement::OldGep(src_ptr_memcpy)));
467                        return true;
468                    }
469                } else {
470                    // if the memcpy copies the entire symbol, we could
471                    // insert a new GEP from the source of the memcpy.
472                    if let (Some(memcpy_src_sym), Some(memcpy_dst_sym), Some(new_indices)) = (
473                        get_gep_symbol(context, src_ptr_memcpy),
474                        get_gep_symbol(context, dst_ptr_memcpy),
475                        memory_utils::combine_indices(context, src_val_ptr),
476                    ) {
477                        let memcpy_src_sym_type = memcpy_src_sym
478                            .get_type(context)
479                            .get_pointee_type(context)
480                            .unwrap();
481                        let memcpy_dst_sym_type = memcpy_dst_sym
482                            .get_type(context)
483                            .get_pointee_type(context)
484                            .unwrap();
485                        if memcpy_src_sym_type == memcpy_dst_sym_type
486                            && memcpy_dst_sym_type.size(context).in_bytes() == copy_len
487                        {
488                            replacements.insert(
489                                inst,
490                                (
491                                    src_val_ptr,
492                                    Replacement::NewGep(ReplGep {
493                                        base: memcpy_src_sym,
494                                        elem_ptr_ty: src_val_ptr.get_type(context).unwrap(),
495                                        indices: new_indices,
496                                    }),
497                                ),
498                            );
499                            return true;
500                        }
501                    }
502                }
503            }
504        }
505
506        false
507    }
508
509    let mut modified = false;
510    for block in function.block_iter(context) {
511        // A `memcpy` itself has a `load`, so we can `process_load` on it.
512        // If now, we've marked the source of this `memcpy` for optimization,
513        // it itself cannot be "generated" as a new candidate `memcpy`.
514        // This is the reason we run a loop on the block till there's no more
515        // optimization possible. We could track just the changes and do it
516        // all in one go, but that would complicate the algorithm. So I've
517        // marked this as a TODO for now (#4600).
518        loop {
519            available_copies = FxHashSet::default();
520            src_to_copies = IndexMap::default();
521            dest_to_copies = IndexMap::default();
522
523            // Replace the load/memcpy source pointer with something else.
524            let mut replacements = FxHashMap::default();
525
526            fn kill_escape_args(
527                context: &Context,
528                args: &Vec<Value>,
529                available_copies: &mut FxHashSet<Value>,
530                src_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
531                dest_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
532            ) {
533                for arg in args {
534                    match get_referred_symbols(context, *arg) {
535                        ReferredSymbols::Complete(rs) => {
536                            let max_size = rs
537                                .iter()
538                                .filter_map(|sym| {
539                                    sym.get_type(context)
540                                        .get_pointee_type(context)
541                                        .map(|pt| pt.size(context).in_bytes())
542                                })
543                                .max()
544                                .unwrap_or(0);
545                            kill_defined_symbol(
546                                context,
547                                *arg,
548                                max_size,
549                                available_copies,
550                                src_to_copies,
551                                dest_to_copies,
552                            );
553                        }
554                        ReferredSymbols::Incomplete(_) => {
555                            // The only safe thing we can do is to clear all information.
556                            available_copies.clear();
557                            src_to_copies.clear();
558                            dest_to_copies.clear();
559
560                            break;
561                        }
562                    }
563                }
564            }
565
566            for inst in block.instruction_iter(context) {
567                match inst.get_instruction(context).unwrap() {
568                    Instruction {
569                        op: InstOp::Call(callee, args),
570                        ..
571                    } => {
572                        let (immutable_args, mutable_args): (Vec<_>, Vec<_>) =
573                            args.iter().enumerate().partition_map(|(arg_idx, arg)| {
574                                if callee.is_arg_immutable(context, arg_idx) {
575                                    Either::Left(*arg)
576                                } else {
577                                    Either::Right(*arg)
578                                }
579                            });
580                        // whichever args may get mutated, we kill them.
581                        kill_escape_args(
582                            context,
583                            &mutable_args,
584                            &mut available_copies,
585                            &mut src_to_copies,
586                            &mut dest_to_copies,
587                        );
588                        // args that aren't mutated can be treated as a "load" (for the purposes
589                        // of optimization).
590                        for arg in immutable_args {
591                            process_load(
592                                context,
593                                escaped_symbols,
594                                inst,
595                                arg,
596                                &dest_to_copies,
597                                &mut replacements,
598                            );
599                        }
600                    }
601                    Instruction {
602                        op: InstOp::AsmBlock(_, args),
603                        ..
604                    } => {
605                        let args = args.iter().filter_map(|arg| arg.initializer).collect();
606                        kill_escape_args(
607                            context,
608                            &args,
609                            &mut available_copies,
610                            &mut src_to_copies,
611                            &mut dest_to_copies,
612                        );
613                    }
614                    Instruction {
615                        op: InstOp::IntToPtr(_, _),
616                        ..
617                    } => {
618                        // The only safe thing we can do is to clear all information.
619                        available_copies.clear();
620                        src_to_copies.clear();
621                        dest_to_copies.clear();
622                    }
623                    Instruction {
624                        op: InstOp::Load(src_val_ptr),
625                        ..
626                    } => {
627                        process_load(
628                            context,
629                            escaped_symbols,
630                            inst,
631                            *src_val_ptr,
632                            &dest_to_copies,
633                            &mut replacements,
634                        );
635                    }
636                    Instruction {
637                        op: InstOp::MemCopyBytes { .. } | InstOp::MemCopyVal { .. },
638                        ..
639                    } => {
640                        let (dst_val_ptr, src_val_ptr, copy_len) =
641                            deconstruct_memcpy(context, inst);
642                        kill_defined_symbol(
643                            context,
644                            dst_val_ptr,
645                            copy_len,
646                            &mut available_copies,
647                            &mut src_to_copies,
648                            &mut dest_to_copies,
649                        );
650                        // If this memcpy itself can be optimized, we do just that, and not "gen" a new one.
651                        if !process_load(
652                            context,
653                            escaped_symbols,
654                            inst,
655                            src_val_ptr,
656                            &dest_to_copies,
657                            &mut replacements,
658                        ) {
659                            gen_new_copy(
660                                context,
661                                escaped_symbols,
662                                inst,
663                                dst_val_ptr,
664                                src_val_ptr,
665                                &mut available_copies,
666                                &mut src_to_copies,
667                                &mut dest_to_copies,
668                            );
669                        }
670                    }
671                    Instruction {
672                        op:
673                            InstOp::Store {
674                                dst_val_ptr,
675                                stored_val: _,
676                            },
677                        ..
678                    } => {
679                        kill_defined_symbol(
680                            context,
681                            *dst_val_ptr,
682                            memory_utils::pointee_size(context, *dst_val_ptr),
683                            &mut available_copies,
684                            &mut src_to_copies,
685                            &mut dest_to_copies,
686                        );
687                    }
688                    Instruction {
689                        op:
690                            InstOp::FuelVm(
691                                FuelVmInstruction::WideBinaryOp { result, .. }
692                                | FuelVmInstruction::WideUnaryOp { result, .. }
693                                | FuelVmInstruction::WideModularOp { result, .. }
694                                | FuelVmInstruction::StateLoadQuadWord {
695                                    load_val: result, ..
696                                },
697                            ),
698                        ..
699                    } => {
700                        kill_defined_symbol(
701                            context,
702                            *result,
703                            memory_utils::pointee_size(context, *result),
704                            &mut available_copies,
705                            &mut src_to_copies,
706                            &mut dest_to_copies,
707                        );
708                    }
709                    _ => (),
710                }
711            }
712
713            if replacements.is_empty() {
714                break;
715            } else {
716                modified = true;
717            }
718
719            // If we have any NewGep replacements, insert those new GEPs into the block.
720            // Since the new instructions need to be just before the value load that they're
721            // going to be used in, we copy all the instructions into a new vec
722            // and just replace the contents of the basic block.
723            let mut new_insts = vec![];
724            for inst in block.instruction_iter(context) {
725                if let Some(replacement) = replacements.remove(&inst) {
726                    let (to_replace, replacement) = match replacement {
727                        (to_replace, Replacement::OldGep(v)) => (to_replace, v),
728                        (
729                            to_replace,
730                            Replacement::NewGep(ReplGep {
731                                base,
732                                elem_ptr_ty,
733                                indices,
734                            }),
735                        ) => {
736                            let base = match base {
737                                Symbol::Local(local) => {
738                                    let base = Value::new_instruction(
739                                        context,
740                                        block,
741                                        InstOp::GetLocal(local),
742                                    );
743                                    new_insts.push(base);
744                                    base
745                                }
746                                Symbol::Arg(block_arg) => {
747                                    block_arg.block.get_arg(context, block_arg.idx).unwrap()
748                                }
749                            };
750                            let v = Value::new_instruction(
751                                context,
752                                block,
753                                InstOp::GetElemPtr {
754                                    base,
755                                    elem_ptr_ty,
756                                    indices,
757                                },
758                            );
759                            new_insts.push(v);
760                            (to_replace, v)
761                        }
762                    };
763                    match inst.get_instruction_mut(context) {
764                        Some(Instruction {
765                            op: InstOp::Load(ref mut src_val_ptr),
766                            ..
767                        })
768                        | Some(Instruction {
769                            op:
770                                InstOp::MemCopyBytes {
771                                    ref mut src_val_ptr,
772                                    ..
773                                },
774                            ..
775                        })
776                        | Some(Instruction {
777                            op:
778                                InstOp::MemCopyVal {
779                                    ref mut src_val_ptr,
780                                    ..
781                                },
782                            ..
783                        }) => {
784                            assert!(to_replace == *src_val_ptr);
785                            *src_val_ptr = replacement
786                        }
787                        Some(Instruction {
788                            op: InstOp::Call(_callee, args),
789                            ..
790                        }) => {
791                            for arg in args {
792                                if *arg == to_replace {
793                                    *arg = replacement;
794                                }
795                            }
796                        }
797                        _ => panic!("Unexpected instruction type"),
798                    }
799                }
800                new_insts.push(inst);
801            }
802
803            // Replace the basic block contents with what we just built.
804            block.take_body(context, new_insts);
805        }
806    }
807
808    Ok(modified)
809}
810
811struct Candidate {
812    load_val: Value,
813    store_val: Value,
814    dst_ptr: Value,
815    src_ptr: Value,
816}
817
818enum CandidateKind {
819    /// If aggregates are clobbered b/w a load and the store, we still need to,
820    /// for correctness (because asmgen cannot handle aggregate loads and stores)
821    /// do the memcpy. So we insert a memcpy to a temporary stack location right after
822    /// the load, and memcpy it to the store pointer at the point of store.
823    ClobberedNoncopyType(Candidate),
824    NonClobbered(Candidate),
825}
826
827// Is (an alias of) src_ptr clobbered on any path from load_val to store_val?
828fn is_clobbered(
829    context: &Context,
830    Candidate {
831        load_val,
832        store_val,
833        dst_ptr,
834        src_ptr,
835    }: &Candidate,
836) -> bool {
837    let store_block = store_val.get_instruction(context).unwrap().parent;
838
839    let mut iter = store_block
840        .instruction_iter(context)
841        .rev()
842        .skip_while(|i| i != store_val);
843    assert!(iter.next().unwrap() == *store_val);
844
845    let ReferredSymbols::Complete(src_symbols) = get_referred_symbols(context, *src_ptr) else {
846        return true;
847    };
848
849    let ReferredSymbols::Complete(dst_symbols) = get_referred_symbols(context, *dst_ptr) else {
850        return true;
851    };
852
853    // If the source and destination may have an overlap, we'll end up generating a mcp
854    // with overlapping source/destination which is not allowed.
855    if src_symbols.intersection(&dst_symbols).next().is_some() {
856        return true;
857    }
858
859    // Scan backwards till we encounter load_val, checking if
860    // any store aliases with src_ptr.
861    let mut worklist: Vec<(Block, Box<dyn Iterator<Item = Value>>)> =
862        vec![(store_block, Box::new(iter))];
863    let mut visited = FxHashSet::default();
864    'next_job: while let Some((block, iter)) = worklist.pop() {
865        visited.insert(block);
866        for inst in iter {
867            if inst == *load_val || inst == *store_val {
868                // We don't need to go beyond either the source load or the candidate store.
869                continue 'next_job;
870            }
871            let stored_syms = get_stored_symbols(context, inst);
872            if let ReferredSymbols::Complete(syms) = stored_syms {
873                if syms.iter().any(|sym| src_symbols.contains(sym)) {
874                    return true;
875                }
876            } else {
877                return true;
878            }
879        }
880        for pred in block.pred_iter(context) {
881            if !visited.contains(pred) {
882                worklist.push((
883                    *pred,
884                    Box::new(pred.instruction_iter(context).rev().skip_while(|_| false)),
885                ));
886            }
887        }
888    }
889
890    false
891}
892
893// This is a copy of sway_core::asm_generation::fuel::fuel_asm_builder::FuelAsmBuilder::is_copy_type.
894fn is_copy_type(ty: &Type, context: &Context) -> bool {
895    ty.is_unit(context)
896        || ty.is_never(context)
897        || ty.is_bool(context)
898        || ty.is_ptr(context)
899        || ty.get_uint_width(context).map(|x| x < 256).unwrap_or(false)
900}
901
902fn load_store_to_memcopy(context: &mut Context, function: Function) -> Result<bool, IrError> {
903    // Find any `store`s of `load`s.  These can be replaced with `mem_copy` and are especially
904    // important for non-copy types on architectures which don't support loading them.
905    let candidates = function
906        .instruction_iter(context)
907        .filter_map(|(_, store_instr_val)| {
908            store_instr_val
909                .get_instruction(context)
910                .and_then(|instr| {
911                    // Is the instruction a Store?
912                    if let Instruction {
913                        op:
914                            InstOp::Store {
915                                dst_val_ptr,
916                                stored_val,
917                            },
918                        ..
919                    } = instr
920                    {
921                        stored_val
922                            .get_instruction(context)
923                            .map(|src_instr| (*stored_val, src_instr, dst_val_ptr))
924                    } else {
925                        None
926                    }
927                })
928                .and_then(|(src_instr_val, src_instr, dst_val_ptr)| {
929                    // Is the Store source a Load?
930                    if let Instruction {
931                        op: InstOp::Load(src_val_ptr),
932                        ..
933                    } = src_instr
934                    {
935                        Some(Candidate {
936                            load_val: src_instr_val,
937                            store_val: store_instr_val,
938                            dst_ptr: *dst_val_ptr,
939                            src_ptr: *src_val_ptr,
940                        })
941                    } else {
942                        None
943                    }
944                })
945                .and_then(|candidate @ Candidate { dst_ptr, .. }| {
946                    // Check that there's no path from load_val to store_val that might overwrite src_ptr.
947                    if !is_clobbered(context, &candidate) {
948                        Some(CandidateKind::NonClobbered(candidate))
949                    } else if !is_copy_type(&dst_ptr.match_ptr_type(context).unwrap(), context) {
950                        Some(CandidateKind::ClobberedNoncopyType(candidate))
951                    } else {
952                        None
953                    }
954                })
955        })
956        .collect::<Vec<_>>();
957
958    if candidates.is_empty() {
959        return Ok(false);
960    }
961
962    for candidate in candidates {
963        match candidate {
964            CandidateKind::ClobberedNoncopyType(Candidate {
965                load_val,
966                store_val,
967                dst_ptr,
968                src_ptr,
969            }) => {
970                let load_block = load_val.get_instruction(context).unwrap().parent;
971                let temp = function.new_unique_local_var(
972                    context,
973                    "__aggr_memcpy_0".into(),
974                    src_ptr.match_ptr_type(context).unwrap(),
975                    None,
976                    true,
977                );
978                let temp_local =
979                    Value::new_instruction(context, load_block, InstOp::GetLocal(temp));
980                let to_temp = Value::new_instruction(
981                    context,
982                    load_block,
983                    InstOp::MemCopyVal {
984                        dst_val_ptr: temp_local,
985                        src_val_ptr: src_ptr,
986                    },
987                );
988                let mut inserter = InstructionInserter::new(
989                    context,
990                    load_block,
991                    crate::InsertionPosition::After(load_val),
992                );
993                inserter.insert_slice(&[temp_local, to_temp]);
994
995                let store_block = store_val.get_instruction(context).unwrap().parent;
996                let mem_copy_val = Value::new_instruction(
997                    context,
998                    store_block,
999                    InstOp::MemCopyVal {
1000                        dst_val_ptr: dst_ptr,
1001                        src_val_ptr: temp_local,
1002                    },
1003                );
1004                store_block.replace_instruction(context, store_val, mem_copy_val, true)?;
1005            }
1006            CandidateKind::NonClobbered(Candidate {
1007                dst_ptr: dst_val_ptr,
1008                src_ptr: src_val_ptr,
1009                store_val,
1010                ..
1011            }) => {
1012                let store_block = store_val.get_instruction(context).unwrap().parent;
1013                let mem_copy_val = Value::new_instruction(
1014                    context,
1015                    store_block,
1016                    InstOp::MemCopyVal {
1017                        dst_val_ptr,
1018                        src_val_ptr,
1019                    },
1020                );
1021                store_block.replace_instruction(context, store_val, mem_copy_val, true)?;
1022            }
1023        }
1024    }
1025
1026    Ok(true)
1027}