Skip to main content

seqc/codegen/specialization/
codegen_word.rs

1//! The top-level specialized codegen: walking a word definition, dispatching
2//! on statement kind, generating the `define … { … }` prologue, lowering
3//! `if / else`, and emitting returns. The per-operation and per-call
4//! emitters live in sibling files (`codegen_ops`, `codegen_safe_math`,
5//! `codegen_calls`).
6
7use super::CodeGen;
8use super::context::RegisterContext;
9use super::types::{RegisterType, SpecSignature};
10use crate::ast::{Statement, WordDef};
11use crate::codegen::CodeGenError;
12use crate::codegen::mangle_name;
13use std::fmt::Write as _;
14
15impl CodeGen {
16    /// Generate a specialized version of a word.
17    ///
18    /// This creates a register-based function that passes values directly in
19    /// CPU registers instead of through the tagged pointer stack.
20    ///
21    /// The generated function:
22    /// - Takes primitive arguments directly (i64 for Int/Bool, double for Float)
23    /// - Returns the result in a register (not via stack pointer)
24    /// - Uses `musttail` for recursive calls to guarantee TCO
25    /// - Handles control flow with phi nodes for value merging
26    pub fn codegen_specialized_word(
27        &mut self,
28        word: &WordDef,
29        sig: &SpecSignature,
30    ) -> Result<(), CodeGenError> {
31        let base_name = format!("seq_{}", mangle_name(&word.name));
32        let spec_name = format!("{}{}", base_name, sig.suffix());
33
34        // Generate function signature
35        // For single output: define i64 @name(i64 %arg0) {
36        // For multiple outputs: define { i64, i64 } @name(i64 %arg0, i64 %arg1) {
37        let return_type = if sig.outputs.len() == 1 {
38            sig.outputs[0].llvm_type().to_string()
39        } else {
40            let types: Vec<_> = sig.outputs.iter().map(|t| t.llvm_type()).collect();
41            format!("{{ {} }}", types.join(", "))
42        };
43
44        let params: Vec<String> = sig
45            .inputs
46            .iter()
47            .enumerate()
48            .map(|(i, ty)| format!("{} %arg{}", ty.llvm_type(), i))
49            .collect();
50
51        writeln!(
52            &mut self.output,
53            "define {} @{}({}) {{",
54            return_type,
55            spec_name,
56            params.join(", ")
57        )?;
58        writeln!(&mut self.output, "entry:")?;
59
60        let initial_params: Vec<(String, RegisterType)> = sig
61            .inputs
62            .iter()
63            .enumerate()
64            .map(|(i, ty)| (format!("arg{}", i), *ty))
65            .collect();
66        let mut ctx = RegisterContext::from_params(&initial_params);
67
68        let body_len = word.body.len();
69        let mut prev_int_literal: Option<i64> = None;
70        for (i, stmt) in word.body.iter().enumerate() {
71            let is_last = i == body_len - 1;
72            self.codegen_specialized_statement(
73                &mut ctx,
74                stmt,
75                &word.name,
76                sig,
77                is_last,
78                &mut prev_int_literal,
79            )?;
80        }
81
82        writeln!(&mut self.output, "}}")?;
83        writeln!(&mut self.output)?;
84
85        // Record that this word is specialized
86        self.specialized_words
87            .insert(word.name.clone(), sig.clone());
88
89        Ok(())
90    }
91
92    /// Generate specialized code for a single statement
93    pub(super) fn codegen_specialized_statement(
94        &mut self,
95        ctx: &mut RegisterContext,
96        stmt: &Statement,
97        word_name: &str,
98        sig: &SpecSignature,
99        is_last: bool,
100        prev_int_literal: &mut Option<i64>,
101    ) -> Result<(), CodeGenError> {
102        // Track previous int literal for pick/roll optimization
103        let prev_int = *prev_int_literal;
104        *prev_int_literal = None; // Reset unless this is an IntLiteral
105
106        match stmt {
107            Statement::IntLiteral(n) => {
108                let var = self.fresh_temp();
109                writeln!(&mut self.output, "  %{} = add i64 0, {}", var, n)?;
110                ctx.push(var, RegisterType::I64);
111                *prev_int_literal = Some(*n); // Track for next statement
112            }
113
114            Statement::FloatLiteral(f) => {
115                let var = self.fresh_temp();
116                // Use bitcast from integer bits for exact IEEE 754 representation.
117                // This avoids precision loss from decimal string conversion (e.g., 0.1
118                // cannot be exactly represented in binary floating point). By storing
119                // the raw bit pattern and using bitcast, we preserve the exact value.
120                let bits = f.to_bits();
121                writeln!(
122                    &mut self.output,
123                    "  %{} = bitcast i64 {} to double",
124                    var, bits
125                )?;
126                ctx.push(var, RegisterType::Double);
127            }
128
129            Statement::BoolLiteral(b) => {
130                let var = self.fresh_temp();
131                let val = if *b { 1 } else { 0 };
132                writeln!(&mut self.output, "  %{} = add i64 0, {}", var, val)?;
133                ctx.push(var, RegisterType::I64);
134            }
135
136            Statement::WordCall { name, .. } => {
137                self.codegen_specialized_word_call(ctx, name, word_name, sig, is_last, prev_int)?;
138            }
139
140            Statement::If {
141                then_branch,
142                else_branch,
143                span: _,
144            } => {
145                self.codegen_specialized_if(
146                    ctx,
147                    then_branch,
148                    else_branch.as_ref(),
149                    word_name,
150                    sig,
151                    is_last,
152                )?;
153            }
154
155            // These shouldn't appear in specializable words (checked in can_specialize)
156            Statement::StringLiteral(_)
157            | Statement::Symbol(_)
158            | Statement::Quotation { .. }
159            | Statement::Match { .. } => {
160                return Err(CodeGenError::Logic(format!(
161                    "Non-specializable statement in specialized word: {:?}",
162                    stmt
163                )));
164            }
165        }
166
167        // Emit return if this is the last statement and it's not a control flow op
168        // that already emits returns (like if, or recursive calls)
169        let already_returns = match stmt {
170            Statement::If { .. } => true,
171            Statement::WordCall { name, .. } if name == word_name => true,
172            _ => false,
173        };
174        if is_last && !already_returns {
175            self.emit_specialized_return(ctx, sig)?;
176        }
177
178        Ok(())
179    }
180
181    /// Emit return statement for specialized function
182    pub(super) fn emit_specialized_return(
183        &mut self,
184        ctx: &RegisterContext,
185        sig: &SpecSignature,
186    ) -> Result<(), CodeGenError> {
187        let output_count = sig.outputs.len();
188
189        if output_count == 0 {
190            writeln!(&mut self.output, "  ret void")?;
191        } else if output_count == 1 {
192            let (var, ty) = ctx
193                .values
194                .last()
195                .ok_or_else(|| CodeGenError::Logic("Empty context at return".to_string()))?;
196            writeln!(&mut self.output, "  ret {} %{}", ty.llvm_type(), var)?;
197        } else {
198            // Multi-output: build struct return.
199            // Values in context are bottom-to-top, matching sig.outputs order.
200            if ctx.values.len() < output_count {
201                return Err(CodeGenError::Logic(format!(
202                    "Not enough values for multi-output return: need {}, have {}",
203                    output_count,
204                    ctx.values.len()
205                )));
206            }
207
208            let start_idx = ctx.values.len() - output_count;
209            let return_values: Vec<_> = ctx.values[start_idx..].to_vec();
210
211            let struct_type = sig.llvm_return_type();
212
213            let mut current_struct = "undef".to_string();
214            for (i, (var, ty)) in return_values.iter().enumerate() {
215                let new_struct = self.fresh_temp();
216                writeln!(
217                    &mut self.output,
218                    "  %{} = insertvalue {} {}, {} %{}, {}",
219                    new_struct,
220                    struct_type,
221                    current_struct,
222                    ty.llvm_type(),
223                    var,
224                    i
225                )?;
226                current_struct = format!("%{}", new_struct);
227            }
228
229            writeln!(&mut self.output, "  ret {} {}", struct_type, current_struct)?;
230        }
231        Ok(())
232    }
233
234    /// Generate specialized if/else statement
235    pub(super) fn codegen_specialized_if(
236        &mut self,
237        ctx: &mut RegisterContext,
238        then_branch: &[Statement],
239        else_branch: Option<&Vec<Statement>>,
240        word_name: &str,
241        sig: &SpecSignature,
242        is_last: bool,
243    ) -> Result<(), CodeGenError> {
244        // Pop condition
245        let (cond_var, _) = ctx
246            .pop()
247            .ok_or_else(|| CodeGenError::Logic("Empty context at if condition".to_string()))?;
248
249        let cmp_result = self.fresh_temp();
250        writeln!(
251            &mut self.output,
252            "  %{} = icmp ne i64 %{}, 0",
253            cmp_result, cond_var
254        )?;
255
256        let then_label = self.fresh_block("if_then");
257        let else_label = self.fresh_block("if_else");
258        let merge_label = self.fresh_block("if_merge");
259
260        writeln!(
261            &mut self.output,
262            "  br i1 %{}, label %{}, label %{}",
263            cmp_result, then_label, else_label
264        )?;
265
266        // Then branch
267        writeln!(&mut self.output, "{}:", then_label)?;
268        let mut then_ctx = ctx.clone();
269        let mut then_prev_int: Option<i64> = None;
270        for (i, stmt) in then_branch.iter().enumerate() {
271            let is_stmt_last = i == then_branch.len() - 1 && is_last;
272            self.codegen_specialized_statement(
273                &mut then_ctx,
274                stmt,
275                word_name,
276                sig,
277                is_stmt_last,
278                &mut then_prev_int,
279            )?;
280        }
281        // If the then branch is empty and this is the last statement, emit return
282        if is_last && then_branch.is_empty() {
283            self.emit_specialized_return(&then_ctx, sig)?;
284        }
285        let then_emitted_return = is_last;
286        let then_pred = if then_emitted_return {
287            None
288        } else {
289            writeln!(&mut self.output, "  br label %{}", merge_label)?;
290            Some(then_label.clone())
291        };
292
293        // Else branch
294        writeln!(&mut self.output, "{}:", else_label)?;
295        let mut else_ctx = ctx.clone();
296        let mut else_prev_int: Option<i64> = None;
297        if let Some(else_stmts) = else_branch {
298            for (i, stmt) in else_stmts.iter().enumerate() {
299                let is_stmt_last = i == else_stmts.len() - 1 && is_last;
300                self.codegen_specialized_statement(
301                    &mut else_ctx,
302                    stmt,
303                    word_name,
304                    sig,
305                    is_stmt_last,
306                    &mut else_prev_int,
307                )?;
308            }
309        }
310        // If the else branch is empty (or None) and this is the last statement, emit return
311        if is_last && (else_branch.is_none() || else_branch.as_ref().is_some_and(|b| b.is_empty()))
312        {
313            self.emit_specialized_return(&else_ctx, sig)?;
314        }
315        let else_emitted_return = is_last;
316        let else_pred = if else_emitted_return {
317            None
318        } else {
319            writeln!(&mut self.output, "  br label %{}", merge_label)?;
320            Some(else_label.clone())
321        };
322
323        // Merge block with phi nodes if either branch continues
324        if then_pred.is_some() || else_pred.is_some() {
325            writeln!(&mut self.output, "{}:", merge_label)?;
326
327            if let (Some(then_p), Some(else_p)) = (&then_pred, &else_pred) {
328                // Both branches continue - merge all values with phi nodes
329                if then_ctx.values.len() != else_ctx.values.len() {
330                    return Err(CodeGenError::Logic(format!(
331                        "Stack depth mismatch in if branches: then has {}, else has {}",
332                        then_ctx.values.len(),
333                        else_ctx.values.len()
334                    )));
335                }
336
337                ctx.values.clear();
338                for i in 0..then_ctx.values.len() {
339                    let (then_var, then_ty) = &then_ctx.values[i];
340                    let (else_var, else_ty) = &else_ctx.values[i];
341
342                    if then_ty != else_ty {
343                        return Err(CodeGenError::Logic(format!(
344                            "Type mismatch at position {} in if branches: {:?} vs {:?}",
345                            i, then_ty, else_ty
346                        )));
347                    }
348
349                    if then_var == else_var {
350                        ctx.push(then_var.clone(), *then_ty);
351                    } else {
352                        let phi_result = self.fresh_temp();
353                        writeln!(
354                            &mut self.output,
355                            "  %{} = phi {} [ %{}, %{} ], [ %{}, %{} ]",
356                            phi_result,
357                            then_ty.llvm_type(),
358                            then_var,
359                            then_p,
360                            else_var,
361                            else_p
362                        )?;
363                        ctx.push(phi_result, *then_ty);
364                    }
365                }
366            } else if then_pred.is_some() {
367                *ctx = then_ctx;
368            } else {
369                *ctx = else_ctx;
370            }
371
372            if is_last && (then_pred.is_some() || else_pred.is_some()) {
373                self.emit_specialized_return(ctx, sig)?;
374            }
375        }
376
377        Ok(())
378    }
379}