1use super::CodeGen;
8use super::context::RegisterContext;
9use super::types::{RegisterType, SpecSignature};
10use crate::ast::{Statement, WordDef};
11use crate::codegen::CodeGenError;
12use crate::codegen::mangle_name;
13use std::fmt::Write as _;
14
15impl CodeGen {
16 pub fn codegen_specialized_word(
27 &mut self,
28 word: &WordDef,
29 sig: &SpecSignature,
30 ) -> Result<(), CodeGenError> {
31 let base_name = format!("seq_{}", mangle_name(&word.name));
32 let spec_name = format!("{}{}", base_name, sig.suffix());
33
34 let return_type = if sig.outputs.len() == 1 {
38 sig.outputs[0].llvm_type().to_string()
39 } else {
40 let types: Vec<_> = sig.outputs.iter().map(|t| t.llvm_type()).collect();
41 format!("{{ {} }}", types.join(", "))
42 };
43
44 let params: Vec<String> = sig
45 .inputs
46 .iter()
47 .enumerate()
48 .map(|(i, ty)| format!("{} %arg{}", ty.llvm_type(), i))
49 .collect();
50
51 writeln!(
52 &mut self.output,
53 "define {} @{}({}) {{",
54 return_type,
55 spec_name,
56 params.join(", ")
57 )?;
58 writeln!(&mut self.output, "entry:")?;
59
60 let initial_params: Vec<(String, RegisterType)> = sig
61 .inputs
62 .iter()
63 .enumerate()
64 .map(|(i, ty)| (format!("arg{}", i), *ty))
65 .collect();
66 let mut ctx = RegisterContext::from_params(&initial_params);
67
68 let body_len = word.body.len();
69 let mut prev_int_literal: Option<i64> = None;
70 for (i, stmt) in word.body.iter().enumerate() {
71 let is_last = i == body_len - 1;
72 self.codegen_specialized_statement(
73 &mut ctx,
74 stmt,
75 &word.name,
76 sig,
77 is_last,
78 &mut prev_int_literal,
79 )?;
80 }
81
82 writeln!(&mut self.output, "}}")?;
83 writeln!(&mut self.output)?;
84
85 self.specialized_words
87 .insert(word.name.clone(), sig.clone());
88
89 Ok(())
90 }
91
92 pub(super) fn codegen_specialized_statement(
94 &mut self,
95 ctx: &mut RegisterContext,
96 stmt: &Statement,
97 word_name: &str,
98 sig: &SpecSignature,
99 is_last: bool,
100 prev_int_literal: &mut Option<i64>,
101 ) -> Result<(), CodeGenError> {
102 let prev_int = *prev_int_literal;
104 *prev_int_literal = None; match stmt {
107 Statement::IntLiteral(n) => {
108 let var = self.fresh_temp();
109 writeln!(&mut self.output, " %{} = add i64 0, {}", var, n)?;
110 ctx.push(var, RegisterType::I64);
111 *prev_int_literal = Some(*n); }
113
114 Statement::FloatLiteral(f) => {
115 let var = self.fresh_temp();
116 let bits = f.to_bits();
121 writeln!(
122 &mut self.output,
123 " %{} = bitcast i64 {} to double",
124 var, bits
125 )?;
126 ctx.push(var, RegisterType::Double);
127 }
128
129 Statement::BoolLiteral(b) => {
130 let var = self.fresh_temp();
131 let val = if *b { 1 } else { 0 };
132 writeln!(&mut self.output, " %{} = add i64 0, {}", var, val)?;
133 ctx.push(var, RegisterType::I64);
134 }
135
136 Statement::WordCall { name, .. } => {
137 self.codegen_specialized_word_call(ctx, name, word_name, sig, is_last, prev_int)?;
138 }
139
140 Statement::If {
141 then_branch,
142 else_branch,
143 span: _,
144 } => {
145 self.codegen_specialized_if(
146 ctx,
147 then_branch,
148 else_branch.as_ref(),
149 word_name,
150 sig,
151 is_last,
152 )?;
153 }
154
155 Statement::StringLiteral(_)
157 | Statement::Symbol(_)
158 | Statement::Quotation { .. }
159 | Statement::Match { .. } => {
160 return Err(CodeGenError::Logic(format!(
161 "Non-specializable statement in specialized word: {:?}",
162 stmt
163 )));
164 }
165 }
166
167 let already_returns = match stmt {
170 Statement::If { .. } => true,
171 Statement::WordCall { name, .. } if name == word_name => true,
172 _ => false,
173 };
174 if is_last && !already_returns {
175 self.emit_specialized_return(ctx, sig)?;
176 }
177
178 Ok(())
179 }
180
181 pub(super) fn emit_specialized_return(
183 &mut self,
184 ctx: &RegisterContext,
185 sig: &SpecSignature,
186 ) -> Result<(), CodeGenError> {
187 let output_count = sig.outputs.len();
188
189 if output_count == 0 {
190 writeln!(&mut self.output, " ret void")?;
191 } else if output_count == 1 {
192 let (var, ty) = ctx
193 .values
194 .last()
195 .ok_or_else(|| CodeGenError::Logic("Empty context at return".to_string()))?;
196 writeln!(&mut self.output, " ret {} %{}", ty.llvm_type(), var)?;
197 } else {
198 if ctx.values.len() < output_count {
201 return Err(CodeGenError::Logic(format!(
202 "Not enough values for multi-output return: need {}, have {}",
203 output_count,
204 ctx.values.len()
205 )));
206 }
207
208 let start_idx = ctx.values.len() - output_count;
209 let return_values: Vec<_> = ctx.values[start_idx..].to_vec();
210
211 let struct_type = sig.llvm_return_type();
212
213 let mut current_struct = "undef".to_string();
214 for (i, (var, ty)) in return_values.iter().enumerate() {
215 let new_struct = self.fresh_temp();
216 writeln!(
217 &mut self.output,
218 " %{} = insertvalue {} {}, {} %{}, {}",
219 new_struct,
220 struct_type,
221 current_struct,
222 ty.llvm_type(),
223 var,
224 i
225 )?;
226 current_struct = format!("%{}", new_struct);
227 }
228
229 writeln!(&mut self.output, " ret {} {}", struct_type, current_struct)?;
230 }
231 Ok(())
232 }
233
234 pub(super) fn codegen_specialized_if(
236 &mut self,
237 ctx: &mut RegisterContext,
238 then_branch: &[Statement],
239 else_branch: Option<&Vec<Statement>>,
240 word_name: &str,
241 sig: &SpecSignature,
242 is_last: bool,
243 ) -> Result<(), CodeGenError> {
244 let (cond_var, _) = ctx
246 .pop()
247 .ok_or_else(|| CodeGenError::Logic("Empty context at if condition".to_string()))?;
248
249 let cmp_result = self.fresh_temp();
250 writeln!(
251 &mut self.output,
252 " %{} = icmp ne i64 %{}, 0",
253 cmp_result, cond_var
254 )?;
255
256 let then_label = self.fresh_block("if_then");
257 let else_label = self.fresh_block("if_else");
258 let merge_label = self.fresh_block("if_merge");
259
260 writeln!(
261 &mut self.output,
262 " br i1 %{}, label %{}, label %{}",
263 cmp_result, then_label, else_label
264 )?;
265
266 writeln!(&mut self.output, "{}:", then_label)?;
268 let mut then_ctx = ctx.clone();
269 let mut then_prev_int: Option<i64> = None;
270 for (i, stmt) in then_branch.iter().enumerate() {
271 let is_stmt_last = i == then_branch.len() - 1 && is_last;
272 self.codegen_specialized_statement(
273 &mut then_ctx,
274 stmt,
275 word_name,
276 sig,
277 is_stmt_last,
278 &mut then_prev_int,
279 )?;
280 }
281 if is_last && then_branch.is_empty() {
283 self.emit_specialized_return(&then_ctx, sig)?;
284 }
285 let then_emitted_return = is_last;
286 let then_pred = if then_emitted_return {
287 None
288 } else {
289 writeln!(&mut self.output, " br label %{}", merge_label)?;
290 Some(then_label.clone())
291 };
292
293 writeln!(&mut self.output, "{}:", else_label)?;
295 let mut else_ctx = ctx.clone();
296 let mut else_prev_int: Option<i64> = None;
297 if let Some(else_stmts) = else_branch {
298 for (i, stmt) in else_stmts.iter().enumerate() {
299 let is_stmt_last = i == else_stmts.len() - 1 && is_last;
300 self.codegen_specialized_statement(
301 &mut else_ctx,
302 stmt,
303 word_name,
304 sig,
305 is_stmt_last,
306 &mut else_prev_int,
307 )?;
308 }
309 }
310 if is_last && (else_branch.is_none() || else_branch.as_ref().is_some_and(|b| b.is_empty()))
312 {
313 self.emit_specialized_return(&else_ctx, sig)?;
314 }
315 let else_emitted_return = is_last;
316 let else_pred = if else_emitted_return {
317 None
318 } else {
319 writeln!(&mut self.output, " br label %{}", merge_label)?;
320 Some(else_label.clone())
321 };
322
323 if then_pred.is_some() || else_pred.is_some() {
325 writeln!(&mut self.output, "{}:", merge_label)?;
326
327 if let (Some(then_p), Some(else_p)) = (&then_pred, &else_pred) {
328 if then_ctx.values.len() != else_ctx.values.len() {
330 return Err(CodeGenError::Logic(format!(
331 "Stack depth mismatch in if branches: then has {}, else has {}",
332 then_ctx.values.len(),
333 else_ctx.values.len()
334 )));
335 }
336
337 ctx.values.clear();
338 for i in 0..then_ctx.values.len() {
339 let (then_var, then_ty) = &then_ctx.values[i];
340 let (else_var, else_ty) = &else_ctx.values[i];
341
342 if then_ty != else_ty {
343 return Err(CodeGenError::Logic(format!(
344 "Type mismatch at position {} in if branches: {:?} vs {:?}",
345 i, then_ty, else_ty
346 )));
347 }
348
349 if then_var == else_var {
350 ctx.push(then_var.clone(), *then_ty);
351 } else {
352 let phi_result = self.fresh_temp();
353 writeln!(
354 &mut self.output,
355 " %{} = phi {} [ %{}, %{} ], [ %{}, %{} ]",
356 phi_result,
357 then_ty.llvm_type(),
358 then_var,
359 then_p,
360 else_var,
361 else_p
362 )?;
363 ctx.push(phi_result, *then_ty);
364 }
365 }
366 } else if then_pred.is_some() {
367 *ctx = then_ctx;
368 } else {
369 *ctx = else_ctx;
370 }
371
372 if is_last && (then_pred.is_some() || else_pred.is_some()) {
373 self.emit_specialized_return(ctx, sig)?;
374 }
375 }
376
377 Ok(())
378 }
379}