1use super::{CodeGen, CodeGenError, mangle_name};
92use crate::ast::{Statement, WordDef};
93use crate::types::{StackType, Type};
94use std::fmt::Write as _;
95
96#[derive(Debug, Clone, Copy, PartialEq, Eq)]
98pub enum RegisterType {
99 I64,
101 Double,
103}
104
105impl RegisterType {
106 pub fn from_type(ty: &Type) -> Option<Self> {
108 match ty {
109 Type::Int | Type::Bool => Some(RegisterType::I64),
110 Type::Float => Some(RegisterType::Double),
111 _ => None,
112 }
113 }
114
115 pub fn llvm_type(&self) -> &'static str {
117 match self {
118 RegisterType::I64 => "i64",
119 RegisterType::Double => "double",
120 }
121 }
122}
123
124#[derive(Debug, Clone)]
126pub struct SpecSignature {
127 pub inputs: Vec<RegisterType>,
129 pub outputs: Vec<RegisterType>,
131}
132
133impl SpecSignature {
134 pub fn suffix(&self) -> String {
138 if self.inputs.len() == 1 && self.outputs.len() == 1 {
139 match (self.inputs[0], self.outputs[0]) {
140 (RegisterType::I64, RegisterType::I64) => "_i64".to_string(),
141 (RegisterType::Double, RegisterType::Double) => "_f64".to_string(),
142 (RegisterType::I64, RegisterType::Double) => "_i64_to_f64".to_string(),
143 (RegisterType::Double, RegisterType::I64) => "_f64_to_i64".to_string(),
144 }
145 } else {
146 let mut suffix = String::new();
148 for ty in &self.inputs {
149 suffix.push('_');
150 suffix.push_str(match ty {
151 RegisterType::I64 => "i",
152 RegisterType::Double => "f",
153 });
154 }
155 suffix.push_str("_to");
156 for ty in &self.outputs {
157 suffix.push('_');
158 suffix.push_str(match ty {
159 RegisterType::I64 => "i",
160 RegisterType::Double => "f",
161 });
162 }
163 suffix
164 }
165 }
166
167 pub fn is_direct_call(&self) -> bool {
169 self.outputs.len() == 1
170 }
171
172 pub fn llvm_return_type(&self) -> String {
177 if self.outputs.len() == 1 {
178 self.outputs[0].llvm_type().to_string()
179 } else {
180 let types: Vec<_> = self.outputs.iter().map(|t| t.llvm_type()).collect();
181 format!("{{ {} }}", types.join(", "))
182 }
183 }
184}
185
186#[derive(Debug, Clone)]
191pub struct RegisterContext {
192 pub values: Vec<(String, RegisterType)>,
194}
195
196impl RegisterContext {
197 pub fn new() -> Self {
199 Self { values: Vec::new() }
200 }
201
202 pub fn from_params(params: &[(String, RegisterType)]) -> Self {
204 Self {
205 values: params.to_vec(),
206 }
207 }
208
209 pub fn push(&mut self, ssa_var: String, ty: RegisterType) {
211 self.values.push((ssa_var, ty));
212 }
213
214 pub fn pop(&mut self) -> Option<(String, RegisterType)> {
216 self.values.pop()
217 }
218
219 #[cfg_attr(not(test), allow(dead_code))]
221 pub fn len(&self) -> usize {
222 self.values.len()
223 }
224
225 pub fn dup(&mut self) {
229 if let Some((ssa, ty)) = self.values.last().cloned() {
230 self.values.push((ssa, ty));
231 }
232 }
233
234 pub fn drop(&mut self) {
236 self.values.pop();
237 }
238
239 pub fn swap(&mut self) {
241 let len = self.values.len();
242 if len >= 2 {
243 self.values.swap(len - 1, len - 2);
244 }
245 }
246
247 pub fn over(&mut self) {
249 let len = self.values.len();
250 if len >= 2 {
251 let a = self.values[len - 2].clone();
252 self.values.push(a);
253 }
254 }
255
256 pub fn rot(&mut self) {
258 let len = self.values.len();
259 if len >= 3 {
260 let a = self.values.remove(len - 3);
261 self.values.push(a);
262 }
263 }
264}
265
266impl Default for RegisterContext {
267 fn default() -> Self {
268 Self::new()
269 }
270}
271
272const SPECIALIZABLE_OPS: &[&str] = &[
282 "i.+",
284 "i.add",
285 "i.-",
286 "i.subtract",
287 "i.*",
288 "i.multiply",
289 "i./",
290 "i.divide",
291 "i.%",
292 "i.mod",
293 "band",
295 "bor",
296 "bxor",
297 "bnot",
298 "shl",
299 "shr",
300 "popcount",
302 "clz",
303 "ctz",
304 "int->float",
306 "float->int",
307 "and",
309 "or",
310 "not",
311 "i.<",
313 "i.lt",
314 "i.>",
315 "i.gt",
316 "i.<=",
317 "i.lte",
318 "i.>=",
319 "i.gte",
320 "i.=",
321 "i.eq",
322 "i.<>",
323 "i.neq",
324 "f.+",
326 "f.add",
327 "f.-",
328 "f.subtract",
329 "f.*",
330 "f.multiply",
331 "f./",
332 "f.divide",
333 "f.<",
335 "f.lt",
336 "f.>",
337 "f.gt",
338 "f.<=",
339 "f.lte",
340 "f.>=",
341 "f.gte",
342 "f.=",
343 "f.eq",
344 "f.<>",
345 "f.neq",
346 "dup",
348 "drop",
349 "swap",
350 "over",
351 "rot",
352 "nip",
353 "tuck",
354 "pick",
355 "roll",
356];
357
358impl CodeGen {
359 pub fn can_specialize(&self, word: &WordDef) -> Option<SpecSignature> {
361 let effect = word.effect.as_ref()?;
363
364 if !effect.is_pure() {
366 return None;
367 }
368
369 let inputs = Self::extract_register_types(&effect.inputs)?;
371 let outputs = Self::extract_register_types(&effect.outputs)?;
372
373 if inputs.is_empty() && outputs.is_empty() {
375 return None;
376 }
377
378 if outputs.is_empty() {
380 return None;
381 }
382
383 if !self.is_body_specializable(&word.body, &word.name) {
385 return None;
386 }
387
388 Some(SpecSignature { inputs, outputs })
389 }
390
391 fn extract_register_types(stack: &StackType) -> Option<Vec<RegisterType>> {
397 let mut types = Vec::new();
398 let mut current = stack;
399
400 loop {
401 match current {
402 StackType::Empty => break,
403 StackType::RowVar(_) => {
404 break;
408 }
409 StackType::Cons { rest, top } => {
410 let reg_ty = RegisterType::from_type(top)?;
411 types.push(reg_ty);
412 current = rest;
413 }
414 }
415 }
416
417 types.reverse();
419 Some(types)
420 }
421
422 fn is_body_specializable(&self, body: &[Statement], word_name: &str) -> bool {
427 let mut prev_was_int_literal = false;
428 for stmt in body {
429 if !self.is_statement_specializable(stmt, word_name, prev_was_int_literal) {
430 return false;
431 }
432 prev_was_int_literal = matches!(stmt, Statement::IntLiteral(_));
434 }
435 true
436 }
437
438 fn is_statement_specializable(
443 &self,
444 stmt: &Statement,
445 word_name: &str,
446 prev_was_int_literal: bool,
447 ) -> bool {
448 match stmt {
449 Statement::IntLiteral(_) => true,
451
452 Statement::FloatLiteral(_) => true,
454
455 Statement::BoolLiteral(_) => true,
457
458 Statement::StringLiteral(_) => false,
460
461 Statement::Symbol(_) => false,
463
464 Statement::Quotation { .. } => false,
466
467 Statement::Match { .. } => false,
469
470 Statement::WordCall { name, .. } => {
472 if name == word_name {
474 return true;
475 }
476
477 if (name == "pick" || name == "roll") && !prev_was_int_literal {
480 return false;
481 }
482
483 if SPECIALIZABLE_OPS.contains(&name.as_str()) {
485 return true;
486 }
487
488 if self.specialized_words.contains_key(name) {
490 return true;
491 }
492
493 false
495 }
496
497 Statement::If {
499 then_branch,
500 else_branch,
501 } => {
502 if !self.is_body_specializable(then_branch, word_name) {
503 return false;
504 }
505 if let Some(else_stmts) = else_branch
506 && !self.is_body_specializable(else_stmts, word_name)
507 {
508 return false;
509 }
510 true
511 }
512 }
513 }
514
515 pub fn codegen_specialized_word(
533 &mut self,
534 word: &WordDef,
535 sig: &SpecSignature,
536 ) -> Result<(), CodeGenError> {
537 let base_name = format!("seq_{}", mangle_name(&word.name));
538 let spec_name = format!("{}{}", base_name, sig.suffix());
539
540 let return_type = if sig.outputs.len() == 1 {
544 sig.outputs[0].llvm_type().to_string()
545 } else {
546 let types: Vec<_> = sig.outputs.iter().map(|t| t.llvm_type()).collect();
548 format!("{{ {} }}", types.join(", "))
549 };
550
551 let params: Vec<String> = sig
553 .inputs
554 .iter()
555 .enumerate()
556 .map(|(i, ty)| format!("{} %arg{}", ty.llvm_type(), i))
557 .collect();
558
559 writeln!(
560 &mut self.output,
561 "define {} @{}({}) {{",
562 return_type,
563 spec_name,
564 params.join(", ")
565 )?;
566 writeln!(&mut self.output, "entry:")?;
567
568 let initial_params: Vec<(String, RegisterType)> = sig
570 .inputs
571 .iter()
572 .enumerate()
573 .map(|(i, ty)| (format!("arg{}", i), *ty))
574 .collect();
575 let mut ctx = RegisterContext::from_params(&initial_params);
576
577 let body_len = word.body.len();
579 let mut prev_int_literal: Option<i64> = None;
580 for (i, stmt) in word.body.iter().enumerate() {
581 let is_last = i == body_len - 1;
582 self.codegen_specialized_statement(
583 &mut ctx,
584 stmt,
585 &word.name,
586 sig,
587 is_last,
588 &mut prev_int_literal,
589 )?;
590 }
591
592 writeln!(&mut self.output, "}}")?;
593 writeln!(&mut self.output)?;
594
595 self.specialized_words
597 .insert(word.name.clone(), sig.clone());
598
599 Ok(())
600 }
601
602 fn codegen_specialized_statement(
604 &mut self,
605 ctx: &mut RegisterContext,
606 stmt: &Statement,
607 word_name: &str,
608 sig: &SpecSignature,
609 is_last: bool,
610 prev_int_literal: &mut Option<i64>,
611 ) -> Result<(), CodeGenError> {
612 let prev_int = *prev_int_literal;
614 *prev_int_literal = None; match stmt {
617 Statement::IntLiteral(n) => {
618 let var = self.fresh_temp();
619 writeln!(&mut self.output, " %{} = add i64 0, {}", var, n)?;
620 ctx.push(var, RegisterType::I64);
621 *prev_int_literal = Some(*n); }
623
624 Statement::FloatLiteral(f) => {
625 let var = self.fresh_temp();
626 let bits = f.to_bits();
631 writeln!(
632 &mut self.output,
633 " %{} = bitcast i64 {} to double",
634 var, bits
635 )?;
636 ctx.push(var, RegisterType::Double);
637 }
638
639 Statement::BoolLiteral(b) => {
640 let var = self.fresh_temp();
641 let val = if *b { 1 } else { 0 };
642 writeln!(&mut self.output, " %{} = add i64 0, {}", var, val)?;
643 ctx.push(var, RegisterType::I64);
644 }
645
646 Statement::WordCall { name, .. } => {
647 self.codegen_specialized_word_call(ctx, name, word_name, sig, is_last, prev_int)?;
648 }
649
650 Statement::If {
651 then_branch,
652 else_branch,
653 } => {
654 self.codegen_specialized_if(
655 ctx,
656 then_branch,
657 else_branch.as_ref(),
658 word_name,
659 sig,
660 is_last,
661 )?;
662 }
663
664 Statement::StringLiteral(_)
666 | Statement::Symbol(_)
667 | Statement::Quotation { .. }
668 | Statement::Match { .. } => {
669 return Err(CodeGenError::Logic(format!(
670 "Non-specializable statement in specialized word: {:?}",
671 stmt
672 )));
673 }
674 }
675
676 let already_returns = match stmt {
679 Statement::If { .. } => true,
680 Statement::WordCall { name, .. } if name == word_name => true,
681 _ => false,
682 };
683 if is_last && !already_returns {
684 self.emit_specialized_return(ctx, sig)?;
685 }
686
687 Ok(())
688 }
689
690 fn codegen_specialized_word_call(
692 &mut self,
693 ctx: &mut RegisterContext,
694 name: &str,
695 word_name: &str,
696 sig: &SpecSignature,
697 is_last: bool,
698 prev_int: Option<i64>,
699 ) -> Result<(), CodeGenError> {
700 match name {
701 "dup" => ctx.dup(),
703 "drop" => ctx.drop(),
704 "swap" => ctx.swap(),
705 "over" => ctx.over(),
706 "rot" => ctx.rot(),
707 "nip" => {
708 ctx.swap();
710 ctx.drop();
711 }
712 "tuck" => {
713 ctx.dup();
715 let b = ctx.pop().unwrap();
716 let b2 = ctx.pop().unwrap();
717 let a = ctx.pop().unwrap();
718 ctx.push(b.0, b.1);
719 ctx.push(a.0, a.1);
720 ctx.push(b2.0, b2.1);
721 }
722 "pick" => {
723 let n = prev_int.ok_or_else(|| {
726 CodeGenError::Logic("pick requires constant N in specialized mode".to_string())
727 })?;
728 if n < 0 {
729 return Err(CodeGenError::Logic(format!(
730 "pick requires non-negative N, got {}",
731 n
732 )));
733 }
734 let n = n as usize;
735 ctx.pop();
737 let len = ctx.values.len();
739 if n >= len {
740 return Err(CodeGenError::Logic(format!(
741 "pick {} but only {} values in context",
742 n, len
743 )));
744 }
745 let (var, ty) = ctx.values[len - 1 - n].clone();
746 ctx.push(var, ty);
747 }
748 "roll" => {
749 let n = prev_int.ok_or_else(|| {
752 CodeGenError::Logic("roll requires constant N in specialized mode".to_string())
753 })?;
754 if n < 0 {
755 return Err(CodeGenError::Logic(format!(
756 "roll requires non-negative N, got {}",
757 n
758 )));
759 }
760 let n = n as usize;
761 ctx.pop();
763 let len = ctx.values.len();
765 if n >= len {
766 return Err(CodeGenError::Logic(format!(
767 "roll {} but only {} values in context",
768 n, len
769 )));
770 }
771 if n > 0 {
772 let val = ctx.values.remove(len - 1 - n);
773 ctx.values.push(val);
774 }
775 }
777
778 "i.+" | "i.add" => {
781 let (b, _) = ctx.pop().unwrap();
782 let (a, _) = ctx.pop().unwrap();
783 let result = self.fresh_temp();
784 writeln!(&mut self.output, " %{} = add i64 %{}, %{}", result, a, b)?;
785 ctx.push(result, RegisterType::I64);
786 }
787 "i.-" | "i.subtract" => {
788 let (b, _) = ctx.pop().unwrap();
789 let (a, _) = ctx.pop().unwrap();
790 let result = self.fresh_temp();
791 writeln!(&mut self.output, " %{} = sub i64 %{}, %{}", result, a, b)?;
792 ctx.push(result, RegisterType::I64);
793 }
794 "i.*" | "i.multiply" => {
795 let (b, _) = ctx.pop().unwrap();
796 let (a, _) = ctx.pop().unwrap();
797 let result = self.fresh_temp();
798 writeln!(&mut self.output, " %{} = mul i64 %{}, %{}", result, a, b)?;
799 ctx.push(result, RegisterType::I64);
800 }
801 "i./" | "i.divide" => {
802 self.emit_specialized_safe_div(ctx, "sdiv")?;
803 }
804 "i.%" | "i.mod" => {
805 self.emit_specialized_safe_div(ctx, "srem")?;
806 }
807
808 "band" => {
810 let (b, _) = ctx.pop().unwrap();
811 let (a, _) = ctx.pop().unwrap();
812 let result = self.fresh_temp();
813 writeln!(&mut self.output, " %{} = and i64 %{}, %{}", result, a, b)?;
814 ctx.push(result, RegisterType::I64);
815 }
816 "bor" => {
817 let (b, _) = ctx.pop().unwrap();
818 let (a, _) = ctx.pop().unwrap();
819 let result = self.fresh_temp();
820 writeln!(&mut self.output, " %{} = or i64 %{}, %{}", result, a, b)?;
821 ctx.push(result, RegisterType::I64);
822 }
823 "bxor" => {
824 let (b, _) = ctx.pop().unwrap();
825 let (a, _) = ctx.pop().unwrap();
826 let result = self.fresh_temp();
827 writeln!(&mut self.output, " %{} = xor i64 %{}, %{}", result, a, b)?;
828 ctx.push(result, RegisterType::I64);
829 }
830 "bnot" => {
831 let (a, _) = ctx.pop().unwrap();
832 let result = self.fresh_temp();
833 writeln!(&mut self.output, " %{} = xor i64 %{}, -1", result, a)?;
835 ctx.push(result, RegisterType::I64);
836 }
837 "shl" => {
838 self.emit_specialized_safe_shift(ctx, true)?;
839 }
840 "shr" => {
841 self.emit_specialized_safe_shift(ctx, false)?;
842 }
843
844 "popcount" => {
846 let (a, _) = ctx.pop().unwrap();
847 let result = self.fresh_temp();
848 writeln!(
849 &mut self.output,
850 " %{} = call i64 @llvm.ctpop.i64(i64 %{})",
851 result, a
852 )?;
853 ctx.push(result, RegisterType::I64);
854 }
855 "clz" => {
856 let (a, _) = ctx.pop().unwrap();
857 let result = self.fresh_temp();
858 writeln!(
860 &mut self.output,
861 " %{} = call i64 @llvm.ctlz.i64(i64 %{}, i1 false)",
862 result, a
863 )?;
864 ctx.push(result, RegisterType::I64);
865 }
866 "ctz" => {
867 let (a, _) = ctx.pop().unwrap();
868 let result = self.fresh_temp();
869 writeln!(
871 &mut self.output,
872 " %{} = call i64 @llvm.cttz.i64(i64 %{}, i1 false)",
873 result, a
874 )?;
875 ctx.push(result, RegisterType::I64);
876 }
877
878 "int->float" => {
880 let (a, _) = ctx.pop().unwrap();
881 let result = self.fresh_temp();
882 writeln!(
883 &mut self.output,
884 " %{} = sitofp i64 %{} to double",
885 result, a
886 )?;
887 ctx.push(result, RegisterType::Double);
888 }
889 "float->int" => {
890 let (a, _) = ctx.pop().unwrap();
891 let result = self.fresh_temp();
892 writeln!(
893 &mut self.output,
894 " %{} = fptosi double %{} to i64",
895 result, a
896 )?;
897 ctx.push(result, RegisterType::I64);
898 }
899
900 "and" => {
902 let (b, _) = ctx.pop().unwrap();
903 let (a, _) = ctx.pop().unwrap();
904 let result = self.fresh_temp();
905 writeln!(&mut self.output, " %{} = and i64 %{}, %{}", result, a, b)?;
906 ctx.push(result, RegisterType::I64);
907 }
908 "or" => {
909 let (b, _) = ctx.pop().unwrap();
910 let (a, _) = ctx.pop().unwrap();
911 let result = self.fresh_temp();
912 writeln!(&mut self.output, " %{} = or i64 %{}, %{}", result, a, b)?;
913 ctx.push(result, RegisterType::I64);
914 }
915 "not" => {
916 let (a, _) = ctx.pop().unwrap();
917 let result = self.fresh_temp();
918 writeln!(&mut self.output, " %{} = xor i64 %{}, 1", result, a)?;
920 ctx.push(result, RegisterType::I64);
921 }
922
923 "i.<" | "i.lt" => self.emit_specialized_icmp(ctx, "slt")?,
925 "i.>" | "i.gt" => self.emit_specialized_icmp(ctx, "sgt")?,
926 "i.<=" | "i.lte" => self.emit_specialized_icmp(ctx, "sle")?,
927 "i.>=" | "i.gte" => self.emit_specialized_icmp(ctx, "sge")?,
928 "i.=" | "i.eq" => self.emit_specialized_icmp(ctx, "eq")?,
929 "i.<>" | "i.neq" => self.emit_specialized_icmp(ctx, "ne")?,
930
931 "f.+" | "f.add" => {
933 let (b, _) = ctx.pop().unwrap();
934 let (a, _) = ctx.pop().unwrap();
935 let result = self.fresh_temp();
936 writeln!(
937 &mut self.output,
938 " %{} = fadd double %{}, %{}",
939 result, a, b
940 )?;
941 ctx.push(result, RegisterType::Double);
942 }
943 "f.-" | "f.subtract" => {
944 let (b, _) = ctx.pop().unwrap();
945 let (a, _) = ctx.pop().unwrap();
946 let result = self.fresh_temp();
947 writeln!(
948 &mut self.output,
949 " %{} = fsub double %{}, %{}",
950 result, a, b
951 )?;
952 ctx.push(result, RegisterType::Double);
953 }
954 "f.*" | "f.multiply" => {
955 let (b, _) = ctx.pop().unwrap();
956 let (a, _) = ctx.pop().unwrap();
957 let result = self.fresh_temp();
958 writeln!(
959 &mut self.output,
960 " %{} = fmul double %{}, %{}",
961 result, a, b
962 )?;
963 ctx.push(result, RegisterType::Double);
964 }
965 "f./" | "f.divide" => {
966 let (b, _) = ctx.pop().unwrap();
967 let (a, _) = ctx.pop().unwrap();
968 let result = self.fresh_temp();
969 writeln!(
970 &mut self.output,
971 " %{} = fdiv double %{}, %{}",
972 result, a, b
973 )?;
974 ctx.push(result, RegisterType::Double);
975 }
976
977 "f.<" | "f.lt" => self.emit_specialized_fcmp(ctx, "olt")?,
979 "f.>" | "f.gt" => self.emit_specialized_fcmp(ctx, "ogt")?,
980 "f.<=" | "f.lte" => self.emit_specialized_fcmp(ctx, "ole")?,
981 "f.>=" | "f.gte" => self.emit_specialized_fcmp(ctx, "oge")?,
982 "f.=" | "f.eq" => self.emit_specialized_fcmp(ctx, "oeq")?,
983 "f.<>" | "f.neq" => self.emit_specialized_fcmp(ctx, "one")?,
984
985 _ if name == word_name => {
987 self.emit_specialized_recursive_call(ctx, word_name, sig, is_last)?;
988 }
989
990 _ if self.specialized_words.contains_key(name) => {
992 self.emit_specialized_word_dispatch(ctx, name)?;
993 }
994
995 _ => {
996 return Err(CodeGenError::Logic(format!(
997 "Unhandled operation in specialized codegen: {}",
998 name
999 )));
1000 }
1001 }
1002 Ok(())
1003 }
1004
1005 fn emit_specialized_icmp(
1007 &mut self,
1008 ctx: &mut RegisterContext,
1009 cmp_op: &str,
1010 ) -> Result<(), CodeGenError> {
1011 let (b, _) = ctx.pop().unwrap();
1012 let (a, _) = ctx.pop().unwrap();
1013 let cmp_result = self.fresh_temp();
1014 let result = self.fresh_temp();
1015 writeln!(
1016 &mut self.output,
1017 " %{} = icmp {} i64 %{}, %{}",
1018 cmp_result, cmp_op, a, b
1019 )?;
1020 writeln!(
1021 &mut self.output,
1022 " %{} = zext i1 %{} to i64",
1023 result, cmp_result
1024 )?;
1025 ctx.push(result, RegisterType::I64);
1026 Ok(())
1027 }
1028
1029 fn emit_specialized_fcmp(
1031 &mut self,
1032 ctx: &mut RegisterContext,
1033 cmp_op: &str,
1034 ) -> Result<(), CodeGenError> {
1035 let (b, _) = ctx.pop().unwrap();
1036 let (a, _) = ctx.pop().unwrap();
1037 let cmp_result = self.fresh_temp();
1038 let result = self.fresh_temp();
1039 writeln!(
1040 &mut self.output,
1041 " %{} = fcmp {} double %{}, %{}",
1042 cmp_result, cmp_op, a, b
1043 )?;
1044 writeln!(
1045 &mut self.output,
1046 " %{} = zext i1 %{} to i64",
1047 result, cmp_result
1048 )?;
1049 ctx.push(result, RegisterType::I64);
1050 Ok(())
1051 }
1052
1053 fn emit_specialized_safe_div(
1062 &mut self,
1063 ctx: &mut RegisterContext,
1064 op: &str, ) -> Result<(), CodeGenError> {
1066 let (b, _) = ctx.pop().unwrap(); let (a, _) = ctx.pop().unwrap(); let is_zero = self.fresh_temp();
1071 writeln!(&mut self.output, " %{} = icmp eq i64 %{}, 0", is_zero, b)?;
1072
1073 let (check_overflow, is_overflow) = if op == "sdiv" {
1076 let is_int_min = self.fresh_temp();
1077 let is_neg_one = self.fresh_temp();
1078 let is_overflow = self.fresh_temp();
1079
1080 writeln!(
1082 &mut self.output,
1083 " %{} = icmp eq i64 %{}, -9223372036854775808",
1084 is_int_min, a
1085 )?;
1086 writeln!(
1088 &mut self.output,
1089 " %{} = icmp eq i64 %{}, -1",
1090 is_neg_one, b
1091 )?;
1092 writeln!(
1094 &mut self.output,
1095 " %{} = and i1 %{}, %{}",
1096 is_overflow, is_int_min, is_neg_one
1097 )?;
1098 (true, is_overflow)
1099 } else {
1100 (false, String::new())
1101 };
1102
1103 let ok_label = self.fresh_block("div_ok");
1105 let fail_label = self.fresh_block("div_fail");
1106 let merge_label = self.fresh_block("div_merge");
1107 let overflow_label = if check_overflow {
1108 self.fresh_block("div_overflow")
1109 } else {
1110 String::new()
1111 };
1112
1113 writeln!(
1115 &mut self.output,
1116 " br i1 %{}, label %{}, label %{}",
1117 is_zero,
1118 fail_label,
1119 if check_overflow {
1120 &overflow_label
1121 } else {
1122 &ok_label
1123 }
1124 )?;
1125
1126 if check_overflow {
1128 writeln!(&mut self.output, "{}:", overflow_label)?;
1129 writeln!(
1130 &mut self.output,
1131 " br i1 %{}, label %{}, label %{}",
1132 is_overflow, merge_label, ok_label
1133 )?;
1134 }
1135
1136 writeln!(&mut self.output, "{}:", ok_label)?;
1138 let ok_result = self.fresh_temp();
1139 writeln!(
1140 &mut self.output,
1141 " %{} = {} i64 %{}, %{}",
1142 ok_result, op, a, b
1143 )?;
1144 writeln!(&mut self.output, " br label %{}", merge_label)?;
1145
1146 writeln!(&mut self.output, "{}:", fail_label)?;
1148 writeln!(&mut self.output, " br label %{}", merge_label)?;
1149
1150 writeln!(&mut self.output, "{}:", merge_label)?;
1152 let result_phi = self.fresh_temp();
1153 let success_phi = self.fresh_temp();
1154
1155 if check_overflow {
1156 writeln!(
1159 &mut self.output,
1160 " %{} = phi i64 [ %{}, %{} ], [ 0, %{} ], [ -9223372036854775808, %{} ]",
1161 result_phi, ok_result, ok_label, fail_label, overflow_label
1162 )?;
1163 writeln!(
1164 &mut self.output,
1165 " %{} = phi i64 [ 1, %{} ], [ 0, %{} ], [ 1, %{} ]",
1166 success_phi, ok_label, fail_label, overflow_label
1167 )?;
1168 } else {
1169 writeln!(
1171 &mut self.output,
1172 " %{} = phi i64 [ %{}, %{} ], [ 0, %{} ]",
1173 result_phi, ok_result, ok_label, fail_label
1174 )?;
1175 writeln!(
1176 &mut self.output,
1177 " %{} = phi i64 [ 1, %{} ], [ 0, %{} ]",
1178 success_phi, ok_label, fail_label
1179 )?;
1180 }
1181
1182 ctx.push(result_phi, RegisterType::I64);
1185 ctx.push(success_phi, RegisterType::I64);
1186
1187 Ok(())
1188 }
1189
1190 fn emit_specialized_safe_shift(
1195 &mut self,
1196 ctx: &mut RegisterContext,
1197 is_left: bool, ) -> Result<(), CodeGenError> {
1199 let (b, _) = ctx.pop().unwrap(); let (a, _) = ctx.pop().unwrap(); let is_negative = self.fresh_temp();
1204 writeln!(
1205 &mut self.output,
1206 " %{} = icmp slt i64 %{}, 0",
1207 is_negative, b
1208 )?;
1209
1210 let is_too_large = self.fresh_temp();
1212 writeln!(
1213 &mut self.output,
1214 " %{} = icmp sge i64 %{}, 64",
1215 is_too_large, b
1216 )?;
1217
1218 let is_invalid = self.fresh_temp();
1220 writeln!(
1221 &mut self.output,
1222 " %{} = or i1 %{}, %{}",
1223 is_invalid, is_negative, is_too_large
1224 )?;
1225
1226 let safe_count = self.fresh_temp();
1228 writeln!(
1229 &mut self.output,
1230 " %{} = select i1 %{}, i64 0, i64 %{}",
1231 safe_count, is_invalid, b
1232 )?;
1233
1234 let shift_result = self.fresh_temp();
1236 let op = if is_left { "shl" } else { "lshr" };
1237 writeln!(
1238 &mut self.output,
1239 " %{} = {} i64 %{}, %{}",
1240 shift_result, op, a, safe_count
1241 )?;
1242
1243 let result = self.fresh_temp();
1245 writeln!(
1246 &mut self.output,
1247 " %{} = select i1 %{}, i64 0, i64 %{}",
1248 result, is_invalid, shift_result
1249 )?;
1250
1251 ctx.push(result, RegisterType::I64);
1252 Ok(())
1253 }
1254
1255 fn emit_specialized_recursive_call(
1263 &mut self,
1264 ctx: &mut RegisterContext,
1265 word_name: &str,
1266 sig: &SpecSignature,
1267 is_tail: bool,
1268 ) -> Result<(), CodeGenError> {
1269 let spec_name = format!("seq_{}{}", mangle_name(word_name), sig.suffix());
1270
1271 if ctx.values.len() < sig.inputs.len() {
1273 return Err(CodeGenError::Logic(format!(
1274 "Not enough values in context for recursive call to {}: need {}, have {}",
1275 word_name,
1276 sig.inputs.len(),
1277 ctx.values.len()
1278 )));
1279 }
1280
1281 let mut args = Vec::new();
1283 for _ in 0..sig.inputs.len() {
1284 args.push(ctx.pop().unwrap());
1285 }
1286 args.reverse(); let arg_strs: Vec<String> = args
1290 .iter()
1291 .map(|(var, ty)| format!("{} %{}", ty.llvm_type(), var))
1292 .collect();
1293
1294 let return_type = sig.llvm_return_type();
1295
1296 if is_tail {
1297 let result = self.fresh_temp();
1299 writeln!(
1300 &mut self.output,
1301 " %{} = musttail call {} @{}({})",
1302 result,
1303 return_type,
1304 spec_name,
1305 arg_strs.join(", ")
1306 )?;
1307 writeln!(&mut self.output, " ret {} %{}", return_type, result)?;
1308 } else {
1309 let result = self.fresh_temp();
1311 writeln!(
1312 &mut self.output,
1313 " %{} = call {} @{}({})",
1314 result,
1315 return_type,
1316 spec_name,
1317 arg_strs.join(", ")
1318 )?;
1319
1320 if sig.outputs.len() == 1 {
1321 ctx.push(result, sig.outputs[0]);
1323 } else {
1324 for (i, out_ty) in sig.outputs.iter().enumerate() {
1326 let extracted = self.fresh_temp();
1327 writeln!(
1328 &mut self.output,
1329 " %{} = extractvalue {} %{}, {}",
1330 extracted, return_type, result, i
1331 )?;
1332 ctx.push(extracted, *out_ty);
1333 }
1334 }
1335 }
1336
1337 Ok(())
1338 }
1339
1340 fn emit_specialized_word_dispatch(
1342 &mut self,
1343 ctx: &mut RegisterContext,
1344 name: &str,
1345 ) -> Result<(), CodeGenError> {
1346 let sig = self
1347 .specialized_words
1348 .get(name)
1349 .ok_or_else(|| CodeGenError::Logic(format!("Unknown specialized word: {}", name)))?
1350 .clone();
1351
1352 let spec_name = format!("seq_{}{}", mangle_name(name), sig.suffix());
1353
1354 let mut args = Vec::new();
1356 for _ in 0..sig.inputs.len() {
1357 args.push(ctx.pop().unwrap());
1358 }
1359 args.reverse();
1360
1361 let arg_strs: Vec<String> = args
1363 .iter()
1364 .map(|(var, ty)| format!("{} %{}", ty.llvm_type(), var))
1365 .collect();
1366
1367 let return_type = sig.llvm_return_type();
1368
1369 let result = self.fresh_temp();
1370 writeln!(
1371 &mut self.output,
1372 " %{} = call {} @{}({})",
1373 result,
1374 return_type,
1375 spec_name,
1376 arg_strs.join(", ")
1377 )?;
1378
1379 if sig.outputs.len() == 1 {
1380 ctx.push(result, sig.outputs[0]);
1382 } else {
1383 for (i, out_ty) in sig.outputs.iter().enumerate() {
1385 let extracted = self.fresh_temp();
1386 writeln!(
1387 &mut self.output,
1388 " %{} = extractvalue {} %{}, {}",
1389 extracted, return_type, result, i
1390 )?;
1391 ctx.push(extracted, *out_ty);
1392 }
1393 }
1394
1395 Ok(())
1396 }
1397
1398 fn emit_specialized_return(
1400 &mut self,
1401 ctx: &RegisterContext,
1402 sig: &SpecSignature,
1403 ) -> Result<(), CodeGenError> {
1404 let output_count = sig.outputs.len();
1405
1406 if output_count == 0 {
1407 writeln!(&mut self.output, " ret void")?;
1408 } else if output_count == 1 {
1409 let (var, ty) = ctx
1410 .values
1411 .last()
1412 .ok_or_else(|| CodeGenError::Logic("Empty context at return".to_string()))?;
1413 writeln!(&mut self.output, " ret {} %{}", ty.llvm_type(), var)?;
1414 } else {
1415 if ctx.values.len() < output_count {
1418 return Err(CodeGenError::Logic(format!(
1419 "Not enough values for multi-output return: need {}, have {}",
1420 output_count,
1421 ctx.values.len()
1422 )));
1423 }
1424
1425 let start_idx = ctx.values.len() - output_count;
1427 let return_values: Vec<_> = ctx.values[start_idx..].to_vec();
1428
1429 let struct_type = sig.llvm_return_type();
1431
1432 let mut current_struct = "undef".to_string();
1434 for (i, (var, ty)) in return_values.iter().enumerate() {
1435 let new_struct = self.fresh_temp();
1436 writeln!(
1437 &mut self.output,
1438 " %{} = insertvalue {} {}, {} %{}, {}",
1439 new_struct,
1440 struct_type,
1441 current_struct,
1442 ty.llvm_type(),
1443 var,
1444 i
1445 )?;
1446 current_struct = format!("%{}", new_struct);
1447 }
1448
1449 writeln!(&mut self.output, " ret {} {}", struct_type, current_struct)?;
1450 }
1451 Ok(())
1452 }
1453
1454 fn codegen_specialized_if(
1456 &mut self,
1457 ctx: &mut RegisterContext,
1458 then_branch: &[Statement],
1459 else_branch: Option<&Vec<Statement>>,
1460 word_name: &str,
1461 sig: &SpecSignature,
1462 is_last: bool,
1463 ) -> Result<(), CodeGenError> {
1464 let (cond_var, _) = ctx
1466 .pop()
1467 .ok_or_else(|| CodeGenError::Logic("Empty context at if condition".to_string()))?;
1468
1469 let cmp_result = self.fresh_temp();
1471 writeln!(
1472 &mut self.output,
1473 " %{} = icmp ne i64 %{}, 0",
1474 cmp_result, cond_var
1475 )?;
1476
1477 let then_label = self.fresh_block("if_then");
1479 let else_label = self.fresh_block("if_else");
1480 let merge_label = self.fresh_block("if_merge");
1481
1482 writeln!(
1483 &mut self.output,
1484 " br i1 %{}, label %{}, label %{}",
1485 cmp_result, then_label, else_label
1486 )?;
1487
1488 writeln!(&mut self.output, "{}:", then_label)?;
1490 let mut then_ctx = ctx.clone();
1491 let mut then_prev_int: Option<i64> = None;
1492 for (i, stmt) in then_branch.iter().enumerate() {
1493 let is_stmt_last = i == then_branch.len() - 1 && is_last;
1494 self.codegen_specialized_statement(
1495 &mut then_ctx,
1496 stmt,
1497 word_name,
1498 sig,
1499 is_stmt_last,
1500 &mut then_prev_int,
1501 )?;
1502 }
1503 if is_last && then_branch.is_empty() {
1505 self.emit_specialized_return(&then_ctx, sig)?;
1506 }
1507 let then_emitted_return = is_last;
1509 let then_pred = if then_emitted_return {
1510 None
1511 } else {
1512 writeln!(&mut self.output, " br label %{}", merge_label)?;
1513 Some(then_label.clone())
1514 };
1515
1516 writeln!(&mut self.output, "{}:", else_label)?;
1518 let mut else_ctx = ctx.clone();
1519 let mut else_prev_int: Option<i64> = None;
1520 if let Some(else_stmts) = else_branch {
1521 for (i, stmt) in else_stmts.iter().enumerate() {
1522 let is_stmt_last = i == else_stmts.len() - 1 && is_last;
1523 self.codegen_specialized_statement(
1524 &mut else_ctx,
1525 stmt,
1526 word_name,
1527 sig,
1528 is_stmt_last,
1529 &mut else_prev_int,
1530 )?;
1531 }
1532 }
1533 if is_last && (else_branch.is_none() || else_branch.as_ref().is_some_and(|b| b.is_empty()))
1535 {
1536 self.emit_specialized_return(&else_ctx, sig)?;
1537 }
1538 let else_emitted_return = is_last;
1540 let else_pred = if else_emitted_return {
1541 None
1542 } else {
1543 writeln!(&mut self.output, " br label %{}", merge_label)?;
1544 Some(else_label.clone())
1545 };
1546
1547 if then_pred.is_some() || else_pred.is_some() {
1549 writeln!(&mut self.output, "{}:", merge_label)?;
1550
1551 if let (Some(then_p), Some(else_p)) = (&then_pred, &else_pred) {
1553 if then_ctx.values.len() != else_ctx.values.len() {
1555 return Err(CodeGenError::Logic(format!(
1556 "Stack depth mismatch in if branches: then has {}, else has {}",
1557 then_ctx.values.len(),
1558 else_ctx.values.len()
1559 )));
1560 }
1561
1562 ctx.values.clear();
1563 for i in 0..then_ctx.values.len() {
1564 let (then_var, then_ty) = &then_ctx.values[i];
1565 let (else_var, else_ty) = &else_ctx.values[i];
1566
1567 if then_ty != else_ty {
1568 return Err(CodeGenError::Logic(format!(
1569 "Type mismatch at position {} in if branches: {:?} vs {:?}",
1570 i, then_ty, else_ty
1571 )));
1572 }
1573
1574 if then_var == else_var {
1576 ctx.push(then_var.clone(), *then_ty);
1577 } else {
1578 let phi_result = self.fresh_temp();
1579 writeln!(
1580 &mut self.output,
1581 " %{} = phi {} [ %{}, %{} ], [ %{}, %{} ]",
1582 phi_result,
1583 then_ty.llvm_type(),
1584 then_var,
1585 then_p,
1586 else_var,
1587 else_p
1588 )?;
1589 ctx.push(phi_result, *then_ty);
1590 }
1591 }
1592 } else if then_pred.is_some() {
1593 *ctx = then_ctx;
1595 } else {
1596 *ctx = else_ctx;
1598 }
1599
1600 if is_last && (then_pred.is_some() || else_pred.is_some()) {
1602 self.emit_specialized_return(ctx, sig)?;
1603 }
1604 }
1605
1606 Ok(())
1607 }
1608}
1609
1610#[cfg(test)]
1611mod tests {
1612 use super::*;
1613
1614 #[test]
1615 fn test_register_type_from_type() {
1616 assert_eq!(RegisterType::from_type(&Type::Int), Some(RegisterType::I64));
1617 assert_eq!(
1618 RegisterType::from_type(&Type::Bool),
1619 Some(RegisterType::I64)
1620 );
1621 assert_eq!(
1622 RegisterType::from_type(&Type::Float),
1623 Some(RegisterType::Double)
1624 );
1625 assert_eq!(RegisterType::from_type(&Type::String), None);
1626 }
1627
1628 #[test]
1629 fn test_spec_signature_suffix() {
1630 let sig = SpecSignature {
1631 inputs: vec![RegisterType::I64],
1632 outputs: vec![RegisterType::I64],
1633 };
1634 assert_eq!(sig.suffix(), "_i64");
1635
1636 let sig2 = SpecSignature {
1637 inputs: vec![RegisterType::Double],
1638 outputs: vec![RegisterType::Double],
1639 };
1640 assert_eq!(sig2.suffix(), "_f64");
1641 }
1642
1643 #[test]
1644 fn test_register_context_stack_ops() {
1645 let mut ctx = RegisterContext::new();
1646 ctx.push("a".to_string(), RegisterType::I64);
1647 ctx.push("b".to_string(), RegisterType::I64);
1648
1649 assert_eq!(ctx.len(), 2);
1650
1651 ctx.swap();
1653 assert_eq!(ctx.values[0].0, "b");
1654 assert_eq!(ctx.values[1].0, "a");
1655
1656 ctx.dup();
1658 assert_eq!(ctx.len(), 3);
1659 assert_eq!(ctx.values[2].0, "a");
1660
1661 ctx.drop();
1663 assert_eq!(ctx.len(), 2);
1664 }
1665}