swamp_code_gen/
call.rs

1/*
2 * Copyright (c) Peter Bjorklund. All rights reserved. https://github.com/swamp/swamp
3 * Licensed under the MIT License. See LICENSE in the project root for license information.
4 */
5//! `CodeBuilder` helper functions for function calls and arguments.
6
7use crate::code_bld::CodeBuilder;
8use crate::ctx::Context;
9use crate::err::Error;
10use crate::reg_pool::RegisterPool;
11use crate::state::FunctionFixup;
12use crate::{
13    ArgumentAndTempScope, MAX_REGISTER_INDEX_FOR_PARAMETERS, RepresentationOfRegisters,
14    SpilledRegisterRegion, err,
15};
16use source_map_node::Node;
17use std::collections::HashSet;
18use swamp_semantic::{ArgumentExpression, InternalFunctionDefinitionRef, pretty_module_name};
19use swamp_types::TypeKind;
20use swamp_types::prelude::Signature;
21use swamp_vm_isa::REG_ON_FRAME_SIZE;
22use swamp_vm_types::FrameMemoryRegion;
23use swamp_vm_types::types::{BasicTypeRef, Place, TypedRegister, VmType, pointer_type};
24
25pub struct CopyArgument {
26    pub canonical_target: TypedRegister,
27    pub source_temporary: TypedRegister,
28}
29pub struct EmitArgumentInfo {
30    pub argument_and_temp_scope: ArgumentAndTempScope,
31    pub copy_back_of_registers_mutated_by_callee: Vec<MutableReturnReg>,
32}
33
34pub struct MutableReturnReg {
35    pub target_location_after_call: Place,
36    pub parameter_reg: TypedRegister,
37}
38
39impl CodeBuilder<'_> {
40    // TODO: for mutable arguments we want to leave output part of the spilling
41    // - store abi registers (do not save r0 if that is the destination for the call itself, since we want to clobber it in that case)
42    // - store temp registers (including the newly created replacement base regs)
43    // - call (callee is allowed to clobber r0, r1, r2, r128)
44    // - restore temp registers (so our r128 is valid again)
45    // - issue copy backs of the mutables scalars. (replacement) base reg is valid, and the r1 and r2 are clobbered, which is good.
46    // - restore abi registers
47    // - copy back of r0 for immutable scalars to target register,  unless the target was r0 itself, in that case it was left out of store abi register mask
48    pub fn spill_required_registers(&mut self, node: &Node, comment: &str) -> ArgumentAndTempScope {
49        const ABI_ARGUMENT_RETURN_AND_ARGUMENT_REGISTERS: usize =
50            MAX_REGISTER_INDEX_FOR_PARAMETERS as usize + 1; // r0-r6
51        const ABI_ARGUMENT_MASK: u8 =
52            ((1u16 << ABI_ARGUMENT_RETURN_AND_ARGUMENT_REGISTERS) - 1) as u8;
53
54        let abi_parameter_frame_memory_region = self.temp_frame_space_for_register(
55            ABI_ARGUMENT_RETURN_AND_ARGUMENT_REGISTERS as u8,
56            &format!("emit abi arguments r0-r6 {comment}"),
57        );
58
59        self.builder.add_st_masked_regs_to_frame(
60            abi_parameter_frame_memory_region.addr,
61            ABI_ARGUMENT_MASK,
62            node,
63            "spill masked registers to stack frame memory.",
64        );
65
66        let abi_parameter_region = SpilledRegisterRegion {
67            registers: RepresentationOfRegisters::Mask(ABI_ARGUMENT_MASK),
68            frame_memory_region: abi_parameter_frame_memory_region,
69        };
70
71        let (first_temp_register_index, temp_register_probable_live_count) =
72            self.temp_registers.start_index_and_number_of_allocated();
73        debug_assert_eq!(first_temp_register_index, 128);
74
75        let temp_register_region = if temp_register_probable_live_count > 0 {
76            let temp_register_frame_memory_region = self.temp_frame_space_for_register(temp_register_probable_live_count, &format!("emit temp arguments from r{first_temp_register_index} count:{temp_register_probable_live_count} {comment}"));
77            let temp_register_region = SpilledRegisterRegion {
78                registers: RepresentationOfRegisters::Range {
79                    start_reg: first_temp_register_index,
80                    count: temp_register_probable_live_count,
81                },
82                frame_memory_region: temp_register_frame_memory_region,
83            };
84
85            self.builder.add_st_contiguous_regs_to_frame(
86                temp_register_frame_memory_region,
87                first_temp_register_index,
88                temp_register_probable_live_count,
89                node,
90                "spill contiguous range of registers to stack frame memory",
91            );
92            Some(temp_register_region)
93        } else {
94            None
95        };
96
97        ArgumentAndTempScope {
98            argument_registers: abi_parameter_region,
99            scratch_registers: temp_register_region,
100        }
101    }
102
103    fn emit_single_argument(
104        &mut self,
105        argument_expr: &ArgumentExpression,
106        argument_to_use: &TypedRegister,
107        target_canonical_argument_register: &TypedRegister,
108        parameter_basic_type: &BasicTypeRef,
109        copy_back_phase_one: &mut Vec<MutableReturnReg>,
110        node: &Node,
111        ctx: &Context,
112    ) {
113        match argument_expr {
114            ArgumentExpression::BorrowMutableReference(lvalue) => {
115                let original_destination = self.emit_lvalue_address(lvalue, ctx);
116
117                if parameter_basic_type.should_be_copied_back_when_mutable_arg_or_return() {
118                    // Load the primitive from memory
119                    self.emit_transfer_value_to_register(
120                        argument_to_use,
121                        &original_destination,
122                        node,
123                        "must get primitive from lvalue and pass as copy back (by value)",
124                    );
125
126                    // Add a copy back to the original location (base register will be restored by spill/restore)
127                    copy_back_phase_one.push(MutableReturnReg {
128                        target_location_after_call: original_destination,
129                        parameter_reg: target_canonical_argument_register.clone(),
130                    });
131                } else {
132                    let flattened_source_pointer_reg = self
133                        .emit_compute_effective_address_to_register(
134                            &original_destination,
135                            node,
136                            "flattened into absolute pointer",
137                        );
138                    self.builder.add_mov_reg(
139                        argument_to_use,
140                        &flattened_source_pointer_reg,
141                        node,
142                        "copy absolute address",
143                    );
144                }
145            }
146            ArgumentExpression::MaterializedExpression(expr) => {
147                if Self::rvalue_needs_memory_location_to_materialize_in(
148                    &mut self.state.layout_cache,
149                    expr,
150                ) {
151                    // Use the helper function to get a pointer to the temporary storage
152                    let temp_ptr = self.emit_scalar_rvalue_or_pointer_to_temporary(expr, ctx, true);
153
154                    self.builder.add_mov_reg(
155                        argument_to_use,
156                        &temp_ptr,
157                        node,
158                        "copy temporary storage address to argument register",
159                    );
160                } else {
161                    self.emit_expression_into_register(
162                        argument_to_use,
163                        expr,
164                        "argument expression into specific argument register",
165                        ctx,
166                    );
167                }
168            }
169
170            ArgumentExpression::Expression(expr) => {
171                // Normal case: expression can be materialized directly into register
172                self.emit_expression_into_register(
173                    argument_to_use,
174                    expr,
175                    "argument expression into specific argument register",
176                    ctx,
177                );
178            }
179            _ => panic!("what kind of argument is it"),
180        }
181    }
182
183    pub(crate) fn emit_arguments(
184        &mut self,
185        output_place: &Place,
186        node: &Node,
187        signature: &Signature,
188        self_variable: Option<&TypedRegister>,
189        arguments: &[ArgumentExpression],
190        is_host_call: bool,
191        ctx: &Context,
192    ) -> EmitArgumentInfo {
193        let mut copy_back_operations: Vec<MutableReturnReg> = Vec::new();
194        let has_return_value = !matches!(&*signature.return_type.kind, TypeKind::Unit);
195
196        // Step 1: Spill live registers before we start using ABI registers
197        let spill_scope = self.spill_required_registers(node, "spill before emit arguments");
198
199        assert!(
200            signature.parameters.len() <= MAX_REGISTER_INDEX_FOR_PARAMETERS.into(),
201            "signature is wrong {signature:?}"
202        );
203
204        // Step 3: Prepare argument registers and handle temporary register conflicts
205        let mut temp_to_abi_copies = Vec::new();
206        let mut argument_registers = RegisterPool::new(1, 6); // r1-r6 for arguments
207        let mut return_copy_arg: Option<CopyArgument> = None;
208        // Step 5: Handle return value setup
209        if has_return_value {
210            let return_basic_type = self.state.layout_cache.layout(&signature.return_type);
211
212            if return_basic_type.needs_hidden_pointer_as_return() {
213                // For aggregates: initialize the destination space first, then set up r0 as pointer to destination
214                let return_pointer_reg = self.emit_compute_effective_address_to_register(
215                    output_place,
216                    node,
217                    "r0: create an absolute pointer to r0 if needed",
218                );
219
220                let temp_reg = self.temp_registers.allocate(
221                    VmType::new_contained_in_register(pointer_type()),
222                    &format!("temporary argument for r0 '{}'", return_pointer_reg.comment),
223                );
224
225                self.builder.add_mov_reg(
226                    &temp_reg.register,
227                    &return_pointer_reg,
228                    node,
229                    "stash sret (dest addr) into temp",
230                );
231
232                let target_canonical_return_register =
233                    TypedRegister::new_vm_type(0, return_pointer_reg.ty);
234                let copy_argument = CopyArgument {
235                    canonical_target: target_canonical_return_register,
236                    source_temporary: temp_reg.register,
237                };
238                return_copy_arg = Some(copy_argument);
239            } else {
240                // For primitives: add r0 to copy-back list (function writes to r0, we copy to destination)
241                let r0 =
242                    TypedRegister::new_vm_type(0, VmType::new_unknown_placement(return_basic_type));
243                copy_back_operations.push(MutableReturnReg {
244                    target_location_after_call: output_place.clone(),
245                    parameter_reg: r0,
246                });
247            }
248        }
249
250        for (index_in_signature, type_for_parameter) in signature.parameters.iter().enumerate() {
251            let parameter_basic_type = self
252                .state
253                .layout_cache
254                .layout(&type_for_parameter.resolved_type);
255            let target_canonical_argument_register = argument_registers.alloc_register(
256                VmType::new_unknown_placement(parameter_basic_type.clone()),
257                &format!("{index_in_signature}:{}", type_for_parameter.name),
258            );
259
260            let argument_to_use = if self.argument_needs_to_be_in_a_temporary_register_first(
261                &target_canonical_argument_register,
262            ) {
263                let temp_reg = self.temp_registers.allocate(
264                    target_canonical_argument_register.ty.clone(),
265                    &format!(
266                        "temporary argument for '{}'",
267                        target_canonical_argument_register.comment
268                    ),
269                );
270                let copy_argument = CopyArgument {
271                    canonical_target: target_canonical_argument_register.clone(),
272                    source_temporary: temp_reg.register.clone(),
273                };
274                temp_to_abi_copies.push(copy_argument);
275                temp_reg.register
276            } else {
277                target_canonical_argument_register.clone()
278            };
279
280            // Handle self variable (first parameter) vs regular arguments
281            if index_in_signature == 0 && self_variable.is_some() {
282                let self_reg = self_variable.as_ref().unwrap();
283                if self_reg.index != argument_to_use.index {
284                    self.builder.add_mov_reg(
285                        &argument_to_use,
286                        self_reg,
287                        node,
288                        &format!(
289                            "move self_variable ({}) to first argument register",
290                            self_reg.ty
291                        ),
292                    );
293                }
294            } else {
295                // Regular argument - get from arguments array
296                let argument_vector_index = if self_variable.is_some() {
297                    index_in_signature - 1
298                } else {
299                    index_in_signature
300                };
301                let argument_expr_or_location = &arguments[argument_vector_index];
302
303                self.emit_single_argument(
304                    argument_expr_or_location,
305                    &argument_to_use,
306                    &target_canonical_argument_register,
307                    &parameter_basic_type,
308                    &mut copy_back_operations,
309                    node,
310                    ctx,
311                );
312            }
313        }
314
315        // Step 4: Copy from temporary registers to final ABI argument registers
316        if let Some(return_reg_copy_argument) = return_copy_arg {
317            self.builder.add_mov_reg(
318                &return_reg_copy_argument.canonical_target,
319                &return_reg_copy_argument.source_temporary,
320                node,
321                &"copy r0 in place before arguments".to_string(),
322            );
323        }
324
325        for (index, copy_argument) in temp_to_abi_copies.iter().enumerate() {
326            let parameter_in_signature = &signature.parameters[index];
327            self.builder.add_mov_reg(
328                &copy_argument.canonical_target,
329                &copy_argument.source_temporary,
330                node,
331                &format!(
332                    "copy argument {index} ({}) in place from temporary '{}'",
333                    parameter_in_signature.name, copy_argument.source_temporary.comment
334                ),
335            );
336        }
337
338        EmitArgumentInfo {
339            argument_and_temp_scope: spill_scope,
340            copy_back_of_registers_mutated_by_callee: copy_back_operations,
341        }
342    }
343
344    pub(crate) fn emit_post_call(
345        &mut self,
346        spilled_arguments: EmitArgumentInfo,
347        node: &Node,
348        comment: &str,
349    ) {
350        // Phase 1: Save current mutable parameter values to temporary safe space before registers get clobbered
351        let mut temp_saved_values = Vec::new();
352        for copy_back in &spilled_arguments.copy_back_of_registers_mutated_by_callee {
353            let temp_reg = self.temp_registers.allocate(
354                copy_back.parameter_reg.ty.clone(),
355                &format!(
356                    "temp save for copy-back of {}",
357                    copy_back.parameter_reg.comment
358                ),
359            );
360
361            self.builder.add_mov_reg(
362                temp_reg.register(),
363                &copy_back.parameter_reg,
364                node,
365                &format!(
366                    "save {} to temp before register restoration",
367                    copy_back.parameter_reg
368                ),
369            );
370
371            temp_saved_values.push((temp_reg, copy_back));
372        }
373
374        // Phase 2: Restore all spilled registers (temp registers first, then variables, then arguments)
375        if let Some(scratch_region) = spilled_arguments.argument_and_temp_scope.scratch_registers {
376            self.emit_restore_region(scratch_region, &HashSet::new(), node, comment);
377        }
378
379        // Restore argument registers - cool thing with this approach is that we don't have to bother with restoring some of them
380        self.emit_restore_region(
381            spilled_arguments.argument_and_temp_scope.argument_registers,
382            &HashSet::new(),
383            node,
384            comment,
385        );
386
387        // Phase 3: Copy from temporary safe registers to the final destinations
388        for (temp_reg, copy_back) in temp_saved_values {
389            let temp_source = Place::Register(temp_reg.register().clone());
390            self.emit_copy_value_between_places(
391                &copy_back.target_location_after_call,
392                &temp_source,
393                node,
394                "copy-back from temp to final destination",
395            );
396        }
397    }
398
399    #[allow(clippy::too_many_lines)]
400    pub fn emit_restore_region(
401        &mut self,
402        region: SpilledRegisterRegion,
403        output_destination_registers: &HashSet<u8>, // TODO: Remove this
404        node: &Node,
405        comment: &str,
406    ) {
407        match region.registers {
408            RepresentationOfRegisters::Individual(spilled_registers_list) => {
409                if !spilled_registers_list.is_empty() {
410                    let mut sorted_regs = spilled_registers_list;
411                    sorted_regs.sort_by_key(|reg| reg.index);
412
413                    // Filter out registers that are in output_destination_registers
414                    let filtered_regs: Vec<_> = sorted_regs
415                        .into_iter()
416                        .filter(|reg| !output_destination_registers.contains(&reg.index))
417                        .collect();
418
419                    if !filtered_regs.is_empty() {
420                        let mut i = 0;
421                        while i < filtered_regs.len() {
422                            let seq_start_idx = i;
423                            let start_reg = filtered_regs[i].index;
424                            let mut seq_length = 1;
425
426                            while i + 1 < filtered_regs.len()
427                                && filtered_regs[i + 1].index == filtered_regs[i].index + 1
428                            {
429                                seq_length += 1;
430                                i += 1;
431                            }
432
433                            let memory_offset = if seq_start_idx > 0 {
434                                (filtered_regs[seq_start_idx].index - filtered_regs[0].index)
435                                    as usize
436                                    * REG_ON_FRAME_SIZE.0 as usize
437                            } else {
438                                0
439                            };
440
441                            let specific_mem_location = FrameMemoryRegion {
442                                addr: region.frame_memory_region.addr
443                                    + swamp_vm_isa::MemoryOffset(memory_offset as u32),
444                                size: REG_ON_FRAME_SIZE,
445                            };
446
447                            self.builder.add_ld_contiguous_regs_from_frame(
448                                start_reg,
449                                specific_mem_location,
450                                seq_length,
451                                node,
452                                &format!(
453                                    "restoring r{}-r{} (sequence) {comment}",
454                                    start_reg,
455                                    start_reg + seq_length - 1
456                                ),
457                            );
458
459                            i += 1;
460                        }
461                    }
462                }
463            }
464
465            RepresentationOfRegisters::Mask(original_spill_mask) => {
466                let mut mask_to_actually_restore = original_spill_mask;
467
468                for i in 0..8 {
469                    let reg_idx = i as u8;
470                    if (original_spill_mask >> i) & 1 != 0
471                        && output_destination_registers.contains(&reg_idx)
472                    {
473                        mask_to_actually_restore &= !(1 << i); // Clear the bit: don't restore this one
474                    }
475                }
476
477                if mask_to_actually_restore != 0 {
478                    self.builder.add_ld_masked_regs_from_frame(
479                        mask_to_actually_restore,
480                        region.frame_memory_region,
481                        node,
482                        &format!("restore registers using mask {comment}"),
483                    );
484                }
485            }
486            RepresentationOfRegisters::Range { start_reg, count } => {
487                let base_mem_addr_of_spilled_range = region.frame_memory_region.addr;
488
489                // Find contiguous sequences of registers that need to be restored
490                let mut i = 0;
491                while i < count {
492                    while i < count && output_destination_registers.contains(&(start_reg + i)) {
493                        i += 1;
494                    }
495
496                    if i < count {
497                        let seq_start_reg = start_reg + i;
498                        let seq_start_offset = (i as usize) * REG_ON_FRAME_SIZE.0 as usize;
499                        let mut seq_length = 1;
500
501                        while i + seq_length < count
502                            && !output_destination_registers.contains(&(start_reg + i + seq_length))
503                        {
504                            seq_length += 1;
505                        }
506
507                        let specific_mem_location = FrameMemoryRegion {
508                            addr: base_mem_addr_of_spilled_range
509                                + swamp_vm_isa::MemoryOffset(seq_start_offset as u32),
510                            size: REG_ON_FRAME_SIZE,
511                        };
512
513                        self.builder.add_ld_contiguous_regs_from_frame(
514                            seq_start_reg,
515                            specific_mem_location,
516                            seq_length,
517                            node,
518                            &format!(
519                                "restoring spilled contiguous range of registers from stack frame r{}-r{} {comment}",
520                                seq_start_reg,
521                                seq_start_reg + seq_length - 1
522                            ),
523                        );
524
525                        i += seq_length;
526                    }
527                }
528            }
529        }
530    }
531
532    pub(crate) fn emit_call(
533        &mut self,
534        node: &Node,
535        internal_fn: &InternalFunctionDefinitionRef,
536        comment: &str,
537    ) {
538        let function_name = internal_fn.associated_with_type.as_ref().map_or_else(
539            || {
540                format!(
541                    "{}::{}",
542                    pretty_module_name(&internal_fn.defined_in_module_path),
543                    internal_fn.assigned_name
544                )
545            },
546            |associated_with_type| {
547                format!(
548                    "{}::{}:{}",
549                    pretty_module_name(&internal_fn.defined_in_module_path),
550                    associated_with_type,
551                    internal_fn.assigned_name
552                )
553            },
554        );
555        let call_comment = &format!("calling `{function_name}` ({comment})",);
556
557        let patch_position = self.builder.add_call_placeholder(node, call_comment);
558        self.state.function_fixups.push(FunctionFixup {
559            patch_position,
560            fn_id: internal_fn.program_unique_id,
561            internal_function_definition: internal_fn.clone(),
562        });
563        //}
564    }
565    pub(crate) fn emit_internal_call(
566        &mut self,
567        target_reg: &Place,
568        node: &Node,
569        internal_fn: &InternalFunctionDefinitionRef,
570        arguments: &Vec<ArgumentExpression>,
571        ctx: &Context,
572    ) {
573        let argument_info = self.emit_arguments(
574            target_reg,
575            node,
576            &internal_fn.signature,
577            None,
578            arguments,
579            false,
580            ctx,
581        );
582
583        self.emit_call(node, internal_fn, "call"); // will be fixed up later
584
585        if !matches!(&*internal_fn.signature.return_type.kind, TypeKind::Never) {
586            self.emit_post_call(argument_info, node, "restore spilled after call");
587        }
588    }
589
590    /// If you're not on the last argument for "Outer Function", you have to put
591    /// that value away somewhere safe for a bit. Otherwise, when you're figuring out the next arguments,
592    /// you might accidentally overwrite it.
593    /// But if you are on the last argument, you can just drop it right where it needs to go.
594    const fn argument_needs_to_be_in_a_temporary_register_first(
595        &self,
596        reg: &TypedRegister,
597    ) -> bool {
598        // TODO: for now just assume it is
599        true
600    }
601
602    fn add_error(&mut self, error_kind: err::ErrorKind, node: &Node) {
603        self.errors.push(Error {
604            node: node.clone(),
605            kind: error_kind,
606        });
607    }
608}