1use super::{CodeGen, CodeGenError, mangle_name};
92use crate::ast::{Statement, WordDef};
93use crate::types::{StackType, Type};
94use std::fmt::Write as _;
95
96#[derive(Debug, Clone, Copy, PartialEq, Eq)]
98pub enum RegisterType {
99 I64,
101 Double,
103}
104
105impl RegisterType {
106 pub fn from_type(ty: &Type) -> Option<Self> {
108 match ty {
109 Type::Int | Type::Bool => Some(RegisterType::I64),
110 Type::Float => Some(RegisterType::Double),
111 _ => None,
112 }
113 }
114
115 pub fn llvm_type(&self) -> &'static str {
117 match self {
118 RegisterType::I64 => "i64",
119 RegisterType::Double => "double",
120 }
121 }
122}
123
124#[derive(Debug, Clone)]
126pub struct SpecSignature {
127 pub inputs: Vec<RegisterType>,
129 pub outputs: Vec<RegisterType>,
131}
132
133impl SpecSignature {
134 pub fn suffix(&self) -> String {
138 if self.inputs.len() == 1 && self.outputs.len() == 1 {
139 match (self.inputs[0], self.outputs[0]) {
140 (RegisterType::I64, RegisterType::I64) => "_i64".to_string(),
141 (RegisterType::Double, RegisterType::Double) => "_f64".to_string(),
142 (RegisterType::I64, RegisterType::Double) => "_i64_to_f64".to_string(),
143 (RegisterType::Double, RegisterType::I64) => "_f64_to_i64".to_string(),
144 }
145 } else {
146 let mut suffix = String::new();
148 for ty in &self.inputs {
149 suffix.push('_');
150 suffix.push_str(match ty {
151 RegisterType::I64 => "i",
152 RegisterType::Double => "f",
153 });
154 }
155 suffix.push_str("_to");
156 for ty in &self.outputs {
157 suffix.push('_');
158 suffix.push_str(match ty {
159 RegisterType::I64 => "i",
160 RegisterType::Double => "f",
161 });
162 }
163 suffix
164 }
165 }
166
167 pub fn is_direct_call(&self) -> bool {
169 self.outputs.len() == 1
170 }
171
172 pub fn llvm_return_type(&self) -> String {
177 if self.outputs.len() == 1 {
178 self.outputs[0].llvm_type().to_string()
179 } else {
180 let types: Vec<_> = self.outputs.iter().map(|t| t.llvm_type()).collect();
181 format!("{{ {} }}", types.join(", "))
182 }
183 }
184}
185
186#[derive(Debug, Clone)]
191pub struct RegisterContext {
192 pub values: Vec<(String, RegisterType)>,
194}
195
196impl RegisterContext {
197 pub fn new() -> Self {
199 Self { values: Vec::new() }
200 }
201
202 pub fn from_params(params: &[(String, RegisterType)]) -> Self {
204 Self {
205 values: params.to_vec(),
206 }
207 }
208
209 pub fn push(&mut self, ssa_var: String, ty: RegisterType) {
211 self.values.push((ssa_var, ty));
212 }
213
214 pub fn pop(&mut self) -> Option<(String, RegisterType)> {
216 self.values.pop()
217 }
218
219 #[cfg_attr(not(test), allow(dead_code))]
221 pub fn len(&self) -> usize {
222 self.values.len()
223 }
224
225 pub fn dup(&mut self) {
229 if let Some((ssa, ty)) = self.values.last().cloned() {
230 self.values.push((ssa, ty));
231 }
232 }
233
234 pub fn drop(&mut self) {
236 self.values.pop();
237 }
238
239 pub fn swap(&mut self) {
241 let len = self.values.len();
242 if len >= 2 {
243 self.values.swap(len - 1, len - 2);
244 }
245 }
246
247 pub fn over(&mut self) {
249 let len = self.values.len();
250 if len >= 2 {
251 let a = self.values[len - 2].clone();
252 self.values.push(a);
253 }
254 }
255
256 pub fn rot(&mut self) {
258 let len = self.values.len();
259 if len >= 3 {
260 let a = self.values.remove(len - 3);
261 self.values.push(a);
262 }
263 }
264}
265
266impl Default for RegisterContext {
267 fn default() -> Self {
268 Self::new()
269 }
270}
271
272const SPECIALIZABLE_OPS: &[&str] = &[
282 "i.+",
284 "i.add",
285 "i.-",
286 "i.subtract",
287 "i.*",
288 "i.multiply",
289 "i./",
290 "i.divide",
291 "i.%",
292 "i.mod",
293 "band",
295 "bor",
296 "bxor",
297 "bnot",
298 "shl",
299 "shr",
300 "popcount",
302 "clz",
303 "ctz",
304 "int->float",
306 "float->int",
307 "and",
309 "or",
310 "not",
311 "i.<",
313 "i.lt",
314 "i.>",
315 "i.gt",
316 "i.<=",
317 "i.lte",
318 "i.>=",
319 "i.gte",
320 "i.=",
321 "i.eq",
322 "i.<>",
323 "i.neq",
324 "f.+",
326 "f.add",
327 "f.-",
328 "f.subtract",
329 "f.*",
330 "f.multiply",
331 "f./",
332 "f.divide",
333 "f.<",
335 "f.lt",
336 "f.>",
337 "f.gt",
338 "f.<=",
339 "f.lte",
340 "f.>=",
341 "f.gte",
342 "f.=",
343 "f.eq",
344 "f.<>",
345 "f.neq",
346 "dup",
348 "drop",
349 "swap",
350 "over",
351 "rot",
352 "nip",
353 "tuck",
354 "pick",
355 "roll",
356];
357
358impl CodeGen {
359 pub fn can_specialize(&self, word: &WordDef) -> Option<SpecSignature> {
361 let effect = word.effect.as_ref()?;
363
364 if !effect.is_pure() {
366 return None;
367 }
368
369 let inputs = Self::extract_register_types(&effect.inputs)?;
371 let outputs = Self::extract_register_types(&effect.outputs)?;
372
373 if inputs.is_empty() && outputs.is_empty() {
375 return None;
376 }
377
378 if outputs.is_empty() {
380 return None;
381 }
382
383 if !self.is_body_specializable(&word.body, &word.name) {
385 return None;
386 }
387
388 Some(SpecSignature { inputs, outputs })
389 }
390
391 fn extract_register_types(stack: &StackType) -> Option<Vec<RegisterType>> {
397 let mut types = Vec::new();
398 let mut current = stack;
399
400 loop {
401 match current {
402 StackType::Empty => break,
403 StackType::RowVar(_) => {
404 break;
408 }
409 StackType::Cons { rest, top } => {
410 let reg_ty = RegisterType::from_type(top)?;
411 types.push(reg_ty);
412 current = rest;
413 }
414 }
415 }
416
417 types.reverse();
419 Some(types)
420 }
421
422 fn is_body_specializable(&self, body: &[Statement], word_name: &str) -> bool {
427 let mut prev_was_int_literal = false;
428 for stmt in body {
429 if !self.is_statement_specializable(stmt, word_name, prev_was_int_literal) {
430 return false;
431 }
432 prev_was_int_literal = matches!(stmt, Statement::IntLiteral(_));
434 }
435 true
436 }
437
438 fn is_statement_specializable(
443 &self,
444 stmt: &Statement,
445 word_name: &str,
446 prev_was_int_literal: bool,
447 ) -> bool {
448 match stmt {
449 Statement::IntLiteral(_) => true,
451
452 Statement::FloatLiteral(_) => true,
454
455 Statement::BoolLiteral(_) => true,
457
458 Statement::StringLiteral(_) => false,
460
461 Statement::Symbol(_) => false,
463
464 Statement::Quotation { .. } => false,
466
467 Statement::Match { .. } => false,
469
470 Statement::WordCall { name, .. } => {
472 if name == word_name {
474 return true;
475 }
476
477 if (name == "pick" || name == "roll") && !prev_was_int_literal {
480 return false;
481 }
482
483 if SPECIALIZABLE_OPS.contains(&name.as_str()) {
485 return true;
486 }
487
488 if self.specialized_words.contains_key(name) {
490 return true;
491 }
492
493 false
495 }
496
497 Statement::If {
499 then_branch,
500 else_branch,
501 span: _,
502 } => {
503 if !self.is_body_specializable(then_branch, word_name) {
504 return false;
505 }
506 if let Some(else_stmts) = else_branch
507 && !self.is_body_specializable(else_stmts, word_name)
508 {
509 return false;
510 }
511 true
512 }
513 }
514 }
515
516 pub fn codegen_specialized_word(
534 &mut self,
535 word: &WordDef,
536 sig: &SpecSignature,
537 ) -> Result<(), CodeGenError> {
538 let base_name = format!("seq_{}", mangle_name(&word.name));
539 let spec_name = format!("{}{}", base_name, sig.suffix());
540
541 let return_type = if sig.outputs.len() == 1 {
545 sig.outputs[0].llvm_type().to_string()
546 } else {
547 let types: Vec<_> = sig.outputs.iter().map(|t| t.llvm_type()).collect();
549 format!("{{ {} }}", types.join(", "))
550 };
551
552 let params: Vec<String> = sig
554 .inputs
555 .iter()
556 .enumerate()
557 .map(|(i, ty)| format!("{} %arg{}", ty.llvm_type(), i))
558 .collect();
559
560 writeln!(
561 &mut self.output,
562 "define {} @{}({}) {{",
563 return_type,
564 spec_name,
565 params.join(", ")
566 )?;
567 writeln!(&mut self.output, "entry:")?;
568
569 let initial_params: Vec<(String, RegisterType)> = sig
571 .inputs
572 .iter()
573 .enumerate()
574 .map(|(i, ty)| (format!("arg{}", i), *ty))
575 .collect();
576 let mut ctx = RegisterContext::from_params(&initial_params);
577
578 let body_len = word.body.len();
580 let mut prev_int_literal: Option<i64> = None;
581 for (i, stmt) in word.body.iter().enumerate() {
582 let is_last = i == body_len - 1;
583 self.codegen_specialized_statement(
584 &mut ctx,
585 stmt,
586 &word.name,
587 sig,
588 is_last,
589 &mut prev_int_literal,
590 )?;
591 }
592
593 writeln!(&mut self.output, "}}")?;
594 writeln!(&mut self.output)?;
595
596 self.specialized_words
598 .insert(word.name.clone(), sig.clone());
599
600 Ok(())
601 }
602
603 fn codegen_specialized_statement(
605 &mut self,
606 ctx: &mut RegisterContext,
607 stmt: &Statement,
608 word_name: &str,
609 sig: &SpecSignature,
610 is_last: bool,
611 prev_int_literal: &mut Option<i64>,
612 ) -> Result<(), CodeGenError> {
613 let prev_int = *prev_int_literal;
615 *prev_int_literal = None; match stmt {
618 Statement::IntLiteral(n) => {
619 let var = self.fresh_temp();
620 writeln!(&mut self.output, " %{} = add i64 0, {}", var, n)?;
621 ctx.push(var, RegisterType::I64);
622 *prev_int_literal = Some(*n); }
624
625 Statement::FloatLiteral(f) => {
626 let var = self.fresh_temp();
627 let bits = f.to_bits();
632 writeln!(
633 &mut self.output,
634 " %{} = bitcast i64 {} to double",
635 var, bits
636 )?;
637 ctx.push(var, RegisterType::Double);
638 }
639
640 Statement::BoolLiteral(b) => {
641 let var = self.fresh_temp();
642 let val = if *b { 1 } else { 0 };
643 writeln!(&mut self.output, " %{} = add i64 0, {}", var, val)?;
644 ctx.push(var, RegisterType::I64);
645 }
646
647 Statement::WordCall { name, .. } => {
648 self.codegen_specialized_word_call(ctx, name, word_name, sig, is_last, prev_int)?;
649 }
650
651 Statement::If {
652 then_branch,
653 else_branch,
654 span: _,
655 } => {
656 self.codegen_specialized_if(
657 ctx,
658 then_branch,
659 else_branch.as_ref(),
660 word_name,
661 sig,
662 is_last,
663 )?;
664 }
665
666 Statement::StringLiteral(_)
668 | Statement::Symbol(_)
669 | Statement::Quotation { .. }
670 | Statement::Match { .. } => {
671 return Err(CodeGenError::Logic(format!(
672 "Non-specializable statement in specialized word: {:?}",
673 stmt
674 )));
675 }
676 }
677
678 let already_returns = match stmt {
681 Statement::If { .. } => true,
682 Statement::WordCall { name, .. } if name == word_name => true,
683 _ => false,
684 };
685 if is_last && !already_returns {
686 self.emit_specialized_return(ctx, sig)?;
687 }
688
689 Ok(())
690 }
691
692 fn codegen_specialized_word_call(
694 &mut self,
695 ctx: &mut RegisterContext,
696 name: &str,
697 word_name: &str,
698 sig: &SpecSignature,
699 is_last: bool,
700 prev_int: Option<i64>,
701 ) -> Result<(), CodeGenError> {
702 match name {
703 "dup" => ctx.dup(),
705 "drop" => ctx.drop(),
706 "swap" => ctx.swap(),
707 "over" => ctx.over(),
708 "rot" => ctx.rot(),
709 "nip" => {
710 ctx.swap();
712 ctx.drop();
713 }
714 "tuck" => {
715 ctx.dup();
717 let b = ctx.pop().unwrap();
718 let b2 = ctx.pop().unwrap();
719 let a = ctx.pop().unwrap();
720 ctx.push(b.0, b.1);
721 ctx.push(a.0, a.1);
722 ctx.push(b2.0, b2.1);
723 }
724 "pick" => {
725 let n = prev_int.ok_or_else(|| {
728 CodeGenError::Logic("pick requires constant N in specialized mode".to_string())
729 })?;
730 if n < 0 {
731 return Err(CodeGenError::Logic(format!(
732 "pick requires non-negative N, got {}",
733 n
734 )));
735 }
736 let n = n as usize;
737 ctx.pop();
739 let len = ctx.values.len();
741 if n >= len {
742 return Err(CodeGenError::Logic(format!(
743 "pick {} but only {} values in context",
744 n, len
745 )));
746 }
747 let (var, ty) = ctx.values[len - 1 - n].clone();
748 ctx.push(var, ty);
749 }
750 "roll" => {
751 let n = prev_int.ok_or_else(|| {
754 CodeGenError::Logic("roll requires constant N in specialized mode".to_string())
755 })?;
756 if n < 0 {
757 return Err(CodeGenError::Logic(format!(
758 "roll requires non-negative N, got {}",
759 n
760 )));
761 }
762 let n = n as usize;
763 ctx.pop();
765 let len = ctx.values.len();
767 if n >= len {
768 return Err(CodeGenError::Logic(format!(
769 "roll {} but only {} values in context",
770 n, len
771 )));
772 }
773 if n > 0 {
774 let val = ctx.values.remove(len - 1 - n);
775 ctx.values.push(val);
776 }
777 }
779
780 "i.+" | "i.add" => {
783 let (b, _) = ctx.pop().unwrap();
784 let (a, _) = ctx.pop().unwrap();
785 let result = self.fresh_temp();
786 writeln!(&mut self.output, " %{} = add i64 %{}, %{}", result, a, b)?;
787 ctx.push(result, RegisterType::I64);
788 }
789 "i.-" | "i.subtract" => {
790 let (b, _) = ctx.pop().unwrap();
791 let (a, _) = ctx.pop().unwrap();
792 let result = self.fresh_temp();
793 writeln!(&mut self.output, " %{} = sub i64 %{}, %{}", result, a, b)?;
794 ctx.push(result, RegisterType::I64);
795 }
796 "i.*" | "i.multiply" => {
797 let (b, _) = ctx.pop().unwrap();
798 let (a, _) = ctx.pop().unwrap();
799 let result = self.fresh_temp();
800 writeln!(&mut self.output, " %{} = mul i64 %{}, %{}", result, a, b)?;
801 ctx.push(result, RegisterType::I64);
802 }
803 "i./" | "i.divide" => {
804 self.emit_specialized_safe_div(ctx, "sdiv")?;
805 }
806 "i.%" | "i.mod" => {
807 self.emit_specialized_safe_div(ctx, "srem")?;
808 }
809
810 "band" => {
812 let (b, _) = ctx.pop().unwrap();
813 let (a, _) = ctx.pop().unwrap();
814 let result = self.fresh_temp();
815 writeln!(&mut self.output, " %{} = and i64 %{}, %{}", result, a, b)?;
816 ctx.push(result, RegisterType::I64);
817 }
818 "bor" => {
819 let (b, _) = ctx.pop().unwrap();
820 let (a, _) = ctx.pop().unwrap();
821 let result = self.fresh_temp();
822 writeln!(&mut self.output, " %{} = or i64 %{}, %{}", result, a, b)?;
823 ctx.push(result, RegisterType::I64);
824 }
825 "bxor" => {
826 let (b, _) = ctx.pop().unwrap();
827 let (a, _) = ctx.pop().unwrap();
828 let result = self.fresh_temp();
829 writeln!(&mut self.output, " %{} = xor i64 %{}, %{}", result, a, b)?;
830 ctx.push(result, RegisterType::I64);
831 }
832 "bnot" => {
833 let (a, _) = ctx.pop().unwrap();
834 let result = self.fresh_temp();
835 writeln!(&mut self.output, " %{} = xor i64 %{}, -1", result, a)?;
837 ctx.push(result, RegisterType::I64);
838 }
839 "shl" => {
840 self.emit_specialized_safe_shift(ctx, true)?;
841 }
842 "shr" => {
843 self.emit_specialized_safe_shift(ctx, false)?;
844 }
845
846 "popcount" => {
848 let (a, _) = ctx.pop().unwrap();
849 let result = self.fresh_temp();
850 writeln!(
851 &mut self.output,
852 " %{} = call i64 @llvm.ctpop.i64(i64 %{})",
853 result, a
854 )?;
855 ctx.push(result, RegisterType::I64);
856 }
857 "clz" => {
858 let (a, _) = ctx.pop().unwrap();
859 let result = self.fresh_temp();
860 writeln!(
862 &mut self.output,
863 " %{} = call i64 @llvm.ctlz.i64(i64 %{}, i1 false)",
864 result, a
865 )?;
866 ctx.push(result, RegisterType::I64);
867 }
868 "ctz" => {
869 let (a, _) = ctx.pop().unwrap();
870 let result = self.fresh_temp();
871 writeln!(
873 &mut self.output,
874 " %{} = call i64 @llvm.cttz.i64(i64 %{}, i1 false)",
875 result, a
876 )?;
877 ctx.push(result, RegisterType::I64);
878 }
879
880 "int->float" => {
882 let (a, _) = ctx.pop().unwrap();
883 let result = self.fresh_temp();
884 writeln!(
885 &mut self.output,
886 " %{} = sitofp i64 %{} to double",
887 result, a
888 )?;
889 ctx.push(result, RegisterType::Double);
890 }
891 "float->int" => {
892 let (a, _) = ctx.pop().unwrap();
893 let result = self.fresh_temp();
894 writeln!(
895 &mut self.output,
896 " %{} = fptosi double %{} to i64",
897 result, a
898 )?;
899 ctx.push(result, RegisterType::I64);
900 }
901
902 "and" => {
904 let (b, _) = ctx.pop().unwrap();
905 let (a, _) = ctx.pop().unwrap();
906 let result = self.fresh_temp();
907 writeln!(&mut self.output, " %{} = and i64 %{}, %{}", result, a, b)?;
908 ctx.push(result, RegisterType::I64);
909 }
910 "or" => {
911 let (b, _) = ctx.pop().unwrap();
912 let (a, _) = ctx.pop().unwrap();
913 let result = self.fresh_temp();
914 writeln!(&mut self.output, " %{} = or i64 %{}, %{}", result, a, b)?;
915 ctx.push(result, RegisterType::I64);
916 }
917 "not" => {
918 let (a, _) = ctx.pop().unwrap();
919 let result = self.fresh_temp();
920 writeln!(&mut self.output, " %{} = xor i64 %{}, 1", result, a)?;
922 ctx.push(result, RegisterType::I64);
923 }
924
925 "i.<" | "i.lt" => self.emit_specialized_icmp(ctx, "slt")?,
927 "i.>" | "i.gt" => self.emit_specialized_icmp(ctx, "sgt")?,
928 "i.<=" | "i.lte" => self.emit_specialized_icmp(ctx, "sle")?,
929 "i.>=" | "i.gte" => self.emit_specialized_icmp(ctx, "sge")?,
930 "i.=" | "i.eq" => self.emit_specialized_icmp(ctx, "eq")?,
931 "i.<>" | "i.neq" => self.emit_specialized_icmp(ctx, "ne")?,
932
933 "f.+" | "f.add" => {
935 let (b, _) = ctx.pop().unwrap();
936 let (a, _) = ctx.pop().unwrap();
937 let result = self.fresh_temp();
938 writeln!(
939 &mut self.output,
940 " %{} = fadd double %{}, %{}",
941 result, a, b
942 )?;
943 ctx.push(result, RegisterType::Double);
944 }
945 "f.-" | "f.subtract" => {
946 let (b, _) = ctx.pop().unwrap();
947 let (a, _) = ctx.pop().unwrap();
948 let result = self.fresh_temp();
949 writeln!(
950 &mut self.output,
951 " %{} = fsub double %{}, %{}",
952 result, a, b
953 )?;
954 ctx.push(result, RegisterType::Double);
955 }
956 "f.*" | "f.multiply" => {
957 let (b, _) = ctx.pop().unwrap();
958 let (a, _) = ctx.pop().unwrap();
959 let result = self.fresh_temp();
960 writeln!(
961 &mut self.output,
962 " %{} = fmul double %{}, %{}",
963 result, a, b
964 )?;
965 ctx.push(result, RegisterType::Double);
966 }
967 "f./" | "f.divide" => {
968 let (b, _) = ctx.pop().unwrap();
969 let (a, _) = ctx.pop().unwrap();
970 let result = self.fresh_temp();
971 writeln!(
972 &mut self.output,
973 " %{} = fdiv double %{}, %{}",
974 result, a, b
975 )?;
976 ctx.push(result, RegisterType::Double);
977 }
978
979 "f.<" | "f.lt" => self.emit_specialized_fcmp(ctx, "olt")?,
981 "f.>" | "f.gt" => self.emit_specialized_fcmp(ctx, "ogt")?,
982 "f.<=" | "f.lte" => self.emit_specialized_fcmp(ctx, "ole")?,
983 "f.>=" | "f.gte" => self.emit_specialized_fcmp(ctx, "oge")?,
984 "f.=" | "f.eq" => self.emit_specialized_fcmp(ctx, "oeq")?,
985 "f.<>" | "f.neq" => self.emit_specialized_fcmp(ctx, "one")?,
986
987 _ if name == word_name => {
989 self.emit_specialized_recursive_call(ctx, word_name, sig, is_last)?;
990 }
991
992 _ if self.specialized_words.contains_key(name) => {
994 self.emit_specialized_word_dispatch(ctx, name)?;
995 }
996
997 _ => {
998 return Err(CodeGenError::Logic(format!(
999 "Unhandled operation in specialized codegen: {}",
1000 name
1001 )));
1002 }
1003 }
1004 Ok(())
1005 }
1006
1007 fn emit_specialized_icmp(
1009 &mut self,
1010 ctx: &mut RegisterContext,
1011 cmp_op: &str,
1012 ) -> Result<(), CodeGenError> {
1013 let (b, _) = ctx.pop().unwrap();
1014 let (a, _) = ctx.pop().unwrap();
1015 let cmp_result = self.fresh_temp();
1016 let result = self.fresh_temp();
1017 writeln!(
1018 &mut self.output,
1019 " %{} = icmp {} i64 %{}, %{}",
1020 cmp_result, cmp_op, a, b
1021 )?;
1022 writeln!(
1023 &mut self.output,
1024 " %{} = zext i1 %{} to i64",
1025 result, cmp_result
1026 )?;
1027 ctx.push(result, RegisterType::I64);
1028 Ok(())
1029 }
1030
1031 fn emit_specialized_fcmp(
1033 &mut self,
1034 ctx: &mut RegisterContext,
1035 cmp_op: &str,
1036 ) -> Result<(), CodeGenError> {
1037 let (b, _) = ctx.pop().unwrap();
1038 let (a, _) = ctx.pop().unwrap();
1039 let cmp_result = self.fresh_temp();
1040 let result = self.fresh_temp();
1041 writeln!(
1042 &mut self.output,
1043 " %{} = fcmp {} double %{}, %{}",
1044 cmp_result, cmp_op, a, b
1045 )?;
1046 writeln!(
1047 &mut self.output,
1048 " %{} = zext i1 %{} to i64",
1049 result, cmp_result
1050 )?;
1051 ctx.push(result, RegisterType::I64);
1052 Ok(())
1053 }
1054
1055 fn emit_specialized_safe_div(
1064 &mut self,
1065 ctx: &mut RegisterContext,
1066 op: &str, ) -> Result<(), CodeGenError> {
1068 let (b, _) = ctx.pop().unwrap(); let (a, _) = ctx.pop().unwrap(); let is_zero = self.fresh_temp();
1073 writeln!(&mut self.output, " %{} = icmp eq i64 %{}, 0", is_zero, b)?;
1074
1075 let (check_overflow, is_overflow) = if op == "sdiv" {
1078 let is_int_min = self.fresh_temp();
1079 let is_neg_one = self.fresh_temp();
1080 let is_overflow = self.fresh_temp();
1081
1082 writeln!(
1084 &mut self.output,
1085 " %{} = icmp eq i64 %{}, -9223372036854775808",
1086 is_int_min, a
1087 )?;
1088 writeln!(
1090 &mut self.output,
1091 " %{} = icmp eq i64 %{}, -1",
1092 is_neg_one, b
1093 )?;
1094 writeln!(
1096 &mut self.output,
1097 " %{} = and i1 %{}, %{}",
1098 is_overflow, is_int_min, is_neg_one
1099 )?;
1100 (true, is_overflow)
1101 } else {
1102 (false, String::new())
1103 };
1104
1105 let ok_label = self.fresh_block("div_ok");
1107 let fail_label = self.fresh_block("div_fail");
1108 let merge_label = self.fresh_block("div_merge");
1109 let overflow_label = if check_overflow {
1110 self.fresh_block("div_overflow")
1111 } else {
1112 String::new()
1113 };
1114
1115 writeln!(
1117 &mut self.output,
1118 " br i1 %{}, label %{}, label %{}",
1119 is_zero,
1120 fail_label,
1121 if check_overflow {
1122 &overflow_label
1123 } else {
1124 &ok_label
1125 }
1126 )?;
1127
1128 if check_overflow {
1130 writeln!(&mut self.output, "{}:", overflow_label)?;
1131 writeln!(
1132 &mut self.output,
1133 " br i1 %{}, label %{}, label %{}",
1134 is_overflow, merge_label, ok_label
1135 )?;
1136 }
1137
1138 writeln!(&mut self.output, "{}:", ok_label)?;
1140 let ok_result = self.fresh_temp();
1141 writeln!(
1142 &mut self.output,
1143 " %{} = {} i64 %{}, %{}",
1144 ok_result, op, a, b
1145 )?;
1146 writeln!(&mut self.output, " br label %{}", merge_label)?;
1147
1148 writeln!(&mut self.output, "{}:", fail_label)?;
1150 writeln!(&mut self.output, " br label %{}", merge_label)?;
1151
1152 writeln!(&mut self.output, "{}:", merge_label)?;
1154 let result_phi = self.fresh_temp();
1155 let success_phi = self.fresh_temp();
1156
1157 if check_overflow {
1158 writeln!(
1161 &mut self.output,
1162 " %{} = phi i64 [ %{}, %{} ], [ 0, %{} ], [ -9223372036854775808, %{} ]",
1163 result_phi, ok_result, ok_label, fail_label, overflow_label
1164 )?;
1165 writeln!(
1166 &mut self.output,
1167 " %{} = phi i64 [ 1, %{} ], [ 0, %{} ], [ 1, %{} ]",
1168 success_phi, ok_label, fail_label, overflow_label
1169 )?;
1170 } else {
1171 writeln!(
1173 &mut self.output,
1174 " %{} = phi i64 [ %{}, %{} ], [ 0, %{} ]",
1175 result_phi, ok_result, ok_label, fail_label
1176 )?;
1177 writeln!(
1178 &mut self.output,
1179 " %{} = phi i64 [ 1, %{} ], [ 0, %{} ]",
1180 success_phi, ok_label, fail_label
1181 )?;
1182 }
1183
1184 ctx.push(result_phi, RegisterType::I64);
1187 ctx.push(success_phi, RegisterType::I64);
1188
1189 Ok(())
1190 }
1191
1192 fn emit_specialized_safe_shift(
1197 &mut self,
1198 ctx: &mut RegisterContext,
1199 is_left: bool, ) -> Result<(), CodeGenError> {
1201 let (b, _) = ctx.pop().unwrap(); let (a, _) = ctx.pop().unwrap(); let is_negative = self.fresh_temp();
1206 writeln!(
1207 &mut self.output,
1208 " %{} = icmp slt i64 %{}, 0",
1209 is_negative, b
1210 )?;
1211
1212 let is_too_large = self.fresh_temp();
1214 writeln!(
1215 &mut self.output,
1216 " %{} = icmp sge i64 %{}, 64",
1217 is_too_large, b
1218 )?;
1219
1220 let is_invalid = self.fresh_temp();
1222 writeln!(
1223 &mut self.output,
1224 " %{} = or i1 %{}, %{}",
1225 is_invalid, is_negative, is_too_large
1226 )?;
1227
1228 let safe_count = self.fresh_temp();
1230 writeln!(
1231 &mut self.output,
1232 " %{} = select i1 %{}, i64 0, i64 %{}",
1233 safe_count, is_invalid, b
1234 )?;
1235
1236 let shift_result = self.fresh_temp();
1238 let op = if is_left { "shl" } else { "lshr" };
1239 writeln!(
1240 &mut self.output,
1241 " %{} = {} i64 %{}, %{}",
1242 shift_result, op, a, safe_count
1243 )?;
1244
1245 let result = self.fresh_temp();
1247 writeln!(
1248 &mut self.output,
1249 " %{} = select i1 %{}, i64 0, i64 %{}",
1250 result, is_invalid, shift_result
1251 )?;
1252
1253 ctx.push(result, RegisterType::I64);
1254 Ok(())
1255 }
1256
1257 fn emit_specialized_recursive_call(
1265 &mut self,
1266 ctx: &mut RegisterContext,
1267 word_name: &str,
1268 sig: &SpecSignature,
1269 is_tail: bool,
1270 ) -> Result<(), CodeGenError> {
1271 let spec_name = format!("seq_{}{}", mangle_name(word_name), sig.suffix());
1272
1273 if ctx.values.len() < sig.inputs.len() {
1275 return Err(CodeGenError::Logic(format!(
1276 "Not enough values in context for recursive call to {}: need {}, have {}",
1277 word_name,
1278 sig.inputs.len(),
1279 ctx.values.len()
1280 )));
1281 }
1282
1283 let mut args = Vec::new();
1285 for _ in 0..sig.inputs.len() {
1286 args.push(ctx.pop().unwrap());
1287 }
1288 args.reverse(); let arg_strs: Vec<String> = args
1292 .iter()
1293 .map(|(var, ty)| format!("{} %{}", ty.llvm_type(), var))
1294 .collect();
1295
1296 let return_type = sig.llvm_return_type();
1297
1298 if is_tail {
1299 let result = self.fresh_temp();
1301 writeln!(
1302 &mut self.output,
1303 " %{} = musttail call {} @{}({})",
1304 result,
1305 return_type,
1306 spec_name,
1307 arg_strs.join(", ")
1308 )?;
1309 writeln!(&mut self.output, " ret {} %{}", return_type, result)?;
1310 } else {
1311 let result = self.fresh_temp();
1313 writeln!(
1314 &mut self.output,
1315 " %{} = call {} @{}({})",
1316 result,
1317 return_type,
1318 spec_name,
1319 arg_strs.join(", ")
1320 )?;
1321
1322 if sig.outputs.len() == 1 {
1323 ctx.push(result, sig.outputs[0]);
1325 } else {
1326 for (i, out_ty) in sig.outputs.iter().enumerate() {
1328 let extracted = self.fresh_temp();
1329 writeln!(
1330 &mut self.output,
1331 " %{} = extractvalue {} %{}, {}",
1332 extracted, return_type, result, i
1333 )?;
1334 ctx.push(extracted, *out_ty);
1335 }
1336 }
1337 }
1338
1339 Ok(())
1340 }
1341
1342 fn emit_specialized_word_dispatch(
1344 &mut self,
1345 ctx: &mut RegisterContext,
1346 name: &str,
1347 ) -> Result<(), CodeGenError> {
1348 let sig = self
1349 .specialized_words
1350 .get(name)
1351 .ok_or_else(|| CodeGenError::Logic(format!("Unknown specialized word: {}", name)))?
1352 .clone();
1353
1354 let spec_name = format!("seq_{}{}", mangle_name(name), sig.suffix());
1355
1356 let mut args = Vec::new();
1358 for _ in 0..sig.inputs.len() {
1359 args.push(ctx.pop().unwrap());
1360 }
1361 args.reverse();
1362
1363 let arg_strs: Vec<String> = args
1365 .iter()
1366 .map(|(var, ty)| format!("{} %{}", ty.llvm_type(), var))
1367 .collect();
1368
1369 let return_type = sig.llvm_return_type();
1370
1371 let result = self.fresh_temp();
1372 writeln!(
1373 &mut self.output,
1374 " %{} = call {} @{}({})",
1375 result,
1376 return_type,
1377 spec_name,
1378 arg_strs.join(", ")
1379 )?;
1380
1381 if sig.outputs.len() == 1 {
1382 ctx.push(result, sig.outputs[0]);
1384 } else {
1385 for (i, out_ty) in sig.outputs.iter().enumerate() {
1387 let extracted = self.fresh_temp();
1388 writeln!(
1389 &mut self.output,
1390 " %{} = extractvalue {} %{}, {}",
1391 extracted, return_type, result, i
1392 )?;
1393 ctx.push(extracted, *out_ty);
1394 }
1395 }
1396
1397 Ok(())
1398 }
1399
1400 fn emit_specialized_return(
1402 &mut self,
1403 ctx: &RegisterContext,
1404 sig: &SpecSignature,
1405 ) -> Result<(), CodeGenError> {
1406 let output_count = sig.outputs.len();
1407
1408 if output_count == 0 {
1409 writeln!(&mut self.output, " ret void")?;
1410 } else if output_count == 1 {
1411 let (var, ty) = ctx
1412 .values
1413 .last()
1414 .ok_or_else(|| CodeGenError::Logic("Empty context at return".to_string()))?;
1415 writeln!(&mut self.output, " ret {} %{}", ty.llvm_type(), var)?;
1416 } else {
1417 if ctx.values.len() < output_count {
1420 return Err(CodeGenError::Logic(format!(
1421 "Not enough values for multi-output return: need {}, have {}",
1422 output_count,
1423 ctx.values.len()
1424 )));
1425 }
1426
1427 let start_idx = ctx.values.len() - output_count;
1429 let return_values: Vec<_> = ctx.values[start_idx..].to_vec();
1430
1431 let struct_type = sig.llvm_return_type();
1433
1434 let mut current_struct = "undef".to_string();
1436 for (i, (var, ty)) in return_values.iter().enumerate() {
1437 let new_struct = self.fresh_temp();
1438 writeln!(
1439 &mut self.output,
1440 " %{} = insertvalue {} {}, {} %{}, {}",
1441 new_struct,
1442 struct_type,
1443 current_struct,
1444 ty.llvm_type(),
1445 var,
1446 i
1447 )?;
1448 current_struct = format!("%{}", new_struct);
1449 }
1450
1451 writeln!(&mut self.output, " ret {} {}", struct_type, current_struct)?;
1452 }
1453 Ok(())
1454 }
1455
1456 fn codegen_specialized_if(
1458 &mut self,
1459 ctx: &mut RegisterContext,
1460 then_branch: &[Statement],
1461 else_branch: Option<&Vec<Statement>>,
1462 word_name: &str,
1463 sig: &SpecSignature,
1464 is_last: bool,
1465 ) -> Result<(), CodeGenError> {
1466 let (cond_var, _) = ctx
1468 .pop()
1469 .ok_or_else(|| CodeGenError::Logic("Empty context at if condition".to_string()))?;
1470
1471 let cmp_result = self.fresh_temp();
1473 writeln!(
1474 &mut self.output,
1475 " %{} = icmp ne i64 %{}, 0",
1476 cmp_result, cond_var
1477 )?;
1478
1479 let then_label = self.fresh_block("if_then");
1481 let else_label = self.fresh_block("if_else");
1482 let merge_label = self.fresh_block("if_merge");
1483
1484 writeln!(
1485 &mut self.output,
1486 " br i1 %{}, label %{}, label %{}",
1487 cmp_result, then_label, else_label
1488 )?;
1489
1490 writeln!(&mut self.output, "{}:", then_label)?;
1492 let mut then_ctx = ctx.clone();
1493 let mut then_prev_int: Option<i64> = None;
1494 for (i, stmt) in then_branch.iter().enumerate() {
1495 let is_stmt_last = i == then_branch.len() - 1 && is_last;
1496 self.codegen_specialized_statement(
1497 &mut then_ctx,
1498 stmt,
1499 word_name,
1500 sig,
1501 is_stmt_last,
1502 &mut then_prev_int,
1503 )?;
1504 }
1505 if is_last && then_branch.is_empty() {
1507 self.emit_specialized_return(&then_ctx, sig)?;
1508 }
1509 let then_emitted_return = is_last;
1511 let then_pred = if then_emitted_return {
1512 None
1513 } else {
1514 writeln!(&mut self.output, " br label %{}", merge_label)?;
1515 Some(then_label.clone())
1516 };
1517
1518 writeln!(&mut self.output, "{}:", else_label)?;
1520 let mut else_ctx = ctx.clone();
1521 let mut else_prev_int: Option<i64> = None;
1522 if let Some(else_stmts) = else_branch {
1523 for (i, stmt) in else_stmts.iter().enumerate() {
1524 let is_stmt_last = i == else_stmts.len() - 1 && is_last;
1525 self.codegen_specialized_statement(
1526 &mut else_ctx,
1527 stmt,
1528 word_name,
1529 sig,
1530 is_stmt_last,
1531 &mut else_prev_int,
1532 )?;
1533 }
1534 }
1535 if is_last && (else_branch.is_none() || else_branch.as_ref().is_some_and(|b| b.is_empty()))
1537 {
1538 self.emit_specialized_return(&else_ctx, sig)?;
1539 }
1540 let else_emitted_return = is_last;
1542 let else_pred = if else_emitted_return {
1543 None
1544 } else {
1545 writeln!(&mut self.output, " br label %{}", merge_label)?;
1546 Some(else_label.clone())
1547 };
1548
1549 if then_pred.is_some() || else_pred.is_some() {
1551 writeln!(&mut self.output, "{}:", merge_label)?;
1552
1553 if let (Some(then_p), Some(else_p)) = (&then_pred, &else_pred) {
1555 if then_ctx.values.len() != else_ctx.values.len() {
1557 return Err(CodeGenError::Logic(format!(
1558 "Stack depth mismatch in if branches: then has {}, else has {}",
1559 then_ctx.values.len(),
1560 else_ctx.values.len()
1561 )));
1562 }
1563
1564 ctx.values.clear();
1565 for i in 0..then_ctx.values.len() {
1566 let (then_var, then_ty) = &then_ctx.values[i];
1567 let (else_var, else_ty) = &else_ctx.values[i];
1568
1569 if then_ty != else_ty {
1570 return Err(CodeGenError::Logic(format!(
1571 "Type mismatch at position {} in if branches: {:?} vs {:?}",
1572 i, then_ty, else_ty
1573 )));
1574 }
1575
1576 if then_var == else_var {
1578 ctx.push(then_var.clone(), *then_ty);
1579 } else {
1580 let phi_result = self.fresh_temp();
1581 writeln!(
1582 &mut self.output,
1583 " %{} = phi {} [ %{}, %{} ], [ %{}, %{} ]",
1584 phi_result,
1585 then_ty.llvm_type(),
1586 then_var,
1587 then_p,
1588 else_var,
1589 else_p
1590 )?;
1591 ctx.push(phi_result, *then_ty);
1592 }
1593 }
1594 } else if then_pred.is_some() {
1595 *ctx = then_ctx;
1597 } else {
1598 *ctx = else_ctx;
1600 }
1601
1602 if is_last && (then_pred.is_some() || else_pred.is_some()) {
1604 self.emit_specialized_return(ctx, sig)?;
1605 }
1606 }
1607
1608 Ok(())
1609 }
1610}
1611
1612#[cfg(test)]
1613mod tests {
1614 use super::*;
1615
1616 #[test]
1617 fn test_register_type_from_type() {
1618 assert_eq!(RegisterType::from_type(&Type::Int), Some(RegisterType::I64));
1619 assert_eq!(
1620 RegisterType::from_type(&Type::Bool),
1621 Some(RegisterType::I64)
1622 );
1623 assert_eq!(
1624 RegisterType::from_type(&Type::Float),
1625 Some(RegisterType::Double)
1626 );
1627 assert_eq!(RegisterType::from_type(&Type::String), None);
1628 }
1629
1630 #[test]
1631 fn test_spec_signature_suffix() {
1632 let sig = SpecSignature {
1633 inputs: vec![RegisterType::I64],
1634 outputs: vec![RegisterType::I64],
1635 };
1636 assert_eq!(sig.suffix(), "_i64");
1637
1638 let sig2 = SpecSignature {
1639 inputs: vec![RegisterType::Double],
1640 outputs: vec![RegisterType::Double],
1641 };
1642 assert_eq!(sig2.suffix(), "_f64");
1643 }
1644
1645 #[test]
1646 fn test_register_context_stack_ops() {
1647 let mut ctx = RegisterContext::new();
1648 ctx.push("a".to_string(), RegisterType::I64);
1649 ctx.push("b".to_string(), RegisterType::I64);
1650
1651 assert_eq!(ctx.len(), 2);
1652
1653 ctx.swap();
1655 assert_eq!(ctx.values[0].0, "b");
1656 assert_eq!(ctx.values[1].0, "a");
1657
1658 ctx.dup();
1660 assert_eq!(ctx.len(), 3);
1661 assert_eq!(ctx.values[2].0, "a");
1662
1663 ctx.drop();
1665 assert_eq!(ctx.len(), 2);
1666 }
1667}